Open cyang49 opened 2 months ago
With hard-coded number of cache blocks, I could get around the first problem and another issue occurred:
File "/net/storage149/mnt/md0/ccyang/github.com/cyang49/fms-extras/fms_extras/models/paged_gpt_bigcode.py", line 297, in forward
preds = self.head(embeds)
^^^^^^^^^^^^^^^^^
File "/net/storage149/mnt/md0/ccyang/miniforge3/envs/fms_llama3-405b/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/net/storage149/mnt/md0/ccyang/miniforge3/envs/fms_llama3-405b/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/net/storage149/mnt/md0/ccyang/miniforge3/envs/fms_llama3-405b/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 117, in forward
return F.linear(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
With some debug prints, I saw that, for some reason, the layers are producing fp32 outputs.
decoder input type=x.dtype=torch.float32
ln input type=x.dtype=torch.float32
ln output type=dec_out.dtype=torch.float32
embeds.dtype=torch.float32
In contrast, when using foundation-model-stack/main
, this issue doesn't occur
decoder input type=x.dtype=torch.float16
ln input type=x.dtype=torch.float16
ln output type=dec_out.dtype=torch.float16
embeds.dtype=torch.float16
A third issue is that there are code changes for llama & gpt bigcode models in fms, but those are not applied to paged gpt bigcode in fms-extras. Especially the tied weights. I assume similar changes are needed.
In function
get_max_gpu_blocks_available
The debugger shows that when peak_memory usage is larger than the default 0.8, the negative number computed eventually results in the function returning zero. I think at this point an error should be raised, since the code can't possibly work with 0 blocks in the paged cache. Adding an exception prevents mysterious queue empty error that will eventually happen later when trying to acquire blocks.
Here's the failed example