MouseLand / cellpose

a generalist algorithm for cellular segmentation with human-in-the-loop capabilities
https://www.cellpose.org/
BSD 3-Clause "New" or "Revised" License
1.33k stars 382 forks source link

AttributeError during training [BUG] #903

Closed redrodion closed 2 weeks ago

redrodion commented 6 months ago

torch version 2.2.2 cellpose version 3.0.7 I get an AttributeError when I call train_seg function. I use the latest version of cellpose.

path_to_new_model = train_seg(old_model, train_data=training_images, train_labels=training_masks,  test_data=images_test, test_labels=masks_test,   batch_size=8, learning_rate=0.05, n_epochs=10, weight_decay=1e-05, momentum=0.9, SGD=False, channels=[0,0], model_name="NewModelTest")

File ~/mambaforge/envs/cellpose/lib/python3.8/site-packages/cellpose/train.py:362 in train_seg net.diam_labels.data = torch.Tensor([diam_train.mean()]).to(device)

AttributeError: attribute 'data' of 'numpy.generic' objects is not writable

I attempted to fix this error by change the source code i.e. Inside "train.py" I changed

net.diam_labels.data = torch.Tensor([diam_train.mean()]).to(device)

to

 net.diam_labels = torch.Tensor([diam_train.mean()]).to(device)

and that fixed the error I was getting. However, I started getting a different error associated with access of parameters of the network inside "train_seg" function

Here is the log:

path_to_new_model = train_seg(old_model, train_data=training_images, train_labels=training_masks, test_data=images_test, test_labels=masks_test, batch_size=8, learning_rate=0.05, n_epochs=10, weight_decay=1e-05, momentum=0.9, SGD=True, channels=[0,0], model_name="NewModelTest")

File ~/mambaforge/envs/cellpose/lib/python3.8/site-packages/cellpose/train.py:394 in train_seg optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate,

AttributeError: 'CellposeModel' object has no attribute 'parameters'

Can you kindly explain what my problem is, if possible?

thomasrose02 commented 2 months ago

Hi, did you ever get around this problem? I'm having the same issue and haven't found a way to solve it yet. Glad to know it's not just me.

kemaleren commented 1 month ago

I think the documentation has a mistake in it. train_seg() takes a network model as the first argument, not a CellPose model. I was able to get it to train by passing the net attribute directly:

path_to_new_model = train_seg(old_model.net, ...)
carsen-stringer commented 2 weeks ago

thanks this is fixed now