Open sugatoray opened 3 months ago
Please assign this one to me. I have started working on this one already. Will send a draft PR soon to collaborate / get suggestions.
cc: @awni @mzbac
Cool! Although I'm wondering how that will go to do encoder/decoder style models in MLX LM. We have a T5 example you can use as a reference.
If it doesn't add too much complexity I would support allowing T5 style models in MLX LM, but otherwise it might make sense to have an alternative package or repo for such things.
@awni Is there any mlx.nn
equivalent of torch.nansum
(if np.nansum
is to be avoided)?
We don't have such an operation, sorry! You could do something like:
def nansum(x):
return mx.sum(mx.where(mx.isnan(x), 0, x))
Thanks, I looked at the torch.nansum implementation as well.
def nansum(x: mx.array, axis: int=-1):
return mx.sum(mx.where(mx.isnan(x), 0, x), axis=axis)
mx.nansum = nansum
@awni Can we add this as a function to mlx.nn? Would you suggest some edits in that case?
testnp = np.array([1, 2., 0., np.nan, -3.5])
testmx = mx.array(testnp)
testch = torch.tensor(testnp)
torch.nansum(testch) == torch.sum(torch.where(torch.isnan(testch), 0, testch), dim=-1) # True
torch.nansum # tensor(-0.5000, dtype=torch.float64)
mx.nansum(testmx) # array(-0.5, dtype=float32)
@awni What should I use as a substitution for torch.tensor.median()
?
mx.array
?np.median(x, axis=1)
as a placeholder for torch.median(x, dim=-1)
?
I would like to add the chronos model(s). Looking for feedback or suggestion from the maintainers.