AdityaNG / kan-gpt

The PyTorch implementation of Generative Pre-trained Transformers (GPTs) using Kolmogorov-Arnold Networks (KANs) for language modeling
https://adityang.github.io/kan-gpt/
MIT License
695 stars 52 forks source link

CUDA out of memory #18

Open ybu-lxd opened 4 months ago

ybu-lxd commented 4 months ago

class KanMLP(nn.Module):
    """Some Information about KanLinear"""
    def __init__(self,
              in_features=1152,
              hidden_features = None,
              out_features = None,
               drop=0.
              ):
        super().__init__()

        approx_gelu = lambda: nn.GELU(approximate="tanh")

        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.mlp = nn.ModuleDict(
            dict(
                c_fc=KAN(width=[in_features, hidden_features]),
                c_proj=KAN(width=[hidden_features, out_features]),
                act=NewGELU(),
                dropout=nn.Dropout(0.0),
            )
        )
        m = self.mlp
        self.mlpf = lambda x: m.dropout(
            m.c_proj(m.act(m.c_fc(x)))
        )  # MLP forward

    def forward(self, x):
        x = self.mlpf(x)
        return x

net = KanMLP(1152,1152*4).to("cuda")
x = torch.rand(size=(4,4096*4,1152)).to("cuda")
nex(x)

When the number of tokens reaches a certain size, the following situation will occur

 CUDA out of memory.
mfrederico commented 1 month ago

Hello! Can you answer these questions?

I dropped your code into claude, and hopfully this gives you some indication:

The main reason you're running out of CUDA memory is the large size of your input tensor. Let's break down the memory usage:

Input tensor x:

Shape: (4, 40964, 1152) = (4, 16384, 1152) Elements: 4 16384 * 1152 = 75,497,472 Assuming float32, this tensor alone requires about 302 MB of memory.

Network parameters:

Input size: 1152 Hidden size: 1152 * 4 = 4608 This results in two large matrices in the KAN layers, each potentially using significant memory.

Intermediate activations:

The forward pass will create several large intermediate tensors, further increasing memory usage.

To address this issue, you can try the following approaches:

Reduce batch size: Instead of processing 4 samples at once, try reducing it to 1 or 2: pythonCopyx = torch.rand(size=(1, 4096*4, 1152)).to("cuda")

Use gradient accumulation: If you need to process larger batches for training stability, you can use gradient accumulation. This involves processing smaller sub-batches and accumulating gradients before performing an optimization step. Use mixed precision training: Utilize float16 (half-precision) computations to reduce memory usage. You can use NVIDIA's Apex library or PyTorch's native AMP (Automatic Mixed Precision): pythonCopyfrom torch.cuda.amp import autocast

with autocast(): output = net(x)

Optimize your model architecture: Consider if you can reduce the size of your hidden layers or use more memory-efficient architectures. Use gradient checkpointing: This technique trades computation for memory by not storing all intermediate activations. If possible, process your data in smaller chunks: Instead of processing the entire 4096*4 sequence length at once, you might be able to process it in smaller segments.