feat: upgrade to hyper 1.0 (#321)

This commit is contained in:
sigoden
2023-12-21 14:24:20 +08:00
committed by GitHub
parent 5988442d5c
commit 270cc0cba2
11 changed files with 595 additions and 509 deletions

105
src/http_utils.rs Normal file
View File

@@ -0,0 +1,105 @@
use bytes::{Bytes, BytesMut};
use futures_util::Stream;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::body::{Body, Incoming};
use std::{
pin::Pin,
task::{Context, Poll},
};
use tokio::io::AsyncRead;
use tokio_util::io::poll_read_buf;
#[derive(Debug)]
pub struct IncomingStream {
inner: Incoming,
}
impl IncomingStream {
pub fn new(inner: Incoming) -> Self {
Self { inner }
}
}
impl Stream for IncomingStream {
type Item = Result<Bytes, anyhow::Error>;
#[inline]
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match futures_util::ready!(Pin::new(&mut self.inner).poll_frame(cx)?) {
Some(frame) => match frame.into_data() {
Ok(data) => return Poll::Ready(Some(Ok(data))),
Err(_frame) => {}
},
None => return Poll::Ready(None),
}
}
}
}
pin_project_lite::pin_project! {
pub struct LengthLimitedStream<R> {
#[pin]
reader: Option<R>,
remaining: usize,
buf: BytesMut,
capacity: usize,
}
}
impl<R> LengthLimitedStream<R> {
pub fn new(reader: R, limit: usize) -> Self {
Self {
reader: Some(reader),
remaining: limit,
buf: BytesMut::new(),
capacity: 4096,
}
}
}
impl<R: AsyncRead> Stream for LengthLimitedStream<R> {
type Item = std::io::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.as_mut().project();
if *this.remaining == 0 {
self.project().reader.set(None);
return Poll::Ready(None);
}
let reader = match this.reader.as_pin_mut() {
Some(r) => r,
None => return Poll::Ready(None),
};
if this.buf.capacity() == 0 {
this.buf.reserve(*this.capacity);
}
match poll_read_buf(reader, cx, &mut this.buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(err)) => {
self.project().reader.set(None);
Poll::Ready(Some(Err(err)))
}
Poll::Ready(Ok(0)) => {
self.project().reader.set(None);
Poll::Ready(None)
}
Poll::Ready(Ok(_)) => {
let mut chunk = this.buf.split();
let chunk_size = (*this.remaining).min(chunk.len());
chunk.truncate(chunk_size);
*this.remaining -= chunk_size;
Poll::Ready(Some(Ok(chunk.freeze())))
}
}
}
}
pub fn body_full(content: impl Into<hyper::body::Bytes>) -> BoxBody<Bytes, anyhow::Error> {
Full::new(content.into())
.map_err(anyhow::Error::new)
.boxed()
}

View File

