issues
search
stanford-crfm
/
haliax
Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
153
stars
11
forks
source link
issues
Newest
Newest
Most commented
Recently updated
Oldest
Least commented
Least recently updated
re-fix the sharded to cpu (backport from levanter)
#111
dlwh
closed
1 week ago
0
Rename overlapping_axes -> intersect_axes & make overlapping_axes instead the intersection of names
#110
cooljoseph1
closed
2 months ago
0
Changed Linear to rename In and Out to avoid conflicting names (fixes #53)
#109
cooljoseph1
opened
2 months ago
2
`overlapping_axes` should be renamed to `intersect_axes`
#108
cooljoseph1
closed
2 months ago
2
AxisSpec should allow nested sequences--and all functions that consume an AxisSpec should support nested sequences.
#107
cooljoseph1
closed
2 months ago
2
Cleaner All Pairs Difference
#106
0xc1c4da
opened
2 months ago
3
make out_first = True the default for mlp
#105
dlwh
closed
2 months ago
0
Proposal: Make haliax.nn modules have an __init__ function which can be used to construct a skeleton tree
#104
cooljoseph1
closed
3 months ago
2
Add a tutorial about Stacked/scan layers
#103
dlwh
opened
4 months ago
0
fix import for newest JAX
#102
dlwh
closed
4 months ago
0
add reduce_loss
#101
dlwh
closed
4 months ago
0
add scan_aware_tree_map
#100
dlwh
closed
4 months ago
0
Make a `scan_aware_tree_map`
#99
dlwh
closed
4 months ago
1
narrow range where we use jit for sharding
#98
dlwh
closed
4 months ago
0
Use Generics
#97
thomasahle
opened
4 months ago
1
Another attempt at shard
#96
dlwh
closed
5 months ago
0
Run levanter tests
#95
dlwh
closed
5 months ago
0
No default in_resources
#94
dlwh
closed
5 months ago
1
Use sqrt(fan_in) as default in and truncated_normal
#93
dlwh
closed
5 months ago
0
scan/fold should not consider the `init` args for scanning over
#92
rjpower
closed
5 months ago
5
Expose the filter_checkpoint helper: it's useful for custom training loops.
#91
rjpower
closed
5 months ago
0
Move StateDict to Haliax, clean it up a lot.
#90
dlwh
closed
2 weeks ago
0
revert previous PR
#89
blahBlahhhJ
closed
6 months ago
0
Fix pspec method for levanter
#88
blahBlahhhJ
closed
6 months ago
0
Add replica axes
#87
blahBlahhhJ
closed
6 months ago
1
fix hax.where in the presence of scalar NamedArrays
#86
dlwh
closed
7 months ago
0
add in-place mutation ops for NamedArrays, support lists/1-d JAX Arrays for indexing
#85
dlwh
closed
7 months ago
0
allow einsum to accept aliases similar to rearrange
#84
dlwh
closed
7 months ago
0
switch to pdm:
#83
dlwh
closed
7 months ago
0
fix revision number calculation for auto-publishing
#82
dlwh
closed
7 months ago
0
Poetry + Dev Builds
#81
dlwh
closed
7 months ago
0
Fix multi device shard
#80
dlwh
closed
8 months ago
0
Simplify sharding
#79
dlwh
closed
8 months ago
0
Make Conv work when input type is not the same as kernel type
#78
dlwh
opened
8 months ago
0
Add FP8 support to Haliax
#77
dlwh
closed
8 months ago
0
No Scan Layers
#76
dlwh
closed
8 months ago
0
jax.random.KeyArray deprecated
#75
jennifgcrl
closed
7 months ago
2
workaround current limitation in jax_metals
#74
dlwh
closed
8 months ago
0
explicitly shard scalars in named_jit
#73
dlwh
closed
8 months ago
0
Support for partitioning/sharded data with Pallas kernels?
#72
G-Levine
opened
8 months ago
10
tile and repeat
#71
dlwh
closed
9 months ago
0
add a non-scan-layers version of Stacked
#70
dlwh
closed
8 months ago
0
Add single argument mode for hax.where
#69
blahBlahhhJ
closed
9 months ago
0
Don't check shapes in tree_unflatten. Fixes #66
#68
dlwh
closed
9 months ago
0
Make a helper function to squash/unsquash all axes (except some) into a single batch axis
#67
dlwh
closed
9 months ago
1
named arrays and `eqxi.while_loop(mode=checkpointed)` do not get a long
#66
dlwh
closed
9 months ago
0
Mixed Precision and Resource Envs
#65
dlwh
opened
9 months ago
0
Lp norm function?
#64
rohan-mehta-1024
opened
9 months ago
1
Einsum
#63
dlwh
closed
9 months ago
0
auto_shard inside zeros etc
#62
dlwh
opened
10 months ago
0
Next