ritheshkumar95 / pytorch-vqvae

Vector Quantized VAEs - PyTorch Implementation
850 stars 138 forks source link

One question about backward grad. #8

Closed mazzzystar closed 6 years ago

mazzzystar commented 6 years ago

Hi, thanks for your implementation ! I'm now trying to implement the audio experiments of VQ-VAE. But when try to imitate your code, there is something I got confused:

But when I train the model by computing the .grad:

optimizer.zero_grad()
x_recon, z_e_x, z_q_x = model(qt_var, speaker_var)
z_q_x.retain_grad()

loss_recon = cross_entropy_loss(x_recon.view(hp.BATCH_SIZE, hp.Q, -1), quantized_audio.view(hp.BATCH_SIZE, -1).long())

loss_recon.backward(retain_graph=True)

# Staright-through estimator
z_e_x.backward(z_q_x.grad, retain_graph=True)

Error happened:

RuntimeError: grad can be implicitly created only for scalar outputs

It means my z_q_x does not have grad. Actually because I dido some quantization work, my z_q_x and z_e_x are LongTensor, is this the reason for no grad ?

csyanbin commented 5 years ago

Hi, @mazzzystar , have you solved the problem? I think your solution is quite straightforward.

mazzzystar commented 5 years ago

@csyanbin Hi I solved the problem but not based on his work. I created an Embedding for emb, and my zqx comes from the nearest vector for zex based on code below:

def find_nearest(zex, emb):
    """
    zex: ->(-1, self.z_dim)
    emb: (k, z_dim)
    """
    j = l2_dist(zex[:, None], emb[None, :]).sum(2).min(1)[1]
    # print("j_idx={}".format(j))
    return emb[j], j  # [1250, 64]

Then use a hook to return back the grad from zqx to encoder

csyanbin commented 5 years ago

Thanks. That seems good. Can I ask how to use hook function to return the grad from zqx to encoder?

mazzzystar commented 5 years ago

I wrote a tiny script with comment to let you know the general meanings.

# set hook var to zex
org_h = z_e_x
# define hook function
def hook(grad):
    nonlocal org_h
    self.saved_grad = grad
    self.saved_h = org_h
    return grad

# register hook
if z_q_x.requires_grad:
    z_q_x.register_hook(hook)

# define a function to backward hook function in model.py.
def bwd(self):
    self.saved_h.backward(self.saved_grad)

# backward loss
model.backward()
model.bwd()
csyanbin commented 5 years ago

Thanks so much! That is really helpful!