Fixes HF Checkpoint callback for TransformerEngine FP8 saving. This PR ensures we serialize the io.BytesIO extra_state tensors as regular tensors insave_pretrained so the code does not error.
Tests
Added unit test, skipped on A100 GPU ✔️
Added unit test, manually ran on H100 GPU ✅
tests/a_scripts/inference/test_convert_composer_to_hf.py::test_huggingface_conversion_callback[1ba-1ba-1ba-1-1-amp_fp8-full-mpt-True-None]
/usr/lib/python3/dist-packages/transformer_engine/pytorch/module/base.py:394: UserWarning: The given buffer is not writable, and PyTorch does not support non-writable tensors. This means you can write to the underlying (supposedly non-writable) buffer using the tensor. You may want to copy the buffer to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:1524.)
state_serialized = torch.frombuffer(pickle.dumps(state), dtype=torch.uint8)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
===================================================================================== 25 passed, 11 skipped, 1621 deselected, 266 warnings in 92.95s (0:01:32) ======================================================================================
Waiting up to 30 seconds for all training processes to terminate. Press Ctrl-C to exit immediately.
Description
Fixes HF Checkpoint callback for TransformerEngine FP8 saving. This PR ensures we serialize the io.BytesIO extra_state tensors as regular tensors in
save_pretrained
so the code does not error.Tests
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ===================================================================================== 25 passed, 11 skipped, 1621 deselected, 266 warnings in 92.95s (0:01:32) ====================================================================================== Waiting up to 30 seconds for all training processes to terminate. Press Ctrl-C to exit immediately.