@@ -1,38 +1,37 @@
mod args;
mod auth;
mod http_logger;
mod http_utils;
mod logger;
mod server;
mod streamer;
#[cfg(feature = "tls")]
mod tls;
#[cfg(unix)]
mod unix;
mod utils;
#[macro_use]
extern crate log;
use crate::args::{build_cli, print_completions, Args};
use crate::server::{Request, Server};
use crate::server::Server;
#[cfg(feature = "tls")]
use crate::tls::{load_certs, load_private_key, TlsAcceptor, TlsStream};
use crate::utils::{load_certs, load_private_key};
use anyhow::{anyhow, Context, Result};
use std::net::{IpAddr, SocketAddr, TcpListener as StdTcpListener};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use args::BindAddr;
use clap_complete::Shell;
use futures::future::join_all;
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
use futures_util::future::join_all;
use hyper::server::conn::{AddrIncoming, AddrStream};
use hyper::service::{make_service_fn, service_fn};
use hyper::{body::Incoming, service::service_fn, Request};
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto::Builder,
};
use std::net::{IpAddr, SocketAddr, TcpListener as StdTcpListener};
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use tokio::{net::TcpListener, task::JoinHandle};
#[cfg(feature = "tls")]
use rustls::ServerConfig;
use tokio_rustls::{rustls::ServerConfig, TlsAcceptor};
#[tokio::main]
async fn main() -> Result<()> {
@@ -45,10 +44,10 @@ async fn main() -> Result<()> {
return Ok(());
}
let args = Args::parse(matches)?;
let args = Arc::new(args);
let running = Arc::new(AtomicBool::new(true));
let handles = serve(args.clone(), running.clone())?;
print_listening(args)?;
let listening = print_listening(&args)?;
let handles = serve(args, running.clone())?;
println!("{listening}");
tokio::select! {
ret = join_all(handles) => {
@@ -66,56 +65,65 @@ async fn main() -> Result<()> {
}
}
fn serve(
args: Arc<Args>,
running: Arc<AtomicBool>,
) -> Result<Vec<JoinHandle<Result<(), hyper::Error>>>> {
let inner = Arc::new(Server::init(args.clone(), running)?);
let mut handles = vec![];
fn serve(args: Args, running: Arc<AtomicBool>) -> Result<Vec<JoinHandle<()>>> {
let addrs = args.addrs.clone();
let port = args.port;
for bind_addr in args.addrs.iter() {
let inner = inner.clone();
let serve_func = move |remote_addr: Option<SocketAddr>| {
let inner = inner.clone();
async move {
Ok::<_, hyper::Error>(service_fn(move |req: Request| {
let inner = inner.clone();
inner.call(req, remote_addr)
}))
}
};
let tls_config = (args.tls_cert.clone(), args.tls_key.clone());
let server_handle = Arc::new(Server::init(args, running)?);
let mut handles = vec![];
for bind_addr in addrs.iter() {
let server_handle = server_handle.clone();
match bind_addr {
BindAddr::Address(ip) => {
let incoming = create_addr_incoming(SocketAddr::new(*ip, port))
let listener = create_listener(SocketAddr::new(*ip, port))
.with_context(|| format!("Failed to bind `{ip}:{port}`"))?;
match (&args.tls_cert, &args.tls_key) {
match &tls_config {
#[cfg(feature = "tls")]
(Some(cert_file), Some(key_file)) => {
let certs = load_certs(cert_file)?;
let key = load_private_key(key_file)?;
let config = ServerConfig::builder()
.with_safe_defaults()
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs.clone(), key.clone())?;
.with_single_cert(certs, key)?;
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let config = Arc::new(config);
let accepter = TlsAcceptor::new(config.clone(), incoming);
let new_service = make_service_fn(move |socket: &TlsStream| {
let remote_addr = socket.remote_addr();
serve_func(Some(remote_addr))
let tls_accepter = TlsAcceptor::from(config);
let handle = tokio::spawn(async move {
loop {
let (cnx, addr) = listener.accept().await.unwrap();
let Ok(stream) = tls_accepter.accept(cnx).await else {
eprintln!(
"Warning during tls handshake connection from {}",
addr
);
continue;
};
let stream = TokioIo::new(stream);
tokio::spawn(handle_stream(
server_handle.clone(),
stream,
Some(addr),
));
}
});
let server =
tokio::spawn(hyper::Server::builder(accepter).serve(new_service));
handles.push(server);
handles.push(handle);
}
(None, None) => {
let new_service = make_service_fn(move |socket: &AddrStream| {
let remote_addr = socket.remote_addr();
serve_func(Some(remote_addr))
let handle = tokio::spawn(async move {
loop {
let (cnx, addr) = listener.accept().await.unwrap();
let stream = TokioIo::new(cnx);
tokio::spawn(handle_stream(
server_handle.clone(),
stream,
Some(addr),
));
}
});
let server =
tokio::spawn(hyper::Server::builder(incoming).serve(new_service));
handles.push(server);
handles.push(handle);
}
_ => {
unreachable!()
@@ -130,10 +138,15 @@ fn serve(
{
let listener = tokio::net::UnixListener::bind(path)
.with_context(|| format!("Failed to bind `{}`", path.display()))?;
let acceptor = unix::UnixAcceptor::from_listener(listener);
let new_service = make_service_fn(move |_| serve_func(None));
let server = tokio::spawn(hyper::Server::builder(acceptor).serve(new_service));
handles.push(server);
let handle = tokio::spawn(async move {
loop {
let (cnx, _) = listener.accept().await.unwrap();
let stream = TokioIo::new(cnx);
tokio::spawn(handle_stream(server_handle.clone(), stream, None));
}
});
handles.push(handle);
}
}
}
@@ -141,7 +154,30 @@ fn serve(
Ok(handles)
}
fn create_addr_incoming(addr: SocketAddr) -> Result<AddrIncoming> {
async fn handle_stream<T>(handle: Arc<Server>, stream: TokioIo<T>, addr: Option<SocketAddr>)
where
T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
let hyper_service =
service_fn(move |request: Request<Incoming>| handle.clone().call(request, addr));
let ret = Builder::new(TokioExecutor::new())
.serve_connection_with_upgrades(stream, hyper_service)
.await;
if let Err(err) = ret {
let scope = match addr {
Some(addr) => format!(" from {}", addr),
None => String::new(),
};
match err.downcast_ref::<std::io::Error>() {
Some(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {}
_ => eprintln!("Warning serving connection{}: {}", scope, err),
}
}
}
fn create_listener(addr: SocketAddr) -> Result<TcpListener> {
use socket2::{Domain, Protocol, Socket, Type};
let socket = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))?;
if addr.is_ipv6() {
@@ -152,11 +188,12 @@ fn create_addr_incoming(addr: SocketAddr) -> Result<AddrIncoming> {
socket.listen(1024 /* Default backlog */)?;
let std_listener = StdTcpListener::from(socket);
std_listener.set_nonblocking(true)?;
let incoming = AddrIncoming::from_listener(TcpListener::from_std(std_listener)?)?;
Ok(incoming)
let listener = TcpListener::from_std(std_listener)?;
Ok(listener)
}
fn print_listening(args: Arc<Args>) -> Result<()> {
fn print_listening(args: &Args) -> Result<String> {
let mut output = String::new();
let mut bind_addrs = vec![];
let (mut ipv4, mut ipv6) = (false, false);
for bind_addr in args.addrs.iter() {
@@ -209,17 +246,17 @@ fn print_listening(args: Arc<Args>) -> Result<()> {
.collect::<Vec<_>>();
if urls.len() == 1 {
println!("Listening on {}", urls[0]);
output.push_str(&format!("Listening on {}", urls[0]))
} else {
let info = urls
.iter()
.map(|v| format!(" {v}"))
.collect::<Vec<String>>()
.join("\n");
println!("Listening on:\n{info}\n");
output.push_str(&format!("Listening on:\n{info}\n"))
}
Ok(())
Ok(output)
}
async fn shutdown_signal() {

View File

@@ -1,29 +1,32 @@
#![allow(clippy::too_many_arguments)]
use crate::auth::{www_authenticate, AccessPaths, AccessPerm};
use crate::streamer::Streamer;
use crate::http_utils::{body_full, IncomingStream, LengthLimitedStream};
use crate::utils::{
decode_uri, encode_uri, get_file_mtime_and_mode, get_file_name, glob, try_get_file_name,
};
use crate::Args;
use anyhow::{anyhow, Result};
use walkdir::WalkDir;
use xml::escape::escape_str_pcdata;
use async_zip::tokio::write::ZipFileWriter;
use async_zip::{Compression, ZipDateTime, ZipEntryBuilder};
use anyhow::{anyhow, Result};
use async_zip::{tokio::write::ZipFileWriter, Compression, ZipDateTime, ZipEntryBuilder};
use bytes::Bytes;
use chrono::{LocalResult, TimeZone, Utc};
use futures::TryStreamExt;
use futures_util::{pin_mut, TryStreamExt};
use headers::{
AcceptRanges, AccessControlAllowCredentials, AccessControlAllowOrigin, CacheControl,
ContentLength, ContentType, ETag, HeaderMap, HeaderMapExt, IfModifiedSince, IfNoneMatch,
IfRange, LastModified, Range,
};
use hyper::header::{
HeaderValue, AUTHORIZATION, CONTENT_DISPOSITION, CONTENT_LENGTH, CONTENT_RANGE, CONTENT_TYPE,
RANGE,
use http_body_util::{combinators::BoxBody, BodyExt, StreamBody};
use hyper::body::Frame;
use hyper::{
body::Incoming,
header::{
HeaderValue, AUTHORIZATION, CONTENT_DISPOSITION, CONTENT_LENGTH, CONTENT_RANGE,
CONTENT_TYPE, RANGE,
},
Method, StatusCode, Uri,
};
use hyper::{Body, Method, StatusCode, Uri};
use serde::Serialize;
use std::borrow::Cow;
use std::cmp::Ordering;
@@ -39,11 +42,13 @@ use tokio::fs::File;
use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWrite};
use tokio::{fs, io};
use tokio_util::compat::FuturesAsyncWriteCompatExt;
use tokio_util::io::StreamReader;
use tokio_util::io::{ReaderStream, StreamReader};
use uuid::Uuid;
use walkdir::WalkDir;
use xml::escape::escape_str_pcdata;
pub type Request = hyper::Request<Body>;
pub type Response = hyper::Response<Body>;
pub type Request = hyper::Request<Incoming>;
pub type Response = hyper::Response<BoxBody<Bytes, anyhow::Error>>;
const INDEX_HTML: &str = include_str!("../assets/index.html");
const INDEX_CSS: &str = include_str!("../assets/index.css");
@@ -54,7 +59,7 @@ const BUF_SIZE: usize = 65536;
const TEXT_MAX_SIZE: u64 = 4194304; // 4M
pub struct Server {
args: Arc<Args>,
args: Args,
assets_prefix: String,
html: Cow<'static, str>,
single_file_req_paths: Vec<String>,
@@ -62,7 +67,7 @@ pub struct Server {
}
impl Server {
pub fn init(args: Arc<Args>, running: Arc<AtomicBool>) -> Result<Self> {
pub fn init(args: Args, running: Arc<AtomicBool>) -> Result<Self> {
let assets_prefix = format!("{}__dufs_v{}_", args.uri_prefix, env!("CARGO_PKG_VERSION"));
let single_file_req_paths = if args.path_is_file {
vec![
@@ -365,7 +370,7 @@ impl Server {
status_forbid(&mut res);
} else if !is_miss {
*res.status_mut() = StatusCode::METHOD_NOT_ALLOWED;
*res.body_mut() = Body::from("Already exists");
*res.body_mut() = body_full("Already exists");
} else {
self.handle_mkcol(path, &mut res).await?;
}
@@ -411,7 +416,7 @@ impl Server {
Ok(res)
}
async fn handle_upload(&self, path: &Path, mut req: Request, res: &mut Response) -> Result<()> {
async fn handle_upload(&self, path: &Path, req: Request, res: &mut Response) -> Result<()> {
ensure_path_parent(path).await?;
let mut file = match fs::File::create(&path).await {
@@ -422,13 +427,12 @@ impl Server {
}
};
let body_with_io_error = req
.body_mut()
.map_err(|err| io::Error::new(io::ErrorKind::Other, err));
let stream = IncomingStream::new(req.into_body());
let body_with_io_error = stream.map_err(|err| io::Error::new(io::ErrorKind::Other, err));
let body_reader = StreamReader::new(body_with_io_error);
futures::pin_mut!(body_reader);
pin_mut!(body_reader);
let ret = io::copy(&mut body_reader, &mut file).await;
if ret.is_err() {
@@ -596,8 +600,14 @@ impl Server {
error!("Failed to zip {}, {}", path.display(), e);
}
});
let reader = Streamer::new(reader, BUF_SIZE);
*res.body_mut() = Body::wrap_stream(reader.into_stream());
let reader_stream = ReaderStream::new(reader);
let stream_body = StreamBody::new(
reader_stream
.map_ok(Frame::data)
.map_err(|err| anyhow!("{err}")),
);
let boxed_body = stream_body.boxed();
*res.body_mut() = boxed_body;
Ok(())
}
@@ -660,21 +670,21 @@ impl Server {
}
None => match name {
"index.js" => {
*res.body_mut() = Body::from(INDEX_JS);
*res.body_mut() = body_full(INDEX_JS);
res.headers_mut().insert(
"content-type",
HeaderValue::from_static("application/javascript; charset=UTF-8"),
);
}
"index.css" => {
*res.body_mut() = Body::from(INDEX_CSS);
*res.body_mut() = body_full(INDEX_CSS);
res.headers_mut().insert(
"content-type",
HeaderValue::from_static("text/css; charset=UTF-8"),
);
}
"favicon.ico" => {
*res.body_mut() = Body::from(FAVICON_ICO);
*res.body_mut() = body_full(FAVICON_ICO);
res.headers_mut()
.insert("content-type", HeaderValue::from_static("image/x-icon"));
}
@@ -761,18 +771,24 @@ impl Server {
&& file.seek(SeekFrom::Start(range.start)).await.is_ok()
{
let end = range.end.unwrap_or(size - 1).min(size - 1);
let part_size = end - range.start + 1;
let reader = Streamer::new(file, BUF_SIZE);
let range_size = end - range.start + 1;
*res.status_mut() = StatusCode::PARTIAL_CONTENT;
let content_range = format!("bytes {}-{}/{}", range.start, end, size);
res.headers_mut()
.insert(CONTENT_RANGE, content_range.parse()?);
res.headers_mut()
.insert(CONTENT_LENGTH, format!("{part_size}").parse()?);
.insert(CONTENT_LENGTH, format!("{range_size}").parse()?);
if head_only {
return Ok(());
}
*res.body_mut() = Body::wrap_stream(reader.into_stream_sized(part_size));
let stream_body = StreamBody::new(
LengthLimitedStream::new(file, range_size as usize)
.map_ok(Frame::data)
.map_err(|err| anyhow!("{err}")),
);
let boxed_body = stream_body.boxed();
*res.body_mut() = boxed_body;
} else {
*res.status_mut() = StatusCode::RANGE_NOT_SATISFIABLE;
res.headers_mut()
@@ -784,8 +800,15 @@ impl Server {
if head_only {
return Ok(());
}
let reader = Streamer::new(file, BUF_SIZE);
*res.body_mut() = Body::wrap_stream(reader.into_stream());
let reader_stream = ReaderStream::new(file);
let stream_body = StreamBody::new(
reader_stream
.map_ok(Frame::data)
.map_err(|err| anyhow!("{err}")),
);
let boxed_body = stream_body.boxed();
*res.body_mut() = boxed_body;
}
Ok(())
}
@@ -828,7 +851,7 @@ impl Server {
if head_only {
return Ok(());
}
*res.body_mut() = output.into();
*res.body_mut() = body_full(output);
Ok(())
}
@@ -943,7 +966,7 @@ impl Server {
res.headers_mut()
.insert("lock-token", format!("<{token}>").parse()?);
*res.body_mut() = Body::from(format!(
*res.body_mut() = body_full(format!(
r#"<?xml version="1.0" encoding="utf-8"?>
<D:prop xmlns:D="DAV:"><D:lockdiscovery><D:activelock>
<D:locktoken><D:href>{token}</D:href></D:locktoken>
@@ -1014,7 +1037,7 @@ impl Server {
.typed_insert(ContentType::from(mime_guess::mime::TEXT_HTML_UTF_8));
res.headers_mut()
.typed_insert(ContentLength(output.as_bytes().len() as u64));
*res.body_mut() = output.into();
*res.body_mut() = body_full(output);
if head_only {
return Ok(());
}
@@ -1060,7 +1083,7 @@ impl Server {
if head_only {
return Ok(());
}
*res.body_mut() = output.into();
*res.body_mut() = body_full(output);
Ok(())
}
@@ -1419,7 +1442,7 @@ fn res_multistatus(res: &mut Response, content: &str) {
"content-type",
HeaderValue::from_static("application/xml; charset=utf-8"),
);
*res.body_mut() = Body::from(format!(
*res.body_mut() = body_full(format!(
r#"<?xml version="1.0" encoding="utf-8" ?>
<D:multistatus xmlns:D="DAV:">
{content}
@@ -1539,12 +1562,12 @@ fn parse_range(headers: &HeaderMap<HeaderValue>) -> Option<RangeValue> {
fn status_forbid(res: &mut Response) {
*res.status_mut() = StatusCode::FORBIDDEN;
*res.body_mut() = Body::from("Forbidden");
*res.body_mut() = body_full("Forbidden");
}
fn status_not_found(res: &mut Response) {
*res.status_mut() = StatusCode::NOT_FOUND;
*res.body_mut() = Body::from("Not Found");
*res.body_mut() = body_full("Not Found");
}
fn status_no_content(res: &mut Response) {

View File

@@ -1,68 +0,0 @@
use async_stream::stream;
use futures::{Stream, StreamExt};
use std::io::Error;
use std::pin::Pin;
use tokio::io::{AsyncRead, AsyncReadExt};
pub struct Streamer<R>
where
R: AsyncRead + Unpin + Send + 'static,
{
reader: R,
buf_size: usize,
}
impl<R> Streamer<R>
where
R: AsyncRead + Unpin + Send + 'static,
{
#[inline]
pub fn new(reader: R, buf_size: usize) -> Self {
Self { reader, buf_size }
}
pub fn into_stream(
mut self,
) -> Pin<Box<impl ?Sized + Stream<Item = Result<Vec<u8>, Error>> + 'static>> {
let stream = stream! {
loop {
let mut buf = vec![0; self.buf_size];
let r = self.reader.read(&mut buf).await?;
if r == 0 {
break
}
buf.truncate(r);
yield Ok(buf);
}
};
stream.boxed()
}
// allow truncation as truncated remaining is always less than buf_size: usize
pub fn into_stream_sized(
mut self,
max_length: u64,
) -> Pin<Box<impl ?Sized + Stream<Item = Result<Vec<u8>, Error>> + 'static>> {
let stream = stream! {
let mut remaining = max_length;
loop {
if remaining == 0 {
break;
}
let bs = if remaining >= self.buf_size as u64 {
self.buf_size
} else {
remaining as usize
};
let mut buf = vec![0; bs];
let r = self.reader.read(&mut buf).await?;
if r == 0 {
break;
} else {
buf.truncate(r);
yield Ok(buf);
}
remaining -= r as u64;
}
};
stream.boxed()
}
}

View File

@@ -1,161 +0,0 @@
use anyhow::{anyhow, bail, Context as AnyhowContext, Result};
use core::task::{Context, Poll};
use futures::ready;
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};
use rustls::{Certificate, PrivateKey};
use std::future::Future;
use std::net::SocketAddr;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::{fs, io};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::rustls::ServerConfig;
enum State {
Handshaking(tokio_rustls::Accept<AddrStream>),
Streaming(tokio_rustls::server::TlsStream<AddrStream>),
}
// tokio_rustls::server::TlsStream doesn't expose constructor methods,
// so we have to TlsAcceptor::accept and handshake to have access to it
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
pub struct TlsStream {
state: State,
remote_addr: SocketAddr,
}
impl TlsStream {
fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
let remote_addr = stream.remote_addr();
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
TlsStream {
state: State::Handshaking(accept),
remote_addr,
}
}
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
}
impl AsyncRead for TlsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for TlsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
},
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
pub struct TlsAcceptor {
config: Arc<ServerConfig>,
incoming: AddrIncoming,
}
impl TlsAcceptor {
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> TlsAcceptor {
TlsAcceptor { config, incoming }
}
}
impl Accept for TlsAcceptor {
type Conn = TlsStream;
type Error = io::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
Some(Err(e)) => Poll::Ready(Some(Err(e))),
None => Poll::Ready(None),
}
}
}
// Load public certificate from file.
pub fn load_certs<T: AsRef<Path>>(filename: T) -> Result<Vec<Certificate>> {
// Open certificate file.
let cert_file = fs::File::open(filename.as_ref())
.with_context(|| format!("Failed to access `{}`", filename.as_ref().display()))?;
let mut reader = io::BufReader::new(cert_file);
// Load and return certificate.
let certs = rustls_pemfile::certs(&mut reader).with_context(|| "Failed to load certificate")?;
if certs.is_empty() {
bail!("No supported certificate in file");
}
Ok(certs.into_iter().map(Certificate).collect())
}
// Load private key from file.
pub fn load_private_key<T: AsRef<Path>>(filename: T) -> Result<PrivateKey> {
let key_file = fs::File::open(filename.as_ref())
.with_context(|| format!("Failed to access `{}`", filename.as_ref().display()))?;
let mut reader = io::BufReader::new(key_file);
// Load and return a single private key.
let keys = rustls_pemfile::read_all(&mut reader)
.with_context(|| "There was a problem with reading private key")?
.into_iter()
.find_map(|item| match item {
rustls_pemfile::Item::RSAKey(key)
| rustls_pemfile::Item::PKCS8Key(key)
| rustls_pemfile::Item::ECKey(key) => Some(key),
_ => None,
})
.ok_or_else(|| anyhow!("No supported private key in file"))?;
Ok(PrivateKey(keys))
}

View File

@@ -1,31 +0,0 @@
use hyper::server::accept::Accept;
use tokio::net::UnixListener;
use std::pin::Pin;
use std::task::{Context, Poll};
pub struct UnixAcceptor {
inner: UnixListener,
}
impl UnixAcceptor {
pub fn from_listener(listener: UnixListener) -> Self {
Self { inner: listener }
}
}
impl Accept for UnixAcceptor {
type Conn = tokio::net::UnixStream;
type Error = std::io::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
match self.inner.poll_accept(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok((socket, _addr))) => Poll::Ready(Some(Ok(socket))),
Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err))),
}
}
}

View File

@@ -1,5 +1,7 @@
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc};
#[cfg(feature = "tls")]
use rustls_pki_types::{CertificateDer, PrivateKeyDer};
use std::{
borrow::Cow,
path::Path,
@@ -58,6 +60,46 @@ pub fn glob(pattern: &str, target: &str) -> bool {
pat.matches(target)
}
// Load public certificate from file.
#[cfg(feature = "tls")]
pub fn load_certs<T: AsRef<Path>>(filename: T) -> Result<Vec<CertificateDer<'static>>> {
// Open certificate file.
let cert_file = std::fs::File::open(filename.as_ref())
.with_context(|| format!("Failed to access `{}`", filename.as_ref().display()))?;
let mut reader = std::io::BufReader::new(cert_file);
// Load and return certificate.
let mut certs = vec![];
for cert in rustls_pemfile::certs(&mut reader) {
let cert = cert.with_context(|| "Failed to load certificate")?;
certs.push(cert)
}
if certs.is_empty() {
anyhow::bail!("No supported certificate in file");
}
Ok(certs)
}
// Load private key from file.
#[cfg(feature = "tls")]
pub fn load_private_key<T: AsRef<Path>>(filename: T) -> Result<PrivateKeyDer<'static>> {
let key_file = std::fs::File::open(filename.as_ref())
.with_context(|| format!("Failed to access `{}`", filename.as_ref().display()))?;
let mut reader = std::io::BufReader::new(key_file);
// Load and return a single private key.
for key in rustls_pemfile::read_all(&mut reader) {
let key = key.with_context(|| "There was a problem with reading private key")?;
match key {
rustls_pemfile::Item::Pkcs1Key(key) => return Ok(PrivateKeyDer::Pkcs1(key)),
rustls_pemfile::Item::Pkcs8Key(key) => return Ok(PrivateKeyDer::Pkcs8(key)),
rustls_pemfile::Item::Sec1Key(key) => return Ok(PrivateKeyDer::Sec1(key)),
_ => {}
}
}
anyhow::bail!("No supported private key in file");
}
#[test]
fn test_glob_key() {
assert!(glob("", ""));