pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.21k stars 6.95k forks source link

Transforms V2 proposal: Enabling reproducible workflows via local RNGs #7027

Open rsokl opened 1 year ago

rsokl commented 1 year ago

🚀 The feature

(This was originally pitched in this long feedback thread. It was recommended that I open a separate issue).

Enable the new transforms API to support the use of local generators to control RNG via the modified API:

from torch import Generator, Tensor, default_generator
import torch.nn as nn

class Transform(nn.Module):
    def _get_params(self, flat_inputs: List[Any], *, generator: Generator) -> Dict[str, Any]:
        ...

    def forward(self, *inputs: Any, generator: Generator = default_generator) -> Any:
        # the only modification
        params = self._get_params(flat_inputs, generator=generator)

Thus transforms that implement _get_params would replace calls like

# e.g. replace calls like
angle = float(torch.empty(1).uniform_(0.0, 180.).item())

with

# specifying the device is, unfortunately, necessary: https://github.com/pytorch/pytorch/issues/79018
angle = float(torch.empty(1, device=generator.device).uniform_(0.0, 180., generator=generator).item())

A transform like Compose would have to be modified as well. Currently, it supports a sequence of callables that are assumed to accept a single positional argument. It could be assumed that only instances of Transform involve stochasticity and will be passed the random generator. In this case, Compose would look like:

class Compose(Transform):
    # __init__ is unchanged

    def forward(self, *inputs: Any, generator: Generator = default_generator) -> Any:
        sample = inputs if len(inputs) > 1 else inputs[0]
        for transform in self.transforms:
            sample = transform(sample) if not isinstance(transform, Transform) else transform(sample, generator=generator)
        return sample

It would be straightforward to document this behavior to users – that only instances of Transform are passed the generator – so that they know how to opt-in to having the generator be passed to their custom transforms. And, again, this would be compatible with the old nn.Module transforms.

An example of this in practice would be:

from torch import Generator

rng = Generator.manual_seed(0)

trans = T.Compose(
    [
        T.ColorJitter(contrast=0.5),
        T.RandomRotation(30),
        T.CenterCrop(480),
    ]
)
img, bboxes, labels = trans(img, bboxes, labels, generator=rng)

Another nice thing about this is that specific fail cases that occur during training/testing can be reproduced in an isolated way; _get_params(dummy_img, generator=rng) can be used to iterate the generator's state to "replay" a sequence of transformations without have to redo all of the compute. Whereas this would not work if the model and the transforms both affect and derive from global state.

Motivation, pitch

