ant-research / Pyraformer

Apache License 2.0
252 stars 38 forks source link

argument of get_q_k #3

Open ShuyangCao opened 2 years ago

ShuyangCao commented 2 years ago

Thanks for the great work! I have two questions regarding the code base:

In Pyraformer_LR.py, the q_k_mask is constructed given the inner_size and window_size. https://github.com/alipay/Pyraformer/blob/3cde01c18384f82a74ef339d90e601ee33def36f/pyraformer/Pyraformer_LR.py#L30

However, according to the definition of the function, these two arguments seem to be exchanged. https://github.com/alipay/Pyraformer/blob/3cde01c18384f82a74ef339d90e601ee33def36f/pyraformer/Layers.py#L91

Is this correct?

Also, different inner_size and window_size are selected for different tasks. How should I select these two parameters for tasks other than those in the paper?

Zhazhan commented 2 years ago

Thanks for your question. For the first question, the code itself is correct, we apologize for some confusion in the variable names. The "window_size" argument of function get_q_k() represents the number of neighbors that each node can attend to, thus corresponding to opt.inner_size. The "stride" argument of function get_q_k() represents the number of child nodes of each parent node, so it corresponds to opt.window_size. It would be better to name "stride" as "window_size" and name "window_size" as "inner_size". For the second question, we recommend to first select a small A to reduce complexity, such as 3 and 5. Then, in order to ensure the network has a receptive field of L, you can select a C that satisfies Equation (5). For more discussion of the selection of hyper-parameters, please refer to Appendix K in our paper. 捕获