Shivanandroy / simpleT5

simpleT5 is built on top of PyTorch-lightning⚡️ and Transformers🤗 that lets you quickly train your T5 models.
MIT License
387 stars 62 forks source link

Kernel dies every time when I start training the model #45

Closed kkrishnan90 closed 1 year ago

kkrishnan90 commented 1 year ago

Hi Shiva, Thank you very much for a such clean and neat wrapper for training ML models. I am using t5(precisely t5-small) as the base to train my model for summarization. I use the dataset using datasets from huggingface. However, everytime when I initiate the training code, the kernel dies and restarts. Any help here is much appreciated!

Following is my code.

Import dependencies

%%capture
!pip install --user simplet5==0.1.4
!pip install transformers
!pip install wandb
!pip install pandas
!pip install datasets
!pip install --user simpletransformers

Load data using datasets from huggingface

import pandas as pd
import warnings
warnings.filterwarnings("ignore")
from datasets import load_dataset
dataset = load_dataset("scitldr")

Preparing the train and eval data

train_df = dataset["train"].to_pandas().copy()
train_df.drop(columns=["source_labels","rouge_scores","paper_id"],inplace=True)
train_df.rename(columns={"source":"source_text","target":"target_text"}, inplace=True)
train_df.count() ## No NaN found - zero 1992 dataset

train_df['source_text'] = train_df['source_text'].astype('str').str.rstrip(']\'')
train_df['source_text'] = train_df['source_text'].astype('str').str.lstrip('[\'')
train_df['target_text'] = train_df['target_text'].astype('str').str.rstrip(']\'')
train_df['target_text'] = train_df['target_text'].astype('str').str.lstrip('[\'')

train_df["source_text"]=train_df["source_text"].str.replace('\'','')
train_df["target_text"]=train_df["target_text"].str.replace('\'','')
train_df["source_text"]="summarize: "+train_df["source_text"]
train_df.to_csv("train.csv")

eval_df = dataset["validation"].to_pandas().copy()
eval_df.drop(columns=["source_labels","rouge_scores","paper_id"],inplace=True)
eval_df.rename(columns={"source":"source_text","target":"target_text"}, inplace=True)
eval_df.count() ## No NaN found - zero 1992 dataset

eval_df['source_text'] = eval_df['source_text'].astype('str').str.rstrip(']\'')
eval_df['source_text'] = eval_df['source_text'].astype('str').str.lstrip('[\'')
eval_df['target_text'] = eval_df['target_text'].astype('str').str.rstrip(']\'')
eval_df['target_text'] = eval_df['target_text'].astype('str').str.lstrip('[\'')

eval_df["source_text"]=train_df["source_text"].str.replace('\'','')
eval_df["target_text"]=train_df["target_text"].str.replace('\'','')
eval_df["source_text"]="summarize: "+train_df["source_text"]
eval_df.to_csv("eval.csv")

Loading simpleT5 and wandb_logger and finally loading the model and training code

from simplet5 import SimpleT5
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(project="ask-poc-logger")
model = SimpleT5()
model.from_pretrained("t5","t5-small")
model.train(train_df=train_df[0:100], 
            eval_df=eval_df[0:100],
            source_max_token_len = 512, 
            target_max_token_len = 100,
            batch_size = 2,
            max_epochs = 3,
            use_gpu = True,
            outputdir = "outputs",
            logger = wandb_logger
            )

I am running this code on the following machine. A vertex AI workbench from Google Cloud. N1-Standard-16 machine type with 16 core and 60 GB Memory. And added GPU P100. Any help is much appreciated ! Thanks in advance!

kkrishnan90 commented 1 year ago

[Fixed] Debugged a bit and found that the pytorch version had some conflict on my Vertex AI managed notebook (pre-built containers with pytorch and transformers already installed). Created a custom container and also tried the code in user managed notebook and is working fine. Closing the issue.