mlverse / torch

R Interface to Torch
https://torch.mlverse.org
Other
498 stars 68 forks source link

Backward pass fails on torch_max with the "inplace operation" error #1185

Open SemyonTab opened 3 months ago

SemyonTab commented 3 months ago

Seems that the backward pass does not work with torch_max function, or I made a mistake somewhere. R code:

m_tensor <- torch_tensor(matrix(1:8, nrow = 2), dtype = torch_float64(), requires_grad = TRUE) n <- torch_max(m_tensor, dim = 2)[[1]] n_sum <- torch_sum(n) n_sum$backward() m_tensor$grad

Fails with an error:

Error in (function (self, inputs, gradient, retain_graph, create_graph) : one of the variables needed for gradient computation has been modified by an inplace operation

Analogous code in Python seems to work fine:

m_tensor = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.float64, requires_grad=True) n = torch.max(m_tensor, dim=1).values n_sum = torch.sum(n) n_sum.backward() m_tensor.grad

Returns:

tensor([[0., 0., 0., 1.], [0., 0., 0., 1.]], dtype=torch.float64)

I would greatly appreciate your help, thanks!

SemyonTab commented 3 months ago

torch_amax works fine by the way:

m_tensor <- torch_tensor(matrix(1:8, nrow = 2), dtype = torch_float64(), requires_grad = TRUE) n <- torch_amax(m_tensor, dim = 2) n_sum <- torch_sum(n) n_sum$backward() m_tensor$grad

Returns:

torch_tensor 0 0 0 1 0 0 0 1 [ CPUDoubleType{2,4} ]