OrangeSodahub / SceneCraft

[NeurIPS 2024] SceneCraft: Layout-Guided 3D Scene Generation
https://orangesodahub.github.io/SceneCraft/
MIT License
124 stars 10 forks source link

Question for the depth map? #3

Closed SYSUykLin closed 2 weeks ago

SYSUykLin commented 2 weeks ago

Hello: great work!! I want to run on my own dataset. But I don't know the type of the segmentation map and depth map. For example, I have seg map and depth map: depth segmentation

What kind of transformation should I perform to make the depth map and segmentation map correctly input into your model and run properly? For example, the type of the depth map is gray map, or seg map is one hot or rgb type the same as the hypersim. Thanks

SYSUykLin commented 2 weeks ago

Hello: I assume the depth and segment map is [H, W],because I saw the default parameters condition_type=one_hot. The depth map is the normalized distance from 3D point cloud to camera. shape: [H, W] The seg map is the labels of the 3D point cloud in NYU40 class. shape: [H, W] The parameters of this code I copy from scripts/generate_outputs.py depth.npy is the distance from 3D point cloud to camera. shape: [H, W] segmentation.jpg is the rgb semantic map. shape:[H, W]. Shown as follow:

segmentation

And I write the code:
depth_npy_path = "scripts/depth.npy"
segment_path = "scripts/segmentation.jpg"
controlnet_conditioning_scale = [3.5, 1.5]
base_model_path = "ckpts/stable-diffusion-2-1"
checkpoint_path = "controlnet_ckpts/layout_diffusion_hypersim_prompt_one_hot_multi_control_bs32_epoch24"
checkpoint_subfolder = "checkpoint-10900"
condition_type = "one_hot"

def generate_output(condition_images, controlnet, pipe, generator):

    prompts = ["A good-looking kitchen"]
    extra_kwargs = dict()
    if pipe.unet.config.in_channels == 5:
        if "depths" not in extra_kwargs:
            scale_h = SCALE_H[args.dataset]
            depths = [transforms.Resize(scale_h, interpolation=transforms.InterpolationMode.NEAREST_EXACT)(depth) for depth in depth_images]
            extra_kwargs.update(depths=depths)

        depth_inv_norm_smalls = [depth_inv_norm_transforms(transforms.ToTensor()(depth)[0])
                                    for depth in extra_kwargs["depths"]]
        height, width = extra_kwargs["depths"][0].shape
        extra_kwargs.update(depth_inv_norm_smalls=depth_inv_norm_smalls, height=height, width=width)

    if isinstance(pipe, StableDiffusionControlNetPipeline):
        extra_kwargs.update(image=control_images, condition_type=condition_type, 
                            indice_to_embedding=None,
                            controlnet_conditioning_scale=controlnet_conditioning_scale,
                            no_depth_cond=False)

    guide_images: List[Image.Image] = pipe(prompts,
                                            num_inference_steps=20,
                                            latents=None,
                                            generator=generator,
                                            guidance_scale=7.5,
                                            **extra_kwargs).images

"""
假设:depth map的格式应该是单纯的深度
     segment map的格式应是one hot
"""
def prepare_images(seg_path, depth_path):
    depth_map = np.load(depth_path)
    nan_mask = np.isnan(depth_map)
    depth_map[nan_mask] = np.nanmax(depth_map)
    max_depth = np.nanmax(depth_map)
    min_depth = np.nanmin(depth_map)
    norm = colors.LogNorm(vmin=min_depth, vmax=max_depth)
    depth_map_norm = norm(depth_map)
    # depth_map_norm = np.concatenate([np.array(depth_map_norm)[:, :, None]] * 3, axis=2).astype(np.uint8)
    depth_images = [Image.fromarray(depth_map_norm)]
    seg_map = cv2.imread(seg_path)
    seg_map = cv2.cvtColor(seg_map, cv2.COLOR_BGR2RGB)
    metadata_semantic_colors_hdf5_file = "scripts/metadata_semantic_colors.hdf5"
    with h5py.File(metadata_semantic_colors_hdf5_file, "r") as f:
        semantic_colors = f["dataset"][:]

    image_idx = np.zeros((seg_map.shape[0], seg_map.shape[1], 1), dtype=np.uint8)
    for i, color in enumerate(semantic_colors):
        diff = np.mean(np.abs(seg_map[:, :, :] - color), axis=-1)
        match_mask = (diff <= 10)

        image_idx[match_mask] = i

    image_idx = np.squeeze(image_idx)
    seg_images = [Image.fromarray(image_idx.astype(np.uint8))]

    # 都是[H, W]
    return [seg_images, depth_images]

