ChunchuanLv / AMR_AS_GRAPH_PREDICTION

53 stars 16 forks source link

KL Divergence expression incorrect? #7

Open StalVars opened 5 years ago

StalVars commented 5 years ago

https://github.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/blob/5d2ddb06ab0da35066a5304b53df2815d4fc268e/src/train.py#L201

I think, according to the paper, "Learning Latent Permutations with Gumbel-Sinkhorn Networks" (https://arxiv.org/pdf/1802.08665.pdf),

it should be S = S + r scores[:l,i,:l].sum()+gamma_rtorch.exp( -scores[:l,i,:l]r).sum() not S = S + r / scores[:l,i,:l].sum()+gamma_rtorch.exp( -scores[:l,i,:l]*r).sum()

ChunchuanLv commented 5 years ago

Hi Stalin,

Thanks for pointing this out. I am busy with the ACL deadline, will come back with an answer later.

Chunchuan

On Mon, 25 Feb 2019 at 14:19, Stalin Varanasi notifications@github.com wrote:

https://github.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/blob/5d2ddb06ab0da35066a5304b53df2815d4fc268e/src/train.py#L201

I think, according to the paper, "Learning Latent Permutations with Gumbel-Sinkhorn Networks" (https://arxiv.org/pdf/1802.08665.pdf),

it should be S = S + r scores[:l,i,:l].sum()+gamma_rtorch.exp( -scores[:l,i,:l]

r).sum() not S = S + r / scores[:l,i,:l].sum()+gamma_rtorch.exp( -scores[:l,i,:l]*r).sum()

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/ChunchuanLv/AMR_AS_GRAPH_PREDICTION/issues/7, or mute the thread https://github.com/notifications/unsubscribe-auth/ADs1bQ7fOOJ50hmWn2RZt9n50fR03N5wks5vQ_EOgaJpZM4bP-Wu .