Open david-waterworth opened 1 year ago
I think the warning has been shown correctly. pickle
has the limitation on local function. Could you please move batch_transform
out of run
function?
Yes either moving batch_transform
or import dill
work fine, but my point is my interpretation of
UserWarning: Local function is not supported by pickle, please use regular python function or functools.partial instead.
Was that functools.partial(batch_transform) should work but it doesn't. It might be clearer if it says
UserWarning: Local function is not supported by pickle, please use regular python function or ensure dill is available.
I'm not really sure what the reference to functools.partial
means? I confirmed that functools.partial(batch_transform)
cannot be pickled so I agree that the tests seem valid.
General reason for local function is to have some local variable used by the function like:
def fn():
local_var = some_instance()
def local_fn(x):
return local_var + x
return local_fn
This can be achieved by a partial function + regular function like:
def local_fn(x, local_var):
return local_var + x
def fn():
local_var = some_instance()
return functools.partial(local_fn, local_var=local_var)
Thanks, yes I understand functools.partial
in general, I was just confused as to what the warning meant when it states use regular python function **or** functools.partial
. It doesn't seem that functools.partial
can be used to turn a local function into something that can be pickled. It can of course be used to partially apply a regular python function but that's not what the warning is about.
@david-waterworth Thanks for call out. We do need to improve the warning message. Do you want to send a PR to PyTorch to fix that?
Re-opening the issue to track the task about making the warning about local function better. Users should either rely on regular python function + partial, or install dill package.
Does anyone have updated the message? I could do that if needed
@DongyuXu77 Not yet. Pls feel free to open PR
I have improved the message, but it seems that I need to open a PR in pytorch instead of data. What should I do?
@ejguan Should I copy this issue to pytorch/pytorch
and open a PR? The code which need to be updated doesn't shows in this repository.
@DongyuXu77 You don't need to copy the issue. Just open a PR in PyTorch Core and add fixes: https://github.com/pytorch/data/issues/947
in the PR summary.
🐛 Describe the bug
I used a local function in a pipeline which results in the warning:
So I wrapped the local function in functools.partial and the warning persisted, so I investigated closer. The problem appears to be that
_check_unpickable_fn
is undoing the partial before checking if fn is local, i.e.This code unwraps partial, converting back to the underlying local function
https://github.com/pytorch/pytorch/blob/301644d3cb8ce7ecb477b5bcb95a7a8a1304fa71/torch/utils/data/datapipes/utils/common.py#L132
And this code then produces the warning
https://github.com/pytorch/pytorch/blob/301644d3cb8ce7ecb477b5bcb95a7a8a1304fa71/torch/utils/data/datapipes/utils/common.py#L136
My calling code is
This produces the warning regardless of whether batch_transform is partial or not (a fix is to import dill or lift
batch_transform
out ofrun
)Versions
PyTorch version: 1.13.0+cu117 Is debug build: False CUDA used to build PyTorch: 11.7 ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.1 LTS (x86_64) GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0 Clang version: Could not collect CMake version: version 3.22.1 Libc version: glibc-2.35
Python version: 3.10.6 (main, Nov 14 2022, 16:10:14) [GCC 11.3.0] (64-bit runtime) Python platform: Linux-5.15.0-58-generic-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 11.7.99 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2080 Ti Nvidia driver version: 525.60.13 cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
Versions of relevant libraries: [pip3] numpy==1.24.1 [pip3] pytorch-ignite==0.4.10 [pip3] torch==1.13.0+cu117 [pip3] torchdata==0.5.0 [pip3] torchtext==0.14.0 [conda] Could not collect