pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.47k stars 340 forks source link

🐛 [Bug] Cannot export with PTQ using a cached calibrator: "TypeError: cannot pickle DataLoaderCalibrator object" #2168

Open laclouis5 opened 1 year ago

laclouis5 commented 1 year ago

Bug Description

When doing Post-training quantization using the INT8 calibration API, the model export works fine when using the ptq.DataLoaderCalibrator but there is a runtime error when loading the calibrator from the cache using ptq.CacheCalibrator:

TypeError: cannot pickle 'DataLoaderCalibrator' object

To Reproduce

Here is some example code to reproduce the issue. I'm using Torchvision's ResNet18 for simplicity.

First, export a model with PTQ and cache the calibrator data:

import torch
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import torch_tensorrt as trt
import torch_tensorrt.ptq as ptq

C, H, W = 3, 768, 1024

class PTQDataset:
    def __len__(self):
        return 100

    def __getitem__(self, i):
        return torch.randn(C, H, W)

def main():
    model = resnet18().eval().cuda()
    example_input = torch.randn(1, C, H, W, device="cuda")

    traced_model = torch.jit.trace(model, example_input)

    test_ds = PTQDataset()
    test_dl = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=1)
    calibrator = ptq.DataLoaderCalibrator(test_dl, cache_file="calibrator.cache")

    with trt.logging.debug():
        trt_model = trt.ts.compile(
            module=traced_model,
            inputs=[example_input],
            enabled_precisions={torch.float, torch.half, torch.int8},
            calibrator=calibrator,
        )

    torch.jit.save(trt_model, "resnet18.ts")

if __name__ == "__main__":
    main()

This export works fine and the calibration cache is saved to disk.

Second, export the model again with PTQ but load the calibrator from the cache, as explained in the documentation:

import torch
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import torch_tensorrt as trt
import torch_tensorrt.ptq as ptq

C, H, W = 3, 768, 1024

class PTQDataset:
    def __len__(self):
        return 100

    def __getitem__(self, i):
        return torch.randn(C, H, W)

def main():
    model = resnet18().eval().cuda()
    example_input = torch.randn(1, C, H, W, device="cuda")

    traced_model = torch.jit.trace(model, example_input)

    calibrator = ptq.CacheCalibrator("calibrator.cache")

    with trt.logging.debug():
        trt_model = trt.ts.compile(
            module=traced_model,
            inputs=[example_input],
            enabled_precisions={torch.float, torch.half, torch.int8},
            calibrator=calibrator,
        )

    torch.jit.save(trt_model, "resnet18.ts")

if __name__ == "__main__":
    main()

This fails with the error:

TypeError: cannot pickle 'DataLoaderCalibrator' object

Expected behavior

This should work and produce the exact same output as the first export. If this fails for whatever reason, the error should be clear and should help the user solving the issue.

Environment

peri044 commented 11 months ago

This issue is because we deepcopy the calibrator object (whose pickling is not defined). Can you replace this line with https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/ts/_compile_spec.py#L228 with compile_spec = compile_spec_ ? We shall investigate this further after verifying other test cases.

laclouis5 commented 11 months ago

My apologies for the late reply. This gives a different error later in the process:

terminate called after throwing an instance of 'pybind11::error_already_set'
  what():  TypeError: get_cache_mode_batch() takes 1 positional argument but 2 were given
Aborted (core dumped)
peri044 commented 8 months ago

Hello @laclouis5 Sorry for the delay. Can you let me know what inputs are you receiving for this function ? https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/ptq.py#L24-L25

laclouis5 commented 8 months ago

@peri044 I suppose I have to pass the CacheCalibrator as input to this function like this?

calibrator = ptq.CacheCalibrator("calibrator.cache")
cache_mode = ptq.get_cache_mode_batch(calibrator)
print(cache_mode)

This returns None.

I see that this function is bound to the CacheCalibrator class later in the code as a get_batch method. Calling this method directly also returns None.

peri044 commented 8 months ago

calibrator = ptq.CacheCalibrator("calibrator.cache") is the right usage. You don't have to ever use get_cache_mode_batch directly and the signature for this function is https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/ptq.py#L24. Once you define the CacheCalibrator class, it needs to be passed into the calibrator argument.

laclouis5 commented 8 months ago

Ok, so a something like this is what you are looking for?

def get_cache_mode_batch(self, *args):
    print(self, args)
    return None

This prints:

<torch_tensorrt.ptq.DataLoaderCalibrator object at 0x7fc114faa110> (['input_0'],)
peri044 commented 7 months ago

Hello @laclouis5 sorry for the delay. I have a workaround for you for the error in https://github.com/pytorch/TensorRT/issues/2168#issuecomment-1732336504. You can try using a DataLoaderCalibrator with use_cache=True to use the calibration cache file you already have.

test_ds = PTQDataset()
test_dl = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=1)
calibrator = ptq.DataLoaderCalibrator(test_dl, use_cache=True, cache_file="calibrator.cache")

This won't recalibrate and the engine seems to be compiling successfully.