qiskit-community / qiskit-dynamics

Tools for building and solving models of quantum systems in Qiskit
https://qiskit-community.github.io/qiskit-dynamics/
Apache License 2.0
105 stars 61 forks source link

Parameter dependent simulation for optimal bayesian experimentation - JAX @jit compatibility #243

Closed abailly-at-anl closed 8 months ago

abailly-at-anl commented 1 year ago

What is the expected behavior?

I am trying to parameterize simulations by decoherence rates and Hamiltonian parameters. In trying to implement this functionality, we have run into lots of compatibility issues with JAX and just-in-time compiling. It notably is not possible to return 'Solver's or 'Signal's or to receive them as parameters.

I attached a Jupyter Notebook which outlines what we have tried so far. As it does not appear possible to create a new Solver instance in a jit compiled function, we instead edit the static_hamiltonian of an existing Solver. However, it does not seem that the same can be done with the hamiltonian_operators or lindblad_dissipators.

Is there a better way to go about implementing these parameter dependent simulations? If not, it would be greatly appreciated if the Qiskit team built in greater compatibility with JAX just-in-time compiling in this regard. Parameter Dependent Simulation.pdf

DanPuzzuoli commented 1 year ago

Hi, thanks for the inquiry! Can you attach the original ipynb file rather than a pdf? I can show you some options for how to get this to work.

abailly-at-anl commented 1 year ago

Hi Daniel!

Sorry for the delayed reply. I attached the ipynb here, as I couldn't attach it to the GitHub issue. Thank you for your help!


From: Daniel Puzzuoli @.> Sent: Wednesday, July 12, 2023 5:02 PM To: Qiskit-Extensions/qiskit-dynamics @.> Cc: Bailly, Atlas Sebastien @.>; Author @.> Subject: Re: [Qiskit-Extensions/qiskit-dynamics] Parameter dependent simulation for optimal bayesian experimentation - JAX @jit compatibility (Issue #243)

Hi, thanks for the inquiry! Can you attach the original ipynb file rather than a pdf? I can show you some options for how to get this to work.

— Reply to this email directly, view it on GitHubhttps://github.com/Qiskit-Extensions/qiskit-dynamics/issues/243#issuecomment-1633267290, or unsubscribehttps://github.com/notifications/unsubscribe-auth/BAWEG6IUGRJY6RGITMNAQH3XP4NI5ANCNFSM6AAAAAA2ICBTSM. You are receiving this because you authored the thread.Message ID: @.***>

DanPuzzuoli commented 1 year ago

I'm not sure if I'm missing the attachment in the email, but I don't see the notebook anywhere. Are you not able to attach a notebook in a comment? Maybe try zipping it?

abailly-at-anl commented 1 year ago

I had attached the notebook to the email, I zipped and attached it here also. Let me know if that works!

Parameter Dependent Simulation.zip

DanPuzzuoli commented 1 year ago

I've attached an updated version of the notebook. I've put some alternate code you can use to compile these things into the section labelled Dan Alternate code suggestion.

Parameter Dependent Simulation.ipynb.zip

I've written some explanation in there, but what I'd suggest for you is to drop the RWA in the construction of the Solver, and with this you can actually build the Solver directly within the function you want to compile, so long as you set validate=False. Both the RWA code and the validation code are not JAX compatible as they both depend on the values in the model operator arrays, which compiled functions can't depend on (the RWA could probably be made JAX compatible but it's extremely low priority).

As an aside: In this case the simulation is just as performant with/without the RWA. What I've observed with this package is that a lot of the numerical benefits of the RWA are actually already present just in entering the rotating frame (without actually doing the approximation).

One issue I couldn't resolve with playing with your notebook is that there's actually a discrepancy in the output of your function v.s. the one I've just made, even if I drop the RWA from yours. I haven't been able to figure this out, though I haven't dug too deeply. I am more inclined to trust my version, as modifying the Solver after the fact could result in sketchy behaviour. (In fact I forgot these setter functions for the operators even existed. I would need to think about it again, but it might make sense to even remove these to discourage this behaviour. It's been a while since I've worked on this code, but I typically treat the Solver as immutable once I've created it.) If you are able to determine why there is a discrepancy I'd be interested to know.

An alternative to what I've shown here is to still create the Solver outside of the jitted function, but to put the parts of the model that you want to modify into hamiltonian_operators and dissipator_operators. These are the terms that are meant to have their coefficients updated on the fly. I think however what I've written is a bit more natural than this.

Lastly, I was amazed to discover that:

def signal_from_input(pulse_input):
    amp = Array(pulse_input[0])
    w = pulse_input[1]
    signal = [Signal(amp, carrier_freq = 1.)]
    signal[0].carrier_freq = w
    return signal

can't be changed to the following and still be jax-compilation compatible:

def signal_from_input(pulse_input):
    amp = Array(pulse_input[0])
    w = pulse_input[1]
    signal = [Signal(amp, carrier_freq = w)]
    return signal

I'm going to create an issue that this should be fixed.

abailly-at-anl commented 1 year ago

Thanks for the suggestions, I will experiment with constructing the Solver directly in the method to see if it's more performant.

I also noticed a similar discrepancy to the one you mentioned this morning. In the original notebook I uploaded, I am using a separate function to update the Hamiltonian. Copying the same code from update_static_hamiltonian into the simulator function itself produces yet new different results. I'm not sure what's going on there. I updated the notebook again so you can see what I mean.

I was likewise surprised about the carrier_freq issue, which I only realized on writing this example notebook. Parameter Dependent Simulation again.zip

DanPuzzuoli commented 1 year ago

Your new function is extracting r, w, b from parameters, but you need to change it to hamiltonian_parameters. After I make this change it agrees with your original function.

Also btw you are setting the initial construction with

r, w, B = parameters

but then later the order is changed in your functions as:

    w = parameters[0]
    r = parameters[1]
    B = parameters[2]

not sure if this will change the comparison but could be another source of mixing things up.

DanPuzzuoli commented 1 year ago

Okay so if I get rid of the setting of the Hamiltonian parameters in your function, and pass hamiltonian_parameters=parameters into my version of the function, the results agree 🎉 . If I change parameters and walk through the whole notebook again they keep agreeing.

So, something fishy is definitely going on with updating model operators. If you don't mind I'll change the name of this issue to point to this specific bug. Unless you disagree, I feel your issue has more-or-less been resolved, and now what remains as far as Dynamics-development is concerned is this remaining problem (along with the issue about Signal carrier frequency). I'm thinking it may make sense to simply make the models immutable.

DanPuzzuoli commented 1 year ago

@abailly-at-anl just fyi that PR #247 will fix the carrier frequency tracing issue #245.

DanPuzzuoli commented 8 months ago

Closing this issue as the discussion is out of date.