kijai / ComfyUI-DiffusersStableCascade

Simple inference with StableCascade using diffusers in ComfyUI
305 stars 54 forks source link

High VRAM use - ComfyUI won't unload models #9

Open GitWilburn opened 7 months ago

GitWilburn commented 7 months ago

Because the models are loaded directly, ComfyUI model manager doesn't know about them, and can't unload them. There are probably better ways to deal with this and once ComfyUI adds a native version, it shouldn't matter. But in order to run this on my 12GB GPU, I had to unload the models in between phases. Probably a better way to do this, but I'm still pretty new to ComfyUI development, so this was my solution for now. I found that because it took some time to load/unload the models in between, this worked pretty well to run batches, and because the latents are smaller I could run 3-4 at a time without trouble.

def process(self, width, height, seed, steps, guidance_scale, prompt, negative_prompt, batch_size, decoder_steps, image=None):

        comfy.model_management.unload_all_models()
        torch.manual_seed(seed)

        device = comfy.model_management.get_torch_device()
        #load the prior
        if not hasattr(self, 'prior') or self.prior == None:
            self.prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=torch.bfloat16).to(device)

        prior_output = self.prior(
            image=image,
            prompt=prompt,
            height=height,
            width=width,
            negative_prompt=negative_prompt,
            guidance_scale=guidance_scale,
            num_images_per_prompt=batch_size,
            num_inference_steps=steps
        )
        #unload the prior
        if hasattr(self, 'prior'):
            self.prior = None
            gc.collect()
        #load the decoder
        if not hasattr(self, 'decoder') or self.decoder == None:
            self.decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade",  torch_dtype=torch.float16).to(device)

        decoder_output = self.decoder(
            image_embeddings=prior_output.image_embeddings.half(),
            prompt=prompt,
            negative_prompt=negative_prompt,
            guidance_scale=0.0,
            output_type="pil",
            num_inference_steps=decoder_steps
        ).images

        #unload the decoder
        if hasattr(self, 'decoder'):
            self.decoder = None
            gc.collect()

        tensors = [ToTensor()(img) for img in decoder_output]
        batch_tensor = torch.stack(tensors).permute(0, 2, 3, 1).cpu()

        return (batch_tensor,image)
Entretoize commented 7 months ago

Where to put that ? Maybe a "unload model node would be useful.