patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.1k stars 141 forks source link

FAQ: does it run any of the public GPT models? #272

Open pannous opened 1 year ago

pannous commented 1 year ago

does it run any of the public GPT models, or are the data structures fundamentally incompatible?

patrick-kidger commented 1 year ago

I don't think I know of an Equinox version of these models. There shouldn't be any incompatibility; just that no-one has yet taken the time to implement them!

dieterichlawson commented 1 year ago

I used equinox to write a version of Andrej Karpathy's "MinGPT" that may be of interest:

https://github.com/dieterichlawson/mingpt-jax

I think some parts are actually cleaner than the PyTorch implementation, such as figuring out which parameters to weight decay.

patrick-kidger commented 1 year ago

That's awesome! Do I understand that this is GPT2, as with Andrej's version?

Would you be okay with me linking to this as an example in the docs?

alexlatif commented 9 months ago

How would you approach Pytorch model conversions to perform model surgery on a pre trained model and design subsequent models in Equinox ?

I've noticed Flax has documentation of the mapping over some of their layers. I'd assume this will be a manual effort otherwise to replicate something similar ?

In a similar area of conversation it's very easy to .apply the params into Flax layer instances I have not found any straightforward ways to load in params using Equinox

I am either going down the route of:

What is the best path to pursue of the 3 above, if at all any of these ?

Thank you as always for your hard work and would be grateful for your guidance on where I can focus my contributions in this regard.

patrick-kidger commented 9 months ago

When I do this, I tend to use eqx.tree_at as suggested in your second bullet point :)

mattf1n commented 7 months ago

@dieterichlawson I would like to write a script that loads pre-trained weights for minGPT using the Flax weights from HuggingFace. Would this work, or are there architectural differences? I am thinking something along these lines would work, just need to line up the parameters properly.

def gpt2_of_hf_flax_params(flax_params):
    flax_params = flax_params["transformer"]
    gpt2 = GPT2(...)
    flax_params = eqx.tree_at(
        where=lambda params: tuple(layer["attn"] for layer in params["h"].values()),
        pytree=flax_params,
        replace_fn=lambda attn: (
            attn["c_attn"]["kernel"],
            attn["c_attn"]["bias"],
            attn["c_proj"]["kernel"],
            attn["c_proj"]["bias"],
        ),
    )
    flax_layers_params = tuple(
        (
            layer["ln_1"]["scale"],
            layer["ln_1"]["bias"],
            layer["attn"],
            layer["ln_2"]["scale"],
            layer["ln_2"]["bias"],
            layer["mlp"]["c_fc"]["kernel"],
            layer["mlp"]["c_fc"]["bias"],
            layer["mlp"]["c_proj"]["kernel"],
            layer["mlp"]["c_proj"]["bias"],
        )
        for layer in flax_params["h"].values()
    )
    flax_params = (
        flax_params["wte"]["embedding"],
        flax_params["wpe"]["embedding"],
        flax_layers_params,
        flax_params["ln_f"]["scale"],
        flax_params["ln_f"]["bias"],
    )
    params, static = eqx.partition(gpt2, eqx.is_array)
    params = eqx.tree_at(
        where=lambda tree: jax.tree.flatten(tree)[0],
        pytree=params,
        replace=jax.tree.flatten(flax_params)[0],
    )
    gpt2 = eqx.combine(params, static)
    return gpt2
gil2rok commented 1 month ago

That's awesome! Do I understand that this is GPT2, as with Andrej's version?

Would you be okay with me linking to this as an example in the docs?

@patrick-kidger Please add a link to this in the doc! I think it'd be super helpful to users and I would have loved to see something like this!