ZubinGou / multi-view-prompting

Repo for "MvP: Multi-view Prompting Improves Aspect Sentiment Tuple Prediction" [ACL'2023]
MIT License
73 stars 17 forks source link

Please provide an easier model inference method #5

Open RomuloNextly opened 1 year ago

RomuloNextly commented 1 year ago

I really want to test the performance of the model without having to fine-tune it for a specific task.

I tried to follow your code, something like this:

tokenizer = T5Tokenizer.from_pretrained(model_path)
tfm_model = MyT5ForConditionalGeneration.from_pretrained(model_path)
model = T5FineTuner(config, tfm_model, tokenizer)

text = "I will be back, I love the sushi badly!"

input_tokenized = tokenizer(text, return_tensors="pt")
summary_ids = model.model.generate(input_tokenized['input_ids'])
output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

print(output)

# Output: [I will be back [e] love sushi [I love badly sushi

But I'm not 100% sure about the config file and I'm getting weird results.

If you could provide an example, it would be fantastic!

Longer767 commented 1 year ago

Have you resolved the problem

Akash-Shaji commented 2 months ago

Were u able to solve the issue? I too would like to test this model on custom inputs.

Zhuifeng414 commented 1 month ago

Try this @RomuloNextly And you can feed your customized data into get_para_targets

import argparse
import os
import sys
import logging
import pickle
from functools import partial
import time
from tqdm import tqdm
from collections import Counter
import random
import numpy as np

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.callbacks import LearningRateMonitor

from transformers import AdamW, T5Tokenizer
from t5 import MyT5ForConditionalGeneration
from transformers import get_linear_schedule_with_warmup

from data_utils import ABSADataset, task_data_list, cal_entropy, get_para_targets, get_para_targets_dev
from const import *
from data_utils import read_line_examples_from_file
from eval_utils import compute_scores, extract_spans_para

class Args:
    def __init__(self):
        # Basic settings
        self.data_path = "../data/"
        self.task = 'asqp'
        self.dataset = 'rest15'
        self.eval_data_split = 'test'
        self.model_name_or_path = 't5-base'
        self.output_dir = 'outputs/temp'
        self.load_ckpt_name = None
        self.do_train = False
        self.do_inference = True

        # Other parameters
        self.max_seq_length = 200
        self.n_gpu = 0
        self.train_batch_size = 16
        self.eval_batch_size = 64
        self.gradient_accumulation_steps = 1
        self.learning_rate = 1e-4
        self.num_train_epochs = 20
        self.seed = 25

        # Training details
        self.weight_decay = 0.0
        self.adam_epsilon = 1e-8
        self.warmup_steps = 0.0
        self.top_k = 1
        self.multi_path = False
        self.num_path = 1
        self.beam_size = 1
        self.save_top_k = 1
        self.check_val_every_n_epoch = 1
        self.single_view_type = "rank"
        self.ctrl_token = "post"
        self.sort_label = False
        self.load_path_cache = False
        self.lowercase = False
        self.multi_task = True
        self.constrained_decode = True
        self.agg_strategy = 'vote'
        self.data_ratio = 1.0

        # Create directories if they don't exist
        self.setup_output_dir()

    def setup_output_dir(self):
        if not os.path.exists('./outputs'):
            os.mkdir('./outputs')

        if not os.path.exists(self.output_dir):
            os.mkdir(self.output_dir)

def init_args():
    return Args()

def tokenize_input_target(inputs, targets, data_type):
    tokenized_inputs = []
    tokenized_targets = []
    max_len = 128
    for i in range(len(inputs)):
        # change input and target to two strings
        input = ' '.join(inputs[i])
        target = targets[i]

        tokenized_input = model.tokenizer.batch_encode_plus(
            [input],
            max_length=max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt")

        # for ACOS Restaurant and Laptop dataset
        # the max target length is much longer than 200
        # we need to set a larger max length for inference
        target_max_length = 1024 if data_type == "test" else max_len

        tokenized_target = model.tokenizer.batch_encode_plus(
            [target],
            max_length=target_max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt")

        tokenized_inputs.append(tokenized_input)
        tokenized_targets.append(tokenized_target)
    return tokenized_inputs, tokenized_targets

def run_case_parse(input_ids, attention_mask, ):
    outs = model.model.generate(
                input_ids=input_ids.to(device),
                attention_mask=attention_mask.to(device),
                max_length=args.max_seq_length,
                num_beams=args.beam_size,
                early_stopping=True,
                return_dict_in_generate=True,
                output_scores=True,
                prefix_allowed_tokens_fn=partial(
                    model.prefix_allowed_tokens_fn, task, data,
                    input_ids) if args.constrained_decode else None,
                )
    dec = [
        model.tokenizer.decode(ids, skip_special_tokens=True)
        for ids in outs.sequences
    ]
    return dec

args = init_args()
model_path = os.path.join(args.output_dir, "final")
tokenizer = T5Tokenizer.from_pretrained(model_path)
tfm_model = MyT5ForConditionalGeneration.from_pretrained(model_path)
model = T5FineTuner(args, tfm_model, tokenizer)
device = torch.device('cuda:0')
model.model.to(device)
model.model.eval()

task = 'acos'
data = 'rest16'
data_type = 'train'
tasks, datas, sents, labels = read_line_examples_from_file(
        f'../data/{task}/{data}/{data_type}.txt', task, data, lowercase=False)

inputs, targets = get_para_targets(sents, 
                 labels,
                 data_name=data,
                 data_type=data_type,
                 top_k=args.num_path,
                 task=task,
                 args=args)

tokenized_inputs, tokenized_targets = tokenize_input_target(inputs, targets, data_type)

index = 404
input_ids = tokenized_inputs[index]["input_ids"]
attention_mask = tokenized_inputs[index]["attention_mask"]
dec = run_case_parse(input_ids, attention_mask)

print('res: ', dec)
print('inputs: ', ' '.join(inputs[index]))
print('targets: ', targets[index])