pmelchior / spender

Spectrum encoder and decoder
MIT License
35 stars 11 forks source link

Minor issues running README example #36

Closed jwuphysics closed 1 year ago

jwuphysics commented 1 year ago

I tried running the examples in the README using Pytorch version 1.12.1 on a CUDA-enabled machine. I ran into two small issues at this point:

spec, w, z, ids, norm, zerr = SDSS.make_batch(data_path, ids)
with torch.no_grad():
    s, spec_rest, spec_reco = model._forward(spec, instrument=sdss, z=z)

This first was that the model.wave_rest tensor was on the GPU device while z was on the CPU, which triggered a RuntimeError here:

File ~/research/spender/spender/model.py:333, in SpectrumDecoder.transform(self, x, instrument, z)
    (...)
--> 333     wave_redshifted = (self.wave_rest.unsqueeze(1) * (1 + z)).T

This can be easily resolved with a self.wave_rest.unsqueeze(1).cpu() call.

Once that's fixed, we hit a second snag a few lines down (here I'm showing the entire error stack):

RuntimeError                              Traceback (most recent call last)
Cell In[4], line 2
      1 with torch.no_grad():
----> 2     s, spec_rest, spec_reco = model._forward(spec, instrument=sdss, z=z)

File ~/research/spender/spender/model.py:426, in BaseAutoencoder._forward(self, y, instrument, z, s, aux)
    423     instrument = self.encoder.instrument
    425 x = self.decode(s)
--> 426 y = self.decoder.transform(x, instrument=instrument, z=z)
    428 return s, x, y

File ~/research/spender/spender/model.py:340, in SpectrumDecoder.transform(self, x, instrument, z)
    337 else:
    338     wave_obs = instrument.wave_obs
--> 340 spectrum = Interp1d()(wave_redshifted, x, wave_obs)
    341 
    342 # convolve with LSF
    343 if instrument.lsf is not None:

File ~/miniconda3/envs/pytorch/lib/python3.10/site-packages/torch/autograd/function.py:315, in Function.__call__(self, *args, **kwargs)
    314 def __call__(self, *args, **kwargs):
--> 315     raise RuntimeError(
    316         "Legacy autograd function with non-static forward method is deprecated. "
    317         "Please use new-style autograd function with static forward method. "
    318         "(Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)")

RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method. (Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)

In the updated syntax, we can simply replace line model.py:340 with:

spectrum = Interp1d.apply(wave_redshifted, x, wave_obs)

and it all works.

If I have time I'll make a PR, assuming you don't get to it first.

pmelchior commented 1 year ago

Hi John, thanks for the report. We're aware of both problems. For the second one, there's PR #35, I just haven't found the time to merge it yet.

For the first one, we have a new branch new-aux, which will organize the input and output data in dicts, and load them all with the same method, so you wouldn't have z on a different device as the rest. If you can wait a little longer, both of these problem should be fixed.

jwuphysics commented 1 year ago

Sounds great, and glad to hear you're on it! Will close this for now, then.

pmelchior commented 1 year ago

The latest main branch has the fix in for the RuntimeError. However, it looks like using dicts as input gives pretty poor performance for our data loader. Until we solve that, I recommend making sure to have all inputs to the spender functions reside on the same device. In your case, you want to make sure that z is on the GPU.