swiftlang / swift

The Swift Programming Language
https://swift.org
Apache License 2.0
67.69k stars 10.39k forks source link

[SR-14297] [AutoDiff] Storing an inactive [Float] array causes a precondition fail in Array.DifferentiableView during the backward pass #56656

Closed vojtamolda closed 3 years ago

vojtamolda commented 3 years ago
Previous ID SR-14297
Radar rdar://problem/75032459
Original Reporter @vojtamolda
Type Bug
Status Resolved
Resolution Done
Environment macOS 11.2.2 with Swift development snapshot 2021-02-24: Apple Swift version 5.4-dev (LLVM 0eb4a6165bbbce5, Swift 288a0db849d8506) Target: x86_64-apple-darwin20.3.0
Additional Detail from JIRA | | | |------------------|-----------------| |Votes | 0 | |Component/s | Compiler | |Labels | Bug, AutoDiff | |Assignee | @vojtamolda | |Priority | Medium | md5: a2e01f966545d298724b1a0508fa6c26

Issue Description:

Hello everyone,

I believe I stumbled upon a bug in the Array.DifferentiableView. The code chunk below causes a precondition fail during the backward pass.

Here's my "amateur astronomer" observations about why the precondition fails. Calculation involving lastFlat1 is inactive, i.e. it isn't used to produce the sum, but it is nevertheless stored in the `Flatter` struct during execution the `flatten(...)` call. When the backward pass runs, the Flatter.TangentVector.lastFlat1 is empty which triggers the failure.

I tried investigating the ArrayDifferentiation.swift file but, I don't have enough experience with the standard library to fix it.

The reproducer looks strangely useless because its taken from a more complex application that has an alternative lost function which uses both lastFlat0 and lastFlat1. But I believe this kind of behavior, where only a subset of differentiable attributes contribute to the cost function, will be common in bigger projects.

import _Differentiation

struct Flatter: Differentiable {
    var lastFlat0: [Float] = []
    var lastFlat1: [Float] = []

    @differentiable(reverse)
    mutating func flatten(_ matrix: [[Float]]) {
        self.lastFlat0 = matrix.differentiableReduce([], +)
        self.lastFlat1 = matrix.differentiableReduce([], +)
    }
}

let resolution = 2
let row = [Float](repeating: 1.0, count: resolution)
let matrix = [[Float]](repeating: row, count: resolution)

let (sum, grad) = valueWithGradient(at: matrix) { matrix -> Float in
    var flatter = Flatter()
    flatter.flatten(matrix)
    // Note: flatter.lastFlat1 is unused and inactive here
    return flatter.lastFlat0.differentiableReduce(0.0, +)
}

print(sum, grad)

This is the exact runtime error produced when running:

_Differentiation/ArrayDifferentiation.swift:220: Precondition failed: Tangent vector with invalid count; expected to equal the sum of operand counts 2 and 2
2021-03-03 08:44:18.789757-0600 DifferentiableReduceTest[63658:1907302] _Differentiation/ArrayDifferentiation.swift:220: Precondition failed: Tangent vector with invalid count; expected to equal the sum of operand counts 2 and 2
typesanitizer commented 3 years ago

@swift-ci create

rxwei commented 3 years ago

@vojtamolda Perhaps the tangent vector was a zero? In that case you just need to detect the zero case `base.isEmpty` in the stdlib code. I'm sorry but I don't have the bandwidth to look into it this week, but would be happy to answer your questions if you'd like to take a stab.

vguerra commented 3 years ago

If I may πŸ™‚. I think that is the case. In @vojtamolda's sample code `v.base.count` is 0.

A side note, a nice improvement to those precondition messages could be to include the value of the count found in addition to the expected one.

vojtamolda commented 3 years ago

Thanks for all the answers. It's nice that all 3 of us agree that the value passed to the vjp pullback closure is an empty array. I wasn't sure if this was the correct and expected behavior.

I've got a question to make sure that my understanding is correct here. Passing empty arrays to vjp closures is a performance optimization? It avoids frequent creation of arrays full of zeros that are in the end added to something else and don't have a chance of affecting the result?

I'll try to submit a PR with a fix to the standard library with some tests. Please, keep your fingers crossed for me πŸ˜‰ It's going to be a wild ride since I've never done that before!

rxwei commented 3 years ago

Yes, the compiler always calls `.zero` to initialize a zero. It is not guaranteed to have the same shape as the corresponding original value. The custom derivatives should be checking for `.zero` values.

vguerra commented 3 years ago

Feel free to reach out if you need help building the code base and/or writing/running the tests. Would be happy to help.

vojtamolda commented 3 years ago

Thanks @vguerra. I'll try to swim on my own and if I drown I'll ping you πŸ˜‰

So far the getting started guide got me to a point where the whole "behemoth" compiled...

vojtamolda commented 3 years ago

To my best knowledge this issue is resolved if PR #36257 gets merged.