nipype / pydra

Pydra Dataflow Engine
https://nipype.github.io/pydra/
Other
120 stars 59 forks source link

splitting list of torch.Tensor causes "repeating" behavior of the output #761

Open wilke0818 opened 1 month ago

wilke0818 commented 1 month ago
import torch
import pydra

@pydra.mark.task
def test_task(fake_audio_input):
  return fake_audio_input+2

wf = pydra.Workflow(name='wf_test', input_spec=['audio_input'])

wf.split('audio_input', audio_input =[torch.tensor([1]),torch.tensor([2])]).combine('audio_input')

wf.add(test_task(name='testing', fake_audio_input=wf.lzin.audio_input))
# wf.combine('audio_input')
wf.set_output([('wf_out', wf.testing.lzout.out)])

with pydra.Submitter(plugin='cf') as sub:
    sub(wf)
print(wf.done)
# results = wf.result(return_inputs='val')
results = wf(plugin='cf')
print(results)
# output: [Result(output=Output(wf_out=tensor([3])), runtime=None, errored=False), Result(output=Output(wf_out=tensor([3])), runtime=None, errored=False)]

Similarly, using other tensor structures doesn't change the behavior:

wf.split('audio_input', audio_input =[torch.tensor([1,2]),torch.tensor([3,4])]).combine('audio_input')
# output: [Result(output=Output(wf_out=tensor([3, 4])), runtime=None, errored=False), Result(output=Output(wf_out=tensor([3, 4])), runtime=None, errored=False)]

or

wf.split('audio_input', audio_input =[torch.tensor([[1],[2]]),torch.tensor([[3],[4]])]).combine('audio_input')
# output: [Result(output=Output(wf_out=tensor([[3],
#        [4]])), runtime=None, errored=False), Result(output=Output(wf_out=tensor([[3],
#        [4]])), runtime=None, errored=False)]

Notably, using numpy does give the expected behavior (likely as a result of #340)

wf.split('audio_input', audio_input =[np.array([1,2]),np.array([3,4])]).combine('audio_input')
# output: [Result(output=Output(wf_out=array([3, 4])), runtime=None, errored=False), Result(output=Output(wf_out=array([5, 6])), runtime=None, errored=False)]
wilke0818 commented 1 month ago

A further look shows that the checksum/hashing process might be the root cause. The checksum/hashes for tensors seems to be the same even if the values of the tensors differ, whereas this is not the case for numpy.

effigies commented 1 month ago

If that's the case, then what you probably need to do is register a serializer for the type:

from pydra.utils.hash import register_serializer, Cache

@register_serializer(torch.tensor)
def bytes_repr_tensor(obj: torch.tensor, cache: Cache) -> Iterator[bytes]:
    # Some efficient method for turning the object into a byte sequence to hash

See https://github.com/nipype/pydra/blob/master/pydra/utils/hash.py for examples.

If you have an approach that will work with all array-likes, we could update bytes_repr_numpy to apply more broadly.

wilke0818 commented 1 month ago

Hi @effigies I have coded up a solution locally, but am getting remote: Permission to nipype/pydra.git denied to wilke0818. fatal: unable to access 'https://github.com/nipype/pydra.git/': The requested URL returned error: 403 was wondering if this is related to the repo or is with my account? Not sure if I need to be added to the repository to make contributions.

effigies commented 1 month ago

You'll need to fork the repository, push to a branch on your own fork, and then create a pull request.

See https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/getting-started/about-collaborative-development-models.

satra commented 1 month ago

it should be the case that without any registering additional serializers pydra should behave appropriately for hashing an arbitrary object (however inefficiently). this is the second example where this has broken down. that seems like a pydra bug.

effigies commented 1 month ago

C extension objects are going to be difficult, as they may not be introspectable (or differentiable) in the same way as pure Python objects, using __dict__ or __slots__. I believe this is why numpy arrays were special-cased in the beginning.

In #762 I have suggested that we identify numpy array API objects with a protocol, which should cover many of these use cases.

satra commented 1 month ago

perhaps this issue is more about how we are detecting types of objects then. may be if we are not confident, we can and should fallback to hashing the pickled bytestream (that should generally work). i believe if this code reached the "pickle and has bytestream part", @wilke0818 wouldn't have the erroneous behavior that he noticed.