openai / jukebox

Code for the paper "Jukebox: A Generative Model for Music"
https://openai.com/blog/jukebox/
Other
7.83k stars 1.41k forks source link

Upsampling in parts? #133

Open ladis212 opened 4 years ago

ladis212 commented 4 years ago

Can I let's say, upsample half of the generated sample on one day, and then continue from the checkpoint and do the other half on the next day?

robinsloan commented 4 years ago

I'm not one of the project's creators or maintainers, but I have been using it a lot and, from what I understand, this is totally possible. You would just separate your final top-level zs into chunks ahead of time, like so (I'm sure this is a horrible way of doing it; I am a very poor Python programmer):

split_point = math.floor( zs[0].shape[1] / 2 )
zs_part1 = [zs[0][:, 0:split_point], zs[1][:, 0:split_point], zs[2][:, 0:split_point]]
zs_part2 = [zs[0][:, split_point:], zs[1][:, split_point:], zs[2][:, split_point:]]

And then you would save your zs_part1 and zs_part2 separately and run them through the upsampler separately.

I suspect there will be a sonic "discontinuity" at the breakpoint so it might make sense to build in some overlap -- zs_part2 begins at split_point-1000 or something.

I'm not 100% sure that will work but it's what I would try if I was doing this!

samuel-larsson commented 4 years ago
split_point = math.floor( zs[0].shape[1] / 2 )
zs_part1 = [zs[0][:, 0:split_point], zs[1][:, 0:split_point], zs[2][:, 0:split_point]]
zs_part2 = [zs[0][:, split_point:], zs[1][:, split_point:], zs[2][:, split_point:]]

And then you would save your zs_part1 and zs_part2 separately and run them through the upsampler separately.

Sadly, this results in an error for me when I split it just before the upsampling oneliner. The oneliner looks like this: zs = upsample(zs_part1, labels, sampling_kwargs, [*upsamplers, top_prior], hps)

This is the output I get:

Sampling level 1
Sampling 8192 tokens for [0,8192]. Conditioning on 0 tokens
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-25-216495ebcafe> in <module>()
----> 1 zs_part1 = upsample(zs_part1, labels, sampling_kwargs, [*upsamplers, top_prior], hps)

4 frames
/usr/local/lib/python3.6/dist-packages/jukebox/sample.py in upsample(zs, labels, sampling_kwargs, priors, hps)
    137 def upsample(zs, labels, sampling_kwargs, priors, hps):
    138     sample_levels = list(range(len(priors) - 1))
--> 139     zs = _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps)
    140     return zs
    141 

/usr/local/lib/python3.6/dist-packages/jukebox/sample.py in _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps)
    100         total_length = hps.sample_length//prior.raw_to_tokens
    101         hop_length = int(hps.hop_fraction[level]*prior.n_ctx)
--> 102         zs = sample_level(zs, labels[level], sampling_kwargs[level], level, prior, total_length, hop_length, hps)
    103 
    104         prior.cpu()

/usr/local/lib/python3.6/dist-packages/jukebox/sample.py in sample_level(zs, labels, sampling_kwargs, level, prior, total_length, hop_length, hps)
     83     if total_length >= prior.n_ctx:
     84         for start in get_starts(total_length, prior.n_ctx, hop_length):
---> 85             zs = sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps)
     86     else:
     87         zs = sample_partial_window(zs, labels, sampling_kwargs, level, prior, total_length, hps)

/usr/local/lib/python3.6/dist-packages/jukebox/sample.py in sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps)
     51 
     52     # get z_conds from level above
---> 53     z_conds = prior.get_z_conds(zs, start, end)
     54 
     55     # set y offset, sample_length and lyrics tokens

