JangYeongSil / JettaRLLLM

Jetta-Reinforcement-Learning-Hybrid-LLM-Architecture
Apache License 2.0
6 stars 0 forks source link

give you an advice #1

Open win10ogod opened 1 month ago

win10ogod commented 1 month ago

Hi, friend, maybe you can try to use the llama architecture instead of the original Transformer?(You can refer to llama architecture in llama2.c)

win10ogod commented 1 month ago

Script using hugging face model tokenizer:

import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import logging
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from typing import Dict, List
from transformers import AutoTokenizer

from torch.optim.lr_scheduler import ReduceLROnPlateau
#from torch.utils.tensorboard import SummaryWriter
import torch.cuda.amp as amp

# Basic configuration for logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Constants and Hyperparameters
folder_directory = "data"
learning_rate = 1e-4
num_epochs = 1000
batch_size = 32
max_seq_length = 512
dropout_rate = 0.01
weight_decay = 0.01

# Define device(s)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.device_count() > 1:
    logging.info(f"Using {torch.cuda.device_count()} GPUs")

# Load pre-trained tokenizer (e.g., GPT-2)
tokenizer = AutoTokenizer.from_pretrained("unsloth/Mistral-Nemo-Instruct-2407")

# Add padding token if it doesn't exist
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Update vocab_size to account for the new padding token
vocab_size = len(tokenizer)

# Initial model parameters
initial_embedding_dim = 256
initial_hidden_dim = 512
initial_num_layers = 6
initial_num_heads = 8
initial_ffn_dim = 1024

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.activation = nn.ReLU()

    def forward(self, src, src_mask=None):
        src2 = self.self_attn(src, src, src, attn_mask=src_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

class HybridBrainModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, num_heads, ffn_dim, dropout_rate):
        super(HybridBrainModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.transformer_layers = nn.ModuleList([
            TransformerEncoderLayer(embedding_dim, num_heads, ffn_dim, dropout_rate)
            for _ in range(num_layers)
        ])
        self.fc = nn.Linear(embedding_dim, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        for layer in self.transformer_layers:
            x = layer(x)
        logits = self.fc(x)
        return logits, x  # Return both logits and transformer output

    def generate(self, input_ids, max_length, temperature=1.0, top_k=50, top_p=0.95):
        self.eval()
        with torch.no_grad():
            for _ in range(max_length - len(input_ids)):
                inputs = torch.tensor(input_ids).unsqueeze(0).to(device)
                outputs, _ = self(inputs)
                next_token_logits = outputs[0, -1, :] / temperature
                filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
                next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
                input_ids.append(next_token.item())
                if next_token.item() == tokenizer.eos_token_id:
                    break
        return input_ids

class ReinforcementLearningModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, num_heads, ffn_dim, dropout_rate):
        super(ReinforcementLearningModel, self).__init__()
        self.actor = HybridBrainModel(vocab_size, embedding_dim, hidden_dim, num_layers, num_heads, ffn_dim, dropout_rate)
        self.critic = nn.Linear(embedding_dim, 1)

    def forward(self, x):
        logits, actor_output = self.actor(x)
        value = self.critic(actor_output)
        return logits, value.squeeze(-1)

    def generate(self, input_ids, max_length, temperature=1.0, top_k=50, top_p=0.95):
        return self.actor.generate(input_ids, max_length, temperature, top_k, top_p)

def calculate_accuracy(logits, targets):
    predictions = torch.argmax(logits, dim=-1)
    correct_predictions = (predictions == targets).float()
    accuracy = correct_predictions.mean().item()
    return accuracy

scaler = amp.GradScaler()

def calculate_actor_loss(logits, targets):
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id, reduction='mean')
    return loss_fn(logits.view(-1, logits.size(-1)), targets.view(-1))

def calculate_critic_loss(value, targets):
    loss_fn = nn.SmoothL1Loss()
    return loss_fn(value, targets.float())

# 修改權重初始化函數
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.LayerNorm):
        nn.init.constant_(m.weight, 1.0)
        nn.init.constant_(m.bias, 0)

# 修改數據處理函數
def preprocess_documents(filenames: List[str], folder_directory: str) -> List[List[int]]:
    all_indices = []
    for filename in tqdm(filenames, desc="Preprocessing documents"):
        file_path = os.path.join(folder_directory, filename)
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
        encoded = tokenizer.encode(text, add_special_tokens=True,
                                   truncation=True, max_length=max_seq_length,
                                   padding='max_length')
        all_indices.append(encoded)
    return all_indices
