Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.35k stars 3.39k forks source link

Trainer(use_distributed_sampler=False) doesn't remove associated error. #19657

Closed CompRhys closed 7 months ago

CompRhys commented 8 months ago

Bug description

Using a custom dataloader I get the following error regardless of whether or not I set use_distributed_sampler=False

TypeError:  Lightning can't inject a (distributed) sampler into your batch sampler, because it doesn't subclass PyTorch's `BatchSampler`. To mitigate this, either follow the API of `BatchSampler` or set `Trainer(use_distributed_sampler=False)`. If you choose the latter, you will be responsible for handling the distributed sampling within your batch sampler.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment * CUDA: - GPU: None - available: False - version: None * Lightning: - lightning-utilities: 0.10.1 - pytorch-lightning: 2.1.0 - torch: 2.2.0 - torch-geometric: 2.4.0 - torchmetrics: 1.3.0.post0 - torchvision: 0.17.0 * Packages: - about-time: 4.2.1 - absl-py: 2.1.0 - aiobotocore: 2.11.2 - aiohttp: 3.9.3 - aioitertools: 0.11.0 - aiosignal: 1.3.1 - alembic: 1.13.1 - alive-progress: 3.1.5 - amqp: 5.2.0 - aniso8601: 9.0.1 - annotated-types: 0.6.0 - anyio: 4.2.0 - appnope: 0.1.4 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.3.0 - asttokens: 2.4.1 - async-lru: 2.0.4 - async-timeout: 4.0.3 - attrs: 23.2.0 - autograd: 1.6.2 - autograd-gamma: 0.5.0 - babel: 2.14.0 - backoff: 2.2.1 - beautifulsoup4: 4.12.3 - billiard: 4.2.0 - bleach: 6.1.0 - bokeh: 2.4.3 - botocore: 1.34.34 - build: 1.0.3 - cachelib: 0.9.0 - cachetools: 5.3.2 - celery: 5.3.6 - cell-model: 0.1.0 - certifi: 2024.2.2 - cffi: 1.16.0 - cfgv: 3.4.0 - charset-normalizer: 3.3.2 - click: 8.1.7 - click-didyoumean: 0.3.0 - click-plugins: 1.1.1 - click-repl: 0.3.0 - cloudpickle: 3.0.0 - cma: 3.2.2 - colorama: 0.4.6 - coloredlogs: 14.0 - comm: 0.2.1 - contourpy: 1.2.0 - coverage: 7.4.2 - cramjam: 2.8.1 - croniter: 2.0.1 - cryptography: 39.0.1 - cycler: 0.12.1 - dagit: 1.6.3 - dagster: 1.6.3 - dagster-cloud: 1.6.3 - dagster-cloud-cli: 1.6.3 - dagster-graphql: 1.6.3 - dagster-pipes: 1.6.3 - dagster-webserver: 1.6.3 - dash: 2.15.0 - dash-ag-grid: 2.4.0 - dash-bootstrap-components: 1.5.0 - dash-core-components: 2.0.0 - dash-daq: 0.5.0 - dash-html-components: 2.0.0 - dash-table: 5.0.0 - dask: 2023.2.1 - dask-cloudprovider: 2022.10.0 - db-dtypes: 1.1.1 - debugpy: 1.8.0 - decorator: 5.1.1 - deepchem: 2.7.2.dev20240208164340 - defusedxml: 0.7.1 - deprecated: 1.2.14 - dill: 0.3.8 - diskcache: 5.6.3 - distlib: 0.3.8 - distributed: 2023.2.1 - docker: 6.0.1 - docstring-parser: 0.15 - einops: 0.7.0 - et-xmlfile: 1.1.0 - exceptiongroup: 1.2.0 - executing: 2.0.1 - fakeredis: 2.21.0 - fastjsonschema: 2.19.1 - fastparquet: 2022.12.0 - filelock: 3.13.1 - flask: 2.2.2 - flask-caching: 2.1.0 - flask-login: 0.6.3 - fonttools: 4.49.0 - formulaic: 1.0.1 - fqdn: 1.5.1 - frozenlist: 1.4.1 - fsspec: 2023.12.2 - future: 1.0.0 - gcsfs: 2023.12.2.post1 - github3.py: 4.0.1 - google-api-core: 2.17.1 - google-api-python-client: 2.79.0 - google-auth: 2.28.1 - google-auth-httplib2: 0.2.0 - google-auth-oauthlib: 1.2.0 - google-cloud-bigquery: 3.17.2 - google-cloud-core: 2.4.1 - google-cloud-storage: 2.7.0 - google-crc32c: 1.5.0 - google-resumable-media: 2.7.0 - googleapis-common-protos: 1.62.0 - gql: 3.5.0 - grapheme: 0.6.0 - graphene: 3.3 - graphql-core: 3.2.3 - graphql-relay: 3.2.0 - greenlet: 3.0.3 - grpcio: 1.62.0 - grpcio-health-checking: 1.47.0 - h11: 0.14.0 - heapdict: 1.0.1 - holidays: 0.32 - httpcore: 1.0.2 - httplib2: 0.21.0 - httptools: 0.6.1 - httpx: 0.26.0 - humanfriendly: 10.0 - identify: 2.5.33 - idna: 3.6 - importlib-metadata: 7.0.1 - importlib-resources: 5.10.1 - inflection: 0.5.1 - iniconfig: 2.0.0 - interface-meta: 1.3.0 - ipykernel: 6.29.2 - ipython: 8.21.0 - ipywidgets: 8.1.1 - isoduration: 20.11.0 - itsdangerous: 2.1.2 - jedi: 0.19.1 - jinja2: 3.1.3 - jmespath: 1.0.1 - joblib: 1.3.2 - json5: 0.9.14 - jsonpointer: 2.4 - jsonschema: 4.21.1 - jsonschema-specifications: 2023.12.1 - jupyter-client: 8.6.0 - jupyter-console: 6.6.3 - jupyter-core: 5.7.1 - jupyter-events: 0.9.0 - jupyter-lsp: 2.2.2 - jupyter-server: 2.12.5 - jupyter-server-terminals: 0.5.2 - jupyterlab: 4.1.0 - jupyterlab-pygments: 0.3.0 - jupyterlab-server: 2.25.2 - jupyterlab-widgets: 3.0.9 - kaleido: 0.2.1 - kiwisolver: 1.4.5 - kombu: 5.3.5 - lifelines: 0.28.0 - lightning-utilities: 0.10.1 - locket: 1.0.0 - lxml: 4.9.2 - mako: 1.3.2 - markdown: 3.5.2 - markdown-it-py: 3.0.0 - markupsafe: 2.1.5 - marshmallow: 3.19.0 - marshmallow-oneofschema: 3.0.1 - matplotlib: 3.8.3 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mistune: 3.0.2 - mockito: 1.4.0 - mpmath: 1.3.0 - msgpack: 1.0.7 - multidict: 6.0.5 - multimethod: 1.11.1 - multiprocess: 0.70.16 - mypy-extensions: 1.0.0 - mysql: 0.0.3 - mysql-connector-python: 8.0.33 - mysqlclient: 2.0.3 - nbclient: 0.9.0 - nbconvert: 7.16.0 - nbformat: 5.9.2 - nest-asyncio: 1.6.0 - networkx: 3.2.1 - nodeenv: 1.8.0 - notebook: 7.0.7 - notebook-shim: 0.2.3 - numpy: 1.23.5 - oauth2client: 4.1.3 - oauthlib: 3.2.2 - openpyxl: 3.0.10 - outcome: 1.3.0.post0 - overrides: 7.7.0 - packaging: 23.2 - pandas: 1.5.2 - pandera: 0.17.2 - pandocfilters: 1.5.1 - parso: 0.8.3 - partd: 1.4.1 - pendulum: 3.0.0 - pex: 2.1.163 - pexpect: 4.9.0 - pillow: 10.2.0 - pip: 24.0 - pip-tools: 7.3.0 - platformdirs: 4.2.0 - plotly: 5.17.0 - pluggy: 1.4.0 - pre-commit: 3.6.0 - prometheus-client: 0.19.0 - prompt-toolkit: 3.0.43 - protobuf: 3.20.3 - psutil: 5.9.8 - ptyprocess: 0.7.0 - pubchempy: 1.0.4 - pure-eval: 0.2.2 - pyairtable: 2.1.0.post1 - pyarrow: 10.0.1 - pyasn1: 0.5.1 - pyasn1-modules: 0.3.0 - pycparser: 2.21 - pydantic: 2.6.1 - pydantic-core: 2.16.2 - pygit2: 1.11.1 - pygithub: 2.1.1 - pygments: 2.17.2 - pyjwt: 2.8.0 - pymsteams: 0.2.2 - pymysql: 1.1.0 - pynacl: 1.5.0 - pyparsing: 3.1.1 - pyproject-hooks: 1.0.0 - pysocks: 1.7.1 - pytest: 7.4.0 - pytest-cov: 4.1.0 - python-box: 6.1.0 - python-dateutil: 2.8.2 - python-dotenv: 1.0.1 - python-json-logger: 2.0.7 - python-slugify: 7.0.0 - pytorch-lightning: 2.1.0 - pytz: 2024.1 - pytzdata: 2020.1 - pyyaml: 6.0.1 - pyzmq: 25.1.2 - qtconsole: 5.5.1 - qtpy: 2.4.1 - questionary: 1.10.0 - rdkit: 2023.9.4 - redis: 5.0.1 - referencing: 0.33.0 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - requests-toolbelt: 1.0.0 - retrying: 1.3.4 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.7.0 - rpds-py: 0.17.1 - rsa: 4.9 - ruff: 0.2.1 - s3fs: 2023.12.2 - scikit-learn: 1.4.0 - scipy: 1.12.0 - selenium: 4.8.0 - send2trash: 1.8.2 - setuptools: 69.0.3 - shellingham: 1.5.4 - six: 1.16.0 - sniffio: 1.3.0 - sortedcontainers: 2.4.0 - soupsieve: 2.5 - sqlalchemy: 2.0.22 - stack-data: 0.6.3 - starlette: 0.37.0 - structlog: 24.1.0 - sympy: 1.12 - tabulate: 0.9.0 - tblib: 3.0.0 - tenacity: 8.2.3 - tensorboard: 2.15.1 - tensorboard-data-server: 0.7.2 - terminado: 0.18.0 - text-unidecode: 1.3 - threadpoolctl: 3.3.0 - time-machine: 2.13.0 - tinycss2: 1.2.1 - toml: 0.10.2 - tomli: 2.0.1 - toolz: 0.12.1 - toposort: 1.10 - torch: 2.2.0 - torch-geometric: 2.4.0 - torchmetrics: 1.3.0.post0 - torchvision: 0.17.0 - tornado: 6.4 - tqdm: 4.64.1 - traitlets: 5.14.1 - trio: 0.24.0 - trio-websocket: 0.11.1 - typeguard: 4.1.5 - typer: 0.9.0 - types-python-dateutil: 2.8.19.20240106 - typing-extensions: 4.9.0 - typing-inspect: 0.9.0 - tzdata: 2024.1 - universal-pathlib: 0.0.23 - uri-template: 1.3.0 - uritemplate: 4.1.1 - urllib3: 2.0.7 - uvicorn: 0.27.0.post1 - uvloop: 0.19.0 - vine: 5.1.0 - virtualenv: 20.25.0 - watchdog: 4.0.0 - watchfiles: 0.21.0 - wcwidth: 0.2.13 - webcolors: 1.13 - webdriver-manager: 3.8.5 - webencodings: 0.5.1 - websocket-client: 1.4.2 - websockets: 12.0 - werkzeug: 2.3.7 - wheel: 0.42.0 - widgetsnbextension: 4.0.9 - wrapt: 1.16.0 - wsproto: 1.2.0 - yarl: 1.9.4 - zict: 3.0.0 - zipp: 3.17.0 * System: - OS: Darwin - architecture: - 64bit - - processor: i386 - python: 3.10.11 - release: 23.2.0 - version: Darwin Kernel Version 23.2.0: Wed Nov 15 21:54:10 PST 2023; root:xnu-10002.61.3~2/RELEASE_X86_64

