huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.22k stars 5.4k forks source link

Unexpected results with controlnet_reference community pipeline #3566

Closed reimager closed 1 year ago

reimager commented 1 year ago

Describe the bug

I can't seem to get expected results from the new controlnet_reference pipeline. I get subpar results with style_fidelity=0.0, but with anything greater than 0.0 I get weird/collapsed images. I believe there are a few issues and have tried many things, but to keep things simple below is an test case with sd-webui-controlnet 1.1.191 and diffusers controlnet_reference pipeline with as similar as I could get them.

The test case:

summary

Input Images: Reference Image: cat-striped-768x768.jpg Depth Image: cat-orange-768x768.jpg HED Image: cat-orange-768x768.jpg

A1111+sd-webui-controlnet Config: ControlNet 1.1.191 Model: Realistic Vision 2.0 (result similar with runwayml 1.5 if that is preferred) Prompt: a cat Width: 768 Height: 768 Steps: 40 Scheduler: euler-a Scale: 7.0 Seed: 0 ControlNet_0: reference_only Image: cat-striped-768x768.jpg Weight: 1.0 ControlNet_1: depth_midas Image: cat-orange-768x768.jpg Weight: 1.0 ControlNet_1: softedge_hed Image: cat-orange-768x768.jpg Weight: 1.0

A1111+sd-webui-controlnet results: controlnet-results

diffusers config is in the script below, but is the same config as used in A1111 or as close as I could get.

diffusers results: diffusers-results

@okotaku

Reproduction

import torch
from PIL import Image
from pytorch_lightning import seed_everything
from diffusers.utils import load_image
from diffusers import (
    DiffusionPipeline,
    ControlNetModel,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import (
    MultiControlNetModel,
)
from diffusers.schedulers import (
    EulerAncestralDiscreteScheduler,
)

import transformers
from controlnet_aux import HEDdetector
import numpy as np

image_ref = load_image("http://metaloft.com/images/cat-striped-768x768.jpg").convert("RGBA")
image_con = load_image("http://metaloft.com/images/cat-orange-768x768.jpg").convert("RGB")

controlnet_hed = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_softedge", torch_dtype=torch.float16)
controlnet_depth = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16)
hed_model = HEDdetector.from_pretrained('lllyasviel/ControlNet')
depth_model = transformers.pipeline(task="depth-estimation", model="Intel/dpt-large")

image_depth = depth_model(image_con)['depth']
image_depth = np.array(image_depth)
image_depth = image_depth[:, :, None]
image_depth = np.concatenate([image_depth, image_depth, image_depth], axis=2)
image_depth = Image.fromarray(image_depth)
image_depth = image_depth.resize((768, 768))
#image_depth.save('depth.png')

image_hed = hed_model(image_con)
image_hed = image_hed.resize((768, 768))
#image_hed.save('hed.png')

model = "SG161222/Realistic_Vision_V2.0"
#model = "runwayml/stable-diffusion-v1-5"

scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model, subfolder="scheduler", use_karras_sigmas=True)
pipe = DiffusionPipeline.from_pretrained(model,
                                         custom_pipeline="stable_diffusion_controlnet_reference",
                                         torch_dtype=torch.float16,
                                         controlnet=MultiControlNetModel([controlnet_depth, controlnet_hed]),
                                         revision="main",
                                         scheduler=scheduler)
pipe.to("cuda")
pipe.enable_xformers_memory_efficient_attention()

for i, style_fidelity in enumerate([0.0, 0.5, 1.0]):
    seed_everything(0)
    result = pipe(prompt=["a cat"], image=[image_depth, image_hed],
                  width=768, height=768, guidance_scale=7.0, num_inference_steps=40,
                  ref_image=image_ref, reference_attn=True, reference_adain=False, style_fidelity=style_fidelity)
    result.images[0].save(f"output-{i}.png")

summary = Image.new("RGB", (768*3, 768))
for i in range(3):
    summary.paste(Image.open(f"output-{i}.png"), (768*i,0))
summary.save("summary.png")

Logs

No response

System Info

okotaku commented 1 year ago

@reimager Thank you for sharing the details! In my inference using my web UI, I am getting similar results as with the diffusers. It's strange.

スクリーンショット 2023-05-26 16 53 21

ControlNet 1.1.194

okotaku commented 1 year ago

I found some ControlNet settings caused the results.

depth only

スクリーンショット 2023-05-26 17 42 20

depth and soft edge

