Closed cooljoseph1 closed 2 months ago
On further thought, I think a combination of ensure_tuple()
and haliax.concat_axis_specs
are enough.
It's an interesting thought though! it hadn't occurred to me. I think I agree that concat_axis_specs is enough, but also that things could be nicer.
Right now
AxisSpec
is defined asAxisSpec = Axis | Sequence[Axis]
. I think it should be changed toAxisSpec = Axis | Sequence[AxisSpec]
. This is because (morally) you should be allowed to create an axis spec by creating a pair of two other axis specs without worrying about whether they are axes or a tuple of axes.The current implementation caused problems for me when I was trying to make a linear layer which would output a linear layer when given input.
I am aware that there is
haliax.concat_axis_specs
which works for pretty much all cases where having nested axis specs would be useful, but it feels like a bit of a crutch.