Jittor / jittor

Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators.
https://cg.cs.tsinghua.edu.cn/jittor/
Apache License 2.0
3.08k stars 311 forks source link

1.2.3.44版本中nn.cross_entropy_loss函数,传入weight参数报错 #236

Open xiashuo opened 3 years ago

xiashuo commented 3 years ago

Traceback (most recent call last): File "/disk_sda/xs/projects/xuexian-jittor/train.py", line 124, in main() File "/disk_sda/xs/projects/xuexian-jittor/train.py", line 119, in main train(model, train_loader, optimizer, epoch, learning_rate, writer, epochs) File "/disk_sda/xs/projects/xuexian-jittor/train.py", line 54, in train loss = nn.cross_entropy_loss(pred, target, weight=jt.array([1, 20, 8, 4]),ignore_index=255) # fix a bug File "/home/kqgis/anaconda3/envs/xuexian-jittor/lib/python3.8/site-packages/jittor/nn.py", line 214, in cross_entropy_loss loss = (logsum - (outputtarget).sum(1)) target_weight RuntimeError: Wrong inputs arguments, Please refer to examples(help(jt.mul)).

Types of your inputs are: self = Var, b = Var,

The function declarations are: VarHolder multiply(VarHolder x, VarHolder* y)

Failed reason:[f 0625 14:07:57.001614 44 binary_op.cc:522] Check failed xshape(8388608) == yshape(512) Shape not match, x:float32[8388608,] y:int32[32,512,512,]

下面是我的调用代码,我是按照pytorch里的参数格式传入的(4个类别),请问是我参数格式有误还是? loss = nn.cross_entropy_loss(pred, target, weight=jt.array([1, 20, 8, 4])) # fix a bug

Jittor commented 3 years ago

请问您predtarget的形状是怎样的呢?

xiashuo commented 3 years ago

pred: (32,4,512,512) target: (32,512,512)

------------------ 原始邮件 ------------------ 发件人: @.>; 发送时间: 2021年6月25日(星期五) 下午2:27 收件人: @.>; 抄送: "(6X13N4) @.>; @.>; 主题: Re: [Jittor/jittor] 1.2.3.44版本中nn.cross_entropy_loss函数,传入weight参数报错 (#236)

请问您pred, target的形状是怎样的呢?

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or unsubscribe.

xiashuo commented 3 years ago

请问您predtarget的形状是怎样的呢?

pred: (32,4,512,512) target: (32,512,512)

Jittor commented 3 years ago

您好,感谢您的反馈,我们刚刚根据您反馈的形状编写了一个test,并没能复现出您的错误,这是我们写的test:https://github.com/Jittor/jittor/commit/aaf97d5f58a3fed8aa2f93be132a9ed9be8dd4a5

请您确认一下是否更新到了最新版本,因为以前旧版本是会有这个错误的。 如果还是有问题,您可以提供一下完整的日志给我们。再次感谢

xiashuo commented 3 years ago

您好,感谢您的反馈,我们刚刚根据您反馈的形状编写了一个test,并没能复现出您的错误,这是我们写的test:aaf97d5

请您确认一下是否更新到了最新版本,因为以前旧版本是会有这个错误的。 如果还是有问题,您可以提供一下完整的日志给我们。再次感谢

升级到1.2.3.45版,问题解决了,谢谢!这版本更新太快了,一天升级了3个版本哈哈

Jittor commented 3 years ago

为了快速支持大家,我们会光速更新的😆