Lee-Gihun / MEDIAR

(NeurIPS 2022 CellSeg Challenge - 1st Winner) Open source code for "MEDIAR: Harmony of Data-Centric and Model-Centric for Multi-Modality Microscopy"
MIT License
141 stars 29 forks source link

Fine-tuning issues #8

Closed hey2homie closed 7 months ago

hey2homie commented 11 months ago

Thank you for this amazing model.

We would like to use MEDIAR in our lab for analysis of our mIF images. We have quite specific cell types with very diverse morphologies, so we would like to fine-tune the model. After some minor modifications to the code base (also for the later ease of incorporation into the pipeline), I was able to run the data on our custom dataset. However, there is something weird happening as the training progresses: the performance degrades until model starts to output no segmentation masks at all. Below are examples of the same image from the validation test taken at epochs 1/3/9. During validation step, F1 values are 0, and no segmentation results found usually after epoch 5.

1 3 9

And here is slightly modified config file used for the training:

{
    "data": {
        "root": "./",
        "mapping_file": "./config/mapping_labeled.json",
        "amplified": false,
        "batch_size": 4,
        "valid_portion": 0.1
    },
    "train_setups": {
        "model": {
            "name": "mediar-former",
            "params": {
            },
            "pretrained": {
                "enabled": true,
                "weights": "./weights/finetuned/from_phase2.pth",
                "strict": false
            }
        },
        "trainer": {
            "name": "mediar",
            "params": {            
                "num_epochs": 60,
                "valid_frequency": 1,
                "device": "coda:0",
                "amp": false
            }
        },
        "optimizer": {
            "name": "adamw",
            "ft_rate": 1.0,
            "params": {
                "lr": 5e-5
            }
        },
        "scheduler": {
            "enabled": true,
            "name": "cosine",
            "params": {
                "T_max": 60,
                "eta_min": 1e-6
            }
        },
        "seed": 19940817
    },
    "pred_setups": {
        "input_path": "",
        "output_path": "",
        "algo_params": {
            "use_tta": false
        }
    },
    "wandb_setups": {
        "project": "Akoya Data",
        "group": "Test",
        "name": "Test Configs"
    }
}

We've also tried to run fine-tuning on other available pre-trained weights, on the subset of data from the datasets that was part of the competition, baseline model, different parameters, and so on. It all results in same performance. However, training and fine-tuning Cellpose model on our data doesn't cost same

Can you help to identify the issues?

Lee-Gihun commented 11 months ago

Hi, it's strange that even fine-tuning on the challenge datasets results in degraded performance.

For clarification, does this mean when you fine-tuned our code on the challenge datasets, it resulted in an F1-score of 0? Or is this problem specific to your datasets?

If fine-tuning on challenge datasets also leads to a score of 0, it's possible that unspecified package versions might result in different preprocessing outcomes. We may need to reproduce the issue to resolve it. Please provide more details to better understand your issue.

I suggest you to check the followings quickly:

hey2homie commented 11 months ago

Thanks for the quick reply. On the challenge dataset, the F1 starts with something around 0.05 and reaches zero within first 10 or so training rounds. For our data, F1 is zero from epoch 1.

I will come back with the packages versions and the inspections of the points you mentioned a bit later as our HPC is under maintenance this week. Thank you!

Lee-Gihun commented 11 months ago

Certainly! If the problem is on our side, I'll quickly work on fixing it. Please provide more details about the issue. Additionally, I suggest starting with the pretrained weights at ./weights/finetuned/phase2.pth. This might not directly relate to the issue of failing to work on the challenge datasets, but it could provide further insights.

hey2homie commented 11 months ago

So, seems that the problem occurs during the loss computing. As mentioned in #3, the labels are rotated (left 90 degree) and flipped horizontally. Here is the example (input image before feeding into the network, cell probabilities, and labels): image_0_before_input image_0_2_before_post_process labels_0_pre_process

This would explain that on the challenge dataset the model returns at least some F1 in the early stages of training just as the cells are more saturated, while on our images they are quite sparse.

Here is the snippet from the Trainer.py:

  # Forward pass
  with torch.cuda.amp.autocast(enabled=self.amp):
      with torch.set_grad_enabled(phase == "train"):
          # Output shape is B x [grad y, grad x, cellprob] x H x W
          plt.imsave(arr=images[0][0], fname="image_0_before_input.png")
          plt.imsave(arr=images[1][0], fname="image_1_before_input.png")
          outputs = self._inference(images, phase)
          outputs = outputs.squeeze(0).cpu().detach().numpy()
          plt.imsave(arr=outputs[0][0], fname="image_0_0_before_post_process.png")
          plt.imsave(arr=outputs[0][1], fname="image_0_1_before_post_process.png")
          plt.imsave(arr=outputs[0][2], fname="image_0_2_before_post_process.png")
          plt.imsave(arr=outputs[1][0], fname="image_1_0_before_post_process.png")
          plt.imsave(arr=outputs[1][1], fname="image_1_1_before_post_process.png")
          plt.imsave(arr=outputs[1][2], fname="image_1_2_before_post_process.png")
          labels = labels.squeeze(0).squeeze(0).cpu().detach().numpy()
          plt.imsave(arr=labels[0][0], fname="labels_0_pre_process.png")
          plt.imsave(arr=labels[1][0], fname="labels_1_pre_process.png")
          raise Exception("test")

