issues
search
google-deepmind
/
kfac-jax
Second Order Optimization and Curvature Estimation with K-FAC in JAX.
Apache License 2.0
250
stars
23
forks
source link
issues
Newest
Newest
Most commented
Recently updated
Oldest
Least commented
Least recently updated
Minor internal changes to optimizer code.
#303
copybara-service[bot]
opened
1 hour ago
0
Adding support for keyword arguments to staged methods.
#302
copybara-service[bot]
opened
3 hours ago
0
Making pmap axis names consistent in examples code to support things like cross-replica batch norm layers.
#301
copybara-service[bot]
opened
4 days ago
0
Going back to old debug mode behavior, but with fixed handling of non-broadcast "scalar" params passed to staged functions. Also putting disable_jit around method calls to prevent compilation in JAX control flow constructs.
#300
copybara-service[bot]
closed
8 hours ago
0
Add a check to TNT scale to deal with zero factors.
#299
copybara-service[bot]
closed
1 week ago
0
chore(build): drop `MANIFEST.in`
#298
SauravMaheshkar
opened
1 week ago
2
- Adding support in the graph scanner for Haiku & Flax normalization layers without learnable shift/offset params.
#297
copybara-service[bot]
closed
1 week ago
0
Add graph scan support for arbitrary number of input and output dimensions for dense layers
#296
copybara-service[bot]
closed
1 week ago
0
Make conv2d tag graph matcher more general
#295
copybara-service[bot]
closed
1 week ago
0
Add reshape parameter to normalization tag to find flax LayerNorms
#294
copybara-service[bot]
closed
2 weeks ago
0
Adding step rejection feature
#293
copybara-service[bot]
closed
2 weeks ago
0
Enabled greater range of preconditioner powers. Some math utilities added.
#292
copybara-service[bot]
closed
1 week ago
0
Bumping JAX version requirement to 0.4.25 due to requirement of jax.tree API.
#291
copybara-service[bot]
closed
2 weeks ago
0
- Using version guard to fix change that broke backwards compatibility with some older versions of JAX.
#290
copybara-service[bot]
closed
2 weeks ago
0
Removing overzealous check that broke QMC. It turns out there are non-higher-order equations with more than 1 output.
#289
copybara-service[bot]
closed
2 weeks ago
0
- Fixing bug that made graph scanner register repeated dense layers as regular dense layers.
#288
copybara-service[bot]
closed
2 weeks ago
0
- Improved and simplified implementation of "debug" mode based on jax.disable_jit().
#287
copybara-service[bot]
closed
3 weeks ago
0
Modifying internal function clean_jaxpr to properly eliminate unused output variables from higher order primitives. Should have no effect on optimizer behavior.
#286
copybara-service[bot]
closed
2 weeks ago
0
use of `core.unsafe_get_axis_names_DO_NOT_USE` which no longer exists
#285
svandenhaute
opened
3 weeks ago
1
Removing hacky "fixes" to test_graph_matcher. Basically, the test insists that the manual registration includes all of the params from the main equation in the match found by the graph scanner. Instead of filtering these out, we now ensure that they are included in the manual registrations done in tests/models.py. Note that passing all these params won't be required when using manual registration in practice. Only certain params are mandatory for particular layers (based the type of curvature block that gets assigned to them).
#284
copybara-service[bot]
closed
4 weeks ago
0
Add sharding rules to some more primitives so that backward pass of minformer passes. There are a couple of changes here:
#283
copybara-service[bot]
closed
1 month ago
0
Fixing break in graph_matcher_test due to recent internal change in JAX.
#282
copybara-service[bot]
closed
1 month ago
0
Avoid depending on JAX internals, which are about to change.
#281
copybara-service[bot]
closed
1 month ago
0
Removing check that initial_damping is not set when use_adaptive_damping is False.
#280
copybara-service[bot]
closed
1 month ago
0
Bug: Jax 0.4.13 support
#279
arnon-1
opened
1 month ago
5
Adding additional argument validity check to stepwise_schedule in examples
#278
copybara-service[bot]
closed
1 month ago
0
Changing optimizer to throw an exception when using burnin without a provided data iterator instead of silently skipping burnin.
#277
copybara-service[bot]
closed
1 month ago
0
- Changing automatic registration (aka the graph scanner) so that it doesn't automatically register a parameter if said parameter is used more than once in the graph. In that case, it resorts to the default "generic" registration (which doesn't make any structure assumptions about how the parameter is used).
#276
copybara-service[bot]
closed
1 month ago
0
Updating constant_schedule to return a Python value instead of JAX array.
#275
copybara-service[bot]
closed
1 month ago
0
parameter update
#274
nickhalmagyi
opened
1 month ago
0
Incorrectly pytree recognition by KFAC optimizer
#273
Uernd
opened
1 month ago
1
Stackless yashful
#272
copybara-service[bot]
opened
1 month ago
0
Improving polynomial schedule in the examples codebase so that it works as expected when the initial value is *lower* than the final value.
#271
copybara-service[bot]
closed
2 months ago
0
Model parameters being marked as orphan when using KFAC optimizer
#270
Uernd
opened
2 months ago
2
[kfac-jax] Update graph matching test to support the new "algorithm" tuning parameters for dot_general that will be included in the next JAX release.
#269
copybara-service[bot]
closed
2 months ago
0
- Adding support for the "Schedule-free" method to be used as a wrapper for Optax optimizers in the examples codebase.
#268
copybara-service[bot]
closed
2 months ago
0
- Passing stats to _post_param_update_processing in examples code.
#267
copybara-service[bot]
closed
2 months ago
0
Move optax interface into the main kfac codebase.
#266
copybara-service[bot]
closed
2 months ago
0
Pass the state to the update_polyak function
#265
copybara-service[bot]
closed
2 months ago
0
Separated the `optimizers` module in kfac examples into separate modules
#264
copybara-service[bot]
closed
2 months ago
0
Adding the repeated dense graph patterns.
#263
copybara-service[bot]
closed
2 months ago
0
Split `curvature_blocks.py` module into a package.
#262
copybara-service[bot]
closed
3 months ago
0
Split `curvature_estimator.py` module into a package.
#261
copybara-service[bot]
closed
3 months ago
0
Fix progress off by one.
#260
copybara-service[bot]
closed
3 months ago
0
Add an option to specify a different value function for the preconditioner's curvature estimator.
#259
copybara-service[bot]
closed
3 months ago
0
Add TNT blocks to kfac_jax.
#258
copybara-service[bot]
closed
3 months ago
0
Remove the `TwoKroneckerFactored` class and use the `KroneckerFactored` class instead.
#257
copybara-service[bot]
closed
3 months ago
0
ImportError: cannot import name 'psd_inv_cholesky' from 'kfac_jax._src.utils'
#256
eul8
opened
3 months ago
1
Minor non-functional change.
#255
copybara-service[bot]
closed
3 months ago
0
- Adding handling of jitted functions to graph scanner.
#254
copybara-service[bot]
closed
3 months ago
0
Next