Slimnios / SaGess

SaGess denoising diffusion model
3 stars 1 forks source link

AssertionError #6

Open tonyPo opened 5 months ago

tonyPo commented 5 months ago

Dear developpers,

I'm running the algorithm on the Cora (same on Wiki) dataset, with the name = test, in a Jupiter notebook. During training I receive the following error message

AssertionError: expected size 16==16, stride 2432==2560 at dim=0

This happens during the execution of the trainer.fit method.

Could you please look into this.

Kind regards, Ton Poppe

This is the stack trace

AssertionError Traceback (most recent call last) /Users/tonpoppe/workspace/synth_graph_baselines/sagess/SaGess/src/demo.ipynb Cell 11 line 6 4 train_start_time = time.time() 5 #trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) ----> 6 trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) 7 train_end_time = time.time() 8 train_time = train_end_time - train_start_time

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 542 self.state.status = TrainerStatus.RUNNING 543 self.training = True --> 544 call._call_and_handle_interrupt( 545 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 546 )

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, kwargs) 42 if trainer.strategy.launcher is not None: 43 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, *kwargs) ---> 44 return trainer_fn(args, kwargs) 46 except _TunerExitException: 47 _call_teardown_hook(trainer)

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 573 assert self.state.fn is not None 574 ckpt_path = self._checkpoint_connector._select_ckpt_path( 575 self.state.fn, 576 ckpt_path, 577 model_provided=True, 578 model_connected=self.lightning_module is not None, 579 ) --> 580 self._run(model, ckpt_path=ckpt_path) 582 assert self.state.stopped 583 self.training = False

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:989, in Trainer._run(self, model, ckpt_path) 984 self._signal_connector.register_signal_handlers() 986 # ---------------------------- 987 # RUN THE TRAINER 988 # ---------------------------- --> 989 results = self._run_stage() 991 # ---------------------------- 992 # POST-Training CLEAN UP 993 # ---------------------------- 994 log.debug(f"{self.class.name}: trainer tearing down")

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1035, in Trainer._run_stage(self) 1033 self._run_sanity_check() 1034 with torch.autograd.set_detect_anomaly(self._detect_anomaly): -> 1035 self.fit_loop.run() 1036 return None 1037 raise RuntimeError(f"Unexpected state {self.state}")

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:202, in _FitLoop.run(self) 200 try: 201 self.on_advance_start() --> 202 self.advance() 203 self.on_advance_end() 204 self._restarting = False

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:359, in _FitLoop.advance(self) 357 with self.trainer.profiler.profile("run_training_epoch"): 358 assert self._data_fetcher is not None --> 359 self.epoch_loop.run(self._data_fetcher)

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:136, in _TrainingEpochLoop.run(self, data_fetcher) 134 while not self.done: 135 try: --> 136 self.advance(data_fetcher) 137 self.on_advance_end(data_fetcher) 138 self._restarting = False

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:240, in _TrainingEpochLoop.advance(self, data_fetcher) 237 with trainer.profiler.profile("run_training_batch"): 238 if trainer.lightning_module.automatic_optimization: 239 # in automatic optimization, there can only be one optimizer --> 240 batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs) 241 else: 242 batch_output = self.manual_optimization.run(kwargs)

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/automatic.py:187, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs) 180 closure() 182 # ------------------------------ 183 # BACKWARD PASS 184 # ------------------------------ 185 # gradient update with accumulated gradients 186 else: --> 187 self._optimizer_step(batch_idx, closure) 189 result = closure.consume_result() 190 if result.loss is None:

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/automatic.py:265, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure) 262 self.optim_progress.optimizer.step.increment_ready() 264 # model hook --> 265 call._call_lightning_module_hook( 266 trainer, 267 "optimizer_step", 268 trainer.current_epoch, 269 batch_idx, 270 optimizer, 271 train_step_and_backward_closure, 272 ) 274 if not should_accumulate: 275 self.optim_progress.optimizer.step.increment_completed()

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:157, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, *kwargs) 154 pl_module._current_fx_name = hook_name 156 with trainer.profiler.profile(f"[LightningModule]{pl_module.class.name}.{hook_name}"): --> 157 output = fn(args, **kwargs) 159 # restore current_fx when nested context 160 pl_module._current_fx_name = prev_fx_name

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/core/module.py:1291, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure) 1252 def optimizer_step( 1253 self, 1254 epoch: int, (...) 1257 optimizer_closure: Optional[Callable[[], Any]] = None, 1258 ) -> None: 1259 r"""Override this method to adjust the default way the :class:~pytorch_lightning.trainer.trainer.Trainer calls 1260 the optimizer. 1261 (...) 1289 1290 """ -> 1291 optimizer.step(closure=optimizer_closure)

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/core/optimizer.py:151, in LightningOptimizer.step(self, closure, kwargs) 148 raise MisconfigurationException("When optimizer.step(closure) is called, the closure should be callable") 150 assert self._strategy is not None --> 151 step_output = self._strategy.optimizer_step(self._optimizer, closure, kwargs) 153 self._on_after_step() 155 return step_output

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py:230, in Strategy.optimizer_step(self, optimizer, closure, model, kwargs) 228 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed 229 assert isinstance(model, pl.LightningModule) --> 230 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, kwargs)

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/plugins/precision/precision.py:117, in Precision.optimizer_step(self, optimizer, model, closure, kwargs) 115 """Hook to run the optimizer step.""" 116 closure = partial(self._wrap_closure, model, optimizer, closure) --> 117 return optimizer.step(closure=closure, kwargs)

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/torch/optim/optimizer.py:391, in Optimizer.profile_hook_step..wrapper(*args, *kwargs) 386 else: 387 raise RuntimeError( 388 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}." 389 ) --> 391 out = func(args, **kwargs) 392 self._optimizer_step_code() 394 # call optimizer step post hooks

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/torch/optim/optimizer.py:76, in _use_grad_for_differentiable.._use_grad(self, *args, *kwargs) 74 torch.set_grad_enabled(self.defaults['differentiable']) 75 torch._dynamo.graph_break() ---> 76 ret = func(self, args, **kwargs) 77 finally: 78 torch._dynamo.graph_break()

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/torch/optim/adamw.py:165, in AdamW.step(self, closure) 163 if closure is not None: 164 with torch.enable_grad(): --> 165 loss = closure() 167 for group in self.param_groups: 168 params_with_grad = []

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/plugins/precision/precision.py:104, in Precision._wrap_closure(self, model, optimizer, closure) 91 def _wrap_closure( 92 self, 93 model: "pl.LightningModule", 94 optimizer: Optimizer, 95 closure: Callable[[], Any], 96 ) -> Any: 97 """This double-closure allows makes sure the closure is executed before the on_before_optimizer_step 98 hook is called. 99 (...) 102 103 """ --> 104 closure_result = closure() 105 self._after_closure(model, optimizer) 106 return closure_result

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/automatic.py:140, in Closure.call(self, *args, kwargs) 139 def call(self, *args: Any, *kwargs: Any) -> Optional[Tensor]: --> 140 self._result = self.closure(args, kwargs) 141 return self._result.loss

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, kwargs) 112 @functools.wraps(func) 113 def decorate_context(*args, *kwargs): 114 with ctx_factory(): --> 115 return func(args, kwargs)

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/automatic.py:135, in Closure.closure(self, *args, **kwargs) 132 self._zero_grad_fn() 134 if self._backward_fn is not None and step_output.closure_loss is not None: --> 135 self._backward_fn(step_output.closure_loss) 137 return step_output

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/loops/optimization/automatic.py:236, in _AutomaticOptimization._make_backward_fn..backward_fn(loss) 235 def backward_fn(loss: Tensor) -> None: --> 236 call._call_strategy_hook(self.trainer, "backward", loss, optimizer)

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:309, in _call_strategy_hook(trainer, hook_name, *args, *kwargs) 306 return None 308 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.class.name}.{hook_name}"): --> 309 output = fn(args, **kwargs) 311 # restore current_fx when nested context 312 pl_module._current_fx_name = prev_fx_name

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py:204, in Strategy.backward(self, closure_loss, optimizer, *args, *kwargs) 201 assert self.lightning_module is not None 202 closure_loss = self.precision_plugin.pre_backward(closure_loss, self.lightning_module) --> 204 self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, args, **kwargs) 206 closure_loss = self.precision_plugin.post_backward(closure_loss, self.lightning_module) 207 self.post_backward(closure_loss)

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/plugins/precision/precision.py:69, in Precision.backward(self, tensor, model, optimizer, args, kwargs) 50 def backward( # type: ignore[override] 51 self, 52 tensor: Tensor, (...) 56 kwargs: Any, 57 ) -> None: 58 r"""Performs the actual backpropagation. 59 60 Args: (...) 67 68 """ ---> 69 model.backward(tensor, args, **kwargs)

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/pytorch_lightning/core/module.py:1078, in LightningModule.backward(self, loss, *args, kwargs) 1076 self._fabric.backward(loss, *args, *kwargs) 1077 else: -> 1078 loss.backward(args, kwargs)

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/torch/_tensor.py:525, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs) 515 if has_torch_function_unary(self): 516 return handle_torch_function( 517 Tensor.backward, 518 (self,), (...) 523 inputs=inputs, 524 ) --> 525 torch.autograd.backward( 526 self, gradient, retain_graph, create_graph, inputs=inputs 527 )

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/torch/autograd/init.py:267, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs) 262 retain_graph = create_graph 264 # The reason we repeat the same comment below is that 265 # some Python versions print out the first line of a multi-line function 266 # calls in the traceback and some print out the last line --> 267 _engine_run_backward( 268 tensors, 269 gradtensors, 270 retain_graph, 271 create_graph, 272 inputs, 273 allow_unreachable=True, 274 accumulate_grad=True, 275 )

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/torch/autograd/graph.py:744, in _engine_run_backward(t_outputs, *args, *kwargs) 742 unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) 743 try: --> 744 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass 745 t_outputs, args, **kwargs 746 ) # Calls into the C++ engine to run the backward pass 747 finally: 748 if attach_logging_hooks:

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/torch/autograd/function.py:301, in BackwardCFunction.apply(self, args) 295 raise RuntimeError( 296 "Implementing both 'backward' and 'vjp' for a custom " 297 "Function is not allowed. You should only implement one " 298 "of them." 299 ) 300 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn --> 301 return user_fn(self, args)

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:882, in aot_dispatch_autograd..CompiledFunction.backward(ctx, flat_args) 880 out = CompiledFunctionBackward.apply(all_args) 881 else: --> 882 out = call_compiled_backward() 884 # TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here. 885 if CompiledFunction.maybe_subclass_metadata is not None:

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py:831, in aot_dispatch_autograd..CompiledFunction.backward..call_compiled_backward() 824 with tracing(saved_context), context(), track_graph_compiling( 825 aot_config, "backward" 826 ): 827 CompiledFunction.compiled_bw = aot_config.bw_compiler( 828 bw_module, placeholder_list 829 ) --> 831 out = call_func_at_runtime_with_args( 832 CompiledFunction.compiled_bw, 833 all_args, 834 steal_args=True, 835 disable_amp=disable_amp, 836 ) 838 out = functionalized_rng_runtime_epilogue( 839 CompiledFunction.metadata, out 840 ) 841 return tuple(out)

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py:113, in call_func_at_runtime_with_args(f, args, steal_args, disable_amp) 111 with context(): 112 if hasattr(f, "_boxed_call"): --> 113 out = normalize_as_list(f(args)) 114 else: 115 # TODO: Please remove soon 116 # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 117 warnings.warn( 118 "Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. " 119 "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. " 120 "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale." 121 )

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:451, in _TorchDynamoContext.call.._fn(*args, *kwargs) 449 prior = set_eval_frame(callback) 450 try: --> 451 return fn(args, **kwargs) 452 finally: 453 set_eval_frame(prior)

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/torch/_dynamo/external_utils.py:36, in wrap_inline..inner(*args, kwargs) 34 @functools.wraps(fn) 35 def inner(*args, *kwargs): ---> 36 return fn(args, kwargs)

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/torch/_inductor/codecache.py:906, in CompiledFxGraph.call(self, inputs) 905 def call(self, inputs: List[Any]) -> Any: --> 906 return self.get_current_callable()(inputs)

File ~/anaconda3/envs/sagess2/lib/python3.11/site-packages/torch/_inductor/codecache.py:934, in _run_from_cache(compiled_graph, inputs) 926 assert compiled_graph.artifact_path 927 compiled_graph.compiled_artifact = PyCodeCache.load_by_key_path( 928 compiled_graph.cache_key, 929 compiled_graph.artifact_path, 930 compiled_graph.cache_linemap, 931 compiled_graph.constants, 932 ).call --> 934 return compiled_graph.compiled_artifact(inputs)

File /var/folders/6g/xr0x199n03z1hfqs24wrzth40000gn/T/torchinductor_tonpoppe/jt/cjtnazhwvyl5db7g4553naqh45jphlrrszvc34p7aki4wtrpb25j.py:336, in call(args) 334 assert_size_stride(sum_1, (16, s0, 1, 8, 16), (128s0, 128, 128, 16, 1)) 335 assert_size_stride(view, (16s0, 128), (128, 1)) --> 336 assert_size_stride(unsqueeze, (16, 1, s0, 8, 16), (2560, 0, 128, 16, 1)) 337 assert_size_stride(view_3, (16, s0, 128), (128*s0, 128, 1)) 338 assert_size_stride(addmm_2, (16, 128), (128, 1))