ml-explore / mlx-examples

Examples in the MLX framework
MIT License
5.52k stars 797 forks source link

Add chronos model(s) to mlx-lm #587

Open sugatoray opened 3 months ago

sugatoray commented 3 months ago

I would like to add the chronos model(s). Looking for feedback or suggestion from the maintainers.

sugatoray commented 3 months ago

Please assign this one to me. I have started working on this one already. Will send a draft PR soon to collaborate / get suggestions.

cc: @awni @mzbac

awni commented 3 months ago

Cool! Although I'm wondering how that will go to do encoder/decoder style models in MLX LM. We have a T5 example you can use as a reference.

If it doesn't add too much complexity I would support allowing T5 style models in MLX LM, but otherwise it might make sense to have an alternative package or repo for such things.

sugatoray commented 3 months ago

@awni Is there any mlx.nn equivalent of torch.nansum (if np.nansum is to be avoided)?

awni commented 3 months ago

We don't have such an operation, sorry! You could do something like:

def nansum(x):
  return mx.sum(mx.where(mx.isnan(x), 0, x))
sugatoray commented 3 months ago

Thanks, I looked at the torch.nansum implementation as well.

https://github.com/pytorch/pytorch/blob/014f91a9d9f94ac9a7f0711600240d7cd7f69844/torch/_decomp/decompositions.py#L4277-L4278

def nansum(x: mx.array, axis: int=-1):
    return mx.sum(mx.where(mx.isnan(x), 0, x), axis=axis)

mx.nansum = nansum

@awni Can we add this as a function to mlx.nn? Would you suggest some edits in that case?

Test

testnp = np.array([1, 2., 0., np.nan, -3.5])
testmx = mx.array(testnp)
testch = torch.tensor(testnp)

torch.nansum(testch) == torch.sum(torch.where(torch.isnan(testch), 0, testch), dim=-1) # True
torch.nansum # tensor(-0.5000, dtype=torch.float64)

mx.nansum(testmx) # array(-0.5, dtype=float32)
sugatoray commented 3 months ago

@awni What should I use as a substitution for torch.tensor.median()?