jieli-matrix / JSoC2021

Proposal for JSoC2021
MIT License
0 stars 0 forks source link

An outline for JSoC2021 Proposal #1

Open jieli-matrix opened 3 years ago

jieli-matrix commented 3 years ago

最近主要在了解NiLang和ChainRules~所以先提个关于ChainRules的问题。 项目要求里提到

We will port the generated backward rules to ChainRules.jl as an extension

这里写个代码示例

using ChainRules 
# 定义函数foo
function foo(x)
    u = sin(x)
    v = asin(u)
    return v
end

# Reverse Mode
# forward pass calculation
x = π / 4;
u, u_pullback = rrule(sin, x);
v, v_pullback = rrule(asin, u);

# output 
u, v
(0.7071067811865475, 0.7853981633974482)
u_pullback, v_pullback
(ChainRules.var"#sin_pullback#1451"{Float64}(0.7071067811865476), ChainRules.var"#asin_pullback#1458"{Float64}(0.7071067811865475))

# backpropogation process
v̅ = 1;
_, u̅ = v_pullback(v̅);
_, x̄ = u_pullback(u̅)

# output
(Zero(), 1.0)

我想了解的是,关于项目成果,是将用NiLang写成的sparse matrix operation的可逆形式,嵌入到ChainRules.jl提供的自定义rrule吗?用户通过rrule,传入operation和sparse matrix,就可以获得其值与pullback。

自定义rrule

rrule(::typeof(foo), args...; kws...)

function rrule(::typeof(foo), args...; kwargs...)
    y = ...
    # 这里嵌入可逆函数?
    function foo_pullback(ȳ)
        ...
        return s̄elf, ārgs...
    end
    return y, foo_pullback
end

第二个是关于pytorch repo的,这里有一列是“Sparse grad?” Yes表示支持,No表示不支持;我发现其实Yes还真的很少,有一部分原因是一些运算是重复的,比如torch.sparse.mm() torch.mm()都是矩阵乘,用户可以选择使用是保留稀疏矩阵梯度还是不保留的算子;另一部分原因我不太清楚,比如torch.bmm()不支持梯度,当然其接受的输入是稀疏tensor不是矩阵,或许是超过本项目的讨论范围,但我以为在NiLang的框架下实际是可行的。

我目前的思路是,关于实现稀疏矩阵运算,其实运算是分两类的:一类是低阶的运算,比如稀疏矩阵最重要的乘法;一类是高阶的运算,比如之前实现的lowrank_svd,是通过乘法把稀疏矩阵映射到低维的稠密阵做SVD分解,是在低阶运算的基础上完成的。所以其实我个人以为,实现尽可能多的低阶运算是很有意义的,之后高阶运算有更多的扩展。

借鉴了陈老师的模板,我列出了本次提案的主体部分,以及每一部分我想展开的要点;如果两位老师不介意这份有些过于简陋的初稿...还请多提提大方向上的建议,比如哪些section是冗余的或者缺少了,也能降低“写完发现走题再次从0开始”的概率...

其实这是我第一次正式用英文写作啦,原本是打算中文写的,但很佩服@GiggleLiu @johnnychen94 的英文写作能力,就厚着脸皮麻烦两位老师多多指教下我吧~十分感谢!!!

GiggleLiu commented 3 years ago
  1. 👍🏾 你的理解没有任何问题,更加具体的说是这样

    function rrule(::typeof(foo), args...; kwargs...)
    y = call 可逆函数 f(args...; kwargs...)
    function foo_pullback(ȳ)
        gargs... = GVar(gargs)
        ārgs = uncall 可逆函数 ~f(gargs...; kwargs)
        return s̄elf, grad(ārgs)...
    end
    return y, foo_pullback
    end

    我下次写个更加具体的例子。

  2. bmm不要管,这个数据结构在julia下面不存在。然后,你说的没错,从低阶开始。之后可以尝试实现一两个算法的自动微分来展示成果。鼓励用英文,这是个很好的锻炼。

johnnychen94 commented 3 years ago
y = call 可逆函数 f(args...; kwargs...)

这样定义 foorrule 似乎会涉及到 type piracy,除非 foo 是 NiLang 下的函数?

GiggleLiu commented 3 years ago
y = call 可逆函数 f(args...; kwargs...)

这样定义 foorrule 似乎会涉及到 type piracy,除非 foo 是 NiLang 下的函数?

不会的,这是独立的。

julia> function ChainRules.rrule(::typeof(sin), x::Real)
        @show y = 999999
        y, adjy->(NO_FIELDS, 3333333)
        end

julia> Zygote.gradient(sin, 4)
y = 999999 = 999999
(3333333,)

julia> sin(4)
-0.7568024953079282
johnnychen94 commented 3 years ago

不会的,这是独立的。

