huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.43k stars 5.27k forks source link

Image output tiling for seamless textures with Stable Diffusion #556

Open torrinworx opened 2 years ago

torrinworx commented 2 years ago

Is your feature request related to a problem? Please describe. Currently there is no way to create seamless textures with Stable Diffusion, a crucial feature that is missing.

Describe the solution you'd like Something similar to this pull on the sd-webui repo: https://github.com/sd-webui/stable-diffusion-webui/pull/911

A simple argument in the StableDiffusionPipeline that would enable seamless texture generation for 3D applications.

patrickvonplaten commented 9 months ago

I'm a bit lost with this issue, could someone try to summarize the problem ideally with a reproducible code snippet?

arisha07 commented 8 months ago

Tried (https://github.com/huggingface/diffusers/issues/556#issuecomment-1691213152) solution with diffusers==0.25.0 and it does not tile properly. Anyone else facing this issue ?

github-actions[bot] commented 7 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

DJviolin commented 7 months ago

Keeping the light alive...

alexisrolland commented 7 months ago

Tried (#556 (comment)) solution with diffusers==0.25.0 and it does not tile properly. Anyone else facing this issue ?

The solution mentioned by cmdr2 still works with the latest version of diffusers 0.26.3 but mostly for txt2img. When doing img2img, the generated images would not tile properly. Example of img2img:

Seamless tiling horizontaly (x) and vertically (y) - KO img2img_x_true_y_true

Seamless tiling vertically (y) only - Looks kinda OK img2img_x_false_y_true

Seamless tiling horizontaly (x) only - KO img2img_x_true_y_false

Anyone would have a trick to implement seamless tiling with img2img, ControlNets, etc?

alexisrolland commented 7 months ago

After further investigation, the problem in my previous message was not because of img2img, but rather because I was using an SDXL model. The previous code provided in this thread was for SD1.5 and SD2.1.

For everyone who would need to implement seamless tiling, here is the last version of the code which I have been using on my side. As of diffusers 0.26.3, this still works:

def seamless_tiling(pipeline, x_axis, y_axis):
    """Utility function used to configure the pipeline to generate seamless images."""

    def asymmetric_conv2d_convforward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
        self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
        self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
        working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
        working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
        return torch.nn.functional.conv2d(working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups)

    # Set padding mode
    x_mode = 'circular' if x_axis else 'constant'
    y_mode = 'circular' if y_axis else 'constant'

    # For SDXL models
    if os.environ['BASE_MODEL'] in ['XL']:
        targets = [pipeline.vae, pipeline.text_encoder, pipeline.text_encoder_2, pipeline.unet]

    # For SD1.5 and SD2.1 models
    else:
        targets = [pipeline.vae, pipeline.text_encoder, pipeline.unet]

    convolution_layers = []
    for target in targets:
        for module in target.modules():
            if isinstance(module, torch.nn.Conv2d):
                convolution_layers.append(module)

    for layer in convolution_layers:
        if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
            layer.lora_layer = lambda * x: 0

        layer._conv_forward = asymmetric_conv2d_convforward.__get__(layer, torch.nn.Conv2d)

    return pipeline
OrenGenieLabs commented 7 months ago

@alexisrolland I was unable to make it work for diffusers 0.26.3. I've used the model runwayml/stable-diffusion-v1-5 with the following code:

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
    use_safetensors=True
)

pipe = seamless_tiling(pipe, True, True)

pipe.to("cuda")

generator = torch.Generator("cuda").manual_seed(42)

input_args = {
    "prompt": "apples and bananas",
    "generator": generator,
}
res = pipe(**input_args)

What have you done differently so it worked?

----------- UPDATE ------------ It seems that the upcoming version will have it fixed, as it will include the following PR #6031: Make LoRACompatibleConv padding_mode work

alexisrolland commented 7 months ago

@alexisrolland I was unable to make it work for diffusers 0.26.3. I've used the model runwayml/stable-diffusion-v1-5 with the following code:

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
    use_safetensors=True
)

pipe = seamless_tiling(pipe, True, True)

pipe.to("cuda")

