First this is an amazing project, I found it could convert and infer very well.
Then I tried to convert and train using tf, it could work but seems keras model still show static batch size shape, like 1 or 128 depending on my dummy input for convert. I could train by set trace_shape=True, if not set will fail.
And the training process could work but much slower then torch or tf orginal code running.
Could you help give some suggestions if I could speed up the training?
First this is an amazing project, I found it could convert and infer very well. Then I tried to convert and train using tf, it could work but seems keras model still show static batch size shape, like 1 or 128 depending on my dummy input for convert. I could train by set trace_shape=True, if not set will fail. And the training process could work but much slower then torch or tf orginal code running. Could you help give some suggestions if I could speed up the training?