thjashin / multires-conv

Sequence Modeling with Multiresolution Convolutional Memory (ICML 2023)
MIT License
119 stars 2 forks source link

A couple of fixes and generative applications #3

Open jbmaxwell opened 1 year ago

jbmaxwell commented 1 year ago

First of all, congratulations on what looks like very exciting work!

I wanted to check your generative modelling application by running autoregressive.py, however, I had to change a couple of things to get it running: 1) I got an error that MultiresLayer doesn't take the mixing argument. I figured, since it was being given False as an input, I could probably just comment it out, which does run. 2) I also got an error when the forward() function calls out = self.output_mapping(x), since that function doesn't seem to exist. I guessed this should probably be out = self.decoder(x), which seems to be running.... and converging. :) However, if I'm wrong about these "fixes" please let me know!

More generally, I'm very curious about the application of this model in generative contexts. I'm specifically in the music and audio field, where powerful sequence models are (obviously) essential.

So, a few things I'm wondering about: 1) Is there a relatively painless way to get images from the autoregressive_eval.py script, so I can see the output? (I can obviously dig in and figure this out, but if there's a quick mod you can suggest that would be great). 2) Do you see opportunities for conditional generation? 3) Since autoregressive generation can be somewhat limited, do you see opportunities using different training methods—for example, "infilling" generation via masking?

Again, thanks for your work.

thjashin commented 1 year ago

Hi @jbmaxwell ,

Yes your fixes are right---I messed up a few things when merging the code of different experiments. Yes there's a way but it will have to be the autoregressive way, basically feeding new prediction as next input. I can work on an example when I got time if that is helpful.

Since MultiresConv is just a sequence modeling layer, it is applicable to any methods/tasks that need a sequence-to-sequence mapping. You could certainly do conditional generation in this AR model example (which is no different from other AR models) and "infilling"/diffusion type methods (where the sequence to sequence map is used for one step of infilling.

Thanks for your interest in our work!

jbmaxwell commented 1 year ago

Thanks for the speedy reply!

By "there's a way", I'm guessing you mean to generate images from the output? I see that the input shape is (64, 3, 1024), which I assume is a batch of images reshaped to three "sequences" of pixels (R, G, and B sequences, I mean). The output is (64, 100, 1024), which I see is the output of the decoder... So yes, if you could put together a quick example, or even point me in the right direction, that would be great!

thjashin commented 1 year ago

Hi @jbmaxwell , Please see autoregressive_eval.py for an example of generating unconditional samples from the trained model.

jbmaxwell commented 1 year ago

I'm hitting an error:

Traceback (most recent call last):
  File "/home/james/src/somms/multires-conv/autoregressive_eval.py", line 208, in <module>
    mp.spawn(main, args=(world_size, args), nprocs=world_size)
  File "/home/james/anaconda3/envs/multires/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 239, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/james/anaconda3/envs/multires/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes
    while not context.join():
  File "/home/james/anaconda3/envs/multires/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/james/anaconda3/envs/multires/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/james/src/somms/multires-conv/autoregressive_eval.py", line 164, in main
    samples = sample(rank, model, 64, data_shape)
  File "/home/james/src/somms/multires-conv/autoregressive_eval.py", line 68, in sample
    out = sample_from_discretized_mix_logistic(out.reshape(*out.shape[:2], *data_shape[1:]), model.module.nr_logistic_mix)
  File "/home/james/anaconda3/envs/multires/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'MultiresAR' object has no attribute 'nr_logistic_mix'

I can see that nr_logistic_mix is clearly a parameter of MultiresAR, so not sure what's going on. Maybe it's mistaking the object type, but not sure how...?

(Just as an initial test, I first trained on 50 epochs of autoregressive.py, btw.)

UPDATE: Actually, although the init takes that argument, I see it doesn't store it, so the module can't provide it in the sample() function (i.e., I got it running just by storing it in the class).

thjashin commented 1 year ago

Hi @jbmaxwell , Yep the solution is just storing that in the class!

thjashin commented 1 year ago

I updated the code. Please let me know if there are any other issues!

jbmaxwell commented 1 year ago

Actually, I was curious about training on different shapes. I did some work on trying to get non-image data (shape = (bs, 32, 512)) working, but kept hitting errors in the loss calculation.