More info

No response

cc @justusschock @awaelchli

awaelchli commented 7 months ago

@CompRhys If you share a runnable code example, I can take a look what caused it. You probably have a custom batch sampler that it causing trouble for Lightning to handle.

CompRhys commented 7 months ago

Hi,

class BySequenceLengthSampler(Sampler):

    def __init__(self, idx_seq_lengths, bucket_boundaries, batch_size=64, shuffle=True):
        self.idx_seq_lengths = idx_seq_lengths
        self.bucket_boundaries = bucket_boundaries
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        data_buckets = dict()

        for idx, seq_length in self.idx_seq_lengths:
            bucket_id = self.seq_length_to_bucket_id(seq_length)
            if bucket_id in data_buckets:
                data_buckets[bucket_id].append(idx)
            else:
                data_buckets[bucket_id] = [idx]

        for k in data_buckets:
            data_buckets[k] = np.asarray(data_buckets[k])

        iter_list = []
        for k in data_buckets:
            np.random.shuffle(data_buckets[k])
            n_batches = ceil(data_buckets[k].shape[0]/self.batch_size)
            if n_batches == 0:
                raise ValueError
            # NOTE for each bucket we will likely get a batch that is smaller than the batch size
            iter_list += (np.array_split(data_buckets[k], n_batches))

        if self.shuffle:
            shuffle(iter_list)  # shuffle all the batches so they aren't ordered by bucket

        for i in iter_list:
            yield i.tolist()  # convert array to list

    def __len__(self):
        return len(list(iter(self)))

    def seq_length_to_bucket_id(self, seq_length: int) -> int:
        boundaries = list(self.bucket_boundaries)
        buckets_min = [np.iinfo(np.int32).min] + boundaries
        buckets_max = boundaries + [np.iinfo(np.int32).max]
        conditions_c = np.logical_and(np.less_equal(buckets_min, seq_length), np.less(seq_length, buckets_max))
        bucket_id = np.min(np.where(conditions_c))
        return bucket_id

