Open fblissjr opened 9 months ago
Need confirmation but the input inside text is multiline. My comprehension is that line return is the input and the completion.
I think it's still just going to optimize for completion of the full text field and doesn't differentiate between the input/output? At least based on the lora code in mlx-llm.
from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/lora.py
class Dataset:
"""
Light-weight wrapper to hold lines from a jsonl file
"""
def __init__(self, path: Path, key: str = "text"):
if not path.exists():
self._data = None
else:
with open(path, "r") as fid:
self._data = [json.loads(l) for l in fid]
self._key = key
def __getitem__(self, idx: int):
return self._data[idx][self._key]
def __len__(self):
if self._data is None:
return 0
return len(self._data)
Looking at the mlx-llm code, I think we need to adjust the Dataset class to account for various dataset types (instruct, chat, etc). Which starts to turn into an axolotl type project eventually, but for simplicity, probably just being able to pass custom dataset types.
Then I think the tuner.trainer needs to be modified for the default loss function (or a new one added for instruct templates).
Let me know if I'm wrong here, but from what I'm seeing, the only training function available right now is for completions.
I believe this fork handles it correctly: https://github.com/chimezie/mlx-tuning-fork/blob/main/src/mlx_tuning_fork/training.py (https://github.com/ml-explore/mlx-examples/pull/235)
edit: saw this PR (https://github.com/ml-explore/mlx-examples/pull/213), looks like the goal is to keep the lora purely as an example, but i do think it may cause confusion for folks trying to do SFT on instruct or chat style datasets. maybe just an edit to the LoRA.md?
We have a more featured version of lora in mlx-lm https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md
Assuming it doesn’t add much code complexity I think it would be cool to update it to support alternative losses / styles of training. Depending on how niche the approach is and/or complex it may also make sense to do it as a standalone package. I’ll mark this issue as enhancement for now.
The recent changes to allow the loss and iterate_batches functions to be specified for the tuning process have made doing this a lot more straightforward to do. I have done this in mlx-tuning-fork, a happily shrinking thin layer over mlx_lm.lora . I can create a PR specifically for instruction tuning w/ (optional) masking of the input in the loss calculation.
However, depending on how this particular kind of tuning is specified in configs/options, I don't know how niche that would be.
MLX is growing fast and community will soon build around a lot. Everything that can be common ground for those projects should be welcome. I'm myself working exclusively on instruct and patiently waiting for more options.
Gentle ping regarding this: #1086 . I don't think this approach would be too niche or add too much complexity (depending on how feasible it is to continue to rely on apply_chat_templates for handling the prompt formatting while keeping the distinction between where input tokens end and output tokens begin)
Please correct me if I'm wrong, but it looks like the current examples for lora training all build a loss function around completion, which lines up with the lora example of using only the 'text' field from the jsonl dataset.
Are there any forks or plans to allow for instruct tuning, where the input is an input prompt, and the loss function is targeting the input/output pair?
Or did I miss something?
Thanks!
edit: example below:
{ "prompt": "[INST] Your input prompt here[/INST]", "text": "The expected output result here" }
Whereas it looks like the current lora process is: { "text": Predict what comes [next] }