awf / awfutils

AWF's Utilities
MIT License
3 stars 0 forks source link

AWFUTILS: AWF's Utilities

A few utils for general python programming

Usage

This is very much code-in-progress. When I use it, I typically just put it as a submodule under whatever else I'm building, so I can easily bugfix awfutils as I do other work:

$ git submodule add https://github.com/awf/awfutils

Typecheck

I love run time type checkers, particularly with JAX but by default they (OK, beartype doesn't) don't check statement-level annotations like these:

def foo(x : int, y : float):
  z : int = x * y # This should error, but doesn't
  w : float = z * 3.2
  return w

foo(3, 1.3)

With the awfutils typecheck decorator, they can...

@typecheck
def foo(x : int, y : float):
  z : int = x * y # Now it raises TypeError: z not of type int
  w : float = z * 3.2
  return w

foo(3, 1.3) # Error comes from this call

This works by AST transformation, replacing the function foo above with the function

def foo_typecheck_wrap(x: int, y: float):
    assert isinstance(x, int), 'x not of type int'
    assert isinstance(y, float), 'y not of type float'
    z: int = x * y
    assert isinstance(z, int), 'z not of type int'
    w: float = z * 3.2
    assert isinstance(w, float), 'w not of type float'
    return w

Because it's AST transformation it is basically literally the above code, which you can see with the optional argument show_src=True

@functools.partial(typecheck, show_src=True)
def foo(x : int, y : float):
  z : int = x * y # Now it does
  w : float = z * 3.2
  return w

Arg

A distributed argument parser, like absl flags, but a little more convenient and less stringy. If you want a config value anywhere in your program, just declare an Arg nearby (at top level), and use it:

# my-random-file.py
from jaxutils.Arg import Arg

tau = Arg("mqh-tau", 0.01, "Scale factor for softmax in my_quick_hack")

def my_quick_hack(xs, qs):
    ...
    softmax(tau() * xs @ qs.T)
    ...

Now, even if my_quick_hack is far down the call tree from main, you can quickly try some values of tau by just running

$ python main.py -mqh-tau 0.0001

More conventionally, Arg is also useful in main:

def main():
    param1 = Arg("p1", default=34, help="Set first parameter")
    switch = Arg("s", False, "Turn on the switch")

    if switch(): # This is where arg parsing first happens. See notes.
        foo(param1())

    # To see which args have been set:
    print(Arg.config())

It's all a thin wrapper around argparse, so inherits a bunch of goodness from there, but avoids a lot of long-chain plumbing when my_quick_hack is far down the call tree from main.

And yes, these are global variables. This is absolutely reasonable, because the command line is a global resource. You'll see that a little bit of namespacing has been illustrated above, where tau was given the flag mqh-tau. Feel free to formalize that as much as you like.

Parsing happens the first time any of the Args is read, and is then cached.

There is a potential gotcha if you want to act on an arg during module load time, e.g. at the top level:

jit_foo = Arg("jit-foo", False, "Run JIT on foo")
if jit_foo():   # Prefer jit_foo.peek() for load-time checks
    foo = jit(foo)

The call to jit_foo() will know only about arguments declared before that point, so a call to --help will produce too short a list.
This is remedied later but is better avoided:

jit_foo = Arg("jit-foo", False, "Run JIT on foo")
if jit_foo.peek(): # Just check for this arg in sys.argv
    foo = jit(foo)

PyTree Utils

Various smoothers for torch.utils._pytree.

Given a nest of lists and tuples, perform various useful funcions.

For example, given the object val as follows:

val = ( # tuple
        [ # list
          1, 
          (np.random.rand(2, 3), "b", np.random.rand(12, 13)),
          3
        ], 
        "fred"
      )

Then PyTree.map(foo, val) will make these six calls to foo:

foo(1), 
foo(np.random.rand(2, 3))
foo("b")
foo(np.random.rand(12, 13))
foo(3)
foo("fred")

And given a numeric-only pytree, e.g.

val = ( # tuple
        [ # list
          1, 
          (np.random.rand(2, 3), np.random.rand(12, 13))
        ]
      )

Then arithmetic can be performed on PyTrees, e.g.

(PyTree(val) + val) * val == 2 * PyTree(val) * val