I have this custom sampler class which doesn't comply with the BatchSampler pytorch API. In training with pl it works fine but when going to predict it gives the above error even when turning off use_distributed_sampler flag.

CompRhys commented 7 months ago

Related issue: https://github.com/Lightning-AI/pytorch-lightning/issues/11807

CompRhys commented 7 months ago

Also related: https://github.com/Lightning-AI/pytorch-lightning/issues/18023

CompRhys commented 7 months ago
class BatchBySequenceLengthSampler(BatchSampler):
    def __init__(self, sampler: Sampler, batch_size: int = 1, drop_last: bool = False) -> None:
        super().__init__(sampler, batch_size=1, drop_last=False)

    def __iter__(self):
        sampler_iter = iter(self.sampler)
        while True:
            try:
                yield next(sampler_iter)
            except StopIteration:
                break

If I add this as a wrapper around my sampler then there's a bunch of extra errors like:

lightning_fabric.utilities.exceptions.MisconfigurationException: Trying to inject parameters into the `CustomDataLoader` instance. This would fail as it doesn't expose all its attributes in the `__init__` signature. The missing arguments are ['batch_sampler', 'drop_last', 'sampler']. HINT: If you wrote the `CustomDataLoader` class, add the `__init__` arguments or allow passing `**kwargs`

