CAFECA-IO / KnowledgeManagement

Creating, Sharing, Using and Managing the knowledge and information of CAFECA
https://mermer.com.tw/knowledge-management
MIT License
0 stars 1 forks source link

嘗試使用 TWCC 提供的 openAI 訓練一個可以生成 ESG 報告的模型 Part 3: 訓練模型 v2 #167

Open TzuHanLiang opened 3 weeks ago

TzuHanLiang commented 3 weeks ago

使用生成式預訓練模型,如 GPT-2 或 GPT-3,這些模型可以生成連貫的文本

TzuHanLiang commented 3 weeks ago
import os
import torch
import matplotlib.pyplot as plt
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW

# 1. 載入模型和分詞器
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

# 添加 pad_token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))

model.train()

# 2. 讀取數據
data_dir = "/data/summerizes"

train_texts = []
window_size = 256
stride = 128

for filename in os.listdir(data_dir):
    if filename.endswith(".txt"):
        with open(os.path.join(data_dir, filename), "r", encoding="utf-8") as f:
            text = f.read()
            paragraphs = text.split('\n\n')  # 以段落為單位切割

            # 將每個段落進行分詞,並使用滑動窗口確保上下文完整性
            for paragraph in paragraphs:
                tokenized_paragraph = tokenizer(paragraph, truncation=False, return_tensors="pt")
                input_ids = tokenized_paragraph["input_ids"].squeeze()

                # 使用滑動窗口處理長文本
                for i in range(0, len(input_ids), stride):
                    end_index = i + window_size
                    if end_index >= len(input_ids):
                        end_index = len(input_ids)
                    window_input_ids = input_ids[i:end_index]
                    # 確保每個窗口的長度一致
                    if len(window_input_ids) < window_size:
                        padding_length = window_size - len(window_input_ids)
                        window_input_ids = torch.cat([window_input_ids, tokenizer.pad_token_id * torch.ones(padding_length, dtype=torch.long)])
                    train_texts.append(window_input_ids)

# 將所有編碼段落轉換為 PyTorch tensor
input_ids_list = [input_ids for input_ids in train_texts]

# 4. 設置數據集和加載器
class TextDataset(Dataset):
    def __init__(self, input_ids):
        self.input_ids = input_ids

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {"input_ids": self.input_ids[idx]}

dataset = TextDataset(input_ids_list)
train_loader = DataLoader(dataset, batch_size=2, shuffle=True)

# 5. 訓練模型並記錄損失值
optimizer = AdamW(model.parameters(), lr=5e-5)
epochs = 10  # 增加訓練次數以獲得更好的效果
loss_values = []

for epoch in range(epochs):
    print(f"Epoch: {epoch+1}/{epochs}")
    epoch_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        inputs = {key: val.to(model.device) for key, val in batch.items()}
        outputs = model(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(f"Batch Loss: {loss.item()}")
    average_epoch_loss = epoch_loss / len(train_loader)
    loss_values.append(average_epoch_loss)
    print(f"Epoch: {epoch+1}/{epochs}, Average Loss: {average_epoch_loss}")

# 6. 保存模型
model.save_pretrained("/data/models_v4")
tokenizer.save_pretrained("/data/tokenizers_v4")

# 7. 可視化損失值
plt.plot(loss_values)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.show()

# 8. 使用模型生成文本
model = GPT2LMHeadModel.from_pretrained("/data/models_v4")
tokenizer = GPT2Tokenizer.from_pretrained("/data/tokenizers_v4")

model.eval()
input_text = "2023年度國泰化工的環境管理狀況如下:"
input_ids = tokenizer.encode(input_text, return_tensors="pt")

# 新增注意力掩碼
attention_mask = (input_ids != tokenizer.pad_token_id).long()

# 使用調整過的生成參數生成文本
with torch.no_grad():
    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=1000,
        do_sample=True,
        temperature=0.7,
        top_k=50,
        top_p=0.95,
        pad_token_id=tokenizer.eos_token_id
    )

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)