Closed awgu closed 3 months ago
@tianyu-l I opened a PR into the meta-llama/llama3
repo with this change.
Personally, I think that this kind of change is valid for us to do in torchtitan
. It does not affect the model structure, is not intrusive (in that, llama3
repo could adopt this pretty easily if they wanted to), and is numerics preserving.
Stack from ghstack (oldest at bottom):
Overview This PR replaces
x[:, :, :, None, :]
withtorch.unsqueeze(x, dim=3)
to avoid unnecessary device copies and fill kernels in backward (namely, 4 fills and 4 copies with shape(bs, seq_len, n_kv_heads, head_dim)
).Traces Existing forward (CPU):
Existing backward (CPU):
Existing backward (GPU):
Each
aten::slice
in forward leads to aSliceBackward0
, which isaten::zeros
->aten::copy_
.New forward (CPU):
New backward (CPU):
New backward (GPU): has no more fill kernels and device copies
Test Plan
torch.manual_seed(0)
andtorch.use_deterministic_algorithms(True, warn_only=False)
at the beginning ofmain
intrain.py
CUBLAS_WORKSPACE_CONFIG=:4096:8 CONFIG_FILE=train_configs/llama3_8b.toml ./run_llama_train.sh
with DP=8, batch size 1, and no AC