yandex-research / ddpm-segmentation

Label-Efficient Semantic Segmentation with Diffusion Models (ICLR'2022)
https://yandex-research.github.io/ddpm-segmentation/
MIT License
662 stars 60 forks source link

Question about colorize_mask #7

Closed JingyeChen closed 2 years ago

JingyeChen commented 2 years ago

Hello, thanks for your great work and your effort on sharing this code! However, when running the code, I encountered an issue regarding the function "colorize_mask". More specifically, when I use the following code at "pixel_classifier.py",

mask = colorize_mask(pred[0], palette)

an error was reported

image

And it works fine when I modify it to

mask = colorize_mask(pred, palette)

So is it a bug?

dbaranchuk commented 2 years ago

Hi, thanks for reporting this!

"save_predictions" is supposed to get "preds" as NxCxHxW array, where C=1. In this case, pred[0] just squeezes the array and makes it HxW. I guess your "preds" are NxHxW and hence it fails because pred[0] is a one-dimensional vector.

I think we can do preds = np.squeeze(preds) before the loop in "save_predictions" and leave `mask = colorize_mask(pred, palette). Thus, it will work fine for both NxHxW and Nx1xHxW. Do you want to make a PR?

JingyeChen commented 2 years ago

A PR is made. Thanks for your reply :D

dbaranchuk commented 2 years ago

Great! Thank you too :)