Hello, I tried running your test program on Google Colab using a T4 GPU and noticed that the GPU ram would always stay at 0. Do you know if this is the expected behavior? Here's the code I was using. I modified the context win size to something larger:
import torch
from ring_attention_pytorch import RingAttention
Hello, I tried running your test program on Google Colab using a T4 GPU and noticed that the GPU ram would always stay at 0. Do you know if this is the expected behavior? Here's the code I was using. I modified the context win size to something larger:
import torch from ring_attention_pytorch import RingAttention
attn = RingAttention( dim = 256, dim_head = 4, heads = 32, bucket_size = 256, causal = True, auto_shard_seq = True, ring_attn = True, ring_seq_size = 256 )
tokens = torch.randn(1, 25000, 256) attended = attn(tokens)
assert attended.shape == tokens.shape