Closed AsakusaRinne closed 1 year ago
感觉可以先试试第二种方案,如果影响比较小的话是更好一点的,不需要额外的说明。
更新: #16 用第二种方案进行了处理,同时也提供了global_wrap_oneflow_all
的API(留着以后可能突然遇到错误,作为一种临时解决方案)。
TODO:
global_wrap
加入到recompile中,会有比较多的额外的判定,可以考虑单独分离出来一个API,有些地方就可以硬编码进去减少操作。当前还有一种情况是仍然处理不了的,那就是在外部进行traced tensor的运算,然后生成GraphModule,可以参考这段代码,整个过程很像tensorflow中的tf.keras.Model(input, output)
这个用来生成模型的API。
当然,如果手动开启全局代理是ok的,但这样用户使用上就和torch.fx不对齐了。
这种情况能想到的解决方案有两个:
__torch_function__
的机制,这一机制简而言之就是当torch function的入参是proxy的时候当场对当前函数进行wrap,但这个目前来看实现还比较麻烦。
在torch.fx的一个example中,出现了symbolic trace一个模块之后,将
Proxy
变量作为参数去执行这个生成的模块的做法,这样做的原因是要单独生成一个新Node然后进行插入、连接操作。这里会因为 #2 这个issue里面的解决方案产生与torch.fx之间的差异,torch.fx借助
__torch_function__
自动代理了所有torch里面的函数,但oneflow这边没有这个机制所以我们对oneflow所有函数当作普通函数进行代理,这二者的差异在于一个是全局的,一个是非全局的。所以在我们把module -> fx.graph -> module这个过程当作端到端的过程,只把生成的module当成nn.Module
来用的时候,不会有什么差异,但如果我们传入Proxy
就会有一定差异。这一问题有三种解决方案:
global_wrap_oneflow_all
的API,让用户每次这样使用GraphModule的时候暂时全局wrap掉所有的oneflow函数,因为考虑到用户直接去使用代理作为参数传入GraphModule的情况一般都发生在对模型进行调整的阶段,这里对性能不是很敏感。缺点是API没和torch.fx对齐。GraphModule.reconpile
的call_wrapped
里面,加上对于当前python code的global wrap,这样可以和torch.fx在用户使用上对齐,但这样每次调用这个模块的forward的时候会多出一步处理,应该会对性能产生一点损耗。具体损耗多少需要后续做一个benchmark测试。__torch_function__
类似的机制,这个需要再探索一下。