danijar / dreamerv3

Mastering Diverse Domains through World Models
https://danijar.com/dreamerv3
MIT License
1.25k stars 216 forks source link

Exporting the JAX policy as a TF model #15

Open LipJ01 opened 1 year ago

LipJ01 commented 1 year ago

Congratulations Danijar on this project and your paper! Again, not really an issue per se. I understand having read and executed the example.py and most of the code that this project doesn't use Tensorflow in the way I'm familiar with and instead uses jax. I will endeavour to understand myself but I was wondering if it were "simple" to use jax2f to obtain a SavedModel? Ideally after training. Then if I'm feeling really brave I intend to use tfjs-converter to run inference in a web demo.

Update: A) I realise I can go straight from jax to tfjs. B) I also realise/think I understand that I'm actually going to have to get 3 nets converted, the world model, actor and critic. Then Implement dreamer in client side javascript. I'm becoming ever doubtful of my ability to pull this off but the payoff has this occupying my full attention (calendar emptied for next 3 days).

EelcoHoogendoorn commented 1 year ago

jax2tf allegedly has some limitations; but none that I think would be relevant here. The trained models are simply jax functions and it should be a few lines of code to save them using jax2tf. Certainly worth a try id say.

LipJ01 commented 1 year ago

I feel close. The conversion/export is as simple as follows 3*tfjs.converters.convert_jax( apply_fn=jax.jit(???), params=params, input_signatures=[tf.TensorSpec(???)], model_dir='/path/to/tfjs_models_directory' ) So far I've managed to get the params from data['agent'] (probably need different keys, I've not figured that out yet) I can probably figure out the TensorSpec using 'heuristic exploration'. I'm really struggling to find where in the code are the jax functions? 😣

EelcoHoogendoorn commented 1 year ago

ive only used conversion from jax2tf side, not via tfjs, so cant comment on that.

jaxagent.policy seems relevant as a function to export for instance; though im still in the process of figuring out the docstrings myself

danijar commented 1 year ago

@LipJ01 The JAX functions are in jaxagent.py. For inference, you'll only need self._policy and self._init_policy.

ChrisAGBlake commented 1 year ago

@LipJ01 Have you had any success with this? I'm also looking to export trained models on a custom environment to tensorflow and then ONNX.

@danijar Thank you so much for publishing this. It's game changing, particularly for hard exploration environments or working with "slow" simulators. Other methods I've tested just aren't feasible and are far too time consuming / expensive to train.

LipJ01 commented 1 year ago

@chrisagblake afraid not 🙃

ChrisAGBlake commented 11 months ago

I managed to convert it to tensorflow but wasn't able to convert to tflite or onnx as it seems to use some operations that aren't supported by those.

For anyone interested here's a bit of an example. This is hard coded for my observation space which is a vector and no image. I'm not sure if this is the best way of doing it but it seems to be working.Manually specifying the function signatures and their modification was necessary to be able to use the tensorflow model from the c++ API but it wasn't required for using it with tensorflow in python.

    # load the trained weights
    checkpoint = embodied.Checkpoint()
    checkpoint.agent = agent
    logdir = embodied.Path(config.logdir)
    checkpoint.load(logdir / 'checkpoint.ckpt', keys=['agent'])

    # modify the init policy function signature so it's compatible with c++
    def mod_init_policy(weights, rng, is_first):
        ((latent, action), task_state, expl_state), _ = agent._init_policy(weights, rng, is_first)
        out = latent
        out['action'] = action
        return out

    # modify the policy function signature so it's compatible with c++
    def mod_policy(weights, rng, vector, reward, is_first, is_last, is_terminal, deter, logit, stoch, action):
        obs = {'vector': vector, 'reward': reward, 'is_first': is_first, 'is_last': is_last, 'is_terminal': is_terminal}
        state = (({'deter': deter, 'logit': logit, 'stoch': stoch}, action), {}, {})
        (outs, state), _ = agent._policy(weights, rng, obs, state)
        (latent, action), _, _ = state
        out = latent
        out['action'] = action
        for k, v in outs.items():
            out[f'outs_{k}'] = v
        return out

    # export policy init to tensorflow
    tf_agent = tf.Module()
    weights = tf.nest.map_structure(tf.Variable, agent.varibs)
    init_f = lambda rng, is_first: jax2tf.convert(mod_init_policy)(weights, rng, is_first)
    policy_f = lambda rng, vector, reward, is_first, is_last, is_terminal, deter, logit, stoch, action: jax2tf.convert(mod_policy)(weights, rng, vector, reward, is_first, is_last, is_terminal, deter, logit, stoch, action)
    tf_agent._variables = tf.nest.flatten(weights)
    tf_agent.init_policy = tf.function(init_f, autograph=False)
    tf_agent.policy = tf.function(policy_f, autograph=False)
    rng = agent._next_rngs(agent.policy_devices)
    obs = agent._dummy_batch(env.obs_space, (1,))
    out = tf_agent.init_policy(rng, obs['is_first'])
    prev_latent = {k: v for k, v in out.items() if k != 'action'}
    prev_action = out['action']
    init_call = tf_agent.init_policy.get_concrete_function(rng, obs['is_first'])
    rng = agent._next_rngs(agent.policy_devices)
    obs = agent._dummy_batch(env.obs_space, (1,))
    vector = obs['vector']
    reward = obs['reward']
    is_first = obs['is_first']
    is_last = obs['is_last']
    is_terminal = obs['is_terminal']
    deter = prev_latent['deter']
    logit = prev_latent['logit']
    stoch = prev_latent['stoch']
    tf_agent.policy(rng, vector, reward, is_first, is_last, is_terminal, deter, logit, stoch, prev_action)
    call = tf_agent.policy.get_concrete_function(rng, vector, reward, is_first, is_last, is_terminal, deter, logit, stoch, prev_action)
    tf.saved_model.save(tf_agent, 'models/tf_agent', signatures={'policy_init': init_call, 'policy': call})