filipstrand / mflux

A MLX port of FLUX based on the Huggingface Diffusers implementation.
657 stars 48 forks source link

Include FLUX.1-Dev #10

Closed Xuzzo closed 4 weeks ago

Xuzzo commented 1 month ago

Hello and thanks again for your very nice work. This PR:

Note: the flux_1_schnell folder still has the same name here so we can have a better diff and review the changes more easily. Once they are merged, the name of the folders and files will need to be updated too.

Xuzzo commented 1 month ago

There are still some minor things that I want to check, but this should be ready for review. I compared the images created here with those from huggingface and they are very similar. Here an example with FLUX.1-dev Diffusers: image Mflux: image

filipstrand commented 4 weeks ago

@Xuzzo Big thanks for all the help! I will take a look at this PR later today when I have the time :)

filipstrand commented 4 weeks ago

Hi again, I have added some smaller changes on top of your work. Most notably, I have added a ModelConfig class that hold a few pieces of information (such as the max_sequence_length) for each model. Since the sigmas dependent on both the model configuration and the more "run time" configuration of the Config class I begin the generate_image method with this: config = config.copy_with_sigmas(self.model_config) now. It is very similar to how you did it, but I always make a new copy of the Config (since it is lightweight) to make sure we don't accidentally override anything

Xuzzo commented 4 weeks ago

Thanks for your additions. Looks good to me. The only thing I am not fully sure about is the config logic: I dont think the sigma methods should be there, since they have external dependencies (i.e. model type). It does not hurt too much for now, but maybe having another class that takes both config and ModelConfig would be better? something like

class RuntimeConfig:
    def __init__(self, config: Config, model: ModelConfig):
        self.model = model
        self.config = config
        self.sigmas = self._get_sigmas(config.num_inference_steps)
        if model == ModelConfig.FLUX1_DEV:
            self.sigmas = self._shift_sigmas(self.sigmas, config.width, config.height)

    @staticmethod
    def _get_sigmas(num_inference_steps):
        ...

    @staticmethod
    def _shift_sigmas(sigmas: mx.array, width: int, height: int):
        ...

and in this case Config and ModelConfig can be frozen after init. Then you can pass around a RuntimeConfig object and refer to anything as we were doing before.

filipstrand commented 4 weeks ago

Thanks for your additions. Looks good to me. The only thing I am not fully sure about is the config logic: I dont think the sigma methods should be there, since they have external dependencies (i.e. model type). It does not hurt too much for now, but maybe having another class that takes both config and ModelConfig would be better? something like

class RuntimeConfig:
    def __init__(self, config: Config, model: ModelConfig):
        self.model = model
        self.config = config
        self.sigmas = self._get_sigmas(config.num_inference_steps)
        if model == ModelConfig.FLUX1_DEV:
            self.sigmas = self._shift_sigmas(self.sigmas, config.width, config.height)

    @staticmethod
    def _get_sigmas(num_inference_steps):
        ...

    @staticmethod
    def _shift_sigmas(sigmas: mx.array, width: int, height: int):
        ...

and in this case Config and ModelConfig can be frozen after init. Then you can pass around a RuntimeConfig object and refer to anything as we were doing before.

Yes that might actually be a bit more clean and separated. I'll try it out and and lets see how it fits...

filipstrand commented 4 weeks ago

Thanks again for your help. After this is merged and tested a bit more I'll probably send out a tweet about it. Just wanted to check that it is OK that I mention your twitter handle https://x.com/f_xuzzo there :) ?

Xuzzo commented 4 weeks ago

ah thanks! yes that's me 🙂

filipstrand commented 4 weeks ago

ah thanks! yes that's me 🙂

Great!