huggingface / nanotron

Minimalistic large language model 3D-parallelism training
Apache License 2.0
1.14k stars 107 forks source link

Deprecate `recompute_granularity` in config #76

Closed NouamaneTazi closed 7 months ago

NouamaneTazi commented 7 months ago

Selective recomputation is handled by flash-attn now so there's no need to keep it in config. We still keep the handy @checkpoint_method decorator if we want to activate recomputation for some method

class MyFancyModule(nn.Module):
    def __init__(self):
        ...
        self.do_checkpoint: bool = True

    @checkpoint_method(attr_name="do_checkpoint")
    def forward(self, x):
        ...