mirror of
https://github.com/sigoden/dufs.git
synced 2026-04-09 09:09:03 +03:00
feat: upgrade to hyper 1.0 (#321)
This commit is contained in:
169
src/main.rs
169
src/main.rs
@@ -1,38 +1,37 @@
|
||||
mod args;
|
||||
mod auth;
|
||||
mod http_logger;
|
||||
mod http_utils;
|
||||
mod logger;
|
||||
mod server;
|
||||
mod streamer;
|
||||
#[cfg(feature = "tls")]
|
||||
mod tls;
|
||||
#[cfg(unix)]
|
||||
mod unix;
|
||||
mod utils;
|
||||
|
||||
#[macro_use]
|
||||
extern crate log;
|
||||
|
||||
use crate::args::{build_cli, print_completions, Args};
|
||||
use crate::server::{Request, Server};
|
||||
use crate::server::Server;
|
||||
#[cfg(feature = "tls")]
|
||||
use crate::tls::{load_certs, load_private_key, TlsAcceptor, TlsStream};
|
||||
use crate::utils::{load_certs, load_private_key};
|
||||
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use std::net::{IpAddr, SocketAddr, TcpListener as StdTcpListener};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
use args::BindAddr;
|
||||
use clap_complete::Shell;
|
||||
use futures::future::join_all;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::task::JoinHandle;
|
||||
use futures_util::future::join_all;
|
||||
|
||||
use hyper::server::conn::{AddrIncoming, AddrStream};
|
||||
use hyper::service::{make_service_fn, service_fn};
|
||||
use hyper::{body::Incoming, service::service_fn, Request};
|
||||
use hyper_util::{
|
||||
rt::{TokioExecutor, TokioIo},
|
||||
server::conn::auto::Builder,
|
||||
};
|
||||
use std::net::{IpAddr, SocketAddr, TcpListener as StdTcpListener};
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use tokio::{net::TcpListener, task::JoinHandle};
|
||||
#[cfg(feature = "tls")]
|
||||
use rustls::ServerConfig;
|
||||
use tokio_rustls::{rustls::ServerConfig, TlsAcceptor};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
@@ -45,10 +44,10 @@ async fn main() -> Result<()> {
|
||||
return Ok(());
|
||||
}
|
||||
let args = Args::parse(matches)?;
|
||||
let args = Arc::new(args);
|
||||
let running = Arc::new(AtomicBool::new(true));
|
||||
let handles = serve(args.clone(), running.clone())?;
|
||||
print_listening(args)?;
|
||||
let listening = print_listening(&args)?;
|
||||
let handles = serve(args, running.clone())?;
|
||||
println!("{listening}");
|
||||
|
||||
tokio::select! {
|
||||
ret = join_all(handles) => {
|
||||
@@ -66,56 +65,65 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
fn serve(
|
||||
args: Arc<Args>,
|
||||
running: Arc<AtomicBool>,
|
||||
) -> Result<Vec<JoinHandle<Result<(), hyper::Error>>>> {
|
||||
let inner = Arc::new(Server::init(args.clone(), running)?);
|
||||
let mut handles = vec![];
|
||||
fn serve(args: Args, running: Arc<AtomicBool>) -> Result<Vec<JoinHandle<()>>> {
|
||||
let addrs = args.addrs.clone();
|
||||
let port = args.port;
|
||||
for bind_addr in args.addrs.iter() {
|
||||
let inner = inner.clone();
|
||||
let serve_func = move |remote_addr: Option<SocketAddr>| {
|
||||
let inner = inner.clone();
|
||||
async move {
|
||||
Ok::<_, hyper::Error>(service_fn(move |req: Request| {
|
||||
let inner = inner.clone();
|
||||
inner.call(req, remote_addr)
|
||||
}))
|
||||
}
|
||||
};
|
||||
let tls_config = (args.tls_cert.clone(), args.tls_key.clone());
|
||||
let server_handle = Arc::new(Server::init(args, running)?);
|
||||
let mut handles = vec![];
|
||||
for bind_addr in addrs.iter() {
|
||||
let server_handle = server_handle.clone();
|
||||
match bind_addr {
|
||||
BindAddr::Address(ip) => {
|
||||
let incoming = create_addr_incoming(SocketAddr::new(*ip, port))
|
||||
let listener = create_listener(SocketAddr::new(*ip, port))
|
||||
.with_context(|| format!("Failed to bind `{ip}:{port}`"))?;
|
||||
|
||||
match (&args.tls_cert, &args.tls_key) {
|
||||
match &tls_config {
|
||||
#[cfg(feature = "tls")]
|
||||
(Some(cert_file), Some(key_file)) => {
|
||||
let certs = load_certs(cert_file)?;
|
||||
let key = load_private_key(key_file)?;
|
||||
let config = ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
let mut config = ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs.clone(), key.clone())?;
|
||||
.with_single_cert(certs, key)?;
|
||||
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
let config = Arc::new(config);
|
||||
let accepter = TlsAcceptor::new(config.clone(), incoming);
|
||||
let new_service = make_service_fn(move |socket: &TlsStream| {
|
||||
let remote_addr = socket.remote_addr();
|
||||
serve_func(Some(remote_addr))
|
||||
let tls_accepter = TlsAcceptor::from(config);
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
loop {
|
||||
let (cnx, addr) = listener.accept().await.unwrap();
|
||||
let Ok(stream) = tls_accepter.accept(cnx).await else {
|
||||
eprintln!(
|
||||
"Warning during tls handshake connection from {}",
|
||||
addr
|
||||
);
|
||||
continue;
|
||||
};
|
||||
let stream = TokioIo::new(stream);
|
||||
tokio::spawn(handle_stream(
|
||||
server_handle.clone(),
|
||||
stream,
|
||||
Some(addr),
|
||||
));
|
||||
}
|
||||
});
|
||||
let server =
|
||||
tokio::spawn(hyper::Server::builder(accepter).serve(new_service));
|
||||
handles.push(server);
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
(None, None) => {
|
||||
let new_service = make_service_fn(move |socket: &AddrStream| {
|
||||
let remote_addr = socket.remote_addr();
|
||||
serve_func(Some(remote_addr))
|
||||
let handle = tokio::spawn(async move {
|
||||
loop {
|
||||
let (cnx, addr) = listener.accept().await.unwrap();
|
||||
let stream = TokioIo::new(cnx);
|
||||
tokio::spawn(handle_stream(
|
||||
server_handle.clone(),
|
||||
stream,
|
||||
Some(addr),
|
||||
));
|
||||
}
|
||||
});
|
||||
let server =
|
||||
tokio::spawn(hyper::Server::builder(incoming).serve(new_service));
|
||||
handles.push(server);
|
||||
handles.push(handle);
|
||||
}
|
||||
_ => {
|
||||
unreachable!()
|
||||
@@ -130,10 +138,15 @@ fn serve(
|
||||
{
|
||||
let listener = tokio::net::UnixListener::bind(path)
|
||||
.with_context(|| format!("Failed to bind `{}`", path.display()))?;
|
||||
let acceptor = unix::UnixAcceptor::from_listener(listener);
|
||||
let new_service = make_service_fn(move |_| serve_func(None));
|
||||
let server = tokio::spawn(hyper::Server::builder(acceptor).serve(new_service));
|
||||
handles.push(server);
|
||||
let handle = tokio::spawn(async move {
|
||||
loop {
|
||||
let (cnx, _) = listener.accept().await.unwrap();
|
||||
let stream = TokioIo::new(cnx);
|
||||
tokio::spawn(handle_stream(server_handle.clone(), stream, None));
|
||||
}
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -141,7 +154,30 @@ fn serve(
|
||||
Ok(handles)
|
||||
}
|
||||
|
||||
fn create_addr_incoming(addr: SocketAddr) -> Result<AddrIncoming> {
|
||||
async fn handle_stream<T>(handle: Arc<Server>, stream: TokioIo<T>, addr: Option<SocketAddr>)
|
||||
where
|
||||
T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let hyper_service =
|
||||
service_fn(move |request: Request<Incoming>| handle.clone().call(request, addr));
|
||||
|
||||
let ret = Builder::new(TokioExecutor::new())
|
||||
.serve_connection_with_upgrades(stream, hyper_service)
|
||||
.await;
|
||||
|
||||
if let Err(err) = ret {
|
||||
let scope = match addr {
|
||||
Some(addr) => format!(" from {}", addr),
|
||||
None => String::new(),
|
||||
};
|
||||
match err.downcast_ref::<std::io::Error>() {
|
||||
Some(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {}
|
||||
_ => eprintln!("Warning serving connection{}: {}", scope, err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_listener(addr: SocketAddr) -> Result<TcpListener> {
|
||||
use socket2::{Domain, Protocol, Socket, Type};
|
||||
let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?;
|
||||
if addr.is_ipv6() {
|
||||
@@ -152,11 +188,12 @@ fn create_addr_incoming(addr: SocketAddr) -> Result<AddrIncoming> {
|
||||
socket.listen(1024 /* Default backlog */)?;
|
||||
let std_listener = StdTcpListener::from(socket);
|
||||
std_listener.set_nonblocking(true)?;
|
||||
let incoming = AddrIncoming::from_listener(TcpListener::from_std(std_listener)?)?;
|
||||
Ok(incoming)
|
||||
let listener = TcpListener::from_std(std_listener)?;
|
||||
Ok(listener)
|
||||
}
|
||||
|
||||
fn print_listening(args: Arc<Args>) -> Result<()> {
|
||||
fn print_listening(args: &Args) -> Result<String> {
|
||||
let mut output = String::new();
|
||||
let mut bind_addrs = vec![];
|
||||
let (mut ipv4, mut ipv6) = (false, false);
|
||||
for bind_addr in args.addrs.iter() {
|
||||
@@ -209,17 +246,17 @@ fn print_listening(args: Arc<Args>) -> Result<()> {
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if urls.len() == 1 {
|
||||
println!("Listening on {}", urls[0]);
|
||||
output.push_str(&format!("Listening on {}", urls[0]))
|
||||
} else {
|
||||
let info = urls
|
||||
.iter()
|
||||
.map(|v| format!(" {v}"))
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
println!("Listening on:\n{info}\n");
|
||||
output.push_str(&format!("Listening on:\n{info}\n"))
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
async fn shutdown_signal() {
|
||||
|
||||
Reference in New Issue
Block a user