ddlee-cn / MuLUT

[ECCV 2022 & T-PAMI 2024] Multiple Look-Up Tables for Efficient Image Restoration
https://mulut.pages.dev
MIT License
90 stars 9 forks source link

inference speed problem android #13

Closed undcloud closed 2 months ago

undcloud commented 2 months ago

Hello, thanks your work, your work is very novel.

I test the1st model(s1_d) by using speed_benchmark_torch.
Android device is Qualcomm Snapdragon 888+
resolution latency
224*224*3 282ms
512*512*3 1.8s
1024*1024*3 7s

the spedd is slow, why?

undcloud commented 2 months ago
import logging
import math
import os
import sys
import time

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
# from torch.utils.tensorboard import SummaryWriter
from torch.utils.mobile_optimizer import optimize_for_mobile

import model
# from data import Provider, SRBenchmark

sys.path.insert(0, "../")  # run under the project directory
from common.option import TrainOptions
# from common.utils import PSNR, logger_info, _rgb2ycbcr

torch.backends.cudnn.benchmark = True

mode_pad_dict = {"s": 1, "d": 2, "y": 2, "e": 3, "h": 3, "o": 3}

if __name__ == "__main__":
    opt_inst = TrainOptions()
    opt = opt_inst.parse()
    modes = [i for i in opt.modes]
    stages = opt.stages
    model = getattr(model, opt.model)
    model_G = model(nf=opt.nf, scale=opt.scale, modes=modes, stages=stages) # .cuda()
    lm = torch.load('/home/fanhao1/MuLUT/models/sr_x2sdy/Model_200000.pth')
    model_G.load_state_dict(lm.state_dict(), strict=True)
    dummy_input = torch.randn(1, 3, 128, 228) #.to('cuda')
    # torch.onnx.export(model_G.s1_d, dummy_input, "MuLUT.onnx", verbose=False, opset_version=17)
    # torch.onnx.dynamo_export(model_G.s1_d, dummy_input).save('s1_d.onnx')

    traced_script_module = torch.jit.trace(model_G.s1_d, dummy_input)
    # traced_script_module = torch.jit.load(load_path)
    traced_script_module_optimized = optimize_for_mobile(traced_script_module) 
    outputPath = "s1_d.ptl"
    traced_script_module_optimized._save_for_lite_interpreter(outputPath)
    outputPath = "s1_d_nop.ptl" # os.path.join ( output_path, model_name + "_{}x{}_nop.ptl".format(height, width) )
    traced_script_module._save_for_lite_interpreter(outputPath)
    outputPath = "s1_d_nop.pt" # os.path.join ( output_path, model_name + "_{}x{}.pt".format(height, width) )
    traced_script_module.save(outputPath)          
    print(model_G)
ddlee-cn commented 2 months ago

Hi, thanks for your interest on our work.

There are many tensor operations inside the torch model, e.g., reshape, which contributes a lot to the running time. As stated in our paper, we reimplement the LUT retrieval process with JAVA following SR-LUT to evaluate its running time.

undcloud commented 2 months ago

Thank you, Could you provide the MuLUT Android Java project code or android apk demo?

ddlee-cn commented 2 months ago

Please send me an email.