Genentech / gReLU

gReLU is a python library to train, interpret, and apply deep learning models to DNA sequences.
https://genentech.github.io/gReLU/
MIT License
228 stars 23 forks source link

added map_location arg in load_model to load model onto any device #18

Closed avantikalal closed 4 months ago

gokceneraslan commented 4 months ago

How about something like def load_model(project, model_name, device, alias='latest', checkpoint_file='model.ckpt') 1) Name is more familiar, map_location is trickier. 2) User has to specify there to load it so if there is no GPU, it can be CPU.

avantikalal commented 4 months ago

How about something like def load_model(project, model_name, device, alias='latest', checkpoint_file='model.ckpt') 1) Name is more familiar, map_location is trickier. 2) User has to specify there to load it so if there is no GPU, it can be CPU.

Done! On second thought I decided to make cpu the default since the functions most users will be using later (predict_on_dataset, embed_on_dataset) take care of moving the model to GPU.

gokceneraslan commented 4 months ago

Sounds good