Open Fireblossom opened 1 year ago
Hi,
thanks for your efforts to contribute to the open source MAGMA code. The codepath for inference during training should indeed be fixed and your help is appreciated :-)
In addition to the comments I added, I would in general prefer not to overload the forward function of the model too much. Maybe you could try to just change the inference_step
method to invoke model.generate
instead of changing the forward pass, see https://github.com/Aleph-Alpha/magma/blob/4d01e5172115ab4a8f4b4bf8da76dbc08b6cf36c/magma/train_loop.py#L85
Thanks again and let me know what you think.
Best,
Constantin
Hi Constantin,
thank you for your advice. I was going to do the same.
But in practice, I found that deepspeed's model_engine
cannot call methods other than forward
.
(I'm a beginner in deepspeed, as this is my first time with it, so please point out if I'm wrong)
But if I call the model
directly, it may lead to some unexpected errors, like device mismatch.
For the above reasons, I had to add the code into forward
.
Best,
Changxu
Ok, I think you can just access the model by model_engine.model
, not 100% sure if that always works but maybe give it a try.
Best,
Constantin
inference_step passes
inference=True
tomodel_engine
. However, the__forward__
of the Magma model does not accept this parameter, which will cause an error during training. I fix it by simply copying the inference code fromexample_inference.py
.