スクリーンショット 2023-05-26 17 41 34

reimager commented 1 year ago

Interesting. What version of sd-webui-controlnet do you have? I originally was thinking we just had some newer updates to port over to this pipeline.

Here are my very different results with your settings (same seed too)

2023-05-26-073353_3832x2121_scrot

And with now with "ControlNet is more important":

2023-05-26-073518_3832x2121_scrot

ControlNet v1.1.191

2023-05-26-073650_3832x2121_scrot

Things are definitely better with depth-only, but you still get the darkness with style_fidelity=1.0. The difference (for me) is more pronounced the more controlnets you have, which is why I used 2 in the original example.

okotaku commented 1 year ago

It seems that there is a version mismatch with the ControlNet model being used. While I am using ControlNet v1.1, You are utilizing an older generation model.

https://huggingface.co/lllyasviel/ControlNet-v1-1/tree/main https://huggingface.co/lllyasviel/ControlNet/tree/main/models

Based on your testing, it appears that the combination with the multi ControlNet works well with the previous-generation model. If you wish to achieve a similar result with diffusers, please try using the following code:

import torch
from PIL import Image
from pytorch_lightning import seed_everything
from diffusers.utils import load_image
from diffusers import (
    DiffusionPipeline,
    ControlNetModel,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import (
    MultiControlNetModel,
)
from diffusers.schedulers import (
    EulerAncestralDiscreteScheduler, UniPCMultistepScheduler
)

import transformers
from controlnet_aux import HEDdetector
import numpy as np

image_ref = load_image("http://metaloft.com/images/cat-striped-768x768.jpg").convert("RGBA")
image_con = load_image("http://metaloft.com/images/cat-orange-768x768.jpg").convert("RGB")

controlnet_hed = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-hed", torch_dtype=torch.float16)
controlnet_depth = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16)
hed_model = HEDdetector.from_pretrained('lllyasviel/ControlNet')
depth_model = transformers.pipeline(task="depth-estimation", model="Intel/dpt-large")

image_depth = depth_model(image_con)['depth']
image_depth = np.array(image_depth)
image_depth = image_depth[:, :, None]
image_depth = np.concatenate([image_depth, image_depth, image_depth], axis=2)
image_depth = Image.fromarray(image_depth)
image_depth = image_depth.resize((768, 768))
#image_depth.save('depth.png')

image_hed = hed_model(image_con)
image_hed = image_hed.resize((768, 768))
#image_hed.save('hed.png')

model = "SG161222/Realistic_Vision_V2.0"
#model = "runwayml/stable-diffusion-v1-5"

scheduler = UniPCMultistepScheduler.from_pretrained(model, subfolder="scheduler", use_karras_sigmas=True)
pipe = DiffusionPipeline.from_pretrained(model,
                                        custom_pipeline="stable_diffusion_controlnet_reference",
                                        torch_dtype=torch.float16,
                                        controlnet=MultiControlNetModel([controlnet_depth, controlnet_hed]),
                                        revision="main",
                                        scheduler=scheduler)
pipe.to("cuda")

for i, style_fidelity in enumerate([0.0, 0.5, 1.0]):
    seed_everything(0)
    result = pipe(prompt=["a cat"], image=[image_depth, image_hed],
                width=768, height=768, guidance_scale=7.0, num_inference_steps=20,
                ref_image=image_ref, reference_attn=True, reference_adain=False, style_fidelity=style_fidelity)
    result.images[0].save(f"output-{i}.png")

summary = Image.new("RGB", (768*3, 768))
for i in range(3):
    summary.paste(Image.open(f"output-{i}.png"), (768*i,0))
summary.save("summary.png")

summary

It seems there might also be issues about ControlNet v1.1 with the implementation on the web UI side.

reimager commented 1 year ago

Great eye! I didn't notice that. Yes indeed - the old models don't have the same darkness issue with style_fidelity. When I move to the old models in diffusers I get the same results as you.

Still pretty different than sd-webui-controlnet, but at least we now have the same results.

okotaku commented 1 year ago

@reimager

webui

It got poor results even with the previous generation controlnet based on some experiments.

スクリーンショット 2023-05-27 10 20 48 スクリーンショット 2023-05-27 10 22 09 スクリーンショット 2023-05-27 10 23 43 スクリーンショット 2023-05-27 10 27 21

diffusers

However, even compared to these results, my implementation of reference AdaIN seems to produce very poor results. I will investigate the possible causes.

summary

reimager commented 1 year ago

