KindXiaoming / pykan

Kolmogorov Arnold Networks
MIT License
14.65k stars 1.35k forks source link

CUDA training #84

Open Maunberg opened 4 months ago

Maunberg commented 4 months ago

I tried to use KANLayer with CUDA, but there is an error

test = KANLayer(2, 2, device='cuda') test(torch.tensor([[0.2, .32]]).cuda())

->


RuntimeError Traceback (most recent call last) Cell In[47], line 1 ----> 1 test(torch.tensor([[0.2, .32]]).cuda())

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /opt/conda/lib/python3.10/site-packages/kan/KANLayer.py:178, in KANLayer.forward(self, x) 176 y = y.permute(1,0) # shape (batch, size) 177 postspline = y.clone().reshape(batch, self.out_dim, self.in_dim) --> 178 y = self.scale_base.unsqueeze(dim=0) base + self.scale_sp.unsqueeze(dim=0) y 179 y = self.mask[None,:] * y 180 postacts = y.clone().reshape(batch, self.out_dim, self.in_dim)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!`

It was here:

def forward(self, x): batch = x.shape[0]

x: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim)

    x = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0)
    preacts = x.permute(1, 0).clone().reshape(batch, self.out_dim, self.in_dim)
    base = self.base_fun(x).permute(1, 0)  # shape (batch, size)
    y = coef2curve(x_eval=x, grid=self.grid[self.weight_sharing], coef=self.coef[self.weight_sharing] **!!not on self.device!!**, k=self.k, device=self.device)  # shape (size, batch)
    y = y.permute(1, 0)  # shape (batch, size)
    postspline = y.clone().reshape(batch, self.out_dim, self.in_dim)
    y = self.scale_base.unsqueeze(dim=0) * base + self.scale_sp.unsqueeze(dim=0) * y
    y = self.mask[None, :] * y
    postacts = y.clone().reshape(batch, self.out_dim, self.in_dim)
    y = torch.sum(y.reshape(batch, self.out_dim, self.in_dim), dim=2)`

After rebugging using cuda becomes available. BUT omptimeser felt with the same mistake


RuntimeError Traceback (most recent call last) Cell In[44], line 1 ----> 1 trainer.fit(model, train_loader)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 542 self.state.status = TrainerStatus.RUNNING 543 self.training = True --> 544 call._call_and_handle_interrupt( 545 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 546 )

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, kwargs) 42 if trainer.strategy.launcher is not None: 43 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, *kwargs) ---> 44 return trainer_fn(args, kwargs) 46 except _TunerExitException: 47 _call_teardown_hook(trainer)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 573 assert self.state.fn is not None 574 ckpt_path = self._checkpoint_connector._select_ckpt_path( 575 self.state.fn, 576 ckpt_path, 577 model_provided=True, 578 model_connected=self.lightning_module is not None, 579 ) --> 580 self._run(model, ckpt_path=ckpt_path) 582 assert self.state.stopped 583 self.training = False

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:987, in Trainer._run(self, model, ckpt_path) 982 self._signal_connector.register_signal_handlers() 984 # ---------------------------- 985 # RUN THE TRAINER 986 # ---------------------------- --> 987 results = self._run_stage() 989 # ---------------------------- 990 # POST-Training CLEAN UP 991 # ---------------------------- 992 log.debug(f"{self.class.name}: trainer tearing down")

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1033, in Trainer._run_stage(self) 1031 self._run_sanity_check() 1032 with torch.autograd.set_detect_anomaly(self._detect_anomaly): -> 1033 self.fit_loop.run() 1034 return None 1035 raise RuntimeError(f"Unexpected state {self.state}")

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:205, in _FitLoop.run(self) 203 try: 204 self.on_advance_start() --> 205 self.advance() 206 self.on_advance_end() 207 self._restarting = False

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:363, in _FitLoop.advance(self) 361 with self.trainer.profiler.profile("run_training_epoch"): 362 assert self._data_fetcher is not None --> 363 self.epoch_loop.run(self._data_fetcher)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140, in _TrainingEpochLoop.run(self, data_fetcher) 138 while not self.done: 139 try: --> 140 self.advance(data_fetcher) 141 self.on_advance_end(data_fetcher) 142 self._restarting = False

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py:250, in _TrainingEpochLoop.advance(self, data_fetcher) 247 with trainer.profiler.profile("run_training_batch"): 248 if trainer.lightning_module.automatic_optimization: 249 # in automatic optimization, there can only be one optimizer --> 250 batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs) 251 else: 252 batch_output = self.manual_optimization.run(kwargs)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py:190, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs) 183 closure() 185 # ------------------------------ 186 # BACKWARD PASS 187 # ------------------------------ 188 # gradient update with accumulated gradients 189 else: --> 190 self._optimizer_step(batch_idx, closure) 192 result = closure.consume_result() 193 if result.loss is None:

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py:268, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure) 265 self.optim_progress.optimizer.step.increment_ready() 267 # model hook --> 268 call._call_lightning_module_hook( 269 trainer, 270 "optimizer_step", 271 trainer.current_epoch, 272 batch_idx, 273 optimizer, 274 train_step_and_backward_closure, 275 ) 277 if not should_accumulate: 278 self.optim_progress.optimizer.step.increment_completed()

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:157, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, *kwargs) 154 pl_module._current_fx_name = hook_name 156 with trainer.profiler.profile(f"[LightningModule]{pl_module.class.name}.{hook_name}"): --> 157 output = fn(args, **kwargs) 159 # restore current_fx when nested context 160 pl_module._current_fx_name = prev_fx_name

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/module.py:1303, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure) 1264 def optimizer_step( 1265 self, 1266 epoch: int, (...) 1269 optimizer_closure: Optional[Callable[[], Any]] = None, 1270 ) -> None: 1271 r"""Override this method to adjust the default way the :class:~pytorch_lightning.trainer.trainer.Trainer calls 1272 the optimizer. 1273 (...) 1301 1302 """ -> 1303 optimizer.step(closure=optimizer_closure)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py:152, in LightningOptimizer.step(self, closure, kwargs) 149 raise MisconfigurationException("When optimizer.step(closure) is called, the closure should be callable") 151 assert self._strategy is not None --> 152 step_output = self._strategy.optimizer_step(self._optimizer, closure, kwargs) 154 self._on_after_step() 156 return step_output

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:239, in Strategy.optimizer_step(self, optimizer, closure, model, kwargs) 237 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed 238 assert isinstance(model, pl.LightningModule) --> 239 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, kwargs)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py:122, in Precision.optimizer_step(self, optimizer, model, closure, kwargs) 120 """Hook to run the optimizer step.""" 121 closure = partial(self._wrap_closure, model, optimizer, closure) --> 122 return optimizer.step(closure=closure, kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/optim/optimizer.py:373, in Optimizer.profile_hook_step..wrapper(*args, *kwargs) 368 else: 369 raise RuntimeError( 370 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}." 371 ) --> 373 out = func(args, **kwargs) 374 self._optimizer_step_code() 376 # call optimizer step post hooks

File /opt/conda/lib/python3.10/site-packages/torch/optim/optimizer.py:76, in _use_grad_for_differentiable.._use_grad(self, *args, *kwargs) 74 torch.set_grad_enabled(self.defaults['differentiable']) 75 torch._dynamo.graph_break() ---> 76 ret = func(self, args, **kwargs) 77 finally: 78 torch._dynamo.graph_break()

File /opt/conda/lib/python3.10/site-packages/torch/optim/adamw.py:184, in AdamW.step(self, closure) 171 beta1, beta2 = group["betas"] 173 self._init_group( 174 group, 175 params_with_grad, (...) 181 state_steps, 182 ) --> 184 adamw( 185 params_with_grad, 186 grads, 187 exp_avgs, 188 exp_avg_sqs, 189 max_exp_avg_sqs, 190 state_steps, 191 amsgrad=amsgrad, 192 beta1=beta1, 193 beta2=beta2, 194 lr=group["lr"], 195 weight_decay=group["weight_decay"], 196 eps=group["eps"], 197 maximize=group["maximize"], 198 foreach=group["foreach"], 199 capturable=group["capturable"], 200 differentiable=group["differentiable"], 201 fused=group["fused"], 202 grad_scale=getattr(self, "grad_scale", None), 203 found_inf=getattr(self, "found_inf", None), 204 ) 206 return loss

File /opt/conda/lib/python3.10/site-packages/torch/optim/adamw.py:335, in adamw(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, foreach, capturable, differentiable, fused, grad_scale, found_inf, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize) 332 else: 333 func = _single_tensor_adamw --> 335 func( 336 params, 337 grads, 338 exp_avgs, 339 exp_avg_sqs, 340 max_exp_avg_sqs, 341 state_steps, 342 amsgrad=amsgrad, 343 beta1=beta1, 344 beta2=beta2, 345 lr=lr, 346 weight_decay=weight_decay, 347 eps=eps, 348 maximize=maximize, 349 capturable=capturable, 350 differentiable=differentiable, 351 grad_scale=grad_scale, 352 found_inf=found_inf, 353 )

File /opt/conda/lib/python3.10/site-packages/torch/optim/adamw.py:509, in _multi_tensor_adamw(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, grad_scale, found_inf, amsgrad, beta1, beta2, lr, weight_decay, eps, maximize, capturable, differentiable) 505 assert not differentiable, "_foreach ops don't support autograd" 507 assert grad_scale is None and found_inf is None --> 509 grouped_tensors = Optimizer._group_tensors_by_device_and_dtype([ 510 params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]) 511 for (( 512 device_params, 513 device_grads, (...) 517 device_statesteps, 518 ), ) in grouped_tensors.values(): 519 if maximize:

File /opt/conda/lib/python3.10/site-packages/torch/optim/optimizer.py:397, in Optimizer._group_tensors_by_device_and_dtype(tensorlistlist, with_indices) 395 return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))} 396 else: --> 397 return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices)

File /opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, kwargs) 112 @functools.wraps(func) 113 def decorate_context(*args, *kwargs): 114 with ctx_factory(): --> 115 return func(args, kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/utils/_foreach_utils.py:42, in _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) 34 @no_grad() 35 def _group_tensors_by_device_and_dtype( 36 tensorlistlist: TensorListList, 37 with_indices: bool = False, 38 ) -> Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]]: 39 return { 40 (device, getattr(torch, str_dtype)): value 41 for (device, str_dtype), value in ---> 42 torch._C._group_tensors_by_device_and_dtype(tensorlistlist, with_indices).items() 43 }

RuntimeError: Tensors of the same index must be on the same device and the same dtype except step tensors that can be CPU and float32 notwithstanding`

Unfortunately, here I don't know how to check mistake. How to solve that&

AlessandroFlati commented 4 months ago

This issue is solved by https://github.com/KindXiaoming/pykan/pull/98