Alescontrela / viper_rl

Using advances in generative modeling to learn reward functions from unlabeled videos.
MIT License
103 stars 11 forks source link

ScopeParamNotFoundError: Could not find parameter named "kernel" in scope "/encoder/Downsample_0/conv". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopePar amNotFoundError) #6

Closed SKYWALKERRAY closed 4 months ago

SKYWALKERRAY commented 5 months ago

It seems the code is having problem with finding the kernel weights?

UnfilteredStackTrace: flax.errors.ScopeParamNotFoundError: Could not find parameter named "kernel" in scope "/encoder/Downsample_0/conv". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopePar amNotFoundError)

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

╭────────────────────────── Traceback (most recent call last) ───────────────────────────╮ │ /home/asc/viper_rl/scripts/train_dreamer.py:247 in │ │ │ │ 244 │ │ 245 │ │ 246 if name == 'main': │ │ ❱ 247 main() │ │ 248 │ │ │ │ /home/asc/viper_rl/scripts/train_dreamer.py:60 in main │ │ │ │ 57 │ replay = make_replay(config, logdir / 'replay', replay_kwargs) │ │ 58 │ cleanup.append(env) │ │ 59 │ agent = agt.Agent(env.obs_space, env.act_space, step, config) │ │ ❱ 60 │ embodied.run.train(agent, env, replay, logger, args) │ │ 61 │ │ │ 62 │ elif args.script == 'train_amp': │ │ 63 │ reference_replay = make_replay(config, config.reference_dir, is_eval=False │ │ │ │ /home/asc/viper_rl/viper_rl/dreamerv3/embodied/run/train.py:68 in train │ │ │ │ 65 │ print("Prefill train dataset.") │ │ 66 │ random_agent = embodied.RandomAgent(env.act_space) │ │ 67 │ while len(replay) < max(args.batch_steps, args.train_fill): │ │ ❱ 68 │ │ driver(random_agent.policy, steps=100) │ │ 69 │ logger.add(metrics.result()) │ │ 70 │ logger.write() │ │ 71 │ │ │ │ /home/asc/viper_rl/viper_rl/dreamerv3/embodied/core/driver.py:42 in call │ │ │ │ 39 │ def call(self, policy, steps=0, episodes=0): │ │ 40 │ │ step, episode = 0, 0 │ │ 41 │ │ while step < steps or episode < episodes: │ │ ❱ 42 │ │ │ step, episode = self._step(policy, step, episode) │ │ 43 │ │ │ 44 │ def _step(self, policy, step, episode): │ │ 45 │ │ assert all(len(x) == len(self._env) for x in self._acts.values()) │ │ │ │ /home/asc/viper_rl/viper_rl/dreamerv3/embodied/core/driver.py:65 in _step │ │ │ │ 62 │ │ for i in range(len(self._env)): │ │ 63 │ │ │ trn = {k: v[i] for k, v in trns.items()} │ │ 64 │ │ │ [self._eps[i][k].append(v) for k, v in trn.items()] │ │ ❱ 65 │ │ │ [fn(trn, i, self._kwargs) for fn in self._on_steps] │ │ 66 │ │ │ step += 1 │ │ 67 │ │ if obs["is_last"].any(): │ │ 68 │ │ │ for i, done in enumerate(obs["is_last"]): │ │ │ │ /home/asc/viper_rl/viper_rl/dreamerv3/embodied/core/driver.py:65 in │ │ │ │ 62 │ │ for i in range(len(self._env)): │ │ 63 │ │ │ trn = {k: v[i] for k, v in trns.items()} │ │ 64 │ │ │ [self._eps[i][k].append(v) for k, v in trn.items()] │ │ ❱ 65 │ │ │ [fn(trn, i, self._kwargs) for fn in self._on_steps] │ │ 66 │ │ │ step += 1 │ │ 67 │ │ if obs["is_last"].any(): │ │ 68 │ │ │ for i, done in enumerate(obs["is_last"]): │ │ │ │ /home/asc/anaconda3/envs/viper/lib/python3.8/contextlib.py:75 in inner │ │ │ │ 72 │ │ @wraps(func) │ │ 73 │ │ def inner(*args, *kwds): │ │ 74 │ │ │ with self._recreate_cm(): │ │ ❱ 75 │ │ │ │ return func(args, kwds) │ │ 76 │ │ return inner │ │ 77 │ │ 78 │ │ │ │ /home/asc/viper_rl/viper_rl/dreamerv3/embodied/replay/replays.py:86 in add │ │ │ │ 83 │ │ │ 84 │ def add(self, step, worker=0, load=False): │ │ 85 │ │ if self.add_mode == "chunk": │ │ ❱ 86 │ │ │ return self.add_chunk(step, worker, load) │ │ 87 │ │ elif self.add_mode == "episode": │ │ 88 │ │ │ return self.add_episode(step, worker, load) │ │ 89 │ │ │ │ /home/asc/viper_rl/viper_rl/dreamerv3/embodied/replay/replays.py:120 in add_chunk │ │ │ │ 117 │ │ │ # Only handle continuous sequences for now. │ │ 118 │ │ │ return │ │ 119 │ │ elif not next_step_processed: │ │ ❱ 120 │ │ │ seq = self.reward_model(tuple(stream)) │ │ 121 │ │ elif seq_has_last_step: │ │ 122 │ │ │ split_idx = 0 │ │ 123 │ │ │ for i in range(len(stream)): │ │ │ │ /home/asc/viper_rl/viper_rl/videogpt/reward_models/videogpt_reward_model.py:111 in │ │ call │ │ │ │ 108 │ │ │ f'\n\tn_skip? {self.n_skip}') │ │ 109 │ │ │ 110 │ def call(self, seq, kwargs): │ │ ❱ 111 │ │ return self.process_seq(self.compute_reward(seq, kwargs), *kwargs) │ │ 112 │ │ │ 113 │ def rollout_video(self, init_frames, video_length, seed=0, open_loop_ctx=4, │ │ 114 │ │ if inputs_are_codes: │ │ │ │ /home/asc/viper_rl/viper_rl/videogpt/reward_models/videogpt_reward_model.py:219 in │ │ compute_reward │ │ │ │ 216 │ │ # Compute encodings and embeddings for image sequence. │ │ 217 │ │ image_batch = jnp.stack([seq[i][self.camera_key] for i in range(start_id │ │ 218 │ │ image_batch = self.process_images(image_batch) │ │ ❱ 219 │ │ encodings = self.ae.encode(jnp.expand_dims(image_batch, axis=0)) │ │ 220 │ │ embeddings = self.ae.lookup(encodings) │ │ 221 │ │ encodings, embeddings = encodings[0], embeddings[0] │ │ 222 │ │ │ │ /home/asc/viper_rl/viper_rl/videogpt/models/init.py:142 in encode │ │ │ │ 139 │ │ │ is_pre_encoded = len(video.shape) == 5 │ │ 140 │ │ if is_pre_encoded: │ │ 141 │ │ │ return video │ │ ❱ 142 │ │ encodings = self._encode(video) │ │ 143 │ │ return encodings │ │ 144 │ │ │ 145 │ def decode(self, encodings): │ │ │ │ /home/asc/viper_rl/viper_rl/videogpt/models/init.py:85 in fn │ │ │ │ 82 │ │ def fn(variables, video): │ │ 83 │ │ │ T = video.shape[1] │ │ 84 │ │ │ video = video.reshape(-1, video.shape[2:]) │ │ ❱ 85 │ │ │ out = self.ae.apply( │ │ 86 │ │ │ │ variables, │ │ 87 │ │ │ │ video, │ │ 88 │ │ │ │ deterministic=True, │ │ │ │ /home/asc/viper_rl/viper_rl/videogpt/models/vqgan.py:62 in encode │ │ │ │ 59 │ │ return recon │ │ 60 │ │ │ 61 │ def encode(self, image, deterministic=True): │ │ ❱ 62 │ │ h = self.encoder(image, deterministic=deterministic) │ │ 63 │ │ h = self.quant_conv(h) │ │ 64 │ │ vq_out = self.quantize(h) │ │ 65 │ │ for downsample_layer in self.downsample_layers: │ │ │ │ /home/asc/viper_rl/viper_rl/videogpt/models/vqgan.py:172 in call │ │ │ │ 169 │ │ │ │ if cur_res in self.attn_resolutions: │ │ 170 │ │ │ │ │ h = AttnBlock()(h) │ │ 171 │ │ │ if i_level != num_resolutions - 1: │ │ ❱ 172 │ │ │ │ h = Downsample(all_strides[i_level], self.resample_with_conv)(h) │ │ 173 │ │ │ │ cur_res //= 2 │ │ 174 │ │ │ │ 175 │ │ h = ResnetBlock(dropout=self.dropout, deterministic=deterministic)(h) │ │ │ │ /home/asc/viper_rl/viper_rl/videogpt/models/vqgan.py:281 in call │ │ │ │ 278 │ def call(self, x): │ │ 279 │ │ print(f"Input shape: {x.shape}") │ │ 280 │ │ if self.with_conv: │ │ ❱ 281 │ │ │ x = self.conv(x) # 使用 setup 中定义的卷积层 │ │ 282 │ │ else: │ │ 283 │ │ │ x = nn.avg_pool(x, self.strides, strides=self.strides) │ │ 284 │ │ print(f"Output shape: {x.shape}") │ │ │ │ /home/asc/anaconda3/envs/viper/lib/python3.8/site-packages/flax/linen/linear.py:480 in │ │ call │ │ │ │ 477 │ │ f'Shapes are: {self.mask.shape}, {kernel_shape}' │ │ 478 │ ) │ │ 479 │ │ │ ❱ 480 │ kernel = self.param( │ │ 481 │ │ 'kernel', self.kernel_init, kernel_shape, self.param_dtype │ │ 482 │ ) │ │ 483 │ ╰────────────────────────────────────────────────────────────────────────────────────────╯ ScopeParamNotFoundError: Could not find parameter named "kernel" in scope "/encoder/Downsample_0/conv". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopePar amNotFoundError)

SKYWALKERRAY commented 4 months ago

Found the problem. You have to run the two steps of training vqgan and videogpt before running the dreamer, or it will cause a problem.

NonsansWD commented 1 month ago

Found the problem. You have to run the two steps of training vqgan and videogpt before running the dreamer, or it will cause a problem.

But arent there checkpoints already? I dont really understand why i have to run it in the exact order when these only produce something that is of the same structure as the checkpoints. Also i unfortunately dont even know where training vqgan puts the resulting model :/ Did you just let it run til it terminated and then ran the videogpt til it terminated and then ran the policy training til it terminated?

SKYWALKERRAY commented 1 month ago

There are checkpoints that are not fully trained , such as the VQGAN may be trained for only 10 epochs, while it need 100 epochs to be usable, and actually if you just ran the program it will finish when it reaches the default epoch.