ZhengPeng7 / BiRefNet

[CAAI AIR'24] Bilateral Reference for High-Resolution Dichotomous Image Segmentation
MIT License
1.09k stars 84 forks source link

inference speed extremely slow #26

Closed razvypp closed 3 months ago

razvypp commented 4 months ago


The inference speed is extremely slow. I am doing the inference with GPU, but its the same i am doing with u2net and ths speed there is 12x faster.

Is there anything i can do to speed up things?

I have also tried to export to onnx but get error

import torch import torch.onnx from models.birefnet import BiRefNet from utils import check_state_dict from torch.onnx import register_custom_op_symbolic

Register custom symbolic function for deform_conv2d

def deform_conv2d_symbolic(g, input, weight, offset, bias, stride, padding, dilation, groups, deformable_groups, use_mask=False, mask=None): return g.op("DeformConv2d", input, weight, offset, bias, stride_i=stride, padding_i=padding, dilation_i=dilation, groups_i=groups, deformable_groups_i=deformable_groups)

register_custom_op_symbolic('torchvision::deform_conv2d', deform_conv2d_symbolic, 11)

Load the model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = BiRefNet(bb_pretrained=False).to(device) state_dict = torch.load("/root/BiRefNet-massive-epoch_240.pth", map_location=device) state_dict = check_state_dict(state_dict) model.load_state_dict(state_dict) model.eval()

Dummy input to trace the model

dummy_input = torch.randn(1, 3, 1024, 1024).to(device)

Ensure to handle tensor-to-Python type conversions in your model

Example modifications:

if W % self.patch_size[1] != 0:

replace with

if (W % self.patch_size[1]).item() != 0:

Export the model

onnx_model_path = "/root/BiRefNet.onnx" torch.onnx.export( model, # model being run dummy_input, # model input (or a tuple for multiple inputs) onnx_model_path, # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=11, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names=['input'], # the model's input names output_names=['output'], # the model's output names dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} # variable length axes )

print(f"Model has been converted to ONNX and saved at {onnx_model_path}")

ZhengPeng7 commented 4 months ago

Hi, could you check your texts and correct their formats?

ZhengPeng7 commented 4 months ago

About the comparison with u-2-net, I don't think that is a problem. It's almost impossible to bring much improvement on accuracy with same number of parameters.

Statistics about BiRefNet with different backbones can be referred to this issue.

You can also find some other issues where we talked about increasing the speed of inference (ONNX, FP16, ...). By now, there haven't been very good methods for it. You can wait for the version with swin_v1_tiny, which could be 4 times faster than the official one.

razvypp commented 4 months ago

Will the v1_tiny have the same performance on general datasets?

I failed to convert to onnx i get crash, is there a guide for this?

ZhengPeng7 commented 4 months ago

Of course not, that's a trade-off. Sorry, as for the issues I mentioned above, I currently have no time for this kind of thing.

ZhengPeng7 commented 4 months ago

The well-trained BiRefNet with the swin_v1_tiny backbone has been uploaded to my Google Drive. Check the stuff in README for access to the weights, performance, predicted maps, and training log in the corresponding folder (exp-xxx). The performance is a bit lower than the official version, but still good (HCE↓: 1152 -> 1182 on DIS-VD). Feel free to download and use them.

ZhengPeng7 commented 4 months ago

By the way, check the update in inference.py. Set the torch.set_float32_matmul_precision to 'high' can increase the FPS on A100 from 5 to 12 with ~0 performance downgrade (Because I set it to 'high' during training).

luoww1992 commented 1 month ago

@razvypp have you done it sucessfully ? if yes, how to do?