Closed timephy closed 2 years ago
Will update when I gain new insights on whether the error keeps occurring even with newer versions...
Hi @timephy, Thanks for the amazingly documented issue report! This is unfortunately something @RasmusOrsoe, myself, and others have bumped into in the past. If updating the torch/CUDA version actually fixes the problem, it would be fantastic news! Please keep us updated, and if you manage to complete a dozen or so training runs without encountering this problem, it would be brilliant if you could propose a PR! 🥳
Hey, just a quick update: The problem seems to be resolved when using the newer versions specified above! Will put in a new PR @asogaard
Description
Reporting a weird problem, where sometimes during training/inferring the program throws an error, which is logged to stdout, but does not crash entirely, instead it freezes and can be quit with
Control+C
(throwing more errors).The problem occurs in this fashion:
When it occurs, the last lines of stdout are the ones above. Of course not always Epoch 20, it varies. Also, it does not happen in every training attempt, it is a little more random - more like every second or third - so it's hard to pinpoint.
Sorry to report something so vague, but I feel like it should be written down somewhere because it has happened about 4-6 times to me so far...
My program and its log
Basically just a simple training script like
examples/train_model.py
with a very reasonable training 'scale':Click the sections below for details
▸ Program
```python def train_test(*, target, database, pulsemap, batch_size, num_workers, gpus, max_epochs, patience, run_name, archive): print() print(f'===== train({target=}) =====') print(f'{features=}') print(f'{truth=}') # Data if target == 'track': train_valid_selection, test_selection = get_even_track_cascade_indicies(database) elif target == 'energy' or target == 'zenith': selection = get_desired_event_numbers( database, 10000000000, fraction_muon=0, fraction_nu_e=0.34, fraction_nu_mu=0.33, fraction_nu_tau=0.33 # type: ignore ) train_valid_selection, test_selection = train_test_split(selection, test_size=0.25, random_state=42) else: raise Exception('target does not match') training_dataloader, validation_dataloader = make_train_validation_dataloader( # type: ignore db=database, selection=train_valid_selection, pulsemaps=pulsemap, features=features, truth=truth, batch_size=batch_size, num_workers=num_workers, test_size=0.33, ) test_dataloader = make_dataloader( db=database, pulsemaps=pulsemap, features=features, truth=truth, selection=test_selection, batch_size=batch_size, shuffle=False, num_workers=num_workers, ) # Building model detector = IceCubeDeepCore(graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8)) gnn = DynEdge(nb_inputs=detector.nb_outputs) if target == 'track': task = BinaryClassificationTask( hidden_size=gnn.nb_outputs, target_labels=target, loss_function=BinaryCrossEntropyLoss(), ) elif target == 'energy': # task = EnergyReconstruction( # hidden_size=gnn.nb_outputs, # target_labels=target, # loss_function=LogCoshLoss(), # transform_prediction_and_target=torch.log10, # ) task = PassOutput1( hidden_size=gnn.nb_outputs, target_labels=target, loss_function=LogCoshLoss(), transform_target=torch.log10, transform_inference=lambda x: torch.pow(10, x) ) elif target == 'zenith': task = ZenithReconstructionWithKappa( hidden_size=gnn.nb_outputs, target_labels=target, loss_function=VonMisesFisher2DLoss(), ) else: raise Exception('target does not match') model = Model( detector=detector, gnn=gnn, tasks=[task], optimizer_class=Adam, optimizer_kwargs={'lr': 1e-03, 'eps': 1e-03}, scheduler_class=PiecewiseLinearLR, scheduler_kwargs={ 'milestones': [0, len(training_dataloader) / 2, len(training_dataloader) * max_epochs], 'factors': [1e-2, 1, 1e-02], }, scheduler_config={ 'interval': 'step', }, ) # Training model callbacks = [ EarlyStopping( monitor='val_loss', patience=patience, ), ProgressBar(), ] trainer = Trainer( gpus=gpus, max_epochs=max_epochs, callbacks=callbacks, log_every_n_steps=1, # logger=wandb_logger, ) # try: trainer.fit(model, training_dataloader, validation_dataloader) # except KeyboardInterrupt: # print('[ctrl+c] Exiting gracefully.') # exit() # model.save(os.path.join(archive, f'{run_name}.pth')) # model.save_state_dict(os.path.join(archive, f'{run_name}_state_dict.pth')) # Saving predictions to file if target == 'track': results = get_predictions( trainer, model, test_dataloader, prediction_columns=[target + '_pred'], additional_attributes=[target, 'event_no', 'energy'], ) elif target == 'energy': results = get_predictions( trainer, model, test_dataloader, prediction_columns=[target + '_pred'], additional_attributes=[target, 'event_no'], ) elif target == 'zenith': results = get_predictions( trainer, model, test_dataloader, prediction_columns=[target + '_pred', target + '_kappa'], additional_attributes=[target, 'event_no', 'energy'], ) else: raise Exception('target does not match') # save_results(database, run_name, results, archive, model) os.makedirs(archive, exist_ok=True) results.to_csv(f'{archive}/{run_name}_results.csv') model.save_state_dict(f'{archive}/{run_name}_state_dict.pth') model.save(f'{archive}/{run_name}_model.pth') def main(): # Config database = '/remote/ceph/user/t/timg/dev_lvl7_robustness_muon_neutrino_0000.db' pulsemap = 'SRTTWOfflinePulsesDC' batch_size = 512 num_workers = 10 gpus = [3] max_epochs = 50 patience = 5 archive = 'results' pipeline_name = 'pipeline_tim_0' database_pipeline = '/mnt/scratch/rasmus_orsoe/databases/oscillations/dev_lvl7_robustness_muon_neutrino_0000/data/dev_lvl7_robustness_muon_neutrino_0000.db' # Parser parser = argparse.ArgumentParser(description='A program.') parser.add_argument('-f', dest='functions', nargs='+', default=['train_test', 'test_results', 'pipeline'], help='what functions to run on targets') parser.add_argument('-t', dest='targets', nargs='+', default=['track', 'energy', 'zenith'], help='what targets to run functions on') args = parser.parse_args() print(f'{args=}') targets = args.targets functions = args.functions # Run for target in targets: run_name = f'dynedge-1-{target}' if 'train_test' in functions: train_test( target=target, database=database, pulsemap=pulsemap, batch_size=batch_size, num_workers=num_workers, gpus=gpus, max_epochs=max_epochs, patience=patience, run_name=run_name, archive=archive ) if __name__ == '__main__': main() ```▸ Logs
``` (gnn_py38) Singularity> python3 main_1.py -t zenith Matplotlib is building the font cache; this may take a moment. args=Namespace(functions=['train_test', 'test_results'], targets=['zenith']) Only 8291804 events in database, using this number instead. There have been 2819213 requested of particle 12, we can only supply 1874116. Renormalising... ===== train(target='zenith') ===== features=['dom_x', 'dom_y', 'dom_z', 'dom_time', 'charge', 'rde', 'pmt_area'] truth=['energy', 'position_x', 'position_y', 'position_z', 'azimuth', 'zenith', 'pid', 'elasticity', 'sim_type', 'interaction_type'] GPU available: True, used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3] | Name | Type | Params ---------------------------------------------- 0 | _detector | IceCubeDeepCore | 0 1 | _gnn | DynEdge | 1.4 M 2 | _tasks | ModuleList | 258 ---------------------------------------------- 1.4 M Trainable params 0 Non-trainable params 1.4 M Total params 5.504 Total estimated model params size (MB) Epoch 0: 100%|█████████████| 4038/4038 [07:27<00:00, 9.02 batch(es)/s, loss=0.541, val_loss=0.533, train_loss=0.682] Epoch 1: 100%|█████████████| 4038/4038 [07:21<00:00, 9.14 batch(es)/s, loss=0.502, val_loss=0.528, train_loss=0.520] Epoch 2: 100%|█████████████| 4038/4038 [07:18<00:00, 9.20 batch(es)/s, loss=0.492, val_loss=0.482, train_loss=0.489] Epoch 3: 100%|█████████████| 4038/4038 [07:19<00:00, 9.19 batch(es)/s, loss=0.466, val_loss=0.477, train_loss=0.470] Epoch 4: 100%|█████████████| 4038/4038 [07:16<00:00, 9.25 batch(es)/s, loss=0.452, val_loss=0.458, train_loss=0.457] Epoch 5: 100%|█████████████| 4038/4038 [07:17<00:00, 9.23 batch(es)/s, loss=0.441, val_loss=0.446, train_loss=0.447] Epoch 6: 100%|█████████████| 4038/4038 [07:19<00:00, 9.19 batch(es)/s, loss=0.441, val_loss=0.444, train_loss=0.440] Epoch 7: 100%|█████████████| 4038/4038 [07:18<00:00, 9.20 batch(es)/s, loss=0.435, val_loss=0.433, train_loss=0.433] Epoch 8: 100%|█████████████| 4038/4038 [07:17<00:00, 9.22 batch(es)/s, loss=0.419, val_loss=0.433, train_loss=0.429] Epoch 9: 100%|█████████████| 4038/4038 [07:23<00:00, 9.10 batch(es)/s, loss=0.409, val_loss=0.439, train_loss=0.424] Epoch 10: 100%|██████████████| 4038/4038 [07:17<00:00, 9.23 batch(es)/s, loss=0.42, val_loss=0.431, train_loss=0.420] Epoch 11: 100%|█████████████| 4038/4038 [07:16<00:00, 9.24 batch(es)/s, loss=0.408, val_loss=0.427, train_loss=0.416] Epoch 12: 100%|█████████████| 4038/4038 [07:16<00:00, 9.25 batch(es)/s, loss=0.422, val_loss=0.431, train_loss=0.413] Epoch 13: 100%|█████████████| 4038/4038 [07:07<00:00, 9.44 batch(es)/s, loss=0.403, val_loss=0.425, train_loss=0.409] Epoch 14: 100%|██████████████| 4038/4038 [07:12<00:00, 9.35 batch(es)/s, loss=0.41, val_loss=0.419, train_loss=0.406] Epoch 15: 100%|█████████████| 4038/4038 [07:06<00:00, 9.47 batch(es)/s, loss=0.411, val_loss=0.423, train_loss=0.402] Epoch 16: 100%|█████████████| 4038/4038 [07:10<00:00, 9.38 batch(es)/s, loss=0.404, val_loss=0.420, train_loss=0.399] Epoch 17: 100%|█████████████| 4038/4038 [07:05<00:00, 9.49 batch(es)/s, loss=0.382, val_loss=0.428, train_loss=0.396] Epoch 18: 100%|█████████████| 4038/4038 [07:11<00:00, 9.36 batch(es)/s, loss=0.407, val_loss=0.416, train_loss=0.393] Epoch 19: 100%|█████████████| 4038/4038 [07:16<00:00, 9.26 batch(es)/s, loss=0.386, val_loss=0.421, train_loss=0.390] Epoch 20: 40%|█████▎ | 1633/4038 [02:58<04:22, 9.15 batch(es)/s, loss=0.396, val_loss=0.421, train_loss=0.390]Traceback (most recent call last): File "/home/iwsatlas1/timg/.conda/envs/gnn_py38/lib/python3.8/multiprocessing/queues.py", line 239, in _feed obj = _ForkingPickler.dumps(obj) File "/home/iwsatlas1/timg/.conda/envs/gnn_py38/lib/python3.8/multiprocessing/reduction.py", line 51, in dumps cls(buf, protocol).dump(obj) File "/home/iwsatlas1/timg/.conda/envs/gnn_py38/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 319, in reduce_storage metadata = storage._share_filename_() RuntimeError: unable to open shared memory object in read-write mode Epoch 20: 41%|█████▎ | 1662/4038 [03:15<04:39, 8.51 batch(es)/s, loss=0.397, val_loss=0.421, train_loss=0.390] ```I have also observed this issue while running a script like
examples/make_pipeline_database.py
, so it's not a training specific thing. It's probably a 'dataloader-thing'.My environment
Installed just like the
README.md
instructs, plus the fix forsetuptools
as described in #185.All
torch
related versions (pip list | grep torch
):Click the sections below for details
▸ pip freeze
``` absl-py==1.0.0 aiohttp==3.8.1 aiosignal==1.2.0 alabaster==0.7.12 anybadge==1.9.0 astroid==2.11.2 async-timeout==4.0.2 attrs==21.4.0 autopep8==1.6.0 Babel==2.9.1 brotlipy @ file:///home/conda/feedstock_root/build_artifacts/brotlipy_1648854175163/work cachetools==5.0.0 certifi==2021.10.8 cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1636046063618/work charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1644853463426/work click==8.1.2 colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1602866480661/work coverage==6.3.2 cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography_1649035228992/work cycler @ file:///home/conda/feedstock_root/build_artifacts/cycler_1635519461629/work dill==0.3.4 docker-pycreds==0.4.0 docutils==0.17.1 fonttools @ file:///home/conda/feedstock_root/build_artifacts/fonttools_1649176251233/work frozenlist==1.3.0 fsspec==2022.3.0 future==0.18.2 gitdb==4.0.9 GitPython==3.1.27 google-auth==2.6.3 google-auth-oauthlib==0.4.6 googledrivedownloader @ file:///home/conda/feedstock_root/build_artifacts/googledrivedownloader_1619807768586/work -e git+https://github.com/icecube/graphnet@7d1eeb481d8221551a9237f126559c3ed6fb5da5#egg=graphnet greenlet==1.1.2 grpcio==1.44.0 idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1642433548627/work imagesize==1.3.0 importlib-metadata==4.11.3 iniconfig==1.1.1 isodate==0.6.1 isort==5.10.1 Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1648299710939/work joblib @ file:///home/conda/feedstock_root/build_artifacts/joblib_1633637554808/work kiwisolver @ file:///home/conda/feedstock_root/build_artifacts/kiwisolver_1648854389294/work lazy-object-proxy==1.7.1 Markdown==3.3.6 MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1648737563195/work matplotlib @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-suite_1639358987786/work mccabe==0.7.0 mkl-fft==1.3.1 mkl-random==1.2.2 mkl-service==2.4.0 multidict==6.0.2 munkres==1.1.4 networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1646497321764/work numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1634095647912/work oauthlib==3.2.0 packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1637239678211/work pandas==1.4.2 pathtools==0.1.2 Pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1648857110829/work platformdirs==2.5.1 pluggy==1.0.0 promise==2.3 protobuf==3.20.0 psutil==5.9.0 py==1.11.0 pyasn1==0.4.8 pyasn1-modules==0.2.8 pycodestyle==2.8.0 pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work pyDeprecate==0.3.1 Pygments==2.11.2 pylint==2.13.5 pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1643496850550/work pyparsing @ file:///home/conda/feedstock_root/build_artifacts/pyparsing_1642753572664/work PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1648857275402/work pytest==7.1.1 python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work python-louvain @ file:///home/conda/feedstock_root/build_artifacts/python-louvain_1643704011655/work pytorch-lightning==1.5.6 pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1647961439546/work PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1648757091578/work rdflib==6.1.1 requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1641580202195/work requests-oauthlib==1.3.1 rsa==4.8 scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1642617107864/work scipy @ file:///tmp/build/80754af9/scipy_1641555001653/work seaborn==0.11.2 sentry-sdk==1.5.8 setproctitle==1.2.2 shortuuid==1.0.8 six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work smmap==5.0.0 snowballstemmer==2.2.0 Sphinx==4.5.0 sphinx-rtd-theme==1.0.0 sphinxcontrib-applehelp==1.0.2 sphinxcontrib-devhelp==1.0.2 sphinxcontrib-htmlhelp==2.0.0 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 SQLAlchemy==1.4.35 tensorboard==2.8.0 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 threadpoolctl @ file:///home/conda/feedstock_root/build_artifacts/threadpoolctl_1643647933166/work timer==0.2.2 toml==0.10.2 tomli==2.0.1 torch==1.9.0 torch-cluster @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-cluster_1631029602774/work torch-geometric==2.0.1 torch-scatter @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-scatter_1634900998173/work torch-sparse @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-sparse_1631173914209/work torch-spline-conv @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-spline-conv_1631008408479/work torchmetrics==0.7.3 tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1649051611147/work typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1644850595256/work unicodedata2 @ file:///home/conda/feedstock_root/build_artifacts/unicodedata2_1649111919534/work urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1647489083693/work versioneer==0.22 wandb==0.12.13 Werkzeug==2.1.1 wrapt==1.14.0 yacs @ file:///home/conda/feedstock_root/build_artifacts/yacs_1645705974477/work yarl==1.7.2 zipp==3.8.0 ```▸ conda list
``` # Name Version Build Channel _libgcc_mutex 0.1 conda_forge conda-forge _openmp_mutex 4.5 1_llvm conda-forge absl-py 1.0.0 pypi_0 pypi aiohttp 3.8.1 pypi_0 pypi aiosignal 1.2.0 pypi_0 pypi alabaster 0.7.12 pypi_0 pypi anybadge 1.9.0 pypi_0 pypi astroid 2.11.2 pypi_0 pypi async-timeout 4.0.2 pypi_0 pypi attrs 21.4.0 pypi_0 pypi babel 2.9.1 pypi_0 pypi blas 1.0 mkl brotli 1.0.9 h166bdaf_7 conda-forge brotli-bin 1.0.9 h166bdaf_7 conda-forge brotlipy 0.7.0 py38h0a891b7_1004 conda-forge bzip2 1.0.8 h7f98852_4 conda-forge ca-certificates 2021.10.8 ha878542_0 conda-forge cachetools 5.0.0 pypi_0 pypi certifi 2021.10.8 py38h578d9bd_2 conda-forge cffi 1.15.0 py38h3931269_0 conda-forge charset-normalizer 2.0.12 pyhd8ed1ab_0 conda-forge click 8.1.2 pypi_0 pypi colorama 0.4.4 pyh9f0ad1d_0 conda-forge coverage 6.3.2 pypi_0 pypi cryptography 36.0.2 py38h2b5fc30_1 conda-forge cudatoolkit 11.1.74 h6bb024c_0 nvidia cycler 0.11.0 pyhd8ed1ab_0 conda-forge dill 0.3.4 pypi_0 pypi docker-pycreds 0.4.0 pypi_0 pypi docutils 0.17.1 pypi_0 pypi fonttools 4.31.2 py38h0a891b7_1 conda-forge freetype 2.10.4 h0708190_1 conda-forge frozenlist 1.3.0 pypi_0 pypi fsspec 2022.3.0 pypi_0 pypi future 0.18.2 pypi_0 pypi giflib 5.2.1 h36c2ea0_2 conda-forge gitdb 4.0.9 pypi_0 pypi gitpython 3.1.27 pypi_0 pypi google-auth 2.6.3 pypi_0 pypi google-auth-oauthlib 0.4.6 pypi_0 pypi googledrivedownloader 0.4 pyhd3deb0d_1 conda-forge graphnet 0.1.2+166.g7d1eeb4 dev_0My 'fix'
As the error does not occur every time I train it is hard to verify that it has been fixed, but in a new environment with a newer
torch
version I seem to run stable (also CUDA version upgraded to cu113).All
torch
related versions (pip list | grep torch
):Proposed change
This is what I did, what supposedly fixed it for me:
Set (above) newer
torch
versions in conda environment fileenv/gnn_py38.yml
: