Oneflow-Inc / swin-transformer

0 stars 0 forks source link

Swin-Transformer 跑通 graph 1d & 2d sbp #8

Closed Ldpe2G closed 2 years ago

Ldpe2G commented 2 years ago

实验分支: swin_clean_ldp_graph

目前遇到的一些问题,把以下问题先绕过能跑通 1d sbp 的 graph:

1、module forward 中构造 rand tensor 会报错,复现问题脚本

import oneflow as flow

class TestGraph(flow.nn.Graph):
    def __init__(self):
        super().__init__()

    def build(self, x):
        shape = (x.shape[0], 1, 1)
        output = x * flow.rand(*shape, dtype=x.dtype, placement=x.placement, sbp=x.sbp)
        # 改成下面的代码则不会报错
        # output = x * flow.randint(0, 1, shape, placement=x.placement, sbp=x.sbp)
        return output

tensor = flow.ones(32, 100, 100, placement=flow.env.all_device_placement("cuda"), sbp=flow.sbp.split(0))
test = TestGraph()
out = test(tensor)
python3 -m oneflow.distributed.launch --nproc_per_node 2 --master_port 12345 test.py

报错信息:

F0111 12:10:01.930294 28125 exec_graph.cpp:122]                                                                                                
[82/99]  File "../oneflow/core/graph/exec_graph.cpp", line 122, in InferBlobDescs                                                                                
CheckPhysicalBlobDesc( *op(), op()->output_bns(), std ... nd_sbp_signature, parallel_ctx, GetBlobDesc4BnInOp)

  File "../oneflow/core/graph/exec_graph.cpp", line 97, in CheckPhysicalBlobDesc
    CheckPhysicalBlobDesc(*::oneflow::private_details:: ... op_parallel_desc, parallel_ctx, *physical_blob_desc)

  File "../oneflow/core/graph/exec_graph.cpp", line 78, in CheckPhysicalBlobDesc

    Check failed: (physical.shape()) == (*::oneflow::private_details::RemoveRValConst(({ 
auto&& _just_value_to_check_ = GetPhysicalShape(logical.shape(), nd_sbp, parallel_desc,
 *parallel_ctx); if (!::oneflow::private_details::JustIsOk(_just_value_to_check_)) { return 
::oneflow::private_details::JustErrorAddStackFrame( 
::oneflow::private_details::JustGetError(_just_value_to_check_), 
"../oneflow/core/graph/exec_graph.cpp", 78, __FUNCTION__, 
"GetPhysicalShape(logical.shape(), nd_sbp, parallel_desc, *parallel_ctx)"); } 
std::forward<decltype(_just_value_to_check_)>(_just_value_to_check_); 
})).Data_YouAreNotAllowedToCallThisFuncOutsideThisFile()) ((32,1,1) vs (16,1,1))

*** Check failure stack trace: ***
    @     0x7f3765588f8d  google::LogMessage::Fail()
    @     0x7f376558a84d  google::LogMessage::SendToLog()
    @     0x7f3765588a4d  google::LogMessage::Flush()
    @     0x7f376558c4a9  google::LogMessageFatal::~LogMessageFatal()                                                                                     
    @     0x7f37567ab71b  oneflow::ExecNode::InferBlobDescs()                                                                                             
    @     0x7f37567b0103  oneflow::Graph<>::TopoForEachNodeWithErrorCaptured()
    @     0x7f37567b0d46  oneflow::Graph<>::TopoForEachNode()
    @     0x7f375682b66f  oneflow::NormalForwardCompTaskNode::BuildExecGphAndRegst()
    @     0x7f37567ff1a9  oneflow::Graph<>::TopoForEachNodeWithErrorCaptured()
    @     0x7f37567ff98b  oneflow::Graph<>::TopoForEachNode()
    @     0x7f375687f856  oneflow::Compiler::Compile()
    @     0x7f37563e37e4  oneflow::NNGraph::CompileAndInitRuntime()

2、graph 中调用 module 以 key-word 方式传参会报错,复现问题脚本:

class Test1Module(flow.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, mask=None):
        return (x + mask).sum()

class TestGraph(flow.nn.Graph):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def build(self, x, mask):
        return self.module(x, mask=mask)

tensor = flow.ones(32, 100, 100, placement=flow.env.all_device_placement("cuda"), sbp=flow.sbp.split(0))
mask = flow.ones(32, 100, 100, placement=flow.env.all_device_placement("cuda"), sbp=flow.sbp.split(0))

module = Test1Module()
test = TestGraph(module)
out = test(tensor, mask)
python3 -m oneflow.distributed.launch --nproc_per_node 2 --master_port 12345 test.py

报错信息:

[ERROR](GRAPH:TestGraph_0:TestGraph) building graph got error: <class 'TypeError'> 
__call__() got an unexpected keyword argument 'mask'
Traceback (most recent call last):
  File "test.py", line 38, in <module>
    out = test(tensor, mask)
  File "/home/liangdepeng/ldp/oneflow/python/oneflow/nn/graph/graph.py", line 258, in __call__
    self._compile(*args)
  File "/home/liangdepeng/ldp/oneflow/python/oneflow/nn/graph/graph.py", line 506, in _compile
    eager_outputs = self._build_graph(*args)
  File "/home/liangdepeng/ldp/oneflow/python/oneflow/nn/graph/graph.py", line 598, in _build_graph
    outputs = self.build(*lazy_args)
  File "test.py", line 31, in build
    return self.module(x, mask=mask)
