Rust 程序设计语言 web server (tokio)

Rust 程序设计语言 web server

1 web server 源码

// src/bin/main.rs
use std::fs;
use std::io::prelude::*;
use std::net::{TcpListener, TcpStream};
use web_server::ThreadPool;

fn main() {
    let listener = TcpListener::bind("127.0.0.1:9981").unwrap();
    let pool = ThreadPool::new(4);

    for stream in listener.incoming() {
        pool.execute(|| {
            handle_connection(stream.unwrap());
        });
    }
}

fn handle_connection(mut stream: TcpStream) {
    let mut buffer = [0; 1024];
    let _size = stream.read(&mut buffer).unwrap();

    let get = b"GET / HTTP/1.1\r\n";
    let sleep = b"GET /sleep HTTP/1.1\r\n";

    let response;
    let (status_line, filename) = if buffer.starts_with(get) {
        ("HTTP/1.1 200 OK", "hello.html")
    } else if buffer.starts_with(sleep) {
        std::thread::sleep(std::time::Duration::from_secs(5));
        ("HTTP/1.1 200 OK", "hello.html")
    } else {
        ("HTTP/1.1 404 Not Found", "404.html")
    };

    let body = fs::read_to_string(filename).unwrap();
    response = format!("{}\r\nContent-Length: {}\r\n\r\n{}", status_line, body.len(), body);
    stream.write(response.as_bytes()).unwrap();
}

// src/lib.rs
use std::sync::{Arc, Mutex, mpsc};

type Job = Box<dyn FnOnce() + Send + 'static>;

enum Message {
    NewJob(Job),
    Terminate,
}

pub struct ThreadPool {
    workers: Vec<Worker>,
    sender: mpsc::Sender<Message>,
}

impl ThreadPool {
    pub fn new(size: usize) -> Self {
        assert!(size > 0);

        let (sender, receiver) = mpsc::channel();
        let receiver = Arc::new(Mutex::new(receiver));

        let mut workers = Vec::with_capacity(size);
        for id in 0..size {
            workers.push(Worker::new(id, receiver.clone()));
        }

        Self { workers, sender }
    }

    pub fn execute<F>(&self, f: F)
        where F: FnOnce() + Send + 'static,
    {
        let job = Box::new(f);
        self.sender.send(Message::NewJob(job)).unwrap();
    }
}

impl Drop for ThreadPool {
    fn drop(&mut self) {
        println!("Sending terminate message to all workers.");

        for _ in &self.workers {
            self.sender.send(Message::Terminate).unwrap();
        }

        println!("Shutting down all workers.");

        for worker in &mut self.workers {
            println!("Shutting down worker {}", worker.id);

            if let Some(thread) = worker.thread.take() {
                thread.join().unwrap();
            }
        }
    }
}

struct Worker {
    id: usize,
    thread: Option<std::thread::JoinHandle<()>>,
}

impl Worker {
    fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Message>>>) -> Self {
        Self {
            id,
            thread: Some(std::thread::spawn(move || {
                loop {
                    let message = receiver.lock().unwrap().recv().unwrap();

                    match message {
                        Message::NewJob(job) => {
                            println!("Worker {} got a job; executing.", id);
                            job()
                        },
                        Message::Terminate => {
                            println!("Worker {} was told to terminate.", id);
                            break
                        },
                    }
                }

                // 错误:job() 执行完才会释放锁
                // recv() 会阻塞当前线程
                // while let Ok(job) = receiver.lock().unwrap().recv() {
                //     println!("Worker {} get a job; executing.", id);
                //     job();
                // }
            })),
        }
    }
}

2 web server (tokio)

// src/main.rs
mod thread_pool;

use std::io::prelude::*;
use std::net::{TcpListener, TcpStream};
use thread_pool::ThreadPool;
use tokio;

#[tokio::main]
async fn main() {
    let listener = TcpListener::bind("127.0.0.1:9981").unwrap();
    let thread_pool = ThreadPool::new(4);

    for stream in listener.incoming() {
        thread_pool.execute(|| {
            handle_connection(stream.unwrap());
        }).await;
    }
}

fn handle_connection(mut stream: TcpStream) {
    let mut buffer = [0; 1024];
    stream.read(&mut buffer).unwrap();

    let get = b"GET / HTTP/1.1\r\n";
    let sleep = b"GET /sleep HTTP/1.1\r\n";

    let response;
    let (status_line, content) = if buffer.starts_with(get) {
        ("HTTP/1.1 200 OK", "Hello Rust")
    } else if buffer.starts_with(sleep) {
        std::thread::sleep(std::time::Duration::from_secs(5));
        ("HTTP/1.1 200 OK", "Hello Rust")
    } else {
        ("HTTP/1.1 404 Not Found", "Not Found")
    };

    response = format!("{}\r\nContent-Length: {}\r\n\r\n{}", status_line, content.len(), content);
    stream.write(response.as_bytes()).unwrap();
}

// src/thread_pool.rs
use std::time::{SystemTime, UNIX_EPOCH};
use futures::executor::block_on;
use tokio::sync::mpsc;

type Task = Box<dyn FnOnce() + Send + 'static>;

enum Message {
    NewTask(Task),
    Terminate,
}

pub struct ThreadPool {
    senders: Vec<(Worker, mpsc::Sender<Message>)>,
}

impl ThreadPool {
    pub fn new(size: usize) -> Self {
        assert!(size > 0);

        let mut senders = Vec::new();
        for id in 0..size {
            senders.push(Worker::new(id));
        }

        Self { senders }
    }

    pub async fn execute<F>(&self, f: F)
        where F: FnOnce() + Send + 'static,
    {
        let time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_micros();
        let len = self.senders.len();
        let index = (time % len as u128) as usize;

        let (_, sender) = self.senders.get(index).unwrap();
        sender.send(Message::NewTask(Box::new(f))).await;
    }
}

impl Drop for ThreadPool {
    fn drop(&mut self) {
        println!("Sending terminate message to all workers.");
        for i in 0..self.senders.len() {
            let (_, sender) = self.senders.get(i).unwrap();

            block_on(async {
                sender.send(Message::Terminate).await;
            });
        }

        println!("Shutting down all workers.");

        for (worker, _) in &mut self.senders {
            println!("Shutting down worker {}", worker.id);

            if let Some(thread) = worker.thread.take() {
                thread.join().unwrap();
            }
        }
    }
}

struct Worker {
    id: usize,
    thread: Option<std::thread::JoinHandle<()>>,
}

impl Worker {
    fn new(id: usize) -> (Worker, mpsc::Sender<Message>) {
        let (tx, mut rx) = mpsc::channel::<Message>(32);

        let thread = std::thread::spawn(move || block_on(
            async move {
                while let Some(message) = rx.recv().await {
                    match message {
                        Message::NewTask(task) => {
                            println!("Worker {} got a job; executing.", id);
                            task();
                        },
                        Message::Terminate => {
                            println!("Worker {} was told to terminate.", id);
                            break;
                        }
                    }
                }
            }
        ));

        let worker = Worker { id, thread: Some(thread) };

        (worker, tx)
    }
}

// Cargo.toml
[package]
name = "web_server"
version = "0.1.0"
edition = "2021"

[dependencies]
tokio = { version = "1.15.0", features = ["full"] }
futures = "0.3"

上一篇:【Rust日报】2020-11-25 AWS loves Rust


下一篇:Rust网络编程框架-深入理解Tokio中的管道