dingo-gw / dingo

Dingo: Deep inference for gravitational-wave observations
MIT License
55 stars 18 forks source link

Model weights as hdf5 #170

Closed mpuerrer closed 1 year ago

mpuerrer commented 1 year ago

Implement conversion and loading of trained models to HDF5 format for inference

mpuerrer commented 1 year ago

Looks great. There are just a couple small comments to address. Could you also confirm that this is well tested?

I think all comments are taken care of now. In terms of testing I've converted one IMRPhenomXPHM network and have done inference using the resulting HDF5 files (see below). This appears to work and the posterior looks unchanged.

dingo_pt_to_hdf5 -n 1 -i training/train_main/model_stage_1.pt -o training/train_main/model_stage_1.hdf5
dingo_pt_to_hdf5 -n 1 -i training/train_time/model_stage_1.pt -o training/train_time/model_stage_1.hdf5
dingo_analyze_event   --model training/train_main/model_stage_1_v1.hdf5   --model_init training/train_time/model_stage_1_v1.hdf5   --gps_time_event 1126259462.4   --num_samples 50000   --num_gnpe_iterations 30   --batch_size 4096
mpuerrer commented 1 year ago

One last thing to check: do older networks also have the keys ['model_kwargs', 'model_state_dict', 'epoch', 'metadata']? If not I'll add a check.

stephengreen commented 1 year ago

One last thing to check: do older networks also have the keys ['model_kwargs', 'model_state_dict', 'epoch', 'metadata']? If not I'll add a check.

I believe that they should all have these keys, yes.

stephengreen commented 1 year ago

One other suggestion would be to add the new script to the list of commands in the documentation: https://dingo-gw.readthedocs.io/en/latest/overview.html#summary-of-commands.

That would require changes to https://github.com/dingo-gw/dingo/blob/main/docs/source/overview.md.

mpuerrer commented 1 year ago

Thanks @mpuerrer. I think one other change that is needed is to modify dingo_ls to print information for HDF5 models, just like it does for PT models.

For this it would be useful to have an identifier within the HDF5 file that specifies its type. For classes that inherit from DingoDataset we have an attribute called dataset_type that does this, e.g.,

Ok, should I add an attribute dataset_type='model_weights' to the HDF5 files saved by pt_to_hdf5.py and print a bit of information in dingo_ls about this?

stephengreen commented 1 year ago

Thanks @mpuerrer. I think one other change that is needed is to modify dingo_ls to print information for HDF5 models, just like it does for PT models. For this it would be useful to have an identifier within the HDF5 file that specifies its type. For classes that inherit from DingoDataset we have an attribute called dataset_type that does this, e.g.,

Ok, should I add an attribute dataset_type='model_weights' to the HDF5 files saved by pt_to_hdf5.py and print a bit of information in dingo_ls about this?

That would be great. I'd call it simply "trained_model" maybe.

mpuerrer commented 1 year ago

I've added the dataset_type, updated dingo_ls to print info about the HDF5 file (pprinting the dicts 'model_kwargs' and 'metadata'), and mentioned dingo_pt_to_hdf5 in the documentation. Should be good to go.