yoyololicon / pytorch-NMF

A pytorch package for non-negative matrix factorization.
https://pytorch-nmf.readthedocs.io/
MIT License
220 stars 24 forks source link

Learning end2end with a neural network #17

Open jonnor opened 2 years ago

jonnor commented 2 years ago

Hi, thank you for this nice project.

Could one connect a neural network to the NFM module, and learn them at the same time? Any example code or tips on how to do that? I am interested in using a convolutional neural network frontend on spectrogram data, and capture a bit more complex activations than single stationary spectrogram frames.

yoyololicon commented 2 years ago

Hi @jonnor ,

Could one connect a neural network to the NFM module, and learn them at the same time? Any example code or tips on how to do that? I am interested in using a convolutional neural network frontend on spectrogram data, and capture a bit more complex activations than single stationary spectrogram frames.

I do plan to add some examples as jupyter notebooks but I'm currently busy at other projects. Your application sounds totally doable to me, but you have to make sure that all the gradients pass from the loss to the NMF parameters are always non-negative.

For example, you want to train a model that will predict the activations, and learn a shared non-negative template jointly, then you can do something like this:

import torch
from torch import nn
from torch import optim
from torchnmf.trainer import BetaMu
from torchnmf import NMF

#pick an activation function so the output is non-negative
H = nn.Sequential(AnotherModel(), nn.Softplus())       
W = NMF(W=(out_channels, in_channels))

optimizer = optim.Adm(H.parameters())
trainer = BetaMu(W.parameters())

for x, y in dataloader:
     # optimize NMF
    def closure():
        trainer.zero_grad()
        with torch.no_grad():
            h = H(x)
        return y, W(H=h)
    trainer.step(closure)

    # optimize nueral net
    h = H(x)
    predict = W(H=h)
    loss = ... # you can use other types of loss here
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
jonnor commented 2 years ago

Hi @yoyololicon - thank you for the response and example code! In this framework, the loss would be something that compares the output of the NMF (decomposed and re-composed)? Like RMS as a simple case, or a perceptual metric for something more advanced?

yoyololicon commented 2 years ago

Hi @yoyololicon - thank you for the response and example code! In this framework, the loss would be something that compares the output of the NMF (decomposed and re-composed)? Like RMS as a simple case, or a perceptual metric for something more advanced?

@jonnor Yes, in the above code you are free to use these kinds of loss function, not only beta divergence. The NMF part is still trained with beta divergence though.