Open sukkritsharmaofficial opened 9 months ago
@sukkritsharmaofficial Can you provide more details?
Thank you!
Hey @awaelchli , so i was trying to train a sentence transformer model using lightning, heres the model class :
class SentenceSimilarityModel(pl.LightningModule):
def __init__(self):
super().__init__()
# Initialize a sentence transformer model
self.model = SentenceTransformer('sentence-transformers/stsb-xlm-r-multilingual')
self.loss_fn = nn.TripletMarginLoss(margin=1.0, p=2)
def forward(self, input_ids):
# The model handles tokenization internally
self.model.train()
torch.set_grad_enabled(True)
return self.model.encode(input_ids, convert_to_tensor=True)
def training_step(self, batch, batch_idx):
anchor, positive, negative = batch
anchor_emb = self(anchor)
positive_emb = self(positive)
negative_emb = self(negative)
loss = self.loss_fn(anchor_emb, positive_emb, negative_emb)
self.log('train_loss', loss, batch_size=len(batch))
return loss
def validation_step(self, batch, batch_idx):
anchor, positive, negative = batch
anchor_emb = self(anchor)
positive_emb = self(positive)
negative_emb = self(negative)
loss = self.loss_fn(anchor_emb, positive_emb, negative_emb)
self.log('val_loss', loss, batch_size=len(batch))
def test_step(self, batch, batch_idx):
anchor, positive, negative = batch
anchor_emb = self(anchor)
positive_emb = self(positive)
negative_emb = self(negative)
loss = self.loss_fn(anchor_emb, positive_emb, negative_emb)
self.log('test_loss', loss, batch_size=len(batch))
def configure_optimizers(self):
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.model.parameters()), lr=3e-6)
lr_scheduler = {
'scheduler': torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=3e-4, # Adjust as needed
steps_per_epoch=len(train_loader),
epochs=self.trainer.max_epochs,
anneal_strategy='linear' # or 'cos'
),
'interval': 'step',
'frequency': 1
}
return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
ERROR That i'm getting, full stack trace :
Cell In[14], line 1
----> 1 trainer.fit(model, train_loader, val_loader)
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
542 self.state.status = TrainerStatus.RUNNING
543 self.training = True
--> 544 call._call_and_handle_interrupt(
545 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
546 )
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
42 if trainer.strategy.launcher is not None:
43 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44 return trainer_fn(*args, **kwargs)
46 except _TunerExitException:
47 _call_teardown_hook(trainer)
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
573 assert self.state.fn is not None
574 ckpt_path = self._checkpoint_connector._select_ckpt_path(
575 self.state.fn,
576 ckpt_path,
577 model_provided=True,
578 model_connected=self.lightning_module is not None,
579 )
--> 580 self._run(model, ckpt_path=ckpt_path)
582 assert self.state.stopped
583 self.training = False
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:989, in Trainer._run(self, model, ckpt_path)
984 self._signal_connector.register_signal_handlers()
986 # ----------------------------
987 # RUN THE TRAINER
988 # ----------------------------
--> 989 results = self._run_stage()
991 # ----------------------------
992 # POST-Training CLEAN UP
993 # ----------------------------
994 log.debug(f"{self.__class__.__name__}: trainer tearing down")
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1035, in Trainer._run_stage(self)
1033 self._run_sanity_check()
1034 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1035 self.fit_loop.run()
1036 return None
1037 raise RuntimeError(f"Unexpected state {self.state}")
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:202, in _FitLoop.run(self)
200 try:
201 self.on_advance_start()
--> 202 self.advance()
203 self.on_advance_end()
204 self._restarting = False
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:359, in _FitLoop.advance(self)
357 with self.trainer.profiler.profile("run_training_epoch"):
358 assert self._data_fetcher is not None
--> 359 self.epoch_loop.run(self._data_fetcher)
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py:136, in _TrainingEpochLoop.run(self, data_fetcher)
134 while not self.done:
135 try:
--> 136 self.advance(data_fetcher)
137 self.on_advance_end(data_fetcher)
138 self._restarting = False
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py:240, in _TrainingEpochLoop.advance(self, data_fetcher)
237 with trainer.profiler.profile("run_training_batch"):
238 if trainer.lightning_module.automatic_optimization:
239 # in automatic optimization, there can only be one optimizer
--> 240 batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
241 else:
242 batch_output = self.manual_optimization.run(kwargs)
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py:187, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs)
180 closure()
182 # ------------------------------
183 # BACKWARD PASS
184 # ------------------------------
185 # gradient update with accumulated gradients
186 else:
--> 187 self._optimizer_step(batch_idx, closure)
189 result = closure.consume_result()
190 if result.loss is None:
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py:265, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
262 self.optim_progress.optimizer.step.increment_ready()
264 # model hook
--> 265 call._call_lightning_module_hook(
266 trainer,
267 "optimizer_step",
268 trainer.current_epoch,
269 batch_idx,
270 optimizer,
271 train_step_and_backward_closure,
272 )
274 if not should_accumulate:
275 self.optim_progress.optimizer.step.increment_completed()
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:157, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
154 pl_module._current_fx_name = hook_name
156 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 157 output = fn(*args, **kwargs)
159 # restore current_fx when nested context
160 pl_module._current_fx_name = prev_fx_name
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/module.py:1291, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
1252 def optimizer_step(
1253 self,
1254 epoch: int,
(...)
1257 optimizer_closure: Optional[Callable[[], Any]] = None,
1258 ) -> None:
1259 r"""Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls
1260 the optimizer.
1261
(...)
1289
1290 """
-> 1291 optimizer.step(closure=optimizer_closure)
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py:151, in LightningOptimizer.step(self, closure, **kwargs)
148 raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
150 assert self._strategy is not None
--> 151 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
153 self._on_after_step()
155 return step_output
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:230, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
228 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
229 assert isinstance(model, pl.LightningModule)
--> 230 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/amp.py:74, in MixedPrecision.optimizer_step(self, optimizer, model, closure, **kwargs)
65 def optimizer_step( # type: ignore[override]
66 self,
67 optimizer: Optimizable,
(...)
70 **kwargs: Any,
71 ) -> Any:
72 if self.scaler is None:
73 # skip scaler logic, as bfloat16 does not require scaler
---> 74 return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
75 if isinstance(optimizer, LBFGS):
76 raise MisconfigurationException("AMP and the LBFGS optimizer are not compatible.")
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py:117, in Precision.optimizer_step(self, optimizer, model, closure, **kwargs)
115 """Hook to run the optimizer step."""
116 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 117 return optimizer.step(closure=closure, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:68, in LRScheduler.__init__.<locals>.with_counter.<locals>.wrapper(*args, **kwargs)
66 instance._step_count += 1
67 wrapped = func.__get__(instance, cls)
---> 68 return wrapped(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/optim/optimizer.py:373, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
368 else:
369 raise RuntimeError(
370 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
371 )
--> 373 out = func(*args, **kwargs)
374 self._optimizer_step_code()
376 # call optimizer step post hooks
File /opt/conda/lib/python3.10/site-packages/torch/optim/optimizer.py:76, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
74 torch.set_grad_enabled(self.defaults['differentiable'])
75 torch._dynamo.graph_break()
---> 76 ret = func(self, *args, **kwargs)
77 finally:
78 torch._dynamo.graph_break()
File /opt/conda/lib/python3.10/site-packages/torch/optim/adamw.py:161, in AdamW.step(self, closure)
159 if closure is not None:
160 with torch.enable_grad():
--> 161 loss = closure()
163 for group in self.param_groups:
164 params_with_grad = []
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py:104, in Precision._wrap_closure(self, model, optimizer, closure)
91 def _wrap_closure(
92 self,
93 model: "pl.LightningModule",
94 optimizer: Optimizer,
95 closure: Callable[[], Any],
96 ) -> Any:
97 """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
98 hook is called.
99
(...)
102
103 """
--> 104 closure_result = closure()
105 self._after_closure(model, optimizer)
106 return closure_result
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py:140, in Closure.__call__(self, *args, **kwargs)
139 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 140 self._result = self.closure(*args, **kwargs)
141 return self._result.loss
File /opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py:135, in Closure.closure(self, *args, **kwargs)
132 self._zero_grad_fn()
134 if self._backward_fn is not None and step_output.closure_loss is not None:
--> 135 self._backward_fn(step_output.closure_loss)
137 return step_output
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py:236, in _AutomaticOptimization._make_backward_fn.<locals>.backward_fn(loss)
235 def backward_fn(loss: Tensor) -> None:
--> 236 call._call_strategy_hook(self.trainer, "backward", loss, optimizer)
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:309, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
306 return None
308 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 309 output = fn(*args, **kwargs)
311 # restore current_fx when nested context
312 pl_module._current_fx_name = prev_fx_name
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:204, in Strategy.backward(self, closure_loss, optimizer, *args, **kwargs)
201 assert self.lightning_module is not None
202 closure_loss = self.precision_plugin.pre_backward(closure_loss, self.lightning_module)
--> 204 self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
206 closure_loss = self.precision_plugin.post_backward(closure_loss, self.lightning_module)
207 self.post_backward(closure_loss)
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py:69, in Precision.backward(self, tensor, model, optimizer, *args, **kwargs)
50 def backward( # type: ignore[override]
51 self,
52 tensor: Tensor,
(...)
56 **kwargs: Any,
57 ) -> None:
58 r"""Performs the actual backpropagation.
59
60 Args:
(...)
67
68 """
---> 69 model.backward(tensor, *args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/module.py:1078, in LightningModule.backward(self, loss, *args, **kwargs)
1076 self._fabric.backward(loss, *args, **kwargs)
1077 else:
-> 1078 loss.backward(*args, **kwargs)
File /opt/conda/lib/python3.10/site-packages/torch/_tensor.py:492, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
482 if has_torch_function_unary(self):
483 return handle_torch_function(
484 Tensor.backward,
485 (self,),
(...)
490 inputs=inputs,
491 )
--> 492 torch.autograd.backward(
493 self, gradient, retain_graph, create_graph, inputs=inputs
494 )
File /opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py:251, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
246 retain_graph = create_graph
248 # The reason we repeat the same comment below is that
249 # some Python versions print out the first line of a multi-line function
250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
252 tensors,
253 grad_tensors_,
254 retain_graph,
255 create_graph,
256 inputs,
257 allow_unreachable=True,
258 accumulate_grad=True,
259 )
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
How do i fix this?
@awaelchli anything on this?
Description & Motivation
Currently we cant train a sentence transformer model, it throws an error, can we have support for this as well so that we can train the model more effectively and use all gpus?
Sentence Transformer Library : https://github.com/UKPLab/sentence-transformers/tree/master
Pitch
No response
Alternatives
No response
Additional context
No response
cc @borda