pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.14k stars 22.42k forks source link

Support gathering nested lists in DataParallel #13447

Open nikhilweee opened 5 years ago

nikhilweee commented 5 years ago

🚀 Feature

The current implementation of the gather function supports simple data structures like tensors and dictionaries but can not handle nested structures such as a list of lists. How about supporting them as well?

Motivation

Complex networks don't always return simple outputs such as tensors. Consider the following situation where the outputs from the forward pass of a question answering network is a dictionary with keys loss, answers and qids. Notice that answers and qids are a list of lists.

# Output from forward pass
 {'loss': tensor(8.6422, device='cuda:0', grad_fn=<ThAddBackward>),
  'answers': [['July 4', 'American Independence'], ['July 1', 'Canadian Independence']],
  'qids': [['C28fbfa5abq#1', 'C28fbfa5abq#2'], ['C2c6ca8b089q#1', 'C2c6ca8b089q#2']}

While running this model over multiple GPUs using nn.DataParallel, the following is passed as input to the gather function

# Input to the `gather` function
[{'loss': tensor(8.6422, device='cuda:0', grad_fn=<ThAddBackward>),
  'answers': [['July 4', 'American Independence'], ['July 1', 'Canadian Independence']],
  'qids': [['C28fbfa5abq#1', 'C28fbfa5abq#2'], ['C2c6ca8b089q#1', 'C2c6ca8b089q#2']},
 {'loss': tensor(8.7826, device='cuda:0', grad_fn=<ThAddBackward>),
  'answers': [['September 16', 'Mexican Independence'], ['November 3', 'Panamanian Independence']],
  'qids': [['C28fbfcdefab#1', 'C28fbfcdefab#2'], ['C2c623ef089q#1', 'C2c623ef089q#2']}]

Here's the expected output where the tensors and lists are concatenated together.

# Expected behaviour
{'loss': tensor([8.6422, 8.7826], device='cuda:0', grad_fn=<ThAddBackward>),
  'answers': [['July 4', 'American Independence'], ['July 1', 'Canadian Independence'], ['September 16', 'Mexican Independence'], ['November 3', 'Panamanian Independence']],
  'qids': [['C28fbfa5abq#1', 'C28fbfa5abq#2'], ['C2c6ca8b089q#1', 'C2c6ca8b089q#2'], ['C28fbfcdefab#1', 'C28fbfcdefab#2'], ['C2c623ef089q#1', 'C2c623ef089q#2']]}

But this is what happens right now. The map clause cannot handle nested lists.

# Current behaviour
{'loss': tensor([8.6422, 8.7826], device='cuda:0', grad_fn=<ThAddBackward>),
  'answers': [['<map object at 0x2ad2445a7f60>', '<map object at 0x2ad2445a75f8>'], ['<map object at 0x2b6862d87358>', '<map object at 0x2b6862d87668>']],
  'qids': [['<map object at 0x2b6862d87668>', '<map object at 0x2b6862d875c0>'], ['<map object at 0x2b6862d87438>', '<map object at 0x2b6862d87828>']]}

Pitch

One solution is to explicitly handle merging of lists together. This basically boils down to adding an extra condition for handling lists in the gather function here.

 if isinstance(out, list):
     return [item for output in outputs for item in output]

Alternatives

One workaround is to avoid nested outputs from the forward pass in the first place. This is still a workaround and not a solution.

Additional context

I'm not sure if this is the right way to handle the larger problem of handling complex data structures. For example one can always come up with a more nested structure which might not work with the proposed solution either. Should pytorch allow the provision for custom gather and scatter functions as it has done with collate_fn?

sseveran commented 4 years ago

Are there any updates or workarounds for this issue?

zou3519 commented 4 years ago

cc @mrshenli , what do you think about this?

quasimik commented 4 years ago

I wanted to gather a Dict which has different types on every key. One of the types is a List[List[Tensor]]. I wanted to keep these tensors as separate elements of the list, as the tensors are not on the dimension split by scatter and are not rectangular when combined. I additionally also wanted to transfer all the tensors to the target device, regardless of where they are in this nested structure.

An example of expected input and output:

>>> gather([{'a': [[tensor([1]), tensor([2])], [tensor([3])]]}, {'a': [[tensor([4])], [tensor([5]), tensor([6])]]}])
{'a': [[tensor([1]), tensor([2]), tensor([4])], [tensor([3]), tensor([5]), tensor([6])]]}

Here's my solution. There are likely more elegant ways.

def gather(outputs, target_device, dim=0):
    r"""
    Gathers tensors from different GPUs on a specified device
      (-1 means the CPU).
    """
    def recursive_to_device(inputs):
        if isinstance(inputs, torch.Tensor):
            return inputs.to(target_device)
        if isinstance(inputs, (list, tuple)):
            return type(inputs)(recursive_to_device(e) for e in inputs)
        if isinstance(inputs, dict):
            return type(inputs)((k, recursive_to_device(v)) for k, v in inputs.items())
        return inputs

    def gather_map(outputs):
        out = outputs[0]
        if isinstance(out, torch.Tensor):
            return Gather.apply(target_device, dim, *outputs)
        if out is None:
            return None
        if isinstance(out, dict):
            if not all((len(out) == len(d) for d in outputs)):
                raise ValueError('All dicts must have the same number of keys')
            ret = {}
            for k in out:
                if isinstance(out[k], torch.Tensor):
                    ret[k] = Gather.apply(target_device, dim, *[d[k] for d in outputs])
                else:
                    ret[k] = type(out[k])([recursive_to_device(e) for d in outputs for e in d[k]])
            return type(out)(ret)
        return type(out)(map(recursive_to_device, zip(*outputs)))

    # Recursive function calls like this create reference cycles.
    # Setting the function to None clears the refcycle.
    try:
        res = gather_map(outputs)
    finally:
        recursive_to_device = None
        gather_map = None
    return res

Sorry I do not have time to make a proper pull request.