tensorflow / rust

Rust language bindings for TensorFlow
Apache License 2.0
5.18k stars 422 forks source link

String tensor vs utf8 encoding #422

Open sir-earl opened 8 months ago

sir-earl commented 8 months ago

I'm trying to use raw_ops::decode_image to load an image directly from a u8 slice (as opposed to from file as per the example), but it seems I must first convert the slice to a scalar tensor string.

It appears I can make this work with something like:

let s = unsafe { String::from_utf8_unchecked(image_bytes.to_vec()) };

My concern is that Rust expects all strings to be utf8 encoded, of which the above certainly is not.

Am I missing something obvious? Is there a better way to approach this?

dskkato commented 8 months ago

No need to check as valid utf-8, since the TensorFlow uses strings as byte buffer containers.

How about using raw_ops::read_file if you want to decode it using this TensorFlow wrapper.

https://github.com/tensorflow/rust/blob/master/examples%2Fmobilenetv3.rs#L38-L38

sir-earl commented 8 months ago

For my use case, the file is already in memory (received via network) so it would be wasteful to load it from disk with raw_ops::read_file.

My concern with putting non-UTF8 character into a String is that Rust is likely to be unhappy, and may for instance return the wrong length, causing data corruption or other undefined behaviour.

It feels like using a different type to represent the string data type might be more sensible, especially given the hoops required to convert a Rust byte container to a Rust String.

dskkato commented 8 months ago

As indicated by the namespace raw_ops, this Op is merely a Rust wrapper of TensorFlow's functionalities. For further details on this API, please refer to the following documentation:

https://www.tensorflow.org/api_docs/python/tf/raw_ops/DecodeImage

While it might be possible to wrap raw_ops to create a more Rust-like API, currently, nobody seems to have undertaken that effort.

adamcrume commented 8 months ago

I've tried to convert a rank-1 tensor of dtype=uint8 to a rank-0 tensor of dtype=string using Cast, ReduceJoin, and DecodeRaw, and they all fail. There doesn't seem to be any way to convert individual bytes to a string in TensorFlow (i.e. the inverse of BytesSplit).

I think we actually need to introduce either Tensor<Vec<u8>> or TString and Tensor<TString> (analogous to OsString or CString) and deprecate Tensor<String> since TensorFlow strings are not necessarily UTF-8.

Your concern about calling from_utf8_unchecked on a something that is not valid UTF-8 is quite valid. The docs say

Constructing a non-UTF-8 string slice is not immediate undefined behavior, but any function called on a string slice may assume that it is valid UTF-8, which means that a non-UTF-8 string slice can lead to undefined behavior down the road.