jihan1218 / brax

3 stars 0 forks source link

Only inference #3

Open Stefano-retinize opened 3 months ago

Stefano-retinize commented 3 months ago

Hello, thanks for this code.

For my particular use case I only want to do inference with a pretrained model. So I want to load the model and do inference.

In the readme you explained how to continue training, is it possible to do only inference with a pretrained model without calling the train function?

thanks

jihan1218 commented 3 months ago

In order to use inference you must save it. after your initial training, save it.

# Save inference function
filename = 'exp'+str(exp_num)+'_make_inference.dill'
directory = 'Your own directory'
full_path = directory + filename
with open(full_path, 'wb') as f:
    dill.dump(make_inference_fn, f)

then you can load it for later usage.

# Load saved inference function 
with open(full_path, 'rb') as f:
    make_inference_fn = dill.load(f)

inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)