jzhang38 / TinyLlama

The TinyLlama project is an open endeavor to pretrain a 1.1B Llama model on 3 trillion tokens.
Apache License 2.0
7.3k stars 425 forks source link

Activation Checkpointing #152

Closed syncdoth closed 4 months ago

syncdoth commented 5 months ago

I'm trying to setup activation checkpointing for training larger model. I've adjusted the code to:

# tinyllama.py
strategy = FSDPStrategy(
                auto_wrap_policy={Block},
                activation_checkpointing_policy={Block},
                state_dict_type="full",
                limit_all_gathers=True,
                cpu_offload=False,
            )

However, I keep getting the following error:

torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: A different number of tensors was saved during the original forward and recomputation.
Number of tensors saved during forward: 27
Number of tensors saved during recomputation: 8

I'm unsure which may cause this in the GPT code.

jzhang38 commented 4 months ago

Xformers SwiGLU is not compatible with activation checkpointing. Consider disable fused xformers swiglu with torch swiglu layers.

syncdoth commented 4 months ago

Thanks! I totally forgot about that.