mlverse / tft

R implementation of Temporal Fusion Transformers
https://mlverse.github.io/tft/
Other
26 stars 9 forks source link

Crashes when running on GPU (cuda11.3) #38

Closed Sanaxen closed 2 years ago

Sanaxen commented 2 years ago

In the sample code it occurs in the following places

result <- luz::lr_finder(
  model, 
  transform(spec, train), 
  end_lr = 1,
  dataloader_options = list(
    batch_size = 64
  ),
  verbose = FALSE
)

and

fitted <- model %>% 
  fit(
    transform(spec),
    valid_data = transform(spec, new_data = valid),
    epochs = 100,
    callbacks = list(
      luz::luz_callback_keep_best_model(monitor = "valid_loss"),
      luz::luz_callback_early_stopping(
        monitor = "valid_loss", 
        patience = 5, 
        min_delta = 0.001
      )
    ),
    verbose = FALSE,
    dataloader_options = list(batch_size = 64, num_workers = 4)
  )

However, the CPU version works correctly. I have been able to install torch and tft without any trouble.

Sanaxen commented 2 years ago

A similar problem already existed. Rstudio crashes when fitting using luz #861 Rstudio crashes when fitting using luz #102

One thing I did find out.

On a PC with two GPUs it crashes, but on a PC with one GPU it works fine and exits with the correct result.

Does luz or torch support multi-GPU at all? If it is mentioned anywhere, I'm sure I'm missing it. Sorry if that is the case. Is there any way to specify which GPU to run on?

Sanaxen commented 2 years ago
fitted <- model %>% 
  fit(
    transform(spec),
    epochs = n_epochs,
    verbose = T,
    valid_data = transform(spec, new_data = valid),
    callbacks = list(
      luz::luz_callback_keep_best_model(monitor = "valid_loss"),
      luz::luz_callback_early_stopping(
        monitor = "valid_loss", 
        patience = 5, 
        min_delta = 0.001
      ),
    ),
    accelerator = tluz::accelerator( 
        device_placement = TRUE,
        cpu = TRUE,
        cuda_index = 0
    ) ,
    dataloader_options = list( batch_size = n_batch_size, num_workers = n_num_workers)
  )

We have found that explicitly specifying accelerator = tluz::accelerator allows for normal execution.