ml-explore / mlx-examples

Examples in the MLX framework
MIT License
6.3k stars 898 forks source link

[BUG] generate hangs after multiple iterations #1019

Closed hschaeufler closed 1 month ago

hschaeufler commented 1 month ago

Describe the bug I use MLX_LM to generate tests for different classes (entries in a data frame) using a model fine-tuned with MLX_LM. Depending on the model, the generate_test step hangs after a certain number or for certain entries, so that the generate method does not provide an answer even after several hours. It looks as if MLX is endlessly generating a text and would not reach an end token.

Is there any idea how I can avoid the problem, or is it possible to define a timeout? Maybe this is also due to an internal cache that runs full?

To Reproduce

Include code snippet

from mlx_lm import load, generate
import pandas as pd
from transformers import GenerationConfig
from tqdm import tqdm

tqdm.pandas()
data_frame = pd.read_csv("validation/test_set.csv")

model_path = "results/llama3_1_8B_instruct_lora/tuning_03/lora_fused_model"
model, tokenizer = load(model_path)

generation_config = GenerationConfig.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
generation_args = {
    "temp": generation_config.temperature,
    "top_p": generation_config.top_p,
    "min_p": generation_config.min_p,
}
print(generation_args)

print(data_frame.columns)

def generate_test(code, code_file_name, code_type, test_type,) -> str:
    print(f"Generating test for {code_file_name}")
    prompt =  (f"Generate {test_type.replace("-"," ").lower()}s in Dart for the following {code_type.lower()}.\n"
               f"### Code:\n"
               f"{code}\n"
               f"### Test:"
               )
    print(f"--------------------------------------\n{prompt}\n--------------------------------------")
    response = generate(
        model,
        tokenizer,
        prompt=prompt,
        verbose=False,
        max_tokens=128000,
        **generation_args
    )
    return response

data_frame["test_gen"] = data_frame.progress_apply(lambda row: generate_test(
    code=row["code"],
    code_file_name=row["code_file_name"],
    code_type=row["code_type"],
    test_type=row["test_type"]
), axis=1)
data_frame

I have to censor the output because some of it contains sensitive data. But you can see that nothing happens for almost 2 hours before I cancelled the processing.

58%|█████▊    | 149/258 [1:45:40<51:36, 28.41s/it]

Generating test for xyz.dart
--------------------------------------
Generate widget tests in Dart for the following widget.
### Code:
...
### Test:
--------------------------------------

58%|█████▊    | 149/258 [4:13:18<3:05:18, 102.00s/it]

If I only generate code for the class I get a result:

from mlx_lm import load, generate
import pandas as pd
from transformers import GenerationConfig
from tqdm import tqdm

tqdm.pandas()
data_frame = pd.read_csv("validation/test_set.csv")

model_path = "results/llama3_1_8B_instruct_lora/tuning_03/lora_fused_model"
model, tokenizer = load(model_path)

generation_config = GenerationConfig.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
generation_args = {
    "temp": generation_config.temperature,
    "top_p": generation_config.top_p,
    "min_p": generation_config.min_p,
}
print(generation_args)

print(data_frame.columns)

def generate_test(code, code_file_name, code_type, test_type,) -> str:
    print(f"Generating test for {code_file_name}")
    prompt =  (f"Generate {test_type.replace("-"," ").lower()}s in Dart for the following {code_type.lower()}.\n"
               f"### Code:\n"
               f"{code}\n"
               f"### Test:"
               )
    print(f"--------------------------------------\n{prompt}\n--------------------------------------")
    response = generate(
        model,
        tokenizer,
        prompt=prompt,
        verbose=False,
        max_tokens=128000,
        **generation_args
    )
    return response

data_frame["test_gen"] = data_frame[data_frame["code_file_name"] == "xyz.dart"].progress_apply(lambda row: generate_test(
    code=row["code"],
    code_file_name=row["code_file_name"],
    code_type=row["code_type"],
    test_type=row["test_type"]
), axis=1)
data_frame
 0%|          | 0/1 [00:00<?, ?it/s]

Generating test for xyz.dart
--------------------------------------
Generate widget tests in Dart for the following widget.
### Code:
...
### Test:
--------------------------------------

100%|██████████| 1/1 [01:12<00:00, 72.78s/it]

Expected behavior I would expect a response after a few minutes, or if it doesn't finish after a while, I get a timeout error after a few minutes.

Desktop (please complete the following information): OS Version: MacOS 14.16.1 Version 0.19.0

Additional context The fused model has been fintuned with one of the previous MLX versions and fused with the current version of MLX. I don't use a prompt template because I just want to get the code back like in the lora-Finetunung-Set and without explanations.

hschaeufler commented 1 month ago

Addition: I have now set the temp to 0 to get repeatable results and activated verbose. It seems that in a few cases the finetuned model repeats specific setences forever. Does anyone have any idea why this might be or how I can prevent it?

Is this due to the Lora fine-tuning? Should I perhaps also use the chat template for my training data set? Or is it perhaps because I am lora fine-tuning all layers (‘self_attn.q_proj’, ‘self_attn.v_proj’, ‘self_attn.k_proj’, ‘self_attn.o_proj’, ‘mlp.gate_proj’, ‘mlp.down_proj’, ‘mlp.up_proj’)?

    expect(
      find.byWidgetPredicate(
        (widget) =>
            widget is OutlinedButton &&
            widget.onPressed == null &&
            (widget.style!.side! as BorderSide).color ==
                Theme.of(find.byType(XYZButton)).disabledColor,
      ),
      findsOneWidget,
    );
    expect(
      find.byWidgetPredicate(
        (widget) =>
            widget is OutlinedButton &&
            widget.onPressed == null &&
            (widget.style!.side! as BorderSide).color ==
                Theme.of(find.byType(XYZButton)).disabledColor,
      ),
      findsOneWidget,
    );
    expect(
      find.byWidgetPredicate(
        (widget) =>
            widget is OutlinedButton &&
            widget.onPressed == null &&
            (widget.style!.side! as BorderSide).color ==
                Theme.of(find.byType(XYZButton)).disabledColor,
      ),
      findsOneWidget,
    );
    expect(
      find.byWidgetPredicate(
        (widget) =>
            widget is OutlinedButton &&
            widget.onPressed == null &&
            (widget.style!.side! as BorderSide).color ==
                Theme.of(find.byType(XYZButton)).disabledColor,
      ),
      findsOneWidget,
    );
hschaeufler commented 1 month ago

I have now set ‘repetition_penalty’: 1.1 and max_tokens=35000. In initial tests, this has resulted in no more eternal repetitions. I'll run it through tonight on the full dataset.

hschaeufler commented 1 month ago

Has run through with the settings mentioned. I am closing the ticket. If anyone still has any tips on how I can prevent this during lora fine tuning, I am happy to receive suggestions.