lucidrains / ring-attention-pytorch

Implementation of 💍 Ring Attention, from Liu et al. at Berkeley AI, in Pytorch
MIT License
474 stars 27 forks source link

Is the GPU being used? #14

Open overseerlabs opened 4 months ago

overseerlabs commented 4 months ago

Hello, I tried running your test program on Google Colab using a T4 GPU and noticed that the GPU ram would always stay at 0. Do you know if this is the expected behavior? Here's the code I was using. I modified the context win size to something larger:

import torch from ring_attention_pytorch import RingAttention

attn = RingAttention( dim = 256, dim_head = 4, heads = 32, bucket_size = 256, causal = True, auto_shard_seq = True, ring_attn = True, ring_seq_size = 256 )

tokens = torch.randn(1, 25000, 256) attended = attn(tokens)

assert attended.shape == tokens.shape

lucidrains commented 4 months ago

oh, you need to move the model and the input onto the CUDA device with a .cuda()