generator = torch.Generator("cuda").manual_seed(42)

input_args = {
    "prompt": "apples and bananas",
    "generator": generator,
}
res = pipe(**input_args)

What have you done differently so it worked?

----------- UPDATE ------------ It seems that the upcoming version will have it fixed, as it will include the following PR #6031: Make LoRACompatibleConv padding_mode work

@OrenGenieLabs, here is a complete code snippet that works in diffusers 0.26.3, but it's using SDXL... You can adapt it to SD1.5

import torch
from typing import Optional
from diffusers import StableDiffusionXLPipeline
from diffusers.models.lora import LoRACompatibleConv

def seamless_tiling(pipeline, x_axis, y_axis):
    def asymmetric_conv2d_convforward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
        self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
        self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
        working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
        working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
        return torch.nn.functional.conv2d(working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups)

    # Set padding mode
    x_mode = 'circular' if x_axis else 'constant'
    y_mode = 'circular' if y_axis else 'constant'

    targets = [pipeline.vae, pipeline.text_encoder, pipeline.text_encoder_2, pipeline.unet]

    convolution_layers = []
    for target in targets:
        for module in target.modules():
            if isinstance(module, torch.nn.Conv2d):
                convolution_layers.append(module)

    for layer in convolution_layers:
        if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
            layer.lora_layer = lambda * x: 0

        layer._conv_forward = asymmetric_conv2d_convforward.__get__(layer, torch.nn.Conv2d)

    return pipeline

# Init pipeline
MODEL_PATH = '/your/model.safetensors'
pipeline = StableDiffusionXLPipeline.from_single_file(MODEL_PATH, torch_dtype=torch.float16)
pipeline.enable_model_cpu_offload()

# Generation settings
prompt = ["texture of a red brick wall"]
seed = 123456
generator = torch.Generator(device='cuda').manual_seed(seed)

# Set seamless tiling
pipeline = seamless_tiling(pipeline=pipeline, x_axis=True, y_axis=True)

# Generate and save image
image = pipeline(
    prompt=prompt,
    width=1024,
    height=1024,
    num_inference_steps=20,
    guidance_scale=7,
    num_images_per_prompt=1,
    generator=generator
).images[0]

# Reset seamless tiling
seamless_tiling(pipeline=pipeline, x_axis=False, y_axis=False)

torch.cuda.empty_cache()
image.save('image.png')
DaLizardWizard commented 7 months ago

Using your fix @alexisrolland I was unable to get it to work in any diffusers > 0.21. Not sure if they are changing things as we speak or if it is due to me using a refiner/ performing asymetrically

Some minor modifications to work with my code


    def asymmetric_conv2d_convforward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
        self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
        self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
        working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
        working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
        return torch.nn.functional.conv2d(working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups)

    # Set padding mode
    x_mode = 'circular' if x_axis else 'constant'
    y_mode = 'circular' if y_axis else 'constant'

    if is_refiner:
        targets = [pipeline.vae, pipeline.unet]
    else:
        targets = [pipeline.vae, pipeline.text_encoder, pipeline.unet]

    convolution_layers = []
    for index, target in enumerate(targets):
        for module in target.modules():
            if isinstance(module, torch.nn.Conv2d):
                convolution_layers.append(module)

    for layer in convolution_layers:
        if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
            layer.lora_layer = lambda * x: 0

        layer._conv_forward = asymmetric_conv2d_convforward.__get__(layer, torch.nn.Conv2d)

    return pipeline

