Open PetrochukM opened 3 years ago
As expected, the issue can be worked-around by using copy
, like so:
partial = functools.partial(worker_init_fn, data=copy.deepcopy(data))
Another workaround is to replace
data = (tensor, tensor.numpy())
by
data = (tensor, tensor.numpy().copy())
It looks like only arrays that don't own their data are concerned to this problem, because tensor.numpy()
is just a way to point to the tensor data from a numpy object (see here ). Given data is move to a shared memory zone, I guess the problem here is that the pointer given to the numpy array is no longer relevant, and should be changed to reflect the new place of the tensor.
The comment from @ClementPinard is correct. tensor
and tensor.numpy()
are sharing same data storage. When we want to share Tensor to processes, we called share_memory_()
to move the underlying storage, which tensor.numpy
also points.
You can try this to see the data behind Numpy is moved:
def run_worker(rank):
tensor = torch.randn(10)
np_tensor = tensor.numpy()
partial = functools.partial(worker_init_fn, data=data)
data_loader = DataLoader(list(range(1000)), num_workers=1, worker_init_fn=partial)
for data in data_loader:
pass
print(np_tensor)
print(tensor.numpy())
You can also pass tensor
into work_init_fn
and call tensor.numpy
within it.
Thanks for the workarounds. Either way, I think this behavior is unexpected. How can we prevent this behavior from causing issues downstream in the future?
So it seems to me that the core issue is that .numpy()
array failes to keep the original storage around.
In [25]: x = torch.randn(3)
In [26]: x
Out[26]: tensor([ 0.9417, 1.5524, -0.1936])
In [27]: xnp = x.numpy()
In [28]: xnp
Out[28]: array([ 0.9416804 , 1.5524077 , -0.19364882], dtype=float32)
In [29]: y = torch.zeros(10)
In [30]: x.set_(y.storage())
Out[30]: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
In [31]: x
Out[31]: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
In [32]: xnp
Out[32]: array([ 7.4516139e+13, 7.0064923e-45, -1.9364882e-01], dtype=float32)
This is not really data loader related so I am removing that flag.
Marking high priority based on @SsnL's investigation.
Well, at least we apparently already knew about this problem.
auto array = THPObjectPtr(PyArray_New(
&PyArray_Type,
tensor.dim(),
sizes.data(),
dtype,
strides.data(),
tensor.data_ptr(),
0,
NPY_ARRAY_ALIGNED | NPY_ARRAY_WRITEABLE,
nullptr));
if (!array) return nullptr;
// TODO: This attempts to keep the underlying memory alive by setting the base
// object of the ndarray to the tensor and disabling resizes on the storage.
// This is not sufficient. For example, the tensor's storage may be changed
// via Tensor.set_, which can free the underlying memory.
PyObject* py_tensor = THPVariable_Wrap(tensor);
if (!py_tensor) throw python_error();
if (PyArray_SetBaseObject((PyArrayObject*)array.get(), py_tensor) == -1) {
return nullptr;
}
// Use the private storage API
tensor.storage().unsafeGetStorageImpl()->set_resizable(false);
return array.release();
}
from torch/csrc/utils/tensor_numpy.cpp
🐛 Bug
For some reason creating a
DataLoader
instance with anumpy
array changes the data in the underlyingnumpy
array.DataLoader
has some sort of side-effect on the arguments that are passed to it.To Reproduce
Run this:
Sample output:
For some reason, my data has been changed from
(-1.0115291, 2.1784658)
to(0.0, 2.1784658)
in the main process, and(-0.6122731, 3.8310927e+33)
in the worker process.Expected behavior
DataLoader
shouldn't have any side-effects on its arguments.Environment
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @mruberry @rgommers @heitorschueroff @SsnL @VitalyFedyunin @ejguan