YihanHu-2022 / DiffMatte

68 stars 4 forks source link

[Bug] run_one_image demo creates bad resutls #5

Closed zijunwei closed 19 hours ago

zijunwei commented 5 months ago

I followed the set up, the only differences are

  1. pytorch 2.2.2 with cuda 12.1: 2.2.2+cu121
  2. seems the einops package is missing, so I installed einops==0.7.0

I got the following results using ViTS_1024 and associated weights

bulb_reproduce result_reproduce

There is no warnings/errors during the inference. Please advise

YihanHu-2022 commented 4 months ago

It seems that the problem is caused by the area value of trimap. Please ensure that the values of trimap in the foreground, unknown, and background areas are 255, 128, and 0 respectively.

jseobyun commented 3 months ago

same issue. Have you solved this issue??

In my case, even though I double-checked the trimap value and corrected wrong value to one of (0, 128, 255), I could not get a result.

Could you share the correct trimap? Above both images are provided by DiffMate but they do not work.

thangnh0608 commented 3 months ago

@jseobyun Hi, can you share me examples to run the file run_one_image.py? I couldn't run this file

YihanHu-2022 commented 3 months ago

I cannot reproduce this bug in my recommending environment, even though I do not think this is caused by unpair envs. Maybe using the fixed command of run_one_image.py demo can help.

YihanHu-2022 commented 3 months ago

I rerun this code and get following results:

image

seems that just fine.

MaxTeselkin commented 3 months ago

I am facing the same issue

MaxTeselkin commented 3 months ago

@YihanHu-2022 the values of trimap must be 0, 0.5 and 1 according to your code (for example). If I replace them on 0, 128 and 255 as you proposed above, then I will get RuntimeError: CUDA error: device-side assert triggered.

I also tried running run_one_image.py on images from your repository and got the same result as @zijunwei reported. If the same code gives you different result, then it means that there is probably a problem with checkpoint (maybe your checkpoint is not the same as you opensourced).

wangjia184 commented 2 months ago

@YihanHu-2022 I am encountering the same issue using demo. The trimap is from the demo folder.

I tried both ViTB and ViTS_1024, both produce bad result.

python run_one_image.py \
    --config-file ViTB \
    --checkpoint-dir /pretrained/diffmatte/DiffMatte_ViTB_Com.pth \
    --image-dir demo/retriever_rgb.png \
    --trimap-dir demo/retriever_trimap.png \
    --output-dir /out/result2.png \
    --device cuda \
    --sample-strategy ddim10

python run_one_image.py \
    --config-file ViTS_1024 \
    --checkpoint-dir /pretrained/diffmatte/DiffMatte_ViTS_Com_1024.pth \
    --image-dir demo/retriever_rgb.png \
    --trimap-dir demo/retriever_trimap.png \
    --output-dir /out/result.png \
    --device cuda \
    --sample-strategy ddim10

Here is my environment in docker container:

FROM nvcr.io/nvidia/cuda:11.8.0-runtime-ubuntu20.04

RUN apt update -y && apt install python3.8 python3-pip -y

RUN ln -s /usr/bin/python3.8 /usr/bin/python

ARG USE_CUDA=0

ENV DEBIAN_FRONTEND=noninteractive
ENV AM_I_DOCKER True
ENV BUILD_WITH_CUDA "${USE_CUDA}"
ENV TORCH_CUDA_ARCH_LIST "8.9"
ENV CUDA_HOME /usr/local/cuda-11.8

#Fuc*** Great Firewall
#RUN pip3 install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118;
COPY install /install
RUN pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple  /install/torch-2.2.2+cu118-cp38-cp38-linux_x86_64.whl; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple  /install/torchaudio-2.2.2+cu118-cp38-cp38-linux_x86_64.whl; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple  /install/torchvision-0.17.2+cu118-cp38-cp38-linux_x86_64.whl; \
    rm -rf /install

#git submodule add https://github.com/YihanHu-2022/DiffMatte.git DiffMatte
#git submodule add https://github.com/facebookresearch/detectron2.git detectron2

COPY DiffMatte /app/DiffMatte
COPY detectron2 /detectron2
COPY pretrained/diffmatte /pretrained/diffmatte

RUN pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple /detectron2

RUN pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple -U setuptools; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple flasgger==0.9.7b2; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple gunicorn==22.0.0; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple marshmallow==3.21.3; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple apispec==6.6.1; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple wheel==0.43.0; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple opencv-python==4.10.0.82; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple supervision==0.21.0; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple pycocotools==2.0.7; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple timm==1.0.3; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple easydict==1.13; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple scikit-image==0.23.2; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple fairscale==0.4.13; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple wget==3.2; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple tensorboard==2.17.0; \
    pip3 install -i https://pypi.tuna.tsinghua.edu.cn/simple einops==0.8.0; 

