ai4co / rl4co

A PyTorch library for all things Reinforcement Learning (RL) for Combinatorial Optimization (CO)
https://rl4.co
MIT License
381 stars 70 forks source link

[BUG] cant run the quickstart #206

Closed sunweice closed 2 weeks ago

sunweice commented 2 weeks ago

Describe the bug

I'm using RL4CO, and when I followed the method on this website (https://rl4co.readthedocs.io/en/latest/examples/1-quickstart/)for testing, I encountered an error. What should I do?

To Reproduce

IN step “trainer.fit(model)”

the error is :

val_file not set. Generating dataset instead test_file not set. Generating dataset instead

ValueError Traceback (most recent call last) Cell In[5], line 1 ----> 1 trainer.fit(model)

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\utils\trainer.py:146, in RL4COTrainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 141 log.warning( 142 "Overriding gradient_clip_val to None for 'automatic_optimization=False' models" 143 ) 144 self.gradient_clip_val = None --> 146 super().fit( 147 model=model, 148 train_dataloaders=train_dataloaders, 149 val_dataloaders=val_dataloaders, 150 datamodule=datamodule, 151 ckpt_path=ckpt_path, 152 )

File ~\anaconda3\envs\RL4CO\lib\site-packages\lightning\pytorch\trainer\trainer.py:538, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 536 self.state.status = TrainerStatus.RUNNING 537 self.training = True --> 538 call._call_and_handle_interrupt( 539 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 540 )

File ~\anaconda3\envs\RL4CO\lib\site-packages\lightning\pytorch\trainer\call.py:47, in _call_and_handle_interrupt(trainer, trainer_fn, *args, kwargs) 45 if trainer.strategy.launcher is not None: 46 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, *kwargs) ---> 47 return trainer_fn(args, kwargs) 49 except _TunerExitException: 50 _call_teardown_hook(trainer)

File ~\anaconda3\envs\RL4CO\lib\site-packages\lightning\pytorch\trainer\trainer.py:574, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 567 assert self.state.fn is not None 568 ckpt_path = self._checkpoint_connector._select_ckpt_path( 569 self.state.fn, 570 ckpt_path, 571 model_provided=True, 572 model_connected=self.lightning_module is not None, 573 ) --> 574 self._run(model, ckpt_path=ckpt_path) 576 assert self.state.stopped 577 self.training = False

File ~\anaconda3\envs\RL4CO\lib\site-packages\lightning\pytorch\trainer\trainer.py:943, in Trainer._run(self, model, ckpt_path) 940 log.debug(f"{self.class.name}: preparing data") 941 self._data_connector.prepare_data() --> 943 call._call_setup_hook(self) # allow user to set up LightningModule in accelerator environment 944 log.debug(f"{self.class.name}: configuring model") 945 call._call_configure_model(self)

File ~\anaconda3\envs\RL4CO\lib\site-packages\lightning\pytorch\trainer\call.py:104, in _call_setup_hook(trainer) 102 _call_lightning_datamodule_hook(trainer, "setup", stage=fn) 103 _call_callback_hooks(trainer, "setup", stage=fn) --> 104 _call_lightning_module_hook(trainer, "setup", stage=fn) 106 trainer.strategy.barrier("post_setup")

File ~\anaconda3\envs\RL4CO\lib\site-packages\lightning\pytorch\trainer\call.py:167, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, *kwargs) 164 pl_module._current_fx_name = hook_name 166 with trainer.profiler.profile(f"[LightningModule]{pl_module.class.name}.{hook_name}"): --> 167 output = fn(args, **kwargs) 169 # restore current_fx when nested context 170 pl_module._current_fx_name = prev_fx_name

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\models\rl\common\base.py:155, in RL4COLitModule.setup(self, stage) 153 self.dataloader_names = None 154 self.setup_loggers() --> 155 self.post_setup_hook()

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\models\rl\reinforce\reinforce.py:110, in REINFORCE.post_setup_hook(self, stage) 108 def post_setup_hook(self, stage="fit"): 109 # Make baseline taking model itself and train_dataloader from model as input --> 110 self.baseline.setup( 111 self.policy, 112 self.env, 113 batch_size=self.val_batch_size, 114 device=get_lightning_device(self), 115 dataset_size=self.data_cfg["val_data_size"], 116 )

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\models\rl\reinforce\baselines.py:117, in WarmupBaseline.setup(self, *args, kw) 116 def setup(self, *args, *kw): --> 117 self.baseline.setup(args, kw)

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\models\rl\reinforce\baselines.py:174, in RolloutBaseline.setup(self, *args, kw) 173 def setup(self, *args, *kw): --> 174 self._update_policy(args, kw)

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\models\rl\reinforce\baselines.py:187, in RolloutBaseline._update_policy(self, policy, env, batch_size, device, dataset_size, dataset) 183 self.dataset = env.dataset(batch_size=[dataset_size]) 185 log.info("Evaluating baseline policy on evaluation dataset") 186 self.bl_vals = ( --> 187 self.rollout(self.policy, env, batch_size, device, self.dataset).cpu().numpy() 188 ) 189 self.mean = self.bl_vals.mean()

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\models\rl\reinforce\baselines.py:242, in RolloutBaseline.rollout(self, policy, env, batch_size, device, dataset) 238 return policy(batch, env, decode_type="greedy")["reward"] 240 dl = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn) --> 242 rewards = torch.cat([eval_policy(batch) for batch in dl], 0) 243 return rewards

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\models\rl\reinforce\baselines.py:242, in (.0) 238 return policy(batch, env, decode_type="greedy")["reward"] 240 dl = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn) --> 242 rewards = torch.cat([eval_policy(batch) for batch in dl], 0) 243 return rewards

