lf1-io / padl

Functional deep learning
Apache License 2.0
106 stars 4 forks source link

The output of `(t1 + t2) / t3` isn't flattened #453

Open philip-bl opened 2 years ago

philip-bl commented 2 years ago

Either a 🐞 Bug or Unclear Documentation

https://lf1-io.github.io/padl/latest/usage/combining_transforms.html#grouping-transforms says

By default, Pipelines, such as rollouts and parallels, are flattened. This means that even if you use parentheses to group them, the output will be a flat tuple.

I provide a code example (used with padl 0.2.5) where I expect the output to be a flat tuple but instead it's a tuple with a tuple inside

import padl
pipeline = (padl.Identity() + padl.transform(lambda x: x**2)) / padl.transform(lambda y: y + 100)
print(pipeline((2, 5)))  # prints namedtuple(out_0=namedtuple(out_0=2, out_1=4), out_1=105)
(pipeline >> padl.Identity() / padl.Identity() / padl.Identity())((2, 5))  # raises IndexError: tuple index out of range
jasonkhadka commented 2 years ago

@philip-bl

Thanks for pointing out the confusion in the documentation. It is indeed unclear documentation.

The output is a flat tuple only for cases when you are using a single operation: either parallel or rollout to create a pipeline. As shown in the example on the documentation:

(t1 + (t2 + t3))(x) == ((t1 + t2) + t3)(x) == (t1 + t2 + t3)(x)  == (t1(x), t2(x), t2(x))
True

But if you mix the operations (parallel, rollout and compose), pipeline will be created using the python operator precedence

That is why first there is a rollout created below, and then a parallel on top of rollout.

import padl
pipeline = (padl.Identity() + padl.transform(lambda x: x**2)) / padl.transform(lambda y: y + 100)

pipeline
Parallel - "pipeline":

   │└─▶ 0: Identity() + lambda x: x**2
   │  /  
   └──▶ 1: lambda y: y + 100

Similarly, in your second example, parallels are created first as / takes precedence over >>. And in the end the created parallel is composed with the pipeline as pipeline >> new-parallel-with-three-identity

(pipeline >> padl.Identity() / padl.Identity() / padl.Identity())
Compose:

      │└───────────┐
      │            │
      ▼ args       ▼ y
   0: [..+..] / lambda y: y + 100
      ││└──────────────────────────────┐
      │└───────────┐                   │
      │            │                   │
      ▼ args       ▼ args              ▼ args
   1: Identity() / Identity()        / Identity()

I understand that the arrows above might not be helpful but image it as:

Compose:

      │
      │            
      ▼ args     
   0: Pipeline
      ││└──────────────────────────────┐
      │└───────────┐                   │
      │            │                   │
      ▼ args       ▼ args              ▼ args
   1: Identity() / Identity()        / Identity()

If you replace >> with another / in your example, then things are fine again:

(pipeline / padl.Identity() / padl.Identity() / padl.Identity())
Out[35]: 
Parallel:

   ││││└─▶ 0: Identity() + lambda x: x**2
   ││││  /  
   │││└──▶ 1: lambda y: y + 100
   │││   /  
   ││└───▶ 2: Identity()
   ││    /  
   │└────▶ 3: Identity()
   │     /  
   └─────▶ 4: Identity()

In : (pipeline / padl.Identity() / padl.Identity() / padl.Identity())((0, 1, 2, 3, 4))
Out: namedtuple(out_0=namedtuple(out_0=0, out_1=0), out_1=101, out_2=2, out_3=3, out_4=4)
philip-bl commented 2 years ago

In that case, maybe you shouldn't flatten it automatically at all. Or come up with a way to flatten any combinations of parallel and rollout. Idk.

jasonkhadka commented 2 years ago

Thanks for your suggestion. We will think about flattening in general.

jasonkhadka commented 2 years ago

Lets start work on documentation.