ivanpanshin / SupCon-Framework

Implementation of Supervised Contrastive Learning with AMP, EMA, SWA, and many other tricks
MIT License
166 stars 21 forks source link

ValueError in Stage2 #28

Closed JiarunLiu closed 2 years ago

JiarunLiu commented 2 years ago

Hi, I'm trying to run your code on CIFAR-10. The training and SWA in stage1 were fine, but I got the following error when training stage2:

root@864d7f9c24b4:/SupCon-Framework# python train.py --config_name configs/train/train_supcon_resnet18_cifar10_stage2.yml 
{'model': {'backbone': 'resnet18', 'ckpt_pretrained': 'weights/supcon_first_stage_cifar10/swa', 'num_classes': 10}, 'train': {'n_epochs': 20, 'amp': True, 'ema': True, 'ema_decay_per_epoch': 0.3, 'logging_name': 'supcon_second_stage_cifar10', 'target_metric': 'accuracy', 'stage': 'second'}, 'dataset': 'data/cifar10', 'dataloaders': {'train_batch_size': 20, 'valid_batch_size': 20, 'num_workers': 12}, 'optimizer': {'name': 'SGD', 'params': {'lr': 0.01}}, 'scheduler': {'name': 'CosineAnnealingLR', 'params': {'T_max': 20, 'eta_min': 0.001}}, 'criterion': {'name': 'LabelSmoothing', 'params': {'classes': 10, 'smoothing': 0.01}}}
Files already downloaded and verified
Files already downloaded and verified
Traceback (most recent call last):
  File "train.py", line 111, in <module>
    train_metrics = utils.train_epoch_ce(loaders['train_features_loader'], model, criterion, optimizer, scaler, ema)
  File "/SupCon-Framework/tools/utils.py", line 250, in train_epoch_ce
    ema.update(model.parameters())
  File "/usr/local/lib/python3.8/dist-packages/torch_ema/ema.py", line 88, in update
    parameters = self._get_parameters(parameters)
  File "/usr/local/lib/python3.8/dist-packages/torch_ema/ema.py", line 65, in _get_parameters
    raise ValueError(
ValueError: Number of parameters passed as argument is different from number of shadow parameters maintained by this ExponentialMovingAverage

Another minor problem is GPU usage. I used to run another implementation of SupContrast. It requires 8x GPU memory (and higher utilization of each GPU) to train stage1 of the same backbone and batch size. Did your know what cause that difference?

ivanpanshin commented 2 years ago

Hey!

  1. That's weird. Worst case scenario - it's possible to just erase EMA from the code, but it's better to keep it. Just as a sanity check, could you share your packages versions?

  2. Pretty sure I do. I guess you're referring to this implementation. As I noted in my README, that implementation is very faulty. In particular, due to incorrect implementation of ResNet, it's very memory-unifficient, that's why you have this 8x utilization.

ivanpanshin commented 2 years ago

Sorry, didn't see that you included the link to another implementation. Well, then I guess I'm right as to why it works that way.

JiarunLiu commented 2 years ago

Thanks for your reply. It works when ema=False. So, I guess it's probably caused by the EMA module. Here is my pip environment:

absl-py==0.11.0
albumentations==1.1.0
argon2-cffi==20.1.0
async-generator==1.10
attrs==20.3.0
backcall==0.2.0
bleach==3.2.1
cachetools==4.2.0
certifi==2020.12.5
cffi==1.14.4
chardet==4.0.0
cycler==0.10.0
dataclasses==0.6
decorator==4.4.2
defusedxml==0.6.0
entrypoints==0.3
faiss-gpu==1.7.1.post2
google-auth==1.24.0
google-auth-oauthlib==0.4.2
grpcio==1.34.0
idna==2.10
imageio==2.9.0
imgaug==0.4.0
ipykernel==5.4.3
ipython==7.16.1
ipython-genutils==0.2.0
ipywidgets==7.6.3
jedi==0.18.0
Jinja2==2.11.2
joblib==1.0.0
jsonschema==3.2.0
jupyter-client==6.1.11
jupyter-console==6.2.0
jupyter-core==4.7.0
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.0
kiwisolver==1.3.1
Markdown==3.3.3
MarkupSafe==1.1.1
matplotlib==3.3.3
mistune==0.8.4
nbclient==0.5.1
nbconvert==6.0.7
nbformat==5.0.8
nest-asyncio==1.4.3
networkx==2.5
notebook==6.1.6
numpy==1.19.4
oauthlib==3.1.0
opencv-python==4.4.0.46
opencv-python-headless==4.4.0.46
packaging==20.8
pandas==1.1.5
pandocfilters==1.4.3
parso==0.8.1
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.0.1
prometheus-client==0.9.0
prompt-toolkit==3.0.10
protobuf==3.14.0
ptyprocess==0.7.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.20
Pygments==2.7.4
PyGObject==3.26.1
pyparsing==2.4.7
pyrsistent==0.17.3
python-apt==1.6.5+ubuntu0.7
python-dateutil==2.8.1
pytorch-metric-learning==0.9.95
pytorch-ranger==0.1.1
pytz==2020.4
PyWavelets==1.1.1
PyYAML==5.3.1
pyzmq==20.0.0
qtconsole==5.0.1
QtPy==1.9.0
qudida==0.0.4
requests==2.25.1
requests-oauthlib==1.3.0
rsa==4.6
scikit-image==0.17.2
scikit-learn==0.23.2
scipy==1.5.4
Send2Trash==1.5.0
Shapely==1.7.1
six==1.15.0
tensorboard==2.4.0
tensorboard-plugin-wit==1.7.0
terminado==0.9.2
testpath==0.4.4
threadpoolctl==2.1.0
tifffile==2020.9.3
timm==0.3.4
torch==1.7.1
torch-ema @ git+https://github.com/fadel/pytorch_ema@3985995e523aa25dd3cff7e7984130eef90a4282
torch-lr-finder==0.2.1
torch-optimizer==0.0.1a17
torchvision==0.8.2
tornado==6.1
tqdm==4.54.1
traitlets==4.3.3
typing-extensions==3.7.4.3
unattended-upgrades==0.1
urllib3==1.26.2
wcwidth==0.2.5
webencodings==0.5.1
Werkzeug==1.0.1
widgetsnbextension==3.5.1
zipp==3.4.0
JiarunLiu commented 2 years ago

邮件已收到!