keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Add support for RaggedTensors to JAX and Pytorch Backend #844

Closed IMvision12 closed 1 year ago

IMvision12 commented 1 year ago

Object detection models in KerasCV necessitate the use of ragged tensors. While TensorFlow as a backend supports this but jax and pytorch backend don't support it.

@ianstenbit @fchollet @tirthasheshpatel

ianstenbit commented 1 year ago

OD does not require ragged Tensors -- in many cases we pad to a dense representation.

JAX does not include support for ragged arrays, and I'm not aware of a strong Torch convention for this either, so for now ragged support is limited to the backend that offers them natively.