adding cuda-nvcc to get ptxas to be readily available in the environment (xla relies on ptxas to work correctly). Note missing ptxas would not have necessarily resulted in any errors in tf/jax, but without ptxas (cuda-nvcc), xla simply doesn't work, so it would go back to unoptimized performance.
ptxas isn't available in conda-forge at the moment. It is available in some images (i.e. any image built on top of NVIDIA's NGC images) and it is available on their conda channel (nvidia). I am told it is a licensing issue, but I don't really know. There will likely be a move to bundle it inside jax going forward, but that's beyond me.
@TomAugspurger this is ready for review. Thank you!
xref https://github.com/pangeo-data/pangeo-docker-images/pull/378 for ease, but mainly:
@TomAugspurger this is ready for review. Thank you!