openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.6k stars 406 forks source link

Tranposing to different layout permutations results in different numerics #17276

Open elfiegg opened 1 week ago

elfiegg commented 1 week ago

Hello, we stumbled upon a numerical issue for below modules while training fp8 quantizated models.

ENTRY main {
    %p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
    %b = f8e4m3fn[4096,12288]{1,0} bitcast(%p0)
    %transpose = f8e4m3fn[12288,4096]{1,0} transpose(%b), dimensions={1,0}
    %p1 = f8e4m3fn[4096,16384]{0,1} parameter(1)
    %p2 = f32[] parameter(2)
    %constant_1 = f32[] constant(1)
    ROOT %cublas-gemm.1.0 = (bf16[12288,16384]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[12288,4096]{1,0} %transpose,  f8e4m3fn[4096,16384]{0,1} %p1,f32[] %p2, f32[]%constant_1, f32[]%constant_1, f32[]%constant_1), custom_call_target="__cublas$lt$matmul$f8", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"],"algorithm":"ALG_UNSET"},"epilogue":"DEFAULT","damax_output":false,"selected_algorithm":"6","lhs_stride":"50331648","rhs_stride":"67108864","grad_x":false,"grad_y":false},"force_earliest_schedule":false}
  }
 ENTRY main {
    %p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
    %transpose = f8e4m3fn[4096,12288]{0,1} transpose(%p0), dimensions={1,0}
    %p1 = f8e4m3fn[4096,16384]{0,1} parameter(1)
    %p2 = f32[] parameter(2)
    %constant_1 = f32[] constant(1)
    ROOT %cublas-gemm.1.0 = (bf16[12288,16384]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[4096,12288]{0,1} %transpose, f8e4m3fn[4096,16384]{0,1} %p1, f32[] %p2, f32[] %constant_1, f32[] %constant_1, f32[] %constant_1), custom_call_target="__cublas$lt$matmul$f8", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["0"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"],"algorithm":"ALG_UNSET"},"epilogue":"DEFAULT","damax_output":false,"selected_algorithm":"7","lhs_stride":"50331648","rhs_stride":"67108864","grad_x":false,"grad_y":false},"force_earliest_schedule":false}
  }

This resulted in different numerics and upon checking the cublas runtime thunk - it processed the logical layout correctly and buffer assignment worked exactly the same.

We then had a unit test for testing out tranpose numerics as below

ENTRY main {
    %p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
    %b = f8e4m3fn[4096,12288]{1,0} reshape(%p0)
    %transpose = f8e4m3fn[12288,4096]{1,0} transpose(%b), dimensions={1,0}
    ROOT bitcast = f8e4m3fn[4096,12288]{0,1} reshape(%transpose)
  }
ENTRY main {
    %p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
    ROOT %transpose = f8e4m3fn[4096,12288]{0,1} transpose(%p0), dimensions={1,0}
  }

The numerical results of them were 99% different with relative errors > 1e-2.

Could you please help us understand why tranpose to different layout permutation would result in numerical difference? Is the default / non-default layout tranpose a known issue or are we making any unintentional assumptions / mistakes?

elfiegg commented 1 week ago

@kaixih @wenscarl

elfiegg commented 1 week ago

unit-test reproducer (that we also modified to test tranpose as the root of modules):

TEST_F(GpuCompilerTest, LayoutNormalizationRequiredForCublasF8) {
  auto cc = backend()
                .default_stream_executor()
                ->GetDeviceDescription()
                .cuda_compute_capability();
  if (!cc.IsAtLeastAmpere()) {
    GTEST_SKIP() << "Autotuning results have only been generated for Ampere "
                 << "and Hopper GPUs";
  }
  const absl::string_view good_hlo_string = R"( 
  HloModule test 

  ENTRY main {
    %p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
    %b = f8e4m3fn[4096,12288]{1,0} bitcast(%p0)
    %transpose = f8e4m3fn[12288,4096]{1,0} transpose(%b), dimensions={1,0}
    %p1 = f8e4m3fn[4096,16384]{0,1} parameter(1)
    %p2 = f32[] parameter(2)
    %constant_1 = f32[] constant(1)
    ROOT %cublas-gemm.1.0 = (bf16[12288,16384]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[12288,4096]{1,0} %transpose,  f8e4m3fn[4096,16384]{0,1} %p1,f32[] %p2, f32[]%constant_1, f32[]%constant_1, f32[]%constant_1), custom_call_target="__cublas$lt$matmul$f8", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"],"algorithm":"ALG_UNSET"},"epilogue":"DEFAULT","damax_output":false,"selected_algorithm":"6","lhs_stride":"50331648","rhs_stride":"67108864","grad_x":false,"grad_y":false},"force_earliest_schedule":false}
  })";

  HloModuleConfig config;
  DebugOptions debug_options = GetDebugOptionsForTest();
  debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false);
  debug_options.set_xla_gpu_enable_triton_gemm(false);
  debug_options.set_xla_gpu_cublas_fallback(true);
  config.set_debug_options(debug_options);
  config.set_replica_count(1);
  config.set_num_partitions(1);

  TF_ASSERT_OK_AND_ASSIGN(
      std::unique_ptr<HloModule> good_module,
      ParseAndReturnVerifiedModule(good_hlo_string, config));

  const absl::string_view bad_hlo_string = R"(
  HloModule test 

  ENTRY main {
    %p0 = f8e4m3fn[12288,4096]{0,1} parameter(0)
    %transpose = f8e4m3fn[4096,12288]{0,1} transpose(%p0), dimensions={1,0}
    %p1 = f8e4m3fn[4096,16384]{0,1} parameter(1)
    %p2 = f32[] parameter(2)
    %constant_1 = f32[] constant(1)
    ROOT %cublas-gemm.1.0 = (bf16[12288,16384]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[4096,12288]{0,1} %transpose, f8e4m3fn[4096,16384]{0,1} %p1, f32[] %p2, f32[] %constant_1, f32[] %constant_1, f32[] %constant_1), custom_call_target="__cublas$lt$matmul$f8", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"alpha_imag":0,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["0"],"rhs_contracting_dimensions":["0"],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"precision_config":{"operand_precision":["DEFAULT","DEFAULT"],"algorithm":"ALG_UNSET"},"epilogue":"DEFAULT","damax_output":false,"selected_algorithm":"7","lhs_stride":"50331648","rhs_stride":"67108864","grad_x":false,"grad_y":false},"force_earliest_schedule":false}
  })";

  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> bad_module,
                          ParseAndReturnVerifiedModule(bad_hlo_string, config));

  EXPECT_TRUE(RunAndCompareTwoModules(good_hlo_string, bad_hlo_string,
                                      ErrorSpec{1e-10, 1e-10}, false));
}
sergachev commented 1 week ago

