zhanglab-aim / cancer-net

Diagnosing cancers using deep learning.
GNU General Public License v2.0
2 stars 0 forks source link

scatter_nd returns flat vector #12

Closed zj-zhang closed 2 years ago

zj-zhang commented 2 years ago

https://github.com/ttesileanu/cancer-net/blob/0186da97b1014ab240ce6c19f0c9c396d538c88f/cancernet/util/tensor.py#L25-L27

reshape is not in-place so res is still a flat vector. This will fix it:

res = res.scatter_add_(0, ind1d, weights).reshape(*shape)