ML4GW / ml4gw

Torch utilities for doing machine learning in gravitational wave physics
18 stars 13 forks source link

Refactor TaylorF2 and IMRPhenomD #137

Closed ravioli1369 closed 5 months ago

ravioli1369 commented 5 months ago

This PR add the following:

ravioli1369 commented 5 months ago

@deepchatterjeeligo so, apparently it's not as simple as doing a self.qnmdata to access the register buffers, there seems to be a self.buffers and self.named_buffers method to iterate over the buffers. @EthanMarx could you confirm if this is the right way of accessing them? When I had used self.qnmdata_, I got an error saying that parameter is not a part of my class.

EthanMarx commented 5 months ago

@ravioli1369 When you register a buffer it should be accessible as a Class variable. Can you post the error

deepchatterjeeligo commented 5 months ago

I agree with Ethan. E.g.

>>> class Net(torch.nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.register_buffer('mean', torch.as_tensor(1.0))
...     def forward(self, x):
...         return x
... 
>>> net = Net()
>>> net.mean
tensor(1.)

So this should work self.register_buffer('qnm_data_a', QNMData_a)?

ravioli1369 commented 5 months ago

@EthanMarx @deepchatterjeeligo, not sure what happened when I tried to access it using self the first time, but now it seems to be working, I've updated the code accordingly.