CINN的架构如下图所示,分为前端后端和执行器,其中前端的主要功能是基于 PIR 进行图层级别的优化,并对子图进行划分为后端高性能 Kernel 代码生成提供支持;后端主要负责将前端处理后的 IR 转换为目标硬件可执行的代码或硬件描述,主要功能包括基于硬件特性的 IR 优化、高效内存管理和代码生成等;最后再由执行器的运行调度接口对编译器生成的 Kernel 进行封装。
这里出现了两类IR,PIR和后端的AST IR,它们都起到了计算和数据进行表示表达的作用,上述的编译期的整个工作过程其实也可以说是对IR的分析变换,我们抽象为前后端Pass以及后端特有的编排调优Schedule
比如很简单的一个子图:
# shape of x, y is [64, 128]
def forward(self, x, y):
tmp = x - y
out = tmp * x
return out
一、任务背景与列表
深度学习编译器是一种专门为深度学习模型优化和部署而设计的工具,其功能是将高层次的深度学习模型转换为低层次的、高效的、底层硬件可执行的代码。飞桨3.0推出了与框架一体化的CINN编译器,能同时支持训练和推理过程,并且具备处理动态可变形状输入的能力。目前,CINN的编译器主要被分为两个阶段:前端与后端。前端主要在PIR层面做一些图层的优化,经过lowering之后上层表示会转化为更贴近硬件实现的后端AST IR表示,后端会在AST IR的基础上进行一系列的分析与变换,最终产生更高效的硬件实现。对IR的分析与变换在编译器中被抽象为了Pass。然而CINN后端IR在设计初期未划分层次结构,难于进行分析与变换,且Pass的编写形式也缺少规范化。 近期,我们升级了后端IR表示,使用了更清晰的层次结构,并提供了方便的访问形式与易用的Pass机制。有了这些核心基础组件,后端的已有的一些转换函数也需要伴随这次升级进行相应的改造。
⭐️ 提交PR 模版 ⭐️:
// ------- PR 标题 --------
// ------- PR 内容 --------
二、任务详情
2.1 CINN编译器介绍
CINN的架构如下图所示,分为前端后端和执行器,其中前端的主要功能是基于 PIR 进行图层级别的优化,并对子图进行划分为后端高性能 Kernel 代码生成提供支持;后端主要负责将前端处理后的 IR 转换为目标硬件可执行的代码或硬件描述,主要功能包括基于硬件特性的 IR 优化、高效内存管理和代码生成等;最后再由执行器的运行调度接口对编译器生成的 Kernel 进行封装。 这里出现了两类IR,PIR和后端的AST IR,它们都起到了计算和数据进行表示表达的作用,上述的编译期的整个工作过程其实也可以说是对IR的分析变换,我们抽象为前后端Pass以及后端特有的编排调优Schedule 比如很简单的一个子图:
转换成PIR就变成了如下Tensor级别的高层次表示,不体现底层的计算逻辑:
经过CINN的前端变换会得到一组组的可以融合起来的FusionOp,这里例子里只有一组subtract+multiply的FusionOp:
后端会对这个FusionOp进行代码生成,然后编译成Jit Kernel以供执行器调用,这里的第一步就是需要将前端IR lowering转换成后端AST IR。AST IR更直观地表达出一个子图到底是怎么算的:
可以看出,两个shape为[64, 128]的tensor的相减后乘是通过两层串行的for循环实现的,循环体进行tensor特定元素的减法和乘法。除了for和加减乘除这样的常见语法,这里还出现了ScheduleBlock这个概念。这其实是为了后续的Schedule编排优化而对语句做的封装,经过Schedule之后这段代码能更贴近硬件实现比如使用32个block,每个block256个线程来完成上述的计算:
2.2 CINN后端升级后的IR层次结构及Pass写法
在此次升级前,的IR中的所有元素都用Expr表示,并不区分语句和表达式,缺少层次结构。这种扁平的设计导致用户在编写后端转换Pass时十分困难且容易出错。在这个基础上,之前对IR的分析与变换主要使用IRMutator/Visitor,每一类语句、表达式都需要实现不同的访问时的反馈函数,后端的很多优化其实是在语句级别的,IRMutator/Visitor却可能遍历访问到最内层的表达式,非常低效且难以理解。此次后端升级主要进行了三大改造:IR结构、IR访问方法、pass编写模式
2.2.1 IR层次结构
升级后,后端IR主要由以下元素构成:module、function、block、statement、expr,其层次结构如下:
module的语义可以对标一个cpp/cu文件,可以包含多个function。一个function含有一个block body,block表示一个代码段,可以理解为一个c++的一个花括号。一个block里面可以包含零条或多条statement,一条statement里面又可以包含零个或多个block以及表达式,比如for语句内包含一个block而if-then-else语句包含两个block;具体的语句定义以及组成元素可以参考Paddle/paddle/cinn/ir/stmt.h。对statement(后简称为stmt)和block的抽象以及关系的描述是本次IR结构升级的核心。
2.2.2 IR访问方法
在对stmt和block进行抽象封装的基础上,我们进一步改造了IR的访问方法,主要提供各种简洁的stmt遍历方法包括:
void Visit(const StmtRef &stmt, const std::function<void(const StmtRef &)> &pre_callback, const std::function<void(const StmtRef &)> &post_callback); // Mutators // ...
2.2.3 Pass编写模式
在此次升级前,我们并未对后端变换pass编写做任何规范,只是要求访问时使用IRMutator/Visitor,此次升级我们将pass分为四类:FunctionPass、BlockPass、StatementPass以及ExprPass,选择何种pass的核心是:该层次是否包含了所有变换所需信息。比如你想做的变换仅仅需要对stmt内部的信息进行分析和处理或是改变其本身,而不需要其上下stmt信息或是外层block的信息,那你的选择应该是StatementPass。编写pass时只需要继承相应pass基类然后重写相应的run函数即可,剩下的工作都由CINN的pass管理机制来完成即可。
三、可参考PR