stanford-crfm / haliax

Named Tensors for Legible Deep Learning in JAX
Apache License 2.0
154 stars 11 forks source link

AxisSpec should allow nested sequences--and all functions that consume an AxisSpec should support nested sequences. #107

Closed cooljoseph1 closed 2 months ago

cooljoseph1 commented 2 months ago

Right now AxisSpec is defined as AxisSpec = Axis | Sequence[Axis]. I think it should be changed to AxisSpec = Axis | Sequence[AxisSpec]. This is because (morally) you should be allowed to create an axis spec by creating a pair of two other axis specs without worrying about whether they are axes or a tuple of axes.

The current implementation caused problems for me when I was trying to make a linear layer which would output a linear layer when given input.

I am aware that there is haliax.concat_axis_specs which works for pretty much all cases where having nested axis specs would be useful, but it feels like a bit of a crutch.

cooljoseph1 commented 2 months ago

On further thought, I think a combination of ensure_tuple() and haliax.concat_axis_specs are enough.

dlwh commented 2 months ago

It's an interesting thought though! it hadn't occurred to me. I think I agree that concat_axis_specs is enough, but also that things could be nicer.