Closed ravioli1369 closed 5 months ago
@deepchatterjeeligo so, apparently it's not as simple as doing a self.qnmdata to access the register buffers, there seems to be a self.buffers
and self.named_buffers
method to iterate over the buffers.
@EthanMarx could you confirm if this is the right way of accessing them? When I had used self.qnmdata_, I got an error saying that parameter is not a part of my class.
@ravioli1369 When you register a buffer it should be accessible as a Class variable. Can you post the error
I agree with Ethan. E.g.
>>> class Net(torch.nn.Module):
... def __init__(self):
... super().__init__()
... self.register_buffer('mean', torch.as_tensor(1.0))
... def forward(self, x):
... return x
...
>>> net = Net()
>>> net.mean
tensor(1.)
So this should work self.register_buffer('qnm_data_a', QNMData_a)
?
@EthanMarx @deepchatterjeeligo, not sure what happened when I tried to access it using self
the first time, but now it seems to be working, I've updated the code accordingly.
This PR add the following:
torch.nn.Module
classes, making use ofself.register_buffer
to avoid tensor initializations on the CPU.