mirror of
https://github.com/sigoden/dufs.git
synced 2026-04-09 09:09:03 +03:00
feat: support tls
This commit is contained in:
57
src/args.rs
57
src/args.rs
@@ -1,8 +1,9 @@
|
||||
use clap::crate_description;
|
||||
use clap::{Arg, ArgMatches};
|
||||
use std::env;
|
||||
use rustls::{Certificate, PrivateKey};
|
||||
use std::net::SocketAddr;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::{env, fs, io};
|
||||
|
||||
use crate::BoxResult;
|
||||
|
||||
@@ -87,6 +88,18 @@ fn app() -> clap::Command<'static> {
|
||||
.long("cors")
|
||||
.help("Enable CORS, sets `Access-Control-Allow-Origin: *`"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("tls-cert")
|
||||
.long("tls-cert")
|
||||
.value_name("path")
|
||||
.help("Path to an SSL/TLS certificate to serve with HTTPS"),
|
||||
)
|
||||
.arg(
|
||||
Arg::new("tls-key")
|
||||
.long("tls-key")
|
||||
.value_name("path")
|
||||
.help("Path to the SSL/TLS certificate's private key"),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn matches() -> ArgMatches {
|
||||
@@ -107,6 +120,7 @@ pub struct Args {
|
||||
pub render_index: bool,
|
||||
pub render_spa: bool,
|
||||
pub cors: bool,
|
||||
pub tls: Option<(Vec<Certificate>, PrivateKey)>,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@@ -127,6 +141,14 @@ impl Args {
|
||||
let allow_symlink = matches.is_present("allow-all") || matches.is_present("allow-symlink");
|
||||
let render_index = matches.is_present("render-index");
|
||||
let render_spa = matches.is_present("render-spa");
|
||||
let tls = match (matches.value_of("tls-cert"), matches.value_of("tls-key")) {
|
||||
(Some(certs_file), Some(key_file)) => {
|
||||
let certs = load_certs(certs_file)?;
|
||||
let key = load_private_key(key_file)?;
|
||||
Some((certs, key))
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Ok(Args {
|
||||
address,
|
||||
@@ -141,6 +163,7 @@ impl Args {
|
||||
allow_symlink,
|
||||
render_index,
|
||||
render_spa,
|
||||
tls,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -179,3 +202,35 @@ impl Args {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Load public certificate from file.
|
||||
pub fn load_certs(filename: &str) -> BoxResult<Vec<Certificate>> {
|
||||
// Open certificate file.
|
||||
let certfile =
|
||||
fs::File::open(&filename).map_err(|e| format!("Failed to open {}: {}", &filename, e))?;
|
||||
let mut reader = io::BufReader::new(certfile);
|
||||
|
||||
// Load and return certificate.
|
||||
let certs = rustls_pemfile::certs(&mut reader).map_err(|_| "Failed to load certificate")?;
|
||||
if certs.is_empty() {
|
||||
return Err("Expected at least one certificate".into());
|
||||
}
|
||||
Ok(certs.into_iter().map(Certificate).collect())
|
||||
}
|
||||
|
||||
// Load private key from file.
|
||||
pub fn load_private_key(filename: &str) -> BoxResult<PrivateKey> {
|
||||
// Open keyfile.
|
||||
let keyfile =
|
||||
fs::File::open(&filename).map_err(|e| format!("Failed to open {}: {}", &filename, e))?;
|
||||
let mut reader = io::BufReader::new(keyfile);
|
||||
|
||||
// Load and return a single private key.
|
||||
let keys = rustls_pemfile::rsa_private_keys(&mut reader)
|
||||
.map_err(|e| format!("There was a problem with reading private key: {:?}", e))?;
|
||||
|
||||
if keys.len() != 1 {
|
||||
return Err("Expected a single private key".into());
|
||||
}
|
||||
Ok(PrivateKey(keys[0].to_owned()))
|
||||
}
|
||||
|
||||
13
src/main.rs
13
src/main.rs
@@ -4,16 +4,11 @@ macro_rules! bail {
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_use]
|
||||
extern crate log;
|
||||
|
||||
mod args;
|
||||
mod server;
|
||||
|
||||
pub type BoxResult<T> = Result<T, Box<dyn std::error::Error>>;
|
||||
|
||||
use log::LevelFilter;
|
||||
|
||||
use crate::args::{matches, Args};
|
||||
use crate::server::serve;
|
||||
|
||||
@@ -24,14 +19,6 @@ async fn main() {
|
||||
|
||||
async fn run() -> BoxResult<()> {
|
||||
let args = Args::parse(matches())?;
|
||||
|
||||
if std::env::var("RUST_LOG").is_ok() {
|
||||
simple_logger::init()?;
|
||||
} else {
|
||||
simple_logger::SimpleLogger::default()
|
||||
.with_level(LevelFilter::Info)
|
||||
.init()?;
|
||||
}
|
||||
serve(args).await
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ use async_walkdir::WalkDir;
|
||||
use async_zip::read::seek::ZipFileReader;
|
||||
use async_zip::write::{EntryOptions, ZipFileWriter};
|
||||
use async_zip::Compression;
|
||||
use chrono::Local;
|
||||
use futures::stream::StreamExt;
|
||||
use futures::TryStreamExt;
|
||||
use get_if_addrs::get_if_addrs;
|
||||
@@ -19,6 +20,7 @@ use hyper::header::{
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{Body, Method, StatusCode};
|
||||
use percent_encoding::percent_decode;
|
||||
use rustls::ServerConfig;
|
||||
use serde::Serialize;
|
||||
use std::convert::Infallible;
|
||||
use std::fs::Metadata;
|
||||
@@ -28,7 +30,9 @@ use std::sync::Arc;
|
||||
use std::time::SystemTime;
|
||||
use tokio::fs::File;
|
||||
use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWrite};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::{fs, io};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tokio_util::codec::{BytesCodec, FramedRead};
|
||||
use tokio_util::io::{ReaderStream, StreamReader};
|
||||
|
||||
@@ -49,33 +53,61 @@ macro_rules! status {
|
||||
}
|
||||
|
||||
pub async fn serve(args: Args) -> BoxResult<()> {
|
||||
let args = Arc::new(args);
|
||||
let socket_addr = args.address()?;
|
||||
let address = args.address.clone();
|
||||
let port = args.port;
|
||||
let inner = Arc::new(InnerService::new(args));
|
||||
let make_svc = make_service_fn(move |_| {
|
||||
let inner = inner.clone();
|
||||
async {
|
||||
Ok::<_, Infallible>(service_fn(move |req| {
|
||||
let inner = inner.clone();
|
||||
inner.call(req)
|
||||
}))
|
||||
}
|
||||
});
|
||||
|
||||
let server = hyper::Server::try_bind(&socket_addr)?.serve(make_svc);
|
||||
print_listening(&address, port);
|
||||
server.await?;
|
||||
let inner = Arc::new(InnerService::new(args.clone()));
|
||||
if let Some((certs, key)) = args.tls.as_ref() {
|
||||
let config = ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs.clone(), key.clone())?;
|
||||
let tls_acceptor = TlsAcceptor::from(Arc::new(config));
|
||||
let arc_acceptor = Arc::new(tls_acceptor);
|
||||
let listener = TcpListener::bind(&socket_addr).await.unwrap();
|
||||
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
|
||||
let incoming = hyper::server::accept::from_stream(incoming.filter_map(|socket| async {
|
||||
match socket {
|
||||
Ok(stream) => match arc_acceptor.clone().accept(stream).await {
|
||||
Ok(val) => Some(Ok::<_, Infallible>(val)),
|
||||
Err(_) => None,
|
||||
},
|
||||
Err(_) => None,
|
||||
}
|
||||
}));
|
||||
let server = hyper::Server::builder(incoming).serve(make_service_fn(move |_| {
|
||||
let inner = inner.clone();
|
||||
async move {
|
||||
Ok::<_, Infallible>(service_fn(move |req| {
|
||||
let inner = inner.clone();
|
||||
inner.call(req)
|
||||
}))
|
||||
}
|
||||
}));
|
||||
print_listening(args.address.as_str(), args.port, true);
|
||||
server.await?;
|
||||
} else {
|
||||
let server = hyper::Server::bind(&socket_addr).serve(make_service_fn(move |_| {
|
||||
let inner = inner.clone();
|
||||
async move {
|
||||
Ok::<_, Infallible>(service_fn(move |req| {
|
||||
let inner = inner.clone();
|
||||
inner.call(req)
|
||||
}))
|
||||
}
|
||||
}));
|
||||
print_listening(args.address.as_str(), args.port, false);
|
||||
server.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct InnerService {
|
||||
args: Args,
|
||||
args: Arc<Args>,
|
||||
}
|
||||
|
||||
impl InnerService {
|
||||
pub fn new(args: Args) -> Self {
|
||||
pub fn new(args: Arc<Args>) -> Self {
|
||||
Self { args }
|
||||
}
|
||||
|
||||
@@ -84,15 +116,20 @@ impl InnerService {
|
||||
let uri = req.uri().clone();
|
||||
let cors = self.args.cors;
|
||||
|
||||
let timestamp = Local::now().format("%d/%b/%Y %H:%M:%S");
|
||||
let mut res = match self.handle(req).await {
|
||||
Ok(res) => {
|
||||
info!(r#""{} {}" - {}"#, method, uri, res.status());
|
||||
println!(r#"[{}] "{} {}" - {}"#, timestamp, method, uri, res.status());
|
||||
res
|
||||
}
|
||||
Err(err) => {
|
||||
let mut res = Response::default();
|
||||
status!(res, StatusCode::INTERNAL_SERVER_ERROR);
|
||||
error!(r#""{} {}" - {} {}"#, method, uri, res.status(), err);
|
||||
let status = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
status!(res, status);
|
||||
eprintln!(
|
||||
r#"[{}] "{} {}" - {} {}"#,
|
||||
timestamp, method, uri, status, err
|
||||
);
|
||||
res
|
||||
}
|
||||
};
|
||||
@@ -314,7 +351,7 @@ impl InnerService {
|
||||
let path = path.to_owned();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = zip_dir(&mut writer, &path).await {
|
||||
error!("Fail to zip {}, {}", path.display(), e.to_string());
|
||||
eprintln!("Failed to zip {}, {}", path.display(), e);
|
||||
}
|
||||
});
|
||||
let stream = ReaderStream::new(reader);
|
||||
@@ -678,14 +715,15 @@ fn to_content_range(range: &Range, complete_length: u64) -> Option<ContentRange>
|
||||
})
|
||||
}
|
||||
|
||||
fn print_listening(address: &str, port: u16) {
|
||||
fn print_listening(address: &str, port: u16, tls: bool) {
|
||||
let addrs = retrive_listening_addrs(address);
|
||||
let protocol = if tls { "https" } else { "http" };
|
||||
if addrs.len() == 1 {
|
||||
eprintln!("Listening on http://{}:{}", addrs[0], port);
|
||||
eprintln!("Listening on {}://{}:{}", protocol, addrs[0], port);
|
||||
} else {
|
||||
eprintln!("Listening on:");
|
||||
for addr in addrs {
|
||||
eprintln!(" http://{}:{}", addr, port);
|
||||
eprintln!(" {}://{}:{}", protocol, addr, port);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user