Closed zj-zhang closed 2 years ago
https://github.com/ttesileanu/cancer-net/blob/0186da97b1014ab240ce6c19f0c9c396d538c88f/cancernet/util/tensor.py#L25-L27
reshape is not in-place so res is still a flat vector. This will fix it:
res
res = res.scatter_add_(0, ind1d, weights).reshape(*shape)
https://github.com/ttesileanu/cancer-net/blob/0186da97b1014ab240ce6c19f0c9c396d538c88f/cancernet/util/tensor.py#L25-L27
reshape is not in-place so
res
is still a flat vector. This will fix it: