Closed ClashLuke closed 2 years ago
It's tested and works. (Temporary) API of a model (ppl=5.5, global_min_ppl=2.5) trained for a day is available orbscale.com. A simple CLI like the one below works well for local testing without deploying the model:
import requests
import argparse
URL = "https://orbscale.com/"
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--prompt", type=str)
parser.add_argument("--temperature", type=float, default=1.)
parser.add_argument("--top-p", type=float, default=1.)
parser.add_argument("--top-k", type=float, default=64)
parser.add_argument("--length", type=int, default=128)
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()
out = requests.post(URL, json={"prompt": args.prompt, "temperature": args.temperature, "top_p": args.top_p, "top_k": args.top_k, "length": args.length, "seed": args.seed})
response = out.json()["completion"]
print(response)
if __name__ == '__main__':
main()
We can now save and restore the structure of an arbitrary Jax PyTree.
We can restore an original model without storing its config or recreating its parameters. Similarly, we can load a checkpoint and continue training without instantiating new optimiser parameters that we would only replace. Dumping PyTrees improves efficiency and elegance while allowing for adaptive filtering of weights and insertion into new structures without manually aligning them.
Other than that, the fixes are non-issues. The checkpoint/restore code is tested, and an evaluation model is currently being trained.