Open gillotte opened 5 months ago
https://github.com/lucidrains/CoCa-pytorch/blob/edee92c74e311ccfa4a0024412fd991c98aff5fd/coca_pytorch/coca_pytorch.py#L532
fyi the dist batch size isn't correct
torch.arange(batch, device=device) -> torch.arange(text_latents.shape[0], device=device)
https://github.com/lucidrains/CoCa-pytorch/blob/edee92c74e311ccfa4a0024412fd991c98aff5fd/coca_pytorch/coca_pytorch.py#L532
fyi the dist batch size isn't correct
torch.arange(batch, device=device) -> torch.arange(text_latents.shape[0], device=device)