pytorch / cppdocs

PyTorch C++ API Documentation
https://pytorch.org/cppdocs
208 stars 33 forks source link

[Pytorch C++ Usage Support] #6

Closed vinayak618 closed 5 years ago

vinayak618 commented 5 years ago

Hi. I'm trying to convert my python model to C++. But while doing so, when i call one class from another class, how do i register the class module. class A(nn.Module): def init(self, num_features): super(A, self).init() self.bn = BatchNorm2d(num_features)

def forward(self, x):
    bn = self.bn(x)
    return bn

def __call__(self, x):
    return self.forward(x)

class B(nn::Module): def init(self, num_features) super(BNCN_Gate, self).init() self.bn = CustomBatchNorm(in_channels, renorm, arg_ex) conv_gate = nn::sequential() ......`

How should write such using Pytorch C++ API. Right now i'm able to do this:

`struct A : torch::nn::Module { A(int64_t num_features) : bn(torch::nn::BatchNorm(num_features)) { } torch::Tensor forward(torch::Tensor x) { x = bn->forward(x); return x; } torch::nn::BatchNorm bn; };

struct BNCC_Gate : torch::nn::Module { BNCC_Gate() { bn(CustomBatchNorm(3, NULL, NULL)); this->register_module("bn", bn); conv_gate = torch::nn::Sequential(); ...... } }; But this is giving me error as: undefined bn. How should i create a object of the class A and regsiter it to get the complete network structure in class B?

@goldsborough @soumith @asmeurer @asuhan @xuhdev Any help would be great. Thank you

soumith commented 5 years ago

please use https://discuss.pytorch.org for questions