这里的 type piracy 不是对 sin 而是对 rrule. 这里 MyAD1 并不拥有 rrule, sinReal 中的任何一个所以在 MyAD1 这个模块里扩展定义是不被允许的一种 type piracy:

julia> using Zygote

julia> Zygote.gradient(sin, 4)
(-0.6536436208636119,)

julia> module MyAD1
           using ChainRules
           function ChainRules.rrule(::typeof(sin), x::Real)
               @show y = 999999
               y, adjy->(NO_FIELDS, 3333333)
           end
       end
Main.MyAD1

julia> Zygote.gradient(sin, 4)
y = 999999 = 999999
(3333333,)

避免这种 type-piracy 的话,应该只有这三种方式:

通过包装一个薄薄的 MyAD1.rrule,可以在依然借用 ChainRules 实现的前提下不改变其他的 AD 框架的行为

julia> module MyAD1
           import ChainRules
           rrule(args...) = ChainRules.rrule(args...)
           function rrule(::typeof(sin), x::Real)
               @show y = 999999
               y, adjy->(NO_FIELDS, 3333333)
           end
       end
Main.MyAD1

julia> using Zygote

julia> Zygote.gradient(sin, 4)
(-0.6536436208636119,)
jieli-matrix commented 3 years ago

感谢两位老师的分析与讨论~

借用 ChainRules 实现的前提下不改变其他的 AD 框架的行为

这一点确实是在导出对应backward规则需要考虑到的, 我冒然揣测下陈老师想表达的是,由于是rewrite functions related to sparse matrix in Julia Base by NiLang, 所以实质上是用NiLang形式的重载运算符(比如稀疏矩阵的乘法),如果rrule没有做NiLang层封装,则会破坏涉及到*的矩阵gradient运算?

比如说,这里有个rrule(::typeof(*), mat1::SparseMatrixCSC, mat2::AbstractArray) 只有在引入NiLang后,才会调用我们对其的实现rrule;而其他AD模块依然可以提供其对于稀疏矩阵乘法的实现?

之前没有仔细考虑过type-piracy(期末考完一定好好研究下julia的类型!),理解的有偏差也请两位老师多指点下~

johnnychen94 commented 3 years ago

type piracy 现象来源于 Julia 的多重派发,你可以通过添加一个新的方法来改变原有函数的定义,但是这种“添加新的方法”是一种侵入式的过程(会传递到整个生态),所以为了避免这种侵入式的改变被滥用,就引入了所谓的 type piracy 规则,即:如果模块 M 不包含要扩展的函数及参数类型中的任何一个,这就是一个 type piracy。比如说下面这种就是一个典型的 type piracy:MyProjectPkg 不拥有 * 和 SparseMatrixCSC,这时候只要项目里面有任何一个包引入了 MyProjectPkg,整个生态就都会受到影响。

module MyProjectPkg
    Base.:*(A::SparseMatrixCSC, B::SparseMatrixCSC) = A + B
end

下面这两种都不是 type piracy,因为它对其中的某一个是有所有权的:

module MyProjectPkg
    # 拥有 f
    f(A::SparseMatrixCSC, B::SparseMatrixCSC) = A + B
end

module MyProjectPkg
    # 拥有 MyArray
    struct MyArray{T,N} <: AbstractArray{T, N} end
    Base.:*(A::SparseMatrixCSC, B::MyArray) = A + B
end

所以严格说来,https://giggleliu.github.io/NiLang.jl/stable/examples/port_zygote/ 里面给的 Zygote.@adjoint 的例子也是某种一种 type piracy,因为 norm2 的 pullback 的所有人应该和 norm2 是同一个。换句话说,如果 PkgAnorm2 的所有人,并且 PkgA 或者 ChainRules 给了 Chainrules(::typeof(norm2), args...) 的定义的话,我们就不能通过 type piracy 去插入 NiLang 的实现,因为会改变之前的结果。

除此之外,也存在一些所谓的 licensed/sanctioned type piracy (欧洲殖民时期会给一些海盗颁发所谓的许可证,也就是所谓的 sanctioned pirate 或者说 privateer),许可的 type piracy 行为一般发生在:PkgAPkgB 互相之间没有依赖关系,然后 PkgC 将两个包粘合在一起,并提供了一些方法的扩展,比如说:NNLibCUDA 关于 CUDA.CuArray 类型 扩展了 NNLib 里的方法。这里 NNLibCUDA 的 type piracy 就是一种典型的被接受的行为:因为他们是同一波人维护的...


不过据 @GiggleLiu 说在 AD 里面这种对 pullback 的 type piracy 是很常见的。另一方面,由于目前 ChainRules 也没有给出 SparseMatrixCSC 的 rrule 扩展,所以不会产生冲突,暂时来说可能不是什么大的问题。