Open AlexanderBobryakov opened 1 month ago
@KhodasM Тут прикладываю mvp (без ONNX):
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline
import pandas as pd
# Загрузка предобученной модели и токенизатора
model_name = "cross-encoder/nli-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Создание пайплайна для классификации текстов
classifier = pipeline('zero-shot-classification', model=model, tokenizer=tokenizer)
# Пример текста для классификации
text = "Как настроить CI/CD для проекта на GitLab?"
# Возможные категории
candidate_labels = ["DevOps", "IT", "Frontend", "Backend", "Data Science", "Machine Learning", "Cybersecurity", "Cloud Computing", "Mobile Development", "Game Development", "Database Administration"]
# Классификация текста
result = classifier(text, candidate_labels)
# Форматирование результата в человекочитаемый вид
df = pd.DataFrame({
'Category': result['labels'],
'Score': result['scores']
})
print(f"Predicted Category: {predicted_category}\n")
print("Detailed Scores:")
print(df)
Так как я тут скидываю прям готовое решение, то ожидаю от тебя:
Как можно сохранить модель и токены:
# Директория для сохранения ONNX модели
onnx_path = Path("onnx_model")
onnx_path.mkdir(exist_ok=True)
!pip install onnx onnxruntime transformers torch
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from pathlib import Path
inputs = tokenizer(text, return_tensors="pt")
torch.onnx.export(
model,
(inputs["input_ids"], inputs["attention_mask"]),
onnx_path / "model.onnx",
input_names=["input_ids", "attention_mask"],
output_names=["output"],
dynamic_axes={"input_ids": {0: "batch_size"}, "attention_mask": {0: "batch_size"}, "output": {0: "batch_size"}},
opset_version=11
)
print("Модель успешно экспортирована в ONNX формат!")
tokenizer.save_pretrained("./onnx_model/tokenizer")
Как можно работать с Java:
public static void main(final String[] args) throws Exception {
// Подгружаем модель и настраиваем ONNX
final var env = OrtEnvironment.getEnvironment();
final var session = env.createSession(
"/Users/asbobryako/IdeaProjects/MIPT/news-recommendation-service/app/src/main/resources/model.onnx",
new OrtSession.SessionOptions());
// Пример входного текста который нужно категоризировать
final var inputText = "Как настроить CI/CD для проекта на GitLab?";
// Список категорий
final var candidateLabels = Arrays.asList("DevOps", /*"IT",*/ "Frontend", "Backend", "Data Science", "Machine Learning", /*"Cybersecurity",*/ "Cloud Comp", "Mobile Platform", "Game News", "Database Admin");
// Подгружаем токены
HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(
Paths.get("/Users/asbobryako/IdeaProjects/MIPT/news-recommendation-service/app/src/main/resources/tokenizer/tokenizer.json"));
// Прбегаемся по всем категориям чтобы передать в модель вход
final var logits = new ArrayList<Float>();
for (String label : candidateLabels) {
// Токенизируем вход: текст + категория
final var encode = tokenizer.encode(List.of(inputText, label));
final var inputTokens = encode.getIds();
final var attentionMask = encode.getAttentionMask();
// Формируем два тензора - их можно было увидеть в самом python-файле когда работали с моделью
OnnxTensor t1 = OnnxTensor.createTensor(env, new long[][]{inputTokens});
OnnxTensor t2 = OnnxTensor.createTensor(env, new long[][]{attentionMask});
// Создаем вход для модели - мапу из двух тензоров с определенными именами (их можно было увидеть в самом python-файле когда работали с моделью)
var inputs = Map.of("input_ids", t1, "attention_mask", t2);
// Вызываем модель
try (var result = session.run(inputs)) {
// Достаем числа - ответ от модели
float[][] resultLogits= (float[][]) result.get(0).getValue();
//1. **Entailment** (следование) — модель считает, что гипотеза логически следует из текста.
//2. **Neutral** (нейтральное) — модель считает, что гипотеза не противоречит тексту, но и не следует из него.
//3. **Contradiction** (противоречие) — модель считает, что гипотеза противоречит тексту.
// Когда ты используешь её для zero-shot классификации, ты создаёшь пары "текст-гипотеза"
// для каждой категории и оцениваешь вероятность "entailment" для каждой пары.
// Это значение ты интерпретируешь как вероятность того, что текст относится к данной категории.
System.out.println(label + ": " + Arrays.deepToString(resultLogits)); // Пример вывода: [2.215198, -2.1031926, 0.88638294]
logits.add(resultLogits[0][0]); // Запоминаем значение именно для типа Entailment ("следование")
}
System.out.println();
}
// Вызываем метод softmax когда получили результат модели для каждой категории, чтобы уже их сравнить
float[] probabilities = softmax(logits);
System.out.println(Arrays.toString(probabilities)); // Пример вывода: [0.116824396, 0.15116894, 0.1472722, 0.101457104, 0.16129832, 0.07249092, 0.06975241, 0.11952588, 0.060209814]
}
private static float[] softmax(List<Float> logits) {
float maxLogit = logits.stream().max(Float::compareTo).orElse(Float.NEGATIVE_INFINITY);
float sumExp = 0.0f;
float[] expLogits = new float[logits.size()];
for (int i = 0; i < logits.size(); i++) {
expLogits[i] = (float) Math.exp(logits.get(i) - maxLogit);
sumExp += expLogits[i];
}
for (int i = 0; i < expLogits.length; i++) {
expLogits[i] /= sumExp;
}
return expLogits;
}
TO-BE: