sunface / rust-course

“连续六年成为全世界最受喜爱的语言,无 GC 也无需手动内存管理、极高的性能和安全性、过程/OO/函数式编程、优秀的包管理、JS 未来基石" — 工作之余的第二语言来试试 Rust 吧。<<Rust语言圣经>>拥有全面且深入的讲解、生动贴切的示例、德芙般丝滑的内容,甚至还有JS程序员关注的 WASM 和 Deno 等专题。这可能是目前最用心的 Rust 中文学习教程 / Book
https://course.rs
24.76k stars 2.13k forks source link

Optimize sample code #1324

Closed asthetik closed 9 months ago

asthetik commented 9 months ago

在原先的官方示例代码中,for循环里的join()方法都会等待子线程执行完毕,当前循环执行完,子线程都会退出,并且在下一次循环时,生成的子线程会回收退出线程的对象。所以for循环执行完之后,只剩下一个子线程。官方示例代码不能很好地体现出求和多个子线程局部变量中的计数器值,事实上只有一个子线程。

证明

fn main() {

let tls = Arc::new(ThreadLocal::new());
// 创建多个线程
for i in 0..5 {
    let tls2 = tls.clone();
    thread::spawn(move || {
        let cell = tls2.get_or(|| Cell::new(0));
        // 回收上一次循环时退出线程的对象,所以当前线程会随着 i 自增而自增
        assert_eq!(cell.get(), i);
        println!("value: {}, i: {}", cell.get(), i);
        cell.set(cell.get() + 1);
    }).join().unwrap();
}
// for循环执行完之后,只有一个线程,计数器总和等于5
assert_eq!(tls.get().unwrap().get(), 5);

// 一旦所有子线程结束,收集它们的线程局部变量中的计数器值,然后进行求和
let tls = Arc::try_unwrap(tls).unwrap();
let total = tls.into_iter().fold(0, |x, y| {
    // 打印每个线程的值,迭代完成后可发现只有一个线程
    println!("x: {}, y: {}", x, y.get());
    x + y.get()
});

// 和为5
assert_eq!(total, 5);

}

成功执行以上代码并输出,证明为真!
```rust
value: 0, i: 0
value: 1, i: 1
value: 2, i: 2
value: 3, i: 3
value: 4, i: 4
x: 0, y: 5

优化官方示例代码

让多个子线程存在,收集它们的线程局部变量中的计数器值

use std::{cell::Cell, sync::Arc, thread};
use thread_local::ThreadLocal;

fn main() {

    let tls = Arc::new(ThreadLocal::new());
    let mut v = vec![];
    // 创建多个线程
    for i in 0..5 {
        let tls2 = tls.clone();
        let handle = thread::spawn(move || {
            // 将计数器加1
            let cell = tls2.get_or(|| Cell::new(0));
           // 并不是所有的线程初始值都为 0,因为当前线程可能会回收退出线程的对象
            println!("value: {}, i: {}", cell.get(), i);
            cell.set(cell.get() + 1);
        });
        v.push(handle);
    }
    for handle in v {
        handle.join().unwrap();
    }
    // 一旦所有子线程结束,收集它们的线程局部变量中的计数器值,然后进行求和
    let tls = Arc::try_unwrap(tls).unwrap();
    let total = tls.into_iter().fold(0, |x, y| {
        // 打印每个线程的值,迭代完成后可发现有多个线程,
        // 不一定存在5个线程,因为一些线程已退出
        println!("x: {}, y: {}", x, y.get());
        x + y.get()
    });

    // 和为5
    assert_eq!(total, 5);
}

执行以上代码并输出

value: 0, i: 0
value: 0, i: 2
value: 0, i: 1
value: 1, i: 3
value: 1, i: 4
x: 0, y: 2
x: 2, y: 2
x: 4, y: 1

多执行几次代码(也可以增大for循环的次数),每一次for循环执行完之后存在的线程数量不一定相同。 参考 std thread_local 库的文档说明: Note that since thread IDs are recycled when a thread exits, it is possible for one thread to retrieve the object of another thread. Since this can only occur after a thread has exited this does not lead to any race conditions.

故作此优化。