keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.06k stars 19.35k forks source link

StringLookup does not return expected dtype for multi_hot #19660

Closed Toku11 closed 1 week ago

Toku11 commented 2 weeks ago

Documentation states we should expected float32 when using 'multi_hot', however int64 tensor is being returned

vocab = ["a", "b", "c", "d"]
data = [["a", "c", "d", "d"], ["d", "z", "b", "z"]]
layer = tf.keras.layers.StringLookup(vocabulary=vocab, out
put_mode='multi_hot')
layer(data)```

tf version: 2.16.1 <tf.Tensor: shape=(2, 5), dtype=int64, numpy= array([[0, 1, 0, 1, 1], [1, 0, 1, 0, 1]])>

SuryanarayanaY commented 1 week ago

Hi @Toku11 ,

I have tested the code snippets from APIs and the dtype should be int64 for the cases output_mode with one_hot multi_hot and count . The documentation needs to be changed to int64 instead of float32. Thanks!