microsoft / nnscaler

MIT License
59 stars 9 forks source link

Got RuntimeError when using parallel module #10

Closed MachineGunLin closed 2 hours ago

MachineGunLin commented 2 hours ago

Hi, I am trying to use nnscaler to parallelize an Attention module's forward(based on fairseq2's implementation).

I manage to use the parallelize method to parallelize my module and got the gencode.

However, when I try to run the module I got an RuntimeError and don't know how to fix it.

Here is my code(parallelize_attn.py):

import torch
from torch import Tensor
from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.utils.hooks import RemovableHandle
from collections.abc import Callable
from typing_extensions import override

from fairseq2.nn.incremental_state import IncrementalState, IncrementalStateBag
from fairseq2.nn.ops import repeat_interleave
from fairseq2.nn.padding import PaddingMask
from fairseq2.nn.position_encoder import PositionEncoder
from fairseq2.nn.projection import Linear, Projection
from fairseq2.nn.transformer.attention import SDPA, create_default_sdpa
from fairseq2.nn.transformer.attention_mask import AttentionMask, AttentionMaskFactory
from fairseq2.typing import DataType, Device
from fairseq2.nn.transformer.multihead_attention import AttentionStateFactory, StandardMultiheadAttention

model_dim=256
kv_dim=256
num_heads = 8