that arise. The documentation for this is very lacking for what seems to actually make a large number of API requirements around standard pytorch which feel contrary to the intent of pl. Indeed this is primarily only an issue given the title bug as I am only on a single gpu machine.

awaelchli commented 7 months ago

I won't be able to reply this week. The code you posted is the sampler definition. I don't know how to use it and some imports are missing. Could you please integrate it into a runnable example? For example taking this: https://github.com/Lightning-AI/pytorch-lightning/blob/master/examples/pytorch/bug_report/bug_report_model.py and show how setting use_distributed_sampler=False results in the same issue?

to actually make a large number of API requirements around standard pytorch which feel contrary to the intent of pl

I don't know what you mean here. PyTorch Lightning is built on top of and wrapping PyTorch so we definitely want to follow their API.

CompRhys commented 7 months ago
import os

import torch
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset, Sampler

from torch.nn.utils.rnn import pad_sequence

import numpy as np
from math import ceil
from random import shuffle

class RandomSequenceDataset(Dataset):
    def __init__(self, seq_len, size, length):
        self.len = length
        self.data = [torch.randn((torch.randint(seq_len,(1,)), size)) for _ in range(length)]
        self.idx_seq_lengths = {i: v.size(0) for i, v in enumerate(self.data)}

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

def create_mask_from_lengths(lengths):
    max_length = torch.max(lengths).item()
    idx = torch.arange(max_length).unsqueeze(0)
    mask = idx < lengths.unsqueeze(1)
    return mask.long()  # convert bool to int

