Aleph-Alpha / magma

MAGMA - a GPT-style multimodal model that can understand any combination of images and language. NOTE: The freely available model from this repo is only a demo. For the latest multimodal and multilingual models from Aleph Alpha check out our website https://app.aleph-alpha.com
MIT License
469 stars 55 forks source link

update transformers version + requirements #16

Closed sdtblck closed 2 years ago

sdtblck commented 2 years ago

Untested as of right now

CoEich commented 2 years ago

I've done the following tests (see test.py):

With the latest transformers version, I loaded the gptj weights from huggingface and compared to the lm weights we have in our checkpoint. The parameter all have different names, but this renaming should do the correct conversion (in magma.from_checkpoint()):

        sd = torch.load(checkpoint_path, map_location=torch.device("cpu"))
        if "module" in sd.keys():
            sd = sd["module"]

        #### This is a hack to load the old checkpoint. TODO: directly release a modified checkpoint
        sd["lm.lm_head.weight"] = sd["lm.lm_head.weight"][:50258, :]
        sd["lm.lm_head.bias"] = sd["lm.lm_head.bias"][:50258]

        new_sd = {}
        for key, value in sd.items():
            if "attention" in key:
                new_key = key.replace("attention.", "")
                new_sd[new_key] = value
            elif "c_fc" in key:
                new_key = key.replace("c_fc", "fc_in")
                new_sd[new_key] = value
            elif "c_proj" in key:
                new_key = key.replace("c_proj", "fc_out")
                new_sd[new_key] = value
            else:
                new_sd[key] = value

        # The below print statement evaluates to false on testing, implying that gpt-j weights
        # in our checkpoint are different from the ones in the latest release
        old_sd = model.state_dict()
        print(
            all([torch.all(old_sd[key] == new_sd[key]).item() for key in old_sd.keys()])
        )

        model.load_state_dict(new_sd, strict=True)

As the comment in the above snippet says, I tested whether the loaded weights from the checkpoint are the same as the downloaded ones, which, unexpectedly does not seem to be the case. Very weird.

Also, I noted that after the renaming I still have these leftover parameters in the checkpoint: image

I then compared output logits of magma on the old transformers version vs the new one, where in both cases I loaded the weights directly from our checkpoint and not from huggingface. The result is that the mean absolute error between logits was about 0.008. This could mean that the above left over parameters which are not loaded in the new version somehow lead to this delta.

My next idea would be to actually compare the implementations and see what how these scale, sin and cos parameters get used.