Closed Icepomel0 closed 8 months ago
Hi What are you trying to do? If your goal is to visualize pre-trained model, you need to have both parameters and inference function. Let's say that this is your first training with 10 evaluations.
train_fn = functools.partial(
ppo.train,
num_timesteps = 100_000_000,
num_evals = 10,
REST OF YOUR TRAINING PARAMETERS)
def progress(num_steps, metrics):
YOUR OWN PROGRESS FUNCTION
make_inference_fn, params, _= train_fn(environment = env, progress_fn = progress)
Once the training is done you should save both make_inference_fn, params. You can save them using following.
# save params
model.save_params(model_path, params)
# save inference func
with open(full_path, 'wb') as f:
dill.dump(make_inference_fn, f)
To visualize your trained model, you need to call
# load saved model
saved_params = model.load_params(model_path)
# Load saved inference function
with open(full_path, 'rb') as f:
make_inference_fn = dill.load(f)
inference_fn = make_inference_fn(saved_params)
jit_inference_fn = jax.jit(inference_fn)
After checking the result, if you want to train it further simply you can load the saved parameters.
train_fn = functools.partial(
ppo.train,
num_timesteps = 100_000_000,
num_evals = 100,
previous_params = saved_params,
REST OF YOUR TRAINING PARAMETERS)
Here, I continue training with 100 more evaluations. Does this answer to your question?
Hi Jihan,
Yes I wonderfully answered my question! Thanks so much for this!
Best
Get Outlook for iOShttps://aka.ms/o0ukef
From: Jihan @.> Sent: Wednesday, March 13, 2024 10:37:30 AM To: jihan1218/brax @.> Cc: Yiming Xie @.>; Author @.> Subject: Re: [jihan1218/brax] Using pre-trained model (Issue #2)
Hi What are you trying to do? If your goal is to visualize pre-trained model, you need to have both parameters and inference function. Let's say that this is your first training with 10 evaluations.
train_fn = functools.partial( ppo.train, num_timesteps = 100_000_000, num_evals = 10, REST OF YOUR TRAINING PARAMETERS)
def progress(num_steps, metrics): YOUR OWN PROGRESS FUNCTION
make_inferencefn, params, = train_fn(environment = env, progress_fn = progress)
Once the training is done you should save both make_inference_fn, params. You can save them using following.
model.save_params(model_path, params)
with open(full_path, 'wb') as f: dill.dump(make_inference_fn, f)
To visualize your trained model, you need to call
saved_params = model.load_params(model_path)
with open(full_path, 'rb') as f: make_inference_fn = dill.load(f)
inference_fn = make_inference_fn(saved_params) jit_inference_fn = jax.jit(inference_fn)
After checking the result, if you want to train it further simply you can load the saved parameters.
train_fn = functools.partial( ppo.train, num_timesteps = 100_000_000, num_evals = 100, previous_params = saved_params, REST OF YOUR TRAINING PARAMETERS)
Here, I continue training with 100 more evaluations. Does this answer to your question?
— Reply to this email directly, view it on GitHubhttps://github.com/jihan1218/brax/issues/2#issuecomment-1992741246, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ARBUQKP4JXNWSKYXYMJORCLYX6GTVAVCNFSM6AAAAABESQADK6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSOJSG42DCMRUGY. You are receiving this because you authored the thread.Message ID: @.***>
Hi there,
Thanks for your contribution and your package update!
I would like to ask if I am using the pre-trained model and I would like to not train it for more timestep, how would I set the
num_timesteps
variable? Should I set it to 0 or really small number or I stick to the pre-trained modelnum_timesteps
?For example, If I have a pre-trained model for 100000 timesteps and I am feeling good about it, when I try to directly use it to visualize the model, should I set
functools.partial.num_timesteps
to 100000 or 0?Really thanks for your help!