rdyro / torch2jax

Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
39 stars 1 forks source link

Quick question about Overview Code Examples #2

Closed adam-hartshorne closed 1 year ago

adam-hartshorne commented 1 year ago

In the overview, in the single output usage example the "definition" call to torch2jax includes an output_shapes parameter. However, in the multiple outputs, no output_shapes parameter is provided.

jax_fn = torch2jax(torch_fn, a, b, output_shapes=Size(a.shape)) # torch_fn will NOT be evaluated vs jax_fn = torch2jax(torch_fn, a, b) # with example arguments

Obviously, I understand that the second example returns multiple values. Is an output_shapes required? If not why is it provided in the first definition for a single output and not for multiple outputs?

Secondly, when "defining" your torch2jax function, a and b are torch tensors. Would it be possible to rather than having to instantiate these tensors rather provide shapes / type information as arguments?

Also, what happens if you pass in an array that is a different size from the defined sizes, particularly smaller size. I am thinking specifically here if you have batched data and the last batch is smaller than the defined batch size. Do you need to ensure its the correct size using padding?

rdyro commented 1 year ago

Hey, thanks for engaging with the package!

Providing output shapes (or rather output shapes and dtypes) lets torch2jax skip calling the Pytorch function to determine its output structure (number of arguments, shapes, and dtypes).

If no output_shapes is provided, then torch2jax will have to call the PyTorch function, so the example arguments must be valid torch tensors. However, if output_shapes is provided, then torch2jax skips the example evaluation step, and example arguments just have to carry information about their shapes and dtypes, so they can equally well be:

It's definitely possible to skip actual tensor instantiations, if output_shapes are provided. You still have to provide example inputs, but they only have to carry shape and dtype information, so they can be structs (like JAX's helper class jax.ShapeDtypeStruct or JAX array instantiations.

Note also that these arguments can be nested or be, e.g. dictionaries. Pretty much anything that pytrees allow (so a nesting of Python containers).

Unfortunately, both input and output shapes have to be exactly fixed, which I believe is mostly how JAX does this (if you call a jitted JAX function with another set of arguments, it'll recompile it). At this time, I am not sure how to support dynamic recompilation for a change in argument shape. However, at this point, you could simply try to define a new torch2jax wrapped function for the last row in the batch. The custom op that torch2jax defines is both really cheap to define and compiles extremely fast because JAX just considers it a black box.

I added a notebook in examples (examples/input_output_specification.ipynb) which goes over examples of alternatives in defining the input/output structure.

adam-hartshorne commented 1 year ago

Initialising Shape Sizes

That's perfect. I personally would lean to showing the jax.ShapeDtypeStruct approach as the base way of doing it on the github front page, as otherwise, people might get the impression you have to instantiate tensors a-prior and ShapeDtypeStruct is widely used already.

(if you call a jitted JAX function with another set of arguments, it'll recompile it)

Dynamic shapes and XLA / JIT are a long-running sore when it comes to using JAX.

However, when you run a training loop with batches and the final batch comes along where the data is smaller size, as far as I am aware it handles that without any recompile or making two copies of the to be jitted function. I wonder how it is handling that?

Edit - I use https://github.com/birkhoffg/jax-dataloader , I haven't actually checked if it doing padding automatically of that last batch itself?

rdyro commented 1 year ago

That's a good idea actually, I updated the README to show how to specify the function without instantiating the data.

Yes, unfortunately.

I'm not sure, anecdotally I notice compilation time when calling a previously compiled function with arguments with slightly different shapes in my work. torch2jax uses JAX's compilation, so it should behave similarly.

Your question made me think it's probably just ok to call torch2jax transform directly in a JAX function. I added some tests for this and this example works. Maybe that's how you can handle a changing batch size.

@jax.jit
def compute(a, b, c):
    d = torch2jax_with_vjp(
        torch_fn,
        jax.ShapeDtypeStruct(a.shape, dtype_t2j(a.dtype)),
        jax.ShapeDtypeStruct(b.shape, dtype_t2j(b.dtype)),
        output_shapes=jax.ShapeDtypeStruct(a.shape, dtype_t2j(a.dtype)),
    )(a, b)
    return d - c

print(compute(a, b, a))