def pad_collate(batch):
    lengths = torch.as_tensor([v.size(0) for v in batch], dtype=torch.int64, device="cpu")
    input_pad = pad_sequence(batch, batch_first=True, padding_value=0)
    mask = create_mask_from_lengths(lengths)
    assert mask.size(1) == input_pad.size(1)
    return input_pad

class BySequenceLengthSampler(Sampler):

    def __init__(self, idx_seq_lengths, bucket_boundaries, batch_size=64, shuffle=True):
        self.idx_seq_lengths = idx_seq_lengths
        self.bucket_boundaries = bucket_boundaries
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        data_buckets = dict()

        for idx, seq_length in self.idx_seq_lengths:
            bucket_id = self.seq_length_to_bucket_id(seq_length)
            if bucket_id in data_buckets:
                data_buckets[bucket_id].append(idx)
            else:
                data_buckets[bucket_id] = [idx]

        for k in data_buckets:
            data_buckets[k] = np.asarray(data_buckets[k])

        iter_list = []
        for k in data_buckets:
            np.random.shuffle(data_buckets[k])
            n_batches = ceil(data_buckets[k].shape[0]/self.batch_size)
            if n_batches == 0:
                raise ValueError
            # NOTE for each bucket we will likely get a batch that is smaller than the batch size
            iter_list += (np.array_split(data_buckets[k], n_batches))

        if self.shuffle:
            shuffle(iter_list)  # shuffle all the batches so they aren't ordered by bucket

        for i in iter_list:
            yield i.tolist()  # convert array to list

    def __len__(self):
        return len(list(iter(self)))

    def seq_length_to_bucket_id(self, seq_length: int) -> int:
        boundaries = list(self.bucket_boundaries)
        buckets_min = [np.iinfo(np.int32).min] + boundaries
        buckets_max = boundaries + [np.iinfo(np.int32).max]
        conditions_c = np.logical_and(np.less_equal(buckets_min, seq_length), np.less(seq_length, buckets_max))
        bucket_id = np.min(np.where(conditions_c))
        return bucket_id

class SequenceDataLoader(DataLoader):
    def __init__(self, dataset, batch_size, bucket_boundaries, shuffle=True, **kwargs):
        self.bucket_boundaries = bucket_boundaries
        idx_seq_lengths = list(dataset.idx_seq_lengths.items())
        sampler = BySequenceLengthSampler(idx_seq_lengths, bucket_boundaries, batch_size, shuffle=shuffle)
        super().__init__(dataset, batch_sampler=sampler, collate_fn=pad_collate)

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(3, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run():
    train_data = SequenceDataLoader(RandomSequenceDataset(32, 3, 64), batch_size=2, bucket_boundaries=[10, 20, 30])
    val_data = SequenceDataLoader(RandomSequenceDataset(32, 3, 64), batch_size=2, bucket_boundaries=[10, 20, 30])
    test_data = SequenceDataLoader(RandomSequenceDataset(32, 3, 64), batch_size=2, bucket_boundaries=[10, 20, 30])

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        use_distributed_sampler=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)

    trainer.predict(model, dataloaders=test_data)

if __name__ == "__main__":
    run()

This throws the error on the predict call even with the flag turned off.

I don't know what you mean here. PyTorch Lightning is built on top of and wrapping PyTorch so we definitely want to follow their API.

I mean that Sampler is the BaseClass, and that restricting allowed inputs to BatchSampler's API of just sampler, batch_size, drop_last breaks the idea that you can just wrap functional torch code

awaelchli commented 7 months ago

I opened a PR that relaxes the conditions for prediction, and I used your code example to validate that it works. If the tests pass, we can probably move forward with it. Feel free to try it out if you want.