def collate_fn(batch):
    inputs, targets = zip(*batch)
    inputs_padded = pad_sequence(inputs, batch_first=True, padding_value=0)
    targets_padded = pad_sequence(targets, batch_first=True, padding_value=0)
    return inputs_padded, targets_padded
# 修改創建數據集函數
def create_dataset(indices, chunk_size):
    dataset = []
    for encoded in indices:
        for i in range(0, len(encoded) - chunk_size, chunk_size):
            input_seq = torch.tensor(encoded[i:i+chunk_size], dtype=torch.long)
            target_seq = torch.tensor(encoded[i+1:i+chunk_size+1], dtype=torch.long)
            dataset.append((input_seq, target_seq))
    return dataset

# 修改生成樣本文本函數
def generate_sample_text(model, tokenizer, device, input_text, max_length=50):
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
    output_ids = model.generate(input_ids[0].tolist(), max_length=max_length)
    output_text = tokenizer.decode(output_ids, skip_special_tokens=True)
    return output_text

# 修改模型訓練函數
def train_reinforcement_model(model, optimizer, train_loader, val_loader, num_epochs, device):
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

    model.train()
    for epoch in range(num_epochs):
        total_train_loss = 0
        total_train_accuracy = 0

        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        for batch in train_pbar:
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()

            with amp.autocast():
                logits, value = model(inputs)
                loss_actor = calculate_actor_loss(logits, targets)
                loss_critic = calculate_critic_loss(value, targets)
                loss = loss_actor + loss_critic

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            accuracy = calculate_accuracy(logits, targets)
            total_train_loss += loss.item()
            total_train_accuracy += accuracy

            train_pbar.set_postfix({'loss': f'{loss.item():.4f}', 'accuracy': f'{accuracy:.4f}'})

        avg_train_loss = total_train_loss / len(train_loader)
        avg_train_accuracy = total_train_accuracy / len(train_loader)

        # Validation loop
        model.eval()
        total_val_loss = 0
        total_val_accuracy = 0
        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
            for batch in val_pbar:
                inputs, targets = batch
                inputs, targets = inputs.to(device), targets.to(device)

                with amp.autocast():
                    logits, value = model(inputs)
                    loss_actor = calculate_actor_loss(logits, targets)
                    loss_critic = calculate_critic_loss(value, targets)
                    loss = loss_actor + loss_critic

                accuracy = calculate_accuracy(logits, targets)
                total_val_loss += loss.item()
                total_val_accuracy += accuracy

                val_pbar.set_postfix({'loss': f'{loss.item():.4f}', 'accuracy': f'{accuracy:.4f}'})

        avg_val_loss = total_val_loss / len(val_loader)
        avg_val_accuracy = total_val_accuracy / len(val_loader)

        print(f'Epoch {epoch+1}/{num_epochs}, '
              f'Train Loss: {avg_train_loss:.4f}, Train Accuracy: {avg_train_accuracy:.4f}, '
              f'Val Loss: {avg_val_loss:.4f}, Val Accuracy: {avg_val_accuracy:.4f}')

        scheduler.step(avg_val_loss)

        # Generate sample text
        if (epoch + 1) % 10 == 0:
            sample_text = generate_sample_text(model, tokenizer, device, "God is")
            print(f"Sample generated text: {sample_text}")

        model.train()

    #writer.close()

# Main execution
if __name__ == "__main__":
    # Step 1: Preprocess documents
    filenames = [f for f in os.listdir(folder_directory) if os.path.isfile(os.path.join(folder_directory, f))]
    all_indices = preprocess_documents(filenames, folder_directory)

    # Step 2: Create dataset and train model
    dataset = create_dataset(all_indices, chunk_size=max_seq_length)
    train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_data, batch_size=batch_size, collate_fn=collate_fn)

    model = ReinforcementLearningModel(vocab_size, initial_embedding_dim, initial_hidden_dim,
                                       initial_num_layers, initial_num_heads, initial_ffn_dim, dropout_rate)
    model = model.to(device)
    model.apply(init_weights)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    print("Starting model training...")
    train_reinforcement_model(model, optimizer, train_loader, val_loader, num_epochs, device)

    torch.save(model.state_dict(), 'reinforcement_language_model.pth')
    tokenizer.save_pretrained('tokenizer')
    logging.info("Model training completed and model saved.")