libffcv / ffcv

FFCV: Fast Forward Computer Vision (and other ML workloads!)
https://ffcv.io
Apache License 2.0
2.79k stars 180 forks source link

Numba typing error: Cannot determine Numba type of <class 'ffcv.transforms.module.ModuleWrapper'> #372

Closed jnboehm closed 2 months ago

jnboehm commented 2 months ago

I get the following error when trying to create an FFCV data loader from a .beton file.

Exception in thread Thread-1:
Traceback (most recent call last):
  File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/usr/local/lib/python3.11/dist-packages/ffcv/loader/epoch_iterator.py", line 84, in run
    result = self.run_pipeline(b_ix, ixes, slot, events[slot])
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/ffcv/loader/epoch_iterator.py", line 146, in run_pipeline
    results = stage_code(**args)
              ^^^^^^^^^^^^^^^^^^
  File "", line 2, in stage_code_0
  File "/usr/local/lib/python3.11/dist-packages/numba/core/dispatcher.py", line 468, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/usr/local/lib/python3.11/dist-packages/numba/core/dispatcher.py", line 409, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Untyped global name 'self': Cannot determine Numba type of <class 'ffcv.transforms.module.ModuleWrapper'>

The execution the code hangs when trying to access the first batch. When I interrupt this with Ctrl-C (I was waiting for quite a long time, much longer than what it usually takes to access the first batch) then I see that it was waiting for some lock:

KeyboardInterrupt                         Traceback (most recent call last)
Cell In[11], line 1
----> 1 b= next(iter(l))

File /usr/local/lib/python3.11/dist-packages/ffcv/loader/epoch_iterator.py:155, in EpochIterator.__next__(self)
    154 def __next__(self):
--> 155     result = self.output_queue.get()
    156     if result is None:
    157         self.close()

File /usr/lib/python3.11/queue.py:171, in Queue.get(self, block, timeout)
    169 elif timeout is None:
    170     while not self._qsize():
--> 171         self.not_empty.wait()
    172 elif timeout < 0:
    173     raise ValueError("'timeout' must be a non-negative number")

File /usr/lib/python3.11/threading.py:327, in Condition.wait(self, timeout)
    325 try:    # restore state no matter what (e.g., KeyboardInterrupt)
    326     if timeout is None:
--> 327         waiter.acquire()
    328         gotit = True
    329     else:

KeyboardInterrupt:

The code below should reproduce the stacktrace. I am using ffcv 1.0.2, torch 2.2 and the latest stable version of torchvision. Not sure if the use of the v2 transforms is causing the problem of if there is some other issue.

import ffcv
import torch
from torchvision.transforms import v2 as transforms

pil2tensor = transforms.Compose(
    [
        transforms.ToImage(),
        transforms.ToDtype(torch.float32, scale=True),
    ]
)

ffcvtr = [ffcv.fields.rgb_image.SimpleRGBImageDecoder()] + pil2tensor.transforms

l = ffcv.Loader(
    "cifar10.beton",
    batch_size=1024,
    order=ffcv.loader.OrderOption.QUASI_RANDOM,
    num_workers=8,
    pipelines=dict(image=ffcvtr),
)

next(iter(l))
jnboehm commented 2 months ago

FWIW, I also observe the same problem when I use the following pipelines, only relying on FFCV transformations:

l = ffcv.Loader(
    "cifar10.beton",
    batch_size=1024,
    order=ffcv.loader.OrderOption.QUASI_RANDOM,
    num_workers=8,
    pipelines=dict(
       image=[
          ffcv.fields.rgb_image.SimpleRGBImageDecoder(),
          ffcv.transforms.ToTorchImage(),
       ]
    ),
)
jnboehm commented 2 months ago

I managed to figure out how to resolve my issue, which I now guess was due to incorrect usage of the API (unfortunately the errors are quite hard to decipher from my perspective). I'll leave my work around below for posterity's sake.

The error (of the last loader, relying only on the FFCV transforms) seems to be due to incompatible transforms. As I understand it, ffcv.transforms.ToTorchImage() cannot directly transform the decoded image (which is in the form of a numpy array) to a torchvision image. The error in this case read (which wasn't helpful):

Exception in thread Thread-147:
Traceback (most recent call last):
  File "/usr/lib/python3.11/threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "/usr/local/lib/python3.11/dist-packages/ffcv/loader/epoch_iterator.py", line 84, in run
    result = self.run_pipeline(b_ix, ixes, slot, events[slot])
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/ffcv/loader/epoch_iterator.py", line 146, in run_pipeline
    results = stage_code(**args)
              ^^^^^^^^^^^^^^^^^^
  File "", line 2, in stage_code_0
  File "/usr/local/lib/python3.11/dist-packages/numba/core/dispatcher.py", line 468, in _compile_for_args
    error_rewrite(e, 'typing')
  File "/usr/local/lib/python3.11/dist-packages/numba/core/dispatcher.py", line 409, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'permute' of type array(uint8, 4d, C)

File "../../../../usr/local/lib/python3.11/dist-packages/ffcv/transforms/ops.py", line 109:
        def to_torch_image(inp: ch.Tensor, dst):
            <source elided>
                pass
            inp = inp.permute([0, 3, 1, 2])
            ^

During: typing of get attribute at /usr/local/lib/python3.11/dist-packages/ffcv/transforms/ops.py (109)

File "../../../../usr/local/lib/python3.11/dist-packages/ffcv/transforms/ops.py", line 109:
        def to_torch_image(inp: ch.Tensor, dst):
            <source elided>
                pass
            inp = inp.permute([0, 3, 1, 2])
            ^

Apparently the correct usage is:

l = ffcv.Loader(
    "cifar10.beton",
    batch_size=1024,
    order=ffcv.loader.OrderOption.QUASI_RANDOM,
    num_workers=8,
    pipelines=dict(
       image=[
          ffcv.fields.rgb_image.SimpleRGBImageDecoder(),
          ffcv.transforms.ToTorchImage(),
          ffcv.transforms.ToTorchImage(convert_back_int16=False),
       ]
    ),
)

which does work. To get the image in a float tensor that pytorch can then work with you need to use another transformation, which you need to write yourself (taken from https://github.com/SerezD/ffcv_pytorch_lightning/blob/32def277998a1b9609703cd6a72b847eb2149c5c/src/ffcv_pl/ffcv_utils/augmentations.py):

l = ffcv.Loader(
    "cifar10.beton",
    batch_size=1024,
    order=ffcv.loader.OrderOption.QUASI_RANDOM,
    num_workers=8,
    pipelines=dict(
       image=[
          ffcv.fields.rgb_image.SimpleRGBImageDecoder(),
          ffcv.transforms.ToTorchImage(),
          ffcv.transforms.ToTorchImage(convert_back_int16=False),
          DivideImageBy255(self.dtype),
       ]
    ),
)

I guess it would be nice to improve the error reporting since that was the main reason for me reporting this issue in the first place. Or at least a remark somewhere that the augmentations cannot be intermixed freely, as it is described on the main website.

But I think that kind of goes beyond the scope of this issue (or at least the title would not fit well anymore). So feel free to close this.