mlr-org / mlr3torch

Deep learning framework for the mlr3 ecosystem based on torch
https://mlr3torch.mlr-org.com
Other
38 stars 7 forks source link

Add a MultiModal example Task #292

Open sebffischer opened 4 days ago

sebffischer commented 4 days ago

One of the strengths of mlr3torch is that it can easily handle multimodal data. This is because a neural network built out of PipeOpTorch operators can have multiple inputs (PipeOpTorchIngress). To showcast this feature, we need a multimodal example dataset for which we can take this one: https://www.kaggle.com/c/siim-isic-melanoma-classification/data. Some predefined image tasks already exist in mlr3torch so integrating this new task will work similar to https://github.com/mlr-org/mlr3torch/blob/main/R/TaskClassif_mnist.R.

To add a new task to mlr3torch, we need to add a function that takes in an ID and returns task.

load_task_melanoma = function(id = "melanoma") {
  ...
  return(task)
}

Then, we need to add this function to the dictionary of tasks as below:

register_task("melanoma", load_task_melanoma)

Because the dataset is too large to be contained in the mlr3torch package, we use a DataBackendLazy as the tasks's backend. Therefore, the load_task_melanoma function first needs to construct this DataBackendLazy and then create a TaskClassif from that DataBackendLazy.

The DataBackendLazy:

sebffischer commented 4 days ago

I also got sent some code for the preprocessing (not sure what we need from there but I am putting it here in case it is useful).

  1. step in R:
preprocess-1.R - R
library(dplyr)
library(mgcv)

df <- read.csv2("data/ISIC_2020_Training_GroundTruth_v2.csv", sep=",")

### remove empty
keep_mask <- (df$sex != "") & (df$anatom_site_general_challenge != "") & !is.na(df$age_approx)
cat("removing", sum(!keep_mask), "rows with empty columns\n")
df <- df[keep_mask,]

### encode
df$sex <- factor(df$sex)
df$diagnosis <- factor(df$diagnosis)
df$site <- factor(df$anatom_site_general_challenge)
df$benign_malignant <- factor(df$benign_malignant)
df$patient_id <- factor(df$patient_id)

pats <- df %>% group_by(patient_id) %>% summarise(n=n()) %>% filter(n>=4)
df <- df[df$patient_id %in% pats$patient_id,]
cat("kept", nrow(df), "lesions from patients with at least four\n")
pats <- pats[sample(1:nrow(pats)),]

test_pats <- pats[1:170, "patient_id"]
tune_pats <- pats[171:340, "patient_id"]
train_pats <- pats[341:nrow(pats), "patient_id"]

test_df <- df %>% filter(patient_id %in% test_pats$patient_id)
tune_df <- df %>% filter(patient_id %in% tune_pats$patient_id)
train_df <- df %>% filter(patient_id %in% train_pats$patient_id)

cat("got", nrow(test_df), "patients for test\n")
cat("got", nrow(tune_df), "patients for tuning\n")
cat("got", nrow(train_df), "patients for training\n")

test_df$subset <- "test"
tune_df$subset <- "tune"
train_df$subset <- "trainval"

df_all <- rbind(train_df, tune_df, test_df)
saveRDS(df_all, "data/train-processed.RDS")

# model matrix for structured effects
#mdl <- gam(target ~ site + sex + s(age_approx), family = "binomial", data = df)
mdl <- bam(
    target ~ site + sex + s(age_approx, by=sex),
    family = "binomial", data = df_all, discrete = TRUE, nthreads = 4
)

x_struc <- as.data.frame(model.matrix(mdl))
x_struc$target <- df_all$target
x_struc$image <- df_all$image
x_struc$patient_id <- df_all$patient_id
x_struc$subset <- df_all$subset

write.csv2(x_struc, "data/x_struc.csv")
  1. step in python:
import torch
import os
from tqdm import tqdm
import torchvision

images = []
files = []
tx = torchvision.transforms.Resize((128, 128))
for f in tqdm(os.listdir("train")):
    img = torchvision.io.read_image("train/" + f)
    images.append(tx(img.float() / 255))
    files.append(f)

torch.save({
    'names': files,
    'images': torch.stack(images),
}, 'x_train_resized_normalized.pt')
cxzhang4 commented 1 day ago

eventual data representation is in a single table

entry1 = po("torch_ingress_ltnsr") %>>% po("nn_linear_1")

entry2 = po("torch_ingress_num") %>>% po("nn_linear_2")

list(entry1, entry2) %>>% po("merge_sum") # takes multiple inputs, so handles multimodal data

more fine-grained control looks something like this

graph = Graph$new()

graph$add_pipeop()