AUTOMATIC1111 / TorchDeepDanbooru

Pure pytorch implementation of DeepDanbooru
MIT License
125 stars 30 forks source link

Expanding + more tags? #4

Open thijsi123 opened 1 year ago

fredzhang7 commented 1 year ago

I finetuned TorchDeepDanbooru a few times. This model has 123,461,760 params and seems to struggle to learn/converge on small (e.g., 1-8800 images and 1-25 unique labels), balanced datasets.

Edit: haven't finetuned on a super large dataset yet, so you may need to adjust the hyperparameters.

I loaded a CSV file combined_captions.csv with two columns (img_filename,labels) using pandas. You can take a look at the top of combined_captions.csv. The images in the ./dataset folder are 512x512 and converted to RGB.

import pandas as pd
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import tqdm
from multiprocessing import freeze_support
import deep_danbooru_model

class DeepDanbooruDataset(Dataset):
    def __init__(self, csv_file, image_dir):
        self.data = pd.read_csv(csv_file)
        self.image_dir = image_dir

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = self.data['img_filename'][idx]
        img_path = f"{self.image_dir}/{img_name}"
        image = Image.open(img_path).convert("RGB")

        image = np.array(image, dtype=np.float32) / 255
        image = torch.from_numpy(image).to(device)

        labels = self.data['labels'][idx].split(', ')

        # Get the updated index mapping for the labels
        label_mapping = {label: index for index, label in enumerate(model.tags)}

        # Convert the ground truth labels to tensors with the updated label mapping
        label_tensor = torch.zeros(1, len(model.tags))
        for label in labels:
            label_idx = label_mapping[label]
            label_tensor[0, label_idx] = 1

        return image, label_tensor

# Create the dataset and data loader
batch_size = 16
dataset = DeepDanbooruDataset('combined_captions.csv', './dataset')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Load the pre-trained model
model = deep_danbooru_model.DeepDanbooruModel()
model.load_state_dict(torch.load('model-resnet_custom_v3.pt'))

# Get the labels in the dataset
labels_in_dataset = set(dataset.data['labels'].str.split(', ').explode())

# Change the last layer to output the number of labels in the dataset
model.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=len(labels_in_dataset), bias=False)

# Set the tags in the model
model.tags = list(labels_in_dataset)
labels_in_dataset = None

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001)

# Move the model to CUDA if available and set to training mode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.train()

# Print model parameters
print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")

# Fine-tuning loop
num_epochs = 5

def main():
    for epoch in range(num_epochs):
        running_loss = 0.0
        running_correct_labels = 0
        total_labels = 0

        for images, labels in tqdm.tqdm(dataloader):
            images = images.to(device)
            labels = torch.squeeze(labels, dim=1)
            labels = labels.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)

            # Calculate the loss
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # Calculate accuracy
            predicted_labels = (outputs > 0.6).int()
            print(outputs, labels)
            running_correct_labels += (predicted_labels == labels).sum().item()
            total_labels += labels.numel()

        epoch_acc = running_correct_labels / total_labels
        epoch_loss = running_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

    # Set the model to evaluation mode
    model.eval()
    model.half()
    model.to(device)

    # Load and preprocess the test image
    pic = Image.open("test.jpg").convert("RGB").resize((512, 512))
    a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255

    with torch.no_grad(), torch.autocast(device.type):
        x = torch.from_numpy(a).to(device)

        # Make predictions on the test image
        y = model(x)[0].detach().cpu().numpy()

        # Measure performance
        for n in range(10):
            model(x)

    # Print the tags with confidence scores greater than or equal to 0.6
    for i, p in enumerate(y):
        if p >= 0.6:
            print(model.tags[i], p)

    # Create a dictionary to hold model.tags and the model state dict
    model_data = model.state_dict()
    model_data['tags'] = model.tags

    # Save the model data to a single file
    torch.save(model_data, 'data_labeller_v1.pt')

if __name__ == "__main__":
    freeze_support()
    main()

I'm still experimenting with my finetuning methods. Let me know what you think about using BCEWithLogitsLoss for the loss function and one-hot encoding for the labels.