Open 152334H opened 9 months ago
Can you share me the details reproduce steps? Seems pytorch 2.2 needs a higher version of NCCL and currently we only supports pytorch 2.1 and 1.4
I met the same problem. My torch version is 2.4.0 with CUDA 12.1:
File "/home/yatorho/doc/projs/MS-AMP/examples/mnist.py", line 182, in <module>
main()
File "/home/yatorho/doc/projs/MS-AMP/examples/mnist.py", line 173, in main
train(args, model, device, train_loader, optimizer, epoch)
File "/home/yatorho/doc/projs/MS-AMP/examples/mnist.py", line 73, in train
scaler.step(optimizer)
File "/home/yatorho/anaconda3/envs/t24/lib/python3.12/site-packages/torch/amp/grad_scaler.py", line 448, in step
self.unscale_(optimizer)
File "/home/yatorho/anaconda3/envs/t24/lib/python3.12/site-packages/torch/amp/grad_scaler.py", line 338, in unscale_
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
^^^^^^^^^^^^^^^^^^^^^
File "/home/yatorho/anaconda3/envs/t24/lib/python3.12/site-packages/torch/amp/grad_scaler.py", line 256, in _unscale_grads_
assert isinstance(param, torch.Tensor), f"param is not a Tensor: {type(param)}"
AssertionError: param is not a Tensor: <class 'msamp.nn.parameter.ScalingParameter'>
The param
's type is ScalingParameter
.
Hi @yatorho , PyTorch added a new assertion to check whether param is torch.Tensor, but ScalingTensor in MS-AMP is not torch.Tensor.
A temporal solution is to comment the Line 256 in torch/amp/grad_scaler.py
: assert isinstance(param, torch.Tensor), f"param is not a Tensor: {type(param)}"
.
Thanks! it works for me.
What's the issue, what's expected?:
python mnist.py --enable-msamp --opt-level=O2
should work with the versions pinned inpyproject.toml
. Specifically, it should work withtorch==2.2.1
, given that torch is unpinned.How to reproduce it?: build MS-AMP with
torch==2.2.1
.Log message or shapshot?:
Additional information: This occurs because
optimizer.param_groups[:,'params']
containsScalingParameter
sScalingParameter
s subclassScalingTensor
which subclasses nothing, so theisinstance
check failsCommenting out the assertion line manually fixes the issue. I do not know how to reasonably fix this without resorting to that.