HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
45 stars 6 forks source link

Checkpoint, Restore and Inference #27

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

We can now save and restore the structure of an arbitrary Jax PyTree.

>>> example_pytree = jax.tree_util.tree_flatten({'2':{'a':jax.numpy.ones((256,))}})[1]
>>> str(example_pytree)
"PyTreeDef({'2': {'a': *}})"
>>> (txt := str(example_pytree).replace('PyTreeDef', '')[1:-1].replace(': *', ': null').replace("{'", '{"').replace("':", '":').replace("', ", '", ').replace(", '", ', "'))
'{"2": {"a": null}}'
>>> # text is now dumpable as text and can be reloaded as json using stdlib
>>> (d := json.loads(txt))
{'2': {'a': None}}
>>> (d := deep_replace(d, jnp.zeros((1,)))
{'2': {'a': DeviceArray([0.], dtype=float32)}}
>>> jax.tree_util.tree_flatten(arr)[1]   # get the original tree
PyTreeDef({'2': {'a': *}})

We can restore an original model without storing its config or recreating its parameters. Similarly, we can load a checkpoint and continue training without instantiating new optimiser parameters that we would only replace. Dumping PyTrees improves efficiency and elegance while allowing for adaptive filtering of weights and insertion into new structures without manually aligning them.

Other than that, the fixes are non-issues. The checkpoint/restore code is tested, and an evaluation model is currently being trained.

ClashLuke commented 2 years ago

It's tested and works. (Temporary) API of a model (ppl=5.5, global_min_ppl=2.5) trained for a day is available orbscale.com. A simple CLI like the one below works well for local testing without deploying the model:

import requests
import argparse

URL = "https://orbscale.com/"

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--prompt", type=str)
    parser.add_argument("--temperature", type=float, default=1.)
    parser.add_argument("--top-p", type=float, default=1.)
    parser.add_argument("--top-k", type=float, default=64)
    parser.add_argument("--length", type=int, default=128)
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()
    out = requests.post(URL, json={"prompt": args.prompt, "temperature": args.temperature, "top_p": args.top_p, "top_k": args.top_k, "length": args.length, "seed": args.seed})
    response = out.json()["completion"]
    print(response)

if __name__ == '__main__':
    main()