lf1-io / padl

Functional deep learning
Apache License 2.0
106 stars 4 forks source link

Assignment breaks the saving #427

Closed jasonkhadka closed 2 years ago

jasonkhadka commented 2 years ago

🐞 Bug

Assignment breaks the saving

Works BUT

@transform def times_two(x): return x * 2

times_two.pd_to(DEVICE)

save(times_two, 'm.padl') m = load('m.padl')

content of `transform.py`

from padl import transform

@transform def times_two(x): return x * 2

_pd_main = times_two


## Not working code
* triggers `visit_Attribute` 6 times.
* creation of `times_two` is not written on the file.

DEVICE = 'cpu'

@transform def times_two(x): return x * 2

times_two = times_two.pd_to(DEVICE)

save(times_two, 'm.padl') m = load('m.padl')


content of `transform.py`

DEVICE = 'cpu' times_two = times_two.pd_to(DEVICE) _pd_main = times_two


## Works BUT
* forgets about `device` assigment.
* forgets about `new_name`
* does not trigger `visit_Attribute`

import padl from padl import transform, save, load

DEVICE = 'cpu'

@transform def times_two(x): return x * 2

new_name = times_two.pd_to(DEVICE)

save(new_name, 'm.padl', force_overwrite=True) m2 = load('m.padl')

This works. but forgets about `new_name` and `new_name` assigment.
Below is the code written on `transform.py`.

from padl import transform

@transform def times_two(x): return x * 2

_pd_main = times_two



### 
sjrl commented 2 years ago
wuhu commented 2 years ago

@sjrl I think the automatic check is a good idea, at least for now so that if something isn't saved properly it doesn't fail silently. Just saving the layers could also be useful.

