Closed CoEich closed 2 years ago
Cleanup done apart from collate_fn:
def collate_fn(batch_data: List[Tuple[torch.Tensor, torch.Tensor]]):
"""
# should be able to replace with this:
return tuple(
torch.cat(i) for i in list(zip(*batch_data))
) # [(img1, cap1), (img2, cap2), ... ] -> [(img1, img2, ... ), (cap1, cap2, ... )])
"""
all_images, all_captions = list(
zip(*batch_data)
) # [(img1, caption1), (img2, caption2), ... ] -> [(img1, img2, ... ), (caption1, caption2, ... )]
return torch.cat(all_images), torch.cat([i[:, :2048] for i in all_captions])
Probably should not hard-code the sequence length (2048) here. Also use the different syntax from the comment?