Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.26k stars 3.38k forks source link

Feature for Training Sentence Transformer Model #19230

Open sukkritsharmaofficial opened 9 months ago

sukkritsharmaofficial commented 9 months ago

Description & Motivation

Currently we cant train a sentence transformer model, it throws an error, can we have support for this as well so that we can train the model more effectively and use all gpus?

Sentence Transformer Library : https://github.com/UKPLab/sentence-transformers/tree/master

Pitch

No response

Alternatives

No response

Additional context

No response

cc @borda

awaelchli commented 9 months ago

@sukkritsharmaofficial Can you provide more details?

Thank you!

sukkritsharmaofficial commented 9 months ago

Hey @awaelchli , so i was trying to train a sentence transformer model using lightning, heres the model class :

class SentenceSimilarityModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # Initialize a sentence transformer model
        self.model = SentenceTransformer('sentence-transformers/stsb-xlm-r-multilingual')
        self.loss_fn = nn.TripletMarginLoss(margin=1.0, p=2)

    def forward(self, input_ids):
        # The model handles tokenization internally
        self.model.train()
        torch.set_grad_enabled(True)
        return self.model.encode(input_ids, convert_to_tensor=True)

    def training_step(self, batch, batch_idx):
        anchor, positive, negative = batch
        anchor_emb = self(anchor)
        positive_emb = self(positive)
        negative_emb = self(negative)
        loss = self.loss_fn(anchor_emb, positive_emb, negative_emb)
        self.log('train_loss', loss, batch_size=len(batch))
        return loss

    def validation_step(self, batch, batch_idx):
        anchor, positive, negative = batch
        anchor_emb = self(anchor)
        positive_emb = self(positive)
        negative_emb = self(negative)
        loss = self.loss_fn(anchor_emb, positive_emb, negative_emb)
        self.log('val_loss', loss, batch_size=len(batch))

    def test_step(self, batch, batch_idx):
        anchor, positive, negative = batch
        anchor_emb = self(anchor)
        positive_emb = self(positive)
        negative_emb = self(negative)
        loss = self.loss_fn(anchor_emb, positive_emb, negative_emb)
        self.log('test_loss', loss, batch_size=len(batch))

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.model.parameters()), lr=3e-6)
        lr_scheduler = {
            'scheduler': torch.optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=3e-4,  # Adjust as needed
                steps_per_epoch=len(train_loader), 
                epochs=self.trainer.max_epochs,
                anneal_strategy='linear'  # or 'cos'
            ),
            'interval': 'step',
            'frequency': 1
        }
        return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}

ERROR That i'm getting, full stack trace :

Cell In[14], line 1
----> 1 trainer.fit(model, train_loader, val_loader)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    542 self.state.status = TrainerStatus.RUNNING
    543 self.training = True
--> 544 call._call_and_handle_interrupt(
    545     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    546 )

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42     if trainer.strategy.launcher is not None:
     43         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44     return trainer_fn(*args, **kwargs)
     46 except _TunerExitException:
     47     _call_teardown_hook(trainer)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    573 assert self.state.fn is not None
    574 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    575     self.state.fn,
    576     ckpt_path,
    577     model_provided=True,
    578     model_connected=self.lightning_module is not None,
    579 )
--> 580 self._run(model, ckpt_path=ckpt_path)
    582 assert self.state.stopped
    583 self.training = False

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:989, in Trainer._run(self, model, ckpt_path)
    984 self._signal_connector.register_signal_handlers()
    986 # ----------------------------
    987 # RUN THE TRAINER
    988 # ----------------------------
--> 989 results = self._run_stage()
    991 # ----------------------------
    992 # POST-Training CLEAN UP
    993 # ----------------------------
    994 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1035, in Trainer._run_stage(self)
   1033         self._run_sanity_check()
   1034     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1035         self.fit_loop.run()
   1036     return None
   1037 raise RuntimeError(f"Unexpected state {self.state}")

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:202, in _FitLoop.run(self)
    200 try:
    201     self.on_advance_start()
--> 202     self.advance()
    203     self.on_advance_end()
    204     self._restarting = False

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:359, in _FitLoop.advance(self)
    357 with self.trainer.profiler.profile("run_training_epoch"):
    358     assert self._data_fetcher is not None
--> 359     self.epoch_loop.run(self._data_fetcher)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py:136, in _TrainingEpochLoop.run(self, data_fetcher)
    134 while not self.done:
    135     try:
--> 136         self.advance(data_fetcher)
    137         self.on_advance_end(data_fetcher)
    138         self._restarting = False

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py:240, in _TrainingEpochLoop.advance(self, data_fetcher)
    237 with trainer.profiler.profile("run_training_batch"):
    238     if trainer.lightning_module.automatic_optimization:
    239         # in automatic optimization, there can only be one optimizer
--> 240         batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
    241     else:
    242         batch_output = self.manual_optimization.run(kwargs)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py:187, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs)
    180         closure()
    182 # ------------------------------
    183 # BACKWARD PASS
    184 # ------------------------------
    185 # gradient update with accumulated gradients
    186 else:
--> 187     self._optimizer_step(batch_idx, closure)
    189 result = closure.consume_result()
    190 if result.loss is None:

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py:265, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
    262     self.optim_progress.optimizer.step.increment_ready()
    264 # model hook
--> 265 call._call_lightning_module_hook(
    266     trainer,
    267     "optimizer_step",
    268     trainer.current_epoch,
    269     batch_idx,
    270     optimizer,
    271     train_step_and_backward_closure,
    272 )
    274 if not should_accumulate:
    275     self.optim_progress.optimizer.step.increment_completed()

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:157, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
    154 pl_module._current_fx_name = hook_name
    156 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 157     output = fn(*args, **kwargs)
    159 # restore current_fx when nested context
    160 pl_module._current_fx_name = prev_fx_name

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/module.py:1291, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
   1252 def optimizer_step(
   1253     self,
   1254     epoch: int,
   (...)
   1257     optimizer_closure: Optional[Callable[[], Any]] = None,
   1258 ) -> None:
   1259     r"""Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls
   1260     the optimizer.
   1261 
   (...)
   1289 
   1290     """
