feat: upgrade to hyper 1.0 (#321)

This commit is contained in:
sigoden
2023-12-21 14:24:20 +08:00
committed by GitHub
parent 5988442d5c
commit 270cc0cba2
11 changed files with 595 additions and 509 deletions

View File

@@ -1,38 +1,37 @@
mod args;
mod auth;
mod http_logger;
mod http_utils;
mod logger;
mod server;
mod streamer;
#[cfg(feature = "tls")]
mod tls;
#[cfg(unix)]
mod unix;
mod utils;
#[macro_use]
extern crate log;
use crate::args::{build_cli, print_completions, Args};
use crate::server::{Request, Server};
use crate::server::Server;
#[cfg(feature = "tls")]
use crate::tls::{load_certs, load_private_key, TlsAcceptor, TlsStream};
use crate::utils::{load_certs, load_private_key};
use anyhow::{anyhow, Context, Result};
use std::net::{IpAddr, SocketAddr, TcpListener as StdTcpListener};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use args::BindAddr;
use clap_complete::Shell;
use futures::future::join_all;
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
use futures_util::future::join_all;
use hyper::server::conn::{AddrIncoming, AddrStream};
use hyper::service::{make_service_fn, service_fn};
use hyper::{body::Incoming, service::service_fn, Request};
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto::Builder,
};
use std::net::{IpAddr, SocketAddr, TcpListener as StdTcpListener};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use tokio::{net::TcpListener, task::JoinHandle};
#[cfg(feature = "tls")]
use rustls::ServerConfig;
use tokio_rustls::{rustls::ServerConfig, TlsAcceptor};
#[tokio::main]
async fn main() -> Result<()> {
@@ -45,10 +44,10 @@ async fn main() -> Result<()> {
return Ok(());
}
let args = Args::parse(matches)?;
let args = Arc::new(args);
let running = Arc::new(AtomicBool::new(true));
let handles = serve(args.clone(), running.clone())?;
print_listening(args)?;
let listening = print_listening(&args)?;
let handles = serve(args, running.clone())?;
println!("{listening}");
tokio::select! {
ret = join_all(handles) => {
@@ -66,56 +65,65 @@ async fn main() -> Result<()> {
}
}
fn serve(
args: Arc<Args>,
running: Arc<AtomicBool>,
) -> Result<Vec<JoinHandle<Result<(), hyper::Error>>>> {
let inner = Arc::new(Server::init(args.clone(), running)?);
let mut handles = vec![];
fn serve(args: Args, running: Arc<AtomicBool>) -> Result<Vec<JoinHandle<()>>> {
let addrs = args.addrs.clone();
let port = args.port;
for bind_addr in args.addrs.iter() {
let inner = inner.clone();
let serve_func = move |remote_addr: Option<SocketAddr>| {
let inner = inner.clone();
async move {
Ok::<_, hyper::Error>(service_fn(move |req: Request| {
let inner = inner.clone();
inner.call(req, remote_addr)
}))
}
};
let tls_config = (args.tls_cert.clone(), args.tls_key.clone());
let server_handle = Arc::new(Server::init(args, running)?);
let mut handles = vec![];
for bind_addr in addrs.iter() {
let server_handle = server_handle.clone();
match bind_addr {
BindAddr::Address(ip) => {
let incoming = create_addr_incoming(SocketAddr::new(*ip, port))
let listener = create_listener(SocketAddr::new(*ip, port))
.with_context(|| format!("Failed to bind `{ip}:{port}`"))?;
match (&args.tls_cert, &args.tls_key) {
match &tls_config {
#[cfg(feature = "tls")]
(Some(cert_file), Some(key_file)) => {
let certs = load_certs(cert_file)?;
let key = load_private_key(key_file)?;
let config = ServerConfig::builder()
.with_safe_defaults()
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs.clone(), key.clone())?;
.with_single_cert(certs, key)?;
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let config = Arc::new(config);
let accepter = TlsAcceptor::new(config.clone(), incoming);
let new_service = make_service_fn(move |socket: &TlsStream| {
let remote_addr = socket.remote_addr();
serve_func(Some(remote_addr))
let tls_accepter = TlsAcceptor::from(config);
let handle = tokio::spawn(async move {
loop {
let (cnx, addr) = listener.accept().await.unwrap();
let Ok(stream) = tls_accepter.accept(cnx).await else {
eprintln!(
"Warning during tls handshake connection from {}",
addr
);
continue;
};
let stream = TokioIo::new(stream);
tokio::spawn(handle_stream(
server_handle.clone(),
stream,
Some(addr),
));
}
});
let server =
tokio::spawn(hyper::Server::builder(accepter).serve(new_service));
handles.push(server);
handles.push(handle);
}
(None, None) => {
let new_service = make_service_fn(move |socket: &AddrStream| {
let remote_addr = socket.remote_addr();
serve_func(Some(remote_addr))
let handle = tokio::spawn(async move {
loop {
let (cnx, addr) = listener.accept().await.unwrap();
let stream = TokioIo::new(cnx);
tokio::spawn(handle_stream(
server_handle.clone(),
stream,
Some(addr),
));
}
});
let server =
tokio::spawn(hyper::Server::builder(incoming).serve(new_service));
handles.push(server);
handles.push(handle);
}
_ => {
unreachable!()
@@ -130,10 +138,15 @@ fn serve(
{
let listener = tokio::net::UnixListener::bind(path)
.with_context(|| format!("Failed to bind `{}`", path.display()))?;
let acceptor = unix::UnixAcceptor::from_listener(listener);
let new_service = make_service_fn(move |_| serve_func(None));
let server = tokio::spawn(hyper::Server::builder(acceptor).serve(new_service));
handles.push(server);
let handle = tokio::spawn(async move {
loop {
let (cnx, _) = listener.accept().await.unwrap();
let stream = TokioIo::new(cnx);
tokio::spawn(handle_stream(server_handle.clone(), stream, None));
}
});
handles.push(handle);
}
}
}
@@ -141,7 +154,30 @@ fn serve(
Ok(handles)
}
fn create_addr_incoming(addr: SocketAddr) -> Result<AddrIncoming> {
async fn handle_stream<T>(handle: Arc<Server>, stream: TokioIo<T>, addr: Option<SocketAddr>)
where
T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
let hyper_service =
service_fn(move |request: Request<Incoming>| handle.clone().call(request, addr));
let ret = Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(stream, hyper_service)
.await;
if let Err(err) = ret {
let scope = match addr {
Some(addr) => format!(" from {}", addr),
None => String::new(),
};
match err.downcast_ref::<std::io::Error>() {
Some(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {}
_ => eprintln!("Warning serving connection{}: {}", scope, err),
}
}
}
fn create_listener(addr: SocketAddr) -> Result<TcpListener> {
use socket2::{Domain, Protocol, Socket, Type};
let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?;
if addr.is_ipv6() {
@@ -152,11 +188,12 @@ fn create_addr_incoming(addr: SocketAddr) -> Result<AddrIncoming> {
socket.listen(1024 /* Default backlog */)?;
let std_listener = StdTcpListener::from(socket);
std_listener.set_nonblocking(true)?;
let incoming = AddrIncoming::from_listener(TcpListener::from_std(std_listener)?)?;
Ok(incoming)
let listener = TcpListener::from_std(std_listener)?;
Ok(listener)
}
fn print_listening(args: Arc<Args>) -> Result<()> {
fn print_listening(args: &Args) -> Result<String> {
let mut output = String::new();
let mut bind_addrs = vec![];
let (mut ipv4, mut ipv6) = (false, false);
for bind_addr in args.addrs.iter() {
@@ -209,17 +246,17 @@ fn print_listening(args: Arc<Args>) -> Result<()> {
.collect::<Vec<_>>();
if urls.len() == 1 {
println!("Listening on {}", urls[0]);
output.push_str(&format!("Listening on {}", urls[0]))
} else {
let info = urls
.iter()
.map(|v| format!(" {v}"))
.collect::<Vec<String>>()
.join("\n");
println!("Listening on:\n{info}\n");
output.push_str(&format!("Listening on:\n{info}\n"))
}
Ok(())
Ok(output)
}
async fn shutdown_signal() {