philippe-eecs / IDQL

Repo for Implicit Diffusion Q-Learning
85 stars 11 forks source link

ValueError: Custom node type mismatch: expected type: <class 'flax.core.frozen_dict.FrozenDict'> #4

Closed JinGuang-cuhksz closed 7 months ago

JinGuang-cuhksz commented 9 months ago

Hello, I met a value error when running the code directly. I extracted the relevant code in Section 1 and got the error in Section 2. It works after I add an extra code actor_params = freeze(actor_params). However, the performance is less than 80 for the walker2d-med task. I'm not sure whether this addition actor_params = freeze(actor_params) is correct. Could you help me? Thanks a lot.

Code

from functools import partial
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
from flax.core import freeze
from flax.training.train_state import TrainState
from jaxrl5.networks import MLP, DDPM, FourierFeatures, MLPResNet, get_weight_decay_mask

def mish(x):
    return x * jnp.tanh(nn.softplus(x))

rng = jax.random.PRNGKey(0)
rng, actor_key, critic_key, value_key = jax.random.split(rng, 4)

preprocess_time_cls = partial(FourierFeatures,
                                output_size=128,
                                learnable=True)

cond_model_cls = partial(MLP,
                        hidden_dims=(128, 128),
                        activations=mish,
                        activate_final=False)

base_model_cls = partial(MLPResNet, use_layer_norm=True,
                                     num_blocks=3,
                                     dropout_rate=0.1,
                                     out_dim=6,
                                     activations=mish)

actor_def = DDPM(time_preprocess_cls=preprocess_time_cls,
                    cond_encoder_cls=cond_model_cls,
                    reverse_encoder_cls=base_model_cls)

actor_params = actor_def.init(actor_key, jnp.ones((1, 17)), jnp.ones((1, 5)),
                                jnp.ones((1, 1)))['params']

# actor_params = freeze(actor_params)

score_model = TrainState.create(apply_fn=actor_def.apply,
                                params=actor_params,
                                tx=optax.adamw(learning_rate=3e-4, 
                                                weight_decay=0.0,
                                                mask=get_weight_decay_mask,))

Error

>>> python test_file.py 
2023-12-05 05:30:41.920337: W external/xla/xla/service/gpu/nvptx_compiler.cc:679] The NVIDIA driver's CUDA version is 11.4 which is older than the ptxas CUDA version (11.8.89). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Traceback (most recent call last):
  File "test_file.py", line 40, in <module>
    score_model = TrainState.create(apply_fn=actor_def.apply,
  File "/root/miniconda3/envs/jax/lib/python3.9/site-packages/flax/training/train_state.py", line 110, in create
    opt_state = tx.init(params_with_opt)
  File "/root/miniconda3/envs/jax/lib/python3.9/site-packages/optax/_src/combine.py", line 50, in init_fn
    return tuple(fn(params) for fn in init_fns)
  File "/root/miniconda3/envs/jax/lib/python3.9/site-packages/optax/_src/combine.py", line 50, in <genexpr>
    return tuple(fn(params) for fn in init_fns)
  File "/root/miniconda3/envs/jax/lib/python3.9/site-packages/optax/_src/wrappers.py", line 516, in init_fn
    masked_params = mask_pytree(params, mask_tree)
  File "/root/miniconda3/envs/jax/lib/python3.9/site-packages/optax/_src/wrappers.py", line 497, in mask_pytree
    return tree_map(lambda m, p: p if m else MaskedNode(), mask_tree, pytree)
  File "/root/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/tree_util.py", line 243, in tree_map
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
  File "/root/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/tree_util.py", line 243, in <listcomp>
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Custom node type mismatch: expected type: <class 'flax.core.frozen_dict.FrozenDict'>, value: {'FourierFeatures_0': {'kernel': Array([[-0.12764107],
       [ 0.18242973],
       [-0.05237452],
       [ 0.07523398],
       [-0.05189894],
      ......

Env

flax=0.7.5=pypi_0
jax=0.4.20=pypi_0
jaxlib=0.4.20+cuda11.cudnn86=pypi_0
philippe-eecs commented 9 months ago

What happens when you remove the weight decay and mask input into adamw? Does the issue persist? I haven't seen this before.

score_model = TrainState.create(apply_fn=actor_def.apply, params=actor_params, tx=optax.adamw(learning_rate=3e-4)

Also, make sure to train for the full 3 million steps to get above 80 on walker2d-med (there is some variance, of course).

philippe-eecs commented 9 months ago

Btw, I just pushed some bug fixes that will speed up the code a fair amount during evaluation. One thing I wanted to do was choose the checkpoint for the BC actor that achieved the lowest validation loss.

I also removed the masking on the weights for weight decay since you had this bug,

JinGuang-cuhksz commented 9 months ago

Thank you so much for your kind and fast response. The value error disappears.

I will try to use your new code to train for the full 3 million steps.

JinGuang-cuhksz commented 9 months ago

Thanks for your suggestions, but I need help getting the desired results. I get the performance $78.17\pm 2.24$ (N=256) and $76.46\pm 0.82$ (N=64) for the walker2d-med task.

iShot_2023-12-11_15 53 22

Could you please share your checkpoints so that I can get evaluation results similar to the papers'? Thank you so much.

philippe-eecs commented 9 months ago

These seem pretty inline with what I got. I made some small changes to the code I recently pushed to make it run faster, such as reducing the batch size down and adding layernorm to IQL to avoid some instability issues. This would probably explain the small 2-4 point deviation.

I don't have my old checkpoints unfortunately. If you really want to get past 80, you might want to revert the code to the prior commit and run it but it will take 2-3x runtime + evaluation.

Here were my learning curve for walker2d medium on N 64 and N 256 from the prior commit. I'm guessing my small changes reduced performance slightly.

image

image

JinGuang-cuhksz commented 9 months ago

Thanks for your kind and fast response. Your performance seems pretty good and stable. Since I need to get the checkpoints with good performance, I need to turn to your old version by increasing the batch size and deleting the layernorm and hope it works.

JinGuang-cuhksz commented 9 months ago

Hello. Do the line and shadow in your figures mean the average performance and the confidence interval of 10 runs, respectively?

I would appreciate your advice. I used the previous version's code and got an average above 80. But it is slightly worse than yours. Could you share the random seed, because I find your code doesn't specify the random seed?

image

philippe-eecs commented 9 months ago

Your results look good. After re-checking the appendix of my paper, I used N=128 for the locomotion results and N=32 for antmaze.

Yes, the shadow is standard deviation from my runs.

If the other results in other papers are within 2 standard deviation of my results. I mark both as bold. There is usually 4-10 points "interval". Feel free to bold however you like, just report how you do it.

I believe I selected totally random rngs, but I think you can fix it by setting the seed in the code to a fixed integer.

Feel free to report the results from your run instead of from the paper!

philippe-eecs commented 9 months ago

Oh, hang on. Did you run for "3 million steps?". You just need to input 1.5 million because the actor takes two gradient steps per critic gradient step. I just report the number of critic gradients. You can just take the results from the 1.5 million evaluation if you'd like.

JinGuang-cuhksz commented 9 months ago

OK. Thank you so much.