cooelf / Auto-GUI

Official implementation for "You Only Look at Screens: Multimodal Chain-of-Action Agents" (Findings of ACL 2024)
https://arxiv.org/abs/2309.11436
Apache License 2.0
174 stars 15 forks source link

click accuracy and scroll accuracy #6

Open njucckevin opened 10 months ago

njucckevin commented 10 months ago

Hi, thanks for the good work. I wonder how the click accuracy and scroll accuracy is calculated in section 5.1. I can not find such code in main.py and action_matching.py. Thanks~

njucckevin commented 10 months ago

Also problem with typed text accuracy. In main.py, the text accuracy is calculated as: if check_match and (action_1_typed_text in action_2_typed_text or action_2_typed_text in action_1_typed_text): text_correct += 1. In this case, the text accuracy will be lower than the total action accuracy (as indicated by check_match). However, in section 5.1 the text accuracy is over 90%.

cooelf commented 10 months ago

Hi, please refer to the following codes. In Section 5.1, the text accuracy is only measured if the predicted and reference texts are matched or overlapped (see below). It is not calculated by the code in main.py.

pred = eval("{" + pred + "}")
action_1_touch_yx = eval(pred["touch_point"])
action_1_lift_yx = eval(pred["lift_point"])
action_1_action_type = int(pred["action_type"])
action_1_typed_text = pred["typed_text"].lower()

try:
    reference = pred_dict["target"]
    lift_id = [i for i,x in enumerate(reference) if x == ","][4] - 1
    lift_punk = reference[lift_id]
    if lift_punk != "'":
        str_list = list(reference)
        str_list.insert(lift_id + 1, "'")
        reference = ''.join(str_list)
    reference = eval("{" + reference + "}")
except:
    print("reference error")
    continue

action_2_touch_yx = eval(reference["touch_point"])
action_2_lift_yx = eval(reference["lift_point"])
action_2_action_type = int(reference["action_type"])
action_2_typed_text = reference["typed_text"].lower()

annotation_positions = gold_ui[idx]

try:
    check_match = action_matching.check_actions_match(
        action_1_touch_yx,
        action_1_lift_yx,
        action_1_action_type,
        action_2_touch_yx,
        action_2_lift_yx,
        action_2_action_type,
        annotation_positions
    )
except Exception as exc:
    print(idx, action_1_touch_yx, action_1_lift_yx)
    check_match = False
    match_label = "invalid"

episode_acc = 0

if check_match:
    partial_correct += 1
    match_label = 1
else:
    match_label = 0

if action_1_action_type == action_2_action_type:
    type_correct += 1

# dual
if action_2_action_type == 4:
    if is_tap_action(action_2_touch_yx, action_2_lift_yx):
        click_num += 1
        if match_label:
            click_correct += 1
    else:
        scroll_num += 1
        if match_label:
            scroll_correct += 1

# type
if action_2_action_type == 3:
    text_num += 1
    if (action_2_typed_text == action_1_typed_text) or (action_1_typed_text in action_2_typed_text) or (action_2_typed_text in action_1_typed_text):
        text_correct += 1   
njucckevin commented 10 months ago

thanks for the explanation.

frankfengdi commented 2 months ago

Thanks for the great work!

I have a question regarding evaluation metrics. In main.py, the metrics are computed in the following code snippets.

Seems that the action accuracy score is computed by the number of action_correct / number of all frames. This is different from the original definition of action accuracy score, which is action_correct / len(episode), averaged by the number of episodes.

Correct me if I misunderstand the metrics ...

Could you share the full eval mentioned here? https://github.com/cooelf/Auto-GUI/issues/6#issuecomment-1763939973

    predict_results = trainer.predict(test_dataset=test_set, max_length=args.output_len) 
    if trainer.is_world_process_zero():
        preds, targets = predict_results.predictions, predict_results.label_ids
        preds= np.where(preds != -100, preds, tokenizer.pad_token_id)
        preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True)

        action_correct = 0
        text_correct = 0
        type_correct = 0

        reference_test_positions = test_set.anno_positions

        output_data = []

        pattern = r'(?<=Action Decision:\s).*'

        assert len(preds) == len(targets)  == len(reference_test_positions)
        for idx, pred in enumerate(preds):
            try:
                result = re.search(pattern, targets[idx])
                target_text = result.group(0)
                target_text = target_text.strip()

                reference = eval("{" + target_text + "}")
            except:
                print("reference error")
                continue

            try:
                result = re.search(pattern, preds[idx])
                pred_text = result.group(0)
                pred_text = pred_text.strip()

                pred = eval("{" + pred_text + "}")
                action_1_touch_yx = eval(pred["touch_point"])
                action_1_lift_yx = eval(pred["lift_point"])
                action_1_action_type = action_type.ActionType[pred["action_type"]].value
                action_1_typed_text = pred["typed_text"].lower()
                action_1_typed_text = action_1_typed_text.strip()

                action_1_wrap = f'"action_type": "{action_1_action_type}", "touch_point": "{action_1_touch_yx}", "lift_point": "{action_1_lift_yx}", "typed_text": "{action_1_typed_text}"'
                action_1_wrap = action_1_wrap.replace('"', "'")
            except:
                pred = '{ "action_type": "TYPE", "touch_point": "[-1.0, -1.0]", "lift_point": "[-1.0, -1.0]", "typed_text": "Invalid"}'

            action_2_touch_yx = eval(reference["touch_point"])
            action_2_lift_yx = eval(reference["lift_point"])
            action_2_action_type = action_type.ActionType[reference["action_type"]].value
            action_2_typed_text = reference["typed_text"].lower()

            action_2_wrap = f'"action_type": "{action_2_action_type}", "touch_point": "{action_2_touch_yx}", "lift_point": "{action_2_lift_yx}", "typed_text": "{action_2_typed_text}"'
            action_2_wrap = action_2_wrap.replace('"', "'")

            annotation_positions = reference_test_positions[idx]

            try:
                check_match = action_matching.check_actions_match(
                    action_1_touch_yx,
                    action_1_lift_yx,
                    action_1_action_type,
                    action_2_touch_yx,
                    action_2_lift_yx,
                    action_2_action_type,
                    annotation_positions
                )

            except Exception as exc:
                print(idx, action_1_touch_yx, action_1_lift_yx)
                check_match = False
                match_label = "invalid"

            if check_match:
                action_correct += 1
                match_label = 1
            else:
                match_label = 0
            if check_match and (action_1_typed_text in action_2_typed_text or action_2_typed_text in action_1_typed_text):
                text_correct += 1
            if action_1_action_type == action_2_action_type:
                type_correct += 1

            action_data = {"pred": action_1_wrap, "target": action_2_wrap, "match_label": match_label}
            output_data.append(action_data)

        metrics["accuracy"] = "{:.2f}".format(action_correct/len(targets) * 100)
        metrics["text_acc"] = "{:.2f}".format(text_correct/len(targets) * 100)
        metrics["type_acc"] = "{:.2f}".format(type_correct/len(targets) * 100)
        metrics["action_correct"] = action_correct
        metrics["text_correct"] = text_correct
        metrics["type_correct"] = type_correct
        metrics["total_num"] = len(targets)