google / TensorNetwork

A library for easy and efficient manipulation of tensor networks.
Apache License 2.0
1.82k stars 359 forks source link

Jax precision #825

Closed mganahl closed 4 years ago

mganahl commented 4 years ago

enable setting precision argument for jax.numpy.matmul, jax.numpy.tensordot and other functions used in JaxBackend via tn.set_jax_precision.

codecov-commenter commented 4 years ago

Codecov Report

Merging #825 into master will decrease coverage by 0.07%. The diff coverage is 88.07%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #825      +/-   ##
==========================================
- Coverage   98.44%   98.37%   -0.08%     
==========================================
  Files         128      128              
  Lines       21600    21620      +20     
==========================================
+ Hits        21265    21268       +3     
- Misses        335      352      +17     
Impacted Files Coverage Δ
tensornetwork/backends/jax/jax_backend_test.py 99.15% <ø> (-0.01%) :arrow_down:
tensornetwork/backends/backend_factory.py 72.91% <60.60%> (-27.09%) :arrow_down:
tensornetwork/__init__.py 100.00% <100.00%> (ø)
tensornetwork/backend_contextmanager.py 100.00% <100.00%> (ø)
tensornetwork/backends/abstract_backend.py 95.57% <100.00%> (+0.07%) :arrow_up:
tensornetwork/backends/jax/jax_backend.py 100.00% <100.00%> (ø)
tensornetwork/backends/jax/jitted_functions.py 96.95% <100.00%> (-1.24%) :arrow_down:
...ensornetwork/backends/jax/jitted_functions_test.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 5f57d1a...53b0649. Read the comment docs.

chaserileyroberts commented 4 years ago

I do not like this design. Global things like this make the codebase very messy.

chaserileyroberts commented 4 years ago

How about this design instead.

We could add a config argument to the intializations of the backends.

This would allow us to have configurable backends, which is helpful for some ideas I want to do. Plus, this also prevents us from having a global configuration for a specific backend that may not even be used.

mganahl commented 4 years ago

Sure, would config be a dict?

mganahl commented 4 years ago

I assume current arguments like dtype would become part of config?

mganahl commented 4 years ago

In that case I would suggest that backends are configurable as e.g. tn.set_default_backend('jax', dtype=np.float64, precision='HIGHEST'). A **kwargs argument could simply be passed through to the corresponding backend. Backends could take an arbitrary number of kw-arguments at initialization. I would prefer this over a single variable config

mganahl commented 4 years ago

This will actually also become messy. I think it's better to have a global config file for each backend from where config params are loaded. Otherwise we need to pass the config arguments to the backends via backend_factory.get_backend(), which means changing the code in a lot of places.

mganahl commented 4 years ago

How about this design instead.

We could add a config argument to the intializations of the backends.

This would allow us to have configurable backends, which is helpful for some ideas I want to do. Plus, this also prevents us from having a global configuration for a specific backend that may not even be used.

We should still allow for globally setting config params though (not just only by having the user initialize his own backend).

mganahl commented 4 years ago

@Thenerdstation let me know what you think if this design!

mganahl commented 4 years ago

Hey Chase, how do feel about this? Happy to change it further!

mganahl commented 4 years ago

If it's okay for you we can just pull in the precision stuff of the JAX backend, and leave the configuration for now. The former would be great to have for some TPU tests I'm running!

mganahl commented 4 years ago

I'll convert this to a draft for now. The precision stuff is now in PR #830

mganahl commented 4 years ago

Let’s actually close this PR for now. The most relevant changes are already introduced in the other ones. Adding the precision argument required moving around code into the scope of the variable precision. The code itself hasn’t changed, merely moved to another position.

On Sep 16, 2020, at 9:37 PM, Chase Roberts notifications@github.com wrote:

@Thenerdstation requested changes on this pull request.

In tensornetwork/backends/jax/jitted_functions.py https://github.com/google/TensorNetwork/pull/825#discussion_r489708972:

+

  • def gmres_krylov_work(gmres_carry: GmresCarryType) -> GmresCarryType:
  • """
  • Performs a single iteration of gmres_krylov. See that function for a more
  • detailed description.
  • Args:
  • gmres_carry: The gmres_carry from gmres_krylov.
  • Returns:
  • gmres_carry: The updated gmres_carry.
  • """
  • gmres_variables, gmres_constants = gmres_carry
  • k, V, R, beta_vec, err, givens = gmres_variables
  • tol, A_mv, A_args, bnorm, = gmres_constants
  • V, H = kth_arnoldi_step(k, A_mv, A_args, V, R, tol, precision)
  • R_col, givens = apply_givens_rotation(H[:, k], givens, k)
  • R = jax.ops.index_update(R, jax.ops.index[:, k], R_col[:])
  • Update the residual vector.

  • cs, sn = givens[:, k] * beta_vec[k]
  • beta_vec = jax.ops.index_update(beta_vec, jax.ops.index[k], cs)
  • beta_vec = jax.ops.index_update(beta_vec, jax.ops.index[k + 1], sn)
  • err = jnp.abs(sn) / b_norm
  • gmres_variables = (k + 1, V, R, beta_vec, err, givens)
  • return (gmres_variables, gmres_constants)
  • def gmres_krylov_loop_condition(gmres_carry: GmresCarryType) -> bool:
  • """
  • This function dictates whether the main GMRES while loop will proceed. It's so difficult to know what you're actually changing when you have these huge refactors randomly.

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/TensorNetwork/pull/825#pullrequestreview-489962339, or unsubscribe https://github.com/notifications/unsubscribe-auth/AE7RWE6M6URRHZ6R7SYHBWTSGEHYLANCNFSM4RIFGCPQ.