The initial tensor resides on CPU. For me, it failed when running RHO-LOSS, since the very first self.rho_loss = torch.cat([self.rho_loss, training_loss - irreducible_loss]).to(training_loss.dtype) was on two different devices, the CPU (initial tensor) and the GPU (training loss and IR loss). This PR fixes that by moving the initial tensor to the correct device.
The initial tensor resides on CPU. For me, it failed when running RHO-LOSS, since the very first
self.rho_loss = torch.cat([self.rho_loss, training_loss - irreducible_loss]).to(training_loss.dtype)
was on two different devices, the CPU (initial tensor) and the GPU (training loss and IR loss). This PR fixes that by moving the initial tensor to the correct device.