Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.53k stars 3.39k forks source link

LightningModule.all_gather produces garbage results #17085

Closed gfx73 closed 1 year ago

gfx73 commented 1 year ago

Bug description

all_gather function somehow produces negative values. Here is the code snippet I have in on_train_epoch_end function:

        if torch.any(self.query_labels < 0):
            print('ahtung')
            raise Exception('ahtung')

        print('query_labels before all gather:', self.query_labels[0], self.query_labels.shape, self.query_labels.dtype)

        feature_bank = self.all_gather(self.feature_bank).view(-1, emb_dim)
        target_bank = self.all_gather(self.target_bank).view(-1)

        query_embeddings = self.all_gather(self.query_embeddings).view(-1, emb_dim)
        query_labels = self.all_gather(self.query_labels).reshape(-1)
        print('query_labels after gather:', query_labels[0], query_labels.shape, query_labels.dtype)

I explicitly check that self.query_labels doesn't have negative values. But the prints are as follows:

query_labels before all gather: tensor(3, device='cuda:1') torch.Size([21551]) torch.int64
query_labels after gather: query_labels before all gather: tensor(2, device='cuda:0') torch.Size([21955]) torch.int64
query_labels after gather: tensor(-4759082983527105483, device='cuda:1') torch.Size([43102]) torch.int64

Additionally, training gets stuck at this point. What are the possible reasons for such behavior? Maybe I'm missing something important.

How to reproduce the bug

https://www.kaggle.com/code/aidarkhuzin1/codebertdsl
Run this kaggle notebook to reproduce. It will take around 10 minutes to reach the point with the bug.

Error messages and logs

# Error messages and logs here please

Environment

My environment is kaggle notebook with 2 gpus.

