keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61k stars 19.34k forks source link

tf dataset with torch model slower ? (torch backend) #18873

Open BenjaminDug opened 5 months ago

BenjaminDug commented 5 months ago

Hello,

First I want to thank you for the release of this amazing framework. I am a tensorflow user, but sometimes I don't have choice to use pytorch and I don't like the dataloader (I prefer tf.dataset which is, according to me, the fastest when there are a lot of I/O).

I decided to try to implement on the MNIST a very simple case where I use different backend with the same data pipeline and a torch model with dataloader:

1 - jax backend with tf.dataset - model is full keras layers - virtualenv from you requirement-jax-gpu.txt 2 - tensorflow backend with tf.dataset - model is full keras layers - virtualenv from you requirement-tensorflow-gpu.txt 3 - torch backend with tf.dataset - model is nn module in the init part of a keras model - virtualenv from you requirement-torch-gpu.txt 4 - torch backend with a dataloader and a torch training loop (no keras here) - virtualenv from you requirement-torch-gpu.txt

hardware: gpu rtx 2070 - cpu i5 9400f

install: ubuntu 22.04 - cuda 12.2 - cudnn 8.9.5.30 - driver 535.129.03

The code for the 1, 2 and 3 are just below:

import os

os.environ["KERAS_BACKEND"] = "torch"
import keras
from keras.layers import TorchModuleWrapper
import torch
from torch import nn
import torch.nn.functional as F
import tensorflow as tf

from datetime import datetime
from mlxtend.data import loadlocal_mnist
import numpy as np

def load_mnist():
    X_train_full, y_train_full = loadlocal_mnist(
        images_path='/media/benjamin/shared/AIproject/keras_core/data/train-images.idx3-ubyte',
        labels_path='/media/benjamin/shared/AIproject/keras_core/data/train-labels.idx1-ubyte')
    X_test_full, y_test_full = loadlocal_mnist(
        images_path='/media/benjamin/shared/AIproject/keras_core/data/t10k-images.idx3-ubyte',
        labels_path='/media/benjamin/shared/AIproject/keras_core/data/t10k-labels.idx1-ubyte')

    X_train_full = np.reshape(X_train_full, (60000, 28, 28, 1)) / 255
    X_test_full = np.reshape(X_test_full, (10000, 28, 28, 1)) / 255

    if "torch" in os.environ['KERAS_BACKEND']:
        X_train_full=np.transpose(X_train_full,(0,3,1,2))
        X_test_full=np.transpose(X_test_full,(0,3,1,2))
    return X_train_full, y_train_full[:, np.newaxis], X_test_full, y_test_full[:, np.newaxis]

def build_dataset():
    Xtrain, ytrain, Xtest, ytest = load_mnist()

    train = tf.data.Dataset.from_tensor_slices((Xtrain, ytrain))
    test = tf.data.Dataset.from_tensor_slices((Xtest, ytest))

    return train, test

def convnet_keras():
    learning_rate = 1e-3

    input = keras.layers.Input(shape=[28, 28, 1])

    net = keras.layers.Conv2D(filters=32,
                              use_bias=True,
                              kernel_size=3,
                              strides=2,
                              padding='SAME')(input)

    net = keras.layers.ReLU()(net)

    net = keras.layers.Conv2D(filters=64,
                              use_bias=True,
                              kernel_size=3,
                              strides=2,
                              padding='SAME')(net)

    net = keras.layers.ReLU()(net)

    net = keras.layers.Flatten()(net)
    logit_out = keras.layers.Dense(10)(net)

    model = keras.models.Model(input, logit_out)
    model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=keras.optimizers.Nadam(learning_rate=learning_rate),
                  metrics=[keras.metrics.sparse_categorical_accuracy])

    callbacks = []

    return model, callbacks

class Classifier(keras.Model):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Wrap `torch.nn.Module`s with `TorchModuleWrapper`
        # if they contain parameters
        self.conv1 = nn.Conv2d(in_channels=1,stride=2, out_channels=32, kernel_size=(3, 3))

        self.conv2 = nn.Conv2d(in_channels=32,stride=2, out_channels=64, kernel_size=(3, 3))

        # self.pool = nn.MaxPool2d(kernel_size=(2, 2))
        self.flatten = nn.Flatten()
        # self.dropout = nn.Dropout(p=0.5)
        self.fc = nn.Linear(2304, 10)

    def call(self, inputs):
        x = F.relu(self.conv1(inputs))
        x = F.relu(self.conv2(x))
        x = self.flatten(x)
        x = self.fc(x)
        return x

