LaurentMazare / diffusers-rs

An implementation of the diffusers api in Rust
Apache License 2.0
521 stars 54 forks source link

Add Scheduler trait/enum #36

Open rockerBOO opened 1 year ago

rockerBOO commented 1 year ago

Right now we are adding the schedulers, but it is difficult to work with since swapping the scheduler doesn't work well. This also slows down testing and evaluation of the schedulers, as a separate script needs to be made each time to test the samplers. I also was implementing these into an application and swapping the schedulers wasn't working (due to different types at runtime).

I experimented some in adding a trait, so we can use impl Scheduler. Came up with the following, but causes some points of contention.

pub trait Scheduler {
    fn step<T: SomeTraitThatWouldTakef64AndUsize>(&mut self, model_output: &Tensor, timestep: T, sample: &Tensor) -> Tensor;
    fn timesteps(&self) -> &[usize];
    fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor;
    fn init_noise_sigma(&self) -> f64;
    fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Tensor;
}

And then I could do the following (I'm still learning traits):

sample<T: Scheduler>(
    ...,
    mut scheduler: T
)

And/or we could also do a Scheduler enum.

enum SamplerScheduler {
    Dpmpp2m(dpmsolver_multistep::DPMSolverMultistepScheduler),
    Dpmpp2s(dpmsolver_singlestep::DPMSolverSinglestepScheduler),
    Ddim(ddim::DDIMScheduler),
    Ddpm(ddpm::DDPMScheduler),
    EulerDiscrete(euler_discrete::EulerDiscreteScheduler),
}

I'm not 100% sure what's the best approach.

mspronesti commented 1 year ago

A possibility to solve the timestep type issue Is to have inside the trait a type, say timestep_t which requires to be cloned, copied and allowed to be converted to a primitive type

type timestep_t = Copy + Clone + toPrimitive;

and then use it in the trait as Self::timestep_t. Therefore, when implementing the trait for a particular scheduler one only needs to set it appropriately

impl Scheduler for MyScheduler {
    type timestep_t = usize;
}

I did all of this already, but I'm waiting to open a PR because I have 2 more pending and one more soon to be opened implementing the Heun Discrete scheduler 😅

mspronesti commented 1 year ago

On second thought, I'm not so sure this is a good idea. Not if we want to port other diffusion models, at least. In fact, some of the missing schedulers have a different implementation logic: some don't implement add_noise, some others have two different kind of steps (the predictive step and the correct step, with different return types), e.g. the stochastic differential equation (SDE) scheduler.