Current environment ``` * CUDA: - GPU: - Tesla T4 - Tesla T4 - available: True - version: 11.3 * Lightning: - lightning-utilities: 0.7.1 - pytorch-ignite: 0.4.11 - pytorch-lightning: 1.9.3 - torch: 1.13.0 - torchaudio: 0.13.0 - torchinfo: 1.7.2 - torchmetrics: 0.11.1 - torchtext: 0.14.0 - torchvision: 0.14.0 * Packages: - absl-py: 1.4.0 - accelerate: 0.12.0 - access: 1.1.8 - affine: 2.4.0 - aiobotocore: 2.4.2 - aiohttp: 3.8.3 - aiohttp-cors: 0.7.0 - aioitertools: 0.11.0 - aiorwlock: 1.3.0 - aiosignal: 1.3.1 - albumentations: 1.3.0 - alembic: 1.9.4 - altair: 4.2.2 - annoy: 1.17.1 - ansiwrap: 0.8.4 - anyio: 3.6.2 - apache-beam: 2.44.0 - aplus: 0.11.0 - appdirs: 1.4.4 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - arviz: 0.12.1 - astroid: 2.14.2 - astropy: 4.3.1 - astunparse: 1.6.3 - async-timeout: 4.0.2 - asynctest: 0.13.0 - atpublic: 2.3 - attrs: 22.2.0 - audioread: 3.0.0 - autocfg: 0.0.8 - autopep8: 1.6.0 - aws-requests-auth: 0.4.3 - babel: 2.11.0 - backcall: 0.2.0 - backoff: 1.10.0 - backports.functools-lru-cache: 1.6.4 - backports.zoneinfo: 0.2.1 - bayesian-optimization: 1.4.2 - bayespy: 0.5.25 - beatrix-jupyterlab: 2023.123.173907 - beautifulsoup4: 4.11.1 - bidict: 0.22.1 - biopython: 1.81 - blake3: 0.2.1 - bleach: 6.0.0 - blessed: 1.19.1 - blis: 0.7.9 - bokeh: 2.4.3 - boruta: 0.3 - boto3: 1.24.40 - botocore: 1.27.59 - bq-helper: 0.4.1 - bqplot: 0.12.36 - branca: 0.6.0 - bravado: 11.0.3 - bravado-core: 5.17.1 - brewer2mpl: 1.4.1 - brotlipy: 0.7.0 - cached-property: 1.5.2 - cachetools: 4.2.4 - cartopy: 0.19.0.post1 - catalogue: 2.0.8 - catalyst: 22.4 - catboost: 1.1.1 - category-encoders: 2.6.0 - certifi: 2022.12.7 - cesium: 0.10.1 - cffi: 1.15.1 - cftime: 1.6.2 - charset-normalizer: 2.1.1 - chex: 0.1.5 - cleverhans: 4.0.0 - click: 8.1.3 - click-plugins: 1.1.1 - cligj: 0.7.2 - cloud-tpu-client: 0.10 - cloud-tpu-profiler: 2.4.0 - cloudpickle: 2.2.1 - cmaes: 0.9.1 - cmake: 3.25.0 - cmdstanpy: 1.1.0 - cmudict: 1.0.13 - colorama: 0.4.6 - colorcet: 3.0.1 - colorful: 0.5.5 - colorlog: 6.7.0 - colorlover: 0.3.0 - conda: 22.9.0 - conda-content-trust: 0+unknown - conda-package-handling: 2.0.2 - conda-package-streaming: 0.7.0 - confection: 0.0.4 - contextily: 1.3.0 - convertdate: 2.4.0 - crcmod: 1.7 - cryptography: 38.0.2 - cudf: 21.12.2 - cufflinks: 0.17.3 - cuml: 21.12.0 - cupy: 9.6.0 - cupy-cuda113: 10.6.0 - cvxcanon: 0.1.2 - cycler: 0.11.0 - cymem: 2.0.7 - cysignals: 1.11.2 - cython: 0.29.33 - cytoolz: 0.12.0 - daal: 2021.6.0 - daal4py: 2021.6.3 - dask: 2022.2.0 - dask-cudf: 21.12.2 - datasets: 2.1.0 - datashader: 0.14.4 - datashape: 0.5.2 - datatable: 1.0.0 - datatile: 1.0.3 - db-dtypes: 1.0.5 - deap: 1.3.3 - debugpy: 1.6.6 - decorator: 5.1.1 - defusedxml: 0.7.1 - delorean: 1.0.0 - deprecat: 2.1.1 - deprecation: 2.1.0 - descartes: 1.1.0 - dill: 0.3.6 - dipy: 1.6.0 - distlib: 0.3.6 - distributed: 2021.11.2 - dm-tree: 0.1.8 - docker: 6.0.1 - docker-pycreds: 0.4.0 - docopt: 0.6.2 - docstring-to-markdown: 0.11 - docutils: 0.19 - earthengine-api: 0.1.342 - easydict: 1.10 - easyocr: 1.6.2 - ecos: 2.0.12 - eli5: 0.13.0 - emoji: 2.2.0 - en-core-web-lg: 3.5.0 - en-core-web-sm: 3.5.0 - entrypoints: 0.4 - ephem: 4.1.4 - esda: 2.4.3 - essentia: 2.1b6.dev858 - et-xmlfile: 1.1.0 - etils: 0.9.0 - exceptiongroup: 1.1.0 - explainable-ai-sdk: 1.3.3 - explainers: 0.1 - fastai: 2.7.11 - fastapi: 0.89.1 - fastavro: 1.7.0 - fastcore: 1.5.28 - fastdownload: 0.0.7 - fasteners: 0.18 - fastjsonschema: 2.16.2 - fastprogress: 1.0.3 - fastrlock: 0.8 - fasttext: 0.9.2 - fbpca: 1.0 - feather-format: 0.4.1 - featuretools: 1.11.1 - filelock: 3.9.0 - fiona: 1.8.22 - fitter: 1.5.2 - flake8: 5.0.4 - flashtext: 2.7 - flask: 2.2.3 - flatbuffers: 23.1.21 - flax: 0.6.4 - flit-core: 3.8.0 - folium: 0.14.0 - fonttools: 4.38.0 - fqdn: 1.5.1 - frozendict: 2.3.5 - frozenlist: 1.3.3 - fsspec: 2023.1.0 - funcy: 1.18 - fury: 0.8.0 - future: 0.18.3 - fuzzywuzzy: 0.18.0 - gast: 0.4.0 - gatspy: 0.3 - gcsfs: 2023.1.0 - gensim: 4.0.1 - geographiclib: 2.0 - geohash: 1.0 - geojson: 3.0.1 - geopandas: 0.10.2 - geoplot: 0.5.1 - geopy: 2.3.0 - geoviews: 1.9.6 - ggplot: 0.11.5 - giddy: 2.3.3 - gitdb: 4.0.10 - gitpython: 3.1.30 - gluoncv: 0.10.5.post0 - gluonnlp: 0.10.0 - google-api-core: 1.34.0 - google-api-python-client: 2.79.0 - google-apitools: 0.5.31 - google-auth: 1.35.0 - google-auth-httplib2: 0.1.0 - google-auth-oauthlib: 0.4.6 - google-cloud-aiplatform: 0.6.0a1 - google-cloud-automl: 1.0.1 - google-cloud-bigquery: 2.2.0 - google-cloud-bigtable: 1.7.3 - google-cloud-core: 1.7.3 - google-cloud-datastore: 1.15.5 - google-cloud-dlp: 3.11.1 - google-cloud-language: 2.6.1 - google-cloud-monitoring: 2.14.1 - google-cloud-pubsub: 2.14.0 - google-cloud-pubsublite: 1.6.0 - google-cloud-recommendations-ai: 0.7.1 - google-cloud-resource-manager: 1.8.1 - google-cloud-spanner: 3.27.0 - google-cloud-storage: 1.44.0 - google-cloud-translate: 3.8.4 - google-cloud-videointelligence: 2.8.3 - google-cloud-vision: 2.8.0 - google-crc32c: 1.5.0 - google-pasta: 0.2.0 - google-resumable-media: 1.3.3 - googleapis-common-protos: 1.58.0 - gplearn: 0.4.2 - gpustat: 1.0.0 - gpxpy: 1.5.0 - graphviz: 0.8.4 - greenlet: 2.0.1 - grpc-google-iam-v1: 0.12.6 - grpcio: 1.51.1 - grpcio-status: 1.48.2 - gviz-api: 1.10.0 - gym: 0.23.1 - gym-notices: 0.0.8 - h11: 0.14.0 - h2o: 3.40.0.1 - h5py: 3.8.0 - haversine: 2.7.0 - hdfs: 2.7.0 - heapdict: 1.0.1 - hep-ml: 0.7.1 - hijri-converter: 2.2.4 - hmmlearn: 0.2.8 - holidays: 0.19 - holoviews: 1.15.4 - hpsklearn: 0.1.0 - html5lib: 1.1 - htmlmin: 0.1.12 - httplib2: 0.20.4 - httptools: 0.5.0 - huggingface-hub: 0.12.1 - humanize: 4.6.0 - hunspell: 0.5.5 - husl: 4.0.3 - hydra-slayer: 0.4.0 - hyperopt: 0.2.7 - hypertools: 0.8.0 - ibis-framework: 2.1.1 - idna: 3.4 - igraph: 0.10.4 - imagecodecs: 2021.11.20 - imagehash: 4.3.1 - imageio: 2.25.0 - imbalanced-learn: 0.10.1 - imgaug: 0.4.0 - implicit: 0.5.2 - importlib-metadata: 4.11.4 - importlib-resources: 5.10.2 - inequality: 1.0.0 - iniconfig: 2.0.0 - ipydatawidgets: 4.3.3 - ipykernel: 6.16.2 - ipyleaflet: 0.17.2 - ipympl: 0.7.0 - ipython: 7.34.0 - ipython-genutils: 0.2.0 - ipython-sql: 0.4.1 - ipyvolume: 0.6.0 - ipyvue: 1.8.0 - ipyvuetify: 1.8.4 - ipywebrtc: 0.6.0 - ipywidgets: 7.7.1 - isoduration: 20.11.0 - isort: 5.11.5 - isoweek: 1.3.3 - itsdangerous: 2.1.2 - janome: 0.4.2 - jax: 0.3.25 - jaxlib: 0.3.25+cuda11.cudnn805 - jedi: 0.18.2 - jieba: 0.42.1 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.2.0 - json5: 0.9.11 - jsonlines: 1.2.0 - jsonpointer: 2.3 - jsonref: 1.1.0 - jsonschema: 4.17.3 - jupyter-client: 7.3.4 - jupyter-console: 6.6.1 - jupyter-core: 4.12.0 - jupyter-http-over-ws: 0.0.8 - jupyter-lsp: 1.5.1 - jupyter-server: 1.23.5 - jupyter-server-mathjax: 0.2.6 - jupyter-server-proxy: 3.2.2 - jupyterlab: 3.4.8 - jupyterlab-git: 0.41.0 - jupyterlab-lsp: 3.10.2 - jupyterlab-pygments: 0.2.2 - jupyterlab-server: 2.15.2 - jupyterlab-widgets: 3.0.5 - jupytext: 1.14.4 - kaggle: 1.5.12 - kaggle-environments: 1.12.0 - keras: 2.11.0 - keras-tuner: 1.1.3 - kiwisolver: 1.4.4 - kmapper: 2.0.1 - kmodes: 0.12.2 - korean-lunar-calendar: 0.3.1 - kornia: 0.5.8 - kt-legacy: 1.0.4 - kubernetes: 25.3.0 - langcodes: 3.3.0 - langid: 1.1.6 - lazy-loader: 0.1 - lazy-object-proxy: 1.9.0 - learntools: 0.3.4 - leven: 1.0.4 - levenshtein: 0.20.9 - libclang: 15.0.6.1 - libmambapy: 0.27.0 - libpysal: 4.7.0 - librosa: 0.10.0 - lightfm: 1.16 - lightgbm: 3.3.2 - lightning-utilities: 0.7.1 - lime: 0.2.0.1 - line-profiler: 4.0.2 - llvmlite: 0.39.1 - lml: 0.1.0 - locket: 1.0.0 - lunarcalendar: 0.0.9 - lxml: 4.9.2 - lz4: 4.3.2 - mako: 1.2.4 - mamba: 0.27.0 - mapclassify: 2.5.0 - marisa-trie: 0.7.8 - markdown: 3.4.1 - markdown-it-py: 2.1.0 - markovify: 0.9.4 - markupsafe: 2.1.1 - matplotlib: 3.5.3 - matplotlib-inline: 0.1.6 - matplotlib-venn: 0.11.7 - matrixprofile: 1.1.10 - mccabe: 0.7.0 - mdit-py-plugins: 0.3.3 - mdurl: 0.1.2 - memory-profiler: 0.61.0 - mercantile: 1.2.1 - mgwr: 2.1.2 - missingno: 0.5.1 - mistune: 2.0.4 - mizani: 0.7.3 - mlcrate: 0.2.0 - mlens: 0.2.3 - mlxtend: 0.21.0 - mmh3: 3.0.0 - mne: 1.3.0 - mnist: 0.2.2 - mock: 5.0.1 - momepy: 0.5.4 - monotonic: 1.6 - mpld3: 0.5.9 - mpmath: 1.2.1 - msgpack: 1.0.4 - msgpack-numpy: 0.4.8 - multidict: 6.0.4 - multimethod: 1.9.1 - multipledispatch: 0.6.0 - multiprocess: 0.70.14 - munch: 2.5.0 - munkres: 1.1.4 - murmurhash: 1.0.9 - mxnet-cu112: 1.9.1 - nb-conda: 2.2.1 - nb-conda-kernels: 2.3.1 - nbclassic: 0.4.8 - nbclient: 0.7.2 - nbconvert: 7.2.8 - nbdime: 3.1.1 - nbformat: 5.7.3 - neptune-client: 1.0.2 - nest-asyncio: 1.5.6 - netcdf4: 1.6.2 - networkx: 2.6.3 - nibabel: 4.0.2 - nilearn: 0.10.0 - ninja: 1.11.1 - nltk: 3.2.4 - nose: 1.3.7 - notebook: 6.5.2 - notebook-executor: 0.2 - notebook-shim: 0.2.2 - numba: 0.56.4 - numexpr: 2.8.4 - numpy: 1.21.6 - nvidia-ml-py: 11.495.46 - nvtx: 0.2.3 - oauth2client: 4.1.3 - oauthlib: 3.2.2 - objsize: 0.6.1 - odfpy: 1.4.1 - olefile: 0.46 - onnx: 1.13.1 - opencensus: 0.11.1 - opencensus-context: 0.1.3 - opencv-contrib-python: 4.5.4.60 - opencv-python: 4.5.4.60 - opencv-python-headless: 4.5.4.60 - openpyxl: 3.1.1 - openslide-python: 1.2.0 - opentelemetry-api: 1.1.0 - opentelemetry-exporter-otlp: 1.1.0 - opentelemetry-exporter-otlp-proto-grpc: 1.1.0 - opentelemetry-proto: 1.1.0 - opentelemetry-sdk: 1.1.0 - opentelemetry-semantic-conventions: 0.20b0 - opt-einsum: 3.3.0 - optax: 0.1.4 - optuna: 3.1.0 - orbax: 0.1.0 - orderedmultidict: 1.0.1 - orjson: 3.8.5 - ortools: 9.4.1874 - osmnx: 1.1.1 - overrides: 6.5.0 - packaging: 23.0 - palettable: 3.3.0 - pandarallel: 1.6.4 - pandas: 1.3.5 - pandas-datareader: 0.10.0 - pandas-profiling: 3.6.2 - pandas-summary: 0.2.0 - pandasql: 0.7.3 - pandocfilters: 1.5.0 - panel: 0.14.3 - papermill: 2.4.0 - param: 1.12.3 - parso: 0.8.3 - parsy: 1.4.0 - partd: 1.3.0 - path: 16.6.0 - path.py: 12.5.0 - pathos: 0.3.0 - pathtools: 0.1.2 - pathy: 0.10.1 - patsy: 0.5.3 - pdf2image: 1.16.2 - pexpect: 4.8.0 - phik: 0.12.3 - pickleshare: 0.7.5 - pillow: 9.4.0 - pip: 22.3.1 - pkgutil-resolve-name: 1.3.10 - platformdirs: 2.6.2 - plotly: 5.13.0 - plotly-express: 0.4.1 - plotnine: 0.8.0 - pluggy: 1.0.0 - pointpats: 2.2.0 - polars: 0.16.8 - polyglot: 16.7.4 - pooch: 1.6.0 - portalocker: 2.7.0 - pox: 0.3.2 - ppca: 0.0.4 - ppft: 1.7.6.6 - preprocessing: 0.1.13 - preshed: 3.0.8 - prettytable: 0.7.2 - progressbar2: 4.2.0 - prometheus-client: 0.15.0 - promise: 2.3 - prompt-toolkit: 3.0.36 - pronouncing: 0.2.0 - prophet: 1.1.1 - proto-plus: 1.22.2 - protobuf: 3.20.3 - psutil: 5.9.3 - ptxcompiler: 0.7.0 - ptyprocess: 0.7.0 - pudb: 2022.1.3 - pulp: 2.7.0 - py-lz4framed: 0.14.0 - py-spy: 0.3.14 - py-stringmatching: 0.4.2 - py-stringsimjoin: 0.3.2 - py4j: 0.10.9.7 - pyaml: 21.10.1 - pyarabic: 0.6.15 - pyarrow: 5.0.0 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pyastronomy: 0.18.1 - pybind11: 2.10.3 - pyclipper: 1.3.0.post4 - pycodestyle: 2.9.1 - pycolmap: 0.3.0 - pycosat: 0.6.4 - pycparser: 2.21 - pycryptodome: 3.17 - pyct: 0.5.0 - pycuda: 2022.1 - pydantic: 1.10.4 - pydegensac: 0.1.2 - pydicom: 2.3.1 - pydocstyle: 6.2.3 - pydot: 1.4.2 - pydub: 0.25.1 - pyemd: 0.5.1 - pyerfa: 2.0.0.1 - pyexcel-io: 0.6.6 - pyexcel-ods: 0.6.0 - pyfasttext: 0.4.6 - pyflakes: 2.5.0 - pygeos: 0.14 - pygments: 2.14.0 - pyjwt: 2.6.0 - pykalman: 0.9.5 - pyldavis: 3.2.2 - pylint: 2.16.2 - pymc3: 3.11.5 - pymeeus: 0.5.12 - pymongo: 3.13.0 - pympler: 1.0.1 - pynndescent: 0.5.8 - pynvml: 11.5.0 - pynvrtc: 9.2 - pyocr: 0.8.3 - pyopenssl: 23.0.0 - pyparsing: 3.0.9 - pypdf: 3.4.1 - pyprind: 2.11.3 - pyproj: 3.1.0 - pyrsistent: 0.19.3 - pysal: 2.6.0 - pyshp: 2.3.1 - pysocks: 1.7.1 - pytesseract: 0.3.10 - pytest: 7.2.1 - python-bidi: 0.4.2 - python-dateutil: 2.8.2 - python-dotenv: 0.21.1 - python-igraph: 0.10.4 - python-levenshtein: 0.20.9 - python-louvain: 0.16 - python-lsp-jsonrpc: 1.0.0 - python-lsp-server: 1.7.1 - python-slugify: 8.0.0 - python-utils: 3.5.2 - pythreejs: 2.4.2 - pytoolconfig: 1.2.5 - pytools: 2022.1.12 - pytorch-ignite: 0.4.11 - pytorch-lightning: 1.9.3 - pytz: 2022.7.1 - pytz-deprecation-shim: 0.1.0.post0 - pyu2f: 0.1.5 - pyupset: 0.1.1.post7 - pyviz-comms: 2.2.1 - pywavelets: 1.3.0 - pyyaml: 6.0 - pyzmq: 25.0.0 - qgrid: 1.3.1 - qtconsole: 5.4.0 - qtpy: 2.3.0 - quantecon: 0.6.0 - quantities: 0.13.0 - qudida: 0.0.4 - quilt3: 5.1.1 - randomgen: 1.23.1 - rapidfuzz: 2.13.7 - rasterio: 1.2.10 - rasterstats: 0.18.0 - ray: 2.2.0 - ray-cpp: 2.2.0 - regex: 2021.11.10 - requests: 2.28.2 - requests-futures: 1.0.0 - requests-oauthlib: 1.3.1 - responses: 0.18.0 - retrying: 1.3.4 - rfc3339-validator: 0.1.4 - rfc3987: 1.3.8 - rgf-python: 3.12.0 - rich: 13.2.0 - rmm: 21.12.0 - rope: 1.7.0 - rsa: 4.9 - rtree: 1.0.1 - ruamel-yaml-conda: 0.15.100 - rvlib: 0.0.6 - s2sphere: 0.2.5 - s3fs: 2023.1.0 - s3transfer: 0.6.0 - scattertext: 0.1.12 - scikit-image: 0.19.3 - scikit-learn: 1.0.2 - scikit-learn-intelex: 2021.6.3 - scikit-multilearn: 0.2.0 - scikit-optimize: 0.9.0 - scikit-plot: 0.3.7 - scikit-surprise: 1.1.1 - scipy: 1.7.3 - seaborn: 0.12.2 - segregation: 2.3.1 - semver: 2.13.0 - send2trash: 1.8.0 - sentencepiece: 0.1.97 - sentry-sdk: 1.15.0 - setproctitle: 1.3.2 - setuptools: 59.8.0 - setuptools-git: 1.2 - shap: 0.41.0 - shapely: 1.8.0 - simpervisor: 0.4 - simpleitk: 2.2.1 - simplejson: 3.18.3 - six: 1.16.0 - sklearn-contrib-py-earth: 0.1.0+1.gdde5f89 - sklearn-pandas: 2.2.0 - slicer: 0.0.7 - smart-open: 6.3.0 - smhasher: 0.150.1 - smmap: 5.0.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - snuggs: 1.4.7 - sortedcontainers: 2.4.0 - soundfile: 0.11.0 - soupsieve: 2.3.2.post1 - soxr: 0.3.3 - spacy: 3.5.0 - spacy-legacy: 3.0.12 - spacy-loggers: 1.0.4 - spaghetti: 1.6.5 - spectral: 0.23.1 - spglm: 1.0.8 - sphinx-rtd-theme: 0.2.4 - spint: 1.0.7 - splot: 1.1.5.post1 - spopt: 0.4.1 - spreg: 1.3.0 - spvcm: 0.3.0 - sqlalchemy: 1.4.46 - sqlparse: 0.4.3 - squarify: 0.4.3 - srsly: 2.4.5 - starlette: 0.22.0 - statsmodels: 0.13.5 - stemming: 1.0.1 - stop-words: 2018.7.23 - stopit: 1.1.2 - stumpy: 1.11.1 - swagger-spec-validator: 3.0.3 - sympy: 1.10.1 - tables: 3.7.0 - tabulate: 0.9.0 - tangled-up-in-unicode: 0.2.0 - tbb: 2021.8.0 - tblib: 1.7.0 - tenacity: 8.1.0 - tensorboard: 2.11.2 - tensorboard-data-server: 0.6.1 - tensorboard-plugin-profile: 2.11.1 - tensorboard-plugin-wit: 1.8.1 - tensorboardx: 2.5.1 - tensorflow: 2.11.0 - tensorflow-addons: 0.19.0 - tensorflow-cloud: 0.1.16 - tensorflow-datasets: 4.8.2 - tensorflow-decision-forests: 1.2.0 - tensorflow-estimator: 2.11.0 - tensorflow-gcs-config: 2.11.0 - tensorflow-hub: 0.12.0 - tensorflow-io: 0.29.0 - tensorflow-io-gcs-filesystem: 0.29.0 - tensorflow-metadata: 1.12.0 - tensorflow-probability: 0.19.0 - tensorflow-serving-api: 2.11.0 - tensorflow-text: 2.11.0 - tensorflow-transform: 1.12.0 - tensorpack: 0.11 - tensorstore: 0.1.28 - termcolor: 2.2.0 - terminado: 0.17.1 - text-unidecode: 1.3 - textblob: 0.17.1 - texttable: 1.6.7 - textwrap3: 0.9.2 - tfx-bsl: 1.12.0 - theano: 1.0.5 - theano-pymc: 1.1.2 - thinc: 8.1.7 - threadpoolctl: 3.1.0 - tifffile: 2021.11.2 - timm: 0.6.12 - tinycss2: 1.2.1 - tobler: 0.9.0 - tokenizers: 0.13.2 - toml: 0.10.2 - tomli: 2.0.1 - tomlkit: 0.11.6 - toolz: 0.11.2 - torch: 1.13.0 - torchaudio: 0.13.0 - torchinfo: 1.7.2 - torchmetrics: 0.11.1 - torchtext: 0.14.0 - torchvision: 0.14.0 - tornado: 6.1 - tpot: 0.11.7 - tqdm: 4.64.1 - traceml: 1.0.8 - traitlets: 5.8.1 - traittypes: 0.2.1 - transformers: 4.26.1 - treelite: 2.1.0 - treelite-runtime: 2.1.0 - trueskill: 0.4.5 - tsfresh: 0.20.0 - typed-ast: 1.5.4 - typeguard: 2.13.3 - typer: 0.7.0 - typing-extensions: 4.4.0 - tzdata: 2022.7 - tzlocal: 4.2 - ucx-py: 0.23.0 - ujson: 5.7.0 - umap-learn: 0.5.3 - unicodedata2: 14.0.0 - unidecode: 1.3.6 - update-checker: 0.18.0 - uri-template: 1.2.0 - uritemplate: 3.0.1 - urllib3: 1.26.14 - urwid: 2.1.2 - urwid-readline: 0.13 - uvicorn: 0.20.0 - uvloop: 0.17.0 - vaex: 4.16.0 - vaex-astro: 0.9.3 - vaex-core: 4.16.1 - vaex-hdf5: 0.14.1 - vaex-jupyter: 0.8.1 - vaex-ml: 0.18.1 - vaex-server: 0.8.1 - vaex-viz: 0.5.4 - vecstack: 0.4.0 - virtualenv: 20.17.1 - visions: 0.7.5 - vowpalwabbit: 9.7.0 - vtk: 9.2.6 - wand: 0.6.11 - wandb: 0.13.10 - wasabi: 1.1.1 - watchfiles: 0.18.1 - wavio: 0.0.7 - wcwidth: 0.2.6 - webcolors: 1.12 - webencodings: 0.5.1 - websocket-client: 1.4.2 - websockets: 10.4 - werkzeug: 2.2.3 - wfdb: 4.1.0 - whatthepatch: 1.0.4 - wheel: 0.38.4 - widgetsnbextension: 3.6.2 - witwidget: 1.8.1 - woodwork: 0.16.4 - wordbatch: 1.4.9 - wordcloud: 1.8.2.2 - wordsegment: 1.3.1 - wrapt: 1.14.1 - wurlitzer: 3.0.3 - xarray: 0.20.2 - xarray-einstats: 0.2.2 - xgboost: 1.6.2 - xvfbwrapper: 0.2.9 - xxhash: 3.2.0 - xyzservices: 2023.2.0 - yacs: 0.1.8 - yapf: 0.32.0 - yarl: 1.8.2 - yellowbrick: 1.5 - zict: 2.2.0 - zipp: 3.11.0 - zstandard: 0.18.0 * System: - OS: Linux - architecture: - 64bit - - processor: x86_64 - python: 3.7.12 - version: #1 SMP Sat Mar 11 10:24:08 UTC 2023 ```

