google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.94k stars 2.74k forks source link

Jax decompiler #13398

Open PierrickPochelu opened 1 year ago

PierrickPochelu commented 1 year ago

A jax decompiler would take jaxpr code and produce a more readable Python code. Even if some information about the original function is lost (obfuscated code) like variable names being lost. Decompilers are important tool for reverse-engineering.

Here the illustration of the usefulness of a decompiler.

(step a) f(x) is my function:

from jax import numpy as jnp
f=lambda x: jnp.log(1+jnp.exp(x)

(step b) The derivative does not give the right answer: print(grad(f)(100.)) # nan <- expected 1

(step c) The JAXPR code of the derivative (along x axis) is:

from jax import make_jaxpr
make_jaxpr(grad(f))(100.)

output: { lambda ; a:f32[]. let b:f32[] = exp a c:f32[] = add 1.0 b _:f32[] = log c d:f32[] = div 1.0 c e:f32[] = mul d b in (e,) }

(step d) The Python equivalent (manually written) is:

df=lambda a: (1 / (1 + jnp.exp(a))) * jnp.exp(a)
print(df(100.)) # nan <- expected 1

(step e) We can easily understand the problem. I refactored a little bit the code and improved the arithmetic stability:

df=lambda x: 1 if x>10 else (jnp.exp(x) / (1 + jnp.exp(x)))
print(df(100.)) # 1. <- expected answer

A decompiler would automate step d.

jakevdp commented 1 year ago

Thanks for the suggestion! It's an interesting idea. There would be some complexity involved for a full solution; off the top of my head:

It could be a fun challenge though 😁

PierrickPochelu commented 1 year ago

I will release a first version working for simple code samples in next days.

jakevdp commented 1 year ago

Cool - if you haven't seen it already, this might be a useful resource for crawling/interpreting jaxprs: https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html

PierrickPochelu commented 1 year ago

I implemented a first version: https://github.com/PierrickPochelu/JaxDecompiler

For instance, it supports 23 common jaxpr operators such as "add, mul, neg, cos, sin,...". It supports also partially pmap.