yhcc / BARTABSA

142 stars 28 forks source link

显示的batch size问题 #17

Closed cos4007 closed 2 years ago

cos4007 commented 2 years ago

无论设置的batch size是多少(4/8/32等都试过),都会显示一次如下信息。 input fields after batch(if batch size is 2): tgt_tokens: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 22]) src_tokens: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 41]) src_seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) tgt_seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) target fields after batch(if batch size is 2): tgt_tokens: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 22]) target_span: (1)type:numpy.ndarray (2)dtype:object, (3)shape:(2,) tgt_seq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) 这一块输出没有看懂。请求作者帮忙解释一下,谢谢。