stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
141 stars 9 forks source link

Move `state_dict` serialization to Haliax #35

Open dlwh opened 10 months ago

dlwh commented 10 months ago

Currently Levanter has a bunch of machinery to support serializing to/from state_dicts, as well as writing to safetensors. I think it would be move that functionality to Haliax.

While doing it, I would like to revamp a few things:

rohan-mehta-1024 commented 10 months ago

Also think this would be nice, b/c, e.g., in my nanoGPT implementation I have to install all of levanter just to use its serialization features. Seems like this may be a large-ish task, are there any areas I could potentially help with?

dlwh commented 10 months ago

I think so!

I think we could start by:

I don't think it's a ton of work, but it's not trivial either

dlwh commented 10 months ago

@rohan-mehta-1024 i merged the change I was making into Levanter, so if you're excited about this, happy to let you take it! Happy to discuss more if you want

rohan-mehta-1024 commented 9 months ago

Sorry for the delay on this! I've just been a busy with starting school. I plan to work on it over the weekend.

dlwh commented 9 months ago

no problem at all. Happy you're looking at it at all!

rohan-mehta-1024 commented 7 months ago

Sorry if this is obvious, but what did you mean by state_dict_key=foo? Also, just to confirm, when you say porting over the hf_checkpoint logic, you don't mean the whole file, right? Because some of the code in there seems to be specific to the case of language models.

dlwh commented 7 months ago

sorry that was very unclear. By that I mean, you could write something like:

class MyModule(eqx.Module):
   weights: hnn.Linear = eqx.field(state_dict_key="foo")

and have the mapping work.

It's not a necessary step but I think it's a nicer api.

And re: the hf_checkpoints stuff, right. I was thinking of basically just a function to do this https://github.com/stanford-crfm/levanter/blob/16112cc003680e79e15fc7c62b63f917679e7e32/src/levanter/compat/hf_checkpoints.py#L403-L407 + the safetensors variant

rohan-mehta-1024 commented 7 months ago

Ok, thanks for the clarification! eqx.field passes all kwargs that are not converter or static directly to dataclasses.field, which does not accept arbitrary kwargs. So we would either have to do eqx.field(metadata={'state_dict_key' : 'foo'}) (which isn't as clean), or find somewhat of a hacky way around this. Would it be worth creating a hax.field function which is basically the same as eqx.field except that if you pass it a kwarg that is not accepted by dataclasses.field it automatically shoves it in metadata? The only downside to this would be replacing a lot of instances of eqx.field...unless maybe this is general enough that we could make a pull request for this change directly in Equinox? I'm not sure how else to get it to work besides these two ways though...

dlwh commented 7 months ago

That's a good point. I think hax.field is probably a good thing to do—it's not like it's mandatory that we use it everywhere—but we don't have to decide just yet

On Wed, Dec 20, 2023 at 7:29 AM rohan-mehta-1024 @.***> wrote:

Ok, thanks for the clarification! eqx.field passes all kwargs that are not converter or static directly to dataclasses.field, which does not accept arbitrary kwargs. So we would either have to do eqx.field(metadata={'state_dict_key' : 'foo'}) (which isn't as clean), or find somewhat of a hacky way around this. Would it be worth creating a hax.field function which is basically the same as eqx.field except that if you pass it a kwarg that is not accepted by dataclasses.field it automatically shoves it in metadata? The only downside to this would be replacing a lot of instances of eqx.field...unless maybe this is general enough that we could make a pull request for this change directly in Equinox? I'm not sure how else to get it to work besides these two ways though...

— Reply to this email directly, view it on GitHub https://github.com/stanford-crfm/haliax/issues/35#issuecomment-1864675518, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAACLIN6GRATZPVS2Q56AEDYKL75HAVCNFSM6AAAAAA4MEPMOKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNRUGY3TKNJRHA . You are receiving this because you authored the thread.Message ID: @.***>

rohan-mehta-1024 commented 7 months ago

Ok, so for now should I just go ahead with a "soft" version of hax.field and then we can modify this approach as we see fit later?

dlwh commented 7 months ago

Not sure what that means? But my guess is probably :-)

On Wed, Dec 20, 2023 at 10:49 AM rohan-mehta-1024 @.***> wrote:

Ok, so for now should I just go ahead with a "soft" version of hax.field and then we can modify this approach as we see fit later?

— Reply to this email directly, view it on GitHub https://github.com/stanford-crfm/haliax/issues/35#issuecomment-1864975575, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAACLIIRL4G7E524ZYZPNRDYKMXKVAVCNFSM6AAAAAA4MEPMOKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNRUHE3TKNJXGU . You are receiving this because you authored the thread.Message ID: @.***>

rohan-mehta-1024 commented 7 months ago

Yeah, that was poor wording, sorry (basically I think I'm just going to go the route of explicitly using metadata and then once everything else has been ported over successfully, I guess we can decide from there, since it should just be a pretty minor stylistic change)! I was also wondering if we should make the following stylistic change. When flattening/unflattening layers, the type signature is (prefix, statedict: StateDict, layer: hnn.Linear, out_dims_first_in_dict: Optional[bool]). Since we pass the layer itself in, the function can use the state_dict_key_map to lookup any necessary prefixes. So, if you have the following:

def _state_dict_key_map(self):
        return {'attn' : 'c_attn', 'proj' : 'c_proj'}

Then you only need to do:

unflatten_linear_layers(prefix, state_dict, self.attn, None)

Instead of:

unflatten_linear_layers(apply_prefix(prefix, 'c_attn'), state_dict, self.attn, None)

Since the function can infer this. However, for stacking/unstacking layers, the function does not explicitly take in the block to be stacked/unstacked. So even if you have:

def _state_dict_key_map(self):
    return {
        "blocks"              : "h",
        "tok_embedding_table" : "wte", 
        "pos_embedding_table" : "wpe"
    }

You still have do:

stacked_params = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "h"))

