Open JordiBolibar opened 2 years ago
This is just an upstream performance issue with the pmap adjoint. My student should put in a PR by the end of the week.
https://github.com/JuliaDiff/ChainRules.jl/pull/566 is the solution here.
If I understood things correctly from the PR, that fix will only be effective for batch_size = 1
for now? For my case I'm working with bigger batch files, so I'm just trying to anticipate things. Thanks!
As discussed with @gaurav-arya, we have verified that his fix https://github.com/JuliaDiff/ChainRules.jl/pull/566 solves this issue with a couple of MWEs. I'll close this once it will be merged. Thanks Gaurav!!!
As discussed on Slack, this is just to track the fix for the performance issue using
EnsembleDistributed
. There are performance gains on the forward solve when usingEnsembleDistributed
, but there is an issue on the reverse, where no performance improvements can be seen compared toEnsembleSerial
.