microsoft / knossos-ksc

Compiler with automatic differentiation
Other
45 stars 10 forks source link

Respect generate_lm argument when generating autograd function #837

Closed dcrc2 closed 3 years ago

dcrc2 commented 3 years ago

When wrapping ks code as a torch.autograd.Function, PR #762 unconditionally uses sufrev_entry to implement backward (see commit https://github.com/microsoft/knossos-ksc/pull/762/commits/25db62cdb7d9edef53f596c9d2ed973de5b12abd ). This breaks ts2mod when generate_lm=True is requested.

Similarly, PR #820 is broken in that, although you can request SUF in ts2mod, the resulting module cannot actually be used as an autograd.Function because this would call the non-existent function rev_entry.

To fix this, this PR passes the generate_lm argument down to the place where backward is implemented, so that we can use either rev_entry or sufrev_entry as appropriate.