MouseLand / cellpose

a generalist algorithm for cellular segmentation with human-in-the-loop capabilities
https://www.cellpose.org/
BSD 3-Clause "New" or "Revised" License
1.32k stars 378 forks source link

[BUG]Error: RuntimeError: itensor_from_mkldnn expects MKL-DNN tensor input when transitioning from evaluation to training mode #820

Closed KorenMary closed 4 days ago

KorenMary commented 10 months ago

Issue Description: I encountered the following error: RuntimeError: itensor_from_mkldnn expects MKL-DNN tensor input. This error typically arises when attempting to switch from evaluation mode to training mode, for example, after the validation process. Unfortunately, I couldn't find an alternative solution other than training the model first and then proceeding with the evaluation.

mrariden commented 10 months ago

Hi,

Can you list the exact steps to reproduce the error?

Thanks

KorenMary commented 9 months ago

Sample code to reproduce the issue:

def training_main(model, dataloaders, channel, num_of_epochs, checkpoint_dir):

    metrics, losses = [], []

    train_dataloader, val_dataloader = dataloaders

    checkpoint_path = os.path.join(checkpoint_dir, f"celltype_{channel}")
    cellpose_model_path = os.path.join(checkpoint_path, "models")

    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)

    for e in range(num_of_epochs):

        tqdm_toolbar_train = tqdm(enumerate(train_dataloader), position=0, leave=False)
        tqdm_toolbar_val = tqdm(enumerate(val_dataloader), position=0, leave=False)

        arr, losses_arr = [], []

        for i, (imgs, masks) in tqdm_toolbar_train:

            imgs = imgs.numpy()
            masks = masks[:, ...].numpy() if len(masks.shape) == 4 else masks.numpy()

            imgs, masks = np.split(imgs, imgs.shape[0], axis=0), np.split(masks, masks.shape[0], axis=0)

            if len(imgs) and len(masks):

                if os.path.exists(cellpose_model_path):
                    for filename in os.listdir(cellpose_model_path):
                        os.remove(os.path.join(cellpose_model_path, filename))

                _, loss_avg = model.train(train_data=imgs, train_labels=masks, train_files=None,
                    test_data=None, test_labels=None, test_files=None,
                    channels=[0,0], normalize=True,  min_train_masks=0,
                    save_path=checkpoint_path, save_every=6, SGD=True,
                    learning_rate=1e-3, n_epochs=5, momentum=0.9, weight_decay=0, batch_size=batch_size, rescale=True)

                masks_hat, _, _ = model.eval(imgs, batch_size=batch_size, channels=[0, 0])
                metric = np.mean(aggregated_jaccard_index(masks, masks_hat))

                tqdm_toolbar_train.set_description(f"train iteration {i+1}: Jaccard index: {metric}, loss: {loss_avg}")

                arr.append(metric)
                losses_arr.append(loss_avg)

                break

        print("here")    

        arr = np.mean(arr)
        losses_arr = np.mean(losses_arr)

        metrics.append(arr)
        losses.append(losses_arr)

The error only arises when using CPU.

mrariden commented 9 months ago

Can you give me an example line number in the cellpose source that the error occurs at? This will help me debug

KorenMary commented 9 months ago

The error occurs in:

File ~\anaconda3\envs\deep-lr\lib\site-packages\cellpose\resnet_torch.py:47, in resdown.forward(self, x)
     46 def forward(self, x):
---> 47     x = self.proj(x) + self.conv[1](self.conv[0](x))
     48     x = x + self.conv[3](self.conv[2](x))
     49     return x

However, the problem is solved on our side by disabling mkldnn backend on torch.

carsen-stringer commented 4 days ago

I've added an option to turn off mkldnn as an input to models.CellposeModel(mkldnn=False,...)