Closed juliendenize closed 4 months ago
These lines of code
params = self._get_params( [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform], num_chunks=1, chunks_indices=(torch.tensor([0], device=flat_inputs[0].device),), )[0]
raise an AttributeError: object has no attribute 'device' flat_inputs[0] is not a subclass of tensor.
flat_inputs[0]
To get the input device, we should follow an heuristic to retrieve the device from the first subclass of tensor found.
The same problem can be found in the code to get the batch size.
These lines of code
raise an AttributeError: object has no attribute 'device'
flat_inputs[0]
is not a subclass of tensor.To get the input device, we should follow an heuristic to retrieve the device from the first subclass of tensor found.