TypeError: __call__() got an unexpected keyword argument 'mask'

3、 swin 中用到了 flowvision 中的 CosineLRScheduler ,是个自定义的 scheduler 先换成了 flow.optim.lr_scheduler.CosineAnnealingLR

4、lazy 的 clip grad 实现需要支持 clip_grad_max_norm > 1.0

strint commented 2 years ago

rand那个,用这个randint试下? 参考例子:https://github.com/Oneflow-Inc/oneflow/pull/7092/files

kwargs的明天我加下

Ldpe2G commented 2 years ago

rand那个,用这个randint试下? 参考例子:https://github.com/Oneflow-Inc/oneflow/pull/7092/files

kwargs的明天我加下

flow.rand 改成 flow.randint 就不会报错了

Ldpe2G commented 2 years ago

2d sbp 遇到 reshape sbp 推导的错误

实验分支 swin_clean_ldp_graph,脚本 debug_with_real_data_ddp.sh oneflow 分支 fix-lazy_copy_cost

strint commented 2 years ago

3、 swin 中用到了 flowvision 中的 CosineLRScheduler ,是个自定义的 scheduler 先换成了flow.optim.lr_scheduler.CosineAnnealingLR

这个LR是oneflow中还没有的?是不是 CosineDecayLR?

Ldpe2G commented 2 years ago

2d sbp 遇到 reshape sbp 推导的错误

实验分支 swin_clean_ldp_graph,脚本 debug_with_real_data_ddp.sh oneflow 分支 fix-lazy_copy_cost

oneflow 分支 fix-lazy_copy_cost 合并 master,且在实验分支 swin_clean_ldp_graph 模型代码中 3 处添加 to_consistent 能跑通 2d sbp:

在该处添加两个连续的to_consistent 修改 grad 的 sbp, https://github.com/Oneflow-Inc/swin-transformer/blob/swin_clean_ldp_graph/swin_transformer/models/swin_transformer.py#L235

unsqueeze_relative_position_bias = relative_position_bias.unsqueeze(0)
unsqueeze_relative_position_bias = unsqueeze_relative_position_bias.to_consistent(grad_sbp=unsqueeze_relative_position_bias.sbp)
unsqueeze_relative_position_bias = unsqueeze_relative_position_bias.to_consistent(grad_sbp=(flow.sbp.broadcast, flow.sbp.broadcast))
attn = attn + unsqueeze_relative_position_bias

这里连续修改两次 grad sbp 是试出来的,这两行都不加或者只加 第一行都会遇到 reshape_like 算子 sbp 推导不匹配的问题,所以就试了下把 grad 的 sbp 先转 [B, B] 再转成 前向输出的 sbp 。

和该处修改前向 feature 的 sbp https://github.com/Oneflow-Inc/swin-transformer/blob/swin_clean_ldp_graph/swin_transformer/models/swin_transformer.py#L244

nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.to_consistent(sbp=(flow.sbp.split(0), flow.sbp.split(0)))
attn = attn.view(-1, self.num_heads, N, N)

这里这样改是因为 , 第一行的结果 attn 的 sbp 是 [S(0), S(1)] ,就会在第二个 view 操作处报错,因为该操作要合并第0和第1维。

leaves-zwx commented 2 years ago

reshape_like 算子可以加 log:https://github.com/Oneflow-Inc/OneTeam/issues/978#issuecomment-1014356871 ,看看不匹配的到底是什么情况。

Ldpe2G commented 2 years ago

reshape_like 算子可以加 log:Oneflow-Inc/OneTeam#978 (comment) ,看看不匹配的到底是什么情况。

好的

Ldpe2G commented 2 years ago

突然发现基于 oneflow 分支 fix-lazy_copy_cost 合并 master 之后,模型代码中不需要加 手动加 to_consistent 也能跑通 2d sbp 了

rentainhe commented 2 years ago

WarmupCosineLR可以替换为这个

def WarmupCosineLR(
    optimizer: flow.optim.Optimizer,
    max_iters: int,
    warmup_factor: float,
    warmup_iters: int,
    alpha: float = 0.0,
    warmup_method: str = "linear",
    **kwargs,
):
    cosine_decay_lr = flow.optim.lr_scheduler.CosineDecayLR(
        optimizer, decay_steps=max_iters, alpha=alpha
    )
    if warmup_iters == 0:
        logger.warning("warmup iters equals to zero, return CosineLR")
        return cosine_decay_lr
    elif warmup_iters > max_iters:
        logger.warning("warmup iters is larger than the total training iters")
    warmup_cosine_lr = flow.optim.lr_scheduler.WarmUpLR(
        cosine_decay_lr,
        warmup_factor=warmup_factor,
        warmup_iters=warmup_iters,
        warmup_method=warmup_method,
        **kwargs,
    )
    return warmup_cosine_lr
strint commented 2 years ago

问题1和2,master已经支持。