Closed andrzejnovak closed 4 years ago
This looks awesome @andrzejnovak :)
jax didn't want to let me differentiate through just adding two functions
If you can be a bit more specific here, I can try and help with this!
From what I can see, there's no additional markdown in your example notebook -- would it be okay if you could add a brief description of the problem setup, and an explanation of the resulting visualisation for this problem? It will also hopefully render in the docs :D
Sure. I mean this whole bit I had to duplicate for the second model: https://github.com/pyhf/neos/pull/14/files#diff-5a13b7576a4dbd63339b93da3bde3e21R131-R179
I tried to set it up such that the structure would be:
def cls_maker2(nn_model_maker1, nn_model_maker2, solver_kwargs):
@jax.jit
def cls_jax(nn_params, test_mu)
def one_model(nn_model_maker):
.... <- the original cls_maker structure - >
return CL
return one_model(nn_model_maker1) + one_model(nn_model_maker2)
return cls_jax
But jax didn't seem to like the last line.
I'll add some docs
@phinate I factored out the changed exclusively into the new notebook, so this should be good to go. Thanks for helping me fix it up!
Awesome, thanks @andrzejnovak! 👍
@phinate Here's my example of fitting for two models together to divide the phase space into 3 regions.
I think my cls_maker2 could be simplified (put into a function for each model), but jax didn't want to let me differentiate through just adding two functions. I suspect there's a trick to do it somehow.
If this goes well, I'll also add the optimization based on asimov significance as well