bytedance / flux

A fast communication-overlapping library for tensor parallelism on GPUs.
Apache License 2.0
234 stars 18 forks source link

[QUESTION] Not supported on A6000? #46

Open Zhuohao-Li opened 1 month ago

Zhuohao-Li commented 1 month ago

Your question Hi,

When I run the test demo with a node consists of 2 A6000 it reports bugs:

RuntimeError: /root/opensource/flux/src/cuda/op_registry.cu:36 Check failed: arch_num == 80 || arch_num == 89 || arch_num == 90. unsupported arch: 86 So flux can only support these three GPUs (cc=90, 80, 89), correct me if I misunderstand it.

Thanks

zheng-ningxin commented 1 month ago

Yes, Flux only compiled the architectures 80, 89, and 90 for now. However, I suspect that CUTLASS v2 should directly support architecture number 86. Could you try adding the corresponding arch number and recompiling to see if it works?

Zhuohao-Li commented 1 month ago

Thanks,

I add 86 arguments to /flux/src/cuda/op_registery.cu line 36 like this:

void
init_arch_tag() {
  int major, minor;
  cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
  cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0);
  int arch_num = major * 10 + minor;
  FLUX_CHECK(arch_num == 80 || arch_num == 89 || arch_num == 90 || arch_num == 86)
      << "unsupported arch: " << arch_num;
  arch = ArchEnum{arch_num};
}
} 

I recompiled it via Build from Source again, but when running ./scripts/launch.sh test/test_gemm_rs.py 4096 12288 49152 --dtype=float16 --iters=10 it turns out:

RuntimeError: ~/flux/include/flux/op_registry.h:220 Check failed: visit_iter != gemm_hparams_.end(). No registered hparams found for meta:GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32),arch=UNK,comm_op=CommNone,gemm_layout=RCR,impl=GemmV2,impl_spec=GemmV2Meta(fast_accum=0),comm_spec=None)

The corresponding code pieces in op_registry.h:

// Iterate all hparams registered for a meta and call func.
  // This can be useful for tuning.
  template <class... Ts>
  void
  visit_hparams(std::function<void(UnifiedGemmHParams)> &&func, GemmMeta<Ts...> meta) {
    std::shared_lock<std::shared_mutex> lock(register_mutex_);
    auto unified_meta = unify_type(meta);
    auto iter = gemm_hparams_.find(unified_meta);
    FLUX_CHECK(iter != gemm_hparams_.end()) << "no op registered for meta:" << meta;
    for (const auto &hparams_pair : iter->second) {
      auto const &hparams = hparams_pair.second;
      func(hparams);
    }
  }

I have not yet took a deep look at what is hparams is, if possible, can you please point it to me quickly? Any additional changes in the codebase? Thanks!

houqi commented 3 weeks ago

Thanks,

I add 86 arguments to /flux/src/cuda/op_registery.cu line 36 like this:

void
init_arch_tag() {
  int major, minor;
  cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
  cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0);
  int arch_num = major * 10 + minor;
  FLUX_CHECK(arch_num == 80 || arch_num == 89 || arch_num == 90 || arch_num == 86)
      << "unsupported arch: " << arch_num;
  arch = ArchEnum{arch_num};
}
} 

I recompiled it via Build from Source again, but when running ./scripts/launch.sh test/test_gemm_rs.py 4096 12288 49152 --dtype=float16 --iters=10 it turns out:

RuntimeError: ~/flux/include/flux/op_registry.h:220 Check failed: visit_iter != gemm_hparams_.end(). No registered hparams found for meta:GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32),arch=UNK,comm_op=CommNone,gemm_layout=RCR,impl=GemmV2,impl_spec=GemmV2Meta(fast_accum=0),comm_spec=None)

The corresponding code pieces in op_registry.h:

// Iterate all hparams registered for a meta and call func.
  // This can be useful for tuning.
  template <class... Ts>
  void
  visit_hparams(std::function<void(UnifiedGemmHParams)> &&func, GemmMeta<Ts...> meta) {
    std::shared_lock<std::shared_mutex> lock(register_mutex_);
    auto unified_meta = unify_type(meta);
    auto iter = gemm_hparams_.find(unified_meta);
    FLUX_CHECK(iter != gemm_hparams_.end()) << "no op registered for meta:" << meta;
    for (const auto &hparams_pair : iter->second) {
      auto const &hparams = hparams_pair.second;
      func(hparams);
    }
  }

I have not yet took a deep look at what is hparams is, if possible, can you please point it to me quickly? Any additional changes in the codebase? Thanks!

You should also

  1. modify flux.h and add to ArchEnum with 86
  2. add into workspace with sm86: gemm_v2_reduce_scatter.hpp#L502 for GRMM+RS, gemm_v2_ag_kernel.hpp#L174 for AG+GEMM