lealone / Lealone

比 MySQL 和 MongoDB 快10倍的 OLTP 关系数据库和文档数据库
Other
2.48k stars 520 forks source link

SPSC队列 #213

Closed areyouok closed 1 year ago

areyouok commented 1 year ago

这是原版本,我只增加了计时:

package test;

public class SpscLinkableListTest {

    public static void main(String[] args) throws Exception {
        new SpscLinkableListTest().run();
    }

    // LinkableList是一个无锁且不需要CAS的普通链表,满足单生产者单消费者的应用场景
    private final LinkableList<PendingTask> pendingTasks = new LinkableList<>();
    private final long pendingTaskCount = 1000 * 10000; // 待处理任务总数
    private long completedTaskCount; // 已经完成的任务数

    private long result; // 存放计算结果

    private void run() throws Exception {
        // 生产者创建pendingTaskCount个AsyncTask
        // 每个AsyncTask的工作就是计算从1到pendingTaskCount的和
        Thread producer = new Thread(() -> {
            for (int i = 1; i <= pendingTaskCount; i++) {
                AsyncTask task = new AsyncTask(i);
                submitTask(task);
            }
        });

        // 消费者不断从pendingTasks中取出AsyncTask执行
        Thread consumer = new Thread(() -> {
            while (completedTaskCount < pendingTaskCount) {
                runPendingTasks();
            }
        });
        long t = System.currentTimeMillis();
        producer.start();
        consumer.start();
        producer.join();
        consumer.join();

        // 如果result跟except相同,说明代码是ok的,如果不同,那就说明代码有bug
        long except = (1 + pendingTaskCount) * pendingTaskCount / 2;
        t = System.currentTimeMillis() - t;
        if (result == except) {
            System.out.println("result: " + result + ", ok. cost " + t + "ms");
        } else {
            System.out.println("result: " + result + ", not ok, except: " + except);
        }
    }

    private void submitTask(AsyncTask task) {
        PendingTask pt = new PendingTask(task);
        pendingTasks.add(pt);
        if (pendingTasks.size() > 1)
            removeCompletedTasks();
    }

    private void removeCompletedTasks() {
        PendingTask pt = pendingTasks.getHead();
        while (pt != null && pt.isCompleted()) {
            pt = pt.getNext();
            pendingTasks.decrementSize();
            pendingTasks.setHead(pt);
        }
        if (pendingTasks.getHead() == null)
            pendingTasks.setTail(null);
    }

    private void runPendingTasks() {
        PendingTask pt = pendingTasks.getHead();
        while (pt != null) {
            if (!pt.isCompleted()) {
                completedTaskCount++;
                pt.getTask().compute();
                pt.setCompleted(true);
            }
            pt = pt.getNext();
        }
    }

    public class AsyncTask {
        int value;

        AsyncTask(int value) {
            this.value = value;
        }

        void compute() {
            result += value;
        }
    }

    public class PendingTask extends LinkableBase<PendingTask> {

        private final AsyncTask task;
        private boolean completed;

        public PendingTask(AsyncTask task) {
            this.task = task;
        }

        public AsyncTask getTask() {
            return task;
        }

        public boolean isCompleted() {
            return completed;
        }

        public void setCompleted(boolean completed) {
            this.completed = completed;
        }
    }

    public interface Linkable<E extends Linkable<E>> {

        void setNext(E next);

        E getNext();

    }

    public class LinkableBase<E extends Linkable<E>> implements Linkable<E> {

        public E next;

        @Override
        public void setNext(E next) {
            this.next = next;
        }

        @Override
        public E getNext() {
            return next;
        }
    }

    public class LinkableList<E extends Linkable<E>> {

        private E head;
        private E tail;
        private int size;

        public E getHead() {
            return head;
        }

        public void setHead(E head) {
            this.head = head;
        }

        public E getTail() {
            return tail;
        }

        public void setTail(E tail) {
            this.tail = tail;
        }

        public boolean isEmpty() {
            return head == null;
        }

        public int size() {
            return size;
        }

        public void decrementSize() {
            size--;
        }

        public void add(E e) {
            size++;
            if (head == null) {
                head = tail = e;
            } else {
                tail.setNext(e);
                tail = e;
            }
        }

        public void remove(E e) {
            size--;
            if (head == e) { // 删除头
                head = e.getNext();
                if (head == null)
                    tail = null;
            } else {
                E n = head;
                E last = n;
                while (n != null) {
                    if (e == n) {
                        last.setNext(n.getNext());
                        break;
                    }
                    last = n;
                    n = n.getNext();
                }
                if (tail == e) // 删除尾
                    tail = last;
            }
        }
    }
}

执行结果:result: 50000005000000, ok. cost 487ms

我写的版本

package test;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;

public class SpscLinkableListTest2 {

    public static void main(String[] args) throws Exception {
        new SpscLinkableListTest2().run();
    }

    // LinkableList是一个无锁且不需要CAS的普通链表,满足单生产者单消费者的应用场景
    private final LinkableList<AsyncTask> pendingTasks = new LinkableList<>();
    private final long pendingTaskCount = 1000 * 10000; // 待处理任务总数
    private long completedTaskCount; // 已经完成的任务数

    private long result; // 存放计算结果

    private void run() throws Exception {
        // 生产者创建pendingTaskCount个AsyncTask
        // 每个AsyncTask的工作就是计算从1到pendingTaskCount的和
        Thread producer = new Thread(() -> {
            for (int i = 1; i <= pendingTaskCount; i++) {
                AsyncTask task = new AsyncTask(i);
                submitTask(task);
            }
        });

        // 消费者不断从pendingTasks中取出AsyncTask执行
        Thread consumer = new Thread(() -> {
            while (completedTaskCount < pendingTaskCount) {
                runPendingTasks();
            }
        });
        long t = System.currentTimeMillis();
        producer.start();
        consumer.start();
        producer.join();
        consumer.join();

        // 如果result跟except相同,说明代码是ok的,如果不同,那就说明代码有bug
        long except = (1 + pendingTaskCount) * pendingTaskCount / 2;
        t = System.currentTimeMillis() - t;
        if (result == except) {
            System.out.println("result: " + result + ", ok. cost " + t + "ms");
        } else {
            System.out.println("result: " + result + ", not ok, except: " + except);
        }
    }

    private void submitTask(AsyncTask task) {
        pendingTasks.add(task);
    }

    private void runPendingTasks() {
        AsyncTask r = pendingTasks.remove();
        if (r != null) {
            r.compute();
            completedTaskCount++;
        }
    }

    public class AsyncTask {
        int value;

        AsyncTask(int value) {
            this.value = value;
        }

        void compute() {
            result += value;
        }
    }

    public static class Node<E> {
        E data;
        volatile Node<E> next;

        static VarHandle NEXT;

        static {
            try {
                MethodHandles.Lookup l = MethodHandles.lookup();
                NEXT = l.findVarHandle(Node.class, "next", Node.class);
            } catch (Exception e) {
                throw new Error(e);
            }
        }
    }

    public static class LinkableList<E> {
        long p00, p01, p02, p03, p04, p05, p06, p07, p08, p09, p0a, p0b, p0c, p0d, p0e, p0f;
        private Node<E> producerNode;
        long p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p1a, p1b, p1c, p1d, p1e, p1f;
        private Node<E> consumerNode;
        long p20, p21, p22, p23, p24, p25, p26, p27, p28, p29, p2a, p2b, p2c, p2d, p2e, p2f;

        public LinkableList() {
            Node<E> node = new Node<>();
            producerNode = node;
            consumerNode = node;
            VarHandle.releaseFence();
        }

        public void add(E data) {
            Node n = new Node();
            n.data = data;
            Node oldProducerNode = producerNode;
            producerNode = n;
            Node.NEXT.setRelease(oldProducerNode, n);
        }

        public E remove() {
            Node<E> nextNode = (Node) Node.NEXT.getAcquire(consumerNode);
            if (nextNode == null) {
                return null;
            }
            consumerNode = nextNode;
            E data = nextNode.data;
            nextNode.data = null;
            return data;
        }
    }
}

result: 50000005000000, ok. cost 278ms

如果循环次数太多,可能因为消费跟不上生产,导致内存膨胀,结果就不太准确。

areyouok commented 1 year ago

只实现了add/remove,这样是最快的,如果要实现size什么的,就要复杂一些了,流程和顺序可能要改。

java8下用Unsafe可以实现release写,但是acquire读无法直接实现,用unsafe设置读屏障应该可以基本实现(可能略有差异)

codefollower commented 1 year ago
  1. 首先你这个版本不能在 jdk 1.8 跑。
  2. 功能少了一点,代码跟 JCTools 的 SpscLinkedQueue 一样有一堆字段,可读性不好,这种实现太 hack 了。
  3. 任务数百万级时两者的性能差距并没有那么大。
  4. SpscLinkedQueue 的性能比你的版本还略好一些,我不用 JCTools 就是不想用一个类引入一个库。
areyouok commented 1 year ago

我知道你的代码无屏障情况下会有什么问题了。

因为重排序消费者可能会拿到构造未完成的对象,AsyncTask的value字段没加final,可能会重排序到构造函数之外,你不能要求入队的每个对象的所有字段都加final。

PendingTask有final,不知jdk会怎么实现,也许它会加屏障,但至少理论上是有问题的,从理论上说final只能保一个字段,AsyncTask的字段也要final。

areyouok commented 1 year ago

就是说,消费者都拿到那个对象了,结果生产者那边构造函数还没执行完。

areyouok commented 1 year ago

jdk版本问题不大,探测一下java版本,然后分别实现就好了。

最不济就像jctools那样,你要相信他事专业的,它这个类库只是有点老了。

codefollower commented 1 year ago

value 那个字段本就不用变更的,加个 final 只是懒而已,这种细节无关痛痒,就算 javac 和 jit 对字节码和本地代码胡乱排序,难不成执行到 compute 方法时 value 还是未知的,jit 敢做这么深的乱排序得需要多复杂的算法才能保证语义正确。实在不行就加上 final 就好了。

areyouok commented 1 year ago

value 那个字段本就不用变更的,加个 final 只是懒而已,这种细节无关痛痒,就算 javac 和 jit 对字节码和本地代码胡乱排序,难不成执行到 compute 方法时 value 还是未知的,jit 敢做这么深的乱排序得需要多复杂的算法才能保证语义正确。实在不行就加上 final 就好了。

你要知道,这个重排序并不是javac或者jit造成的(当然它们可能也会搞这种事),而是CPU它就这么干。线程A构造一个对象放在某个(非volatile)字段上,线程B看见这个字段不是null了,结果线程A构造函数干的事,线程B居然看不见。并不会因为你用java,就有什么优待。反而是,用java写出个只能在x86平台运行的程序(因为这个构造函数重排序的事情不会在x86架构下发生),有点尴尬。

你如果不了解这些,还是不要自己折腾了,ConcurrentLinkedQueue/LinkedBlockingQueue就够了。你甚至都不在非x86架构下做测试,那怎么让人放心呢。

可读性不好,这种实现太 hack 了

这个我不认可,都是堂堂正正的java代码,连unsafe都没用,可移植性也没有问题,怎么就是hack呢。

那些padding只要是了解伪共享的人都知道是什么意思,早就都是常规手段了,以前jdk类库里面也用,现在只是改成注解了。

SpscLinkedQueue 的性能比你的版本还略好一些

这个我倒没有测试,不过我这个程序其实就是JCTools SpscLinkedQueue的现代简化版。理论上性能只会更好。如果确实不如SpscLinkedQueue,那有3种可能:1、测试误差或者和运行时工况有关;2、我什么地方搞错了;3、你什么地方搞错了。

功能少了一点

既然都是特化的SPSC队列了,要那么多功能干嘛,最多是少个size方法,那要实现也很容易。第一个办法是从JCTools抄,第二更简单,弄两个LongAdder,生产者和消费者分别累加,调用size的时候计算差值看个大概就行了(你原来的代码中,size的计算也是不准确的)。

不能在 jdk 1.8 跑。

下面这个版本可以在jdk 1.6下跑,而且代码更简单。

package test;

import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;

public class SpscLinkableListTest3 {

    public static void main(String[] args) throws Exception {
        new SpscLinkableListTest3().run();
    }

    // LinkableList是一个无锁且不需要CAS的普通链表,满足单生产者单消费者的应用场景
    private final LinkableList<AsyncTask> pendingTasks = new LinkableList<>();
    private final long pendingTaskCount = 1000 * 10000; // 待处理任务总数
    private long completedTaskCount; // 已经完成的任务数

    private long result; // 存放计算结果

    private void run() throws Exception {
        // 生产者创建pendingTaskCount个AsyncTask
        // 每个AsyncTask的工作就是计算从1到pendingTaskCount的和
        Thread producer = new Thread(() -> {
            for (int i = 1; i <= pendingTaskCount; i++) {
                AsyncTask task = new AsyncTask(i);
                submitTask(task);
            }
        });

        // 消费者不断从pendingTasks中取出AsyncTask执行
        Thread consumer = new Thread(() -> {
            while (completedTaskCount < pendingTaskCount) {
                runPendingTasks();
            }
        });
        long t = System.currentTimeMillis();
        producer.start();
        consumer.start();
        producer.join();
        consumer.join();

        // 如果result跟except相同,说明代码是ok的,如果不同,那就说明代码有bug
        long except = (1 + pendingTaskCount) * pendingTaskCount / 2;
        t = System.currentTimeMillis() - t;
        if (result == except) {
            System.out.println("result: " + result + ", ok. cost " + t + "ms");
        } else {
            System.out.println("result: " + result + ", not ok, except: " + except);
        }
    }