I will continue to play with it also. For the most part I'm just trying to replicate the basic reference_only example from A1111. I get amazing results with A1111. It is definitely the first cat, in the pose of the second cat.

I also tried with new controlnet models just to be sure. I haven't been able to replicate these at all in diffusers with any settings.

controlnet 1.0 models: 2023-05-26-205555_3832x2121_scrot

controlnet 1.1 models: (This result really blows my mind)

2023-05-26-211012_3832x2121_scrot

okotaku commented 1 year ago

@reimager

I have found that when using the "controlnet is more important" option in the web UI, I indeed obtain excellent output. This seems to correspond to "guess_mode=True" in diffusers, but diffusers do not yield similar results as the web UI. I will continue to investigate this further.

Specifically, when comparing the results of using "guess_mode=True" to the results obtained using the "controlnet is more important" option with only depth and soft edge, without using a reference, I noticed that the results from diffusers are clearly incorrect.

import cv2
import torch
import numpy as np
from PIL import Image
from diffusers import UniPCMultistepScheduler, EulerAncestralDiscreteScheduler
from diffusers.utils import load_image
import transformers
from controlnet_aux import HEDdetector, PidiNetDetector

input_image = load_image("http://metaloft.com/images/cat-striped-768x768.jpg").convert("RGB")
image_con = load_image("http://metaloft.com/images/cat-orange-768x768.jpg").convert("RGB")

# get canny image
"""image = cv2.Canny(np.array(image_con), 150, 230)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
canny_image.save('c1.png')"""

#processor = HEDdetector.from_pretrained('lllyasviel/ControlNet')
processor = PidiNetDetector.from_pretrained('lllyasviel/Annotators')
image_hed = processor(image_con)
image_hed.save('c1.png')

depth_model = transformers.pipeline(task="depth-estimation", model="Intel/dpt-large"
                                    )
image_depth = depth_model(image_con)['depth']
image_depth = np.array(image_depth)
image_depth = image_depth[:, :, None]
image_depth = np.concatenate([image_depth, image_depth, image_depth], axis=2)
image_depth = Image.fromarray(image_depth)
image_depth.save('c2.png')

controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16)
controlnet_depth = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16)
controlnet_hed = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_softedge", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetReferencePipeline.from_pretrained(
       "SG161222/Realistic_Vision_V2.0",
       controlnet=[controlnet_depth, controlnet_hed],
       safety_checker=None,
       torch_dtype=torch.float16
       ).to('cuda:0')

pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)

from pytorch_lightning import seed_everything
seed_everything(1606954448)
result_img = pipe(ref_image=input_image,
      prompt="a cat",
      image=[image_depth,image_hed],
      num_inference_steps=20,
      height=512,
      width=512,
      reference_attn=True,
      reference_adain=False,
      style_fidelity=1.0,
      guidance_scale=7.0,
      guess_mode=True).images[0]
result_img.save('tmp.png')

tmp

import torch
import numpy as np
from PIL import Image
from diffusers.models import ControlNetModel
from diffusers import UniPCMultistepScheduler, EulerAncestralDiscreteScheduler, StableDiffusionControlNetPipeline
from diffusers.utils import load_image
import transformers
from controlnet_aux import HEDdetector, PidiNetDetector

input_image = load_image("http://metaloft.com/images/cat-striped-768x768.jpg").convert("RGB")
image_con = load_image("http://metaloft.com/images/cat-orange-768x768.jpg").convert("RGB")

#processor = HEDdetector.from_pretrained('lllyasviel/ControlNet')
processor = PidiNetDetector.from_pretrained('lllyasviel/Annotators')
image_hed = processor(image_con).resize((512, 512))
image_hed.save('c1.png')

depth_model = transformers.pipeline(task="depth-estimation", model="Intel/dpt-large"
                                    )
image_depth = depth_model(image_con)['depth']
image_depth = np.array(image_depth)
image_depth = image_depth[:, :, None]
image_depth = np.concatenate([image_depth, image_depth, image_depth], axis=2)
image_depth = Image.fromarray(image_depth).resize((512, 512))
image_depth.save('c2.png')

controlnet_depth = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16)
controlnet_hed = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_softedge", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
       "SG161222/Realistic_Vision_V2.0",
       controlnet=[controlnet_depth, controlnet_hed],
       safety_checker=None,
       torch_dtype=torch.float16
       ).to('cuda:0')

pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)

