EleutherAI / pythia

The hub for EleutherAI's work on interpretability and learning dynamics
Apache License 2.0
2.23k stars 165 forks source link

Weights of "step0" and "step1" checkpoints are identical for all pythia models #83

Closed byungdoh closed 1 year ago

byungdoh commented 1 year ago

Dear EleutherAI team,

I've noticed that the weights associated with the recently added "step0" and "step1" checkpoints are identical for all pythia models:

def main():
    print(f"========== {sys.argv[1]} ==========")
    model_step0 = GPTNeoXForCausalLM.from_pretrained(sys.argv[1], revision="step0", cache_dir=f"./test")
    model_step1 = GPTNeoXForCausalLM.from_pretrained(sys.argv[1], revision="step1", cache_dir=f"./test")

    for (name0, param0), (name1, param1) in zip(model_step0.named_parameters(), model_step1.named_parameters()):
        print(name0, name1, name0 == name1, torch.all(param0==param1))

This yields something like the following for all eight pythia models:

========== EleutherAI/pythia-70m ==========
gpt_neox.embed_in.weight gpt_neox.embed_in.weight True tensor(True)
gpt_neox.layers.0.input_layernorm.weight gpt_neox.layers.0.input_layernorm.weight True tensor(True)
...
gpt_neox.final_layer_norm.weight gpt_neox.final_layer_norm.weight True tensor(True)
gpt_neox.final_layer_norm.bias gpt_neox.final_layer_norm.bias True tensor(True)
embed_out.weight embed_out.weight True tensor(True)

Would it be possible for you to clarify whether these identical weights correspond to those from "step0" or "step1?" I've noticed that the conditional probabilities calculated using these weights aren't perfectly uniform, which leads me to believe these are actually weights from "step1."

Thanks! Byung-Doh

haileyschoelkopf commented 1 year ago

Hi, thanks very much for reporting this! I'll look into it and get back to you as soon as I'm able.

StellaAthena commented 1 year ago

@haileyschoelkopf did you end up looking into this?

haileyschoelkopf commented 1 year ago

I have not yet unfortunately, I'll look at this tomorrow and report back!

StellaAthena commented 1 year ago

Looking around the checkpointing code, it looks to me like we should be saving the 0th checkpoint before we do any weight updates. That's the obvious failure mode that could be causing this.

haileyschoelkopf commented 1 year ago

Continuing to investigate, but upon digging in I'm finding that the info reported by Deepspeed's checkpoint metadata for the NeoX-library checkpoints reports that all is ok! For the EleutherAI/pythia-160m model, the step0 checkpoint reports global_samples: 0 and global_steps: 0 while the step1 checkpoint reports global_samples: 1024 and global_steps: 1.

I therefore suspect that this is an artifact of LR warmup starting from 0, causing weights to not yet update on the first step, but am looking into this further. On a scan of a couple parameters in a layer of the 160M model, many (but not all) of the individual floating point parameters printed as the same for step as for step 0, indicating that some parameters will show up as equal to the step 0 checkpoint even after multiple train steps for these super early warmup steps.

I'm therefore pretty confident I did in fact save and upload the correct early checkpoints.

Hope this answers your question @byungdoh !

(Aside:

Note that there was an issue in which "step0" checkpoints in NeoX would be overwritten if a job was resumed by the step being resumed from, but that issue was patched before these models were trained.)

byungdoh commented 1 year ago

I think it is indeed because the learning rate is 0.0 at the first step, as self.num_iters (and thereby num_iters_) here is initialized from 0. Thank you both @haileyschoelkopf @StellaAthena !