CTUAvastLab / Mill.jl

Build flexible hierarchical multi-instance learning models.
https://ctuavastlab.github.io/Mill.jl/stable/
MIT License
86 stars 8 forks source link

Bug in gradient of SegmentedSum #117

Closed pevnak closed 1 year ago

pevnak commented 1 year ago

SegmentedSum does not work properly with ScatteredBags, when instances points to same item in x. This is an MWE.

ds1 = BagNode(ArrayNode([1 2 1; 1 2 1;]), [1:3])
        ds2 = BagNode(ArrayNode([1 2; 1 2]), ScatteredBags([[1,2,1]]))
        m = BagModel(ArrayModel(Dense(2,2)), SegmentedSum(2), identity)
        ps = Flux.params(m)
        fval1, gs1 = Flux.withgradient(() -> sum(m(ds1)), ps)
        fval2, gs2 = Flux.withgradient(() -> sum(m(ds2)), ps)
        @test fval1 ≈ fval2
        sum([sum(abs2.(gs1[p] .- gs2[p])) for p in ps if gs1[p] !== nothing])

Tests lack any presence of this pattern, which is useful for deduplication. The solution is simple, it requires changing equality in https://github.com/CTUAvastLab/Mill.jl/blob/30a61f0c0e8ed045cb0612a28ecf89f1b4c09b3c/src/aggregations/segmented_sum.jl#L70 to += and changing initization of dx from similar to zero.