KruASe76 / news-recommendation-service

Part of education at MIPT, HSSE under the lead of MTS
0 stars 0 forks source link

ML. Поиск модели машинного обучения #15

Open AlexanderBobryakov opened 1 month ago

AlexanderBobryakov commented 1 month ago

TO-BE:

AlexanderBobryakov commented 4 weeks ago

@KhodasM Тут прикладываю mvp (без ONNX):

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline
import pandas as pd

# Загрузка предобученной модели и токенизатора
model_name = "cross-encoder/nli-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Создание пайплайна для классификации текстов
classifier = pipeline('zero-shot-classification', model=model, tokenizer=tokenizer)

# Пример текста для классификации
text = "Как настроить CI/CD для проекта на GitLab?"

# Возможные категории
candidate_labels = ["DevOps", "IT", "Frontend", "Backend", "Data Science", "Machine Learning", "Cybersecurity", "Cloud Computing", "Mobile Development", "Game Development", "Database Administration"]

# Классификация текста
result = classifier(text, candidate_labels)

# Форматирование результата в человекочитаемый вид
df = pd.DataFrame({
    'Category': result['labels'],
    'Score': result['scores']
})

print(f"Predicted Category: {predicted_category}\n")
print("Detailed Scores:")
print(df)

Так как я тут скидываю прям готовое решение, то ожидаю от тебя:

AlexanderBobryakov commented 1 week ago

Как можно сохранить модель и токены:

# Директория для сохранения ONNX модели
onnx_path = Path("onnx_model")
onnx_path.mkdir(exist_ok=True)

!pip install onnx onnxruntime transformers torch
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pathlib import Path  

inputs = tokenizer(text, return_tensors="pt")
torch.onnx.export(
   model, 
   (inputs["input_ids"], inputs["attention_mask"]), 
   onnx_path / "model.onnx",
   input_names=["input_ids", "attention_mask"],
   output_names=["output"],
   dynamic_axes={"input_ids": {0: "batch_size"}, "attention_mask": {0: "batch_size"}, "output": {0: "batch_size"}},
   opset_version=11
)

print("Модель успешно экспортирована в ONNX формат!")

tokenizer.save_pretrained("./onnx_model/tokenizer")
AlexanderBobryakov commented 1 week ago

Как можно работать с Java:

public static void main(final String[] args) throws Exception {
    // Подгружаем модель и настраиваем ONNX
    final var env = OrtEnvironment.getEnvironment();
    final var session = env.createSession(
        "/Users/asbobryako/IdeaProjects/MIPT/news-recommendation-service/app/src/main/resources/model.onnx",
        new OrtSession.SessionOptions());

    // Пример входного текста который нужно категоризировать
    final var inputText = "Как настроить CI/CD для проекта на GitLab?";
    // Список категорий
    final var candidateLabels = Arrays.asList("DevOps", /*"IT",*/ "Frontend", "Backend", "Data Science", "Machine Learning", /*"Cybersecurity",*/ "Cloud Comp", "Mobile Platform", "Game News", "Database Admin");

    // Подгружаем токены
    HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(
        Paths.get("/Users/asbobryako/IdeaProjects/MIPT/news-recommendation-service/app/src/main/resources/tokenizer/tokenizer.json"));

    // Прбегаемся по всем категориям чтобы передать в модель вход
    final var logits = new ArrayList<Float>();
    for (String label : candidateLabels) {
        // Токенизируем вход: текст + категория
        final var encode = tokenizer.encode(List.of(inputText, label));
        final var inputTokens = encode.getIds();
        final var attentionMask = encode.getAttentionMask();

        // Формируем два тензора - их можно было увидеть в самом python-файле когда работали с моделью
        OnnxTensor t1 = OnnxTensor.createTensor(env, new long[][]{inputTokens});
        OnnxTensor t2 = OnnxTensor.createTensor(env, new long[][]{attentionMask});

        // Создаем вход для модели - мапу из двух тензоров с определенными именами (их можно было увидеть в самом python-файле когда работали с моделью)
        var inputs = Map.of("input_ids", t1, "attention_mask", t2);
        // Вызываем модель
        try (var result = session.run(inputs)) {
            // Достаем числа - ответ от модели
            float[][] resultLogits= (float[][]) result.get(0).getValue();
            //1. **Entailment** (следование) — модель считает, что гипотеза логически следует из текста.
            //2. **Neutral** (нейтральное) — модель считает, что гипотеза не противоречит тексту, но и не следует из него.
            //3. **Contradiction** (противоречие) — модель считает, что гипотеза противоречит тексту.
            // Когда ты используешь её для zero-shot классификации, ты создаёшь пары "текст-гипотеза"
            // для каждой категории и оцениваешь вероятность "entailment" для каждой пары.
            // Это значение ты интерпретируешь как вероятность того, что текст относится к данной категории.
            System.out.println(label + ":   " + Arrays.deepToString(resultLogits)); // Пример вывода: [2.215198, -2.1031926, 0.88638294]
            logits.add(resultLogits[0][0]);  // Запоминаем значение именно для типа Entailment ("следование")
        }
        System.out.println();
    }

    // Вызываем метод softmax когда получили результат модели для каждой категории, чтобы уже их сравнить
    float[] probabilities = softmax(logits);
    System.out.println(Arrays.toString(probabilities)); // Пример вывода: [0.116824396, 0.15116894, 0.1472722, 0.101457104, 0.16129832, 0.07249092, 0.06975241, 0.11952588, 0.060209814]
}

private static float[] softmax(List<Float> logits) {
    float maxLogit = logits.stream().max(Float::compareTo).orElse(Float.NEGATIVE_INFINITY);
    float sumExp = 0.0f;
    float[] expLogits = new float[logits.size()];
    for (int i = 0; i < logits.size(); i++) {
        expLogits[i] = (float) Math.exp(logits.get(i) - maxLogit);
        sumExp += expLogits[i];
    }
    for (int i = 0; i < expLogits.length; i++) {
        expLogits[i] /= sumExp;
    }
    return expLogits;
}

Image