Could be that the problem with reading .tiff files occurs in the LoadImage.py? I'm having troubles digging into that code.

And here is the packages used on the HPC (A100) and I've also tried running locally on M1 CPU with the same results:

Package                           Version
--------------------------------- -----------
alabaster                         0.7.12
appdirs                           1.4.4
asn1crypto                        1.5.1
atomicwrites                      1.4.0
attrs                             21.4.0
Babel                             2.10.1
backports.entry-points-selectable 1.1.1
backports.functools-lru-cache     1.6.4
bcrypt                            3.2.2
beniget                           0.4.1
bitstring                         3.1.9
blist                             1.3.6
Bottleneck                        1.3.4
CacheControl                      0.12.11
cachy                             0.3.0
cellpose                          2.2.3
certifi                           2021.10.8
cffi                              1.15.0
chardet                           4.0.0
charset-normalizer                2.0.12
cleo                              0.8.1
click                             8.1.3
clikit                            0.6.2
colorama                          0.4.4
contourpy                         1.2.0
crashtest                         0.3.1
cryptography                      37.0.1
cycler                            0.12.1
Cython                            0.29.28
deap                              1.3.3
decorator                         5.1.1
distlib                           0.3.4
docker-pycreds                    0.4.0
docopt                            0.6.2
docutils                          0.17.1
ecdsa                             0.17.0
editables                         0.3
efficientnet-pytorch              0.7.1
einops                            0.7.0
expecttest                        0.1.3
fastremap                         1.14.0
filelock                          3.6.0
flit                              3.7.1
flit_core                         3.7.1
fonttools                         4.46.0
fsspec                            2022.3.0
future                            0.18.2
gast                              0.5.3
gitdb                             4.0.9
GitPython                         3.1.27
glob2                             0.7
html5lib                          1.1
huggingface-hub                   0.13.4
idna                              3.3
imagecodecs                       2023.9.18
imageio                           2.31.6
imagesize                         1.3.0
importlib-metadata                4.11.3
importlib-resources               5.7.1
iniconfig                         1.1.1
inplace-abn                       1.1.0
intervaltree                      3.1.0
intreehooks                       1.0
ipaddress                         1.0.23
jeepney                           0.8.0
Jinja2                            3.1.2
joblib                            1.1.0
jsonschema                        4.4.0
keyring                           23.5.0
keyrings.alt                      4.1.0
kiwisolver                        1.4.5
lazy_loader                       0.3
liac-arff                         2.5.0
llvmlite                          0.41.1
lockfile                          0.12.2
MarkupSafe                        2.1.1
matplotlib                        3.8.2
mock                              4.0.3
monai                             1.3.0
more-itertools                    8.12.0
mpi4py                            3.1.3
mpmath                            1.2.1
msgpack                           1.0.3
munch                             4.0.0
natsort                           8.4.0
netaddr                           0.8.0
netifaces                         0.11.0
networkx                          3.2.1
numba                             0.58.0
numexpr                           2.8.1
numpy                             1.22.3
nvidia-cublas-cu12                12.1.3.1
nvidia-cuda-cupti-cu12            12.1.105
nvidia-cuda-nvrtc-cu12            12.1.105
nvidia-cuda-runtime-cu12          12.1.105
nvidia-cudnn-cu12                 8.9.2.26
nvidia-cufft-cu12                 11.0.2.54
nvidia-curand-cu12                10.3.2.106
nvidia-cusolver-cu12              11.4.5.107
nvidia-cusparse-cu12              12.1.0.106
nvidia-nccl-cu12                  2.18.1
nvidia-nvjitlink-cu12             12.3.52
nvidia-nvtx-cu12                  12.1.105
opencv-python-headless            3.4.18.65
packaging                         23.2
pandas                            1.4.2
paramiko                          2.10.4
pastel                            0.2.1
pathlib2                          2.3.7.post1
pathspec                          0.9.0
pathtools                         0.1.2
pbr                               5.8.1
pexpect                           4.8.0
Pillow                            9.2.0
pip                               22.0.4
pkginfo                           1.8.2
platformdirs                      2.4.1
pluggy                            1.0.0
ply                               3.11
poetry                            1.1.13
poetry-core                       1.0.8
pretrainedmodels                  0.7.4
promise                           2.3
protobuf                          3.19.4
psutil                            5.9.0
ptyprocess                        0.7.0
py                                1.11.0
py-expression-eval                0.3.14
pyasn1                            0.4.8
pybind11                          2.9.2
pycparser                         2.21
pycryptodome                      3.17
Pygments                          2.12.0
pylev                             1.4.0
PyNaCl                            1.5.0
pyparsing                         3.0.8
pyrsistent                        0.18.1
pytest                            7.1.2
python-dateutil                   2.8.2
pythran                           0.11.0
pytoml                            0.1.21
pytz                              2022.1
PyYAML                            6.0
regex                             2022.4.24
requests                          2.27.1
requests-toolbelt                 0.9.1
roifile                           2023.8.30
safetensors                       0.3.0
scandir                           1.10.0
scikit-image                      0.22.0
SciPy                             1.8.1
SecretStorage                     3.3.2
semantic-version                  2.9.0
sentry-sdk                        1.8.0
setproctitle                      1.3.2
setuptools                        62.1.0
setuptools-rust                   1.3.0
setuptools-scm                    6.4.2
shellingham                       1.4.0
shortuuid                         1.0.9
simplegeneric                     0.8.1
simplejson                        3.17.6
six                               1.16.0
smmap                             5.0.0
snowballstemmer                   2.2.0
sortedcontainers                  2.4.0
Sphinx                            4.5.0
sphinx-bootstrap-theme            0.8.1
sphinxcontrib-applehelp           1.0.2
sphinxcontrib-devhelp             1.0.2
sphinxcontrib-htmlhelp            2.0.0
sphinxcontrib-jsmath              1.0.1
sphinxcontrib-qthelp              1.0.3
sphinxcontrib-serializinghtml     1.1.5
sphinxcontrib-websupport          1.2.4
sympy                             1.12
tabulate                          0.8.9
termcolor                         1.1.0
threadpoolctl                     3.1.0
tifffile                          2023.9.26
timm                              0.6.13
toml                              0.10.2
tomli                             2.0.1
tomli_w                           1.0.0
tomlkit                           0.10.2
torch                             1.12.0
torchvision                       0.13.1
tqdm                              4.64.0
triton                            2.1.0
typing_extensions                 4.2.0
ujson                             5.2.0
urllib3                           1.26.9
virtualenv                        20.14.1
wandb                             0.13.4
wcwidth                           0.2.5
webencodings                      0.5.1
wheel                             0.37.1
xlrd                              2.0.1
yaspin                            2.1.0
zipfile36                         0.1.3
zipp                              3.8.0
hey2homie commented 11 months ago

