Open pannous opened 1 year ago
I don't think I know of an Equinox version of these models. There shouldn't be any incompatibility; just that no-one has yet taken the time to implement them!
I used equinox to write a version of Andrej Karpathy's "MinGPT" that may be of interest:
https://github.com/dieterichlawson/mingpt-jax
I think some parts are actually cleaner than the PyTorch implementation, such as figuring out which parameters to weight decay.
That's awesome! Do I understand that this is GPT2, as with Andrej's version?
Would you be okay with me linking to this as an example in the docs?
How would you approach Pytorch model conversions to perform model surgery on a pre trained model and design subsequent models in Equinox ?
I've noticed Flax has documentation of the mapping over some of their layers. I'd assume this will be a manual effort otherwise to replicate something similar ?
In a similar area of conversation it's very easy to .apply the params into Flax layer instances I have not found any straightforward ways to load in params using Equinox
I am either going down the route of:
eqx.tree_deserialise_leaves(buffer, model_cls)
for all blocks and layersWhat is the best path to pursue of the 3 above, if at all any of these ?
Thank you as always for your hard work and would be grateful for your guidance on where I can focus my contributions in this regard.
When I do this, I tend to use eqx.tree_at
as suggested in your second bullet point :)
@dieterichlawson I would like to write a script that loads pre-trained weights for minGPT using the Flax weights from HuggingFace. Would this work, or are there architectural differences? I am thinking something along these lines would work, just need to line up the parameters properly.
def gpt2_of_hf_flax_params(flax_params):
flax_params = flax_params["transformer"]
gpt2 = GPT2(...)
flax_params = eqx.tree_at(
where=lambda params: tuple(layer["attn"] for layer in params["h"].values()),
pytree=flax_params,
replace_fn=lambda attn: (
attn["c_attn"]["kernel"],
attn["c_attn"]["bias"],
attn["c_proj"]["kernel"],
attn["c_proj"]["bias"],
),
)
flax_layers_params = tuple(
(
layer["ln_1"]["scale"],
layer["ln_1"]["bias"],
layer["attn"],
layer["ln_2"]["scale"],
layer["ln_2"]["bias"],
layer["mlp"]["c_fc"]["kernel"],
layer["mlp"]["c_fc"]["bias"],
layer["mlp"]["c_proj"]["kernel"],
layer["mlp"]["c_proj"]["bias"],
)
for layer in flax_params["h"].values()
)
flax_params = (
flax_params["wte"]["embedding"],
flax_params["wpe"]["embedding"],
flax_layers_params,
flax_params["ln_f"]["scale"],
flax_params["ln_f"]["bias"],
)
params, static = eqx.partition(gpt2, eqx.is_array)
params = eqx.tree_at(
where=lambda tree: jax.tree.flatten(tree)[0],
pytree=params,
replace=jax.tree.flatten(flax_params)[0],
)
gpt2 = eqx.combine(params, static)
return gpt2
That's awesome! Do I understand that this is GPT2, as with Andrej's version?
Would you be okay with me linking to this as an example in the docs?
@patrick-kidger Please add a link to this in the doc! I think it'd be super helpful to users and I would have loved to see something like this!
does it run any of the public GPT models, or are the data structures fundamentally incompatible?