mlverse / torch

R Interface to Torch
https://torch.mlverse.org
Other
498 stars 68 forks source link

Add a currently unexported `torch_tensor_free` that deletes tensors without requiring us to wait for GC. #1194

Closed dfalbel closed 2 months ago

dfalbel commented 2 months ago

@sebffischer Do you want to try out and see if this allows some speedups on your setup? I didn't see speedups on MPS (memory management on mps is soo different), but maybe it helps with CUDA:

library(torch)

p = 100
steps = 1000
n = 1000

device = "mps"

X = torch_randn(n, p, device = device)
beta = torch_randn(p, 1, device = device)
Y = X$matmul(beta)

latent = 5000

net = nn_sequential(
  nn_linear(p, latent),
  nn_relu(),
  nn_linear(latent, 1)
)

net$to(device = device)
opt = optim_adam(net$parameters, lr = 0.01)

t1 = Sys.time()

for (i in 1:steps) {
  opt$zero_grad()
  Y_hat = net(X)
  loss = nnf_mse_loss(Y, Y_hat)
  loss$backward()
  torch:::torch_tensor_free(loss)
  torch:::torch_tensor_free(Y_hat)
}

t2 = Sys.time()

print(paste0("Total time: ", t2 - t1))
sebffischer commented 2 months ago

awesome, thanks! I will try it out and report back