Closed vzantedeschi closed 7 months ago
Hello, I have just received and passed it, thank you for your submission. I was not familiar with the parameters of model.generate() when I was using the tool before, and I did not take into account some non-tensor parameters. Your modification is very good.
Thank you for the very useful repo!
I made a small change in the code of flops_counter.calculate_flops() so that we can pass non-tensor arguments to model.generate() (e.g., kwargs = {..., max_new_tokens=10}). I couldn't figure out how to pass non-tensor arguments without changing the code.
Other minor changes: small refactor and raise an error when forward_mode is not forward or generate, to see immediately that the runtime error is caused by e.g. a typo.