mean = torch.randn(3, 4, device='cuda')
std = 1.5
result = normal_tensor_float(mean, std)
然而出现错误:
result = normal_tensor_float(mean, std)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/home/szli/FlagGems/src/flag_gems/ops_copy/normal.py", line 67, in normal_tensor_float
out = normal_distribution(mean, std)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/home/szli/FlagGems/src/flag_gems/ops_copy/normal.py", line 47, in normal_distribution
shape = broadcast_shapes([mean.shape, std.shape])
^^^^^^^^^
AttributeError: 'float' object has no attribute 'shape'
Please describe your question
当我在尝试调用normal.py中的def normal_tensor_float(mean, std, *, generator=None):方法时,我采用了下面的方式:
然而出现错误:
当指定mean或std为float类型时,其本身并不具备shape属性。是我在传入的参数类型这里出错了吗