Open BradLarson opened 2 years ago
I narrowed down the reproducer further, and the crash signature changed. It now crashes on toolchains as early as July 2020. If I add @inlinable
, something peculiar happens. On older toolchains, it would crash just like if it didn't have @inlinable
. On newer toolchains, the crash signature changes to what's mentioned in this issue's first comment.
import _Differentiation
public extension SIMD where Self: Differentiable, Scalar == Float {
// May be `@inlinable`, but adding that attribute may change the crash
// signature.
@derivative(of: sum)
func _vjpSum() -> (value: Float, pullback: (Float) -> TangentVector) {
fatalError()
}
}
The crash for @inlinable
changes between the 2022-03-09 and 2022-03-22 snapshots.
I have reproduced the vanilla crash as far back as the 2020-07-22 toolchain. This is also the crash signature for if you don't add @inlinable
.
So, here is the function in question:
// reverse-mode derivative of SIMD<>.sum()
sil hidden [thunk] [always_inline] [ossa] @$ss4SIMDPsSF6ScalarRpzrlE3sumADyFsAARz16_Differentiation14DifferentiableRzSfACs11SIMDStoragePRtzlTJrSpSr : $@convention(method) <τ_0_0 where τ_0_0 : SIMD, τ_0_0 : Differentiable, τ_0_0.Scalar == Float> (@in_guaranteed τ_0_0) -> (@out Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed Float) -> @out τ_0_0 for <τ_0_0.TangentVector>) {
// %0 // user: %3
// %1 // user: %3
bb0(%0 : $*Float, %1 : $*τ_0_0):
// function_ref SIMD<>._vjpSum()
%2 = function_ref @$ss4SIMDP3ser16_Differentiation14DifferentiableRzSf6Scalars11SIMDStoragePRtzrlE7_vjpSumSf5value_13TangentVectorAdEPQzSfc8pullbacktyF : $@convention(method) <τ_0_0 where τ_0_0 : SIMD, τ_0_0 : Differentiable, τ_0_0.Scalar == Float> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_0.TangentVector>) // user: %3
%3 = apply %2<τ_0_0>(%0, %1) : $@convention(method) <τ_0_0 where τ_0_0 : SIMD, τ_0_0 : Differentiable, τ_0_0.Scalar == Float> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_0.TangentVector>) // user: %4
(%4, %5) = destructure_tuple %3 : $(Float, @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_0.TangentVector>) // users: %10, %6
%6 = convert_function %5 : $@callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_0.TangentVector> to $@callee_guaranteed (Float) -> @out τ_0_0.TangentVector // user: %8
// function_ref thunk for @escaping @callee_guaranteed (@unowned Float) -> (@out A.Differentiable.TangentVector)
%7 = function_ref @$sSf13TangentVector16_Differentiation14DifferentiablePQzIegyr_SfAEIegnr_s4SIMDRzAbCRzSf6Scalars11SIMDStoragePRtzlTR : $@convention(thin) <τ_0_0 where τ_0_0 : SIMD, τ_0_0 : Differentiable, τ_0_0.Scalar == Float> (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> @out τ_0_0.TangentVector) -> @out τ_0_0.TangentVector // user: %8
%8 = partial_apply [callee_guaranteed] %7<τ_0_0>(%6) : $@convention(thin) <τ_0_0 where τ_0_0 : SIMD, τ_0_0 : Differentiable, τ_0_0.Scalar == Float> (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> @out τ_0_0.TangentVector) -> @out τ_0_0.TangentVector // user: %9
%9 = convert_function %8 : $@callee_guaranteed (@in_guaranteed Float) -> @out τ_0_0.TangentVector to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed Float) -> @out τ_0_0 for <τ_0_0.TangentVector> // user: %10
%10 = tuple (%4 : $Float, %9 : $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed Float) -> @out τ_0_0 for <τ_0_0.TangentVector>) // user: %11
return %10 : $(Float, @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed Float) -> @out τ_0_0 for <τ_0_0.TangentVector>) // id: %11
} // end sil function '$ss4SIMDPsSF6ScalarRpzrlE3sumADyFsAARz16_Differentiation14DifferentiableRzSfACs11SIMDStoragePRtzlTJrSpSr'
Note that it has first result being indirect (@out Float
) and therefore address is passed as a first parameter. It seems the same convention is expected for
%3 = apply %2<τ_0_0>(%0, %1) : $@convention(method) <τ_0_0 where τ_0_0 : SIMD, τ_0_0 : Differentiable, τ_0_0.Scalar == Float> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_0.TangentVector>) // user: %4
However, the function returns its result direct, therefore we are having one extra argument passed to function. Hence the assert. Looks like reabstraction is missed somewhere
Reproducer still crashes on 05/24 toolchain.
Additional Detail from JIRA
| | | |------------------|-----------------| |Votes | 0 | |Component/s | | |Labels | Bug, AutoDiff | |Assignee | None | |Priority | Medium | md5: fda31536276ac9af6789e3146d3c3482Issue Description:
Starting with the 2022-03-13 nightly toolchain snapshot, the following simple custom VJP:
when placed in a file and built with `swiftc file.swift` causes an assertion failure of "SIL verification failed: internal/private function cannot be serialized or serializable: !F->isSerialized()". This was not present in the 2022-03-09 nightly snapshot toolchain, and seems to be a fairly recent regression.
The full assertion is as follows: