huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
130.16k stars 25.86k forks source link

Add SpikeGPT model #21875

Open gsarti opened 1 year ago

gsarti commented 1 year ago

Model description

Abstract:

As the size of large language models continue to scale, so does the computational resources required to run it. Spiking neural networks (SNNs) have emerged as an energy-efficient approach to deep learning that leverage sparse and event-driven activations to reduce the computational overhead associated with model inference. While they have become competitive with non-spiking models on many computer vision tasks, SNNs have also proven to be more challenging to train. As a result, their performance lags behind modern deep learning, and we are yet to see the effectiveness of SNNs in language generation. In this paper, inspired by the RWKV language model, we successfully implement `SpikeGPT', a generative language model with pure binary, event-driven spiking activation units. We train the proposed model on three model variants: 45M, 125M and 260M parameters. To the best of our knowledge, this is 4x larger than any functional backprop-trained SNN to date. We achieve this by modifying the transformer block to replace multi-head self attention to reduce quadratic computational complexity to linear with increasing sequence length. Input tokens are instead streamed in sequentially to our attention mechanism (as with typical SNNs). Our preliminary experiments show that SpikeGPT remains competitive with non-spiking models on tested benchmarks, while maintaining 5x less energy consumption when processed on neuromorphic hardware that can leverage sparse, event-driven activations.

Concretely, it is a GPT model using Receptance Weighted Key Value (RWKV) instead of regular attention, and an adapted FFN layer.

Open source status

Provide useful links for the implementation

Paper | Code

Author: @ridgerchu

ridgerchu commented 1 year ago

Thanks for your interest to our work! The checkpoint weights of 120M spike GPT has available now, but just for debug and playing with the model.

julien-c commented 1 year ago

I've read the paper, this model looks really cool 👍