dbolya / tomesd

Speed up Stable Diffusion with this one simple trick!
MIT License
1.29k stars 78 forks source link

Feature/fixed rand generator #25

Closed alihassanijr closed 1 year ago

alihassanijr commented 1 year ago

When use_rand is enabled, the fact that token indices are being generated using the default generator results in future rand calls to be offset. This appears to be the reason why in some cases, ToMe generates images that are slightly different from what SD would generate with the same random seed.

While some non-determinism is unavoidable on CUDA specifically, this can be avoided by using an independent rand generator for that one step.

With this commit, the behavior remains unchanged if use_rand is enabled, just to ensure consistency. However, if rand_seed is set to any integer, the independent generator will be used and images will start to look more like what SD would produce.

Here's a few examples:

From right to left are image generated after applying ToMe, image generated without ToMe, image generated with ToMe without randomness, image generated with this commit (unchanged behavior), and image generated with this commit with the random seed set to an integer.

Astronaut on horse in Mars Temple in ruins Castle view

dbolya commented 1 year ago

Thanks for making the PR!

What's the reason for needing one generator per layer? In the examples you have, you were using a single generator for the entire network and it worked fine. As far as I understand, It's not like the generator is reseeded every image, so it's still going to generate different randomness at each layer even if there's only one generator.

alihassanijr commented 1 year ago

Not a problem, happy to help.

Actually I don't have a good response, because you're right. One generator will do. Even with the multiple generators, the seeds had to be incremented. I guess I over-complicated the solution. I'll roll that last commit back and have it use just the one.

dbolya commented 1 year ago

Yeah, that way the generator can still be in _tome_info and nothing else needs to be changed.

alihassanijr commented 1 year ago

I guess I missed the fact that _tome_info was shared across layers initially, so I thought there's no other way than to have one generator per layer. But yeah it's now back to just one generator in _tome_info now.

dbolya commented 1 year ago

Cool, latest edit looks good to me.

I wonder, are there any downsides to having a separate generator? If there's none, maybe having rand_seed=None should use a separate generator without a manual seed instead.

The only issue I can foresee with that is if you don't set rand_seed now you can get unreproducible results (e.g. if some extension doesn't expose that variable or something).

Maybe if the seed is None we could just seed the new generator with the state of the current generator. That way, it's controlled by the same seed that stable diffusion itself uses.

I think the code for that would look like:

if rand_seed is None:
    generator.set_state(torch.get_rng_state())
else:
    generator.manual_seed(rand_seed)

Then the generator gets created and used no matter what, and rand_seed is just the seed (with None being use the current rng state).

alihassanijr commented 1 year ago

That is actually an excellent way of handling it. But I'm getting a size mismatch error when attempting it (RuntimeError: RNG state is wrong size). I am yet to figure out why.

However, while searching for an answer I came across torch.random.fork_rng which forks the current generator, making all of this a simple one line fix 🙂 . As in all we'd have to do is change:

rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)

To:

with torch.random.fork_rng():
    rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)

Let me know if you'd like to switch to this, because it seems like the most convenient approach.

If there is a caveat to this, it's just that it forks the generator per call, which means there's just more calls to the backend. I don't really expect this to result in any noticeable increase in latency, but just thought I'd bring it up.

dbolya commented 1 year ago

Hmm, fork_rng would be convenient, but I think it just saves the state with torch.get_rng_state and then loads it again after the call with torch.set_rng_state. That definitely could affect speed, but maybe it's ok?

The documentation says:

devices (iterable of CUDA IDs) – CUDA devices for which to fork the RNG. CPU RNG state is always forked. By default, fork_rng() operates on all devices, but will emit a warning if your machine has a lot of devices, since this function will run very slowly in that case. If you explicitly specify devices, this warning will be suppressed

This sounds like the forking process isn't free and could potentially be expensive on some systems. We can mitigate that by passing in metric.device, but then it sounds like it also forks the cpu rng anyway? If after passing the device in the time taken is negligible, then it could be ok.

The main issue is that I don't have a bunch of different devices to test this on. Does passing metric.device work if it's on the CPU (cpu), on an M1 device (mkl), on a direct ml device (dml), etc.? I would prefer support for different platforms over simpler code in this case.

alihassanijr commented 1 year ago

Yeah I don't really think it would affect latency all that much, but having one separate generator and just sticking to that is definitely the approach that's least likely to introduce additional latency. Because forking would act as a store/restore in some ways, so it would just end up being more work, regardless of its insignificance.

So yeah I'd recommend keeping the generator as well.

I figured out the reason behind size mismatch issue as well. torch.get_rng_state() defaults to CPU random state, and if we're on CUDA, state shapes won't match. If we want the CUDA state, it appears the only choice is explicitly using torch.cuda.get_rng_state(). Based on source it appears that there's also a torch.mps.get_rng_state(), and that pytorch itself handles each device differently.

Any thoughts on how you'd want this handled?

dbolya commented 1 year ago

Ah that's not fun. We could just create a function that returns the correct generator, but that reminds me of a bigger issue. We can't create the generator at patch time, because the user might change the device of the model after the patch.

Instead, we need some kind of lazy generator that changes based on the current device, and constructs a new generator when it gets an unexpected device / stores a generator per device.

All of a sudden, torch.fork_rng is looking a lot more lucrative... However I looked at the source for it, and nope it only accepts cuda devices, so it wouldn't work for e.g. mps.

I wonder if torch has any in built functionality for this? For instance, torch.rand obviously needs to find the right generator for the requested device, but it also needs to create a new one if none has been initialized yet.

alihassanijr commented 1 year ago

Yeah it is a bit of a pickle apparently. I did pick up on fork_rng only doing CUDA as well, unsure why they're handling it like that.

I guess we could handle it just by bringing Generator back into the matching method, and setting its state to the current state given the input device. It should account for all of those scenarios, but the catch is that these will all be called at every single hook... I'm pretty sure the latency will still be very little, but it's difficult to imagine that will remain to be the case for all devices.

Or alternatively we could have the generator be constructed in the first forward call by adding an extra hook?

dbolya commented 1 year ago

I think the solution with the least issues would be to do this:

  1. Create a function init_generator in utils that takes in a device.
  2. In that function, create and return a generator for that device, and load the correct state using one of the get_random_state functions for the device type (e.g., if device.type == 'cuda': state = torch.cuda.get_rng_state(device), etc.)
  3. During the patch, init the generator with None.
  4. In each layer, add a check if _tome_info['generator'] is None or _tome_info['generator'].device != x.device: and then set the generator to be a newly initialized generator then (making sure to overwrite the current _tome_info['generator']).
  5. Then pass the now guaranteed correct device / initialized generator to the merging function.

This removes the rand seed argument, but I'm not sure it's that useful (and it could be added back by passing rand_seed to _tome_info["args"] and adding it to init_generator if we want).

alihassanijr commented 1 year ago

That all sound great. And yeah the rand seed was only there to allow a custom seed, which becomes unnecessary when we're just forking.

Just verified the behavior is still the same, and pushed.

SD

SD + ToMe

dbolya commented 1 year ago

Sweet, thanks for doing this! Checked the code and looks good to me. I'll merge now and incorporate it into a tomesd 0.1.3 release later tonight.

alihassanijr commented 1 year ago

Glad I could help. Amazing work again!