def load_ckpts():
    unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
        base_model_path, subfolder="unet", torch_dtype=torch.float16
    )

    if checkpoint_path is None:
        controlnet = None
    elif "multi_control" not in checkpoint_path:
        print("single controlnet...")
        controlnet = ControlNetModel.from_pretrained(
            checkpoint_path, subfolder=f"{checkpoint_subfolder.strip('/')}/controlnetmodel", torch_dtype=torch.float16)
    else:
        print("multiple controlnet...")
        controlnet: MultiControlNetModel = MultiControlNetModel.from_pretrained(
            checkpoint_path, subfolder=f"{checkpoint_subfolder.strip('/')}", torch_dtype=torch.float16
        )
        assert isinstance(controlnet_conditioning_scale, (list, tuple))

    if controlnet is not None:
        pipe_method = StableDiffusionControlNetPipeline
        pipe = pipe_method.from_pretrained(
            base_model_path, unet=unet, controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None,
        )
    else:
        pipe = StableDiffusionPipeline.from_pretrained(
            base_model_path, unet=unet, torch_dtype=torch.float16, safety_checker=None,
        )
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

    pipe.to(torch_device=torch.device('cuda'))
    # pipe.enable_xformers_memory_efficient_attention()
    pipe.set_progress_bar_config(disable=True)
    generator = torch.Generator(device=pipe.device).manual_seed(10000)

    return controlnet, pipe, generator

if __name__ == "__main__":
    control_images = prepare_images(segment_path, depth_npy_path)

    controlnet, pipe, generator = load_ckpts()

    generate_output(control_images, controlnet, pipe, generator)