class fairseq_attn_module(StandardMultiheadAttention):
    def __init__(
        self,
        model_dim: int,
        num_heads: int,
        *,
        kv_dim: int | None = None,
        num_key_value_heads: int | None = None,
        q_proj: Projection | None = None,
        k_proj: Projection | None = None,
        v_proj: Projection | None = None,
        qkv_proj_init_fn: Callable[[Linear], None] | None = None,
        attn_mask_factory: AttentionMaskFactory | None = None,
        pos_encoder: PositionEncoder | None = None,
        sdpa: SDPA | None = None,
        scale_heads: bool = False,
        output_proj: Projection | None = None,
        output_proj_init_fn: Callable[[Linear], None] | None = None,
        bias: bool = True,
        state_factory: AttentionStateFactory | None = None,
        device: Device | None = None,
        dtype: DataType | None = None,
    ) -> None:
        super().__init__(model_dim, num_heads)

    @override
    def forward(
        self,
        seqs: Tensor,
        keys: Tensor,
        values: Tensor,
    ) -> Tensor:
        padding_mask = None
        key_padding_mask = None
        attn_mask = None
        state_bag = None

        # (N, S, M) -> (N, H, S, K_h)
        q = self._project_q(seqs, padding_mask, state_bag)

        if self.training or state_bag is None:
            # k: (N, S_kv, M) -> (N, H_kv, S_kv, K_h)
            # v: (N, S_kv, M) -> (N, H_kv, S_kv, V_h)
            k, v = self._project_kv(keys, key_padding_mask, values)
        else:
            if seqs is keys:  # Self attention
                if key_padding_mask is not None:
                    raise ValueError(
                        "`key_padding_mask` must be `None` during incremental decoding."
                    )

                # k: (N, S_step, M) -> (N, H_kv, S_step, K_h)
                # v: (N, S_step, M) -> (N, H_kv, S_step, V_h)
                k, v = self._project_kv(keys, key_padding_mask, values, state_bag)

                state = state_bag.get_state(self, AttentionState)
                if state is None:
                    state_factory = self.state_factory or FullAttentionState

                    state = state_factory(
                        k, v, state_bag.max_num_steps, state_bag.capacity_increment
                    )

                    state_bag.set_state(self, state)
                else:
                    state.append(k, v)

                    # k: (N, H_kv, S_kv, K_h)
                    # v: (N, H_kv, S_kv, V_h)
                    k, v = state.get()
            else:
                state = state_bag.get_state(self, AttentionState)
                if state is None:
                    # k: (N, S_kv, M) -> (N, H_kv, S_kv, K_h)
                    # v: (N, S_kv, M) -> (N, H_kv, S_kv, V_h)
                    k, v = self._project_kv(keys, key_padding_mask, values)

                    state_factory = self.state_factory or StaticAttentionState

                    state = state_factory(
                        k, v, max_seq_len=k.size(2), capacity_increment=None
                    )

                    state_bag.set_state(self, state)
                else:
                    # k: (N, H_kv, S_kv, K_h)
                    # v: (N, H_kv, S_kv, V_h)
                    k, v = state.get()

        # With Grouped Query Attention, each key/value head is repeated.
        if (num_query_groups := self.num_heads // self.num_key_value_heads) > 1:
            # (N, H_kv, S_kv, K_h) -> (N, H, S_kv, K_h)
            k = repeat_interleave(k, dim=1, repeat=num_query_groups)
            # (N, H_kv, S_kv, K_h) -> (N, H, S_kv, V_h)
            v = repeat_interleave(v, dim=1, repeat=num_query_groups)

        if self.attn_mask_factory is not None:
            attn_mask = self.attn_mask_factory(
                seqs, keys=keys, training=self.training, state_bag=state_bag
            )

        needs_weights = len(self._attn_weight_hooks) > 0

        # attn:         (N, H, S, V_h)
        # attn_weights: (N, H, S, S_kv)
        attn, attn_weights = self.sdpa(
            q,
            k,
            key_padding_mask,
            v,
            attn_mask=attn_mask,
            needs_weights=needs_weights,
        )

        if attn_weights is not None:
            for hook in self._attn_weight_hooks.values():
                hook(self, attn, attn_weights)

        # (N, H, S, V_h) -> (N, S, H, V_h)
        attn = attn.transpose(1, 2)

        if self.head_scale_weight is not None:
            attn = torch.einsum("nshv,h->nshv", attn, self.head_scale_weight)

        # (N, S, H, V_h) -> (N, S, V_proj)
        attn = attn.flatten(2, 3)

        # (N, S, V_proj) -> (N, S, M)
        attn = self.output_proj(attn)

        return attn  # type: ignore[no-any-return]

# attn_module = StandardMultiheadAttention(
#     model_dim=model_dim, num_heads=num_heads
# )
attn_module = fairseq_attn_module(
    model_dim=model_dim, num_heads=num_heads
)

batch_size = 3
input_len = 11
prefix_len = 5
if kv_dim is None:
    kv_dim = model_dim
inputs = torch.randn([batch_size, input_len, kv_dim])
prefix = torch.randn([batch_size, prefix_len, model_dim])

# result = attn_module(
#     seqs=prefix,
#     keys=inputs,
#     values=inputs,
#     padding_mask=None,
#     key_padding_mask=None,
# )
result = attn_module(
    seqs=prefix,
    keys=inputs,
    values=inputs,
)
assert result.shape == prefix.shape
print("\n")
print(f"result.shape: {result.shape}")
print(f"prefix.shape: {prefix.shape}")
print("\n")

from nnscaler.parallel import parallelize, ReuseType, ComputeConfig
from nnscaler.graph import IRGraph
from nnscaler.ir.operator import IRFwOperation

def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph:
    ngpus = resource.plan_ngpus
    partitioned = False
    for idx, node in enumerate(graph.select(ntype=IRFwOperation)):
        if not partitioned:
            print('Partitioned node: ', node)
            sub_nodes = graph.partition(
                node, node.algorithms('dim'), idx=1, dim=0, num=ngpus)
            partitioned = True
        else:
            # raise RuntimeError("partitioned==false")
            sub_nodes = graph.replicate(node, times=ngpus)
        for idx, sub_node in enumerate(sub_nodes):
            graph.assign(sub_node, idx)
    return graph

# parallelize attn
import nnscaler
nnscaler.init()
parallelized_module = parallelize(
    module_or_module_class=attn_module,
    dummy_forward_args={
        'seqs': torch.randn(3, 5, 256),
        'keys': torch.randn(3, 11, 256),
        'values': torch.randn(3, 11, 256),
        'padding_mask': None,
        'key_padding_mask': None,
        'attn_mask': None,
        'state_bag': None,
    },
    # pas_policy='autodist',
    pas_policy=policy,
    compute_config=ComputeConfig(2, 8),
    gen_savedir='./.nnscaler',
    reuse=ReuseType.OVERRIDE,
    instance_name="gen_attn_parallel_module",
    load_module=True,
    broadcast_strategy="all",
)

# test whether parallelized module works (since it's not a end2end module, i suppose the usage is the same as torch.nn.module?)
print("\n")
print(f"parallelized_module: {parallelized_module}")        # fairseq_attn_module()
new_result = parallelized_module(
    seqs=prefix,
    keys=inputs,
    values=inputs,
)
print(f"new_result.shape: {new_result.shape}")
print("\n")

This is the output on terminal:

root@12011bc99720:/home/infiniai/linrongjian/nnscaler# torchrun --nproc_per_node=8 parallelize_attn.py
[2024-10-22 06:18:45,218] torch.distributed.run: [WARNING] 
[2024-10-22 06:18:45,218] torch.distributed.run: [WARNING] *****************************************
[2024-10-22 06:18:45,218] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2024-10-22 06:18:45,218] torch.distributed.run: [WARNING] *****************************************

