Additional Detail from JIRA
| | |
|------------------|-----------------|
|Votes | 0 |
|Component/s | Swift for TensorFlow |
|Labels | Bug |
|Assignee | @marcrasi |
|Priority | Medium |
md5: 4ec014473c6134e71973707eceb02505
Issue Description:
Currently AD treats functions and methods the same way. However, this is not consistent with how method adjoints are defined.
We require the adjoint of a static function to be defined in the same type context. They follow the same typing rule in Swift source like top-level functions.
... because the metatype is one of the original parameters!
We need to teach AD to handle @convention(method) functions. Whenever we differentiate such a function (in AdjointEmitter::visitApplyInst), we need to make sure not to treat the last parameter in the original function as a parameter, and pass it at the end instead.
We also need to change SILFunctionType::getGradientType to form the correct gradient type if the input is a @convention(method) function, so that `createGradFunction` in AD will create the correct function prototypes.
Additional Detail from JIRA
| | | |------------------|-----------------| |Votes | 0 | |Component/s | Swift for TensorFlow | |Labels | Bug | |Assignee | @marcrasi | |Priority | Medium | md5: 4ec014473c6134e71973707eceb02505Issue Description:
Currently AD treats functions and methods the same way. However, this is not consistent with how method adjoints are defined.
We require the adjoint of a static function to be defined in the same type context. They follow the same typing rule in Swift source like top-level functions.
However, when this is translated down to SIL, dFoo has unexpected type: There's a meta type at the end!
This makes the compiler not be be able to match the adjoint type, because AD thinks the gradient type is
... because the metatype is one of the original parameters!
We need to teach AD to handle @convention(method) functions. Whenever we differentiate such a function (in AdjointEmitter::visitApplyInst), we need to make sure not to treat the last parameter in the original function as a parameter, and pass it at the end instead.
We also need to change SILFunctionType::getGradientType to form the correct gradient type if the input is a @convention(method) function, so that `createGradFunction` in AD will create the correct function prototypes.