mlverse / torch

R Interface to Torch
https://torch.mlverse.org
Other
490 stars 64 forks source link

nnf_multi_head_attention_forward (indices) #1133

Closed AGPatriota closed 4 months ago

AGPatriota commented 8 months ago

Dear fellows,

Can you please verify whether the following piece of code is correct:

 if (!is.null(in_proj_bias)) {
       q <- nnf_linear(query, q_proj_weight, in_proj_bias[1:embed_dim])
       k <- nnf_linear(key, k_proj_weight, in_proj_bias[embed_dim:(embed_dim * 2)])
       v <- nnf_linear(value, v_proj_weight, in_proj_bias[(embed_dim * 2):N])

link Should'nt it be the following:

if (!is.null(in_proj_bias)) {
      q <- nnf_linear(query, q_proj_weight, in_proj_bias[1:embed_dim])
      k <- nnf_linear(key, k_proj_weight, in_proj_bias[(embed_dim+1):(embed_dim * 2)])
      v <- nnf_linear(value, v_proj_weight, in_proj_bias[(embed_dim * 2+1):N])

Also, I did not find where N is defined in the code.

In Python, the code is:

   if in_proj_bias is not None:
             q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
             k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
             v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])

which is correct.

Thanks

cregouby commented 4 months ago

Hello @AGPatriota N in {torch} is a build-in for last(), an equivalent of the python negative index -1.

You are correct with the mistake in translating pytorch idx into R {torch} indices. I'll propose a fix.