Closed CiaoHe closed 2 years ago
@CiaoHe Hi again :wave: :smile: turns out I had that notated incorrectly https://github.com/lucidrains/CoCa-pytorch/blob/main/coca_pytorch/coca_pytorch.py#L461 sorry for the confusion!
@lucidrains Oh, I misunderstand the usage of CE. Ha, yes, the input logits should be 'b, vocab_size, length'
I think the logits before this line in shape (bsz, length, num_tokens) -> so I don't think here need one more rearrange https://github.com/lucidrains/CoCa-pytorch/blob/25de0b04326d8dc4c6f969e90b4466fc4894835e/coca_pytorch/coca_pytorch.py#L461