ViTB ViTS_1024

@YihanHu-2022 can you pack your environment into a docker image so that we can see why it fails?

hsdjkfnsfc commented 2 months ago

same issue with you guys

wangjia184 commented 2 months ago

The trimap is not the cause. I added three lines below to guarantee only three values (0/0.5/1) from trimap. Same result

def get_data(image_dir, trimap_dir):
    """
    Get the data of one image.
    Input:
        image_dir: the directory of the image
        trimap_dir: the directory of the trimap
    """
    image = Image.open(image_dir).convert('RGB')
    image = F.to_tensor(image).unsqueeze(0)
    trimap = Image.open(trimap_dir).convert('L')
    trimap = F.to_tensor(trimap).unsqueeze(0)

    # force tri-values in trimap
    trimap[trimap > 0.9] = 1.00000
    trimap[(trimap >= 0.1) & (trimap <= 0.9)] = 0.50000
    trimap[trimap < 0.1] = 0.00000

    return {
        'image': image,
        'trimap': trimap
    }
ReddyNick commented 2 months ago

The problem is in saving output. Use this fixed function instead

def infer_one_image(model, input, save_dir=None):
    output = model(input)

    # output = F.to_pil_image(output).convert('RGB')
    # output.save(opj(save_dir))

    output = cv2.cvtColor(output, cv2.COLOR_GRAY2RGB)
    cv2.imwrite(opj(save_dir), output)
    return None
wangjia184 commented 2 months ago

@ReddyNick yes, you saved the world !

I created a PR here https://github.com/YihanHu-2022/DiffMatte/pull/8

From my test, Diffmatte is faster than AEMatte although they both produce good result

MaxTeselkin commented 2 months ago

@ReddyNick @wangjia184 @YihanHu-2022 Am I the only one here who is unable to run DiffMatte ViTS-1024 on image of resolution 1920 x 2880 without facing CUDA out of memory error?) On image of resolution 1920 x 1280 everything is ok. It seems to me that gradient calculation is not disabled somewhere. I tried using torch.set_grad_enabled(False), but it didn't make any difference. The problem here is the fact that input image resizing and mask interpolation to original size is not suitable for image matting since it decreases mask detalization.

YihanHu-2022 commented 1 month ago

@ReddyNick @wangjia184 @YihanHu-2022 Am I the only one here who is unable to run DiffMatte ViTS-1024 on image of resolution 1920 x 2880 without facing CUDA out of memory error?) On image of resolution 1920 x 1280 everything is ok. It seems to me that gradient calculation is not disabled somewhere. I tried using torch.set_grad_enabled(False), but it didn't make any difference. The problem here is the fact that input image resizing and mask interpolation to original size is not suitable for image matting since it decreases mask detalization.

Hi, this is a valuable question and is indeed what our team currently works on. The huge cuda memory need is caused by the global attention employed in the ViT backbone. To address this problem you can refer to the solution at https://github.com/hustvl/ViTMatte/issues/10. BTW we will share another work that provides a better method to handle high-res matting with the mature ViT series backbone network, and I'll share the link under this issue.

LukaGiorgadze commented 1 month ago

@YihanHu-2022 First of all, thanks for reacting to the issues. I'm running Diffmate on every HTTP POST request, but the memory usage keeps growing until I run out of memory. For example, the first inference uses 10 GB of memory, the second uses 12 GB, the third uses 18 GB. The memory keeps increasing like this until it reaches out of memory. What might be wrong?

Here's an example code:

from diffmatte_run import infer_one_image as infer_diffmate, init_model as init_model_diffmatte

app = Flask(__name__)

@app.route("/", methods=["POST"])
def run_inference():
    # ... code to process the image

    try:
        image_alpha = infer_diffmate(
                init_model_diffmatte(args.model, args.checkpoint, args.device, "ddim10"),
                {
                    "image": F.to_tensor(image_rgb).unsqueeze(0),
                    "trimap": F.to_tensor(image_trimap).unsqueeze(0),
                },
            )

        output_format = "png"
        byte_arr = io.BytesIO()
        image_alpha.save(byte_arr, format=output_format)
        byte_arr.seek(0)
        return Response(byte_arr.getvalue(), mimetype=f"image/{output_format}")

    except Exception as e:
        return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500

    finally:
        if args.device == "gpu":
            torch.cuda.empty_cache()

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=8092, debug=True, use_reloader=True)

If you are interested, I can you provide full code of the Flask app.

wangjia184 commented 1 month ago

@LukaGiorgadze Start a subprocess for every inferrence, that is how I use it:)

YihanHu-2022 commented 19 hours ago

As this issue seems to be addressed, I gonna close it up.