Open rsokl opened 1 year ago
Also, an important thing is propagation of passed generators to all components of transform pipeline? (there's an additional complexity that not all transforms need these generators, but those that can accept them may need it to be propagated)
In my own code, I implemented my own containers and transform classes partly because of this. This is possible, but at least there should be reusable staticmethod ways for sampling the transform arguments that can accept rng.
@vadimkantorov I updated my post to include a description of how one would pass rng
through to all components of Compose
-based pipeline.
@rsokl Yes, I implemented something similar in my own code. For that, all transforms must accept a generator
argument even if they don't use any randomness and are deterministic, which may be an okay solution!
Although somewhat niche and low priority, https://github.com/pytorch/vision/pull/3001#issuecomment-814849595 also shows an example of why good RNG support is needed. In case a user wants to use the same random parameters at different points in time, there are currently only two solutions:
torch.manual_seed(42)
before each call. But without any containment this leaks into the surroundings. Thus, in such a situation it would be really beneficial to just reset a generator and pass it again.
generalizing and accepting rng/generator optionally to those get_params / sample_params could be a first step towards easier reproducibility (if not done yet...)
Passing the generator in forward()
means that we can't have a per-transform RNG stream anymore when using containers like Compose()
(and most use-cases involve Compose()
). Taking the example from above
from torch import Generator
rng = Generator.manual_seed(0)
trans = T.Compose(
[
T.ColorJitter(contrast=0.5),
T.RandomRotation(30),
T.CenterCrop(480),
]
)
img, bboxes, labels = trans(img, bboxes, labels, generator=rng)
All 3 transforms will be using the same rng
generator. There's no way to have a finer control where each transform would get its own RNG stream.
What if I want to "freeze" the RNG of one transform in Compose, while preserving maximal entropy for the rest of the transforms? I can't re-seed the rng
object because that would affect all transforms. Is that something we want to enable?
In the limit, the only benefit I can see from allowing generators in forward()
is that it separates the transforms RNG's from that of the global pytorch RNG (the one from torch.manual_seed()
), but that's it. It still treats the transforms RNG as one (and only one) RNG. It doesn't allow a per-transform RNG in practice, since Compose must be used.
I'm curious if either @rsokl or @vadimkantorov have found this to be a limitation in practice?
(Instead of passing generators in forward()
, we could pass them in __init__
, but I need to think more about what that means.)
Oh, I think we can rule-out forward()
completely. How would that even work with datasets, where the tranforms forward is called in the dataset's __getitem__()
?
I spent a bit more time thinking about it and I implemented a toy solution to check what happens when DataLoader
multi-processing is involved (passing the generators in __init__
, as forward()
is impossible). As far as I can tell, we're going to hit major UX issues (more details in the notebook below).
For now I can't think of a clean way to handle all this without requiring users to understand inner-details of the DataLoader, so that's a bummer. But I'm curious what you all think.
# %%
import torch
from torch.utils.data import DataLoader
class MyTransform(torch.nn.Module):
def __init__(self, rng):
super().__init__()
self.rng = rng
def forward(self):
return torch.randint(0, 1000, size=(1,), generator=self.rng).item()
class Dataset:
def __init__(self, transform):
self.transform = transform
def __getitem__(self, _):
return self.transform() # no input to the transform, we don't care.
def __len__(self):
return 1000
rng = torch.Generator()
t = MyTransform(rng)
ds = Dataset(t)
# %%
# Dataset only, so far so good
for x, _ in zip(ds, range(4)):
print(x)
# 710
# 284
# 837
# 820
# %%
# Things break with DataLoder(num_workers > 0).
# The generator is duplicated across workers when we fork.
# Oopsies.
# Note: this is actually documented! https://pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading
dl = DataLoader(ds, num_workers=2)
for x, _ in zip(dl, range(10)):
print(x)
# tensor([299])
# tensor([299])
# tensor([754])
# tensor([754])
# tensor([334])
# tensor([334])
# tensor([739])
# tensor([739])
# tensor([609])
# tensor([609])
# %%
# Only way to make it work is to set a per-worker seed: https://pytorch.org/docs/stable/notes/faq.html#my-data-loader-workers-return-identical-random-numbers
# And BTW the only reason things "work" by default is becuase torch does that
# for us already https://github.com/pytorch/pytorch/blob/1a661639f77a172df5d1ccd6987049292c6f3440/torch/utils/data/_utils/worker.py#L223-L225
def worker_init_fn(worker_id):
rng.manual_seed(worker_id)
prev_state = rng.get_state() # surpise surprise, see cell below
dl = DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)
for x, _ in zip(dl, range(10)):
print(x)
# tensor([44])
# tensor([845])
# tensor([239])
# tensor([139])
# tensor([933])
# tensor([124])
# tensor([760])
# tensor([368])
# tensor([963])
# tensor([263])
# %%
# Oh and on top of that, the RNG from the main process is never consumed!
# So we get the exact same RNG across epochs.
# EDIT: As Philip pointed out, this is actually already the case even when the global RNG is used everywhere
# Things only work OK because the global RNG gets consumed elsewhere e.g. by the RandomSampler. Ew.
assert (rng.get_state() == prev_state).all()
I'm curious if either @rsokl or @vadimkantorov have found this to be a limitation in practice?
In my own practice, I implemented my own Compose/functions for passing down RNG when I need it (as I'm using mainly the pure tensor functions). Supporting it as a forward argument is a worthy thing, even if it's not passed down by default Compose. Your solution of passing it in the constructor is also not bad!
Oh, I think we can rule-out
forward()
completely. How would that even work with datasets, where the tranforms forward is called in the dataset's
In my own code when I need to control the RNG / transforms (sometimes when I had to apply dependent versions of transforms to several images in the batch), I usually implemented my own custom dataset code and my own custom samplers
At least to simplify the life for advanced usecases, but it's much better to let an explicit option (be it with RNG in the field or also with forward, can even easily support both - either take the forward arg-provided RNG or the RNG from the field)
Another fun fact: I just realized that Generators
aren't pickleable, so they can never be used with Dataloader(num_workers > 0)
on Windows or MacOS (https://pytorch.org/docs/stable/data.html#platform-specific-behaviors). This isn't really a blocker because most training jobs happen on linux. But the UX issues from https://github.com/pytorch/vision/issues/7027#issuecomment-1626174673 definitely are blocking IMO.
I've followed-up with the torch core team to see if we can do something about it.
A problem is that we may still want different generators for different worker threads, so the generator should not just be cloned, but also seeded with the worker-id or somehow depend on example-id. And also, the match of worker-id and example-id might not be guaranteed even if it was all satisfied.
I would propose, that the most useful it's to support passing rng to forwards/get_params. for simplifying advanced cases (e.g. the advanced user may just in Dataset's __getitem__
create a new Generator and seed it with epoch_id + example_id
, and pass it to the augmentation pipeline. Yes, this would not achieve aug-param reproducibility of reference workflows, but it would make life simpler for complicated aug pipelines
forward()
isn't an option @vadimkantorov. It wouldn't be compatible with what torchvision currently looks like.
@NicolasHug This prototype https://github.com/pytorch/vision/pull/7445 by @pmeier is actually similar to what I rolled for my own code
I am familiar with https://github.com/pytorch/vision/pull/7445.
I have explained in a few of my comments above why passing RNGs to forward()
is impossible. It just doesn't work with datasets, nor with the Dataloader. It is technically impossible. Passing RNGs to __init__
is the only possible solution that wouldn't require changes to datasets or to the entire Dataloader code-base.
Let me know if there is anything I can clarify. Otherwise, let's please reduce the noise on this issue and focus on the more urgent matter described in https://github.com/pytorch/vision/issues/7027#issuecomment-1626174673. I am in a tight loop with torch core to get it hopefully resolved.
If you're still keen on debating the forward vs init issue, let's please do so on another issue.
If you're still keen on debating the forward vs init issue, let's please do so on another issue.
I was not sure if you were responding to my arguments only or also to the prototype in this linked PR, and I linked it because I missed it up in the thread and because it's more complete and concrete than my words, so I linked it.
I respectfully don't consider your technical arguments correct in this case. As I outlined, being able to pass down the aug transforms and rng to the worker from the main thread will not likely solve reproducibility and kills the point of augmentations, as for RNG-based augs to be truly deterministic but meaningful/random enough at the same time they need to be correlated with things like epoch-id and example-id to control for thread scheduling randomness. So I consider the argument of "technically impossible" incorrect and the impossibility-of-pickleness not very relevant for the actual goal of reproducibility (if course it would be nice if Generators can be pickled, but doesn't seem very relevant because of these reasons). You seem to have ignored this argument. This means that the goal of "not adding any changes to Datasets/DataLoaders and still achieving meaninigful reproducibility" is indeed inachievable, but passing down the generator is also not a good solution! The field method to be meaninigful would still require users to reset these RNG objects in the whole aug pipeline in Dataset's __getitem__
. Of course it would still be better than the global state, but if we accept that the users should do Dataset modifications, the forward
solution is at least considerable (and more FP-like / JaX-like).
But, as I see now, the decision is taken and evaluated the PR in question as well, I of course recognize that, and I see no point arguing anymore in this or any other issues in this repo (for that matter).
I will reduce my noise/feedback/conversation/issues in this repo to zero from now on, sorry for the bother.
Your message above is mixing up 2 orthogonal issues:
They're orthogonal. The first one is solved, and I am working with torch core to address the second.
Respectfully, I don't think we have a shared understanding on this topic. Let's leave it at that please.
🚀 The feature
(This was originally pitched in this long feedback thread. It was recommended that I open a separate issue).
Enable the new transforms API to support the use of local generators to control RNG via the modified API:
Thus transforms that implement
_get_params
would replace calls likewith
A transform like
Compose
would have to be modified as well. Currently, it supports a sequence of callables that are assumed to accept a single positional argument. It could be assumed that only instances ofTransform
involve stochasticity and will be passed the random generator. In this case,Compose
would look like:It would be straightforward to document this behavior to users – that only instances of
Transform
are passed the generator – so that they know how to opt-in to having the generator be passed to their custom transforms. And, again, this would be compatible with the oldnn.Module
transforms.An example of this in practice would be:
Another nice thing about this is that specific fail cases that occur during training/testing can be reproduced in an isolated way;
_get_params(dummy_img, generator=rng)
can be used to iterate the generator's state to "replay" a sequence of transformations without have to redo all of the compute. Whereas this would not work if the model and the transforms both affect and derive from global state.Motivation, pitch
In recent years, NumPy has completely revised their PRNG API to avoid global random state (here is a great post on good practices with NumPy's generators). JAX avoids mutable RNG objects altogether. PyTorch provides
torch.Generator
to users to to make randomness local and "non-spooky", but many libraries prevent users from utilizing this capability.I am proposing that
Transform
enable users to optionally pass in aGenerator
to the forward pass so that torchvision transform pipelines can be made to be isolated from global entropy and thus support more reproducible workflows. This reproducibility is especially useful in the context of performing testing & evaluation – the specific sequence of data transformations performed should be able to be isolated from whether or not a model is using dropout in its forward pass.Alternatives
No response
Additional context
@pmeier already provided (positive) feedback on this proposal here
cc @vfdev-5 @datumbox @bjuncek @pmeier