SCIInstitute / ShapeWorks

ShapeWorks
http://sciinstitute.github.io/ShapeWorks/
Other
103 stars 32 forks source link

1881 deepssm implement variants and losses #2129

Closed zahidemon closed 1 year ago

zahidemon commented 1 year ago

Main changes:

  1. Added TL-Net as model architecture and associated training and validation functions.
  2. Added focal loss as loss function options.
  3. Added different learning rate scheduler options.
  4. Added documentation for the newly added variants.
zahidemon commented 1 year ago

To test the added variants in the DeepSSM use case, please make the following changes in the config file:

  1. Set the config["loss"]["function"] parameter as the loss function to train the model. Available options are: MSE and Focal. Default value: MSE
  2. Set the config["tl_net"]["enabled"] parameter as True to train the model using TL-Net. If the value is False, the model will be trained with Vanilla/Base DeepSSM. Default value: False

@akenmorris @jadie1

akenmorris commented 1 year ago

@jadie1 , any comments?