I would like to ask if it is possible to perform inference with model.generate() after wrapping with XLA FSDP on TPU?
I found that directly using generate() will cause an error as mentioned in this issue. However, if we rebuild the parameters before generating and free the parameters afterward (similar to how FSDP.forward works), it results in an OOM error on v3-8 TPU (same even for 3B models).
I would like to know if you have encountered the same issue or have any related suggestions. Thanks!
Hi~ Thanks for your great work.
I would like to ask if it is possible to perform inference with model.generate() after wrapping with XLA FSDP on TPU?
I found that directly using generate() will cause an error as mentioned in this issue. However, if we rebuild the parameters before generating and free the parameters afterward (similar to how FSDP.forward works), it results in an OOM error on v3-8 TPU (same even for 3B models).
I would like to know if you have encountered the same issue or have any related suggestions. Thanks!