    private void submitTask(AsyncTask task) {
        pendingTasks.add(task);
    }

    private void runPendingTasks() {
        AsyncTask r = pendingTasks.remove();
        if (r != null) {
            r.compute();
            completedTaskCount++;
        }
    }

    public class AsyncTask {
        int value;

        AsyncTask(int value) {
            this.value = value;
        }

        void compute() {
            result += value;
        }
    }

    public static class Node<E> {
        E data;
        volatile Node<E> next;

        final static AtomicReferenceFieldUpdater<Node, Node> NEXT = AtomicReferenceFieldUpdater
                .newUpdater(Node.class, Node.class, "next");
    }

    public static class LinkableList<E> {
        long p00, p01, p02, p03, p04, p05, p06, p07, p08, p09, p0a, p0b, p0c, p0d, p0e, p0f;
        private Node<E> producerNode;
        long p10, p11, p12, p13, p14, p15, p16, p17, p18, p19, p1a, p1b, p1c, p1d, p1e, p1f;
        private Node<E> consumerNode;
        long p20, p21, p22, p23, p24, p25, p26, p27, p28, p29, p2a, p2b, p2c, p2d, p2e, p2f;

        public LinkableList() {
            Node<E> node = new Node<>();
            producerNode = node;
            consumerNode = node;
        }

        public void add(E data) {
            Node n = new Node();
            n.data = data;
            Node oldProducerNode = producerNode;
            producerNode = n;
            Node.NEXT.lazySet(oldProducerNode, n);
        }

        public E remove() {
            Node<E> nextNode = consumerNode.next;
            if (nextNode == null) {
                return null;
            }
            consumerNode = nextNode;
            E data = nextNode.data;
            nextNode.data = null;
            return data;
        }
    }
}

result: 50000005000000, ok. cost 291ms

codefollower commented 1 year ago

我把微博的内容搬过来,设计 org.lealone.db.link.LinkableList 的初衷是为了取代 java.util.LinkedList

在单线程场景下使用 LinkedList 我已经发现三个小问题了:1. 添加新元素时需要额外创建一个 node 对象;2. 遍历的时候又创建一个迭代器对象;3. 遍历的过程中不能在其他地方随意删除元素。结论就是如果在一个高频率执行代码的事件循环线程里不适合用 LinkedList。 ​​​

lealone 6 中的 LinkableList 最初只是为了替换 java.util.LinkedList,然后变成每个全局调度器的私有队列,java.util.LinkedList 的问题我记得在微博聊过,忘记了。

全局调度器的各种私有队列并不是为 SPSC 服务的,只是最近搞了个客户端调度器才发现 LinkableList 可以当 SPSC 链表用,把应用的线程产生的任务放到这个链表。性能我都没测,只是想尝试一下不用锁不用 CAS 如何用最常规的技术实现一个 SPSC 链表,这是一个很小的东西,我本来不想说了,结果看到网友说可以放到教科书就一时兴起拿出来献丑。

LinkableList 如果用来实现 SPSC,目前的版本依然是不高效的,很多时间都花在无效的循环检测上了。我拿 jdk 的 ConcurrentLinkedQueue 来对比压测,发现比 LinkableList 实现的 SPSC 链表还略好一些。

ConcurrentLinkedQueue 有将近一千行代码,基于 LinkableList 的 SPSC 链表才100多行代码,初级版本就能取得这样的效果我已经挺满意了。

areyouok commented 1 year ago

既然如此,那么这个话题就结束了。

不过我还是想再补充一个信息。

你觉得jctools对比jdk ConcurrentLinkedQueue似乎没有优势,那是因为jctools是一个古老的类库,它是为java8设计的。而你在java11/17/21下运行的ConcurrentLinkedQueue有更加现代的优化加成。用java17的ConcurrentLinkedQueue去打为java8设计的jctools有点不公平,到java8下面ConcurrentLinkedQueue的性能就会差好多。

还是那句话,jctools是专业的,它只是有点老了。

当然ConcurrentLinkedQueue也是极其牛逼的人开发的,一个通用的MPMC能做到这么高性能,绝大部分场景下确实没有自己造一个queue的必要。