Open anthonyAndchen opened 1 year ago
你好,我发现了类似的错误,你是否解决了这个问题?
File "main.py", line 246, in
你好,我经过一些尝试后,在modules.model.py中的_get_length函数修改一下就可以运行。 原来的函数: def _get_length(self, logit, dim=-1): """ Greed decoder to obtain length from logit""" out = (logit.argmax(dim=-1) == self.charset.null_label) out = self.first_nonzero(out.int()) + 1 return out 修改为: def _get_length(self, logit, dim=-1): """ Greed decoder to obtain length from logit"""
out = (logit.argmax(dim=-1) == self.charset.null_label)
abn = out.any(dim)
out = ((out.cumsum(dim) == 1) & out).max(dim)[1]
out = out + 1 # additional end token
out = torch.where(abn, out, out.new_tensor(logit.shape[1]))
return out
发件人: wander @.> 发送时间: 2024年1月26日 21:01 收件人: FangShancheng/ABINet @.> 抄送: anthonyAndchen @.>; Author @.> 主题: Re: [FangShancheng/ABINet] callbacks.py中断言错误:AssertionError: tensor([0, 0], device='cuda:0') != tensor([26, 26], device='cuda:0') (Issue #102)
你好,我发现了类似的错误,你是否解决了这个问题?
File "main.py", line 246, in main() File "main.py", line 234, in main learner.fit(epochs=config.training_epochs, File "/home/yaorenyuan/anaconda3/envs/fz/lib/python3.8/site-packages/fastai/basic_train.py", line 200, in fit fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks) File "/home/yaorenyuan/anaconda3/envs/fz/lib/python3.8/site-packages/fastai/basic_train.py", line 102, in fit if cb_handler.on_batch_end(loss): break File "/home/yaorenyuan/anaconda3/envs/fz/lib/python3.8/site-packages/fastai/callback.py", line 308, in on_batch_end self('batch_end', call_mets = not self.state_dict['train']) File "/home/yaorenyuan/anaconda3/envs/fz/lib/python3.8/site-packages/fastai/callback.py", line 251, in call for cb in self.callbacks: self._call_and_update(cb, cb_name, kwargs) File "/home/yaorenyuan/anaconda3/envs/fz/lib/python3.8/site-packages/fastai/callback.py", line 241, in call_and_update new = ifnone(getattr(cb, f'on{cb_name}')(self.state_dict, kwargs), dict()) File "/home/yaorenyuan/fz/ABINet/callbacks.py", line 117, in on_batch_end last_metrics = self._validate() File "/home/yaorenyuan/fz/ABINet/callbacks.py", line 65, in validate val_metrics = validate(self.learn.model, dl, self.loss_func, cb_handler) File "/home/yaorenyuan/anaconda3/envs/fz/lib/python3.8/site-packages/fastai/basic_train.py", line 63, in validate if cb_handler and cb_handler.on_batch_end(val_losses[-1]): break File "/home/yaorenyuan/anaconda3/envs/fz/lib/python3.8/site-packages/fastai/callback.py", line 308, in on_batch_end self('batch_end', call_mets = not self.state_dict['train']) File "/home/yaorenyuan/anaconda3/envs/fz/lib/python3.8/site-packages/fastai/callback.py", line 250, in call for met in self.metrics: self.call_and_update(met, cb_name, kwargs) File "/home/yaorenyuan/anaconda3/envs/fz/lib/python3.8/site-packages/fastai/callback.py", line 241, in call_and_update new = ifnone(getattr(cb, f'on{cb_name}')(self.state_dict, kwargs), dict()) File "/home/yaorenyuan/fz/ABINet/callbacks.py", line 205, in on_batch_end assert (pt_lengths == pt_lengths).all(), f'{pt_lengths} != {pt_lengths} for {pt_text}' AssertionError: tensor([ 4, 11, 5, 5, 6, 4, 7, 6, 6, 5, 4, 12, 7, 4, 5, 8, 3, 5, 9, 9, 6, 5, 10, 5, 3, 3, 5, 6, 4, 11, 9, 6, 12, 8, 4, 5, 6, 5, 3, 0, 6, 3, 9, 3, 6, 5, 8, 4, 2, 11, 5, 7, 8, 6, 7, 7, 8, 7, 4, 2, 5, 10, 7, 6], device='cuda:0') != tensor([ 4, 11, 5, 5, 6, 4, 7, 6, 6, 5, 4, 12, 7, 4, 5, 8, 3, 5, 9, 9, 6, 5, 10, 5, 3, 3, 5, 6, 4, 11, 9, 6, 12, 8, 4, 5, 6, 5, 3, 26, 6, 3, 9, 3, 6, 5, 8, 4, 2, 11, 5, 7, 8, 6, 7, 7, 8, 7, 4, 2, 5, 10, 7, 6], device='cuda:0') f
― Reply to this email directly, view it on GitHubhttps://github.com/FangShancheng/ABINet/issues/102#issuecomment-1912035087, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AOWBX24CSH4AGU72DBJW5DLYQOSLZAVCNFSM6AAAAAA4QLP7QSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSMJSGAZTKMBYG4. You are receiving this because you authored the thread.Message ID: @.***>
My solution here:
def _get_length(self, logit):
""" Greed decoder to obtain length from logit"""
out = (logit.argmax(dim=-1) == self.charset.null_label)
out = self.first_nonzero(out.int()) + 1
out[out==0]=logit.shape[1];
return out
The gist is that the original code assumes the predicted strings are always with an EOS character in it, but alas for difficult or smol datasets there may not be, esp at early iterations. So we force the strings to end at the last time stamp here.
My solution here:
def _get_length(self, logit): """ Greed decoder to obtain length from logit""" out = (logit.argmax(dim=-1) == self.charset.null_label) out = self.first_nonzero(out.int()) + 1 out[out==0]=logit.shape[1]; return out
The gist is that the original code assumes the predicted strings are always with an EOS character in it, but alas for difficult or smol datasets there may not be, esp at early iterations. So we force the strings to end at the last time stamp here.
你好,这个结束符也就是null_label不是0吗,所以添加结束符out[out==0]=logit.shape[1]为什么会是logit.shape[1]而不是=0呢?
My solution here:
def _get_length(self, logit): """ Greed decoder to obtain length from logit""" out = (logit.argmax(dim=-1) == self.charset.null_label) out = self.first_nonzero(out.int()) + 1 out[out==0]=logit.shape[1]; return out
The gist is that the original code assumes the predicted strings are always with an EOS character in it, but alas for difficult or smol datasets there may not be, esp at early iterations. So we force the strings to end at the last time stamp here.
你好,这个结束符也就是null_label不是0吗,所以添加结束符out[out==0]=logit.shape[1]为什么会是logit.shape[1]而不是=0呢?
lemme rephrase the code a bit as the explanation (~_^)
def _get_length(self, logit):
""" Greed decoder to obtain length from logit"""
iseos = (logit.argmax(dim=-1) == self.charset.null_label)
len= self.first_nonzero(iseost.int()) + 1
max_t=logit.shape[1];
len[len==0]= max_t;
return len;
My solution here:
def _get_length(self, logit): """ Greed decoder to obtain length from logit""" out = (logit.argmax(dim=-1) == self.charset.null_label) out = self.first_nonzero(out.int()) + 1 out[out==0]=logit.shape[1]; return out
The gist is that the original code assumes the predicted strings are always with an EOS character in it, but alas for difficult or smol datasets there may not be, esp at early iterations. So we force the strings to end at the last time stamp here.
你好,这个结束符也就是null_label不是0吗,所以添加结束符out[out==0]=logit.shape[1]为什么会是logit.shape[1]而不是=0呢?
Lemme 稍微改写一下代码作为解释 (~_^)
def _get_length(self, logit): """ Greed decoder to obtain length from logit""" iseos = (logit.argmax(dim=-1) == self.charset.null_label) len= self.first_nonzero(iseost.int()) + 1 max_t=logit.shape[1]; len[len==0]= max_t; return len;
谢谢你的解决办法,但是有一处小小的错误,iseos和iseost,多了一个“t”
你好,我最近在尝试ABINet算法,但是运行时进行eval时出现以下错误:
我尝试打断点到callbacks报错处,发现传入的last_outputs和获取的output都是None: ![Uploading image.png…]()
请问这有可能是什么错误呢?非常感谢?