Rather than being unnecessary, it actually appears to be problematic:
UserWarning: The jitted function <unnamed function> includes a pmap. Using jit-of-pmap can lead to inefficient data movement, as the outer jit does not preserve sharded data representations and instead collects input and output arrays onto a single device. Consider removing the outer jit unless you know what you're doing. See [https://github.com/google/jax/issues/2926].
As stated in the docs, there is no need to jit pmaps: https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#pmap-and-jit
Rather than being unnecessary, it actually appears to be problematic:
UserWarning: The jitted function <unnamed function> includes a pmap. Using jit-of-pmap can lead to inefficient data movement, as the outer jit does not preserve sharded data representations and instead collects input and output arrays onto a single device. Consider removing the outer jit unless you know what you're doing. See [https://github.com/google/jax/issues/2926].