result.shape: torch.Size([3, 5, 256])
prefix.shape: torch.Size([3, 5, 256])

result.shape: torch.Size([3, 5, 256])

prefix.shape: torch.Size([3, 5, 256])

result.shape: torch.Size([3, 5, 256])
prefix.shape: torch.Size([3, 5, 256])

result.shape: torch.Size([3, 5, 256])
prefix.shape: torch.Size([3, 5, 256])

result.shape: torch.Size([3, 5, 256])
prefix.shape: torch.Size([3, 5, 256])

result.shape: torch.Size([3, 5, 256])
prefix.shape: torch.Size([3, 5, 256])

result.shape: torch.Size([3, 5, 256])
prefix.shape: torch.Size([3, 5, 256])

result.shape: torch.Size([3, 5, 256])
prefix.shape: torch.Size([3, 5, 256])

Find unknown pytorch operation: torch.unflatten
Find unknown pytorch operation: torch.unflatten
Find unknown pytorch operation: torch.unflatten
Partitioned node:  FwOp1-()(name=linear, inputs=(t103(p101,(3, 5, 256),d(),v(0/1)), w81(p7,(256, 256),d(),v(0/1)), w82(p9,(256,),d(),v(0/1))), outputs=(t83(p11,(3, 5, 256),d(),v(0/1)),))

parallelized_module: fairseq_attn_module()

parallelized_module: fairseq_attn_module()

parallelized_module: fairseq_attn_module()

parallelized_module: fairseq_attn_module()

parallelized_module: fairseq_attn_module()

parallelized_module: fairseq_attn_module()

parallelized_module: fairseq_attn_module()

