HabanaAI / Model-References

Reference models for Intel(R) Gaudi(R) AI Accelerator
153 stars 79 forks source link

training time is slow because of PReLU #14

Closed anti-machinee closed 2 years ago

anti-machinee commented 2 years ago
greg-serochi commented 2 years ago

hi @anti-machinee, the SynapseAI Software stack does not yet support PReLU, so the activity here is falling back to run on CPU and taking more execution time. Since we do support standard ReLU, this is why it's 6x faster when using ReLU instead.

The team did a check on some options here, moving your OPS to ReLU would be ideal, but there are other options to try, inspired by this: prelu(x) = relu(x) + alpha(x - abs(x)) 0.5

By the way, you can see the PyTorch OPS support in our documentations here

anti-machinee commented 2 years ago

@greg-serochi PReLU in Pytorch has an alpha which is learnable parameter. If I use only PreLU as formulation PReLU(x)=max(0,x)+a∗min(0,x) the speed is as same as ReLU but I just want to use PReLU with trainable parameter. Can you explain

  1. Why PReLU runs slowly in SynapseAI Software but fast in CUDA
  2. Are there any trick to use a learnable parameters in PReLu as default Pytorch
ssarkar2 commented 2 years ago

Hi @anti-machinee

Why PReLU runs slowly in SynapseAI Software but fast in CUDA: PReLU is falling back on CPU, gaudi does not support this. To workaround this for now, we have the following suggestions: a) Use relu if it suits your need b) implement prelu using basic ops, something similar to this. You can see @greg-serochi 's answer above

Are there any trick to use a learnable parameters in PReLu as default Pytorch Please refer to custom nn.modules with learnable parameters such as ones mentioned in this discussion. Create a nnmodule to represent Prelu, with a nn.Parameter for "alpha" in the constructor.

For example:

import torch
import torch.nn as nn

class Prelu(nn.Module):
    def __init__(self):
        super(Prelu, self).__init__()
        self.alpha = nn.Parameter(torch.randn(1), requires_grad=True)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(x) + self.alpha*(x - torch.abs(x)) * 0.5

torch.manual_seed(1)

module = Prelu()
print( [i for i in module.parameters()])  
#[Parameter containing:
#tensor([0.6614], requires_grad=True)]

x = torch.tensor([1,2,-3,0])
res = module(x)
print(res)
#tensor([ 1.0000,  2.0000, -1.9841,  0.0000], grad_fn=<AddBackward0>)

From the definition of prelu in pytorch doc here: Here a is a learnable parameter. When called without arguments, nn.PReLU() uses a single parameter a across all input channels. If called with nn.PReLU(nChannels), a separate a is used for each input channel.

In my example, I have prelu as a single number used for all elements of x, but if you want, you can make "a" the same size as x or have "a" be the length of number-of-channels etc.

You could also look at this implementation of Prelu in pt code, and edit the forward function like in my example above

anti-machinee commented 2 years ago

Thank you @ssarkar2 I write PReLU by myself and it works, speed decrease from 7 hours (8 gaudi) to 2.5 hours (8 gaudi) so that model is different from the model use PReLU default but accuracy does not change much Another one, I use Pytorch PReLU formulation and this one is different from the formulation you use.

class PReLUCustom(nn.Module):

def __init__(self, num_parameters):
    super(PReLUCustom, self).__init__()
    self.weight = nn.Parameter(torch.FloatTensor(num_parameters))  # out, in
    # nn.init.xavier_uniform_(self.weight)
    self.relu = nn.ReLU()
def forward(self, _input):
    zero_tmp = torch.zeros_like(_input)
    x = self.relu(_input) + self.weight.unsqueeze(1).unsqueeze(2).expand_as(_input) * torch.min(zero_tmp, _input)
    return x
ssarkar2 commented 2 years ago

Yes, your implementation prelu(x) = relu(x) + w * min(0,x) looks correct as well. If you are unblocked, please feel free to close this issue, and you can keep an eye out here to see if we support prelu in the future

anti-machinee commented 2 years ago

Would PReLU be supported soon ?