base = seamless_tiling(base,True,False, False)
print("base adjusted to asymetric tiling")
refiner = seamless_tiling(refiner,True,False, True)
print("refiner adjusted to asymetric tiling")```
OrenGenieLabs commented 7 months ago

@DaLizardWizard What what I've seen, until the upcoming diffusers update, you must implement the solution of PR #6031 by yourself, and that the only way to achieve the seamless tiling.

This means changing the forward method of the LoRACompatibleConv convolution layers. Example code for symmetric tiling in both x and y axes:

def forward_symmetric(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
    hidden_states = torch.nn.functional.pad(hidden_states, self._reversed_padding_repeated_twice, mode='circular')
    padding = torch.nn.modules.utils._pair(0)

    original_outputs = torch.nn.functional.conv2d(hidden_states, self.weight, self.bias, self.stride, padding, self.dilation, self.groups)

    if self.lora_layer is None:
        return original_outputs
    else:
        return original_outputs + (scale * self.lora_layer(hidden_states))

Than, replace the line:

layer._conv_forward = asymmetric_conv2d_convforward.__get__(layer, torch.nn.Conv2d)

with (first line stays the same, but as mentioned we change the forward function):

layer._conv_forward = asymmetric_conv2d_convforward.__get__(layer, torch.nn.Conv2d)
if isinstance(layer, LoRACompatibleConv):
    layer.forward = forward_symmetric.__get__(layer, LoRACompatibleConv)

And to remove the symmetric effect, simply restore the forward method with the following:

def forward_regular(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
    original_outputs = torch.nn.functional.conv2d(hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

    if self.lora_layer is None:
        return original_outputs
    else:
        return original_outputs + (scale * self.lora_layer(hidden_states))
yiyixuxu commented 7 months ago

cc @stevhliu @sayakpaul this seems like a pretty popular feature - maybe we should collect it as community workflow examples?

alexisrolland commented 7 months ago

@OrenGenieLabs @DaLizardWizard the code I provided above works with diffusers 0.26.3. I'll try to provide the adapted code for SD1.5 if that's what you're after.

Also saying that "it does not work" but not giving cues about the error message you encounter or the complete code does not help us to help you ;)

alexisrolland commented 7 months ago

@OrenGenieLabs @DaLizardWizard here is the adapted version of the previous code I shared. This one is for SD1.5, using diffusers 0.26.3 and it works

import torch
from typing import Optional
from diffusers import StableDiffusionPipeline
from diffusers.models.lora import LoRACompatibleConv

def seamless_tiling(pipeline, x_axis, y_axis):
    def asymmetric_conv2d_convforward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
        self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
        self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
        working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
        working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
        return torch.nn.functional.conv2d(working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups)

    # Set padding mode
    x_mode = 'circular' if x_axis else 'constant'
    y_mode = 'circular' if y_axis else 'constant'

    targets = [pipeline.vae, pipeline.text_encoder, pipeline.unet]
    convolution_layers = []
    for target in targets:
        for module in target.modules():
            if isinstance(module, torch.nn.Conv2d):
                convolution_layers.append(module)

    for layer in convolution_layers:
        if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
            layer.lora_layer = lambda * x: 0

        layer._conv_forward = asymmetric_conv2d_convforward.__get__(layer, torch.nn.Conv2d)

    return pipeline

# Init pipeline
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True)
pipeline.enable_model_cpu_offload()

# Generation settings
prompt = ["texture of a red brick wall"]
seed = 123456
generator = torch.Generator(device='cuda').manual_seed(seed)

# Set seamless tiling
pipeline = seamless_tiling(pipeline=pipeline, x_axis=True, y_axis=True)

# Generate and save image
image = pipeline(
    prompt=prompt,
    width=512,
    height=512,
    num_inference_steps=20,
    guidance_scale=7,
    num_images_per_prompt=1,
    generator=generator
).images[0]

# Reset seamless tiling
seamless_tiling(pipeline=pipeline, x_axis=False, y_axis=False)

torch.cuda.empty_cache()
image.save('image.png')

Generated image:

image

Tiled:

tiled

DaLizardWizard commented 7 months ago

@alexisrolland Yes you are absolutely right I should have been more clear. My tests using diffusers > 0.21 successfully completed using the SDXL model with no errors but they did not tile. As in the output images looked like a normal SDXL image.

I will post the entire code for completeness as it also works using a LoRA for me

madaror commented 7 months ago

@OrenGenieLabs @DaLizardWizard here is the adapted version of the previous code I shared. This one is for SD1.5, using diffusers 0.26.3 and it works

import torch
from typing import Optional
from diffusers import StableDiffusionPipeline
from diffusers.models.lora import LoRACompatibleConv

def seamless_tiling(pipeline, x_axis, y_axis):
    def asymmetric_conv2d_convforward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
        self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
        self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
        working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
        working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
        return torch.nn.functional.conv2d(working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups)

    # Set padding mode
    x_mode = 'circular' if x_axis else 'constant'
    y_mode = 'circular' if y_axis else 'constant'

    targets = [pipeline.vae, pipeline.text_encoder, pipeline.unet]
    convolution_layers = []
    for target in targets:
        for module in target.modules():
            if isinstance(module, torch.nn.Conv2d):
                convolution_layers.append(module)

    for layer in convolution_layers:
        if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
            layer.lora_layer = lambda * x: 0

        layer._conv_forward = asymmetric_conv2d_convforward.__get__(layer, torch.nn.Conv2d)

    return pipeline

# Init pipeline
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True)
pipeline.enable_model_cpu_offload()

# Generation settings
prompt = ["texture of a red brick wall"]
seed = 123456
generator = torch.Generator(device='cuda').manual_seed(seed)

# Set seamless tiling
pipeline = seamless_tiling(pipeline=pipeline, x_axis=True, y_axis=True)

# Generate and save image
image = pipeline(
    prompt=prompt,
    width=512,
    height=512,
    num_inference_steps=20,
    guidance_scale=7,
    num_images_per_prompt=1,
    generator=generator
).images[0]

# Reset seamless tiling
seamless_tiling(pipeline=pipeline, x_axis=False, y_axis=False)

torch.cuda.empty_cache()
image.save('image.png')

Generated image:

image

Tiled:

tiled

I was trying to use your script with the same diffusers version as you but the brick wall is not tiled as the result you posted..

dishanil commented 6 months ago

@OrenGenieLabs @DaLizardWizard here is the adapted version of the previous code I shared. This one is for SD1.5, using diffusers 0.26.3 and it works

import torch
from typing import Optional
from diffusers import StableDiffusionPipeline
from diffusers.models.lora import LoRACompatibleConv

def seamless_tiling(pipeline, x_axis, y_axis):
    def asymmetric_conv2d_convforward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
        self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
        self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
        working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
        working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
        return torch.nn.functional.conv2d(working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups)

    # Set padding mode
    x_mode = 'circular' if x_axis else 'constant'
    y_mode = 'circular' if y_axis else 'constant'

    targets = [pipeline.vae, pipeline.text_encoder, pipeline.unet]
    convolution_layers = []
    for target in targets:
        for module in target.modules():
            if isinstance(module, torch.nn.Conv2d):
                convolution_layers.append(module)

    for layer in convolution_layers:
        if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
            layer.lora_layer = lambda * x: 0

        layer._conv_forward = asymmetric_conv2d_convforward.__get__(layer, torch.nn.Conv2d)

    return pipeline

# Init pipeline
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True)
pipeline.enable_model_cpu_offload()

# Generation settings
prompt = ["texture of a red brick wall"]
seed = 123456
generator = torch.Generator(device='cuda').manual_seed(seed)

# Set seamless tiling
pipeline = seamless_tiling(pipeline=pipeline, x_axis=True, y_axis=True)

# Generate and save image
image = pipeline(
    prompt=prompt,
    width=512,
    height=512,
    num_inference_steps=20,
    guidance_scale=7,
    num_images_per_prompt=1,
    generator=generator
).images[0]

# Reset seamless tiling
seamless_tiling(pipeline=pipeline, x_axis=False, y_axis=False)

torch.cuda.empty_cache()
image.save('image.png')

Generated image:

image

Tiled:

tiled

I tried with the same version of diffusers and the exact script and used PIL to tile the generated image for me. I do not observe seamless tiling. Although, if I just change the diffusers version to 0.21.4, the generated pattern seamlessly stitches up.

dishanil commented 6 months ago

Although, if use @OrenGenieLabs 's forward_symmetric for the lora layers, the solution works perfectly for diffusers versions >=0.22 as well. Thank you!

sayakpaul commented 6 months ago

@yiyixuxu does it make sense to have a enable_seamless_tiling() method on the pipelines? WDYT?

yiyixuxu commented 6 months ago

@sayakpaul I don't think so - but it can be a community pipeline though

asomoza commented 6 months ago

My honest opinion is that it should, this is a feature that helps people that use SD to generate textures for 3D software, all major apps have this as an "enable option" or as an extension and is a must if you want diffusers to be used in the professional space.

You can just do a "seamless texture pattern stable diffusion" search in google and you'll find a lot of articles, tutorials and videos about this.

If you want a real use case example I can do one with blender.

alexisrolland commented 6 months ago

@dishanil @OrenGenieLabs @DaLizardWizard I have no idea why it does not work for you... I re(re)tested the code I provided above and it works seamlessly for me (pun intended ;))

Here is my diffusers config:

- `diffusers` version: 0.26.3
- Platform: Linux-5.15.133.1-microsoft-standard-WSL2-x86_64-with-glibc2.36
- Python version: 3.10.13
- PyTorch version (GPU?): 2.1.2+cu121 (True)
- Huggingface_hub version: 0.20.3
- Transformers version: 4.36.2
- Accelerate version: 0.26.1
- xFormers version: 0.0.23.post1

Here is the code which I copy/pasted in my Python interpreter... I just removed the comments and blank lines:

import torch
from typing import Optional
from diffusers import StableDiffusionPipeline
from diffusers.models.lora import LoRACompatibleConv

def seamless_tiling(pipeline, x_axis, y_axis):
    def asymmetric_conv2d_convforward(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None):
        self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
        self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
        working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode)
        working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode)
        return torch.nn.functional.conv2d(working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups)
    x_mode = 'circular' if x_axis else 'constant'
    y_mode = 'circular' if y_axis else 'constant'
    targets = [pipeline.vae, pipeline.text_encoder, pipeline.unet]
    convolution_layers = []
    for target in targets:
        for module in target.modules():
            if isinstance(module, torch.nn.Conv2d):
                convolution_layers.append(module)
    for layer in convolution_layers:
        if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None:
            layer.lora_layer = lambda * x: 0
        layer._conv_forward = asymmetric_conv2d_convforward.__get__(layer, torch.nn.Conv2d)
    return pipeline

pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True)
pipeline.enable_model_cpu_offload()
prompt = ["texture of a red brick wall"]
seed = 123456
generator = torch.Generator(device='cuda').manual_seed(seed)

pipeline = seamless_tiling(pipeline=pipeline, x_axis=True, y_axis=True)
image = pipeline(
    prompt=prompt,
    width=512,
    height=512,
    num_inference_steps=20,
    guidance_scale=7,
    num_images_per_prompt=1,
    generator=generator
).images[0]
seamless_tiling(pipeline=pipeline, x_axis=False, y_axis=False)

torch.cuda.empty_cache()
image.save('image.png')

Here is the image generated:

image

It tiles perfectly:

tiled

asomoza commented 6 months ago

Since I was in this topic, I tested it with SDXL and it worked for me too:

Generated Tiled
20240313003235_573631814 wall
yiyixuxu commented 6 months ago

@asomoza

My honest opinion is that it should, this is a feature that helps people that use SD to generate textures for 3D software, all major apps have this as an "enable option" or as an extension and is a must if you want diffusers to be used in the professional space.

if we add it as a community pipeline we can apply it with the from_pipe API which I'm working on here and it is intended exactly for use case like this https://github.com/huggingface/diffusers/pull/7241 -

pipe_seamless = DiffusionPipeline.from_pipe(pipe_sd, custom_pipeline="....")
...

it is slightly less convenient to use than features (but only slightly IMO) however, a lot more flexible - and we do not have to worry about overwhelming our pipelines

asomoza commented 6 months ago

@yiyixuxu

I didn't know that PR existed, very nice!

Even so, If you don't want to add this code the the SD pipelines, I think it would be better to just write it as a technique in the new section since is just like 20 lines of code and it wouldn't need to be maintained or updated which is a known issue with community pipelines.

cmdr2 commented 6 months ago

My 2 cents (maintainer of Easy Diffusion) - I'd also request that tiling be included in the main code, because tiling is a feature present in pretty much every Stable Diffusion software (including Easy Diffusion). InvokeAI has it, auto1111 has it, and it's a very common use-case.

I genuinely understand that it's a tricky balance between keeping the project lean vs empowering users. But in this particular case, given the context, I believe it makes sense to tilt towards including it. But that's obviously your call.

It's a regular task for me to fix our tiling code every few releases of diffusers, because something changed that broke tiling. I don't see that changing with making it a community pipeline. People keep patching the code in this ticket, and people will keep patching the code in a community pipeline. But the bottom line is that tiling will keep breaking every few releases (because it isn't a core feature with automated tests).

So to me it doesn't really matter whether the code is here in this ticket, or in a community pipeline. If it isn't in the core, it'll very likely continue breaking every few releases. And given the number of people maintaining the code in this ticket for so long (to keep tiling updated with diffusers), I'd say it signals the utility of this feature.

I'm not sure tiling is a different pipeline in the first place. It's an option for an image, using the same pipeline. Regardless of whether we're doing SDXL or Img2Img or Inpaint etc. But anyway, that's semantics. :)

yiyixuxu commented 6 months ago

@asomoza @cmdr2 thanks for your feedback! I hear y'all

I just realized that we merged this PR https://github.com/huggingface/diffusers/pull/6031 I think with this, we don't have to patch it up this way anymore?

I think we can add it to StableDiffusionMixin (https://github.com/huggingface/diffusers/blob/4974b84564d25bd4b5c594db4e04cb885cc0a9ed/src/diffusers/pipelines/pipeline_utils.py#L1654) and make it automatically available to all pipelines in SD family - this way we do not need to add any code to the pipelines

yiyixuxu commented 6 months ago

hey sorry guys

I looked into this and really can't come up with a clean way to support this feature from diffusers

Also I now agree it does not make too much sense to make community pipelines for it since it should be a feature every single SD pipeline can use; so I added it to the "community script" section in the community folder for now https://github.com/huggingface/diffusers/pull/7358. This will also most likely be added to our docs later

cmdr2 commented 6 months ago

Thanks @yiyixuxu ! That sounds fair enough, thanks for trying!

Just curious, what's the complication with adding a set_seamless_tiling() function to StableDiffusionMixin, which accepts the tiling type as the argument? Not challenging, just curious to understand the complication. :) Thanks!

yiyixuxu commented 6 months ago

@cmdr2

I think we would be able to support it if it's just to modify the padding_mode attribute - but patching the forward method is just way too hacky and it's better to be done outside of diffusers

cmdr2 commented 6 months ago

@yiyixuxu Yeah, that makes sense. I don't have a good suggestion either, still thinking about it. Maybe a wrapper class (like the LoRA Conv class), that is applied when seamless-tiling is enabled, and removed when disabled? Not sure yet whether this will play well with the LoRA Conv class, and whether it'll work well with torch.compile. @sayakpaul - I was wondering if you had a suggestion regarding torch.compile? thanks

I agree that hacks should be avoided, so brainstorming about ways to make this happen without bad code.

Because this feature is really useful, especially when it comes to 3D models and game-dev. For e.g. here's a tweet by Adobe Substance 3D (from yesterday), which shows its ability to generate seamless textures using prompts - https://twitter.com/Substance3D/status/1769778370227683514

sayakpaul commented 6 months ago

Not sure yet whether this will play well with the LoRA Conv class, and whether it'll work well with torch.compile. @sayakpaul - I was wondering if you had a suggestion regarding torch.compile? thanks

Should work as expected. If not, please let us know. Would be more than happy to investigate :)

github-actions[bot] commented 5 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

samiede commented 5 months ago

I've recently had a lot of issues with sdxl and tiling, I've noticed that images come out wonky or display very strong artifacts. I've used the above snippets to generate some examples: First images I've generated with 1.5 with the above technique: image (1) image (4) and with the same seed, also sdxl: image (2) image (3)

Here's another example image (6)

is this expected behavior?

yiyixuxu commented 5 months ago

cc @asomoza here

asomoza commented 5 months ago

Probably depends on the model and the prompt, SDXL understands better a good prompt, so if you get bad results you'll need to make a better prompt (see the ceramic floor):

generated tiled
brick wall inpainting_20240503200135_52524990_52524990 inpainting_20240503200135_tiled
plywood seamless_20240503200548_3245576796_3245576796 seamless_20240503200548_tiled
ceramic floor seamless_20240503200758_1682921198_1682921198 seamless_20240503200758_tiled
ceramic floor seamless uniform pattern seamless_20240503201014_2236064518_2236064518 seamless_20240503201014_tiled
samiede commented 5 months ago

I can get good looking examples from sdxl with a lot of experimentation, but one big issue that we are seeing is that lower resolutions (e.g. here 512x512) result in these very much destroyed and artefacted images. They are always looking similar, and you can see the patterns also emerge in the last example I posted. Here's the artefacted examples:

image (16) image (20)

asomoza commented 5 months ago

SDXL wasn't trained in 512x512 images or even in 768x768, the results will almost always be bad. You'll need to use 1024x1024 or some of the resolutions in which SDXL was trained.

samiede commented 5 months ago

Yes, I am aware of that, the thing I'm wondering about is why the distribution seems to be getting thrown so much stronger when tiling is enabled, i.e. this is a 768x768 generation from sdxl 0.9 without tiling vs with tiling. Default_brick_wall_0

image

Admittedly, 512x512 doesn't produce any useful results with or without tiling

asomoza commented 5 months ago

yeah, I don't know the "true" answer to that, but probably you'll get the same answer I gave you anywhere you ask, not many people might spend time and resources in understanding why a model doesn't work well when you use it in some specific way outside of its domain.

But in the meantime, we can safely assume that this is not a diffusers issue, maybe you can try asking in the Stability AI repo since they trained the model and maybe they did some experiments with it about this.

Edit: I forgot about this, if you need to work with lower resolutions, maybe this will work: https://github.com/bytedance/res-adapter

samiede commented 5 months ago

Yes, I agree, it's definitively not a diffusers issue, sorry for highjacking the thread, I found it on my quest to understand :)

Thanks for the link, I'll check the adapter out!

chirag4798 commented 2 months ago

I also had to replace padding_mode for TransposeConv2d layers as well, here's what worked for me

import torch
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline

def flatten(model: torch.nn.Module):
    """
    Recursively flattens the model to retrieve all layers.
    """
    children = list(model.children())
    flattened = []

    if children == []:
        return model

    for child in children:
        try:
            flattened.extend(flatten(child))
        except TypeError:
            flattened.append(flatten(child))
    return flattened

def seamless_tiling(pipeline):
    """
    Enables seamless tiling for specific layers in the pipeline.
    """
    targets = [pipeline.vae, pipeline.text_encoder, pipeline.unet]

    if hasattr(pipeline, "text_encoder_2"):
        targets.append(pipeline.text_encoder_2)
    if pipeline.image_encoder is not None:
        targets.append(pipeline.image_encoder)

    layers = [
        layer
        for target in targets
        for layer in flatten(target)
        if isinstance(layer, (torch.nn.Conv2d, torch.nn.ConvTranspose2d))
    ]

    for layer in layers:
        layer.padding_mode = "circular"

def main(model_id="SG161222/RealVisXL_V4.0"):
    pipeline = StableDiffusionXLPipeline.from_pretrained(
        model_id, torch_dtype=torch.float16, use_safetensors=True
    )

    prompt = "A seamless pattern showcasing a nature inspired graphic print with flowers and vines"
    pipeline.enable_model_cpu_offload()

    seamless_tiling(pipeline=pipeline)
    image = pipeline(
        prompt=prompt,
        width=1024,
        height=1024,
        num_inference_steps=40,
        guidance_scale=3,
        num_images_per_prompt=1,
    ).images[0]

    torch.cuda.empty_cache()
    image.save(f"image.png")

if __name__ == "__main__":
    main()

Output:

image

Tiled Output:

image
github-actions[bot] commented 3 weeks ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.