parallelized_module: fairseq_attn_module()
Traceback (most recent call last):
  File "/home/infiniai/linrongjian/nnscaler/parallelize_attn.py", line 237, in <module>
    new_result = parallelized_module(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/infiniai/linrongjian/nnscaler/nnscaler/runtime/module.py", line 846, in forward
    return self._forward_impl(*args, **kwargs)
  File "/home/infiniai/linrongjian/nnscaler/.nnscaler/_parallel_modules/__main__/fairseq_attn_module/gen_attn_parallel_module/gencode6.py", line 125, in _forward_impl
    linear_3_80 = self.segment176(seqs, keys, values)
  File "/home/infiniai/linrongjian/nnscaler/.nnscaler/_parallel_modules/__main__/fairseq_attn_module/gen_attn_parallel_module/gencode6.py", line 78, in segment176
    linear_138 = torch.nn.functional.linear(seqs_103, self.q_proj_weight_134, self.q_proj_bias_136)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:6 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
Traceback (most recent call last):
  File "/home/infiniai/linrongjian/nnscaler/parallelize_attn.py", line 237, in <module>
    new_result = parallelized_module(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
Traceback (most recent call last):
  File "/home/infiniai/linrongjian/nnscaler/parallelize_attn.py", line 237, in <module>
Traceback (most recent call last):
  File "/home/infiniai/linrongjian/nnscaler/parallelize_attn.py", line 237, in <module>
    new_result = parallelized_module(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    new_result = parallelized_module(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return forward_call(*args, **kwargs)
  File "/home/infiniai/linrongjian/nnscaler/nnscaler/runtime/module.py", line 846, in forward
    return self._call_impl(*args, **kwargs)
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
  File "/home/infiniai/linrongjian/nnscaler/parallelize_attn.py", line 237, in <module>
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return self._forward_impl(*args, **kwargs)
  File "/home/infiniai/linrongjian/nnscaler/.nnscaler/_parallel_modules/__main__/fairseq_attn_module/gen_attn_parallel_module/gencode3.py", line 125, in _forward_impl
        new_result = parallelized_module(linear_3_80 = self.segment178(seqs, keys, values)

  File "/home/infiniai/linrongjian/nnscaler/.nnscaler/_parallel_modules/__main__/fairseq_attn_module/gen_attn_parallel_module/gencode3.py", line 78, in segment178
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return forward_call(*args, **kwargs)
  File "/home/infiniai/linrongjian/nnscaler/nnscaler/runtime/module.py", line 846, in forward
    linear_139 = torch.nn.functional.linear(seqs_103, self.q_proj_weight_135, self.q_proj_bias_137)
    return forward_call(*args, **kwargs)
RuntimeError  File "/home/infiniai/linrongjian/nnscaler/nnscaler/runtime/module.py", line 846, in forward
: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
    return self._forward_impl(*args, **kwargs)
  File "/home/infiniai/linrongjian/nnscaler/.nnscaler/_parallel_modules/__main__/fairseq_attn_module/gen_attn_parallel_module/gencode2.py", line 125, in _forward_impl
    return self._forward_impl(*args, **kwargs)
  File "/home/infiniai/linrongjian/nnscaler/.nnscaler/_parallel_modules/__main__/fairseq_attn_module/gen_attn_parallel_module/gencode1.py", line 125, in _forward_impl
    linear_3_80 = self.segment176(seqs, keys, values)
  File "/home/infiniai/linrongjian/nnscaler/.nnscaler/_parallel_modules/__main__/fairseq_attn_module/gen_attn_parallel_module/gencode2.py", line 78, in segment176
        linear_3_80 = self.segment178(seqs, keys, values)return self._call_impl(*args, **kwargs)

  File "/home/infiniai/linrongjian/nnscaler/.nnscaler/_parallel_modules/__main__/fairseq_attn_module/gen_attn_parallel_module/gencode1.py", line 78, in segment178
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    linear_138 = torch.nn.functional.linear(seqs_103, self.q_proj_weight_134, self.q_proj_bias_136)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:2 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
    linear_139 = torch.nn.functional.linear(seqs_103, self.q_proj_weight_135, self.q_proj_bias_137)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
    return forward_call(*args, **kwargs)
  File "/home/infiniai/linrongjian/nnscaler/nnscaler/runtime/module.py", line 846, in forward
    return self._forward_impl(*args, **kwargs)
  File "/home/infiniai/linrongjian/nnscaler/.nnscaler/_parallel_modules/__main__/fairseq_attn_module/gen_attn_parallel_module/gencode5.py", line 125, in _forward_impl
    linear_3_80 = self.segment178(seqs, keys, values)
  File "/home/infiniai/linrongjian/nnscaler/.nnscaler/_parallel_modules/__main__/fairseq_attn_module/gen_attn_parallel_module/gencode5.py", line 78, in segment178
    linear_139 = torch.nn.functional.linear(seqs_103, self.q_proj_weight_135, self.q_proj_bias_137)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:5 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
Traceback (most recent call last):
  File "/home/infiniai/linrongjian/nnscaler/parallelize_attn.py", line 237, in <module>
Traceback (most recent call last):
  File "/home/infiniai/linrongjian/nnscaler/parallelize_attn.py", line 237, in <module>
    new_result = parallelized_module(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    new_result = parallelized_module(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/infiniai/linrongjian/nnscaler/nnscaler/runtime/module.py", line 846, in forward
    return forward_call(*args, **kwargs)
  File "/home/infiniai/linrongjian/nnscaler/nnscaler/runtime/module.py", line 846, in forward
    return self._forward_impl(*args, **kwargs)
  File "/home/infiniai/linrongjian/nnscaler/.nnscaler/_parallel_modules/__main__/fairseq_attn_module/gen_attn_parallel_module/gencode7.py", line 125, in _forward_impl
    linear_3_80 = self.segment178(seqs, keys, values)
  File "/home/infiniai/linrongjian/nnscaler/.nnscaler/_parallel_modules/__main__/fairseq_attn_module/gen_attn_parallel_module/gencode7.py", line 78, in segment178
    linear_139 = torch.nn.functional.linear(seqs_103, self.q_proj_weight_135, self.q_proj_bias_137)
RuntimeError:     Expected all tensors to be on the same device, but found at least two devices, cuda:7 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)return self._forward_impl(*args, **kwargs)

  File "/home/infiniai/linrongjian/nnscaler/.nnscaler/_parallel_modules/__main__/fairseq_attn_module/gen_attn_parallel_module/gencode4.py", line 125, in _forward_impl
    linear_3_80 = self.segment176(seqs, keys, values)
  File "/home/infiniai/linrongjian/nnscaler/.nnscaler/_parallel_modules/__main__/fairseq_attn_module/gen_attn_parallel_module/gencode4.py", line 78, in segment176
    linear_138 = torch.nn.functional.linear(seqs_103, self.q_proj_weight_134, self.q_proj_bias_136)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:4 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
Traceback (most recent call last):
  File "/home/infiniai/linrongjian/nnscaler/parallelize_attn.py", line 237, in <module>
    new_result = parallelized_module(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/infiniai/linrongjian/nnscaler/nnscaler/runtime/module.py", line 846, in forward
    return self._forward_impl(*args, **kwargs)
  File "/home/infiniai/linrongjian/nnscaler/.nnscaler/_parallel_modules/__main__/fairseq_attn_module/gen_attn_parallel_module/gencode0.py", line 125, in _forward_impl
    linear_3_80 = self.segment176(seqs, keys, values)
  File "/home/infiniai/linrongjian/nnscaler/.nnscaler/_parallel_modules/__main__/fairseq_attn_module/gen_attn_parallel_module/gencode0.py", line 78, in segment176
    linear_138 = torch.nn.functional.linear(seqs_103, self.q_proj_weight_134, self.q_proj_bias_136)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
[2024-10-22 06:18:55,237] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 4220) of binary: /usr/bin/python
Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 812, in main
    run(args)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
parallelize_attn.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2024-10-22_06:18:55
  host      : 12011bc99720
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 4221)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[2]:
  time      : 2024-10-22_06:18:55
  host      : 12011bc99720
  rank      : 2 (local_rank: 2)
  exitcode  : 1 (pid: 4222)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[3]:
  time      : 2024-10-22_06:18:55
  host      : 12011bc99720
  rank      : 3 (local_rank: 3)
  exitcode  : 1 (pid: 4223)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[4]:
  time      : 2024-10-22_06:18:55
  host      : 12011bc99720
  rank      : 4 (local_rank: 4)
  exitcode  : 1 (pid: 4224)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[5]:
  time      : 2024-10-22_06:18:55
  host      : 12011bc99720
  rank      : 5 (local_rank: 5)
  exitcode  : 1 (pid: 4225)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[6]:
  time      : 2024-10-22_06:18:55
  host      : 12011bc99720
  rank      : 6 (local_rank: 6)
  exitcode  : 1 (pid: 4226)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[7]:
  time      : 2024-10-22_06:18:55
  host      : 12011bc99720
  rank      : 7 (local_rank: 7)
  exitcode  : 1 (pid: 4227)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-10-22_06:18:55
  host      : 12011bc99720
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 4220)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Before calling parallelized_module's forward, it was all good. gencode0.py looks like this:


########## Generated Model Code ###########
from typing import *
from pathlib import Path
import torch
import torch.utils.checkpoint as ckpt
import nnscaler
import _operator
from numpy import inf
import builtins

import apex.normalization.fused_layer_norm

import apex.normalization.fused_layer_norm

import apex.normalization.fused_layer_norm

import apex.normalization.fused_layer_norm

class GenModel(nnscaler.runtime.module.ParallelModule):
    use_scheduler = False
    nmicros_per_scheduler_step = 1
    rank = 0

    def __init__(self, init_params=True):
        super().__init__()
        # communication groups
        self.init_group(ranks=[0, 2, 4, 6])
        self.init_group(ranks=[1, 3, 5, 7])
        self.init_group(ranks=[0, 1])
        self.init_group(ranks=[2, 3])
        self.init_group(ranks=[4, 5])
        self.init_group(ranks=[6, 7])

        self.register_parameter('q_proj_weight_134', torch.nn.Parameter(torch.empty((128, 256), dtype=torch.float32)))
        self.add_full_map('q_proj_weight_134', 7, True, 'q_proj.weight', (256, 256), (slice(0, 128, None), slice(0, 256, None)), 1)

        self.register_parameter('q_proj_bias_136', torch.nn.Parameter(torch.empty((128,), dtype=torch.float32)))
        self.add_full_map('q_proj_bias_136', 9, True, 'q_proj.bias', (256,), (slice(0, 128, None),), 1)

        self.register_parameter('k_proj_weight_86', torch.nn.Parameter(torch.empty((256, 256), dtype=torch.float32)))
        self.add_full_map('k_proj_weight_86', 17, True, 'k_proj.weight', (256, 256), (slice(0, 256, None), slice(0, 256, None)), 1)

        self.register_parameter('k_proj_bias_87', torch.nn.Parameter(torch.empty((256,), dtype=torch.float32)))
        self.add_full_map('k_proj_bias_87', 19, True, 'k_proj.bias', (256,), (slice(0, 256, None),), 1)

        self.register_parameter('v_proj_weight_89', torch.nn.Parameter(torch.empty((256, 256), dtype=torch.float32)))
        self.add_full_map('v_proj_weight_89', 23, True, 'v_proj.weight', (256, 256), (slice(0, 256, None), slice(0, 256, None)), 1)

        self.register_parameter('v_proj_bias_90', torch.nn.Parameter(torch.empty((256,), dtype=torch.float32)))
        self.add_full_map('v_proj_bias_90', 25, True, 'v_proj.bias', (256,), (slice(0, 256, None),), 1)

        self.register_parameter('output_proj_weight_99', torch.nn.Parameter(torch.empty((256, 256), dtype=torch.float32)))
        self.add_full_map('output_proj_weight_99', 43, True, 'output_proj.weight', (256, 256), (slice(0, 256, None), slice(0, 256, None)), 1)

        self.register_parameter('output_proj_bias_100', torch.nn.Parameter(torch.empty((256,), dtype=torch.float32)))
        self.add_full_map('output_proj_bias_100', 45, True, 'output_proj.bias', (256,), (slice(0, 256, None),), 1)

        self.wreducer180 = nnscaler.runtime.adapter.Reducer(ranks=[0, 2, 4, 6], reduce_op='sum', async_op=False, zero=False, max_bucket_size_bytes=137217728, zero_ngroups=1)
        self.wreducer180.add_param(self.q_proj_weight_134)
        self.wreducer180.add_param(self.q_proj_bias_136)
        self.wreducer180.add_param(self.output_proj_weight_99)
        self.wreducer180.add_param(self.output_proj_bias_100)
        self.wreducer180.add_param(self.k_proj_weight_86)
        self.wreducer180.add_param(self.k_proj_bias_87)
        self.wreducer180.add_param(self.v_proj_weight_89)
        self.wreducer180.add_param(self.v_proj_bias_90)
        self.add_reducer(self.wreducer180)

        self._post_init(init_params)

    def segment176(self, seqs_103, keys_106, values_109):
        seqs_103 = nnscaler.runtime.adapter.nn.identity_allreduce(seqs_103, ranks=[0, 1])
        # File "/usr/local/lib/python3.10/dist-packages/fairseq2/nn/projection.py", line 130, in forward,  return linear(x, self.weight, self.bias)
        linear_138 = torch.nn.functional.linear(seqs_103, self.q_proj_weight_134, self.q_proj_bias_136)
        del seqs_103
        linear_83 = nnscaler.runtime.adapter.nn.allgather_split(linear_138, dim=2, ranks=[0, 1])
        del linear_138
        # File "/usr/local/lib/python3.10/dist-packages/fairseq2/nn/transformer/multihead_attention.py", line 490, in _project_q,  q = q.unflatten(-1, (self.num_heads, -1)).transpose(1, 2)
        unflatten_84 = torch.unflatten(linear_83, -1, (8, -1))
        del linear_83
        # File "/usr/local/lib/python3.10/dist-packages/fairseq2/nn/transformer/multihead_attention.py", line 490, in _project_q,  q = q.unflatten(-1, (self.num_heads, -1)).transpose(1, 2)
        transpose_85 = torch.transpose(unflatten_84, dim0=1, dim1=2)
        del unflatten_84
        # File "/usr/local/lib/python3.10/dist-packages/fairseq2/nn/projection.py", line 130, in forward,  return linear(x, self.weight, self.bias)
        linear_1_88 = torch.nn.functional.linear(keys_106, self.k_proj_weight_86, self.k_proj_bias_87)
        del keys_106
        # File "/usr/local/lib/python3.10/dist-packages/fairseq2/nn/projection.py", line 130, in forward,  return linear(x, self.weight, self.bias)
        linear_2_91 = torch.nn.functional.linear(values_109, self.v_proj_weight_89, self.v_proj_bias_90)
        del values_109
        # File "/usr/local/lib/python3.10/dist-packages/fairseq2/nn/transformer/multihead_attention.py", line 510, in _project_kv,  k = k.unflatten(-1, (self.num_key_value_heads, -1)).transpose(1, 2)
        unflatten_1_92 = torch.unflatten(linear_1_88, -1, (8, -1))
        del linear_1_88
        # File "/usr/local/lib/python3.10/dist-packages/fairseq2/nn/transformer/multihead_attention.py", line 510, in _project_kv,  k = k.unflatten(-1, (self.num_key_value_heads, -1)).transpose(1, 2)
        transpose_1_93 = torch.transpose(unflatten_1_92, dim0=1, dim1=2)
        del unflatten_1_92
        # File "/usr/local/lib/python3.10/dist-packages/fairseq2/nn/transformer/multihead_attention.py", line 512, in _project_kv,  v = v.unflatten(-1, (self.num_key_value_heads, -1)).transpose(1, 2)
        unflatten_2_94 = torch.unflatten(linear_2_91, -1, (8, -1))
        del linear_2_91
        # File "/usr/local/lib/python3.10/dist-packages/fairseq2/nn/transformer/multihead_attention.py", line 512, in _project_kv,  v = v.unflatten(-1, (self.num_key_value_heads, -1)).transpose(1, 2)
        transpose_2_95 = torch.transpose(unflatten_2_94, dim0=1, dim1=2)
        del unflatten_2_94
        # File "/usr/local/lib/python3.10/dist-packages/fairseq2/nn/transformer/attention.py", line 176, in forward,  attn = F.scaled_dot_product_attention(  # type: ignore[attr-defined]
        scaled_dot_product_attention_96 = torch._C._nn.scaled_dot_product_attention(transpose_85, transpose_1_93, transpose_2_95, dropout_p=0.0, is_causal=False)
        del transpose_85, transpose_1_93, transpose_2_95
        # File "/home/infiniai/linrongjian/nnscaler/parallelize_attn.py", line 142, in forward,  attn = attn.transpose(1, 2)
        transpose_3_97 = torch.transpose(scaled_dot_product_attention_96, dim0=1, dim1=2)
        del scaled_dot_product_attention_96
        # File "/home/infiniai/linrongjian/nnscaler/parallelize_attn.py", line 148, in forward,  attn = attn.flatten(2, 3)
        flatten_98 = torch.flatten(transpose_3_97, start_dim=2, end_dim=3)
        del transpose_3_97
        # File "/usr/local/lib/python3.10/dist-packages/fairseq2/nn/projection.py", line 130, in forward,  return linear(x, self.weight, self.bias)
        linear_3_80 = torch.nn.functional.linear(flatten_98, self.output_proj_weight_99, self.output_proj_bias_100)
        del flatten_98
        return linear_3_80

    def reducer180(self):
        self.wreducer180.sync_grads()
        return 

    def _forward_impl(self, seqs, keys, values):
        linear_3_80 = self.segment176(seqs, keys, values)
        return linear_3_80

If you know how to fix this or you need more information, please let me know.

Thank you for your help.

zyeric commented 2 hours ago

by default, nnscaler assumes that computation is launched on GPU. I think you should place prefix, inputs to GPU before passing to the parallel_module

dev = torch.cuda.current_device()
new_result = parallelized_module(
    seqs=prefix.to(dev),
    keys=inputs.to(dev),
    values=inputs.to(dev),
)