Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.36k stars 3.38k forks source link

Precision 32 disabling grad #18062

Closed JBaum2000 closed 1 year ago

JBaum2000 commented 1 year ago

Bug description

With precision=32 in the Trainer, a RuntimeError occurs after the last training_step before optimization. The runtime error appears to be associated with torch.is_grad_enabled() being set to False in the last iteration of the training_step. With precision=16 the error does not occur and torch.is_grad_enabled() prints True on each iteration.

The error seems similar to #17949 .

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import lightning.pytorch as pl
import pandas as pd
from functools import partial
import faiss
from datasets import Features, Value, Sequence
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.core import LightningModule
from lightning.pytorch.utilities import rank_zero_only
from transformers import RagSequenceForGeneration, RagConfig, RagRetriever, RagTokenizer, DPRConfig, BatchEncoding, AutoModel, AdamW, DPRContextEncoderTokenizerFast, DPRContextEncoder
from transformers.models.rag.retrieval_rag import CustomHFIndex
from transformers.optimization import get_polynomial_decay_schedule_with_warmup
from collections import defaultdict
from typing import Dict, Tuple, List, Any
import numpy as np
import torch
import json
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import linecache
import re
import string
import time
import copy

class AttrDict(dict):
  def __init__(self, *args, **kwargs):
    super(AttrDict, self).__init__(*args, **kwargs)
    self.__dict__ = self

