Open PierrickPochelu opened 1 year ago
Thanks for the suggestion! It's an interesting idea. There would be some complexity involved for a full solution; off the top of my head:
scatter_p
, gather_p
, conv_general_dilated_p
, etc.?It could be a fun challenge though 😁
I will release a first version working for simple code samples in next days.
Cool - if you haven't seen it already, this might be a useful resource for crawling/interpreting jaxprs: https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html
I implemented a first version: https://github.com/PierrickPochelu/JaxDecompiler
For instance, it supports 23 common jaxpr operators such as "add, mul, neg, cos, sin,...". It supports also partially pmap.
A jax decompiler would take jaxpr code and produce a more readable Python code. Even if some information about the original function is lost (obfuscated code) like variable names being lost. Decompilers are important tool for reverse-engineering.
Here the illustration of the usefulness of a decompiler.
(step a) f(x) is my function:
(step b) The derivative does not give the right answer:
print(grad(f)(100.)) # nan <- expected 1
(step c) The JAXPR code of the derivative (along x axis) is:
output: { lambda ; a:f32[]. let b:f32[] = exp a c:f32[] = add 1.0 b _:f32[] = log c d:f32[] = div 1.0 c e:f32[] = mul d b in (e,) }
(step d) The Python equivalent (manually written) is:
(step e) We can easily understand the problem. I refactored a little bit the code and improved the arithmetic stability:
A decompiler would automate step d.