Just to be sure, I've checked that the label files are correctly written in the first place, and it's indeed the case (also we are using the same dataset to train CellPose and no issues there):

import matplotlib.pyplot as plt

img = plt.imread("./Images/146_rgb.png")
plt.imsave("img.png", img)

label = plt.imread("./Labels/146_masks.png")
plt.imsave("label.png", label)

img label

pakiessling commented 9 months ago

Hi @hey2homie did you ever figure out a way to fine-tune MEDIAR? I would be very interested

hey2homie commented 9 months ago

@pakiessling, @Lee-Gihun, I've finally had time to look more closely with the label issues. Though I haven't found the reason why the labels are rotated/flipped as the file reader uses tifffile as I do in another project with no issue, I'm having an ad-hoc fix for the issue I'm experiencing. In a nutshell, in transforms.py I disable both RandAxisFlipd and RandRotate90d for the image, and I add these two to both validation and training transforms:

Rotate90d(k=1, keys=["label"], spatial_axes=(0, 1)),
Flipd(keys=["label"], spatial_axis=0),

Works like a charm and now I actually getting very good results when fine-tuning. I don't know whether I should submit pull request as it seems to me the reasons is purely due to the package versions.

pakiessling commented 9 months ago

@hey2homie That's awesome! I am going to try my luck. Do you mind if I shoot you a quick message on how you approached things if I run into problems?

hey2homie commented 9 months ago

@pakiessling, not at all, and good luck with your work!

Lee-Gihun commented 7 months ago

Sorry for the inconvenience. I've been swamped with preparations for my Ph.D. graduation...

I believe the root of the problem stems from the noisy versions of related packages, which induced unexpected behaviors in the following custom loading pipeline:

https://github.com/Lee-Gihun/MEDIAR/blob/9c8b9eead41d75116765cc35e1f26f867f552fb7/train_tools/data_utils/custom/LoadImage.py#L110-L158

I have now realigned the related package versions with the latest ones and verified that the loss steadily decreases as the training progresses.

Please reopen this issue if the problem reoccurs.