In recent years, NumPy has completely revised their PRNG API to avoid global random state (here is a great post on good practices with NumPy's generators). JAX avoids mutable RNG objects altogether. PyTorch provides torch.Generator to users to to make randomness local and "non-spooky", but many libraries prevent users from utilizing this capability.

I am proposing that Transform enable users to optionally pass in a Generator to the forward pass so that torchvision transform pipelines can be made to be isolated from global entropy and thus support more reproducible workflows. This reproducibility is especially useful in the context of performing testing & evaluation – the specific sequence of data transformations performed should be able to be isolated from whether or not a model is using dropout in its forward pass.

Alternatives

No response

Additional context

@pmeier already provided (positive) feedback on this proposal here

cc @vfdev-5 @datumbox @bjuncek @pmeier

vadimkantorov commented 1 year ago

Also, an important thing is propagation of passed generators to all components of transform pipeline? (there's an additional complexity that not all transforms need these generators, but those that can accept them may need it to be propagated)

In my own code, I implemented my own containers and transform classes partly because of this. This is possible, but at least there should be reusable staticmethod ways for sampling the transform arguments that can accept rng.

rsokl commented 1 year ago

@vadimkantorov I updated my post to include a description of how one would pass rng through to all components of Compose-based pipeline.

vadimkantorov commented 1 year ago

@rsokl Yes, I implemented something similar in my own code. For that, all transforms must accept a generator argument even if they don't use any randomness and are deterministic, which may be an okay solution!

pmeier commented 1 year ago

Although somewhat niche and low priority, https://github.com/pytorch/vision/pull/3001#issuecomment-814849595 also shows an example of why good RNG support is needed. In case a user wants to use the same random parameters at different points in time, there are currently only two solutions:

  1. Some random transformations in transforms v1 have the ability to sample the random parameters statically. However, there is no support to use these parameters directly and so also the parameters were known, the user had to write the actual transformation themselves. This might be a single call to a functional, but can be more complicated real quick.
  2. The user could use the nuclear option and set torch.manual_seed(42) before each call. But without any containment this leaks into the surroundings.

Thus, in such a situation it would be really beneficial to just reset a generator and pass it again.

vadimkantorov commented 1 year ago

generalizing and accepting rng/generator optionally to those get_params / sample_params could be a first step towards easier reproducibility (if not done yet...)

NicolasHug commented 1 year ago

Passing the generator in forward() means that we can't have a per-transform RNG stream anymore when using containers like Compose() (and most use-cases involve Compose()). Taking the example from above

from torch import Generator

rng = Generator.manual_seed(0)

trans = T.Compose(
    [
        T.ColorJitter(contrast=0.5),
        T.RandomRotation(30),
        T.CenterCrop(480),
    ]
)
img, bboxes, labels = trans(img, bboxes, labels, generator=rng)

All 3 transforms will be using the same rng generator. There's no way to have a finer control where each transform would get its own RNG stream.

What if I want to "freeze" the RNG of one transform in Compose, while preserving maximal entropy for the rest of the transforms? I can't re-seed the rng object because that would affect all transforms. Is that something we want to enable?

In the limit, the only benefit I can see from allowing generators in forward() is that it separates the transforms RNG's from that of the global pytorch RNG (the one from torch.manual_seed()), but that's it. It still treats the transforms RNG as one (and only one) RNG. It doesn't allow a per-transform RNG in practice, since Compose must be used.

I'm curious if either @rsokl or @vadimkantorov have found this to be a limitation in practice?

(Instead of passing generators in forward(), we could pass them in __init__, but I need to think more about what that means.)

NicolasHug commented 1 year ago

Oh, I think we can rule-out forward() completely. How would that even work with datasets, where the tranforms forward is called in the dataset's __getitem__()?

NicolasHug commented 1 year ago

I spent a bit more time thinking about it and I implemented a toy solution to check what happens when DataLoader multi-processing is involved (passing the generators in __init__, as forward() is impossible). As far as I can tell, we're going to hit major UX issues (more details in the notebook below).

For now I can't think of a clean way to handle all this without requiring users to understand inner-details of the DataLoader, so that's a bummer. But I'm curious what you all think.

# %%
import torch
from torch.utils.data import DataLoader

class MyTransform(torch.nn.Module):
    def __init__(self, rng):
        super().__init__()
        self.rng = rng

    def forward(self):
        return torch.randint(0, 1000, size=(1,), generator=self.rng).item()

class Dataset:
    def __init__(self, transform):
        self.transform = transform

    def __getitem__(self, _):
        return self.transform()  # no input to the transform, we don't care.

    def __len__(self):
        return 1000

rng = torch.Generator()

t = MyTransform(rng)
ds = Dataset(t)

# %%
# Dataset only, so far so good
for x, _ in zip(ds, range(4)):
    print(x)
# 710
# 284
# 837
# 820

# %%
# Things break with DataLoder(num_workers > 0).
# The generator is duplicated across workers when we fork.
# Oopsies.
# Note: this is actually documented! https://pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading
dl = DataLoader(ds, num_workers=2)
for x, _ in zip(dl, range(10)):
    print(x)
# tensor([299])
# tensor([299])
# tensor([754])
# tensor([754])
# tensor([334])
# tensor([334])
# tensor([739])
# tensor([739])
# tensor([609])
# tensor([609])
# %%

# Only way to make it work is to set a per-worker seed: https://pytorch.org/docs/stable/notes/faq.html#my-data-loader-workers-return-identical-random-numbers
# And BTW the only reason things "work" by default is becuase torch does that
# for us already  https://github.com/pytorch/pytorch/blob/1a661639f77a172df5d1ccd6987049292c6f3440/torch/utils/data/_utils/worker.py#L223-L225
def worker_init_fn(worker_id):
    rng.manual_seed(worker_id)

prev_state = rng.get_state()  # surpise surprise, see cell below
dl = DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)
for x, _ in zip(dl, range(10)):
    print(x)
# tensor([44])
# tensor([845])
# tensor([239])
# tensor([139])
# tensor([933])
# tensor([124])
# tensor([760])
# tensor([368])
# tensor([963])
# tensor([263])

# %%
# Oh and on top of that, the RNG from the main process is never consumed!
# So we get the exact same RNG across epochs.
# EDIT: As Philip pointed out, this is actually already the case even when the global RNG is used everywhere
# Things only work OK because the global RNG gets consumed elsewhere e.g. by the RandomSampler. Ew.
assert (rng.get_state() == prev_state).all()
vadimkantorov commented 1 year ago

I'm curious if either @rsokl or @vadimkantorov have found this to be a limitation in practice?

In my own practice, I implemented my own Compose/functions for passing down RNG when I need it (as I'm using mainly the pure tensor functions). Supporting it as a forward argument is a worthy thing, even if it's not passed down by default Compose. Your solution of passing it in the constructor is also not bad!

vadimkantorov commented 1 year ago

Oh, I think we can rule-out forward() completely. How would that even work with datasets, where the tranforms forward is called in the dataset's

In my own code when I need to control the RNG / transforms (sometimes when I had to apply dependent versions of transforms to several images in the batch), I usually implemented my own custom dataset code and my own custom samplers

At least to simplify the life for advanced usecases, but it's much better to let an explicit option (be it with RNG in the field or also with forward, can even easily support both - either take the forward arg-provided RNG or the RNG from the field)

NicolasHug commented 1 year ago

Another fun fact: I just realized that Generators aren't pickleable, so they can never be used with Dataloader(num_workers > 0) on Windows or MacOS (https://pytorch.org/docs/stable/data.html#platform-specific-behaviors). This isn't really a blocker because most training jobs happen on linux. But the UX issues from https://github.com/pytorch/vision/issues/7027#issuecomment-1626174673 definitely are blocking IMO. I've followed-up with the torch core team to see if we can do something about it.

vadimkantorov commented 1 year ago

A problem is that we may still want different generators for different worker threads, so the generator should not just be cloned, but also seeded with the worker-id or somehow depend on example-id. And also, the match of worker-id and example-id might not be guaranteed even if it was all satisfied.

I would propose, that the most useful it's to support passing rng to forwards/get_params. for simplifying advanced cases (e.g. the advanced user may just in Dataset's __getitem__ create a new Generator and seed it with epoch_id + example_id, and pass it to the augmentation pipeline. Yes, this would not achieve aug-param reproducibility of reference workflows, but it would make life simpler for complicated aug pipelines

NicolasHug commented 1 year ago

forward() isn't an option @vadimkantorov. It wouldn't be compatible with what torchvision currently looks like.

vadimkantorov commented 1 year ago

@NicolasHug This prototype https://github.com/pytorch/vision/pull/7445 by @pmeier is actually similar to what I rolled for my own code

NicolasHug commented 1 year ago

I am familiar with https://github.com/pytorch/vision/pull/7445.

I have explained in a few of my comments above why passing RNGs to forward() is impossible. It just doesn't work with datasets, nor with the Dataloader. It is technically impossible. Passing RNGs to __init__ is the only possible solution that wouldn't require changes to datasets or to the entire Dataloader code-base.

Let me know if there is anything I can clarify. Otherwise, let's please reduce the noise on this issue and focus on the more urgent matter described in https://github.com/pytorch/vision/issues/7027#issuecomment-1626174673. I am in a tight loop with torch core to get it hopefully resolved.

If you're still keen on debating the forward vs init issue, let's please do so on another issue.

vadimkantorov commented 1 year ago

If you're still keen on debating the forward vs init issue, let's please do so on another issue.

I was not sure if you were responding to my arguments only or also to the prototype in this linked PR, and I linked it because I missed it up in the thread and because it's more complete and concrete than my words, so I linked it.

I respectfully don't consider your technical arguments correct in this case. As I outlined, being able to pass down the aug transforms and rng to the worker from the main thread will not likely solve reproducibility and kills the point of augmentations, as for RNG-based augs to be truly deterministic but meaningful/random enough at the same time they need to be correlated with things like epoch-id and example-id to control for thread scheduling randomness. So I consider the argument of "technically impossible" incorrect and the impossibility-of-pickleness not very relevant for the actual goal of reproducibility (if course it would be nice if Generators can be pickled, but doesn't seem very relevant because of these reasons). You seem to have ignored this argument. This means that the goal of "not adding any changes to Datasets/DataLoaders and still achieving meaninigful reproducibility" is indeed inachievable, but passing down the generator is also not a good solution! The field method to be meaninigful would still require users to reset these RNG objects in the whole aug pipeline in Dataset's __getitem__. Of course it would still be better than the global state, but if we accept that the users should do Dataset modifications, the forward solution is at least considerable (and more FP-like / JaX-like).

But, as I see now, the decision is taken and evaluated the PR in question as well, I of course recognize that, and I see no point arguing anymore in this or any other issues in this repo (for that matter).

I will reduce my noise/feedback/conversation/issues in this repo to zero from now on, sorry for the bother.

NicolasHug commented 1 year ago

Your message above is mixing up 2 orthogonal issues:

They're orthogonal. The first one is solved, and I am working with torch core to address the second.

Respectfully, I don't think we have a shared understanding on this topic. Let's leave it at that please.