gsquire / singleflight

Rust port of Go's singleflight package
https://crates.io/crates/singleflight
MIT License
12 stars 1 forks source link

整个函数执行期间都持有锁, 不同的任务也需要等待, 这应该是不太合理的喔 #4

Open Zzaniu opened 5 months ago

Zzaniu commented 5 months ago

Mutex 锁住了整个任务函数, 这应该是不合理的, 这回阻止其他不同的任务函数执行. 返回值指定了T, 那么就只能用于拥有相同返回类型的任务. PS: golang里面是使用interface来支持不同返回类型的任务函数

gsquire commented 5 months ago

I'm relying on Google translate so please excuse me if I'm misunderstanding.

Here's what I see from the issue title:

The lock is held during the entire function execution, and different tasks also need to wait. This should be unreasonable.

I guess we could try and optimize this and just wait for the WaitGroup to be dropped.

Zzaniu commented 5 months ago

我尝试编写下面的代码, 我认为这是可行的. 由于我学习rust时间还不长, 如有错误的地方, 敬请谅解


use std::any::Any;
use std::ops::Deref;
use std::sync::Arc;

use hashbrown::HashMap;
use parking_lot::Mutex;

// 单飞模式, 模拟 go 的 singleflight 写的
// 用于避免重复请求, 当有相同的请求时, 直接返回之前的结果
// 适用于耗时长的请求, 如网络请求, 数据库查询等
// 需要拥有多线程的内部可变性, 所以需要 Arc + Mutex
// 需要支持不同的类型分发到不同的线程, 所以要用 Arc<dyn Any + Send + Sync>

type ResultArcAny<T> = Result<Arc<dyn Any + Send + Sync>, T>;
type Call<T> = Arc<Mutex<Option<ResultArcAny<T>>>>;

#[derive(Default, Clone)] // Group 是需要给到不同的线程的, 所以需要 Clone
pub struct Group<T> {
    m: Arc<Mutex<HashMap<String, Call<T>>>>, // 这里加个 Arc + Mutex 包装一下, 使得内部可变性可以被多个线程访问
}

impl<T: Clone + Send + Sync> Group<T> {
    pub fn work<F>(&self, key: &str, func: F) -> ResultArcAny<T>
        where
            F: FnOnce() -> ResultArcAny<T>
    {
        // 加大锁, 避免并发访问
        let mut m = self.m.lock();
        // 如果任务 key 已存在, 获取包裹结果的锁
        if let Some(c) = m.get(key) {
            // 拷贝一下, 因为要释放大锁
            let c = c.clone();
            // 释放大锁
            drop(m);
            // 加锁等待, 返回之前的结果并释放锁
            return c.lock().clone().take().unwrap();
        }

        let c = Call::default();
        // 拿到结果的锁
        let res = c.clone();
        // 写入 map, 其他相同任务看到这个 key 后, 不会再执行 func 函数
        m.insert(key.to_owned(), c);
        // 加锁, 阻止其他相同任务在结果写入前访问结果
        let mut res_guard = res.lock();
        // 释放大锁. 这把锁不能长时间拥有, 否则会影响其他任务执行
        drop(m);

        // 执行任务
        let res = func();
        // 将结果写入内部可变性的 Mutex 中
        res_guard.replace(res.clone());
        // 释放锁, 让其他相同的任务可以访问结果
        drop(res_guard);

        // 加大锁删除 map 中的 key, 避免内存泄漏, 并释放锁.
        // 在这之后拿到大锁的任务, 需要重新执行 func 函数
        self.m.lock().remove(key).unwrap();

        res
    }
}

#[cfg(test)]
mod tests {
    use std::sync::Arc;
    use std::thread;
    use std::thread::sleep;
    use std::time::Duration;

    use crate::single_flight::Group;

    #[derive(Clone)]
    struct A {
        value: i32,
    }

    #[test]
    fn test_single_flight() {
        use std::sync::mpsc;
        let count = 200;
        let (tx, rx) = mpsc::channel::<()>();
        let g: Group<String> = Group::default();
        for i in 0..count {
            let g = g.clone();
            let tx = tx.clone();
            thread::spawn(move || {
                let (key, value) = if i % 2 == 0 {
                    ("aaa", 121)
                } else {
                    ("bbb", 111)
                };
                let res = g.work(key, || {
                    sleep(Duration::from_millis(300));
                    Ok(Arc::new(value))
                });
                let res = res.unwrap().downcast::<i32>().unwrap();
                assert_eq!(*res, value);
                drop(tx);
            });
        }

        for i in 0..count {
            let g = g.clone();
            let tx = tx.clone();
            thread::spawn(move || {
                let (key, value) = if i % 2 == 0 {
                    ("ccc", "121".to_owned())
                } else {
                    ("ddd", "111".to_owned())
                };
                let value2 = value.clone();
                let res = g.work(key, || {
                    sleep(Duration::from_millis(300));
                    Ok(Arc::new(value2))
                });
                let res = res.unwrap().downcast::<String>().unwrap();
                assert_eq!(*res, value);
                drop(tx);
            });
        }

        for i in 0..count {
            let g = g.clone();
            let tx = tx.clone();
            thread::spawn(move || {
                let (key, value) = if i % 2 == 0 {
                    ("eee", A { value: 121 })
                } else {
                    ("fff", A { value: 111 })
                };
                let value2 = value.clone();
                let res = g.work(key, || {
                    sleep(Duration::from_millis(300));
                    Ok(Arc::new(value2))
                });
                let res = res.unwrap().downcast::<A>().unwrap();
                assert_eq!(res.value, value.value);
                drop(tx);
            });
        }
        drop(tx);
        while rx.recv().is_ok() {};
    }
}
gsquire commented 5 months ago

Your implementation looks good as well since it uses an Arc value and more closely mimics Go's internals. Nice work!