huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.97k stars 26.29k forks source link

TypeError: Object of type Tensor is not JSON serializable #33134

Open dengchengxifrank opened 2 weeks ago

dengchengxifrank commented 2 weeks ago

System Info

Who can help?

@sanchit-gandhi

Information

Tasks

Reproduction

from datasets import load_dataset, load_metric
from datasets import ClassLabel, Audio
import datasets
import random
import pandas as pd
import numpy as np
import re, time
import json, librosa
from peft import PeftModel, PeftConfig
import soundfile as sf
import torch
torch.backends.cudnn.enabled = False
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
import pdb
from transformers import TrainingArguments
from transformers import Trainer
from transformers import WhisperProcessor, WhisperForConditionalGeneration , WhisperTokenizer
from transformers import WhisperFeatureExtractor
from transformers import Seq2SeqTrainingArguments
import evaluate
from datasets import load_dataset
metric = evaluate.load("cer")
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

def remove_special_characters(batch):
    batch["test_gt"] = re.sub(chars_to_ignore_regex, '', batch["test_gt"])
    batch["train_gt"] = re.sub(chars_to_ignore_regex, '', batch["train_gt"])
    return batch

def prepare_dataset(batch):
    audio_1 = batch["test_file"]
    audio_2 = batch["picked_file"]
    batch["input_features"] = feature_extractor(np.concatenate((audio_2["array"],audio_1["array"]),axis=0), sampling_rate=audio_1["sampling_rate"]).input_features[0]
    batch["labels"] = tokenizer(batch["train_gt"]+'。'+batch["test_gt"]).input_ids
    return batch

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:

        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        label_features = [{"input_ids": feature["labels"]} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    label_ids[label_ids == -100] = tokenizer.pad_token_id
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    cer = 100 * metric.compute(predictions=pred_str, references=label_str)
    return {"cer": cer}

chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'

train_csv = "./pick_data/tmp.csv"

dbank = load_dataset('csv', data_files={'train':train_csv})
dbank = dbank.map(remove_special_characters,keep_in_memory=True)

model_path = 'openai/whisper-large-v2'

feature_extractor = WhisperFeatureExtractor.from_pretrained(model_path,cache_dir='./')
tokenizer = WhisperTokenizer.from_pretrained(model_path, language="chinese", task="transcribe",cache_dir='./')
processor = WhisperProcessor.from_pretrained(model_path,cache_dir='./')
model = WhisperForConditionalGeneration.from_pretrained(model_path,cache_dir='./').cuda()

config = LoraConfig(r=8, lora_alpha=64,target_modules=["q_proj", "v_proj","k_proj","out_proj"], lora_dropout=0.05, bias="lora_only",init_lora_weights="gaussian")
model = get_peft_model(model, config)

model.print_trainable_parameters()

model.generation_config.language = "chinese"
model.generation_config.task = "transcribe"
PREV_TOKEN = 50360
START_TRANSCRIPT = 50258
ZH_LANGUAGE = 50260
TRANSCRIBE = 50359
NO_TIMESTAMP = 50363
prompt_tokens = [START_TRANSCRIPT,ZH_LANGUAGE,TRANSCRIBE,NO_TIMESTAMP]

pdb.set_trace()
model.config.forced_decoder_ids = torch.LongTensor(prompt_tokens)

dbank = dbank.cast_column("test_file", Audio(sampling_rate=16000))
dbank = dbank.cast_column("picked_file", Audio(sampling_rate=16000))

dbank = dbank.map(prepare_dataset, remove_columns=dbank.column_names["train"], num_proc=1)

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

cer_metric = load_metric("cer")

model.config.ctc_zero_infinity = True

training_args = TrainingArguments(
    output_dir="./tmo_icl", 
    per_device_train_batch_size=2,
    gradient_accumulation_steps=1,
    learning_rate=3e-5, 
    warmup_steps=50, 
    num_train_epochs=15, 
    save_strategy='epoch',
    do_eval=False,
    save_total_limit=3,
    fp16=True, 
    per_device_eval_batch_size=2, 
    report_to=["tensorboard"],
    logging_steps=50, 
    remove_unused_columns=False, 
    label_names=["labels"], 
)

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
    train_dataset=dbank["train"],
    tokenizer=processor.feature_extractor,
)

trainer.train()

Expected behavior

I get the errors as:TypeError: Object of type Tensor is not JSON serializable

I wonder how to set the decoder_start_token_id or forced_decoder_ids during training. And I want the training loss is not including the tokens: . Could you please give me some code examples to show how to set the decoder_start_token_id or forced_decoder_ids ?Thanks

ArthurZucker commented 2 weeks ago

cc @gante for serialization of decode input ids

gante commented 1 week ago

HI @dengchengxifrank 👋 Thank you for opening this issue 🤗

As shown in our documentation, decoder_start_token_id and forced_decoder_ids are not expected to be torch.Tensor, but rather an int or a list of int.

Changing the type should fix your issues 🤗

dengchengxifrank commented 11 hours ago

@gante Thanks!