Closed JingliangGao closed 7 months ago
Good question: dlprimitives provide several functions for activation so I tried to do something generic. After some time I reverted back to directly implementation but this remained.
So the answer is - historically, and probably it can be implemented as all other functions without a problem.
In order to register the relu operator in PrivateUse1, I rewrite the function like that ` Tensor & relu(Tensor & self)
{
GUARD;
dlprim::Tensor X = todp(self);
Tensor out = new_tensor_as(X.shape(),self);
dlprim::Tensor Y=todp(out);
dlprim::core::activation_forward(X,Y,dlprim::StandardActivations::relu,getExecutionContext(self));
sync_if_needed(self.device());
return self;
}
`
I have succeed to build this project, but failed in the running process as follows.
Could you provide more details about how to design a custom operator and then register it ? BTW, is it also possible to register the linear and max_pool2d operator in PrivateUse1?
As far as I remember max_pool2d
and linear
use custom ops because of slightly different input/output/gradient implementations for backpropogation by Torch operators and dlprimitives.
That is why custom backward function is applied. I used autograd private use any time I need little bit different backpropogation. I don't recall why I used it for ReLU but I assume I use it this way for max pooling because of different way I handle the location of "max" item for backpropogation.
I assume it can be rewritten differently but likely you'll need to change the kernels themselves in dlprimitives and run all tests.
Can you explain why it bothers you?
Tensor & relu(Tensor & self)
I assume you replacing aten::relu
- it requires self to be const
Tensor relu(const Tensor & self); // {"schema": "aten::relu(Tensor self) -> Tensor", "dispatch": "True", "default": "False"}
Hi, In the pointwiseops.cpp file, you register the relu operator in AutogradPrivateUse1, which is different with the relu operator. Why not register the relu operator in PrivateUse1 ? Just like the tan and tanh_ operators, which are registered in PrivateUse1.