maciej-sypetkowski / kaggle-rcic-1st

1st Place Solution for Kaggle Recursion Cellular Image Classification Challenge -- https://www.kaggle.com/c/recursion-cellular-image-classification/
MIT License
140 stars 40 forks source link

sirnas are not parsed correctly #4

Closed alxndrkalinin closed 4 years ago

alxndrkalinin commented 4 years ago

Hi, thanks for sharing your code!

I'm trying to run training with the first command (with or without --all-controls-train 0):

python main.py -e 130 --pl-epoch 90 --lr cosine,1.5e-4,90,6e-5,150,0 --pl-size-func 0.6*x+0.4 --cv-number -1 --seed 0 --save /results/dn161_0

and it seems like there is an issue with parsing of sirna numbers when reading them from .csv file:

Traceback (most recent call last):
  File "main.py", line 657, in <module>
    main(args)
  File "main.py", line 646, in main
    train(args, model)
  File "main.py", line 549, in train
    train_loader, val_loader = dataset.get_train_val_loader(args)
  File "/home/user/projects/kaggle-rcic-1st/dataset.py", line 95, in get_train_val_loader
    normalization=args.data_normalization,
  File "/home/user/projects/kaggle-rcic-1st/dataset.py", line 249, in __init__
    if not hasattr(r, "sirna") or r.sirna < self.treatment_classes:
TypeError: '<' not supported between instances of 'NoneType' and 'int'

I noticed you're using 1138 later as an extra id, so I tried to assume it's for UNTREATED controls and to parse sirnas like this:

if hasattr(r, "sirna"):
  sirna = r.sirna.split('_')
  r.sirna = int(sirna[1]) if len(sirna) == 2 else 1138

but I still ran into another assertion:

Traceback (most recent call last):
  File "main.py", line 657, in <module>
    main(args)
  File "main.py", line 646, in main
    train(args, model)
  File "main.py", line 549, in train
    train_loader, val_loader = dataset.get_train_val_loader(args)
  File "/home/user/projects/kaggle-rcic-1st/dataset.py", line 95, in get_train_val_loader
    normalization=args.data_normalization,
  File "/home/user/projects/kaggle-rcic-1st/dataset.py", line 260, in __init__
    assert sirna < self.treatment_classes
AssertionError

Can you please help with understanding how to parse sirna ids?

maciej-sypetkowski commented 4 years ago

That's interesting. For some reason the sirna column in the Kaggle dataset has changed. Previously sirna was just an integer, where untreated class was 1138. Right now, sirna_ prefix was added and also ids of sirnas have changed, so 1108 isn't a split point between normal and control sirnas anymore. Try downloading dataset from the official webpage (https://www.rxrx.ai/rxrx1) and it should work.

maciej-sypetkowski commented 4 years ago

The dataset from the official webpage has also a little different format, but the numeration of sirnas is correct. You can find a mapping between these two, or assign all treatment sirna (i.e. from train.csv) indices <1108, and for control sirnas indices >=1108

alxndrkalinin commented 4 years ago

Thanks a lot! It seems like I was able to map IDs and the training works now.