Aleph-Alpha / magma

MAGMA - a GPT-style multimodal model that can understand any combination of images and language. NOTE: The freely available model from this repo is only a demo. For the latest multimodal and multilingual models from Aleph Alpha check out our website https://app.aleph-alpha.com
MIT License
475 stars 55 forks source link

Remove dataset builders and old classes in multimodal_fewshot.datasets #6

Closed CoEich closed 2 years ago

CoEich commented 2 years ago
CoEich commented 2 years ago

Cleanup done apart from collate_fn:

def collate_fn(batch_data: List[Tuple[torch.Tensor, torch.Tensor]]):
    """
    # should be able to replace with this:
    return tuple(
        torch.cat(i) for i in list(zip(*batch_data))
    )  # [(img1, cap1), (img2, cap2), ... ] -> [(img1, img2, ... ), (cap1, cap2, ... )])
    """
    all_images, all_captions = list(
        zip(*batch_data)
    )  # [(img1, caption1), (img2, caption2), ... ] -> [(img1, img2, ... ), (caption1, caption2, ... )]
    return torch.cat(all_images), torch.cat([i[:, :2048] for i in all_captions])

Probably should not hard-code the sequence length (2048) here. Also use the different syntax from the comment?