Shivanandroy / simpleT5

simpleT5 is built on top of PyTorch-lightning⚡️ and Transformers🤗 that lets you quickly train your T5 models.
MIT License
382 stars 61 forks source link

How to resume training? #43

Open RK-BAKU opened 1 year ago

RK-BAKU commented 1 year ago

Hi guys! Is it possible to continue training from specific checkpoint?

mgh1 commented 1 year ago

This is important! Any help on this one?

CherylChaoNYCU commented 1 year ago

@RK-BAKU @mgh1

Hi, for me I just load the model I saved and then keep training on the model:

model.load_model("t5", 'file/to/your/trained/model', use_gpu=True)

the rest is all the same for training

MAX_EPOCHS = 3

torch.cuda.memory_summary(device=None, abbreviated=False) torch.utils.checkpoint

model.train(train_df=df[0:(int)(0.7TRAINNING_SIZE)], eval_df=df[(int)(0.7TRAINNING_SIZE):TRAINNING_SIZE], source_max_token_len=MAX_LEN, target_max_token_len=SUMMARY_LEN, batch_size=5, max_epochs=MAX_EPOCHS, outputdir='/content/gdrive/MyDrive/HW5_HL_gen/t5model',use_gpu=True)

kolaganisankar commented 5 months ago

@RK-BAKU @mgh1

Hi, for me I just load the model I saved and then keep training on the model:

model.load_model("t5", 'file/to/your/trained/model', use_gpu=True)

the rest is all the same for training

MAX_EPOCHS = 3

torch.cuda.memory_summary(device=None, abbreviated=False) torch.utils.checkpoint

model.train(train_df=df[0:(int)(0.7_TRAINNING_SIZE)], eval_df=df[(int)(0.7_TRAINNING_SIZE):TRAINNING_SIZE], source_max_token_len=MAX_LEN, target_max_token_len=SUMMARY_LEN, batch_size=5, max_epochs=MAX_EPOCHS, outputdir='/content/gdrive/MyDrive/HW5_HL_gen/t5model',use_gpu=True)

How do you save the model?

Because there doesnt seems to be any save model.