@jasonkhadka the two "Works BUT" examples are expected: (1) Saving does not dump attribute calls if they are not part of an assignment (see https://lf1-io.github.io/padl/latest/advanced/saving.html#mutated-objects) (2) function transforms always go by the name of the function, not by the name of a variable they're assigned to, so when you dump them, the saver tracks down the name of the function, not the name of the variable - so it's not strictly a bug, still we could think about whether this is the most intuitive behavior and perhaps change it

jasonkhadka commented 2 years ago

PR for the fix: https://github.com/lf1-io/padl/pull/432

This still has 1 issue.

from padl import transform

recursive = 2

@transform
def recursive(x):
    if x == 0:
        return x
    return 1 + recursive(x - 1)

Gives:

from padl import transform

recursive = 2

@transform
def recursive(x):
    if x == 0:
        return x
    return 1 + recursive(x - 1)

_pd_main = recursive

There is no harm in having recursive = 2. But at the same time, it is not needed.

jasonkhadka commented 2 years ago

I see. We can also think about including the attribute calls if needed. But that can be a future feature.

@sjrl I think the automatic check is a good idea, at least for now so that if something isn't saved properly it doesn't fail silently. Just saving the layers could also be useful.

@jasonkhadka the two "Works BUT" examples are expected: (1) Saving does not dump attribute calls if they are not part of an assignment (see https://lf1-io.github.io/padl/latest/advanced/saving.html#mutated-objects) (2) function transforms always go by the name of the function, not by the name of a variable they're assigned to, so when you dump them, the saver tracks down the name of the function, not the name of the variable - so it's not strictly a bug, still we could think about whether this is the most intuitive behavior and perhaps change it

jasonkhadka commented 2 years ago

Why does PR https://github.com/lf1-io/padl/pull/432 fixes this issue?

Current problem:

@transform
def times_two(x):
    return x * 2

times_two = times_two.pd_to(DEVICE) # <- this line

In code above, <- this line has times_two that is identified as ScopedName('times_two',..,n=0) and times_two.pd_to as ScopedName('times_two.pd_to',..,n=0).

Now if you look up ScopedName('times_two.pd_to',..,n=0) in ast.tree for source code, you will end up finding the same line highlighted by # <- this line. That is because n =0 is signifying the latest line defining times_two. The suffix.pd_to does not matter while trying to find the source code here.

Currecnt increment_same_name_var tries to fix this issue by adding initial n with new n. But when both n are 0, 0 + 0 = 0. So this was the source of this failure. I updated increment_same_name_var to take care of few different cases that fixes this issue, and also issue with overried some variable. Example:

a = 2

def a(arg):
    if arg == 0:
         return 1
    return arg*a(arg-1)

In the above code, Function a depends on arg and on itself (Function a). Function a first will be identified as ScopedName(a, ..., n=0), and it will find new a. ScopedName(a, ..., n=0), but now if we just add 1 to n, new ScopedName(a, ..., n=1) will actually point to a = 2. This is wrong clearly as function a is not dependent on the declaration a = 2.

So, there are special checks when variables are overridden and still can have self-dependence, which is true for FunctionDef and ClassDef.

sjrl commented 2 years ago

So, there are special checks when variables are overridden and still can have self-dependence, which is true for FunctionDef and ClassDef.

Are there any other Def types that should be caught here as well? I guess what I'm asking is there an exhaustive list of Def types in python somewhere?

jasonkhadka commented 2 years ago

This should check for any possiblity of recursion basically. I can only think of recursion possible through functiondef and classdef. Cannot find another source of it. I agree if there was a list of possible defs that would be great.

So, there are special checks when variables are overridden and still can have self-dependence, which is true for FunctionDef and ClassDef.

Are there any other Def types that should be caught here as well? I guess what I'm asking is there an exhaustive list of Def types in python somewhere?

wuhu commented 2 years ago

here's a list of all ast node types: https://docs.python.org/3/library/ast.html

wuhu commented 2 years ago

@jasonkhadka thanks for the explanation! It seems that you have slightly misunderstood the role of increment_same_name_var - this is not where the n should be incremented in this case and I don't think it's the source of the bug. It's a bit hard to explain here therefore I've written a notebook: increment_same_name_var.zip

I think the problem is here: https://github.com/lf1-io/padl/blob/main/padl/dumptools/var2mod.py#L249

Let me know if you want to discuss this!

jasonkhadka commented 2 years ago

Thanks, I had gone through that and updated the list for Func, Class, AsyncFuncDef.

here's a list of all ast node types: https://docs.python.org/3/library/ast.html

Thanks for the explanation, I am going through the notebook. I will ping you if I need more clarifaction. I will update the PR with the changes. I think the changes and conditions for n+1 I introduced are needed but may be they are in the wrong place at the moment.

It seems that you have slightly misunderstood the role of increment_same_name_var - this is not where the n should be incremented in this case and I don't think it's the source of the bug. It's a bit hard to explain here therefore I've written a notebook:

jasonkhadka commented 2 years ago

From the discussion on PR: https://github.com/lf1-io/padl/pull/432

@wuhu pointed out:

from padl import *
import numpy.random
numpy = numpy.random

@transform
def f(bla):
    return numpy(10)

print(f._pd_dumps())

would not work.

Also here, numpy, n = 1 would not work, it needs to have numpy.random, n = 0.

Another example that fails:

from padl import *

class A:
    def __init__(self, arg):
        self.A = arg

@transform
class B:
    def __init__(self, arg):
        self.B = arg

a = A(2)

b = B(a)

b = b.B.A

b = B(b)
print(b._pd_dumps())
>> RuntimeError: Graph has a circle or dangling roots.

While:

a = A(2)

b = B(a)

b.B.A = 3
print(b._pd_dumps())

would skip last b.B.A = 3 assignment but would give the dump.

from padl import transform

class A:
    def __init__(self, arg):
        self.A = arg

a = A(2)

@transform
class B:
    def __init__(self, arg):
        self.B = arg

_pd_main = B(a)
wuhu commented 2 years ago

Right, though this

[...]
b = B(b)
print(b._pd_dumps())

seems to be a different problem, even just:

@transform
class B:
    def __init__(self, x):
        ...
b = 1
b = B(b)
print(b._pd_dumps())

goes wrong, dump is:

from padl import transform

b = 1

@transform
class B:
    def __init__(self, x):
        ...

b = B(b)
_pd_main = B(b)

That's because the transform to dump (B(b)) depends on a variable with the same name the transform itself has, which isn't accounted for.

We should be able to deal with this by applying increment_same_name_var to todo in Transform._pd_build_codegraph with Transform.pd_varname (I think).

wuhu commented 2 years ago

The other problem seems a bit tricky - to get rid of the bug, it seems, we need to have (and deal with) multiple "conditional" ns (e.g. for x = x.y: "if it's x: n=1, if it's y: n=0). We should think carefully about how to do that, I believe it has the potential complicate things a lot if we don't find the right abstraction. The ns are used in the various symfinder.find.. functions (ultimately in symfinder.find_in_source as the i parameter: Basically when something's found, it decrements i and continues searching until i == 0.

jasonkhadka commented 2 years ago

After discussion: I will implement the variants for ScopedName. So for code:

import numpy.random
numpy = numpy.random

ScopedName('numpy') finds its dependence on numpy.random, it will create new ScopedName with two variants of possiblity of name and n on same scope. ScopedName(scope=Scope, variants=[('numpy', 1), ('numpy.random', 0)])

So both variants will be looked up.

Another example:

a = obj()        #1
a.b = 2          #2 
a = a.b + 1    #3

Order of variable finds: On # 3: [(a, 1), (a.b, 0)]

(a,1) will find a = obj() on # 1 (a.b, 0) will find a.b = 2 on # 2

a.b = 2 will be picked up as that is the latest. That is a problem. If the target here is an attribute, probably both a and a.b need to be looked up for dependency.