jihan1218 / brax

3 stars 0 forks source link

Using pre-trained model #2

Closed Icepomel0 closed 6 months ago

Icepomel0 commented 6 months ago

Hi there,

Thanks for your contribution and your package update!

I would like to ask if I am using the pre-trained model and I would like to not train it for more timestep, how would I set the num_timesteps variable? Should I set it to 0 or really small number or I stick to the pre-trained model num_timesteps ?

For example, If I have a pre-trained model for 100000 timesteps and I am feeling good about it, when I try to directly use it to visualize the model, should I set functools.partial.num_timesteps to 100000 or 0?

Really thanks for your help!

jihan1218 commented 6 months ago

Hi What are you trying to do? If your goal is to visualize pre-trained model, you need to have both parameters and inference function. Let's say that this is your first training with 10 evaluations.

train_fn = functools.partial(
    ppo.train, 
    num_timesteps = 100_000_000, 
    num_evals = 10, 
    REST OF YOUR TRAINING PARAMETERS)

def progress(num_steps, metrics):
    YOUR OWN PROGRESS FUNCTION

make_inference_fn, params, _= train_fn(environment = env, progress_fn = progress)

Once the training is done you should save both make_inference_fn, params. You can save them using following.

# save params
model.save_params(model_path, params)

# save inference func
with open(full_path, 'wb') as f:
    dill.dump(make_inference_fn, f)

To visualize your trained model, you need to call

# load saved model
saved_params = model.load_params(model_path)

# Load saved inference function 
with open(full_path, 'rb') as f:
    make_inference_fn = dill.load(f)

inference_fn = make_inference_fn(saved_params)
jit_inference_fn = jax.jit(inference_fn)

After checking the result, if you want to train it further simply you can load the saved parameters.

train_fn = functools.partial(
    ppo.train, 
    num_timesteps = 100_000_000, 
    num_evals = 100, 
    previous_params = saved_params,
    REST OF YOUR TRAINING PARAMETERS)

Here, I continue training with 100 more evaluations. Does this answer to your question?

Icepomel0 commented 6 months ago

Hi Jihan,

Yes I wonderfully answered my question! Thanks so much for this!

Best

Get Outlook for iOShttps://aka.ms/o0ukef


From: Jihan @.> Sent: Wednesday, March 13, 2024 10:37:30 AM To: jihan1218/brax @.> Cc: Yiming Xie @.>; Author @.> Subject: Re: [jihan1218/brax] Using pre-trained model (Issue #2)

Hi What are you trying to do? If your goal is to visualize pre-trained model, you need to have both parameters and inference function. Let's say that this is your first training with 10 evaluations.

train_fn = functools.partial( ppo.train, num_timesteps = 100_000_000, num_evals = 10, REST OF YOUR TRAINING PARAMETERS)

def progress(num_steps, metrics): YOUR OWN PROGRESS FUNCTION

make_inferencefn, params, = train_fn(environment = env, progress_fn = progress)

Once the training is done you should save both make_inference_fn, params. You can save them using following.

save params

model.save_params(model_path, params)

save inference func

with open(full_path, 'wb') as f: dill.dump(make_inference_fn, f)

To visualize your trained model, you need to call

load saved model

saved_params = model.load_params(model_path)

Load saved inference function

with open(full_path, 'rb') as f: make_inference_fn = dill.load(f)

inference_fn = make_inference_fn(saved_params) jit_inference_fn = jax.jit(inference_fn)

After checking the result, if you want to train it further simply you can load the saved parameters.

train_fn = functools.partial( ppo.train, num_timesteps = 100_000_000, num_evals = 100, previous_params = saved_params, REST OF YOUR TRAINING PARAMETERS)

Here, I continue training with 100 more evaluations. Does this answer to your question?

— Reply to this email directly, view it on GitHubhttps://github.com/jihan1218/brax/issues/2#issuecomment-1992741246, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ARBUQKP4JXNWSKYXYMJORCLYX6GTVAVCNFSM6AAAAABESQADK6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSOJSG42DCMRUGY. You are receiving this because you authored the thread.Message ID: @.***>