Open Birch-san opened 2 years ago
d'oh, I just noticed you already implemented a quantize
boolean on ingress to the model:
https://github.com/crowsonkb/k-diffusion/blob/3f93c28088890e4d6bc593072739e1d6e759b392/k_diffusion/external.py
and certainly that approach gives me other ideas about how to factor this (i.e. make use of the DiscreteSchedule
class rather than doing everything in the sampler).
our solutions are equivalent when gamma == 0
…
but when gamma > 0
though, I think your quantize=True
diverges slightly from the paper.
the paper says to quantize sigma_hat
after line 5, such that it impacts the computation of x
on line 6.
whereas currently k-diffusion passes into the model:
x
computed from a non-discretized sigma_hat
sigma_hat
So the paper didn't mention this, but the result is terrible at low step counts if you actually implement the 0 as they describe. Maybe this is just a problem for discrete time models?
You need the 0 on the end so the sampler outputs a fully denoised image, the ODE needs to be integrated from sigma_max to 0 for this to happen. I think the thing you are observing happens because sigma_min (the last noise level the model is evaluated at) is too low for low step counts. Have you tried increasing sigma_min instead, but keeping the concatenation of 0?
thanks very much @crowsonkb for explaining the importance of the 0!
okay, so we need to keep the 0. but ramping all the way down to sigma_min inclusive isn't the best use of our limited sigmas.
looking at the next-lowest sigma, the successful picture sampled 0.1072. the unsuccessful picture sampled 0.0292.
so one idea is to formalize the wacky way from which that 1.072 was computed, so we can intentionally use it as our sigma_min.
the 1.072 can be obtained like this:
steps=7
get_sigmas_karras(
# there's an argument that steps+1 is wacky, so let's remember to try without the +1 too
n=steps+1,
# 14.6146
sigma_max=model.sigmas[-1].item(),
# 0.0292
sigma_min=model.sigmas[0].item(),
rho=7.
)[-3] # skip nth because it's 0, skip n-1th because it's the known-bad sigma_min
or more efficiently like this:
# gets the N-1th sigma from a Karras noise schedule
def get_awesome_sigma_min(
steps: int,
sigma_max: float,
sigma_min_nominal: float,
rho: float
) -> float:
min_inv_rho = sigma_min_nominal ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
ramp = (steps-2) * 1/(steps-1)
sigma_min = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigma_min
steps=7
# 14.6146
sigma_max=model.sigmas[-1].item()
sigma_min = get_awesome_sigma_min(
steps=steps+1,
sigma_max=sigma_max,
# 0.0292
sigma_min_nominal=model.sigmas[0].item(),
rho=7.
)
having computed a new sigma_min 0.1072 using steps+1
(a bit arbitrary but matches my original experiment),
we call the (unmodified) get_sigmas_karras()
the normal way, with our new sigma_min
:
sigmas = get_sigmas_karras(
n=opt.steps,
sigma_min=sigma_min,
sigma_max=sigma_max,
rho=rho,
)
it returns the following noise schedule, identical to our first experiment except ending with 0 instead of 0.0292:
[14.6146, 7.9029, 4.0277, 1.9104, 0.8289, 0.3211, 0.1072, 0.0000]
the sigma_hats get discretized to these before being passed into the model:
[14.6146, 7.9216, 4.0300, 1.9103, 0.8299, 0.3213, 0.1072]
Picture still looks good (the power of the 0.1072 sigma, probably):
now let's try simplifying that ugly get_awesome_sigma_min(steps=steps+1
to just steps=steps
.
new sigmas out of the oven. spends more time on the middle sigmas. "more time in the middle" sounds closer to the behaviour of DiscreteSchedule
, based on my comparison in this thread:
https://twitter.com/Birchlabs/status/1565114066548527104
[14.6146, 8.0451, 4.1888, 2.0392, 0.9141, 0.3692, 0.1303, 0.0000]
the sigma_hats discretize to:
[14.6146, 8.0461, 4.1878, 2.0402, 0.9136, 0.3687, 0.1308]
Picture still looks good:
We got a new floral pattern on the sleeve! plus some new hair detail. though the eyebrows, eyesockets, irises and shading on facial bones suffered a bit. probably not good to lose so many low sigmas.
steps+1
probably better for keeping the details we care about.
but overall, keeping 0 seems to make this a nicer algorithm than we started with.
so, I think we don't need the concat_zero
change from this PR. since zeroes weren't the problem, wasteful use of the low sigma, sigma_min, was the problem.
the same schedule can be computed without changes to k-diffusion. unless you wanted to accept a get_awesome_sigma_min
helper under "wacky experiments". 😛
but I still think two problems remain regarding adhering to the paper:
quantize=True
quantizes too late (i.e. on model ingress)DiscreteSchedule
abstraction somehowget_sigmas_karras
quantize=True
achieves the same effect, but only when gamma=0
(i.e. when sigma_hat == sigma)DiscreteSchedule#get_sigmas_karras
method, which delegates to get_sigmas_karras
and applies an argmin afterward.If you do model.t_to_sigma(model.sigma_to_t(sigma)) inside the sampler you can get the quantized sigma... but you can't count on those methods being there because the user could just pass in any arbitrarily wrapped model. I'm not really sure what to do tbh.
an older factoring of the code that I tried was to expose a quanta
parameter:
https://github.com/Birch-san/k-diffusion/commit/8ebcd0098e67823ceee6499c34bb29e20704c1f3
but feels kinda like a failure to make use of the model-wrapping idiom.
given that discrete sigmas are a DiscreteSchedule
concern, it feels like DiscreteSchedule
should provide peers of all karras sampler methods. like DiscreteSchedule#sample_heun
.
DiscreteSchedule#sample_heun
would forward calls to sample_heun
. or perhaps DiscreteSchedule#sample_heun
and sample_heun
would mutually forward calls to a private function, _sample_heun()
(which could expose an optional quantize
callback).
factoring out a common core might not be the craziest thing to do, since sample_euler()
and sample_heun()
both have a lot of code they could share.
The samplers are supposed to be independent of the models, though, that would duplicate a ton of code and I might add new samplers later etc. Is there some reasonable way to guarantee that a wrapper class has all the required methods? The usual idiom here is subclassing but that doesn't really work with the wrapper idiom...
you mean a way to sniff model
to see if it provides a way to quantize?
if we're ruling out "checking if it extends a class/mixin", then I guess that leaves "check whether it has a particular method decorated with a decorator you provide"?
Maybe there could be a model wrapper class that has all of the methods that the samplers etc. expect, and the default implementation of these methods just forwards to the wrapped model, and users could override these methods to customize the behavior. That is, all model wrappers would subclass this and override methods, maybe just forward() but they could also alter the other methods if they did something more complicated.
yes, that would be a good way to do it.
if you're a continuous-time model, you don't want to quantize sigma_hat at all.
so maybe a new base model wrapper class would be introduced (from which VDenoiser
and DiscreteSchedule
would then inherit).
the base model wrapper class would have a decorate_karras_sigma_hat(tensor: Tensor) -> Tensor
which would just be the identity function. DiscreteSchedule
would override this to (if quantize=True
) quantize the tensor.
I need to think about which methods to make standard on the wrapper...
forward()
/__call__()
obviouslysigma_to_t()
and t_to_sigma()
(for k-diffusion native models these can just be the identity function). If you have these two methods you can quantize by going from sigma to t and back.loss()
probably.Maybe forward get_sigmas()
if it exists on the inner model, and don't forward it if it doesn't? Or just make the base model class have this method but raise NotImplementedError
.
Oh! Maybe add encode()
and decode()
methods that are the identity function by default but which, for latent diffusion models, encode/decode using the autoencoder.
Maybe also have a sigma_min
and a sigma_max
property for easy access to the valid timestep range.
disclaimer: my design patterns are based on Java experience, not Python.
I'd start by only implementing stuff that you actually have a user for.
easy to add more later once the requirement is discovered. very hard to take back something once it's shipped.
I'd start from "who consumes a base class?".
your samplers will. they currently only do one thing with model
: __call__()
.
and we'll want at least one new capability to help us discretize sigmas. maybe that's sigma_to_t()
+ t_to_sigma()
, as you say. but whilst it's good composition, it's anti-performance (it prevents simplifying it to just a one-line argmin). so I'd say there's a performance case for a quantize_sigma()
method. this could be provided in addition to the two other functions of course.
will end-users consume this base class? I don't know a use-case that would mean they'd ever see the base class.
I, for example, construct a CompVisDenoiser(model)
. I know the subclass, so I don't need the base class to be descriptive.
the only reason I'd lose this information is if I'm doing something like a strategy factory, to pick a wrapper at runtime. I don't think anybody would actually do this…
another consideration vis-à-vis forcing subclasses to adhere to the same method signatures… we already see some divergence here; DiscreteSchedule
's sigma_to_t()
has an additional quantize: bool
. it's compatible, but would enforcing method signatures from a base class restrict your design options in future (e.g. if one of the subclasses needs some bespoke, additional params on a method)?
not sure how polymorphism works in Python (i.e. whether it's still considered "overriding" a function if the method signatures are different but compatible).
I guess you can get around anything by picking a lowest common denominator, and spreading **kwargs
to make it extensible.
Maybe forward
get_sigmas()
if it exists on the inner model, and don't forward it if it doesn't? Or just make the base model class have this method butraise NotImplementedError
.
of the choices, I prefer raise NotImplementedError
(principle of least astonishment — better to be limited and clear than to try and be helpful in an unpredictable way).
if it's a method that's mandatory, but for which no sensible default implementation can be provided: abstract methods are a good way to force the subclass to make the decision of how to implement instead (but yeah that could just be raise NotImplementedError
).
Oh! Maybe add
encode()
anddecode()
methods that are the identity function by default but which, for latent diffusion models, encode/decode using the autoencoder.
This might be another situation where — for performance reasons — it would be good to support roundtrip()
(which as you say could for some models be the identity function) rather than forcing to go through encode(decode())
.
Maybe also have a
sigma_min
and asigma_max
property for easy access to the valid timestep range.
hmm I guess that's something I'd use (I'm currently resorting to model_k_wrapped.sigmas[-1].item()
) but feels like it's something that only discrete timestep models would need?
so maybe it wouldn't go as low as the base class, but rather into a mixin or superclass that discrete models inherit?
I was thinking about something along the lines of the following:
class BaseModelWrapper(nn.Module):
"""The base wrapper class for the k-diffusion model wrapper idiom. Model
wrappers should subclass this class and customize the behavior of the
wrapped model by implementing or overriding methods."""
def __init__(self, inner_model):
super().__init__()
self.inner_model = inner_model
def __dir__(self):
return list(set(super().__dir__() + dir(self.inner_model)))
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.inner_model, name)
def forward(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
I'm not sure where the standard methods should go yet, on this base wrapper or separately implemented on the different denoiser wrappers, which would be changed to be subclasses of this.
hey, sorry for slow reply.
okay, so this wrapper creates the illusion that the wrapped instance is a subtype of inner_model. seems reasonable (we want the wrapped instance to be a substitute that can be used in all the same situations).
as for where the standard methods (e.g. sigma_to_t()) should go...
I guess it depends what this base wrapper claims its responsibilities are.
if it's a "general model wrapper" (i.e. nothing to do with diffusion, but perhaps with generic responsibilities like logging), then I wouldn't put sigma_to_t() this low.
it it's a "diffusion model wrapper" (and I assume it is), then I think it makes sense to put sigma_to_t() this low if (and only if) that's something that every diffusion model needs.
if there's a one-size-fits-all implementation of sigma_to_t() that can be put here, it can go here.
if "it depends", then I think it should be an abstract method. the base wrapper forces subclasses to provide an implementation.
generally, the decision of "should I put sigma_to_t() -- or at least an abstract interface for it -- this low", is answered by "what model type will the samplers integrate against?"
if the samplers expect the user to pass in an instance of BaseModelWrapper, then sigma_to_t() needs to be on BaseModelWrapper, or the sampler needs to be prepared to sniff the model instance and look for a more specific subclass. it's preferable to avoid that.
hey, sorry for slow reply.
okay, so this wrapper creates the illusion that the wrapped instance is a subtype of inner_model. seems reasonable (we want the wrapped instance to be a substitute that can be used in all the same situations).
as for where the standard methods (e.g. sigma_to_t()) should go... I guess it depends what this base wrapper claims its responsibilities are.
if it's a "general model wrapper" (i.e. nothing to do with diffusion, but perhaps with generic responsibilities like logging), then I wouldn't put sigma_to_t() this low.
it it's a "diffusion model wrapper" (and I assume it is), then I think it makes sense to put sigma_to_t() this low if (and only if) that's something that every diffusion model needs.
It is yeah.
if there's a one-size-fits-all implementation of sigma_to_t() that can be put here, it can go here. if "it depends", then I think it should be an abstract method. the base wrapper forces subclasses to provide an implementation.
For a k-diffusion native model, sigma(t) = t, so the default implementation can simply return its input, same for t_to_sigma().
okay sure, then yes: let's put a default implementation (identity function) of sigma_to_t() and t_to_sigma() in the base diffusion model wrapper.
Improves support for diffusion models with discrete time-steps (such as Stable Diffusion's DDIM).
I have some questions though, so this may need some iterating.
The user would invoke like so:
Implements the change to "Algorithm 2, line 5" described in the Elucidating paper arXiv:2206.00364 section C.3.4 "iDDPM practical considerations" practical challenge 3.
In other words we round sigmas to the nearest sigma supported by the DDIM.
For your convenience, here's the sigmas supported by Stable Diffusion DDIM:
https://gist.github.com/Birch-san/6cd1574e51871a5e2b88d59f0f3d4fd3
You may be wondering "okay, rounding sigma_hat solves challenge 3, but what about challenge 2".
There's an argument that solving challenge 3, solves challenge 2 for some situations.
When
gamma == 0
, rounding sigma_hat is equivalent to rounding sigma (which is what challenge 2 requires you to do for any outputs ofget_sigmas_karras()
).Problem here is the final sigma we'll receive, 0. we probably don't want to apply the same rounding rules to that… especially because we have a special-case predicated on 0. should that be predicated on uargmin instead, or perhaps on "have we reached the final sigma?"
edit: maybe the only reason they special-case 0 is because they want to avoid dividing by zero?
If we do care about satisfying challenge 2 in the
gamma > 0
situation, we'd want to round-to-nearest-sigma what comes out ofget_sigmas_karras()
. I happen to have made a torch snippet for runningargmin
on every element returned byget_sigmas_karras()
simultaneously:But again, not sure of what the implications are for the 0 it returns.
Anyway, maybe we can look at the outputs to decide. We'll try with keeping the 0 and without.
I tried to stress this to its limits by using as few steps as I could manage before it looked bad. All images are:
68673924
get_sigmas_karras()
noise schedule.Heun, 7 steps
Excluding 0 from
get_sigmas_karras()
The better-looking result was when I excluded the 0 returned by
get_sigmas_karras()
, in favour of ramping for 1 more step.Recall that SD's sigmas run from max =
14.6146
to min =0.0292
.Sigmas returned by
get_sigmas_karras()
:_
sample_heun
only iterates to n-1, so never touches the0.0292
._Time-step discretization enabled
Sigmas (up to n-1) after discretization:
Original k-diffusion behaviour (no discretization)
Not much perceptible difference. The discrete one defines the far sleeve better, but the other subtle differences it's hard for me to say which is the better generation.
Keeping 0 from
get_sigmas_karras()
So the paper didn't mention this, but the result is terrible at low step counts if you actually implement the 0 as they describe. Maybe this is just a problem for discrete time models?
Sigmas returned by
get_sigmas_karras()
:_
sample_heun
only iterates to n-1, so never touches the 0._Time-step discretization enabled
Sigmas (up to n-1) after discretization:
Original k-diffusion behaviour (no discretization)
Slightly more perceptible difference. The discrete one did better on the eyes and has slightly more clothing definition.
Conclusion
Removing the "concat 0" from
get_sigmas_karras()
seems to be hugely beneficial for small numbers of steps. This is not backed up by the literature. The reason I tried this was due to a misunderstanding. I saw that if I discretized the whole schedule, I'd end up with a repeated uargmin (… 0.0292, 0.0292]
). I removed the concat 0 to ensure I didn't end up producing duplicates. I didn't realize though that the sampler stops at n-1 so repeats aren't actually a problem. But it seems that for a different reason, the results are far better.Discretization of time-steps doesn't have the dramatic impact I was hoping for, but is probably still a sensible thing to do on the basis that the paper recommended it.
Heun, 50 steps, excluding 0
Let's do one more example, to 50 steps
Time-step discretization enabled
Original k-diffusion behaviour (no discretization)
Discretization seems to be more noticeable over 50 steps. The discretized image seems to have sharper hair and clothing, and highlights are brighter. Not sure I could say which is "better" though.
It's hard to compare images scrolling on GitHub; personally I flicked between these using QuickLook in the Finder.
If you know a better way to evaluate whether this is an improvement: I'm all ears! 👂