lllyasviel / ControlNet

Let us control diffusion models!
Apache License 2.0
28.97k stars 2.62k forks source link

semantic pattern generation given a boundary #593

Closed BigDevil82 closed 7 months ago

BigDevil82 commented 7 months ago

Hello, appreciation for the great work!

I'd like to apply ControlNet to a task that aims to generate a semantic room layout given a specified boundary, as illustrated in the following images. image the prompt for this is:

room layout; a symmetric architectural layout; gray lines; room space separation; residential building plan; 2D floor plan; 2D architectural layout; 2D room layout; 2D floor plan; gray lines representing walls; sample case 0

Note that because it's difficult to describe the content for each image, the prompt is almost the same, with only some small changes across different images.

The source image is the floor boundary for the room layout, and the target is wall positions of a building plan view which is represented by gray lines. I have a total of 240 image pairs for this task.

I've trained ControlNet according to the provided tutorial_train.py script for about 300 epochs, which is about 6k steps.

# Configs
resume_path = "./models/control_sd15_ini.ckpt"
# resume_path = None
prompt_path = "./training/filled_cond/prompt.json"
batch_size = 12
logger_freq = 300
learning_rate = 1e-6
sd_locked = True
only_mid_control = False

if __name__ == "__main__":
    # First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
    model = create_model("./models/cldm_v15.yaml").cpu()
    model.load_state_dict(load_state_dict(resume_path, location="cpu"))
    model.learning_rate = learning_rate
    model.sd_locked = sd_locked
    model.only_mid_control = only_mid_control

    # Misc
    dataset = MyDataset(prompt_path)
    dataloader = DataLoader(dataset, num_workers=4, batch_size=batch_size, shuffle=True)
    logger = ImageLogger(batch_frequency=logger_freq)
    trainer = pl.Trainer(gpus=[1], precision=32, callbacks=[logger], max_epochs=300)

    # Train!
    trainer.fit(model, dataloader)

The result shows that the network learned about the pattern of the training images, but it can not bind connections between the source and target. The generated result looks rational in terms of layout style but ignores the given condition. image

In case the connection between the source and target image is not obvious enough, I changed the boundary color to red and added contour to the target image like this:

image but the model still can not learn the condition: image

I'm not sure what may be the possible reasons for this problem. The dataset is too small? Or the prompt is too similar? Any additional information can be provided as needed. Any suggestion is welcome, thanks!

usmancheema89 commented 6 months ago

Were you able to figure out a solution to this problem?