result_img = pipe(
      prompt="a cat",
      image=[image_depth,image_hed],
      num_inference_steps=20,
      height=512,
      width=512,
      guidance_scale=5.0,
      guess_mode=True
      ).images[0]
result_img.save('tmp.png')

tmp

webui

スクリーンショット 2023-05-28 9 53 04

okotaku commented 1 year ago

I found EulerAncestralDiscreteScheduler didn't work well for guess_mode=True. Following script works well.

tmp

import cv2
import torch
import numpy as np
from PIL import Image
from diffusers import UniPCMultistepScheduler, EulerAncestralDiscreteScheduler
from diffusers.utils import load_image
import transformers
from controlnet_aux import HEDdetector, PidiNetDetector

input_image = load_image("http://metaloft.com/images/cat-striped-768x768.jpg").convert("RGB")
image_con = load_image("http://metaloft.com/images/cat-orange-768x768.jpg").convert("RGB")

# get canny image
"""image = cv2.Canny(np.array(image_con), 150, 230)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
canny_image.save('c1.png')"""

#processor = HEDdetector.from_pretrained('lllyasviel/ControlNet')
processor = PidiNetDetector.from_pretrained('lllyasviel/Annotators')
image_hed = processor(image_con)
image_hed.save('c1.png')

depth_model = transformers.pipeline(task="depth-estimation", model="Intel/dpt-large"
                                    )
image_depth = depth_model(image_con)['depth']
image_depth = np.array(image_depth)
image_depth = image_depth[:, :, None]
image_depth = np.concatenate([image_depth, image_depth, image_depth], axis=2)
image_depth = Image.fromarray(image_depth)
image_depth.save('c2.png')

controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16)
controlnet_depth = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16)
controlnet_hed = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_softedge", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetReferencePipeline.from_pretrained(
       "SG161222/Realistic_Vision_V2.0",
       controlnet=[controlnet_depth, controlnet_hed],
       safety_checker=None,
       torch_dtype=torch.float16
       ).to('cuda:0')

#pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)

from pytorch_lightning import seed_everything
seed_everything(1606954448)
result_img = pipe(ref_image=input_image,
      prompt="a cat",
      image=[image_depth,image_hed],
      #num_inference_steps=20,
      #height=512,
      #width=512,
      reference_attn=True,
      reference_adain=False,
      style_fidelity=1.0,
      #guidance_scale=7.0,
      guess_mode=True).images[0]
result_img.save('tmp.png')
okotaku commented 1 year ago

Is the following result with 'guess_mode=True' a bug?

Normal scheduler

import torch
import numpy as np
from PIL import Image
from diffusers.models import ControlNetModel
from diffusers import UniPCMultistepScheduler, EulerAncestralDiscreteScheduler, StableDiffusionControlNetPipeline
from diffusers.utils import load_image
import transformers
from controlnet_aux import HEDdetector, PidiNetDetector

input_image = load_image("http://metaloft.com/images/cat-striped-768x768.jpg").convert("RGB")
image_con = load_image("http://metaloft.com/images/cat-orange-768x768.jpg").convert("RGB")

#processor = HEDdetector.from_pretrained('lllyasviel/ControlNet')
processor = PidiNetDetector.from_pretrained('lllyasviel/Annotators')
image_hed = processor(image_con).resize((512, 512))
image_hed.save('c1.png')

depth_model = transformers.pipeline(task="depth-estimation", model="Intel/dpt-large"
                                    )
image_depth = depth_model(image_con)['depth']
image_depth = np.array(image_depth)
image_depth = image_depth[:, :, None]
image_depth = np.concatenate([image_depth, image_depth, image_depth], axis=2)
image_depth = Image.fromarray(image_depth).resize((512, 512))
image_depth.save('c2.png')

controlnet_depth = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16)
controlnet_hed = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_softedge", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
       "SG161222/Realistic_Vision_V2.0",
       controlnet=[controlnet_depth, controlnet_hed],
       safety_checker=None,
       torch_dtype=torch.float16
       ).to('cuda:0')

result_img = pipe(
      prompt="a cat",
      image=[image_depth, image_hed],
      guess_mode=True
      ).images[0]
result_img.save('tmp.png')

tmp

UniPCMultistepScheduler

import torch
import numpy as np
from PIL import Image
from diffusers.models import ControlNetModel
from diffusers import UniPCMultistepScheduler, EulerAncestralDiscreteScheduler, StableDiffusionControlNetPipeline
from diffusers.utils import load_image
import transformers
from controlnet_aux import HEDdetector, PidiNetDetector

