Open nikhilweee opened 5 years ago
Are there any updates or workarounds for this issue?
cc @mrshenli , what do you think about this?
I wanted to gather
a Dict
which has different types on every key. One of the types is a List[List[Tensor]]
. I wanted to keep these tensors as separate elements of the list, as the tensors are not on the dimension split by scatter
and are not rectangular when combined. I additionally also wanted to transfer all the tensors to the target device, regardless of where they are in this nested structure.
An example of expected input and output:
>>> gather([{'a': [[tensor([1]), tensor([2])], [tensor([3])]]}, {'a': [[tensor([4])], [tensor([5]), tensor([6])]]}])
{'a': [[tensor([1]), tensor([2]), tensor([4])], [tensor([3]), tensor([5]), tensor([6])]]}
Here's my solution. There are likely more elegant ways.
def gather(outputs, target_device, dim=0):
r"""
Gathers tensors from different GPUs on a specified device
(-1 means the CPU).
"""
def recursive_to_device(inputs):
if isinstance(inputs, torch.Tensor):
return inputs.to(target_device)
if isinstance(inputs, (list, tuple)):
return type(inputs)(recursive_to_device(e) for e in inputs)
if isinstance(inputs, dict):
return type(inputs)((k, recursive_to_device(v)) for k, v in inputs.items())
return inputs
def gather_map(outputs):
out = outputs[0]
if isinstance(out, torch.Tensor):
return Gather.apply(target_device, dim, *outputs)
if out is None:
return None
if isinstance(out, dict):
if not all((len(out) == len(d) for d in outputs)):
raise ValueError('All dicts must have the same number of keys')
ret = {}
for k in out:
if isinstance(out[k], torch.Tensor):
ret[k] = Gather.apply(target_device, dim, *[d[k] for d in outputs])
else:
ret[k] = type(out[k])([recursive_to_device(e) for d in outputs for e in d[k]])
return type(out)(ret)
return type(out)(map(recursive_to_device, zip(*outputs)))
# Recursive function calls like this create reference cycles.
# Setting the function to None clears the refcycle.
try:
res = gather_map(outputs)
finally:
recursive_to_device = None
gather_map = None
return res
Sorry I do not have time to make a proper pull request.
🚀 Feature
The current implementation of the
gather
function supports simple data structures like tensors and dictionaries but can not handle nested structures such as a list of lists. How about supporting them as well?Motivation
Complex networks don't always return simple outputs such as tensors. Consider the following situation where the outputs from the forward pass of a question answering network is a dictionary with keys
loss
,answers
andqids
. Notice thatanswers
andqids
are a list of lists.While running this model over multiple GPUs using
nn.DataParallel
, the following is passed as input to the gather functionHere's the expected output where the tensors and lists are concatenated together.
But this is what happens right now. The
map
clause cannot handle nested lists.Pitch
One solution is to explicitly handle merging of lists together. This basically boils down to adding an extra condition for handling lists in the
gather
function here.Alternatives
One workaround is to avoid nested outputs from the forward pass in the first place. This is still a workaround and not a solution.
Additional context
I'm not sure if this is the right way to handle the larger problem of handling complex data structures. For example one can always come up with a more nested structure which might not work with the proposed solution either. Should pytorch allow the provision for custom
gather
andscatter
functions as it has done withcollate_fn
?