flaport / inverse_design

https://flaport.github.io/inverse_design
Apache License 2.0
16 stars 5 forks source link

`value_and_grad` on the straight thru estimator #13

Open jan-david-fischbach opened 1 year ago

jan-david-fischbach commented 1 year ago

I seem to run into the problem, that if I use jax.value_and_grad I do not actually go through the forward pass of the generator. I think this is related to these lines: https://github.com/flaport/inverse_design/blob/86486d609d447a5247e5b8fe97f6e24ebdd376e3/inverse_design/conditional_generator.py#L188-L190

I believe one could solve that by passing the primals through the generator in the customjvp. This however leads to unneccesary computational cost if I only want the gradients i.e. use the jax.grad function. Any Idea how to get it to work efficiently in both cases?

lucasgrjn commented 1 year ago

Hum... Jax is actually build to return the result of the function AND the tangents using custom_jvp. The only way to do it will be to use a 'hack'. The simplest will be to define two different functions depending which mode to use (either 'value_and_grad' or 'grad'). But it is really dirty...