cistrome / MIRA

Python package for analysis of multiomic single cell RNA-seq and ATAC-seq.
52 stars 7 forks source link

Errors when loading trained models #6

Closed KongMingxi closed 2 years ago

KongMingxi commented 2 years ago

Hi authors,

Many thanks for your excellent work and contributions! I am learning the tutorial "2021-12-26_MIRA_tutorial.ipynb". When I load the trained model using mira.topic_model.ExpressionTopicModel.load(), I got this error:


UnpicklingError Traceback (most recent call last) /tmp/ipykernel_47536/3391690466.py in ----> 1 tt.load('./models/best_rna_model.pth')

~/miniconda3/envs/mira/lib/python3.7/site-packages/mira/topic_model/base.py in load(cls, filename) 165 ''' 166 --> 167 data = torch.load(filename) 168 169 model = cls(**data['params'])

~/miniconda3/envs/mira/lib/python3.7/site-packages/torch/serialization.py in load(f, map_location, pickle_module, pickle_load_args) 606 return torch.jit.load(opened_file) 607 return _load(opened_zipfile, map_location, pickle_module, pickle_load_args) --> 608 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) 609 610

~/miniconda3/envs/mira/lib/python3.7/site-packages/torch/serialization.py in _legacy_load(f, map_location, pickle_module, pickle_load_args) 775 "functionality.") 776 --> 777 magic_number = pickle_module.load(f, pickle_load_args) 778 if magic_number != MAGIC_NUMBER: 779 raise RuntimeError("Invalid magic number; corrupt file?")

UnpicklingError: NEWOBJ expected an arg tuple.

I also tried to load it using torch.load() but it doesn't work.

I am wondering how I can load trained models (.pth files)?

Best, Mingxi

AllenWLynch commented 2 years ago

Hi,

Many thanks! Did you specify the file path in your “load” command? Regardless of any internal issues with MIRA, torch.load should always work on the saved object. Can you give me more of the code you used to save the model so I can make sure everything is working correctly?

Thanks! AL

On Feb 7, 2022, at 8:07 PM, MingxiKong @.***> wrote:



Hi authors,

Many thanks for your excellent work and contributions! I am learning the tutorial "2021-12-26_MIRA_tutorial.ipynb". When I load the trained model using mira.topic_model.ExpressionTopicModel.load(), I got this error:


UnpicklingError Traceback (most recent call last) /tmp/ipykernel_47536/3391690466.py in ----> 1 tt.load('./models/best_rna_model.pth')

~/miniconda3/envs/mira/lib/python3.7/site-packages/mira/topic_model/base.py in load(cls, filename) 165 ''' 166 --> 167 data = torch.load(filename) 168 169 model = cls(**data['params'])

~/miniconda3/envs/mira/lib/python3.7/site-packages/torch/serialization.py in load(f, map_location, pickle_module, pickle_load_args) 606 return torch.jit.load(opened_file) 607 return _load(opened_zipfile, map_location, pickle_module, pickle_load_args) --> 608 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) 609 610

~/miniconda3/envs/mira/lib/python3.7/site-packages/torch/serialization.py in _legacy_load(f, map_location, pickle_module, pickle_load_args) 775 "functionality.") 776 --> 777 magic_number = pickle_module.load(f, pickle_load_args) 778 if magic_number != MAGIC_NUMBER: 779 raise RuntimeError("Invalid magic number; corrupt file?")

UnpicklingError: NEWOBJ expected an arg tuple.

I also tried to load it using torch.load() but it doesn't work.

I am wondering how I can load trained models (.pth files)?

Best, Mingxi

— Reply to this email directly, view it on GitHubhttps://na01.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fcistrome%2FMIRA%2Fissues%2F6&data=04%7C01%7C%7C4422b00dbf5f4887154208d9ea9f5ee3%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637798792479672760%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000&sdata=IuzPLDfS3xLdcF%2Bl3y5e5Axj36CO3pS0jQtGM6io%2BFA%3D&reserved=0, or unsubscribehttps://na01.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fnotifications%2Funsubscribe-auth%2FAE43JPFUDJLVMFHBWRZ25R3U2BUE3ANCNFSM5NZAX2TA&data=04%7C01%7C%7C4422b00dbf5f4887154208d9ea9f5ee3%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637798792479672760%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000&sdata=dHKDZ%2BYXceLu5qWK2G5qXUaWM0jVNzBPVfSezUz6ny4%3D&reserved=0. Triage notifications on the go with GitHub Mobile for iOShttps://na01.safelinks.protection.outlook.com/?url=https%3A%2F%2Fapps.apple.com%2Fapp%2Fapple-store%2Fid1477376905%3Fct%3Dnotification-email%26mt%3D8%26pt%3D524675&data=04%7C01%7C%7C4422b00dbf5f4887154208d9ea9f5ee3%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637798792479672760%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000&sdata=3uRph5WXV0YBIRamaz%2BVEygWwlOKqFxbFIzf%2FnuzQVU%3D&reserved=0 or Androidhttps://na01.safelinks.protection.outlook.com/?url=https%3A%2F%2Fplay.google.com%2Fstore%2Fapps%2Fdetails%3Fid%3Dcom.github.android%26referrer%3Dutm_campaign%253Dnotification-email%2526utm_medium%253Demail%2526utm_source%253Dgithub&data=04%7C01%7C%7C4422b00dbf5f4887154208d9ea9f5ee3%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637798792479672760%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000&sdata=cdEAnhXof1matnRAeV0ZIB3dQFlA6m0kW8RA9wm1vD0%3D&reserved=0. You are receiving this because you are subscribed to this thread.Message ID: @.***>

