Open jieli-matrix opened 3 years ago
👍🏾 你的理解没有任何问题,更加具体的说是这样
function rrule(::typeof(foo), args...; kwargs...)
y = call 可逆函数 f(args...; kwargs...)
function foo_pullback(ȳ)
gargs... = GVar(gargs)
ārgs = uncall 可逆函数 ~f(gargs...; kwargs)
return s̄elf, grad(ārgs)...
end
return y, foo_pullback
end
我下次写个更加具体的例子。
bmm不要管,这个数据结构在julia下面不存在。然后,你说的没错,从低阶开始。之后可以尝试实现一两个算法的自动微分来展示成果。鼓励用英文,这是个很好的锻炼。
y = call 可逆函数 f(args...; kwargs...)
这样定义 foo
的 rrule
似乎会涉及到 type piracy,除非 foo
是 NiLang 下的函数?
y = call 可逆函数 f(args...; kwargs...)
这样定义
foo
的rrule
似乎会涉及到 type piracy,除非foo
是 NiLang 下的函数?
不会的,这是独立的。
julia> function ChainRules.rrule(::typeof(sin), x::Real)
@show y = 999999
y, adjy->(NO_FIELDS, 3333333)
end
julia> Zygote.gradient(sin, 4)
y = 999999 = 999999
(3333333,)
julia> sin(4)
-0.7568024953079282
不会的,这是独立的。
这里的 type piracy 不是对 sin
而是对 rrule
. 这里 MyAD1 并不拥有 rrule
, sin
和 Real
中的任何一个所以在 MyAD1 这个模块里扩展定义是不被允许的一种 type piracy:
julia> using Zygote
julia> Zygote.gradient(sin, 4)
(-0.6536436208636119,)
julia> module MyAD1
using ChainRules
function ChainRules.rrule(::typeof(sin), x::Real)
@show y = 999999
y, adjy->(NO_FIELDS, 3333333)
end
end
Main.MyAD1
julia> Zygote.gradient(sin, 4)
y = 999999 = 999999
(3333333,)
避免这种 type-piracy 的话,应该只有这三种方式:
rrule
扩展 NiLang 自己的算子 foo
NiLang.rrule
,然后只扩展 NiLang.rrule
而不扩展 ChainRules.rrule
。这应该是唯一不引入 type piracy 就能扩展 +, -, *, / 等Julia基本算子的方案。通过包装一个薄薄的 MyAD1.rrule
,可以在依然借用 ChainRules 实现的前提下不改变其他的 AD 框架的行为
julia> module MyAD1
import ChainRules
rrule(args...) = ChainRules.rrule(args...)
function rrule(::typeof(sin), x::Real)
@show y = 999999
y, adjy->(NO_FIELDS, 3333333)
end
end
Main.MyAD1
julia> using Zygote
julia> Zygote.gradient(sin, 4)
(-0.6536436208636119,)
感谢两位老师的分析与讨论~
借用 ChainRules 实现的前提下不改变其他的 AD 框架的行为
这一点确实是在导出对应backward规则需要考虑到的, 我冒然揣测下陈老师想表达的是,由于是rewrite functions related to sparse matrix in Julia Base by NiLang, 所以实质上是用NiLang形式的重载运算符(比如稀疏矩阵的乘法),如果rrule没有做NiLang层封装,则会破坏涉及到*的矩阵gradient运算?
比如说,这里有个rrule(::typeof(*), mat1::SparseMatrixCSC, mat2::AbstractArray) 只有在引入NiLang后,才会调用我们对其的实现rrule;而其他AD模块依然可以提供其对于稀疏矩阵乘法的实现?
之前没有仔细考虑过type-piracy(期末考完一定好好研究下julia的类型!),理解的有偏差也请两位老师多指点下~
type piracy 现象来源于 Julia 的多重派发,你可以通过添加一个新的方法来改变原有函数的定义,但是这种“添加新的方法”是一种侵入式的过程(会传递到整个生态),所以为了避免这种侵入式的改变被滥用,就引入了所谓的 type piracy 规则,即:如果模块 M 不包含要扩展的函数及参数类型中的任何一个,这就是一个 type piracy。比如说下面这种就是一个典型的 type piracy:MyProjectPkg
不拥有 * 和 SparseMatrixCSC
,这时候只要项目里面有任何一个包引入了 MyProjectPkg
,整个生态就都会受到影响。
module MyProjectPkg
Base.:*(A::SparseMatrixCSC, B::SparseMatrixCSC) = A + B
end
下面这两种都不是 type piracy,因为它对其中的某一个是有所有权的:
module MyProjectPkg
# 拥有 f
f(A::SparseMatrixCSC, B::SparseMatrixCSC) = A + B
end
module MyProjectPkg
# 拥有 MyArray
struct MyArray{T,N} <: AbstractArray{T, N} end
Base.:*(A::SparseMatrixCSC, B::MyArray) = A + B
end
所以严格说来,https://giggleliu.github.io/NiLang.jl/stable/examples/port_zygote/ 里面给的 Zygote.@adjoint
的例子也是某种一种 type piracy,因为 norm2
的 pullback 的所有人应该和 norm2
是同一个。换句话说,如果 PkgA
是 norm2
的所有人,并且 PkgA
或者 ChainRules
给了 Chainrules(::typeof(norm2), args...)
的定义的话,我们就不能通过 type piracy 去插入 NiLang 的实现,因为会改变之前的结果。
除此之外,也存在一些所谓的 licensed/sanctioned type piracy (欧洲殖民时期会给一些海盗颁发所谓的许可证,也就是所谓的 sanctioned pirate 或者说 privateer),许可的 type piracy 行为一般发生在:PkgA
和 PkgB
互相之间没有依赖关系,然后 PkgC
将两个包粘合在一起,并提供了一些方法的扩展,比如说:NNLibCUDA
关于 CUDA.CuArray
类型 扩展了 NNLib
里的方法。这里 NNLibCUDA
的 type piracy 就是一种典型的被接受的行为:因为他们是同一波人维护的...
不过据 @GiggleLiu 说在 AD 里面这种对 pullback 的 type piracy 是很常见的。另一方面,由于目前 ChainRules 也没有给出 SparseMatrixCSC 的 rrule 扩展,所以不会产生冲突,暂时来说可能不是什么大的问题。
最近主要在了解NiLang和ChainRules~所以先提个关于ChainRules的问题。 项目要求里提到
这里写个代码示例
我想了解的是,关于项目成果,是将用NiLang写成的sparse matrix operation的可逆形式,嵌入到ChainRules.jl提供的自定义rrule吗?用户通过rrule,传入operation和sparse matrix,就可以获得其值与pullback。
自定义rrule
第二个是关于pytorch repo的,这里有一列是“Sparse grad?” Yes表示支持,No表示不支持;我发现其实Yes还真的很少,有一部分原因是一些运算是重复的,比如torch.sparse.mm() torch.mm()都是矩阵乘,用户可以选择使用是保留稀疏矩阵梯度还是不保留的算子;另一部分原因我不太清楚,比如torch.bmm()不支持梯度,当然其接受的输入是稀疏tensor不是矩阵,或许是超过本项目的讨论范围,但我以为在NiLang的框架下实际是可行的。
我目前的思路是,关于实现稀疏矩阵运算,其实运算是分两类的:一类是低阶的运算,比如稀疏矩阵最重要的乘法;一类是高阶的运算,比如之前实现的lowrank_svd,是通过乘法把稀疏矩阵映射到低维的稠密阵做SVD分解,是在低阶运算的基础上完成的。所以其实我个人以为,实现尽可能多的低阶运算是很有意义的,之后高阶运算有更多的扩展。
借鉴了陈老师的模板,我列出了本次提案的主体部分,以及每一部分我想展开的要点;如果两位老师不介意这份有些过于简陋的初稿...还请多提提大方向上的建议,比如哪些section是冗余的或者缺少了,也能降低“写完发现走题再次从0开始”的概率...
其实这是我第一次正式用英文写作啦,原本是打算中文写的,但很佩服@GiggleLiu @johnnychen94 的英文写作能力,就厚着脸皮麻烦两位老师多多指教下我吧~十分感谢!!!