gndlwch2w / msvm-unet

The official codes for the work "MSVM-UNet: Multi-Scale Vision Mamba UNet for Medical Image Segmentation".
https://arxiv.org/abs/2408.13735
29 stars 2 forks source link

Issue with Pytorch lightning training #1

Closed penningavery closed 1 month ago

penningavery commented 1 month ago

Hi,

Thank you for your interesting software. I am having an issue. I had to make a number of changes to get it to start to train 1. the function model.prepare_data() needed to be called. 2. The dataset image needed it's type to be changed to np.float32 (it gave an error with float64). 3. The Trainer function needed to be called with 'strategy='ddp_find_unused_parameters_true'.

Once I made these changes it trained for 1 epoch and then failed with '[rank0]: RuntimeError: No backend type associated with device type cpu'

I searched for this error and some people solve this by downgrading pytorch lightning, but I tried that it and it gave another error (TypeError: Synapse.validation_step() takes 2 positional arguments but 3 were given).

Any help you can provide to get the software to train would be appreciated. Thank you.

gndlwch2w commented 1 month ago

Thank you for your interest in our work. First of all, in the Lightning framework, there is no need to manually call the prepare_data() function. Once the configuration is complete, running trainer.fit(...) will take care of everything else. For more information, please refer to this link. Secondly, the default training precision we use is float32. Lastly, we are using version lightning=2.2.5 by default, and everything works smoothly in our environment. I hope this information is helpful to you.