/usr/local/lib/python3.6/dist-packages/jukebox/prior/prior.py in get_z_conds(self, zs, start, end)
    160             assert start % self.cond_downsample == end % self.cond_downsample == 0
    161             z_cond = zs[self.level + 1][:,start//self.cond_downsample:end//self.cond_downsample]
--> 162             assert z_cond.shape[1] == self.n_ctx//self.cond_downsample
    163             z_conds = [z_cond]
    164         else:

AssertionError: 
robinsloan commented 4 years ago

@samuel-larsson, I'm not 100% sure what's causing this error for you, but it might be interesting to throw a print statement in there, just before the error, to see what the shape of z_cond actually is. I wonder if something is off by one, owing to the rounding we're doing with the fraction?

In the meantime, here are some updated splitting/truncation functions that have been working for me; it might be worth pasting these into your Colab and trying them:

def zs_clean():
  # note: zs[0] is *all the level 0s for the samples*
  # and so forth; so, in this notebook, we only care about level 2s

  zs[0] = t.zeros(hps.n_samples, 0, dtype=t.long, device='cuda')
  zs[1] = t.zeros(hps.n_samples, 0, dtype=t.long, device='cuda')

def take_zs_start_ending_at(frac):
  orig_len_2 = len(zs[2][0])
  new_len_2 = math.floor( orig_len_2 * frac )
  zs[2] = zs[2][:,0:new_len_2].clone()

  zs_clean()
  x = vqvae.decode(zs[2:], start_level=2).cpu().numpy()

def take_zs_end_starting_from(frac):
  orig_len_2 = len(zs[2][0])
  new_len_2 = math.ceil( orig_len_2 * frac )
  zs[2] = zs[2][:,new_len_2:(orig_len_2-1)].clone()

  zs_clean()
  x = vqvae.decode(zs[2:], start_level=2).cpu().numpy()
Broccaloo commented 4 years ago

@robinsloan I was able to run your code without error (at least the zs_part1), but it didn't work as intended - Jukebox upsampled the full length sample anyway instead of just the first half.

This is the code I run:

import math
split_point = math.floor( zs[0].shape[1] / 2 )
zs_part1 = [zs[0][:, 0:split_point], zs[1][:, 0:split_point], zs[2][:, 0:split_point]]
zs_part2 = [zs[0][:, split_point:], zs[1][:, split_point:], zs[2][:, split_point:]]
zs = upsample(zs_part1, labels, sampling_kwargs, [*upsamplers, top_prior], hps)

Maybe it has to do with this part of your explanation that I don't quite understand:

And then you would save your zs_part1 and zs_part2 separately and run them through the upsampler separately.

How can you "save" these? I didn't find any files named that way. Sorry if this is a super noob-question, but I don't know much more than the basics of Python.

Appreciate your efforts, best regards!

robinsloan commented 4 years ago

@Broccaloo Ah, sorry for the confusion; I always save my zs as files (using something like t.save(zs, 'zs-checkpoint.t')) and then run a separate script that reads those files and does the upsampling step. So, yeah, saving/restoring isn't relevant if you're doing it all in the same notebook. (Although I do encourage you to save checkpoints so your work isn't lost if the notebook stalls out!)

You know, at the time I wrote the little snippet you quoted & used, I did not quite understand the entire zs data structure. The functions I pasted further down in the thread reflect a better understanding, and they might do the trick for you. I'd recommend trying those functions, or, if you want to tie them back into the "split zs into two parts" idea, maybe something like:

import math
frac = 0.5 # this is the point, 0.0-1.0, where you want to split the zs
orig_len_zs_2 = len(zs[2][0])
new_len_zs_2 = math.floor( orig_len_zs_2 * frac )

zs_part1 = [ t.zeros(hps.n_samples, 0, dtype=t.long, device='cuda'),
             t.zeros(hps.n_samples, 0, dtype=t.long, device='cuda'),
             zs[2][:,0:new_len_zs_2].clone() ]

zs_part2 = [ t.zeros(hps.n_samples, 0, dtype=t.long, device='cuda'),
             t.zeros(hps.n_samples, 0, dtype=t.long, device='cuda'),
             zs[2][:,new_len_zs_2:].clone() ]

Following that, if you zs = upsample(zs_part1... you'll get what you expect… ONE HOPES 🀞

Broccaloo commented 4 years ago

EDIT: @robinsloan I got it to work! Initially, I got this error:

Sampling 8192 tokens for [32768,40960]. Conditioning on 4096 tokens
Primed sampling 3 samples with temp=0.99, top_k=0, top_p=0.0
128/128 [00:11<00:00, 11.44it/s]
4096/4096 [04:21<00:00, 15.64it/s]
Sampling 8192 tokens for [36864,45056]. Conditioning on 4096 tokens

---------------------------------------------------------------------------

AssertionError                            Traceback (most recent call last)

<ipython-input-16-e7a476d5f0c7> in <module>()
----> 1 zs = upsample(zs_part1, labels, sampling_kwargs, [*upsamplers, top_prior], hps)

4 frames

/usr/local/lib/python3.6/dist-packages/jukebox/prior/prior.py in get_z_conds(self, zs, start, end)
    160             assert start % self.cond_downsample == end % self.cond_downsample == 0
    161             z_cond = zs[self.level + 1][:,start//self.cond_downsample:end//self.cond_downsample]
--> 162             assert z_cond.shape[1] == self.n_ctx//self.cond_downsample
    163             z_conds = [z_cond]
    164         else:

AssertionError: 

But then I changed the sample_length_in_seconds from the original 60 seconds I generated to 30 seconds to reflect the new length and it worked. Thank you so much!

robinsloan commented 4 years ago

I'm glad it worked! In my script, I don't have to change the sample_length_in_seconds, so I wonder what's different here… πŸ€”

Ah, okay -- I just actually looked at my standalone upsampling script and noticed these lines:

top_prior_raw_to_tokens = 128
hps.sample_length = zs[2].shape[1] * top_prior_raw_to_tokens

So, adding those might accomplish the same thing as resetting the sample_length_in_seconds manually, or it might even be a little more reliable/flexible -- keep it in mind.

If you produce anything interesting, post a link! 😝 Here's my Jukebox-powered EP.

leonardog27 commented 4 years ago

I'm glad it worked! In my script, I don't have to change the sample_length_in_seconds, so I wonder what's different here… πŸ€”

Ah, okay -- I just actually looked at my standalone upsampling script and noticed these lines:

top_prior_raw_to_tokens = 128
hps.sample_length = zs[2].shape[1] * top_prior_raw_to_tokens

So, adding those might accomplish the same thing as resetting the sample_length_in_seconds manually, or it might even be a little more reliable/flexible -- keep it in mind.

If you produce anything interesting, post a link! 😝 Here's my Jukebox-powered EP.

Good Days Sirs, I continued the test done my friend broccaloo using this code for part two and level 1 was successfully generated in the test done.

After finishing this running, I will rename folders level 1 and level 0 to start to upscaling to level 1 and 0 part 1 with top_prior_raw_to_tokens = 64 that already broccaloo test for part 1.

Thank you very much for share your knowledge. πŸ€–πŸŽΌπŸŽΉπŸ‘ŒπŸΌπŸ₯¦πŸ‘¨β€πŸ’»πŸ‘©β€πŸ’»πŸŽΌπŸ€–

IMG-20201025-WA0013

leonardog27 commented 4 years ago

I am going to do 3 samples 5B Model Symphonic of the Imperial March, 8 minutes each. Using upsampling by parts (4 parts) IMG_20201025_210319

πŸŽΌπŸ€–πŸ‘¨β€πŸ’» Thank you very much.

leonardog27 commented 4 years ago

top_prior_raw_to_tokens = 32 hps.sample_length = zs[2].shape[1] top_prior_raw_to_tokens zs = upsample(zs_part4, labels, sampling_kwargs, [upsamplers, top_prior], hps)

The following error message appear at the moment to upsample PART 4, please help ?

Sampling level 1 Sampling 8192 tokens for [0,8192]. Conditioning on 0 tokens

IndexError Traceback (most recent call last)

in () 1 top_prior_raw_to_tokens = 32 2 hps.sample_length = zs[2].shape[1] * top_prior_raw_to_tokens ----> 3 zs = upsample(zs_part4, labels, sampling_kwargs, [*upsamplers, top_prior], hps) 4 frames /usr/local/lib/python3.6/dist-packages/jukebox/sample.py in upsample(zs, labels, sampling_kwargs, priors, hps) 137 def upsample(zs, labels, sampling_kwargs, priors, hps): 138 sample_levels = list(range(len(priors) - 1)) --> 139 zs = _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps) 140 return zs 141 /usr/local/lib/python3.6/dist-packages/jukebox/sample.py in _sample(zs, labels, sampling_kwargs, priors, sample_levels, hps) 100 total_length = hps.sample_length//prior.raw_to_tokens 101 hop_length = int(hps.hop_fraction[level]*prior.n_ctx) --> 102 zs = sample_level(zs, labels[level], sampling_kwargs[level], level, prior, total_length, hop_length, hps) 103 104 prior.cpu() /usr/local/lib/python3.6/dist-packages/jukebox/sample.py in sample_level(zs, labels, sampling_kwargs, level, prior, total_length, hop_length, hps) 83 if total_length >= prior.n_ctx: 84 for start in get_starts(total_length, prior.n_ctx, hop_length): ---> 85 zs = sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps) 86 else: 87 zs = sample_partial_window(zs, labels, sampling_kwargs, level, prior, total_length, hps) /usr/local/lib/python3.6/dist-packages/jukebox/sample.py in sample_single_window(zs, labels, sampling_kwargs, level, prior, start, hps) 51 52 # get z_conds from level above ---> 53 z_conds = prior.get_z_conds(zs, start, end) 54 55 # set y offset, sample_length and lyrics tokens /usr/local/lib/python3.6/dist-packages/jukebox/prior/prior.py in get_z_conds(self, zs, start, end) 159 if self.level != self.levels - 1: 160 assert start % self.cond_downsample == end % self.cond_downsample == 0 --> 161 z_cond = zs[self.level + 1][:,start//self.cond_downsample:end//self.cond_downsample] 162 assert z_cond.shape[1] == self.n_ctx//self.cond_downsample 163 z_conds = [z_cond] IndexError: too many indices for tensor of dimension 1 It was prevoiuslly runned: frac025 = 0.25 frac050 = 0.5 # these are the points, 0.0-1.0, where you want to split the zs frac075 = 0.75 orig_len_zs_2 = len(zs[2][0]) new_len_zs_025 = math.floor( orig_len_zs_2 * frac025 ) new_len_zs_050 = math.floor( orig_len_zs_2 * frac050 ) new_len_zs_075 = math.floor( orig_len_zs_2 * frac075 ) zs_part1 = [ t.zeros(hps.n_samples, 0, dtype=t.long, device='cuda'), t.zeros(hps.n_samples, 0, dtype=t.long, device='cuda'), zs[2][:,0:new_len_zs_025].clone() ] zs_part2 = [ t.zeros(hps.n_samples, 0, dtype=t.long, device='cuda'), t.zeros(hps.n_samples, 0, dtype=t.long, device='cuda'), zs[2][:,new_len_zs_025:new_len_zs_050].clone() ] zs_part3 = [ t.zeros(hps.n_samples, 0, dtype=t.long, device='cuda'), t.zeros(hps.n_samples, 0, dtype=t.long, device='cuda'), zs[2][:,new_len_zs_050:new_len_zs_075:].clone() ] zs_part4 = [ t.zeros(hps.n_samples, 0, dtype=t.long, device='cuda'), t.zeros(hps.n_samples, 0, dtype=t.long, device='cuda'), zs[2][:,new_len_zs_075].clone() ]
leonardog27 commented 4 years ago

IMG-20201026-WA0016 Assertion Error appears on this try... :(

leonardog27 commented 4 years ago

Plan B running as follows: IMG-20201026-WA0007 IMG-20201026-WA0018

robinsloan commented 4 years ago

@leonardog27, I don't know if this is the root problem, but in the parts where you set hps.sample_length, you are using the shape from the old zs rather than the shape of the new zs_part2 or zs_part3 etc.

If it was me, I would get this up and running with just two parts before expanding to more. Once you have a pattern that works, it's easy to expand, but when you're doing a bunch of different/new things at once -- with more code -- it makes it more difficult to figure out what/where the problem is.

leonardog27 commented 4 years ago

@leonardog27, I don't know if this is the root problem, but in the parts where you set hps.sample_length, you are using the shape from the old zs rather than the shape of the new zs_part2 or zs_part3 etc.

If it was me, I would get this up and running with just two parts before expanding to more. Once you have a pattern that works, it's easy to expand, but when you're doing a bunch of different/new things at once -- with more code -- it makes it more difficult to figure out what/where the problem is.

Hi @robinsloan Upsampling by two parts setting manualy the time lenght was the only way to upsample by parts. Firts two parts, and then each part in two parts. Please find as follows results of the first experiment done. Thank you for your help. https://www.youtube.com/watch?v=Sm53U3zo3kk&t=26s