def convnet_vanille_torchbackend_class():
    learning_rate = 1e-3

    model = Classifier()

    model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=keras.optimizers.Nadam(learning_rate=learning_rate),
                  metrics=[keras.metrics.sparse_categorical_accuracy])
    model.build(input_shape=(1, 28, 28))
    callbacks = []

    return model, callbacks

if __name__ == "__main__":
    print(f"keras version: {keras.__version__}")
    print(f"tf version: {tf.__version__}")
    print(f"torch version: {torch.__version__}")

    train, valid = build_dataset()

    train = train.repeat().batch(64).prefetch(10)
    valid = valid.repeat(1).batch(64).prefetch(10).cache()

    model, cb = convnet_keras()
    if "torch" in os.environ['KERAS_BACKEND']:
        model, cb = convnet_vanille_torchbackend_class()

    logdir = os.path.join('./log/', datetime.now().strftime('%Y%m%d_%H%M'))
    os.makedirs(logdir)

    main_cb = [
        keras.callbacks.TensorBoard(logdir),
        keras.callbacks.ModelCheckpoint(logdir + '/bestconv_loss.keras', save_best_only=True),
    ]

    model.summary()
    print("training part...")
    model.fit(train,
              validation_data=valid,
              steps_per_epoch=938,
              callbacks=cb + main_cb,
              epochs=10000,
              )

The code for the 4 is just below:

import torch
import argparse
import idx2numpy
import os
import gzip
import glob
import logging
import yaml

import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