-> 1291     optimizer.step(closure=optimizer_closure)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/optimizer.py:151, in LightningOptimizer.step(self, closure, **kwargs)
    148     raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
    150 assert self._strategy is not None
--> 151 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
    153 self._on_after_step()
    155 return step_output

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:230, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
    228 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
    229 assert isinstance(model, pl.LightningModule)
--> 230 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/amp.py:74, in MixedPrecision.optimizer_step(self, optimizer, model, closure, **kwargs)
     65 def optimizer_step(  # type: ignore[override]
     66     self,
     67     optimizer: Optimizable,
   (...)
     70     **kwargs: Any,
     71 ) -> Any:
     72     if self.scaler is None:
     73         # skip scaler logic, as bfloat16 does not require scaler
---> 74         return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
     75     if isinstance(optimizer, LBFGS):
     76         raise MisconfigurationException("AMP and the LBFGS optimizer are not compatible.")

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py:117, in Precision.optimizer_step(self, optimizer, model, closure, **kwargs)
    115 """Hook to run the optimizer step."""
    116 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 117 return optimizer.step(closure=closure, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:68, in LRScheduler.__init__.<locals>.with_counter.<locals>.wrapper(*args, **kwargs)
     66 instance._step_count += 1
     67 wrapped = func.__get__(instance, cls)
---> 68 return wrapped(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/optim/optimizer.py:373, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    368         else:
    369             raise RuntimeError(
    370                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    371             )
--> 373 out = func(*args, **kwargs)
    374 self._optimizer_step_code()
    376 # call optimizer step post hooks

File /opt/conda/lib/python3.10/site-packages/torch/optim/optimizer.py:76, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
     74     torch.set_grad_enabled(self.defaults['differentiable'])
     75     torch._dynamo.graph_break()
---> 76     ret = func(self, *args, **kwargs)
     77 finally:
     78     torch._dynamo.graph_break()

File /opt/conda/lib/python3.10/site-packages/torch/optim/adamw.py:161, in AdamW.step(self, closure)
    159 if closure is not None:
    160     with torch.enable_grad():
--> 161         loss = closure()
    163 for group in self.param_groups:
    164     params_with_grad = []

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py:104, in Precision._wrap_closure(self, model, optimizer, closure)
     91 def _wrap_closure(
     92     self,
     93     model: "pl.LightningModule",
     94     optimizer: Optimizer,
     95     closure: Callable[[], Any],
     96 ) -> Any:
     97     """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
     98     hook is called.
     99 
   (...)
    102 
    103     """
--> 104     closure_result = closure()
    105     self._after_closure(model, optimizer)
    106     return closure_result

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py:140, in Closure.__call__(self, *args, **kwargs)
    139 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 140     self._result = self.closure(*args, **kwargs)
    141     return self._result.loss

File /opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py:135, in Closure.closure(self, *args, **kwargs)
    132     self._zero_grad_fn()
    134 if self._backward_fn is not None and step_output.closure_loss is not None:
--> 135     self._backward_fn(step_output.closure_loss)
    137 return step_output

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py:236, in _AutomaticOptimization._make_backward_fn.<locals>.backward_fn(loss)
    235 def backward_fn(loss: Tensor) -> None:
--> 236     call._call_strategy_hook(self.trainer, "backward", loss, optimizer)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:309, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    306     return None
    308 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 309     output = fn(*args, **kwargs)
    311 # restore current_fx when nested context
    312 pl_module._current_fx_name = prev_fx_name

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:204, in Strategy.backward(self, closure_loss, optimizer, *args, **kwargs)
    201 assert self.lightning_module is not None
    202 closure_loss = self.precision_plugin.pre_backward(closure_loss, self.lightning_module)
--> 204 self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
    206 closure_loss = self.precision_plugin.post_backward(closure_loss, self.lightning_module)
    207 self.post_backward(closure_loss)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/precision.py:69, in Precision.backward(self, tensor, model, optimizer, *args, **kwargs)
     50 def backward(  # type: ignore[override]
     51     self,
     52     tensor: Tensor,
   (...)
     56     **kwargs: Any,
     57 ) -> None:
     58     r"""Performs the actual backpropagation.
     59 
     60     Args:
   (...)
     67 
     68     """
---> 69     model.backward(tensor, *args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/core/module.py:1078, in LightningModule.backward(self, loss, *args, **kwargs)
   1076     self._fabric.backward(loss, *args, **kwargs)
   1077 else:
-> 1078     loss.backward(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/_tensor.py:492, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    482 if has_torch_function_unary(self):
    483     return handle_torch_function(
    484         Tensor.backward,
    485         (self,),
   (...)
    490         inputs=inputs,
    491     )
--> 492 torch.autograd.backward(
    493     self, gradient, retain_graph, create_graph, inputs=inputs
    494 )

File /opt/conda/lib/python3.10/site-packages/torch/autograd/__init__.py:251, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    246     retain_graph = create_graph
    248 # The reason we repeat the same comment below is that
    249 # some Python versions print out the first line of a multi-line function
    250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    252     tensors,
    253     grad_tensors_,
    254     retain_graph,
    255     create_graph,
    256     inputs,
    257     allow_unreachable=True,
    258     accumulate_grad=True,
    259 )

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

How do i fix this?

sukkritsharmaofficial commented 9 months ago

@awaelchli anything on this?