feat: support tls

This commit is contained in:
sigoden
2022-06-02 11:06:41 +08:00
parent 97978719b3
commit e2d7f996c7
6 changed files with 307 additions and 108 deletions

View File

@@ -4,6 +4,7 @@ use async_walkdir::WalkDir;
use async_zip::read::seek::ZipFileReader;
use async_zip::write::{EntryOptions, ZipFileWriter};
use async_zip::Compression;
use chrono::Local;
use futures::stream::StreamExt;
use futures::TryStreamExt;
use get_if_addrs::get_if_addrs;
@@ -19,6 +20,7 @@ use hyper::header::{
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, StatusCode};
use percent_encoding::percent_decode;
use rustls::ServerConfig;
use serde::Serialize;
use std::convert::Infallible;
use std::fs::Metadata;
@@ -28,7 +30,9 @@ use std::sync::Arc;
use std::time::SystemTime;
use tokio::fs::File;
use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWrite};
use tokio::net::TcpListener;
use tokio::{fs, io};
use tokio_rustls::TlsAcceptor;
use tokio_util::codec::{BytesCodec, FramedRead};
use tokio_util::io::{ReaderStream, StreamReader};
@@ -49,33 +53,61 @@ macro_rules! status {
}
pub async fn serve(args: Args) -> BoxResult<()> {
let args = Arc::new(args);
let socket_addr = args.address()?;
let address = args.address.clone();
let port = args.port;
let inner = Arc::new(InnerService::new(args));
let make_svc = make_service_fn(move |_| {
let inner = inner.clone();
async {
Ok::<_, Infallible>(service_fn(move |req| {
let inner = inner.clone();
inner.call(req)
}))
}
});
let server = hyper::Server::try_bind(&socket_addr)?.serve(make_svc);
print_listening(&address, port);
server.await?;
let inner = Arc::new(InnerService::new(args.clone()));
if let Some((certs, key)) = args.tls.as_ref() {
let config = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs.clone(), key.clone())?;
let tls_acceptor = TlsAcceptor::from(Arc::new(config));
let arc_acceptor = Arc::new(tls_acceptor);
let listener = TcpListener::bind(&socket_addr).await.unwrap();
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
let incoming = hyper::server::accept::from_stream(incoming.filter_map(|socket| async {
match socket {
Ok(stream) => match arc_acceptor.clone().accept(stream).await {
Ok(val) => Some(Ok::<_, Infallible>(val)),
Err(_) => None,
},
Err(_) => None,
}
}));
let server = hyper::Server::builder(incoming).serve(make_service_fn(move |_| {
let inner = inner.clone();
async move {
Ok::<_, Infallible>(service_fn(move |req| {
let inner = inner.clone();
inner.call(req)
}))
}
}));
print_listening(args.address.as_str(), args.port, true);
server.await?;
} else {
let server = hyper::Server::bind(&socket_addr).serve(make_service_fn(move |_| {
let inner = inner.clone();
async move {
Ok::<_, Infallible>(service_fn(move |req| {
let inner = inner.clone();
inner.call(req)
}))
}
}));
print_listening(args.address.as_str(), args.port, false);
server.await?;
}
Ok(())
}
struct InnerService {
args: Args,
args: Arc<Args>,
}
impl InnerService {
pub fn new(args: Args) -> Self {
pub fn new(args: Arc<Args>) -> Self {
Self { args }
}
@@ -84,15 +116,20 @@ impl InnerService {
let uri = req.uri().clone();
let cors = self.args.cors;
let timestamp = Local::now().format("%d/%b/%Y %H:%M:%S");
let mut res = match self.handle(req).await {
Ok(res) => {
info!(r#""{} {}" - {}"#, method, uri, res.status());
println!(r#"[{}] "{} {}" - {}"#, timestamp, method, uri, res.status());
res
}
Err(err) => {
let mut res = Response::default();
status!(res, StatusCode::INTERNAL_SERVER_ERROR);
error!(r#""{} {}" - {} {}"#, method, uri, res.status(), err);
let status = StatusCode::INTERNAL_SERVER_ERROR;
status!(res, status);
eprintln!(
r#"[{}] "{} {}" - {} {}"#,
timestamp, method, uri, status, err
);
res
}
};
@@ -314,7 +351,7 @@ impl InnerService {
let path = path.to_owned();
tokio::spawn(async move {
if let Err(e) = zip_dir(&mut writer, &path).await {
error!("Fail to zip {}, {}", path.display(), e.to_string());
eprintln!("Failed to zip {}, {}", path.display(), e);
}
});
let stream = ReaderStream::new(reader);
@@ -678,14 +715,15 @@ fn to_content_range(range: &Range, complete_length: u64) -> Option<ContentRange>
})
}
fn print_listening(address: &str, port: u16) {
fn print_listening(address: &str, port: u16, tls: bool) {
let addrs = retrive_listening_addrs(address);
let protocol = if tls { "https" } else { "http" };
if addrs.len() == 1 {
eprintln!("Listening on http://{}:{}", addrs[0], port);
eprintln!("Listening on {}://{}:{}", protocol, addrs[0], port);
} else {
eprintln!("Listening on:");
for addr in addrs {
eprintln!(" http://{}:{}", addr, port);
eprintln!(" {}://{}:{}", protocol, addr, port);
}
}
}