lizekang / ITDD

The source code of our ACL2019 paper "Incremental Transformer with Deliberation Decoder for Document Grounded Conversations "
MIT License
86 stars 17 forks source link

argmax between first encoder and second decoder #10

Closed dishavarshney082 closed 4 years ago

dishavarshney082 commented 4 years ago

You are using argmax between the first encoder and second decoder. How is the back propagation happening since argmax is non-differentiable ?

lizekang commented 4 years ago

In contrast to the original Deliberation Network (Xia et al., 2017), where they propose a joint learning framework using Monte Carlo Method, we don't perform backpropagation from the second decoder to the first decoder as "Modeling coherence for discourse neural machine translation" does.

dishavarshney082 commented 4 years ago

So you are updating weights of both the first decoder and second decoder using their individual loss separately ?

Then why are you detaching the first decoder at the end of every decoder time step?

lizekang commented 4 years ago

So you are updating weights of both the first decoder and second decoder using their individual loss separately ?

Then why are you detaching the first decoder at the end of every decoder time step?

Yes, we update the weights using their individual loss separately. For the second problem, please provide some details.

dishavarshney082 commented 4 years ago

When computing the gradients using loss.backward() for the second decoder, gradients of the encoder should also change ?

I tried printing the gradients of some of the encoder parameters but the gradients did not change.

            print('first_decoder')
            batch_stats1 = self.train_loss.sharded_compute_loss(
                batch, first_outputs, first_attns, j,
                trunc_size, self.shard_size, normalization)

            for name, param in self.model.named_parameters():
                if 'encoder.htransformer.layer_norm.weight' in name:
                    print(param.grad)

            print('second_decoder')

            batch_stats2 = self.train_loss.sharded_compute_loss(
                batch, second_outputs, second_attns, j,
                trunc_size, self.shard_size, normalization)

            for name, param in self.model.named_parameters():
                if 'encoder.htransformer.layer_norm.weight' in name:
                    print(param.grad)

In such type of cases how do you ensure proper back propagation?

lizekang commented 4 years ago

The context encoder doesn't change because the second decoder only uses the knowledge representation and the first-pass output representation. There is no grad propagate to the context encoder (encoder.htransformer).

dishavarshney082 commented 4 years ago

Thanks for your response. I understood a lot of things. BTW have you ever used gumbel softmax ?

lizekang commented 4 years ago

No, I haven't used the gumbel softmax. But I think it will work. You can have a try.