the shape of depth map and seg map is [H, W]. But it show an error:

    raise ValueError(f"Wrong channels of input condition, expected {self.config.conditioning_channels}, "
ValueError: Wrong channels of input condition, expected 8, got 3.

Could you please tell me the format for inputting depth and semantic maps? Thank you

OrangeSodahub commented 2 weeks ago

Hi, thanks for your interest in this work!

First, the depth map used in our model is a [H, W] shaped map with float values (in meters unit, the normalization will be done internally, so the input shouldn't be a nomalized one). For seg map in one_hot format, it is also a [H, W] shaped map with int values (label indices).

For your issue, it is wierd to have a channel dimention equal to 3. Could you please add a breakpoint here:

https://github.com/OrangeSodahub/SceneCraft/blob/6568a2c9da79beab8c5004db0a9fd26589e89c76/scenecraft/model.py#L845

and check the channels of image right after this line? You could tell me the result.

SYSUykLin commented 2 weeks ago

Hi, thanks for your interest in this work!

First, the depth map used in our model is a [H, W] shaped map with float values (in meters unit). For seg map in one_hot format, it is also a [H, W] shaped map with int values (label indices).

For your issue, it is wierd to have a channel dimention equal to 3. Could you please add a breakpoint here:

https://github.com/OrangeSodahub/SceneCraft/blob/6568a2c9da79beab8c5004db0a9fd26589e89c76/scenecraft/model.py#L845

and check the channels of image right after this line? You could tell me the result.

Thanks very much.

  1. Yes, I read the code. And if the condition image is segmentation(labels), it will encode the image and output [2, 8, H, W]. But if the condition image is depth, it do nothing. So after: https://github.com/OrangeSodahub/SceneCraft/blob/6568a2c9da79beab8c5004db0a9fd26589e89c76/scenecraft/model.py#L845 seg map: [2, 8, H, W]. depth map: [2, 1, H, W]
  2. Sorry, I made a mistake. The error is:
    File "/cluster/personal/NeRF_projects/NeRF/diffusers/SceneCraft/scenecraft/model.py", line 603, in forward
    raise ValueError(f"Wrong channels of input condition, expected {self.config.conditioning_channels}, "
    ValueError: Wrong channels of input condition, expected 8, got 1.

    And the error code is: https://github.com/OrangeSodahub/SceneCraft/blob/6568a2c9da79beab8c5004db0a9fd26589e89c76/scenecraft/model.py#L598

I think that I might be missing a depth encoder? I use the default paramters "condition_type=one_hot", and I found that the indice_to_embedding is None. So I set indice_to_embedding = None. Could it be the problem here?

OrangeSodahub commented 2 weeks ago

Nah indice_to_embedding = None is not the issue.

For seg map, the input channels should be 8. and for depth map it should be 1:

https://huggingface.co/gzzyyxy/layout_diffusion_hypersim_prompt_one_hot_multi_control_bs32_epoch24/blob/1d27df15c619672e64377467e8973d5e116cb36d/checkpoint-10900/multicontrolnetmodel_1/config.json#L21

Could you please let me know the error above is popped when the input image is seg map or depth map?

To check the seg map:

https://github.com/OrangeSodahub/SceneCraft/blob/6568a2c9da79beab8c5004db0a9fd26589e89c76/scenecraft/model.py#L634

check the depth map:

https://github.com/OrangeSodahub/SceneCraft/blob/6568a2c9da79beab8c5004db0a9fd26589e89c76/scenecraft/model.py#L641

SYSUykLin commented 2 weeks ago

Nah indice_to_embedding = None is not the issue.

For seg map, the input channels should be 8. and for depth map it should be 1:

https://huggingface.co/gzzyyxy/layout_diffusion_hypersim_prompt_one_hot_multi_control_bs32_epoch24/blob/1d27df15c619672e64377467e8973d5e116cb36d/checkpoint-10900/multicontrolnetmodel_1/config.json#L21

Could you please let me know the error above is popped when the input image is seg map or depth map?

Thanks for your reply.

  1. Sorry, it's possible that I didn't post the full error message. I think is depth map. Error message:
    You have 2 ControlNets and you have passed 1 prompts. The conditionings will be fixed across the prompts.
    Traceback (most recent call last):
    File "/cluster/personal/NeRF_projects/NeRF/diffusers/SceneCraft/scripts/inference.py", line 133, in <module>
    generate_output(control_images, controlnet, pipe, generator)
    File "/cluster/personal/NeRF_projects/NeRF/diffusers/SceneCraft/scripts/inference.py", line 50, in generate_output
    guide_images: List[Image.Image] = pipe(prompts,
    File "/opt/conda/envs/diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
    File "/cluster/personal/NeRF_projects/NeRF/diffusers/SceneCraft/scenecraft/model.py", line 1059, in __call__
    down_block_res_samples, mid_block_res_sample = self.controlnet(
    File "/opt/conda/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
    File "/opt/conda/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
    File "/cluster/personal/NeRF_projects/NeRF/diffusers/SceneCraft/scenecraft/model.py", line 646, in forward
    down_samples2, mid_sample2 = self.nets[1](
    File "/opt/conda/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
    File "/opt/conda/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
    File "/cluster/personal/NeRF_projects/NeRF/diffusers/SceneCraft/scenecraft/model.py", line 603, in forward
    raise ValueError(f"Wrong channels of input condition, expected {self.config.conditioning_channels}, "
    ValueError: Wrong channels of input condition, expected 8, got 1.

    note that

    File "/cluster/personal/NeRF_projects/NeRF/diffusers/SceneCraft/scenecraft/model.py", line 646, in forward
    down_samples2, mid_sample2 = self.nets[1](

    https://github.com/OrangeSodahub/SceneCraft/blob/6568a2c9da79beab8c5004db0a9fd26589e89c76/scenecraft/model.py#L641

nets[1] is depth map.

SYSUykLin commented 2 weeks ago

And I print the shape of the controlnet_cond before feed into the self.nets:

        print(controlnet_cond[0].shape)
        down_samples, mid_sample = self.nets[0](
            sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states,
            controlnet_cond=controlnet_cond[0], conditioning_scale=conditioning_scale[0], **kwargs
        )
        print(controlnet_cond[1].shape)

        if not kwargs.get("no_depth_cond", False):
            # depth controlnet
            kwargs.pop("indice_to_embedding", None)
            down_samples2, mid_sample2 = self.nets[1](
                sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states,
                controlnet_cond=controlnet_cond[1], conditioning_scale=conditioning_scale[1], **kwargs
            )

controlnet_cond[0].shape is [2, 8, 768, 1024] controlnet_cond[1].shape is [2, 1, 768, 1024]

OrangeSodahub commented 2 weeks ago

Hi, then it maybe a model loading issue. Since the channels should be set to 1 when loading the model:

https://huggingface.co/gzzyyxy/layout_diffusion_hypersim_prompt_one_hot_multi_control_bs32_epoch24/blob/1d27df15c619672e64377467e8973d5e116cb36d/checkpoint-10900/multicontrolnetmodel_1/config.json#L21

And I noticed you use:

base_model_path = "ckpts/stable-diffusion-2-1"
checkpoint_path = "controlnet_ckpts/layout_diffusion_hypersim_prompt_one_hot_multi_control_bs32_epoch24"

Did you download the checkpoints to local disk? If so, I recommend you directly replace the model name to the repo name in huggingface, for example: gzzyyxy/layout_diffusion_hypersim_prompt_one_hot_multi_control_bs32_epoch24. I'm not sure that loading from local disk will read the config.json correctly or not but you can try what I suggest first.

Before you try that, you can check the self.nets[1].config.conditioning_channels, it should be 1 as set in config.json. If it is not, try loading from huggingface.

I verified myself with:

import torch
from scenecraft.model import MultiControlNetModel

checkpoint_path = "gzzyyxy/layout_diffusion_hypersim_prompt_one_hot_multi_control_bs32_epoch24"
checkpoint_subfolder = "checkpoint-10900"

controlnet: MultiControlNetModel = MultiControlNetModel.from_pretrained(
    checkpoint_path, subfolder=f"{checkpoint_subfolder.strip('/')}", torch_dtype=torch.float16
)

print(controlnet.nets[0].config.conditioning_channels)
print(controlnet.nets[1].config.conditioning_channels)

output:

8
1
OrangeSodahub commented 2 weeks ago

I tried loading from local folders and that also perform well. So just make sure that config.json file is correct. Please let me know if you still cannot address that.

SYSUykLin commented 2 weeks ago

I tried loading from local folders and that also perform well. So just make sure that config.json file is correct. Please let me know if you still cannot address that.

Thanks for your reply.

  1. Yes, the problem is error loading.
  2. In MultiControlNetModel, I modify the code: controlnets = [ ControlNetModel.from_pretrained(os.path.join(model_path_to_load, f"{subfolder}/multicontrolnetmodel"), torch_dtype=torch_dtype), ControlNetModel.from_pretrained(os.path.join(model_path_to_load, f"{subfolder}/multicontrolnetmodel_1"), torch_dtype=torch_dtype) ] It work. Maybe the version of the diffusers. Thanks for your help. Thanks very much.
SYSUykLin commented 2 weeks ago

hello: last question. Could u provide the corresponding labels and numbers?For example, (bed, 0), (wall, 1), (desk, 2)... this indices number is the elements of the seg map. Thanks.

OrangeSodahub commented 2 weeks ago

I found you proposed data sample is from Hypersim. So you can just follow it (nyu40 labels):

https://github.com/OrangeSodahub/SceneCraft/blob/6568a2c9da79beab8c5004db0a9fd26589e89c76/scenecraft/data/hypersim_utils/meta_data/nyu40id.txt#L1-L3

SYSUykLin commented 2 weeks ago

I found you proposed data sample is from Hypersim. So you can just follow it (nyu40 labels):

https://github.com/OrangeSodahub/SceneCraft/blob/6568a2c9da79beab8c5004db0a9fd26589e89c76/scenecraft/data/hypersim_utils/meta_data/nyu40id.txt#L1-L3

Thanks very much.