File ~\anaconda3\envs\RL4CO\lib\site-packages\torch\utils\data\dataloader.py:630, in _BaseDataLoaderIter.next(self) 627 if self._sampler_iter is None: 628 # TODO(https://github.com/pytorch/pytorch/issues/76750) 629 self._reset() # type: ignore[call-arg] --> 630 data = self._next_data() 631 self._num_yielded += 1 632 if self._dataset_kind == _DatasetKind.Iterable and \ 633 self._IterableDataset_len_called is not None and \ 634 self._num_yielded > self._IterableDataset_len_called:

File ~\anaconda3\envs\RL4CO\lib\site-packages\torch\utils\data\dataloader.py:673, in _SingleProcessDataLoaderIter._next_data(self) 671 def _next_data(self): 672 index = self._next_index() # may raise StopIteration --> 673 data = self._dataset_fetcher.fetch(index) # may raise StopIteration 674 if self._pin_memory: 675 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File ~\anaconda3\envs\RL4CO\lib\site-packages\torch\utils\data_utils\fetch.py:55, in _MapDatasetFetcher.fetch(self, possibly_batched_index) 53 else: 54 data = self.dataset[possibly_batched_index] ---> 55 return self.collate_fn(data)

File ~\anaconda3\envs\RL4CO\lib\site-packages\rl4co\data\dataset.py:37, in TensorDictDataset.collate_fn(batch) 34 @staticmethod 35 def collate_fn(batch: Union[dict, TensorDict]): 36 """Collate function compatible with TensorDicts that reassembles a list of dicts.""" ---> 37 return TensorDict( 38 {key: torch.stack([b[key] for b in batch]) for key in batch[0].keys()}, 39 batch_size=torch.Size([len(batch)]), 40 _run_checks=False, 41 )

File ~\anaconda3\envs\RL4CO\lib\site-packages\tensordict_td.py:240, in TensorDict.init(self, source, batch_size, device, names, non_blocking, lock, kwargs) 229 def init( 230 self, 231 source: T | dict[str, CompatibleType] = None, (...) 237 kwargs, 238 ) -> None: 239 if (source is not None) and kwargs: --> 240 raise ValueError( 241 "Either a dictionary or a sequence of kwargs must be provided, not both." 242 ) 243 source = source if not kwargs else kwargs 244 if names and is_dynamo_compiling():

ValueError: Either a dictionary or a sequence of kwargs must be provided, not both.

sunweice commented 2 weeks ago
@staticmethod
def collate_fn(batch: Union[dict, TensorDict]):
    """Collate function compatible with TensorDicts that reassembles a list of dicts."""
    return TensorDict(
        {key: torch.stack([b[key] for b in batch]) for key in batch[0].keys()},
        batch_size=torch.Size([len(batch)]),
        _run_checks=False,
    )

It seems I've found the reason. In (https://github.com/pytorch/tensordict/pull/175#discussion_r1084354980), it's mentioned that TensorDict does not support the _run_checks parameter. and after I delete the _run_checks=False,the code can run

fedebotu commented 2 weeks ago

@sunweice yep, that's correct! We recently updated the development version of RL4CO to make it compatible with the latest TensorDict (only that change was needed).

Note we should officially release the latest RL4CO version in the coming month, with lots of updates!