pytorch / data

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.
BSD 3-Clause "New" or "Revised" License
1.12k stars 149 forks source link

Improve the warning message regarding local function not supported by pickle #947

Open david-waterworth opened 1 year ago

david-waterworth commented 1 year ago

🐛 Describe the bug

I used a local function in a pipeline which results in the warning:

UserWarning: Local function is not supported by pickle, please use regular python function or functools.partial instead.

So I wrapped the local function in functools.partial and the warning persisted, so I investigated closer. The problem appears to be that _check_unpickable_fn is undoing the partial before checking if fn is local, i.e.

This code unwraps partial, converting back to the underlying local function

https://github.com/pytorch/pytorch/blob/301644d3cb8ce7ecb477b5bcb95a7a8a1304fa71/torch/utils/data/datapipes/utils/common.py#L132

And this code then produces the warning

https://github.com/pytorch/pytorch/blob/301644d3cb8ce7ecb477b5bcb95a7a8a1304fa71/torch/utils/data/datapipes/utils/common.py#L136

My calling code is


def run(args)

    def batch_transform(x):
        return {"inputs": text_transform(x["text"]), "target": label_transform(x["label"])}

    train_datapipe,, test_datapipe= DATASETS[args.dataset](root=args.data_dir, split=('train', 'test'))

    # shuffle and batch
    train_datapipe = train_datapipe.shuffle().batch(args.batch_size).rows2columnar(["text", "label"])
    train_datapipe = train_datapipe.map(partial(batch_transform))

This produces the warning regardless of whether batch_transform is partial or not (a fix is to import dill or lift batch_transform out of run)

Versions

PyTorch version: 1.13.0+cu117 Is debug build: False CUDA used to build PyTorch: 11.7 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64) GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0 Clang version: Could not collect CMake version: version 3.22.1 Libc version: glibc-2.35

Python version: 3.10.6 (main, Nov 14 2022, 16:10:14) [GCC 11.3.0] (64-bit runtime) Python platform: Linux-5.15.0-58-generic-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: 11.7.99 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2080 Ti Nvidia driver version: 525.60.13 cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

Versions of relevant libraries: [pip3] numpy==1.24.1 [pip3] pytorch-ignite==0.4.10 [pip3] torch==1.13.0+cu117 [pip3] torchdata==0.5.0 [pip3] torchtext==0.14.0 [conda] Could not collect

ejguan commented 1 year ago

I think the warning has been shown correctly. pickle has the limitation on local function. Could you please move batch_transform out of run function?

david-waterworth commented 1 year ago

Yes either moving batch_transform or import dill work fine, but my point is my interpretation of

UserWarning: Local function is not supported by pickle, please use regular python function or functools.partial instead.

Was that functools.partial(batch_transform) should work but it doesn't. It might be clearer if it says

UserWarning: Local function is not supported by pickle, please use regular python function or ensure dill is available.

I'm not really sure what the reference to functools.partial means? I confirmed that functools.partial(batch_transform) cannot be pickled so I agree that the tests seem valid.

ejguan commented 1 year ago

General reason for local function is to have some local variable used by the function like:

def fn():
    local_var = some_instance()
    def local_fn(x):
        return local_var + x
    return local_fn

This can be achieved by a partial function + regular function like:

def local_fn(x, local_var):
    return local_var + x

def fn():
    local_var = some_instance()
    return functools.partial(local_fn, local_var=local_var)
david-waterworth commented 1 year ago

Thanks, yes I understand functools.partial in general, I was just confused as to what the warning meant when it states use regular python function **or** functools.partial. It doesn't seem that functools.partial can be used to turn a local function into something that can be pickled. It can of course be used to partially apply a regular python function but that's not what the warning is about.

ejguan commented 1 year ago

@david-waterworth Thanks for call out. We do need to improve the warning message. Do you want to send a PR to PyTorch to fix that?

ejguan commented 1 year ago

Re-opening the issue to track the task about making the warning about local function better. Users should either rely on regular python function + partial, or install dill package.

DongyuXu77 commented 1 year ago

Does anyone have updated the message? I could do that if needed

ejguan commented 1 year ago

@DongyuXu77 Not yet. Pls feel free to open PR

DongyuXu77 commented 1 year ago

I have improved the message, but it seems that I need to open a PR in pytorch instead of data. What should I do?

DongyuXu77 commented 1 year ago

@ejguan Should I copy this issue to pytorch/pytorch and open a PR? The code which need to be updated doesn't shows in this repository.

ejguan commented 1 year ago

@DongyuXu77 You don't need to copy the issue. Just open a PR in PyTorch Core and add fixes: https://github.com/pytorch/data/issues/947 in the PR summary.