Open swamidass opened 1 year ago
Hello, @swamidass! I hope you are having a nice christmas holiday.
I've been considering getting into Jax, just never had the time. Mostly use TF myself, but use torch from time to time.
I might make an attempt to add Jax backend, if it is cruicial for a study of yours. Do you have a deadline when you need this feature to be ready?
I have made PRs to add full backend support for reinhard and a modified reinhard variant a few days ago. Could add Jax support, but would need to make an attempt first.
Thanks for the response, and I hope you also are having a nice holiday.
We are in the middle of porting our WSI pipeline from tensorflow to Jax. So the there isn't a rush, because we are still using tenorflow. However, we would be well situated to test a jax backend.
We are in the middle of porting our WSI pipeline from tensorflow to Jax.
That's interesting. What is the main reason you are moving to Jax from TF? I was considering it, but for training purposes tf/keras offer most of what I need. For deployment, on the other hand, I find that using the inference engines such as ONNX Runtime or TensorRT are better, but that is not a problem as "all" TF models can be converted to the ONNX format, which both IEs support.
However, we would be well situated to test a jax backend.
I'm partly back from vacation now and will likely make an attempt at the Jax backend soon. Will keep you updated :]
It was too tempting...
Jax backend with Macenko seems to work. Did not spend lots of time with it. Hence, there are probably lots of ways to optimize it and whatnot, but at least we have a baseline that you could test and benchmark, if you'd like, @swamidass. Then we can try to improve it later on if necessary.
To test, be sure to install from this branch by:
pip install git+https://github.com/andreped/torchstain.git@jax-backend
Then to use, simply do:
import torchstain
import numpy as np
normalizer_jax = torchstain.normalizers.MacenkoNormalizer(backend='jax')
normalizer_jax.fit(target)
norm_jax, _, _ = normalizer_jax.normalize(to_transform)
norm = np.asarray(norm_jax)
When @carloalbertobarbano is back from vacation, we will resolve some existing PRs and then I can start adding Jax to the main branch for all stain normalization techniques. Should be quite seemless, as Jax has a surprisingly neat numpy-like API, when using jax.numpy
as a direct replacement for numpy
.
When all that is working, I believe we are ready for a new release :] Will keep you updated, @swamidass.
Very nice! In our experience, JAX is substantially faster than TF and pytorch, which is why we are migrating to it.
So it will be interesting to see if that's true in this case too.
One thing you want to be sure to test though, is whether or not the entrypoint you are providing is compilable. You can test this easily (and it should be a test case) with this modification to your code:
@jax.jit
def jax_normalize(to_transform):
norm_jax, _, _ = normalizer_jax.normalize(to_transform)
return norm_jax
norm_jax = jax_normalize(to_transform)
If that throws an error. there is more work to do. Usually, there will be an internal function you are calling that needs to be wrapped in a jax.jit, identifying what are the static_args.
If that does not throw an error, you might be done!
So, running this, we do get an error. It's a classic example of where code needs to be refactored for JAX:
ODhat = OD[~jnp.any(OD < beta, axis=1)]
This needs to be refactored so that all intermediate matrices are a fixed size. The boolean select here yields an array with indeterminate size. The way to refactor this is with a mask.
I'm willing to make some of the changes, if I can. If i do, would you mind adding me as an author? I was also thinking of adding a utility we developed to enable parallelized application of this to large slides too.
Also so a few of these
Inorm.at[Inorm > 255].set(255)
They are no ops as written. You'd need to do:
Inorm = Inorm.at[Inorm > 255].set(255)
But I'm not srue they will compile, regardless. A better way to write that is:
Inorm = jnp.where(Inorm > 255, 255, Inorm)
Its rather late in Norway, but I can try to incorporate your ideas, @swamidass. Give me a second, and I will make a commit if I get it working as intended.
Hmm, as far as I can tell, there does not seem to exist a masking mechanism in jax as of yet to mimic y = x[mask]
. Rather strange I must say, but there are numerous threads about it.
Even more surprising: https://github.com/google/jax/issues/11557
Any ideas? It is this line that is giving me a headache:
ODhat = OD[~jnp.any(OD < beta, axis=1)]
I got further, but the masking in Jax is giving me a headache. At least now, assuming there is a fix for this one line, adding the @jax.jit decorator works and hence, Macenko should work with Jax-backend.
You do not need to add it yourself externally, it is added directly within the class (see here).
Just reinstall the latest version of the same branch and run the same commands as mentioned previously to test: https://github.com/EIDOSLAB/torchstain/issues/31#issuecomment-1369004760
I'm willing to make some of the changes, if I can. If i do, would you mind adding me as an author? I was also thinking of adding a utility we developed to enable parallelized application of this to large slides too.
@swamidass Oh, and of course, contributors are always welcome. Regarding authorship, I am not the owner of this project, just contributing to it, but I'm open to the idea :] You have already been helpful in the Jax-backend implementation.
I will take a look at it in a moment. but regarding:
ODhat = OD[~jnp.any(OD < beta, axis=1)]
The trick is to refactor the code so ODhat isn't needed. I believe the correct solution would be to change line 61,
_, eigvecs = jnp.linalg.eigh(jnp.cov(ODhat.T))
To, something close to:
mask = ~jnp.any(OD < beta, axis=1)
cov = jnp.cov(OD.T, fweights = mask)
_, eigvecs = jnp.linalg.eigh(cov)
That leaves leaves lines 27-32 to be refactored,
That = ODhat.dot(eigvecs[:, 1:3])
phi = jnp.arctan2(That[:, 1], That[:, 0])
minPhi = jnp.percentile(phi, alpha)
maxPhi = jnp.percentile(phi, 100 - alpha)
Into something like:
Th = OD.dot(eigvecs[:, 1:3])
phi = jnp.arctan2(Th[:, 1], Th[:, 0])
phi = jnp.where(mask, phi, jnp.inf)
pvalid = mask.mean() # proportion that is valid and not masked
minPhi = jnp.percentile(phi, alpha * pvalid)
maxPhi = jnp.percentile(phi, (100 - alpha) * pvalid)
I think those two changes will make the fix complete.
Also, the way you are doing jit in the code base is incorrect. You should not do:
@partial(jax.jit, static_argnums=(0,))
def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):
The problem is that the 'stains' argument is also a static argument, and as written, it will not work correctly with self. This is one of the weird things about Jax, which they explain at length: https://jax.readthedocs.io/en/latest/faq.html?highlight=self#how-to-use-jit-with-methods. You are taking Strategy #2, and the docs explain how that breaks.
Incorporated your suggestions and it seems to work. Great! :) Cheers!
Will add a unit test later on to verify that we get "identical" results compared to the other backends.
Its 7 AM in Norway right now, so I think I will head to bed. Just started using Jax today, so lots of new things to consider, such as how to use jit with methods.
I will look into Strategy #3 tomorrow, which seems like the appropriate approach, but pushing what I have now to the jax-backend
branch, for you to test, if you want.
JAX Support is a good idea, given we already have tf and torch. We can mark it as experimental for now
JAX Support is a good idea, given we already have tf and torch. We can mark it as experimental for now
I have the PR ready which adds Jax backend support for Macenko (as discussed above). Will make a PR after all the existing PRs have been merged.
@andreped can you link me to it to review for any issues?
@andreped can you link me to it to review for any issues?
The edits are in this branch. I noticed now that I had not yet fixed the methods thing we discussed above. I have been focusing on my own PhD work lately. Note that it was Reinhard we are adding jax backend support for, not Macenko (yet).
I have a deadline for tomorrow, but I can make an attempt to fix that after. Atlernatively, if you have time, @swamidass, you can fork my branch and add the final fix there.
We can look into adding Macenko (+ modified Reinhard) jax backend support in a future PR, if that is of interest to you, @swamidass :]
But it might be a good idea to wait until the development branch has been merged with the main, as there will likely be several merge conflicts.
Added the PyTree fix as we discussed in a new commit, @swamidass. See latest version in branch here.
At least it runs with the @jax.jit
decorator. Added to both normalize
and fit
methods.
Not sure if I understood the whole static/dynamic value thing in the PyTree, as for this specific class, it made sense to not have any. What do you think? Note that the class itself does not have any arguments, but maybe that is required to work properly? Not sure.
Oh, important note! Just added a unit test which is ran for each new commit, and I observed that the Jax output differ from the numpy output.
From the look of the CI log, by quite a lot: https://github.com/andreped/torchstain/actions/runs/3982389028/jobs/6826815783
I swear it used to work, at least before we added the three main fixes discussed above to get it working with the @jax.jit
decorator. Can see if I can make a gist tomorrow which reproduces the issue, which we can use to debug this further, and hopefully finalize this feature.
EDIT: @carloalbertobarbano I believe this feature is not ready for the upcoming new release v1.3.0. It is not a critical feature either, as I believe @swamidass do not need it for his current study.
I believe I have fixed it. It both yields output that visually appear similar to the numpy backend output and passes the unit test.
Hence, I made a PR that adds Macenko JAX backend support: https://github.com/EIDOSLAB/torchstain/pull/36
Runtime-wise the current JAX implementation is a lot slower than the alternative backends. There is a long thread likely explaining why that might be the case here. However, it is also likely that further improvements can be made to further optimize the JAX backend, but right now I'm happy with having something that runs.
I ran a simple benchmark (single run) which yielded:
backends | numpy | jax | torch | tf |
---|---|---|---|---|
runtime [s] | 0.455 | 2.427 | 0.201 | 0.442 |
If you have time, @swamidass, it would be great if you could review the implementation in the PR.
Thanks for doing this. Performance issues are important, but not required to solve for the first implementation.
Thanks for doing this. Performance issues are important, but not required to solve for the first implementation.
Always happy to contribute :] Let me know when you start testing it further, and if any modifications are made to make it faster. And of course, PRs are always welcome!
When you have time @carloalbertobarbano, you can make the next release :] After the PRs have been merged into the development
branch, and merged with master of course.
Excellent tool. We are likely to cite you in the future.
Would you consider building out a Jax backend?