Closed sdtblck closed 2 years ago
I've done the following tests (see test.py
):
With the latest transformers version, I loaded the gptj weights from huggingface and compared to the lm weights we have in our checkpoint. The parameter all have different names, but this renaming should do the correct conversion (in magma.from_checkpoint()
):
sd = torch.load(checkpoint_path, map_location=torch.device("cpu"))
if "module" in sd.keys():
sd = sd["module"]
#### This is a hack to load the old checkpoint. TODO: directly release a modified checkpoint
sd["lm.lm_head.weight"] = sd["lm.lm_head.weight"][:50258, :]
sd["lm.lm_head.bias"] = sd["lm.lm_head.bias"][:50258]
new_sd = {}
for key, value in sd.items():
if "attention" in key:
new_key = key.replace("attention.", "")
new_sd[new_key] = value
elif "c_fc" in key:
new_key = key.replace("c_fc", "fc_in")
new_sd[new_key] = value
elif "c_proj" in key:
new_key = key.replace("c_proj", "fc_out")
new_sd[new_key] = value
else:
new_sd[key] = value
# The below print statement evaluates to false on testing, implying that gpt-j weights
# in our checkpoint are different from the ones in the latest release
old_sd = model.state_dict()
print(
all([torch.all(old_sd[key] == new_sd[key]).item() for key in old_sd.keys()])
)
model.load_state_dict(new_sd, strict=True)
As the comment in the above snippet says, I tested whether the loaded weights from the checkpoint are the same as the downloaded ones, which, unexpectedly does not seem to be the case. Very weird.
Also, I noted that after the renaming I still have these leftover parameters in the checkpoint:
I then compared output logits of magma on the old transformers version vs the new one, where in both cases I loaded the weights directly from our checkpoint and not from huggingface. The result is that the mean absolute error between logits was about 0.008. This could mean that the above left over parameters which are not loaded in the new version somehow lead to this delta.
My next idea would be to actually compare the implementations and see what how these scale, sin and cos parameters get used.
Untested as of right now