ashishkumar822 / PiX

PiX: Dynamic Channel Sampling for ConvNets (CVPR 2024)
GNU General Public License v3.0
10 stars 1 forks source link

FP16 and use_amp support? #2

Open FlotingDream opened 3 weeks ago

FlotingDream commented 3 weeks ago

FP16 and use_amp support

FlotingDream commented 3 weeks ago

it can be down by this

from torch.cuda.amp import custom_fwd,custom_bwd
# gradients in the backward are received in the order of tensor as they were output in forward function
class PiXOperator(torch.autograd.Function):
    @staticmethod
    @custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, zeta: int, tau: float,  input: torch.Tensor, p: torch.Tensor):
        outputs = pix_layer_cuda.forward(zeta, tau, input, p)
        ctx.save_for_backward(input, p)
        ctx.zeta = zeta
        ctx.tau = tau
        return outputs[0]

    @staticmethod
    @custom_bwd
    def backward(ctx, out_grad):
        input, p = ctx.saved_tensors
        zeta = ctx.zeta
        tau = ctx.tau
        input_grad, fusion_prob_grad = pix_layer_cuda.backward(zeta, tau, input, p, out_grad)
        return None, None, input_grad, fusion_prob_grad

but I want to know if the cuda'code can support fp16 operator? AND fp16 inference? even onnx suppot and trt support. like https://github.com/5had3z/CerberusNet/tree/master/runtime/cerberus_net/trt_plugins

please help, thx!

ashishkumar822 commented 2 weeks ago

Thanks for the patience @FlotingDream . Currently, we have not written CUDA kernels for FP16. But the FP16 kernel can also be written.