yzhangcs / parser

:rocket: State-of-the-art parsers for natural language.
https://parser.yzhang.site/
MIT License
829 stars 141 forks source link

Understanding biaffine operation. #2

Closed pranoy-k closed 5 years ago

pranoy-k commented 5 years ago

https://github.com/zysite/biaffine-parser/blob/e54c2104658443e10df4e27a392041a559fcc745/parser/modules/biaffine.py#L43

Hi zysite I am really getting confused by this operation. In the official code they do an adjoint of a matrix and here you do s = x @ self.weight @ torch.transpose(y, -1, -2).

Especially I want to understand how it is same as

Do the multiplications

(bn x d) (d x rd) -> (bn x rd)

lin = tf.matmul(tf.reshape(inputs1, [-1, inputs1_size + add_bias1]), tf.reshape(weights, [inputs1_size + add_bias1, -1]))

(b x nr x d) (b x n x d)T -> (b x nr x n)

bilin = tf.matmul( tf.reshape(lin, [batch_size, inputs1_bucket_size * output_size, inputs2_size + add_bias2]), inputs2, adjoint_b=True)

What was your intuition to come up with this. Can you help me understand this in detail. Thanks a lot for your help.

Sincerely Pranoy

yzhangcs commented 5 years ago

the @ operator comes from numpy.dot it seems that PyTorch retains this operator here the shape of x and y are [batch_size, 1, seq_len, d], and self.weight is an [n_out, d, d] tensor so x @ self.weight will be an [batch_size, n_out, seq_len, d] tensor the final shape of x @ self.weight @ torch.transpose(y, -1, -2) is [batch_size, n_out, seq_len, seq_len] you can refer to the doc of torch.matmul for more information

pranoy-k commented 5 years ago

Oh thanks a lot, that really clarifies. :)

yzhangcs commented 5 years ago

you're welcome:blush: