asteroid-team / asteroid-filterbanks

Asteroid's filterbanks :rocket:
https://asteroid-team.github.io/
MIT License
80 stars 20 forks source link

Trainable Per-Channel Energy Normalization #12

Closed cameronmaske closed 3 years ago

cameronmaske commented 3 years ago

Howdy!

This PR implementation of Trainable Per-Channel Energy Normalization (PCEN) (based on this paper).

It is heavily inspired by leaf-audio's implementation.

I thought it may be good to open this for some early feedback and to double-check if it's a contribution of interest!

I still need to finalize the docstrings, and I want to verify this behaves as expected in a notebook.

mpariente commented 3 years ago

Hey ! Yes, this is a very nice contribution, thanks !

I believe @popcornell and @michelolzam have some experience with PCEN, if they want to review.

BTW, we will drop support for Python<3.8, feel free to assume Python 3.8 or higher :-)

popcornell commented 3 years ago

This is great !

cameronmaske commented 3 years ago

Thanks for the warm reception @mpariente and @popcornell! This should be ready for review.

Here's a google collab notebook where you can play around with the parameters of the PCEN. image

Visually, it looks like it makes sense to me.

mpariente commented 3 years ago

image

I'd be happy to try the sliders :nerd_face: Could you fix the notebook please?

cameronmaske commented 3 years ago

@mpariente Oops! The notebook should be working now.

cameronmaske commented 3 years ago

@mpariente Changes made! It should be JIT compatible, I see there are a few warning, but from how I understand them, it should not affect the inference.

tests/pcen_test.py: 24 warnings
  /home/cam/dev/cameronmaske/asteroid-filterbanks/asteroid_filterbanks/pcen.py:145: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
    if not transforms.is_asteroid_complex(tf_rep):

tests/pcen_test.py: 24 warnings
  /home/cam/dev/cameronmaske/asteroid-filterbanks/asteroid_filterbanks/transforms.py:224: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
    if not is_asteroid_complex(tensor, dim):

tests/pcen_test.py: 24 warnings
  /home/cam/dev/cameronmaske/asteroid-filterbanks/asteroid_filterbanks/pcen.py:177: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
    alpha = torch.min(self.alpha, torch.tensor(1.0))

tests/pcen_test.py: 24 warnings
  /home/cam/dev/cameronmaske/asteroid-filterbanks/asteroid_filterbanks/pcen.py:178: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
    root = torch.max(self.root, torch.tensor(1.0))
mpariente commented 3 years ago

Great, the JIT looks fine, thanks!

cameronmaske commented 3 years ago

@mpariente @popcornell Changes made. I'm not sure why torch.stft CI check is failing (I don't think it is due to anything introduced on this branch).

mpariente commented 3 years ago

Indeed, it's not due to the branch, no problem.

Thanks a lot for the PR!
Somehow, I think the code could be simpler:

I'm ok to merge it like that, but I think the intent could be much more readable.

popcornell commented 3 years ago

alpha = torch.min(self.alpha, torch.tensor(1.0)) root = torch.max(self.root, torch.tensor(1.0))

maybe torch.clamp would be better ?

cameronmaske commented 3 years ago

root and alpha are trainable, so they may change between forwards.

alpha = torch.min(self.alpha, torch.tensor(1.0))
root = torch.max(self.root, torch.tensor(1.0))
one_over_root = 1.0 / root

I don't think clamp makes sense here? The torch.max ensures that root should always be >= 1, i.e. image With torch.clamp you have to set a min AND max.

mpariente commented 3 years ago

I see I was wrong in lots of comments, sorry about that.

Thanks again, let's merge!