Open laclouis5 opened 1 year ago
This issue is because we deepcopy the calibrator object (whose pickling is not defined). Can you replace this line with https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/ts/_compile_spec.py#L228 with compile_spec = compile_spec_
? We shall investigate this further after verifying other test cases.
My apologies for the late reply. This gives a different error later in the process:
terminate called after throwing an instance of 'pybind11::error_already_set'
what(): TypeError: get_cache_mode_batch() takes 1 positional argument but 2 were given
Aborted (core dumped)
Hello @laclouis5 Sorry for the delay. Can you let me know what inputs are you receiving for this function ? https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/ptq.py#L24-L25
@peri044 I suppose I have to pass the CacheCalibrator
as input to this function like this?
calibrator = ptq.CacheCalibrator("calibrator.cache")
cache_mode = ptq.get_cache_mode_batch(calibrator)
print(cache_mode)
This returns None
.
I see that this function is bound to the CacheCalibrator
class later in the code as a get_batch
method. Calling this method directly also returns None
.
calibrator = ptq.CacheCalibrator("calibrator.cache")
is the right usage. You don't have to ever use get_cache_mode_batch
directly and the signature for this function is https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/ptq.py#L24. Once you define the CacheCalibrator
class, it needs to be passed into the calibrator
argument.
Ok, so a something like this is what you are looking for?
def get_cache_mode_batch(self, *args):
print(self, args)
return None
This prints:
<torch_tensorrt.ptq.DataLoaderCalibrator object at 0x7fc114faa110> (['input_0'],)
Hello @laclouis5 sorry for the delay. I have a workaround for you for the error in https://github.com/pytorch/TensorRT/issues/2168#issuecomment-1732336504.
You can try using a DataLoaderCalibrator
with use_cache=True to use the calibration cache file you already have.
test_ds = PTQDataset()
test_dl = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=1)
calibrator = ptq.DataLoaderCalibrator(test_dl, use_cache=True, cache_file="calibrator.cache")
This won't recalibrate and the engine seems to be compiling successfully.
Bug Description
When doing Post-training quantization using the INT8 calibration API, the model export works fine when using the
ptq.DataLoaderCalibrator
but there is a runtime error when loading the calibrator from the cache usingptq.CacheCalibrator
:To Reproduce
Here is some example code to reproduce the issue. I'm using Torchvision's ResNet18 for simplicity.
First, export a model with PTQ and cache the calibrator data:
This export works fine and the calibration cache is saved to disk.
Second, export the model again with PTQ but load the calibrator from the cache, as explained in the documentation:
This fails with the error:
Expected behavior
This should work and produce the exact same output as the first export. If this fails for whatever reason, the error should be clear and should help the user solving the issue.
Environment
conda
,pip
,libtorch
, source): pip