ml-explore / mlx-examples

Examples in the MLX framework
MIT License
6.26k stars 892 forks source link

Instruct tuning for lora/finetune? #484

Open fblissjr opened 9 months ago

fblissjr commented 9 months ago

Please correct me if I'm wrong, but it looks like the current examples for lora training all build a loss function around completion, which lines up with the lora example of using only the 'text' field from the jsonl dataset.

Are there any forks or plans to allow for instruct tuning, where the input is an input prompt, and the loss function is targeting the input/output pair?

Or did I miss something?

Thanks!

edit: example below:

{ "prompt": "[INST] Your input prompt here[/INST]", "text": "The expected output result here" }

Whereas it looks like the current lora process is: { "text": Predict what comes [next] }

Solido commented 9 months ago

Need confirmation but the input inside text is multiline. My comprehension is that line return is the input and the completion.

fblissjr commented 9 months ago

I think it's still just going to optimize for completion of the full text field and doesn't differentiate between the input/output? At least based on the lora code in mlx-llm.

from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/lora.py

class Dataset:
    """
    Light-weight wrapper to hold lines from a jsonl file
    """

    def __init__(self, path: Path, key: str = "text"):
        if not path.exists():
            self._data = None
        else:
            with open(path, "r") as fid:
                self._data = [json.loads(l) for l in fid]
        self._key = key

    def __getitem__(self, idx: int):
        return self._data[idx][self._key]

    def __len__(self):
        if self._data is None:
            return 0
        return len(self._data)
fblissjr commented 9 months ago

Looking at the mlx-llm code, I think we need to adjust the Dataset class to account for various dataset types (instruct, chat, etc). Which starts to turn into an axolotl type project eventually, but for simplicity, probably just being able to pass custom dataset types.

Then I think the tuner.trainer needs to be modified for the default loss function (or a new one added for instruct templates).

Let me know if I'm wrong here, but from what I'm seeing, the only training function available right now is for completions.

fblissjr commented 9 months ago

I believe this fork handles it correctly: https://github.com/chimezie/mlx-tuning-fork/blob/main/src/mlx_tuning_fork/training.py (https://github.com/ml-explore/mlx-examples/pull/235)

edit: saw this PR (https://github.com/ml-explore/mlx-examples/pull/213), looks like the goal is to keep the lora purely as an example, but i do think it may cause confusion for folks trying to do SFT on instruct or chat style datasets. maybe just an edit to the LoRA.md?

awni commented 9 months ago

We have a more featured version of lora in mlx-lm https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md

Assuming it doesn’t add much code complexity I think it would be cool to update it to support alternative losses / styles of training. Depending on how niche the approach is and/or complex it may also make sense to do it as a standalone package. I’ll mark this issue as enhancement for now.

chimezie commented 8 months ago

The recent changes to allow the loss and iterate_batches functions to be specified for the tuning process have made doing this a lot more straightforward to do. I have done this in mlx-tuning-fork, a happily shrinking thin layer over mlx_lm.lora . I can create a PR specifically for instruction tuning w/ (optional) masking of the input in the loss calculation.

However, depending on how this particular kind of tuning is specified in configs/options, I don't know how niche that would be.

Solido commented 8 months ago

MLX is growing fast and community will soon build around a lot. Everything that can be common ground for those projects should be welcome. I'm myself working exclusively on instruct and patiently waiting for more options.

chimezie commented 2 weeks ago

Gentle ping regarding this: #1086 . I don't think this approach would be too niche or add too much complexity (depending on how feasible it is to continue to rely on apply_chat_templates for handling the prompt formatting while keeping the distinction between where input tokens end and output tokens begin)