FlagOpen / FlagGems

FlagGems is an operator library for large language models implemented in Triton Language.
Apache License 2.0
296 stars 26 forks source link

在运行normal.py时遇到了问题 #234

Open ShZ-Li opened 2 days ago

ShZ-Li commented 2 days ago

Please describe your question

当我在尝试调用normal.py中的def normal_tensor_float(mean, std, *, generator=None):方法时,我采用了下面的方式:

mean = torch.randn(3, 4, device='cuda')
std = 1.5
result = normal_tensor_float(mean, std)

然而出现错误:

result = normal_tensor_float(mean, std)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/szli/FlagGems/src/flag_gems/ops_copy/normal.py", line 67, in normal_tensor_float
    out = normal_distribution(mean, std)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/szli/FlagGems/src/flag_gems/ops_copy/normal.py", line 47, in normal_distribution
    shape = broadcast_shapes([mean.shape, std.shape])
                                          ^^^^^^^^^
AttributeError: 'float' object has no attribute 'shape'

当指定mean或std为float类型时,其本身并不具备shape属性。是我在传入的参数类型这里出错了吗

StrongSpoon commented 1 day ago

请 @Bowen12992 帮忙看看,我理解normal_tensor_float的std参数应该是float形式,不支持.shape?