issues
search
google-deepmind
/
dm-haiku
JAX-based neural network library
https://dm-haiku.readthedocs.io
Apache License 2.0
2.91k
stars
231
forks
source link
issues
Newest
Newest
Most commented
Recently updated
Oldest
Least commented
Least recently updated
Allow removal (and config) of biases in `MultiHeadAttention`.
#652
dwf
closed
1 year ago
1
Add MutableParams and MutableState return types for init_fns.
#651
copybara-service[bot]
closed
1 year ago
0
Add a note about more advanced uses of hk.vmap.
#650
copybara-service[bot]
closed
1 year ago
0
Don't assume key shape
#649
copybara-service[bot]
closed
1 year ago
0
Expose lower and upper bounds in TruncatedNormal init
#648
copybara-service[bot]
closed
1 year ago
1
Remove some spurious `transform_and_run` calls.
#647
copybara-service[bot]
closed
1 year ago
0
Better error type for non-empty state (and more granular error catching).
#646
copybara-service[bot]
closed
1 year ago
0
rnn classifies mnist
#645
never-to-never
opened
1 year ago
3
Make haiku compatible with jax.enable_custom_prng
#644
copybara-service[bot]
closed
1 year ago
0
Fix assert_is_prng_key when jax_enable_custom_prng=True
#643
copybara-service[bot]
closed
1 year ago
0
Remove unused import.
#642
copybara-service[bot]
closed
1 year ago
0
Added a few missing annotations to data_structures.Stack
#641
copybara-service[bot]
closed
1 year ago
0
Remove hk.experimental.named_call.
#640
copybara-service[bot]
closed
1 year ago
0
Improving compilation speed for repeated layers
#639
davisyoshida
closed
1 year ago
0
Use new config context in integration test.
#638
copybara-service[bot]
closed
1 year ago
0
Avoid an extra device-to-host copy in assert_is_prng_key()
#637
copybara-service[bot]
closed
1 year ago
1
For readability & to remove duplicate constant: re-use constant PAD_TOKEN when checking not in corpus.
#636
copybara-service[bot]
closed
1 year ago
1
Orthogonal initilizer does not support MLP?
#635
rezunli96
closed
1 year ago
1
DeprecationWarning `transform_with_state` because of `jax.xla` vs `jax.interpreters.xla`
#634
joeryjoery
opened
1 year ago
0
Changing deprecated usages from Jax.
#633
copybara-service[bot]
closed
1 year ago
1
[Question] Is it perfectly okay to use hk.next_rng_key() inside of a hk.scan function? This takes place inside a call method of a haiku module.
#632
EdanToledo
opened
1 year ago
0
Update MNIST examples to do perfect shuffle then repeat
#631
copybara-service[bot]
closed
1 year ago
1
Skip TPU test for jax2tf native_serialization migragtion.
#630
copybara-service[bot]
closed
1 year ago
0
Add hk.experimental.{maybe_get_rng_sequence_state,maybe_replace_rng_sequence_state}.
#629
copybara-service[bot]
closed
1 year ago
1
Remove CompiledFunction from Haiku since the default jax version >= 0.4.6
#628
copybara-service[bot]
closed
1 year ago
0
Bump JAX version used for testing to 0.4.6.
#627
copybara-service[bot]
closed
1 year ago
0
Delete xla_call_p since it has been replaced with pjit.pjit_p
#626
copybara-service[bot]
closed
1 year ago
0
Revert: `custom_vjp` symbolic zeros support
#625
copybara-service[bot]
closed
1 year ago
0
Remove C++ jax.jit support. Also remove `GetEnableJaxArray` since that is always True.
#624
copybara-service[bot]
closed
1 year ago
0
Currently the test ReshapeTest.test_reshape_convert and Jax2TfTest
#623
copybara-service[bot]
closed
1 year ago
0
Is it impossible to turn a sequence of identical instance blocks into a compiled loop?
#622
cmunna0052
closed
1 year ago
2
Remove _PositionalSemantics class since it is not used anymore because jax.Array always has GLOBAL semantics
#621
copybara-service[bot]
closed
1 year ago
0
Add hk.experimental.{get_params,get_initial_state,get_current_state}.
#620
copybara-service[bot]
closed
1 year ago
0
Add `__signature__` to `hk.ModuleMetaclass`.
#619
copybara-service[bot]
closed
1 year ago
0
Replaces references to jax.numpy.DeviceArray with jax.Array.
#618
copybara-service[bot]
closed
1 year ago
0
Suppress some pytype errors related to jnp.DeviceArray == jax.Array.
#617
copybara-service[bot]
closed
1 year ago
0
[JAX] Fix type error in Haiku transformer train example.
#616
copybara-service[bot]
closed
1 year ago
0
[JAX] Fix up uses of PyTree and PyTreeDef types.
#615
copybara-service[bot]
closed
1 year ago
0
Internal change
#614
copybara-service[bot]
closed
1 year ago
0
Add context manager api for setting mixed precision policies.
#613
copybara-service[bot]
closed
1 year ago
0
Replace jnp.ndarray with jax.Array in type annotations.
#612
copybara-service[bot]
closed
1 year ago
0
[Haiku] Remove str argument from NextGetter type definition.
#611
copybara-service[bot]
closed
1 year ago
0
Analog to `flax.struct.dataclass`
#610
homerjed
closed
1 year ago
2
Changed how hk is defined in _src submodules
#609
copybara-service[bot]
closed
1 year ago
0
Fix _param_axis_passed_explicitly computation in layernorm.
#608
copybara-service[bot]
closed
1 year ago
1
Fix or ignore some pytype errors related to jnp.ndarray == jax.Array.
#607
copybara-service[bot]
closed
1 year ago
0
Haiku initialization
#606
adhikarirsr
opened
1 year ago
0
Fix pytype failures in Haiku if jnp.ndarray is defined as jax.Array.
#605
copybara-service[bot]
closed
1 year ago
0
Added a few missing imports to dot.py and jaxpr_info.py
#604
copybara-service[bot]
closed
1 year ago
0
[JAX] Replace uses of jax.xla_computation() with jax.jit().lower().
#603
copybara-service[bot]
closed
1 year ago
0
Previous
Next