@akuegel is my understanding right that transposes should always use the default layout and that's normally ensured by the layout normalization? If so, should we try to detect the wrong ones at codegen or in the HLO verifier?

mooskagh commented 1 week ago

It's indeed expected that cuBLAS gemms get the layout normalized, and if you feed the HLO to the optimization passes, it already fails (with slightly different error):

layout_assignment.cc:321] Check failed: !IsCublasGemm(*instruction) Gemm rewriting should run after layout assignment

We can add a check somewhere that ensures that layout that gets into the custom call is normalized, but that's just to ensure internal invariant, it's (at least in theory) not possible to get to this state from a pre-optimized HLO.

sergachev commented 1 week ago

It's not about cuBLAS, it's about transpose alone, see the second reproducer in the first message.

akuegel commented 1 week ago

@sergachev While Layout Normalization will make sure that transposes have the default layout, there could be passes later in the pipeline that create transposes with non-default layout. Note that anything that calls MakeTransposeHlo from hlo_creation_utils will most likely have a non-default layout, as that function infers a layout that will make the transpose a bitcast. This is something we want to avoid, so if you see any pass that runs after LayoutNormalization that calls MakeTransposeHlo, please file a bug or send a PR.

elfiegg commented 1 week ago

@akuegel it sounds to me, generally speaking we should ensure that layout normalization and its associated passes are called after all rewriters and op-changing passes, before codegen, to ensure they have accounted for all ops? As the layouts normalized by the pass might be a strict requirement

akuegel commented 1 week ago

We already have HLO passes that rely on having only transposes with default layout. For example the one I added recently (TransposeDimensionGrouper) only works on transposes with default layout and will return an error otherwise. So just running the layout normalization once again at the end of the pipeline will not fix the issue. So the suggestion of @sergachev to make it part of the HloVerifier sounds better to me. It would need to be a verifier option that is off by default, but can be turned on in our pipeline after LayoutNormalization pass.

elfiegg commented 1 week ago

OK that sounds good! My original comment was more focused on other instructions involved in the layout normalization in a broader sense. Are all the instructions that the layout normalization pass standardizes considered a strict requirement? Or maybe transpose is a special case that we stumbled upon that would affect correctness

akuegel commented 5 days ago

Once LayoutNormalization has run, it is quite unlikely that other passes will introduce ops that don't have the default layout. Normally the layout of new ops is derived from the ops surrounding it, so if all those ops have the default layout, the new ops will have the default layout as well. Transpose is special because of the MakeTransposeHlo() method, because that will choose a non-default layout. I believe it was a mistake to make that function assign a non-default layout, but that would probably be quite hard to change now. And then, most of the code would still work with any layout, as LayoutNormalization is kind of new and the code was written to support any layout. Only newly added code might be relying on default layout.

elfiegg commented 4 days ago

The layout of new ops is indeed derived from the ops surrounding it, and the "bug" is due to some of ops don't have a chance to go through layoutnormalization pass: Triton first fuses FP8 GEMMs, but during layout normalization, the tranpose has not yet being inserted by GemmRewriter and ops within these fusions are not handled either. Then when the autotuner falls back to cublas, where the fused computations are inlined, cublas GemmRewriter might insert a non-default tranpose based on the context. In this situation, would you consider it a bug where layout normalization should also occur after inlining the computations, or should we better insert a non-default tranpose in the GemmRewriter?

akuegel commented 4 days ago

Ideally we would insert a transpose with default layout in the GemmRewriter. If you have a transpose that preserves the non-default layout of its operand, it can be normalized to have a default layout by adding a bitcast transpose in front and after it. Unfortunately we still don't normalize Dots, which means we often have a bitcast operand of a dot with non-default layout, so if a transpose is inserted between the bitcast and the dot, it would have non-default layout as well.

elfiegg commented 3 days ago

OK, that case could you please also take a look at https://github.com/openxla/xla/pull/17440 for any comment? @wenscarl had a fix for inserting default layout transpose in GemmRewriter