class RAGEnd2EndTransformer(LightningModule):
  def __init__(self, label_smoothing: float = 0.1, weight_decay: float = 0.001, warmup_steps: int = 500, batch_size: int = 1, num_epochs: int = 3, adam_epsilon: float = 1e-8, indexing_freq: int = 500, seed: int = 42, num_workers: int = 4,  learning_rate: float = 3e-05):
    self.model_class = RagSequenceForGeneration
    config = RagConfig.from_pretrained("facebook/rag-sequence-nq")
    config.index_name = "custom"
    config.index_path = kb_faiss_path
    config.passages_path = kb_dataset_path
    config.use_dummy_dataset = False
    config.generator.encoder_layerdrop
    config.generator.decoder_layerdrop
    config.generator.attention_dropout
    config.generator.dropout
    config.label_smoothing = label_smoothing
    retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", config=config)
    ctx_encoder_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
    retriever.set_ctx_encoder_tokenizer(ctx_encoder_tokenizer)
    model = self.model_class.from_pretrained("facebook/rag-sequence-nq", config=config, retriever=retriever)
    ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
    model.set_context_encoder_for_training(ctx_encoder)
    prefix = config.question_encoder.prefix
    self.tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
    self.retriever = retriever
    super().__init__()
    self.config_dpr = DPRConfig.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
    self.custom_config = AttrDict({})
    self.custom_config.end2end = True
    self.custom_config.indexing_freq = indexing_freq
    # cc_data.xlsx consists of 100 rows, col 1 -> title, col 2 -> text
    self.custom_config.csv_path = "/content/cc_data.xlsx"
    self.context_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
    self.step_count = 0
    self.metrics = defaultdict(list)
    self.dataset_kwargs: dict = {"data_dir": "/content/model_checkpoints/data_dir/", "max_source_length": 128, "prefix": prefix or "",}
    n_observations_per_split = {"train": -1, "val": -1, "test": -1,}
    self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
    self.target_lens = {"train": 25, "val": 25, "test": 25,}
    self.num_workers = num_workers
    self.model = model
    self.model.retriever.init_retrieval()
    self.distributed_retriever = "pytorch"
    self.metric_names = ["em"]
    self.train_batch_size = batch_size
    self.eval_batch_size = batch_size
    self.seed = seed
    self.accumulate_grad_batches = 3
    self.model_type = AutoModel
    self.warmup_steps = warmup_steps
    self.learning_rate = learning_rate
    self.weight_decay = weight_decay
    self.num_train_epochs = num_epochs
    self.adam_epsilon = adam_epsilon
    self.max_epochs = num_epochs
    self.loss_names = ["loss"]
    self.val_metric = "em"
    self.mode = "generative_qa"
    self.metric_names = ["em"]
    self.validation_step_outputs = []
    self.test_step_outputs = []

  def forward(self, input_ids, **kwargs):
    return self.model(input_ids, **kwargs)

  def configure_optimizers(self):
    model = self.model
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [{"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],"weight_decay": self.weight_decay,}, {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],"weight_decay": 0.0,},]
    optimizer = AdamW(optimizer_grouped_parameters, lr=self.learning_rate, eps=self.adam_epsilon)
    self.opt = optimizer
    scheduler = self.get_lr_scheduler()
    return [optimizer], [scheduler]

  def setup(self, stage):
    if stage == "test":
      self.dataset_size = len(self.test_dataloader().dataset)
    else:
      self.train_loader = self.get_dataloader("train", self.train_batch_size, shuffle=True)
      self.dataset_size = len(self.train_dataloader().dataset)

  def total_steps(self) -> int:
    effective_batch_size = self.train_batch_size * self.accumulate_grad_batches
    return (self.dataset_size / effective_batch_size) * self.max_epochs

  def get_lr_scheduler(self):
    scheduler = get_polynomial_decay_schedule_with_warmup(self.opt, self.warmup_steps, num_training_steps=self.total_steps())
    scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
    return scheduler

  def ids_to_clean_text(self, generated_ids: List[int]):
    gen_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return list(map(str.strip, gen_text))

  def _step(self, batch: dict) -> Tuple:
    source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
    rag_kwargs = {}
    decoder_input_ids = target_ids
    lm_labels = decoder_input_ids
    rag_kwargs["reduce_loss"] = True
    outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False, labels=lm_labels, **rag_kwargs,)
    loss = outputs["loss"]
    print("loss requires_grad:", loss.requires_grad)
    return (loss,)

  def training_step(self, batch, batch_idx) -> Dict:
    print("Grad enabled (training_step):", torch.is_grad_enabled())
    print("model training (training_step):", self.model.training)
    model_copy = type(self.model.rag.ctx_encoder)(self.config_dpr)
    model_copy.load_state_dict(self.model.rag.ctx_encoder.state_dict())
    device = torch.device("cuda")
    df = pd.read_excel("/content/cc_data.xlsx", usecols=["title", "plaintext"])
    df.rename(columns={"plaintext":"text"}, inplace=True)
    titles = df["title"].tolist()
    articles = df["text"].tolist()
    p_titles = []
    passages = []
    for i in range(len(titles)):
      title = titles[i].lower()
      article = articles[i].lower()
      words = article.split()
      for i in range(0, len(words), 100):
        chunk = " ".join(words[i:i+100])
        p_titles.append(title)
        passages.append(chunk)
    chunked_corpus = {"title": p_titles, "text": passages}
    from datasets import Dataset
    kb_dataset = Dataset.from_dict(chunked_corpus)
    context_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
    ctx_encoder = copy.deepcopy(model_copy)
    ctx_encoder = ctx_encoder.to(device=device)
    features = Features({"title": Value("string"), "text": Value("string"), "embeddings": Sequence(Value("float32"))})
    def embed(documents: dict, ctx_encoder: DPRContextEncoder, ctx_tokenizer: DPRContextEncoderTokenizerFast, device):
      input_ids = ctx_tokenizer(documents["title"], documents["text"], truncation=True, padding="longest", return_tensors="pt",)["input_ids"]
      embeddings = ctx_encoder(input_ids.to(device=device), return_dict=True).pooler_output
      return {"embeddings": embeddings.detach().cpu().numpy()}

    dataset = kb_dataset.map(partial(embed, ctx_encoder=ctx_encoder, ctx_tokenizer=context_tokenizer, device=device), batched=True, batch_size=self.train_batch_size, features=features,)
    dataset.save_to_disk(kb_dataset_path)
    index = faiss.IndexHNSWFlat(768, 128, faiss.METRIC_INNER_PRODUCT)
    dataset.add_faiss_index("embeddings", custom_index=index)
    dataset.get_index("embeddings").save(kb_faiss_path)
    self.trainer.model.model.rag.retriever.re_load()
    self.trainer.model.model.rag.retriever.init_retrieval()
    self.trainer.strategy.barrier("barrier")
    loss_tensors = self._step(batch)
    print("loss_tensors requires gradients: ", loss_tensors[0].requires_grad)
    logs = dict(zip(["loss"], loss_tensors))
    return loss_tensors[0]

  def normalize_answer(self, s):
    def remove_articles(text):
      return re.sub(r"\b(a|an|the)\b", " ", text)
    def white_space_fix(text):
      return " ".join(text.split())
    def remove_punc(text):
      exclude = set(string.punctuation)
      return "".join(ch for ch in text if ch not in exclude)
    def lower(text):
      return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

  def exact_match_score(self, prediction, ground_truth):
    return self.normalize_answer(prediction) == self.normalize_answer(ground_truth)

  def calculate_exact_match(self, output_lns: List[str], reference_lns: List[str]) -> Dict:
    em = 0
    for hypo, pred in zip(output_lns, reference_lns):
      em += self.exact_match_score(hypo, pred)
    if len(output_lns) > 0:
      em /= len(output_lns)
    return {"em": em}

  def calc_generative_metrics(self, preds, target) -> Dict:
    return self.calculate_exact_match(preds, target)

  def validation_step(self, batch, batch_idx) -> Dict:
    output = self._generative_step(batch)
    self.validation_step_outputs.append(output)
    return output

  def _generative_step(self, batch: dict) -> dict:
    start_time = time.time()
    batch = BatchEncoding(batch).to(device=self.model.device)
    generated_ids = self.model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], do_deduplication=False, use_cache=True, min_length=1, max_length=self.target_lens["val"],)
    print('generated_ids requires grad:', generated_ids.requires_grad)
    preds: List[str] = self.ids_to_clean_text(generated_ids)
    target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"])
    loss_tensors = self._step(batch)
    print("loss_tensors requires_grad (_generative_step):", loss_tensors[0].requires_grad)
    base_metrics = dict(zip(["loss"], loss_tensors))
    gen_metrics: Dict = self.calc_generative_metrics(preds, target)
    summ_len = np.mean(list(map(len, generated_ids)))
    gen_time = (time.time() - start_time) / batch["input_ids"].shape[0]
    base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **gen_metrics)
    return base_metrics

  def on_validation_epoch_end(self):
    prefix="val"
    outputs = self.validation_step_outputs
    self.step_count += 1
    losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
    loss = losses["loss"]
    gen_metrics = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]}
    metrics_tensor: torch.FloatTensor = torch.tensor(gen_metrics[self.val_metric]).type_as(loss)
    gen_metrics.update({k: v.item() for k, v in losses.items()})
    losses.update(gen_metrics)
    metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
    metrics["step_count"] = self.step_count
    log_dict = {f"{prefix}_avg_em": metrics[f"{prefix}_avg_em"], "step_count": metrics["step_count"], f"{prefix}_avg_loss": metrics[f"{prefix}_avg_loss"], f"{prefix}_loss": loss, f"{prefix}_em": metrics_tensor,}
    self.log_dict(log_dict)
    self.validation_step_outputs.clear()

  def test_step(self, batch, batch_idx):
    output = self._generative_step(batch)
    self.test_step_outputs.append(output)
    return output

  def on_test_epoch_end(self):
    prefix="test"
    outputs = self.test_step_outputs
    self.step_count += 1
    losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names}
    loss = losses["loss"]
    print("here are the outputs:", outputs)
    gen_metrics = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]}
    metrics_tensor: torch.FloatTensor = torch.tensor(gen_metrics[self.val_metric]).type_as(loss)
    gen_metrics.update({k: v.item() for k, v in losses.items()})
    losses.update(gen_metrics)
    metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()}
    metrics["step_count"] = self.step_count
    self.save_metrics(metrics, prefix)
    log_dict = {f"{prefix}_avg_em": metrics[f"{prefix}_avg_em"], "step_count": metrics["step_count"], f"{prefix}_avg_loss": metrics[f"{prefix}_avg_loss"], f"{prefix}_loss": loss, f"{prefix}_em": metrics_tensor,}
    self.log_dict(log_dict)
    self.test_step_outputs.clear()

  def get_dataset(self, type_path) -> Seq2SeqDataset:
    n_obs = self.n_obs[type_path]
    max_target_length = self.target_lens[type_path]
    dataset = Seq2SeqDataset(self.tokenizer, type_path=type_path, n_obs=n_obs, max_target_length=max_target_length, **self.dataset_kwargs,)
    return dataset

  def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader:
    dataset = self.get_dataset(type_path)
    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=shuffle, num_workers=self.num_workers,)
    return dataloader

  def train_dataloader(self) -> DataLoader:
    dataloader = self.get_dataloader("train", batch_size=self.train_batch_size, shuffle=True)
    return dataloader

  def val_dataloader(self) -> DataLoader:
    return self.get_dataloader("val", batch_size=self.eval_batch_size)

  def test_dataloader(self) -> DataLoader:
    return self.get_dataloader("test", batch_size=self.eval_batch_size)