KongMingxi commented 2 years ago

Hi AL,

Thanks for your help! Yes, I specified the path for the .pth file in my "load" command. My code for model building, saving and loading is as follows:

  1. to initialize a model for scRNA data: example_rna_model = mira.topics.ExpressionTopicModel(exogenous_key='exog_feature', endogenous_key='highly_variable')
  2. to train and save the best model after hyperparameter tuning: tuner = mira.topics.TopicModelTuner(example_rna_model, save_name="./my_folder/best_rna_model.pth") tuner.train_test_split(rna_data, train_size = 0.8) study = tuner.tune(rna_data) example_rna_model = tuner.select_best_model(rna_data)
  3. to load the trained model: rna_model_trained = mira.topic_model.ExpressionTopicModel.load('./my_folder/best_rna_model.pth')

Also, I tried to load the trained model using the following command, but it failed to load: rna_model_trained = torch.load('./my_folder/best_rna_model.pth')

Please let me know if you need more code.

Thanks! Mingxi

AllenWLynch commented 2 years ago

Ahh I see the problem and it's an issue due to a confusing part of the API.

The "save_name" parameter of the tuner object records the trials and results from the optimization sequence, but it doesn't record the weights of the model. After the "select_best_model" function where you got your "best_rna_model", you then would have had to save that by:

best_rna_model.save('path/to/weights.pth')

so, how to recover your best model. When you ran "select_best_model", the last printout will tell you what the best parameters are. You can just retrain the model with those parameters with all of your data and it will produce the same result.

Let me know if you have more questions, AL


From: MingxiKong @.> Sent: Tuesday, February 8, 2022 11:11 AM To: cistrome/MIRA @.> Cc: AllenWLynch @.>; Comment @.> Subject: Re: [cistrome/MIRA] Errors when loading trained models (Issue #6)

Hi AL,

Thanks for your help! Yes, I specified the path for the .pth file in my "load" command. My code for model building, saving and loading is as follows:

  1. to initialize a model for scRNA data: example_rna_model = mira.topics.ExpressionTopicModel(exogenous_key='exog_feature', endogenous_key='highly_variable')
  2. to train and save the best model after hyperparameter tuning: tuner = mira.topics.TopicModelTuner(example_rna_model, save_name="./my_folder/best_rna_model.pth") tuner.train_test_split(rna_data, train_size = 0.8) study = tuner.tune(rna_data) example_rna_model = tuner.select_best_model(rna_data)
  3. to load the trained model: rna_model_trained = mira.topic_model.ExpressionTopicModel.load('./my_folder/best_rna_model.pth')

Also, I tried to load the trained model using the following command, but it failed to load: rna_model_trained = torch.load('./my_folder/best_rna_model.pth')

Please let me know if you need more code.

Thanks! Mingxi

— Reply to this email directly, view it on GitHubhttps://na01.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fcistrome%2FMIRA%2Fissues%2F6%23issuecomment-1032856924&data=04%7C01%7C%7C4083448793a74e5679bf08d9eb26085a%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637799370865107664%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000&sdata=6FI%2B3vyTBKrrX55Rz2JvBmhvGY18UbW38NiZiaGfCUQ%3D&reserved=0, or unsubscribehttps://na01.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fnotifications%2Funsubscribe-auth%2FAE43JPD66KQFM7Q7NKBS6C3U2FFDTANCNFSM5NZAX2TA&data=04%7C01%7C%7C4083448793a74e5679bf08d9eb26085a%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637799370865107664%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000&sdata=YiqD4icEJ51VuymFEmgrXfidD3XN4EROnhIJDBec%2BIA%3D&reserved=0. Triage notifications on the go with GitHub Mobile for iOShttps://na01.safelinks.protection.outlook.com/?url=https%3A%2F%2Fapps.apple.com%2Fapp%2Fapple-store%2Fid1477376905%3Fct%3Dnotification-email%26mt%3D8%26pt%3D524675&data=04%7C01%7C%7C4083448793a74e5679bf08d9eb26085a%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637799370865107664%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000&sdata=kaprN3%2FBuQ4KGFOxrCeFClbtfIUnRpqofdrFJf%2F3pgo%3D&reserved=0 or Androidhttps://na01.safelinks.protection.outlook.com/?url=https%3A%2F%2Fplay.google.com%2Fstore%2Fapps%2Fdetails%3Fid%3Dcom.github.android%26referrer%3Dutm_campaign%253Dnotification-email%2526utm_medium%253Demail%2526utm_source%253Dgithub&data=04%7C01%7C%7C4083448793a74e5679bf08d9eb26085a%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637799370865107664%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000&sdata=%2BnwDPcswlW705sU4TKbg80Bh9elPAXE0JfurHIhn3nA%3D&reserved=0. You are receiving this because you commented.Message ID: @.***>

KongMingxi commented 2 years ago

Thanks! I can save and load models now! Thanks again for your help!