mosaicml / llm-foundry

LLM training code for Databricks foundation models
https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm
Apache License 2.0
3.84k stars 503 forks source link

Fix TE HF checkpoint saving #1280

Closed j316chuck closed 2 weeks ago

j316chuck commented 3 weeks ago

Description

Fixes HF Checkpoint callback for TransformerEngine FP8 saving. This PR ensures we serialize the io.BytesIO extra_state tensors as regular tensors insave_pretrained so the code does not error.

Tests

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ===================================================================================== 25 passed, 11 skipped, 1621 deselected, 266 warnings in 92.95s (0:01:32) ====================================================================================== Waiting up to 30 seconds for all training processes to terminate. Press Ctrl-C to exit immediately.


- Before: `failed-hf-checkpointer-fp8-llama3-8b-metamath-4ep-KOTaOP` 🔴 
- After: `success-hf-checkpointer-fp8-llama3-8b-metamath-4ep-yxNFTK` ✅ 

# Issues
Closes https://databricks.atlassian.net/browse/RGENAI-255 
j316chuck commented 3 weeks ago

@mvpatel2000 loading from fp8 and training with bf16 seens to work with test run example here: torch-231-bf16-load-from-fp8-bR8NzC.

Curious what the use case is in which you would do that though?