def get_cmd() -> argparse.Namespace():
    """Get command line parameters.

    Returns:
        params (argparse): argparse object containing all command line parameters.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config",
                        dest="config_path",
                        required=True,
                        help="Path of config file to load.")
    parser.add_argument("--data",
                        dest="data_path",
                        required=True)
    parser.add_argument("--save-model",
                        action=argparse.BooleanOptionalAction,
                        dest="save_model",
                        help="Save the model.")
    params = parser.parse_args()

    return params

def load_config(config_path: str = "config.yaml") -> dict:
    """Load the config file. Should be a .yaml file.

    Args:
        config_path (str): path to the config file. Defaults to "config.yaml".

    Returns:
        config (dict): dictionnary containing all parameters in the config file.
    """
    logging.info("Checking config file")
    with open(config_path) as inf:
        try:
            config = yaml.safe_load(inf)
            logging.info("Config file loaded sucessfully")
        except yaml.YAMLError as exc:
            logging.info(exc)

    return config

class MNIST(Dataset):
    """Custom class for MNIST dataset.
    """
    def __init__(self, data_path: str, transform: transforms.Compose = None):
        """Initialisation of the dataset.

        Args:
            data_path (str): path to the folder containing the MNIST data.
            transform (transforms.Compose, optional): transformations to apply to MNISt images. Defaults to None.
        """
        super(MNIST, self).__init__()
        self.transform = transform
        self.images, self.labels = self._load_data(data_path)

    def _load_data(self, data_path: str) -> tuple[np.array, np.array]:
        """Load MNIST data from located in data_path.

        Args:
            data_path (str): path to the folder containing the MNIST data.

        Returns:
            images (numpy.array): images of the MNIST dataset.
            labels (numpy.array): labels of the MNIST dataset.
        """
        images_path = glob.glob(os.path.join(data_path, "train-images*"))[0]
        labels_path = glob.glob(os.path.join(data_path, "train-labels*"))[0]

        if os.path.splitext(images_path)[-1] == '.gz':
            with gzip.open(images_path, 'r') as f:
                images = idx2numpy.convert_from_file(f)
        else:
            images = idx2numpy.convert_from_file(images_path)

        if os.path.splitext(labels_path)[-1] == '.gz':
            with gzip.open(labels_path, 'r') as f:
                labels = idx2numpy.convert_from_file(f)
        else:
            labels = idx2numpy.convert_from_file(labels_path)

        return images, labels

    def __getitem__(self, index: int):
        """Get a single couple of (image, label) from the dataset.

        Args:
            index (int): index of the element in the dataset.

        Returns:
            image: image of the dataset corresponding to the index.
            label: image of the dataset corresponding to the index.
        """
        image, label = self.images[index], int(self.labels[index])

        image = Image.fromarray(image, mode='L')

        if self.transform is not None:
            image = self.transform(image)

        return image, label

    def __len__(self) -> int:
        """Get the length of the dataset.

        Returns:
            int: length of the dataset.
        """
        return len(self.labels)

class CNN(nn.Module):
    """Custom class for CNN model.
    """
    def __init__(self):
        """Iinitialisation of the CNN model.
        """
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(2304, 10)

    def forward(self, x: torch.tensor) -> torch.tensor:
        """Froward method of the CNN model.

        Args:
            x (torch.tensor): inputs of the model (images of the MNIST dataset).

        Returns:
            output (torch.tensor): model prediction.
        """
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        output = F.log_softmax(x, dim=1)

        return output

def train(model: CNN,
          train_loader: DataLoader,
          val_loader: DataLoader,
          optimizer: torch.optim.Optimizer,
          epoch: int,
          device: str):
    """Training loop.

    Args:
        model (CNN): model to train, in this case the CNN.
        train_loader (DataLoader): training dataloader.
        val_loader (DataLoader): validation dataloader.
        optimizer (torch.optim.Optimizer): optimizer to use.
        epoch (int): current epoch.
        device (str): device to use.
    """
    # Training
    model.train()
    pbar = tqdm(total=len(train_loader) + len(val_loader),
                desc=f"Epoch {epoch}",
                bar_format="{l_bar}{bar:25}{r_bar}{bar:-5b}",
                unit='batch')

    train_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        pbar.set_postfix(train_loss=train_loss / (batch_idx + 1))
        pbar.update()

    # Validation
    model.eval()
    val_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            val_loss += F.nll_loss(output, target, reduction="sum").item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            pbar.update()
    val_loss /= len(val_loader.sampler)
    accuracy = 100 * correct / len(val_loader.sampler)
    pbar.set_postfix(train_loss=train_loss / (batch_idx + 1), val_loss=val_loss, accuracy=f"{accuracy:.2f}%")

if __name__ == "__main__":
    # Set up logging
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO)

    # Get command line parameters
    params = get_cmd()

    # Get config
    config = load_config()

    # Get device
    device = torch.device(config["train_torch"]["device"])

    # Create transformations
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Load dataset
    dataset = MNIST(params.data_path, transform=transform)
    logging.info("Dataset loaded sucessfully")

    # Split dataset
    indices = list(range(len(dataset)))
    split = int(np.floor(config["train_torch"]["val_split"] * len(dataset)))
    np.random.seed(27)
    np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_indices)
    val_sampler = SubsetRandomSampler(val_indices)

    # Create dataloaders
    batch_size = config["train_torch"]["batch_size"]
    num_workers = config["train_torch"]["num_workers"]
    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              num_workers=num_workers,
                              pin_memory=True if device == "cuda" else False,  # Accelerate transfer from cpu to GPU
                              sampler=train_sampler)
    val_loader = DataLoader(dataset,
                            batch_size=batch_size,
                            num_workers=num_workers,
                            pin_memory=True if device == "cuda" else False,  # Accelerate transfer from cpu to GPU
                            sampler=val_sampler)

    # Load model
    model = CNN().to(device)

    # Load optimizer and scheduler
    optimizer = optim.AdamW(model.parameters(), lr=config["train_torch"]["lr"])
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=config["train_torch"]["gamma"])

    # Train model
    logging.info("Start of training")
    for epoch in range(1, config["train_torch"]["epochs"] + 1):
        train(model, train_loader, val_loader, optimizer, epoch, device)
        scheduler.step()
    logging.info("Training done")

    # Save model if parameter is true
    if params.save_model:
        torch.save(model.state_dict(), "MNIST_CNN.pt")
        logging.info("Model saved sucessfully")

I have the following stdout for 1 ( jax backend with tf.dataset - model is full keras layers - virtualenv from you requirement-jax-gpu.txt):

Total params: 50,186 (196.04 KB)
 Trainable params: 50,186 (196.04 KB)
 Non-trainable params: 0 (0.00 B)
training part...
2023-12-02 21:01:06.822887: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 376320000 exceeds 10% of free system memory.
Epoch 1/10000
906/938 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.5088 - sparse_categorical_accuracy: 0.86092023-12-02 21:01:10.700383: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
938/938 ━━━━━━━━━━━━━━━━━━━━ 4s 2ms/step - loss: 0.4999 - sparse_categorical_accuracy: 0.8633 - val_loss: 0.0960 - val_sparse_categorical_accuracy: 0.9698
Epoch 2/10000
908/938 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.0936 - sparse_categorical_accuracy: 0.97272023-12-02 21:01:12.094358: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
938/938 ━━━━━━━━━━━━━━━━━━━━ 1s 1ms/step - loss: 0.0933 - sparse_categorical_accuracy: 0.9728 - val_loss: 0.0647 - val_sparse_categorical_accuracy: 0.9772
Epoch 3/10000
917/938 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.0645 - sparse_categorical_accuracy: 0.98062023-12-02 21:01:13.515612: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence

I have the following stdout for 2 ( tensorflow backend with tf.dataset - model is full keras layers - virtualenv from you requirement-tensorflow-gpu.txt):

Total params: 50,186 (196.04 KB)
 Trainable params: 50,186 (196.04 KB)
 Non-trainable params: 0 (0.00 B)
training part...
2023-12-02 21:01:52.509621: I external/local_tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
Epoch 1/10000
2023-12-02 21:01:52.716252: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 376320000 exceeds 10% of free system memory.
2023-12-02 21:01:53.488823: I external/local_xla/xla/service/service.cc:144] XLA service 0x7f6d6c008a90 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2023-12-02 21:01:53.488862: I external/local_xla/xla/service/service.cc:152]   StreamExecutor device (0): NVIDIA GeForce RTX 2070, Compute Capability 7.5
2023-12-02 21:01:53.509684: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2023-12-02 21:01:53.598504: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:456] Loaded cuDNN version 8906
2023-12-02 21:01:54.237403: I external/local_tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1701547314.626129   10031 device_compiler.h:187] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
938/938 ━━━━━━━━━━━━━━━━━━━━ 5s 4ms/step - loss: 0.5074 - sparse_categorical_accuracy: 0.8689 - val_loss: 0.1133 - val_sparse_categorical_accuracy: 0.9639
Epoch 2/10000
938/938 ━━━━━━━━━━━━━━━━━━━━ 3s 3ms/step - loss: 0.0946 - sparse_categorical_accuracy: 0.9723 - val_loss: 0.0774 - val_sparse_categorical_accuracy: 0.9743
Epoch 3/10000
938/938 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 0.0650 - sparse_categorical_accuracy: 0.9803 - val_loss: 0.0594 - val_sparse_categorical_accuracy: 0.9804
Epoch 4/10000

I have the following stdout for 3 ( torch backend with tf.dataset - model is nn module in the init part of a keras model - virtualenv from you requirement-torch-gpu.txt):

Total params: 41,866 (163.54 KB)
 Trainable params: 41,866 (163.54 KB)
 Non-trainable params: 0 (0.00 B)
training part...
2023-12-02 20:57:34.754596: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 376320000 exceeds 10% of free system memory.
Epoch 1/10000
935/938 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.4942 - sparse_categorical_accuracy: 0.86022023-12-02 20:57:40.420302: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
938/938 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - loss: 0.4932 - sparse_categorical_accuracy: 0.8605 - val_loss: 0.1068 - val_sparse_categorical_accuracy: 0.9655
Epoch 2/10000
938/938 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.0972 - sparse_categorical_accuracy: 0.97082023-12-02 20:57:45.591035: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
938/938 ━━━━━━━━━━━━━━━━━━━━ 5s 6ms/step - loss: 0.0972 - sparse_categorical_accuracy: 0.9708 - val_loss: 0.0682 - val_sparse_categorical_accuracy: 0.9769

Of course, I changed the os.environ['KERAS_BACKEND'] for each backend.

We can see that Jax is pretty good (1s an epoch), tensorflow is good too (2s an epoch), but torch with tf.dataset is very much slower (5s an epoch).

I have the following stdout for 4 ( torch backend with a dataloader and a torch training loop (no keras here) - virtualenv from you requirement-torch-gpu.txt):

2023-12-02 21:05:29 - INFO - Checking config file
2023-12-02 21:05:29 - INFO - Config file loaded sucessfully
2023-12-02 21:05:29 - INFO - Dataset loaded sucessfully
2023-12-02 21:05:29 - INFO - Start of training
Epoch 1: 100%|█████████████████████████| 938/938 [00:02<00:00, 342.97batch/s, accuracy=97.28%, train_
Epoch 2: 100%|█████████████████████████| 938/938 [00:02<00:00, 365.75batch/s, accuracy=97.84%, train_
Epoch 3: 100%|█████████████████████████| 938/938 [00:02<00:00, 373.15batch/s, accuracy=98.11%, train_
Epoch 4:   8%|█▉                       | 75/938 [00:00<00:04, 175.00batch/s, train_loss=0.0261]  

We can see that the torch backend, with a dataloader with a torch training loop is around 2s an epoch like tensorflow with tf.dataset pipeline. The latter has the same speed as dataloader because everything is already in cpu ram. I know that tf.dataset is really better when we have tfrecords.

I know that the torch model is a bit different (padding SAME in tensorflow whereas is VALID in torch) and the order of channel is not the same between tensorflow and torch. But the comparison is between the same torch model, but with 2 pipelines differents.

Why the torch backend with a torch model in keras with tf.dataset is slower than the other with dataloader in a pytorch training loop ?

Is there a conversion a kind of conversion between tf.dataset (maybe tf.tensor ?) and input model which are torch tensor ?

Maybe I have done something bad in the code ? Maybe this case is too much easy to make a good comparison ?

I really want to use tf.dataset with torch backend as tensorflow can do.

Thank you for you help

fchollet commented 5 months ago

Most likely this has to do with GPU memory prefetching. Both tf.data and torch DataLoader can do prefetching, but when using a different backend they have to convert to the right tensor type in CPU memory, which cancels the benefits of prefetching. I believe this could be optimized further. @haifeng-jin to advise.

BenjaminDug commented 5 months ago

I tried to withdraw the prefetch in tf dataset, the epoch go from 5s to 6s for the torch backend

I observe that my cpu are full with torch backend and tf dataset pipeline:

torchbackend_good

Below there is the picture of tensorflow backend with the same tf dataset pipeline

tfbnackend

haifeng-jin commented 4 months ago

I will look into it. BTW, model.fit() should work with torch DataLoader directly. @BenjaminDug

BenjaminDug commented 4 months ago

Thank you !

Yes I know that with keras 3.0 we can use dataloader in the .fit() but in my use case I have tfrecords and I need to use a torch model. With tensorflow, the best pipeline is tf.dataset but I have a hope with keras 3.0 to use tf.dataset with a torch model for loading efficiently my tfrecords using torch backend.

I hope that the data stay in low level and there is no python instruction for converting data from tf.dataset. Some times ago, I have already created a dataloader which loaded tfrecords and converted tf.tensor data to torch.tensor. It was really slow because of this conversion in python. So I had to give up tfrecords for this time.

haohuanw commented 4 months ago

@BenjaminDug my team had same setup and what we have found out is that you need to make sure numpy->torch copy is overlapping with the compute. tf does it natively with tf dataset, but if you are going with tf.dataset.as_numpy_iterator() for torch, you need to handle the to gpu buffering yourself.

a code snippet would look like below:

def pin_mem_fn(b: dict[str, torch.Tensor]):
    return {k: v.pin_memory() for k, v in b.items()}

future_batch = executor.submit(pin_mem_fn, next(iter))
for step in ...:
    batch = future_batch.result()
    future_batch = executor.submit(pin_mem_fn, next(iter))
    <do work with batch>

this is similar to what data loader has also but need to be done manually if you are not using a torch data loader. if you are okay with using torch data loader, you can also wrap tf dataset in a torch iterable dataset and then use data loader.

haifeng-jin commented 3 weeks ago

I believe @hertschuh has more insights on this issue. It was either resolved already or very hard to resolve.

Assigning to @hertschuh temporarily. Feel free to assign it back.

hertschuh commented 3 weeks ago

@BenjaminDug ,

I did some rework of the "DataAdapters" after this bug was created, however, I don't believe the performance of your specific use case has changed. After doing some benchmarking, I came to the conclusion that there is no easy way to improve the performance of feeding a tf.data.Dataset to Torch model. The slowdown comes mostly from the copying of tensors.

The bottom line is that you get better performance with a Torch DataLoader.

github-actions[bot] commented 1 week ago

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.