Would it make sense to rewrite the function to have it take in the block to stack/unstack, so that you could do this instead:

stacked_params = stack_state_dict(state_dict, self.blocks, prefix=prefix)

It would probably complicate the underlying function, but it would also allow for automatically inferring prefixes from state_dict_key_map. Also, it seems like it would be a little more explicit, since you don't have to reason through the prefixes to find out what is being stacked/unstacked. I was just wondering if this line of thought makes sense to you, or if there are some other reasons I haven't identified why it would be better to keep things as they are?

dlwh commented 7 months ago

I'm confused. How does it know what layer it is?

On Thu, Dec 21, 2023 at 8:30 AM rohan-mehta-1024 @.***> wrote:

Yeah, that was poor wording, sorry (basically I think I'm just going to go the route of explicitly using metadata and then once everything else has been ported over successfully, I guess we can decide from there, since it should just be a pretty minor stylistic change)! I was also wondering if we should make the following stylistic change. When flattening/unflattening layers, the type signature is (prefix, statedict: StateDict, layer: hnn.Linear, out_dims_first_in_dict: Optional[bool]). Since we pass the layer itself in, the function can use the state_dict_key_map to lookup any necessary prefixes. So, if you have the following:

def _state_dict_key_map(self): return {'attn' : 'c_attn', 'proj' : 'c_proj'}

Then you only need to do:

unflatten_linear_layers(prefix, state_dict, self.attn, None)

Instead of:

unflatten_linear_layers(apply_prefix(prefix, 'c_attn'), state_dict, self.attn, None)

Since the function can infer this. However, for stacking/unstacking layers, the function does not explicitly take in the block to be stacked/unstacked. So even if you have:

def _state_dict_key_map(self): return { "blocks" : "h", "tok_embedding_table" : "wte", "pos_embedding_table" : "wpe" }

You still have do:

stacked_params = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "h"))

Would it make sense to rewrite the function to have it take in the block to stack/unstack, so that you could do this instead:

stacked_params = stack_state_dict(state_dict, self.blocks, prefix=prefix)

It would probably complicate the underlying function, but it would also allow for automatically inferring prefixes from state_dict_key_map. Also, it seems like it would be a little more explicit, since you don't have to reason through the prefixes to find out what is being stacked/unstacked. I was just wondering if this line of thought makes sense to you, or if there are some other reasons I haven't identified why it would be better to keep things as they are?

— Reply to this email directly, view it on GitHub https://github.com/stanford-crfm/haliax/issues/35#issuecomment-1866609313, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAACLIMSZN2AODVAJCJYWYLYKRP4BAVCNFSM6AAAAAA4MEPMOKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTQNRWGYYDSMZRGM . You are receiving this because you authored the thread.Message ID: @.***>

rohan-mehta-1024 commented 6 months ago

Maybe I'm misunderstanding, but the following code, e.g., in unflatten_linear_layers:

tree_prefixes = leaf_key_paths(
        layer, prefix, is_leaf=lambda x: isinstance(x, hnn.Linear), use_state_dict_keys=True
  )

has a use_state_dict_keys argument which allows it to access the state_dict_key_map. And here is the relevant code from leaf_key_paths where it does access this mapping:

    elif isinstance(pytree, eqx.Module):
        names = []
        rec_values = []
        for field in fields(pytree):
            if field.metadata.get("static", False):
                continue
            field_name = field.name
            field = getattr(pytree, field_name)
            names.append(field_name)

            if use_state_dict_keys and hasattr(pytree, "_state_dict_key_map"):
                field_name = pytree._state_dict_key_map().get(field_name, field_name)

            rec_value = rec(field, field_name)
            rec_values.append(rec_value)
        return eqx.tree_at(lambda m: [getattr(m, name) for name in names], pytree, rec_values)

So I don't understand why you would explicitly have to do: d.update(unflatten_linear_layers(apply_prefix(prefix, "c_attn"), state_dict, self.c_attn, None)). Why do we need to do the apply_prefix(prefix, "c_attn") ourselves? Won't this code figure out that the attn module has to be renamed "c_attn", because it goes through all the fields of our given layer and updates the names of all of those which are also keys in our mapping? My understanding of how prefixes work is that they are automatically and recursively added on to more and more nested layers, so I'm a little confused why we sometimes have to explicitly add them on ourselves? Hopefully this makes a bit more sense!

dlwh commented 6 months ago

oh well, the convention in all these methods is that they receive the path they're supposed to use. self.c_attn doesn't know that it's supposed to be named self.c_attn. It would need information from the parent (self) to know that.

rohan-mehta-1024 commented 5 months ago

Oh ok, I see, that makes sense now (sorry for the confusion). I think I've basically made all the necessary changes at this point (the only other thing I'm wondering about is what you meant by persistent_buffer = false, because anything to do with buffers seems to come from other parts of the hf_checkpoints file?). Also, I was curious what what you think a good idea is for testing? The Levanter tests for serialization/state_dict loading use full-blown models and are pretty extensive, but I was thinking of maybe just creating a small dummy model in both Haliax and PyTorch and then making sure it was possible to load in the PyTorch weights into the Haliax model and vice versa. Do you think this is a robust enough test?

dlwh commented 5 months ago

Oh wow, thank you!

I think buffers thing can be a separate PR/future issue. I agree lighter weight tests are better here and I agree round-tripping to/from torch (and Haliax itself) is probably the best route.

Maybe like, a few tests that test only-Haliax roundtrip functionality (dicts, renames, reorders, whatever), and then a few tests that only fire with torch installed that make sure roundtrips work there?