input_image = load_image("http://metaloft.com/images/cat-striped-768x768.jpg").convert("RGB")
image_con = load_image("http://metaloft.com/images/cat-orange-768x768.jpg").convert("RGB")

#processor = HEDdetector.from_pretrained('lllyasviel/ControlNet')
processor = PidiNetDetector.from_pretrained('lllyasviel/Annotators')
image_hed = processor(image_con).resize((512, 512))
image_hed.save('c1.png')

depth_model = transformers.pipeline(task="depth-estimation", model="Intel/dpt-large"
                                    )
image_depth = depth_model(image_con)['depth']
image_depth = np.array(image_depth)
image_depth = image_depth[:, :, None]
image_depth = np.concatenate([image_depth, image_depth, image_depth], axis=2)
image_depth = Image.fromarray(image_depth).resize((512, 512))
image_depth.save('c2.png')

controlnet_depth = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16)
controlnet_hed = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_softedge", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
       "SG161222/Realistic_Vision_V2.0",
       controlnet=[controlnet_depth, controlnet_hed],
       safety_checker=None,
       torch_dtype=torch.float16
       ).to('cuda:0')
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

result_img = pipe(
      prompt="a cat",
      image=[image_depth, image_hed],
      guess_mode=True,
      num_inference_steps=20,
      ).images[0]
result_img.save('tmp.png')

tmp2

EulerAncestralDiscreteScheduler

import torch
import numpy as np
from PIL import Image
from diffusers.models import ControlNetModel
from diffusers import UniPCMultistepScheduler, EulerAncestralDiscreteScheduler, StableDiffusionControlNetPipeline
from diffusers.utils import load_image
import transformers
from controlnet_aux import HEDdetector, PidiNetDetector

input_image = load_image("http://metaloft.com/images/cat-striped-768x768.jpg").convert("RGB")
image_con = load_image("http://metaloft.com/images/cat-orange-768x768.jpg").convert("RGB")

#processor = HEDdetector.from_pretrained('lllyasviel/ControlNet')
processor = PidiNetDetector.from_pretrained('lllyasviel/Annotators')
image_hed = processor(image_con).resize((512, 512))
image_hed.save('c1.png')

depth_model = transformers.pipeline(task="depth-estimation", model="Intel/dpt-large"
                                    )
image_depth = depth_model(image_con)['depth']
image_depth = np.array(image_depth)
image_depth = image_depth[:, :, None]
image_depth = np.concatenate([image_depth, image_depth, image_depth], axis=2)
image_depth = Image.fromarray(image_depth).resize((512, 512))
image_depth.save('c2.png')

controlnet_depth = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth", torch_dtype=torch.float16)
controlnet_hed = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_softedge", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
       "SG161222/Realistic_Vision_V2.0",
       controlnet=[controlnet_depth, controlnet_hed],
       safety_checker=None,
       torch_dtype=torch.float16
       ).to('cuda:0')
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)

result_img = pipe(
      prompt="a cat",
      image=[image_depth, image_hed],
      guess_mode=True,
      num_inference_steps=20,
      ).images[0]
result_img.save('tmp.png')

tmp3

okotaku commented 1 year ago

@reimager

I have created a PR based on this issue.

https://github.com/huggingface/diffusers/pull/3589

reimager commented 1 year ago

Interesting. I was just using euler-a for the most 'apples to apples' comparison. I usually use unipc, so I'll try that again.

Sorry I assumed style_fidelity = 1.0 was the equivalent of 'controlnet more important' based on this comment:

            style_fidelity (`float`):
                style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,
                elif style_fidelity=0.0, prompt more important, else balanced.

I will also try again with playing with guess_mode and controlnet_conditioning_scale instead

okotaku commented 1 year ago

Sorry I assumed style_fidelity = 1.0 was the equivalent of 'controlnet more important' based on this comment:

Yes, you can see here. Reference Control and controlnet more important means style_fidelity = 1.0.

https://github.com/Mikubill/sd-webui-controlnet/blob/main/scripts/hook.py#L480-L481

Depth or other Control and controlnet more important means guess_mode=True.

https://github.com/Mikubill/sd-webui-controlnet/blob/main/scripts/hook.py#L410-L414 https://github.com/Mikubill/sd-webui-controlnet/blob/2e0dc37d222aaba355a71dac0eda4bb7ca54f05f/scripts/external_code.py#L203-L204

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.