ucbrise / actnn

ActNN: Reducing Training Memory Footprint via 2-Bit Activation Compressed Training
MIT License
196 stars 30 forks source link

the memory consume question #38

Closed KimmiShi closed 10 months ago

KimmiShi commented 10 months ago

Hi, I am trying to use actnn on transformer models, and I am testing it on a simple nn.linear module:

import torch
import torch.nn as nn
import torch.nn.functional as F
import actnn
from actnn import config, QScheme, QModule
class GEGLU(nn.Module):
    def __init__(self, dim_in: int, dim_out: int):
        super().__init__()
        # self.proj = LoRACompatibleLinear(dim_in, dim_out)
        self.proj = nn.Linear(dim_in, dim_out)

    def gelu(self, gate):
        if gate.device.type != "mps":
            return F.gelu(gate)
        # mps: gelu is not implemented for float16
        return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)

    def forward(self, hidden_states):

        tmp = self.proj(hidden_states)
        # print((tmp.numel()-hidden_states.numel())*4/1e6)
        hidden_states, gate = tmp.chunk(2, dim=-1)
        # import pdb;pdb.set_trace()

        return hidden_states * self.gelu(gate)

def test_m():
    model = GEGLU(640, 5120)
    model = QModule(model)
    model.cuda()

    inp = torch.rand(128, 2304, 640).cuda()

    _ = model(torch.rand(2, 2304, 640).cuda())
    # out.mean().backward()

    beg = torch.cuda.memory_allocated()/1e6

    out = model(inp)
    print("memory:", torch.cuda.memory_allocated()/1e6-beg)
    # print(model.proj.weight.grad.numel()/1e6)
    # out.mean().backward()

actnn.set_optimization_level("L3")

test_m()

How ever, the memory consume I see through the code above does not change when I use or commentmodel = QModule(model), from example:

I printed in actnn/actnn/ops.py how much memory was saved after quantized = quantize_activation(input, scheme), the quantized size was much smaller than input size, there should be about 700MB saved, but I didn't see this difference on the result above.

KimmiShi commented 10 months ago

It seems that memory save does not work when there is only one nn.linear

KimmiShi commented 10 months ago

I did another experiment in a real module, it seems that actnn only works for a certain structure:

for example,the module defined below:

class FeedForward(nn.Module):

    def __init__(
        self,
        dim: int,
        dim_out = None,
        mult: int = 4,
        dropout: float = 0.0,
        final_dropout: bool = False,
    ):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = dim_out if dim_out is not None else dim

        act_fn = nn.Linear(dim, inner_dim)
        self.net = nn.ModuleList([])
        # project in
        self.net.append(act_fn)
        # project dropout
        self.net.append(nn.Dropout(dropout))
        # project out
        self.net.append(nn.Linear(inner_dim, dim_out))
        # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
        if final_dropout:
            self.net.append(nn.Dropout(dropout))

    def forward(self, hidden_states):
        for module in self.net:
            beg = torch.cuda.memory_allocated()/1e6
            hidden_states = module(hidden_states)
            print("module", type(module), torch.cuda.memory_allocated()/1e6-beg)

        return hidden_states

result with Qmodule:

module <class 'actnn.layers.QLinear'> 3079.9303680000003
module <class 'actnn.layers.QDropout'> 94.37183999999979
module <class 'actnn.layers.QLinear'> 3220.439552
memory: 4884.79232

result of basline:

module <class 'torch.nn.modules.linear.Linear'> 3028.41856
module <class 'torch.nn.modules.dropout.Dropout'> 0.0
module <class 'torch.nn.modules.linear.Linear'> 6039.7977599999995
memory: 7558.266879999999

only the last Linear consumes less memory, can anyone tell me why not all linear are quantized to 2bit?

KimmiShi commented 10 months ago

the first layer the saved activation is just a reference