lucidrains / CoCa-pytorch

Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch
MIT License
1.04k stars 88 forks source link

register_buffer for masks and position encodings breaks DDP #17

Closed gshaikov-paige closed 1 year ago

gshaikov-paige commented 1 year ago

Hi!

Unfortunately, using buffers to cache masks and pos encodings fails when running with DDP.

https://github.com/lucidrains/CoCa-pytorch/blob/790415ceaf2af3e937cf2dc16826ccef91ffddfa/coca_pytorch/coca_pytorch.py#L116-L130

Each rank has a different sequence length because text comes in different sizes. Pytorch buffers are synched by DDP but fail to be reduced since the tensors have different dims on each rank.

I found that using buffers is redundant here anyway since we don't store them in state_dict (persistent=False). Unless you know of a good reason why buffers are preferable that I am missing, @lucidrains?

This code worked in DDP setting:

class ParallelTransformerBlock(nn.Module):
    ...
        self.mask = None
        self.pos_emb = None

    def get_mask(self, n: int, device: torch.device) -> Tensor:
        if self.mask is not None and self.mask.shape[-1] >= n:
            return self.mask[:n, :n]

        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        self.mask = mask
        return mask

    def get_rotary_embedding(self, n: int, device: torch.device) -> Tensor:
        if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
            return self.pos_emb[:n]

        pos_emb = self.rotary_emb(n, device=device)
        self.pos_emb = pos_emb
        return pos_emb

Thanks, George

lucidrains commented 1 year ago

yea you are right, that is acceptable

only advantage to using buffers is that it would manage its device with the rest of the network parameters, if you were to move it

lucidrains commented 1 year ago

@gshaikov-paige let me know if 0.0.8 takes care of it

gshaikov-paige commented 1 year ago

Thanks! Almost - you need to change it in __init__ as well

lucidrains commented 1 year ago

@gshaikov-paige oh yes, thanks! i'll also go make the change for a contending method here

lucidrains commented 1 year ago

@gshaikov-paige so one note is that since i open sourced this, there has been research suggesting PaLM to have some sort of inherent instability. i think i should start removing PaLM architecture from all my repositories soon, and replace it with the one from llama (which is close, just serial)

lucidrains commented 1 year ago

happy training!

gshaikov-paige commented 1 year ago

Thanks for the heads up! I can help with transition to LLAMA in this repo, if you're open to PRs?

lucidrains commented 1 year ago

@gshaikov-paige yea absolutely! 🙏

gshaikov-paige commented 1 year ago

I also added type hints on my branch, I can commit them back if you'd like.

lucidrains commented 1 year ago

@gshaikov-paige hmm, i think this project is small enough type hints are not that crucial

lucidrains commented 1 year ago

@gshaikov-paige are you aware that Laion has already trained a CoCa? https://laion.ai/blog/coca/

gshaikov-paige commented 1 year ago

@lucidrains thanks - yes, I am doing research to modify CoCa or related methods for my use case, so needed to be able to train it from scratch. I do plan to experiment with a pretrained model though!

Re Llama port: I am planning to make the change from my personal account (@gshaikov) but I can do it only next week since I am away this week, so if this work is urgent please don't wait for me :)

lucidrains commented 1 year ago

@gshaikov-paige sg! i think finetuning it for the medical domain, if that's what you are still doing, should work well

yea sg, i won't wait haha

gshaikov-paige commented 1 year ago

Ok fair enough :)