pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
212 stars 20 forks source link

[FSDP2] set vocab_size=32 to avoid must be divisible by 16 error #265

Closed weifengpy closed 4 months ago

weifengpy commented 4 months ago

pytest -s test/test_fsdp2/test_fsdp2_eager.py -k test_transformer_parity_dynamic

E             File "/home/weif/local/pytorch-official/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 205, in forward
E               output = self.output(h).float()
E             File "/home/weif/local/pytorch-official/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
E               return self._call_impl(*args, **kwargs)
E             File "/home/weif/local/pytorch-official/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
E               return forward_call(*args, **kwargs)
E             File "/data/users/weif/float8_experimental/float8_experimental/float8_dynamic_linear.py", line 71, in forward
E               y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
E             File "/data/users/weif/float8_experimental/float8_experimental/float8_tensor.py", line 297, in __torch_dispatch__
E               return FLOAT8_OPS_TABLE[func](func, args, kwargs)
E             File "/data/users/weif/float8_experimental/float8_experimental/float8_ops.py", line 151, in float8_mm
E               tensor_out, amax = addmm_float8_unwrapped(
E             File "/data/users/weif/float8_experimental/float8_experimental/float8_python_api.py", line 55, in addmm_float8_unwrapped
E               output, output_amax = torch._scaled_mm(
E           RuntimeError: mat2 shape (768x8 must be divisible by 16
E           Exception raised from _scaled_mm_out_cuda at /data/users/weif/pytorch-official/pytorch/aten/src/ATen/native/cuda/Blas.cpp:874 (most recent call first):
facebook-github-bot commented 4 months ago

@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 4 months ago

@weifengpy merged this pull request in pytorch-labs/float8_experimental@cdb78678d543178aab59f7216dc0458f2242f629.