Oneflow-Inc / one-fx

A toolkit for developers to simplify the transformation of nn.Module instances. It's now corresponding to Pytorch.fx.
Other
13 stars 1 forks source link

一个如果对齐可能影响一定性能的点 #15

Closed AsakusaRinne closed 1 year ago

AsakusaRinne commented 1 year ago

在torch.fx的一个example中,出现了symbolic trace一个模块之后,将Proxy变量作为参数去执行这个生成的模块的做法,这样做的原因是要单独生成一个新Node然后进行插入、连接操作。

这里会因为 #2 这个issue里面的解决方案产生与torch.fx之间的差异,torch.fx借助__torch_function__自动代理了所有torch里面的函数,但oneflow这边没有这个机制所以我们对oneflow所有函数当作普通函数进行代理,这二者的差异在于一个是全局的,一个是非全局的。所以在我们把module -> fx.graph -> module这个过程当作端到端的过程,只把生成的module当成nn.Module来用的时候,不会有什么差异,但如果我们传入Proxy就会有一定差异。

这一问题有三种解决方案:

  1. 添加一个global_wrap_oneflow_all的API,让用户每次这样使用GraphModule的时候暂时全局wrap掉所有的oneflow函数,因为考虑到用户直接去使用代理作为参数传入GraphModule的情况一般都发生在对模型进行调整的阶段,这里对性能不是很敏感。缺点是API没和torch.fx对齐。
  2. GraphModule.reconpilecall_wrapped里面,加上对于当前python code的global wrap,这样可以和torch.fx在用户使用上对齐,但这样每次调用这个模块的forward的时候会多出一步处理,应该会对性能产生一点损耗。具体损耗多少需要后续做一个benchmark测试。
  3. 做一个和__torch_function__类似的机制,这个需要再探索一下。
BBuf commented 1 year ago

感觉可以先试试第二种方案,如果影响比较小的话是更好一点的,不需要额外的说明。

AsakusaRinne commented 1 year ago

更新: #16 用第二种方案进行了处理,同时也提供了global_wrap_oneflow_all的API(留着以后可能突然遇到错误,作为一种临时解决方案)。

TODO:

AsakusaRinne commented 1 year ago

当前还有一种情况是仍然处理不了的,那就是在外部进行traced tensor的运算,然后生成GraphModule,可以参考这段代码,整个过程很像tensorflow中的tf.keras.Model(input, output)这个用来生成模型的API。

当然,如果手动开启全局代理是ok的,但这样用户使用上就和torch.fx不对齐了。

这种情况能想到的解决方案有两个:

  1. 实现类似__torch_function__的机制,这一机制简而言之就是当torch function的入参是proxy的时候当场对当前函数进行wrap,但这个目前来看实现还比较麻烦。
  2. 只要有了Proxy变量,立刻开启oneflow函数全局代理。但这个缺点是不好判断关闭全局代理的时机。