mlverse / tft

R implementation of Temporal Fusion Transformers
https://mlverse.github.io/tft/
Other
25 stars 9 forks source link

Mixed tensor devices in running README.md example #9

Closed cregouby closed 3 years ago

cregouby commented 3 years ago

Current behavior

Running README.md example leads to

> fit <- tft_fit(rec, vic_elec_train, epochs = 100, batch_size=100, total_time_steps=12, num_encoder_steps=10, verbose=TRUE)
 Error in (function (self, other, alpha)  : 
  Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Exception raised from compute_types at ../aten/src/ATen/TensorIterator.cpp:388 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x69 (0x7fae102ebb29 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xd2 (0x7fae102e8ab2 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libc10.so)
frame #2: at::TensorIteratorBase::compute_types(at::TensorIteratorConfig const&) + 0x2d1 (0x7fadf27c0771 in /home/home/creg/R/x86_64-pc-linux-gnu-library/4.0/torch/deps/./libtorch_cpu.so)
frame #3: at::TensorIteratorBase::build(at::TensorIteratorConfig&) + 0x7a (0x7fadf27c40da in /home/home 

Expected behavior

no error

Workaround

add an explicit device="cuda" in the tft configuration parameters (TBC)