keras-team / keras-hub

Pretrained model hub for Keras 3
Apache License 2.0
804 stars 243 forks source link

Convert input dictionary to tensors during train_on_batch #1919

Closed wenxindongwork closed 1 month ago

wenxindongwork commented 1 month ago

The causal_lm_preprocessor returns x as an dict of {"token_ids":Array, "padding_mask":Array} so calling ops.convert_to_tensor(x) will fail. Usingx = tree.map_structure(ops.convert_to_tensor, x) resolves the problem.