keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 116 forks source link

Issue `keras.utils.split_dataset` #637

Closed innat closed 1 year ago

innat commented 1 year ago

The keras.utils.split_dataset returns _PrefetchDataset, thus unable to pass to torch.utils.data.TensorDataset API.

!pip install keras-core -q
import os
os.environ["KERAS_BACKEND"] = "torch"

import keras_core as keras
(x_train, y_train), held_out = keras.datasets.cifar10.load_data()
test_set, val_set = keras.utils.split_dataset(held_out, left_size=0.2)
test_set

<_PrefetchDataset element_spec=(TensorSpec(shape=(32, 32, 3), dtype=tf.uint8, name=None), TensorSpec(shape=(1,), dtype=tf.uint8, name=None))>

fchollet commented 1 year ago

thus unable to pass to torch.utils.data.TensorDataset API.

Can you expand on this? What's this API and why can it not use a prefetched dataset?

The reason we are prefetching the datasets is that it represents a huge potential performance issue, and most users forget to do it themselves after getting the dataset.

innat commented 1 year ago

Thanks for the prompt response.

Can you expand on this?

I need to build a torch data loader, that's why I need to use torch.utils.data.TensorDataset API.

The reason we are prefetching the datasets is that it represents a huge potential performance issue, and most users forget to do it themselves after getting the dataset.

Agreed. But what if I don't have Tensorflow installed or not using it (using torch or jax backend instead)? This split API has used tf functions.


Background: I am translating a starter code written in a torch, source. The provider will host a Kaggle competition. But for some reason, they informed that the submission will be valid if the data loader and model are written in torch code, source. Now, I am trying to translate their starter torch code the same as possible with Keras-core (previously done with tf. keras) with torch backed.

fchollet commented 1 year ago

I need to build a torch data loader

The simplest way to get a torch data loader from a tf.data dataset is likely to write a torch Dataset subclass that yields samples from the tf.data dataset in __getitem__.

Only accepting submissions in PyTorch seems odd to me -- tf.data is significantly faster and safer (no Python multiprocessing, no random segfaults), as well as portable (you can run in a Python-less environment). Plus, since the competition is LLM related: all production-grade LLMs are written in JAX these days (e.g. those from Google, DeepMind, Cohere, Anthropic, Apple -- meanwhile OpenAI uses its own custom CUDA stack). This is unserious.

innat commented 1 year ago

The simplest way to get a torch data loader from a tf.data dataset is likely to write a torch Dataset subclass that yields samples from the tf.data dataset in getitem

My goal was to use same-level API if possible across the reimplementation notebook. For example, in torch: torch.utils.data.random_split and use tf.keras : tf.keras.utils.split_dataset

Only accepting submissions in PyTorch seems odd to me

This is uncommon in Kaggle competition. Only a few competitions make such requirements. For example, this google-asl competition asks to submit in tflite format regardless of the framework people use.

But now I have reimplemented the official torch starter with Keras-core with torch backend. Here is that updated notebook. Their main concern (unlearning method) is considered in this notebook.

The competition is LLM related

I think this competition is CV related. Details.


Any thoughts on the following cases? Now, I think, if I use the torch backend, I can use torch.random_split and keras.split_dataset for the tf backend. Not sure if there's anything in jax API. But keras.utils.split_dataset API could be general, like model_selection.train_test_split

what if I don't have Tensorflow installed or not using it (using torch or jax backend instead)? This split API has used tf functions.

fchollet commented 1 year ago

Any thoughts on the following cases? Now, I think, if I use the torch backend, I can use torch.random_split and keras.split_dataset for the tf backend. Not sure if there's anything in jax API.

For simplicity, perhaps you can stick to using torch data APIs only -- a torch DataLoader should work with a Keras model with any backend.

JAX does not have its own data loading API (JAX users canonically use tf.data). And Keras is intended to be agnostic to what data loading API you use.

But keras.utils.split_dataset API could be general, like model_selection.train_test_split

General as is, able to return either a tf.data Dataset or a torch DataLoader? I think we could do that, by adding a return type argument (defaulting to tf.data).

innat commented 1 year ago

I think we could do that, by adding a return type argument (defaulting to tf.data).

That would be fantastic.

asingh9530 commented 1 year ago

General as is, able to return either a tf.data Dataset or a torch DataLoader? I think we could do that, by adding a return type argument (defaulting to tf.data).

@fchollet This is I actually added in my last PR for split dataset, If needed I can add it again. What do you think ?