Closed CompRhys closed 7 months ago
@CompRhys If you share a runnable code example, I can take a look what caused it. You probably have a custom batch sampler that it causing trouble for Lightning to handle.
Hi,
class BySequenceLengthSampler(Sampler):
def __init__(self, idx_seq_lengths, bucket_boundaries, batch_size=64, shuffle=True):
self.idx_seq_lengths = idx_seq_lengths
self.bucket_boundaries = bucket_boundaries
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
data_buckets = dict()
for idx, seq_length in self.idx_seq_lengths:
bucket_id = self.seq_length_to_bucket_id(seq_length)
if bucket_id in data_buckets:
data_buckets[bucket_id].append(idx)
else:
data_buckets[bucket_id] = [idx]
for k in data_buckets:
data_buckets[k] = np.asarray(data_buckets[k])
iter_list = []
for k in data_buckets:
np.random.shuffle(data_buckets[k])
n_batches = ceil(data_buckets[k].shape[0]/self.batch_size)
if n_batches == 0:
raise ValueError
# NOTE for each bucket we will likely get a batch that is smaller than the batch size
iter_list += (np.array_split(data_buckets[k], n_batches))
if self.shuffle:
shuffle(iter_list) # shuffle all the batches so they aren't ordered by bucket
for i in iter_list:
yield i.tolist() # convert array to list
def __len__(self):
return len(list(iter(self)))
def seq_length_to_bucket_id(self, seq_length: int) -> int:
boundaries = list(self.bucket_boundaries)
buckets_min = [np.iinfo(np.int32).min] + boundaries
buckets_max = boundaries + [np.iinfo(np.int32).max]
conditions_c = np.logical_and(np.less_equal(buckets_min, seq_length), np.less(seq_length, buckets_max))
bucket_id = np.min(np.where(conditions_c))
return bucket_id
I have this custom sampler class which doesn't comply with the BatchSampler
pytorch API. In training with pl it works fine but when going to predict it gives the above error even when turning off use_distributed_sampler
flag.
class BatchBySequenceLengthSampler(BatchSampler):
def __init__(self, sampler: Sampler, batch_size: int = 1, drop_last: bool = False) -> None:
super().__init__(sampler, batch_size=1, drop_last=False)
def __iter__(self):
sampler_iter = iter(self.sampler)
while True:
try:
yield next(sampler_iter)
except StopIteration:
break
If I add this as a wrapper around my sampler then there's a bunch of extra errors like:
lightning_fabric.utilities.exceptions.MisconfigurationException: Trying to inject parameters into the `CustomDataLoader` instance. This would fail as it doesn't expose all its attributes in the `__init__` signature. The missing arguments are ['batch_sampler', 'drop_last', 'sampler']. HINT: If you wrote the `CustomDataLoader` class, add the `__init__` arguments or allow passing `**kwargs`
that arise. The documentation for this is very lacking for what seems to actually make a large number of API requirements around standard pytorch which feel contrary to the intent of pl. Indeed this is primarily only an issue given the title bug as I am only on a single gpu machine.
I won't be able to reply this week.
The code you posted is the sampler definition. I don't know how to use it and some imports are missing. Could you please integrate it into a runnable example? For example taking this: https://github.com/Lightning-AI/pytorch-lightning/blob/master/examples/pytorch/bug_report/bug_report_model.py
and show how setting use_distributed_sampler=False
results in the same issue?
to actually make a large number of API requirements around standard pytorch which feel contrary to the intent of pl
I don't know what you mean here. PyTorch Lightning is built on top of and wrapping PyTorch so we definitely want to follow their API.
import os
import torch
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset, Sampler
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from math import ceil
from random import shuffle
class RandomSequenceDataset(Dataset):
def __init__(self, seq_len, size, length):
self.len = length
self.data = [torch.randn((torch.randint(seq_len,(1,)), size)) for _ in range(length)]
self.idx_seq_lengths = {i: v.size(0) for i, v in enumerate(self.data)}
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
def create_mask_from_lengths(lengths):
max_length = torch.max(lengths).item()
idx = torch.arange(max_length).unsqueeze(0)
mask = idx < lengths.unsqueeze(1)
return mask.long() # convert bool to int
def pad_collate(batch):
lengths = torch.as_tensor([v.size(0) for v in batch], dtype=torch.int64, device="cpu")
input_pad = pad_sequence(batch, batch_first=True, padding_value=0)
mask = create_mask_from_lengths(lengths)
assert mask.size(1) == input_pad.size(1)
return input_pad
class BySequenceLengthSampler(Sampler):
def __init__(self, idx_seq_lengths, bucket_boundaries, batch_size=64, shuffle=True):
self.idx_seq_lengths = idx_seq_lengths
self.bucket_boundaries = bucket_boundaries
self.batch_size = batch_size
self.shuffle = shuffle
def __iter__(self):
data_buckets = dict()
for idx, seq_length in self.idx_seq_lengths:
bucket_id = self.seq_length_to_bucket_id(seq_length)
if bucket_id in data_buckets:
data_buckets[bucket_id].append(idx)
else:
data_buckets[bucket_id] = [idx]
for k in data_buckets:
data_buckets[k] = np.asarray(data_buckets[k])
iter_list = []
for k in data_buckets:
np.random.shuffle(data_buckets[k])
n_batches = ceil(data_buckets[k].shape[0]/self.batch_size)
if n_batches == 0:
raise ValueError
# NOTE for each bucket we will likely get a batch that is smaller than the batch size
iter_list += (np.array_split(data_buckets[k], n_batches))
if self.shuffle:
shuffle(iter_list) # shuffle all the batches so they aren't ordered by bucket
for i in iter_list:
yield i.tolist() # convert array to list
def __len__(self):
return len(list(iter(self)))
def seq_length_to_bucket_id(self, seq_length: int) -> int:
boundaries = list(self.bucket_boundaries)
buckets_min = [np.iinfo(np.int32).min] + boundaries
buckets_max = boundaries + [np.iinfo(np.int32).max]
conditions_c = np.logical_and(np.less_equal(buckets_min, seq_length), np.less(seq_length, buckets_max))
bucket_id = np.min(np.where(conditions_c))
return bucket_id
class SequenceDataLoader(DataLoader):
def __init__(self, dataset, batch_size, bucket_boundaries, shuffle=True, **kwargs):
self.bucket_boundaries = bucket_boundaries
idx_seq_lengths = list(dataset.idx_seq_lengths.items())
sampler = BySequenceLengthSampler(idx_seq_lengths, bucket_boundaries, batch_size, shuffle=shuffle)
super().__init__(dataset, batch_sampler=sampler, collate_fn=pad_collate)
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(3, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = SequenceDataLoader(RandomSequenceDataset(32, 3, 64), batch_size=2, bucket_boundaries=[10, 20, 30])
val_data = SequenceDataLoader(RandomSequenceDataset(32, 3, 64), batch_size=2, bucket_boundaries=[10, 20, 30])
test_data = SequenceDataLoader(RandomSequenceDataset(32, 3, 64), batch_size=2, bucket_boundaries=[10, 20, 30])
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
use_distributed_sampler=False,
)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
trainer.test(model, dataloaders=test_data)
trainer.predict(model, dataloaders=test_data)
if __name__ == "__main__":
run()
This throws the error on the predict call even with the flag turned off.
I don't know what you mean here. PyTorch Lightning is built on top of and wrapping PyTorch so we definitely want to follow their API.
I mean that Sampler
is the BaseClass, and that restricting allowed inputs to BatchSampler
's API of just sampler, batch_size, drop_last
breaks the idea that you can just wrap functional torch code
I opened a PR that relaxes the conditions for prediction, and I used your code example to validate that it works. If the tests pass, we can probably move forward with it. Feel free to try it out if you want.
Bug description
Using a custom dataloader I get the following error regardless of whether or not I set
use_distributed_sampler=False
What version are you seeing the problem on?
v2.1
How to reproduce the bug
No response
Error messages and logs
Environment
Current environment
* CUDA: - GPU: None - available: False - version: None * Lightning: - lightning-utilities: 0.10.1 - pytorch-lightning: 2.1.0 - torch: 2.2.0 - torch-geometric: 2.4.0 - torchmetrics: 1.3.0.post0 - torchvision: 0.17.0 * Packages: - about-time: 4.2.1 - absl-py: 2.1.0 - aiobotocore: 2.11.2 - aiohttp: 3.9.3 - aioitertools: 0.11.0 - aiosignal: 1.3.1 - alembic: 1.13.1 - alive-progress: 3.1.5 - amqp: 5.2.0 - aniso8601: 9.0.1 - annotated-types: 0.6.0 - anyio: 4.2.0 - appnope: 0.1.4 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.3.0 - asttokens: 2.4.1 - async-lru: 2.0.4 - async-timeout: 4.0.3 - attrs: 23.2.0 - autograd: 1.6.2 - autograd-gamma: 0.5.0 - babel: 2.14.0 - backoff: 2.2.1 - beautifulsoup4: 4.12.3 - billiard: 4.2.0 - bleach: 6.1.0 - bokeh: 2.4.3 - botocore: 1.34.34 - build: 1.0.3 - cachelib: 0.9.0 - cachetools: 5.3.2 - celery: 5.3.6 - cell-model: 0.1.0 - certifi: 2024.2.2 - cffi: 1.16.0 - cfgv: 3.4.0 - charset-normalizer: 3.3.2 - click: 8.1.7 - click-didyoumean: 0.3.0 - click-plugins: 1.1.1 - click-repl: 0.3.0 - cloudpickle: 3.0.0 - cma: 3.2.2 - colorama: 0.4.6 - coloredlogs: 14.0 - comm: 0.2.1 - contourpy: 1.2.0 - coverage: 7.4.2 - cramjam: 2.8.1 - croniter: 2.0.1 - cryptography: 39.0.1 - cycler: 0.12.1 - dagit: 1.6.3 - dagster: 1.6.3 - dagster-cloud: 1.6.3 - dagster-cloud-cli: 1.6.3 - dagster-graphql: 1.6.3 - dagster-pipes: 1.6.3 - dagster-webserver: 1.6.3 - dash: 2.15.0 - dash-ag-grid: 2.4.0 - dash-bootstrap-components: 1.5.0 - dash-core-components: 2.0.0 - dash-daq: 0.5.0 - dash-html-components: 2.0.0 - dash-table: 5.0.0 - dask: 2023.2.1 - dask-cloudprovider: 2022.10.0 - db-dtypes: 1.1.1 - debugpy: 1.8.0 - decorator: 5.1.1 - deepchem: 2.7.2.dev20240208164340 - defusedxml: 0.7.1 - deprecated: 1.2.14 - dill: 0.3.8 - diskcache: 5.6.3 - distlib: 0.3.8 - distributed: 2023.2.1 - docker: 6.0.1 - docstring-parser: 0.15 - einops: 0.7.0 - et-xmlfile: 1.1.0 - exceptiongroup: 1.2.0 - executing: 2.0.1 - fakeredis: 2.21.0 - fastjsonschema: 2.19.1 - fastparquet: 2022.12.0 - filelock: 3.13.1 - flask: 2.2.2 - flask-caching: 2.1.0 - flask-login: 0.6.3 - fonttools: 4.49.0 - formulaic: 1.0.1 - fqdn: 1.5.1 - frozenlist: 1.4.1 - fsspec: 2023.12.2 - future: 1.0.0 - gcsfs: 2023.12.2.post1 - github3.py: 4.0.1 - google-api-core: 2.17.1 - google-api-python-client: 2.79.0 - google-auth: 2.28.1 - google-auth-httplib2: 0.2.0 - google-auth-oauthlib: 1.2.0 - google-cloud-bigquery: 3.17.2 - google-cloud-core: 2.4.1 - google-cloud-storage: 2.7.0 - google-crc32c: 1.5.0 - google-resumable-media: 2.7.0 - googleapis-common-protos: 1.62.0 - gql: 3.5.0 - grapheme: 0.6.0 - graphene: 3.3 - graphql-core: 3.2.3 - graphql-relay: 3.2.0 - greenlet: 3.0.3 - grpcio: 1.62.0 - grpcio-health-checking: 1.47.0 - h11: 0.14.0 - heapdict: 1.0.1 - holidays: 0.32 - httpcore: 1.0.2 - httplib2: 0.21.0 - httptools: 0.6.1 - httpx: 0.26.0 - humanfriendly: 10.0 - identify: 2.5.33 - idna: 3.6 - importlib-metadata: 7.0.1 - importlib-resources: 5.10.1 - inflection: 0.5.1 - iniconfig: 2.0.0 - interface-meta: 1.3.0 - ipykernel: 6.29.2 - ipython: 8.21.0 - ipywidgets: 8.1.1 - isoduration: 20.11.0 - itsdangerous: 2.1.2 - jedi: 0.19.1 - jinja2: 3.1.3 - jmespath: 1.0.1 - joblib: 1.3.2 - json5: 0.9.14 - jsonpointer: 2.4 - jsonschema: 4.21.1 - jsonschema-specifications: 2023.12.1 - jupyter-client: 8.6.0 - jupyter-console: 6.6.3 - jupyter-core: 5.7.1 - jupyter-events: 0.9.0 - jupyter-lsp: 2.2.2 - jupyter-server: 2.12.5 - jupyter-server-terminals: 0.5.2 - jupyterlab: 4.1.0 - jupyterlab-pygments: 0.3.0 - jupyterlab-server: 2.25.2 - jupyterlab-widgets: 3.0.9 - kaleido: 0.2.1 - kiwisolver: 1.4.5 - kombu: 5.3.5 - lifelines: 0.28.0 - lightning-utilities: 0.10.1 - locket: 1.0.0 - lxml: 4.9.2 - mako: 1.3.2 - markdown: 3.5.2 - markdown-it-py: 3.0.0 - markupsafe: 2.1.5 - marshmallow: 3.19.0 - marshmallow-oneofschema: 3.0.1 - matplotlib: 3.8.3 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mistune: 3.0.2 - mockito: 1.4.0 - mpmath: 1.3.0 - msgpack: 1.0.7 - multidict: 6.0.5 - multimethod: 1.11.1 - multiprocess: 0.70.16 - mypy-extensions: 1.0.0 - mysql: 0.0.3 - mysql-connector-python: 8.0.33 - mysqlclient: 2.0.3 - nbclient: 0.9.0 - nbconvert: 7.16.0 - nbformat: 5.9.2 - nest-asyncio: 1.6.0 - networkx: 3.2.1 - nodeenv: 1.8.0 - notebook: 7.0.7 - notebook-shim: 0.2.3 - numpy: 1.23.5 - oauth2client: 4.1.3 - oauthlib: 3.2.2 - openpyxl: 3.0.10 - outcome: 1.3.0.post0 - overrides: 7.7.0 - packaging: 23.2 - pandas: 1.5.2 - pandera: 0.17.2 - pandocfilters: 1.5.1 - parso: 0.8.3 - partd: 1.4.1 - pendulum: 3.0.0 - pex: 2.1.163 - pexpect: 4.9.0 - pillow: 10.2.0 - pip: 24.0 - pip-tools: 7.3.0 - platformdirs: 4.2.0 - plotly: 5.17.0 - pluggy: 1.4.0 - pre-commit: 3.6.0 - prometheus-client: 0.19.0 - prompt-toolkit: 3.0.43 - protobuf: 3.20.3 - psutil: 5.9.8 - ptyprocess: 0.7.0 - pubchempy: 1.0.4 - pure-eval: 0.2.2 - pyairtable: 2.1.0.post1 - pyarrow: 10.0.1 - pyasn1: 0.5.1 - pyasn1-modules: 0.3.0 - pycparser: 2.21 - pydantic: 2.6.1 - pydantic-core: 2.16.2 - pygit2: 1.11.1 - pygithub: 2.1.1 - pygments: 2.17.2 - pyjwt: 2.8.0 - pymsteams: 0.2.2 - pymysql: 1.1.0 - pynacl: 1.5.0 - pyparsing: 3.1.1 - pyproject-hooks: 1.0.0 - pysocks: 1.7.1 - pytest: 7.4.0 - pytest-cov: 4.1.0 - python-box: 6.1.0 - python-dateutil: 2.8.2 - python-dotenv: 1.0.1 - python-json-logger: 2.0.7 - python-slugify: 7.0.0 - pytorch-lightning: 2.1.0 - pytz: 2024.1 - pytzdata: 2020.1 - pyyaml: 6.0.1 - pyzmq: 25.1.2 - qtconsole: 5.5.1 - qtpy: 2.4.1 - questionary: 1.10.0 - rdkit: 2023.9.4 - redis: 5.0.1 - referencing: 0.33.0 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - requests-toolbelt: 1.0.0 - retrying: 1.3.4 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.7.0 - rpds-py: 0.17.1 - rsa: 4.9 - ruff: 0.2.1 - s3fs: 2023.12.2 - scikit-learn: 1.4.0 - scipy: 1.12.0 - selenium: 4.8.0 - send2trash: 1.8.2 - setuptools: 69.0.3 - shellingham: 1.5.4 - six: 1.16.0 - sniffio: 1.3.0 - sortedcontainers: 2.4.0 - soupsieve: 2.5 - sqlalchemy: 2.0.22 - stack-data: 0.6.3 - starlette: 0.37.0 - structlog: 24.1.0 - sympy: 1.12 - tabulate: 0.9.0 - tblib: 3.0.0 - tenacity: 8.2.3 - tensorboard: 2.15.1 - tensorboard-data-server: 0.7.2 - terminado: 0.18.0 - text-unidecode: 1.3 - threadpoolctl: 3.3.0 - time-machine: 2.13.0 - tinycss2: 1.2.1 - toml: 0.10.2 - tomli: 2.0.1 - toolz: 0.12.1 - toposort: 1.10 - torch: 2.2.0 - torch-geometric: 2.4.0 - torchmetrics: 1.3.0.post0 - torchvision: 0.17.0 - tornado: 6.4 - tqdm: 4.64.1 - traitlets: 5.14.1 - trio: 0.24.0 - trio-websocket: 0.11.1 - typeguard: 4.1.5 - typer: 0.9.0 - types-python-dateutil: 2.8.19.20240106 - typing-extensions: 4.9.0 - typing-inspect: 0.9.0 - tzdata: 2024.1 - universal-pathlib: 0.0.23 - uri-template: 1.3.0 - uritemplate: 4.1.1 - urllib3: 2.0.7 - uvicorn: 0.27.0.post1 - uvloop: 0.19.0 - vine: 5.1.0 - virtualenv: 20.25.0 - watchdog: 4.0.0 - watchfiles: 0.21.0 - wcwidth: 0.2.13 - webcolors: 1.13 - webdriver-manager: 3.8.5 - webencodings: 0.5.1 - websocket-client: 1.4.2 - websockets: 12.0 - werkzeug: 2.3.7 - wheel: 0.42.0 - widgetsnbextension: 4.0.9 - wrapt: 1.16.0 - wsproto: 1.2.0 - yarl: 1.9.4 - zict: 3.0.0 - zipp: 3.17.0 * System: - OS: Darwin - architecture: - 64bit - - processor: i386 - python: 3.10.11 - release: 23.2.0 - version: Darwin Kernel Version 23.2.0: Wed Nov 15 21:54:10 PST 2023; root:xnu-10002.61.3~2/RELEASE_X86_64More info
No response
cc @justusschock @awaelchli