Closed mganahl closed 4 years ago
Merging #825 into master will decrease coverage by
0.07%
. The diff coverage is88.07%
.
@@ Coverage Diff @@
## master #825 +/- ##
==========================================
- Coverage 98.44% 98.37% -0.08%
==========================================
Files 128 128
Lines 21600 21620 +20
==========================================
+ Hits 21265 21268 +3
- Misses 335 352 +17
Impacted Files | Coverage Δ | |
---|---|---|
tensornetwork/backends/jax/jax_backend_test.py | 99.15% <ø> (-0.01%) |
:arrow_down: |
tensornetwork/backends/backend_factory.py | 72.91% <60.60%> (-27.09%) |
:arrow_down: |
tensornetwork/__init__.py | 100.00% <100.00%> (ø) |
|
tensornetwork/backend_contextmanager.py | 100.00% <100.00%> (ø) |
|
tensornetwork/backends/abstract_backend.py | 95.57% <100.00%> (+0.07%) |
:arrow_up: |
tensornetwork/backends/jax/jax_backend.py | 100.00% <100.00%> (ø) |
|
tensornetwork/backends/jax/jitted_functions.py | 96.95% <100.00%> (-1.24%) |
:arrow_down: |
...ensornetwork/backends/jax/jitted_functions_test.py | 100.00% <100.00%> (ø) |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update 5f57d1a...53b0649. Read the comment docs.
I do not like this design. Global things like this make the codebase very messy.
How about this design instead.
We could add a config
argument to the intializations of the backends.
This would allow us to have configurable backends, which is helpful for some ideas I want to do. Plus, this also prevents us from having a global configuration for a specific backend that may not even be used.
Sure, would config
be a dict
?
I assume current arguments like dtype
would become part of config
?
In that case I would suggest that backends are configurable as e.g. tn.set_default_backend('jax', dtype=np.float64, precision='HIGHEST')
. A **kwargs
argument could simply be passed through to the corresponding backend.
Backends could take an arbitrary number of kw-arguments at initialization. I would prefer this over a single variable config
This will actually also become messy. I think it's better to have a global config file for each backend from where config params are loaded. Otherwise we need to pass the config arguments to the backends via backend_factory.get_backend()
, which means changing the code in a lot of places.
How about this design instead.
We could add a
config
argument to the intializations of the backends.This would allow us to have configurable backends, which is helpful for some ideas I want to do. Plus, this also prevents us from having a global configuration for a specific backend that may not even be used.
We should still allow for globally setting config params though (not just only by having the user initialize his own backend).
@Thenerdstation let me know what you think if this design!
Hey Chase, how do feel about this? Happy to change it further!
If it's okay for you we can just pull in the precision stuff of the JAX backend, and leave the configuration for now. The former would be great to have for some TPU tests I'm running!
I'll convert this to a draft for now. The precision stuff is now in PR #830
Let’s actually close this PR for now. The most relevant changes are already introduced in the other ones.
Adding the precision argument required moving around code into the scope of the variable precision
. The code
itself hasn’t changed, merely moved to another position.
On Sep 16, 2020, at 9:37 PM, Chase Roberts notifications@github.com wrote:
@Thenerdstation requested changes on this pull request.
In tensornetwork/backends/jax/jitted_functions.py https://github.com/google/TensorNetwork/pull/825#discussion_r489708972:
+
- def gmres_krylov_work(gmres_carry: GmresCarryType) -> GmresCarryType:
- """
- Performs a single iteration of gmres_krylov. See that function for a more
- detailed description.
- Args:
- gmres_carry: The gmres_carry from gmres_krylov.
- Returns:
- gmres_carry: The updated gmres_carry.
- """
- gmres_variables, gmres_constants = gmres_carry
- k, V, R, beta_vec, err, givens = gmres_variables
- tol, A_mv, A_args, bnorm, = gmres_constants
- V, H = kth_arnoldi_step(k, A_mv, A_args, V, R, tol, precision)
- R_col, givens = apply_givens_rotation(H[:, k], givens, k)
- R = jax.ops.index_update(R, jax.ops.index[:, k], R_col[:])
Update the residual vector.
- cs, sn = givens[:, k] * beta_vec[k]
- beta_vec = jax.ops.index_update(beta_vec, jax.ops.index[k], cs)
- beta_vec = jax.ops.index_update(beta_vec, jax.ops.index[k + 1], sn)
- err = jnp.abs(sn) / b_norm
- gmres_variables = (k + 1, V, R, beta_vec, err, givens)
- return (gmres_variables, gmres_constants)
- def gmres_krylov_loop_condition(gmres_carry: GmresCarryType) -> bool:
- """
- This function dictates whether the main GMRES while loop will proceed. It's so difficult to know what you're actually changing when you have these huge refactors randomly.
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/google/TensorNetwork/pull/825#pullrequestreview-489962339, or unsubscribe https://github.com/notifications/unsubscribe-auth/AE7RWE6M6URRHZ6R7SYHBWTSGEHYLANCNFSM4RIFGCPQ.
enable setting precision argument for
jax.numpy.matmul
,jax.numpy.tensordot
and other functions used inJaxBackend
viatn.set_jax_precision
.