Shivanandroy / simpleT5

simpleT5 is built on top of PyTorch-lightning⚡️ and Transformers🤗 that lets you quickly train your T5 models.
MIT License
386 stars 62 forks source link

Suppress the Output Models #21

Closed bayismet closed 2 years ago

bayismet commented 2 years ago

Hello there!

I'd like to ask if there is any possible way to eliminate all models, except for the last trained one. When I fine tune a model, it gives me X different models if I fine tune the model X epochs. I just need the last model and couldn't find a way to prevent writing those models to disk.

Thanks!

pashok3d commented 2 years ago

You need to add ModelCheckpoint callback with k=1 here.

Shivanandroy commented 2 years ago

@bayismet Sorry for late reply, you can save the model at the last epoch in the latest release you can use save_only_last_epoch=True in .train method.

# import
from simplet5 import SimpleT5

# instantiate
model = SimpleT5()

# load (supports t5, mt5, byT5 models)
model.from_pretrained("t5","t5-base")

# train
model.train(train_df=train_df
            eval_df=eval_df,
            source_max_token_len = 512, 
            target_max_token_len = 128,
            batch_size = 8,
            max_epochs = 5,
            use_gpu = True,
            outputdir = "outputs",
            save_only_last_epoch=True
            )
bayismet commented 2 years ago

Better late than never!

Thanks for the answer, I'll submit is as closed now.

bayismet commented 2 years ago

Sorry buy I have to reopen it. I was already installed simplet5 via pip. So I updated it by using the following command:

%pip install simplet5 -U

Then I tried the following code:

from simplet5 import SimpleT5

model = SimpleT5()

model.from_pretrained("t5", "t5-base")

model.train(train_df=df,
            eval_df=df,
            source_max_token_len=512,
            target_max_token_len=128,
            batch_size=8,
            max_epochs=5,
            use_gpu=True,
            outputdir="outputs",
            save_only_last_epoch=True
)

But it gives the following error:


TypeError Traceback (most recent call last)

in 8 model.from_pretrained("t5", "t5-base") 9 ---> 10 model.train(train_df=df, 11 eval_df=df, 12 source_max_token_len=512, TypeError: train() got an unexpected keyword argument 'save_only_last_epoch' Thanks.
Shivanandroy commented 2 years ago

It looks like you are still using previous version of simplet5. You will need to update simplet5 pip install -U simplet5

bayismet commented 2 years ago

Tried it again, I can confirm that property is working. Closing the issue now. Thanks!