More info

No response

stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

awaelchli commented 1 year ago

@gfx73 Could this be due to the fact that your tensors don't have the same shape? You have two distinct sizes:

torch.Size([21551])
torch.Size([21955])

before the all-gather. It shouldn't be possible to gather tensors like that. Sorry for the late answer, but how did you work around this issue in the mean time?

gfx73 commented 1 year ago

Hi @awaelchli Thank you for your response.

As for your question, at least I did not find any mentions of such requirement in LightningModule documentation. I don't really have an expertise in PyTorch distributed functionality.

I switched to single GPU accelerator as a workaround.

awaelchli commented 1 year ago

@gfx73 Thanks for the feedback, I'll clarify this in the docs.

gfx73 commented 1 year ago

@gfx73 Thanks for the feedback, I'll clarify this in the docs.

Just curious how is it possible to gather tensors in this specific case. Intuitively, I thought all_gather should work in the same way as torch.cat.

Additionally, maybe if gathering tensors of different shapes leads to such errors it is better to throw exception? For me it took a lot of effort to understand why my program gets stuck.

eric-tc-wong commented 1 year ago

Hi, I am working on a similar case. My solution is to first create a padded tensor that is of the same shape across all devices.

        world_size = torch.distributed.get_world_size()
        local_shape = torch.tensor(pred.shape[0], device=device)
        max_size = torch.stack([*self.all_gather(local_size)]).max()
        padded_pred = torch.zeros(max_size, device=device)
        padded_pred[:local_size] = pred
        pred = self.all_gather(padded_pred, sync_grads=True).view(-1)

However, I am not sure if the loss calculated using these output tensors needed to be normalized by the number of devices. Is there a better solution?