ROCm / torch_migraphx

Libraries integrating migraphx with pytorch
BSD 3-Clause "New" or "Revised" License
5 stars 1 forks source link

Op Support: torchvision.roi_align.default #143

Open shivadbhavsar opened 2 months ago

shivadbhavsar commented 1 month ago

Here is some sample test code to help with the acc_op implementation

import torch
from torchvision.ops import RoIAlign

roi_mod = RoIAlign(output_size=[2, 2], spatial_scale=1.0, sampling_ratio=-1, aligned=False)

inp = torch.randn(2, 3, 4, 4)
boxes = torch.tensor([[0, 0, 0, 0.5, 0.7]])

torch_out = roi_mod(inp, boxes)
print(torch_out)

import torch_migraphx.fx.tracer.acc_tracer.acc_tracer as acc_tracer
from torch_migraphx.fx.fx2mgx import MGXInterpreter
from torch_migraphx.fx.mgx_module import MGXModule
def convert_to_mgx(mod, inp):
    traced = acc_tracer.trace(mod.eval(), inp)
    traced.graph.print_tabular()
    interp = MGXInterpreter(traced, inp)
    interp.run()
    return MGXModule(interp.program, interp.get_input_names())

mgx_mod = convert_to_mgx(roi_mod, [inp, boxes])

mgx_out = mgx_mod(inp, boxes)
print(mgx_out)

Currently, this will fail with RuntimeError: Conversion of function torchvision.ops.roi_align.roi_align not supported.. Do one of the following: a. Create a acc_op wrapper for torchvision.ops.roi_align.roi_align and implement a converter for it in acc_ops_converters. b. Dont bother with acc wrapper, just create a torchvision_converters file and directly implement a converter for torchvision.ops.roi_align.roi_align