pasqal-io / qadence

Digital-analog quantum programming interface
https://pasqal-io.github.io/qadence/latest/
Apache License 2.0
64 stars 17 forks source link

[Differentiability, Refactoring] Rethink parameter dictionaries in backends / Introduce hybrid differentiation modes #255

Open dominikandreasseitz opened 7 months ago

dominikandreasseitz commented 7 months ago

Issue:

Right now, when we do:

quantum_backend = SomeBackend()
conv = quantum_backend.convert(circuit, obs)
conv_circ, conv_obs, embedding_fn, params = conv

we store all of the following in the initial params dict:

(a) all variational user-facing parameters of the circuit AND observable (b) all fixed parameters in both circuit AND observable

Issue 1: when using torch, the (torch-based) backend then knows for which params to compute gradients via the requires_grad flag. however, this doesnt work for JAX.

issue 2: both diff_modes ADJOINT and GPSR do not support parametric observables

Possible Solution:

Introduce separate parameter dicts for initial fixed and vparams in both circuit and observable: initial_params = {'circuit_vparams': ..., 'circuit_fixedparams': ..., 'obs_vparams': ..., 'obs_fixedparams': ..., }

  1. This way, we can easily differentiate between v and fixed params in JAX
  2. We can start thinking about introducing hybrid diff_modes, circuit_diffmode = GPSR / ADJOINT, observable_diffmode= AD and use a certain diff routine on subsets of the parameters
nmheim commented 7 months ago

If I understand correctly, then the problem is that conv.params returns a dict that contains both fixed and variational parameters, right? would it be easier/more elegant to introduce a conv.vparams and conv.circuit.vparams (+same for observable) that returns only the variational parameters? then we don't have to change all the code that assumes conv.params to be one non-nested dict.

dominikandreasseitz commented 7 months ago

yes great idea, but i would try to avoid changing the low-level interface so i would be inclined to keep the conv.params and let it just return the composition of conv.circuit.vparams,...

next question: do we want to give the user the option to choose which diff_mode to use for a particular part of the model?

GJBoth commented 7 months ago

If I remember correctly, I originally designed this with @awennersteen. Indeed we didn't consider fixing some parameters, and having different grad backends. There's two options, both without changing the API, I think:

  1. Have the parameters always the returned dict always be trainable, and let the embedding_fn take care of adding in the non-trainable params. In my opinion not a great option for various reasons, but possible.
  2. Instead of the params being a basic dict, make it a slightly more involved object with for example trainable_params, fixed_params, and the corresponding AD rules for each group. This is very Jax style (have a look at optax) I think this might cover everything we need. This is my preferred option, and I think in line with @dominikandreasseitz idea, if I understand it correctly?
awennersteen commented 7 months ago

next question: do we want to give the user the option to choose which diff_mode to use for a particular part of the model? IMHO, this sounds dangerous. But I guess it makes a lot of sense if consider for example a hybrid model where we have a classical NN composed with a QNN. I think that it should be strictly defined.

I, like @GJBoth, have no recollection of why we did this and what we may have considered or not :p The one thing I do remember was that after the initial design over the next month or so there where many hacks and patches to make it actually work...

Since @nmheim was asking about namedtuples the other day, maybe this is another place to use them? so that we have a more solid object, we keep all the different data in there (Gert-Jax' option number 2), and then go for it?

My only concern is how this might behave together with the idea of different diff-modes for different parts? But maybe this is the best way of achieving that too? Suppose we end up saying that in order to use different Diff modes you would achieve this by composing multiple QuantumModels (or maybe DifferentiableBackends or whatever is the current name). Then in this namedtuple keeping track of parameters we could also keep track of which model they belong to. So then by using Gert-Jax' idea of "looking up the AD rules for each group" I guess that could be achieved arbitrarly. This is quickly overengineering though and we should think about that perhaps before implementing that part.