def generic_train(model: RAGEnd2EndTransformer):
  pl.seed_everything(model.seed)
  train_params = {}
  train_params["accumulate_grad_batches"] = 8
  train_params["devices"] = "auto"
  train_params["max_epochs"] = 3
  train_params["gradient_clip_val"] = 0.1
  train_params["precision"] = 32 # no RuntimeError when precision=16
  trainer = Trainer(val_check_interval=1, num_sanity_val_steps=2, **train_params)
  trainer.fit(model)
  return trainer

def main() -> RAGEnd2EndTransformer:
  model: RAGEnd2EndTransformer = RAGEnd2EndTransformer()
  trainer: pl.Trainer = generic_train(model)
  trainer.test()
  return model

Error messages and logs

 /usr/local/lib/python3.10/dist-packages/lightning/pytorch/trainer/call.py:291 in                 
 _call_strategy_hook                                                                              

   288 │   │   return None                                                                        
   289 │                                                                                          
   290 │   with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hoo   
 ❱ 291 │   │   output = fn(*args, **kwargs)                                                       
   292 │                                                                                          
   293 │   # restore current_fx when nested context                                               
   294 │   pl_module._current_fx_name = prev_fx_name                                              

 /usr/local/lib/python3.10/dist-packages/lightning/pytorch/strategies/strategy.py:200 in backward 

   197 │   │   assert self.lightning_module is not None                                           
   198 │   │   closure_loss = self.precision_plugin.pre_backward(closure_loss, self.lightning_m   
   199 │   │                                                                                      
 ❱ 200 │   │   self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *   
   201 │   │                                                                                      
   202 │   │   closure_loss = self.precision_plugin.post_backward(closure_loss, self.lightning_   
   203 │   │   self.post_backward(closure_loss)                                                   

 /usr/local/lib/python3.10/dist-packages/lightning/pytorch/plugins/precision/precision_plugin.py: 
 67 in backward                                                                                   

    64 │   │   │   │   :meth:`~torch.Tensor.backward`.                                           
    65 │   │   │   \**kwargs: Keyword arguments for the same purpose as ``*args``.                
    66 │   │   """                                                                                
 ❱  67 │   │   model.backward(tensor, *args, **kwargs)                                            
    68 │                                                                                          
    69 │   def post_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor:  #    
    70 │   │   # once backward has been applied, release graph                                    

 /usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/module.py:1046 in backward        

   1043 │   │   if self._fabric:                                                                  
   1044 │   │   │   self._fabric.backward(loss, *args, **kwargs)                                  
   1045 │   │   else:                                                                             
 ❱ 1046 │   │   │   loss.backward(*args, **kwargs)                                                
   1047 │                                                                                         
   1048 │   def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> None:  
   1049 │   │   """Makes sure only the gradients of the current optimizer's parameters are calcu  

 /usr/local/lib/python3.10/dist-packages/torch/_tensor.py:487 in backward                         

    484 │   │   │   │   create_graph=create_graph,                                                
    485 │   │   │   │   inputs=inputs,                                                            
    486 │   │   │   )                                                                             
 ❱  487 │   │   torch.autograd.backward(                                                          
    488 │   │   │   self, gradient, retain_graph, create_graph, inputs=inputs                     
    489 │   │   )                                                                                 
    490                                                                                           

 /usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py:200 in backward               

   197 │   # The reason we repeat same the comment below is that                                  
   198 │   # some Python versions print out the first line of a multi-line function               
   199 │   # calls in the traceback and some print out the last line                              
 ❱ 200 │   Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the bac   
   201 │   │   tensors, grad_tensors_, retain_graph, create_graph, inputs,                        
   202 │   │   allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to ru   
   203                                                                                            
╰───────────────────────────────────────────────────────────────
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Environment

cc @borda

stephen-nju commented 1 year ago

maybe the optimizer from transformers AdamW,try using torch.optim.AdamW

awaelchli commented 1 year ago

From your code it looks like you were using the AdamW as suggested by @stephen-nju (thanks for helping out ❤️). This was fixed here #18268, so I'm closing the issue. But let me know if this doesn't solve it for you.

As a workaround, if you can't wait for the fix to be released, you can manually set torch.set_grad_enabled(True) at the beginning of your training step.