From 261843bfaee9e4f8e6a81f9d6d92a0afd9eb275a Mon Sep 17 00:00:00 2001 From: YigongQin Date: Fri, 10 Apr 2026 14:40:23 -0700 Subject: [PATCH 1/4] Handle empty tensors in dequantize for CUDA graph compatibility Signed-off-by: YigongQin --- tests/cpp/operator/test_dequantize_mxfp8.cu | 2 ++ .../common/cast/dispatch/dequantize.cuh | 4 ++++ .../pytorch/module/grouped_linear.py | 22 ++++--------------- 3 files changed, 10 insertions(+), 18 deletions(-) diff --git a/tests/cpp/operator/test_dequantize_mxfp8.cu b/tests/cpp/operator/test_dequantize_mxfp8.cu index a529f93d7c..1ac794a20e 100644 --- a/tests/cpp/operator/test_dequantize_mxfp8.cu +++ b/tests/cpp/operator/test_dequantize_mxfp8.cu @@ -370,6 +370,8 @@ void performTest_x2(const size_t rows, } std::vector> tensor_dims = { + {0, 128}, + {0, 256}, {1, 16}, {16, 48}, {65, 96}, diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index 81304981d3..c63aac2831 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -25,6 +25,10 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t CheckInputTensor(input, "cast_input"); CheckOutputTensor(*output, "cast_output"); + if (input.numel() == 0) { + return; + } + switch (input.scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { NVTE_CHECK(is_fp8_dtype(input.dtype()), "Input must have FP8 type."); diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 2cce6c3ef8..76c36af0ee 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -515,25 +515,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) elif ctx.backward_override == "dequantized": inputmats_dequant = [] - for m_split, inputmat in zip(ctx.m_splits, inputmats): + for inputmat in inputmats: if isinstance(inputmat, QuantizedTensorStorage): - if m_split == 0: - # Dequant kernels for some quantized storage formats - # (e.g. MXFP8/Float8BlockScaling) do not accept empty - # M-dimension inputs. For empty grouped splits, materialize - # an explicit empty high-precision matrix instead of invoking - # dequantize(). - inputmats_dequant.append( - torch.empty( - (0, ctx.weights_shape_1), - dtype=ctx.activation_dtype, - device=ctx.device, - ) - ) - else: - inputmats_dequant.append( - inputmat.dequantize(dtype=ctx.activation_dtype) - ) + inputmats_dequant.append( + inputmat.dequantize(dtype=ctx.activation_dtype) + ) else: inputmats_dequant.append(cast_if_needed(inputmat, ctx.activation_dtype)) inputmats = inputmats_dequant From f6f15256e65efe9db12fe27a748879c74a33eb9c Mon Sep 17 00:00:00 2001 From: YigongQin Date: Fri, 10 Apr 2026 14:40:23 -0700 Subject: [PATCH 2/4] dequant with swizzled scales Signed-off-by: YigongQin --- tests/cpp/operator/CMakeLists.txt | 1 + tests/cpp/operator/test_dequantize_mxfp8.cu | 204 +++++++++++++ tests/cpp/operator/test_dequantize_nvfp4.cu | 277 ++++++++++++++++++ tests/cpp/test_common.h | 10 + .../common/cast/mxfp8/dequantize_mxfp8.cuh | 60 ++-- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 31 +- 6 files changed, 554 insertions(+), 29 deletions(-) create mode 100644 tests/cpp/operator/test_dequantize_nvfp4.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 5e73675f4f..f9f05ede53 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -15,6 +15,7 @@ add_executable(test_operator test_cast_nvfp4_transpose.cu test_cast_float8blockwise.cu test_dequantize_mxfp8.cu + test_dequantize_nvfp4.cu test_transpose.cu test_cast_transpose.cu test_cast_transpose_current_scaling.cu diff --git a/tests/cpp/operator/test_dequantize_mxfp8.cu b/tests/cpp/operator/test_dequantize_mxfp8.cu index 1ac794a20e..4a579848a1 100644 --- a/tests/cpp/operator/test_dequantize_mxfp8.cu +++ b/tests/cpp/operator/test_dequantize_mxfp8.cu @@ -18,6 +18,7 @@ #include #include +#include #include "../test_common.h" #include "transformer_engine/transformer_engine.h" @@ -369,6 +370,134 @@ void performTest_x2(const size_t rows, compareResults("output_colwise", output, ref_output_colwise.get(), false, atol, rtol); } +// Dequantize with GEMM-swizzled scales (single dimension) +template +void performTest_x1_swizzled(const size_t rows, + const size_t cols, + const bool rowwise, + const bool colwise) +{ + using namespace test; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t block_size_rows = rowwise ? 1 : 32; + const size_t block_size_cols = colwise ? 1 : 32; + + const size_t unpadded_blocks_Y_rowwise = rows; + const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols); + const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows); + const size_t unpadded_blocks_X_colwise = cols; + + const size_t blocks_Y_rowwise = round_up_to_nearest_multiple(unpadded_blocks_Y_rowwise, + scale_tensor_alignment_Y_rowwise); + const size_t blocks_X_rowwise = round_up_to_nearest_multiple(unpadded_blocks_X_rowwise, + scale_tensor_alignment_X_rowwise); + const size_t blocks_Y_colwise = round_up_to_nearest_multiple(unpadded_blocks_Y_colwise, + scale_tensor_alignment_Y_colwise); + const size_t blocks_X_colwise = round_up_to_nearest_multiple(unpadded_blocks_X_colwise, + scale_tensor_alignment_X_colwise); + + const size_t blocks_num_rowwise = blocks_Y_rowwise * blocks_X_rowwise; + const size_t blocks_num_colwise = blocks_Y_colwise * blocks_X_colwise; + + const size_t blocks_num = rowwise ? blocks_num_rowwise : blocks_num_colwise; + const size_t scales_stride = rowwise ? blocks_X_rowwise : blocks_X_colwise; + + Tensor input_compact_scales("input_compact_scales", std::vector{ rows, cols }, itype, + rowwise, colwise, NVTE_MXFP8_1D_SCALING); + + Tensor input_swizzled_scales("input_swizzled_scales", std::vector{ rows, cols }, itype, + rowwise, colwise, NVTE_MXFP8_1D_SCALING); + input_swizzled_scales.set_with_gemm_swizzled_scales(true); + + Tensor output("output", std::vector{ rows, cols }, otype, true, false); + + std::unique_ptr ref_output = std::make_unique(rows * cols); + std::unique_ptr scales = std::make_unique(blocks_num); + + fill_tensor_data(input_compact_scales, scales.get(), scales.get(), rowwise, colwise, + rows, cols, blocks_num_rowwise, blocks_num_colwise); + + const size_t data_bytes = rows * cols * sizeof(InputType); + if (rowwise && data_bytes > 0) { + cudaMemcpy(input_swizzled_scales.rowwise_dptr(), input_compact_scales.rowwise_dptr(), + data_bytes, cudaMemcpyDeviceToDevice); + } + if (colwise && data_bytes > 0) { + cudaMemcpy(input_swizzled_scales.columnwise_dptr(), input_compact_scales.columnwise_dptr(), + data_bytes, cudaMemcpyDeviceToDevice); + } + + if (data_bytes > 0) { + nvte_swizzle_scaling_factors(input_compact_scales.data(), input_swizzled_scales.data(), 0); + } + + nvte_dequantize(input_swizzled_scales.data(), output.data(), 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + InputType *data_ptr = rowwise + ? input_compact_scales.rowwise_cpu_dptr() + : input_compact_scales.columnwise_cpu_dptr(); + + compute_ref_x1(data_ptr, + ref_output.get(), + scales.get(), + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_swizzled", output, ref_output.get(), true, atol, rtol); +} + +// Quantize with swizzled scales, then dequantize — round-trip test +template +void performTest_quantize_then_dequantize_swizzled(const size_t rows, + const size_t cols, + const bool rowwise, + const bool colwise) +{ + using namespace test; + using EncodingType = fp32; + DType in_type = TypeInfo::dtype; + DType intermed_type = TypeInfo::dtype; + DType out_type = TypeInfo::dtype; + + std::unique_ptr output_cpu = std::make_unique(rows * cols); + + Tensor input("input", std::vector{ rows, cols }, in_type); + Tensor quantized("quantized", std::vector{ rows, cols }, intermed_type, + rowwise, colwise, NVTE_MXFP8_1D_SCALING); + quantized.set_with_gemm_swizzled_scales(true); + + Tensor output("output", std::vector{ rows, cols }, out_type, true, false); + + fillCase(&input, InputsFillCase::uniform); + + if (rows > 0 && cols > 0) { + nvte_quantize(input.data(), quantized.data(), 0); + cudaDeviceSynchronize(); + } + + nvte_dequantize(quantized.data(), output.data(), 0); + cudaDeviceSynchronize(); + + const size_t copy_size = sizeof(InputType) * rows * cols; + cudaMemcpy(output_cpu.get(), output.rowwise_dptr(), copy_size, cudaMemcpyDeviceToHost); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + auto [atol, rtol] = getTolerances(intermed_type); + compareResults("Quantize-Dequantize-Swizzled", input, output_cpu.get(), true, atol, rtol); +} + std::vector> tensor_dims = { {0, 128}, {0, 256}, @@ -472,3 +601,78 @@ INSTANTIATE_TEST_SUITE_P( return name; } ); + +/***************************************************************************** + * Swizzled-scale dequantization tests + *****************************************************************************/ + +class DequantizeMXFP8SwizzledTestSuite : public ::testing::TestWithParam + , + std::pair, + transformer_engine::DType, + transformer_engine::DType, + bool>> {}; + +TEST_P(DequantizeMXFP8SwizzledTestSuite, TestDequantizeMXFP8Swizzled) +{ + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const auto tensor_size = std::get<0>(GetParam()); + const auto block_size = std::get<1>(GetParam()); + const DType input_type = std::get<2>(GetParam()); + const DType output_type = std::get<3>(GetParam()); + const bool quantize_then_dequantize = std::get<4>(GetParam()); + + const bool rowwise = block_size.second != 1; + const bool colwise = block_size.first != 1; + + if (rowwise && colwise) { + GTEST_SKIP(); + } + + if (rowwise && tensor_size.second % 32 != 0) { + GTEST_SKIP(); + } + if (colwise && tensor_size.first % 32 != 0) { + GTEST_SKIP(); + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, + if (quantize_then_dequantize) { + performTest_quantize_then_dequantize_swizzled( + tensor_size.first, tensor_size.second, rowwise, colwise); + } else { + performTest_x1_swizzled( + tensor_size.first, tensor_size.second, rowwise, colwise); + } + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + DequantizeMXFP8SwizzledTestSuite, + ::testing::Combine( + ::testing::ValuesIn(tensor_dims), + ::testing::ValuesIn(block_sizes), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(false)), + [](const testing::TestParamInfo& info) + { + std::string name = std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "X" + + std::to_string(std::get<1>(info.param).first) + "X" + + std::to_string(std::get<1>(info.param).second) + "X" + + test::typeName(std::get<2>(info.param)) + "X" + + test::typeName(std::get<3>(info.param)) + "X" + + (std::get<4>(info.param) ? "QD_Swizzled" : "D_Swizzled"); + return name; + } +); diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu new file mode 100644 index 0000000000..535f744d99 --- /dev/null +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -0,0 +1,277 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#if FP4_TYPE_SUPPORTED +#include +#endif + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +#if FP4_TYPE_SUPPORTED + +namespace { + +float2 cvt_fp4x2_to_float2(fp4e2m1x2 fp4_pair) { + const __half2_raw raw = + __nv_cvt_fp4x2_to_halfraw2( + *reinterpret_cast<__nv_fp4x2_storage_t *>(&fp4_pair), __NV_E2M1); + const __half2 h2(raw); + return {static_cast(h2.x), static_cast(h2.y)}; +} + +template +void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, + const fp8e4m3 *scales, + float amax, + OType *output, + size_t rows, + size_t cols, + size_t scale_stride) { + constexpr float factor_inv = 1.0f / (6.0f * 448.0f); + constexpr size_t BLOCK_SIZE = 16; + const size_t Mread = cols / BLOCK_SIZE; + const size_t bytes_per_block = BLOCK_SIZE / 2; + + for (size_t row = 0; row < rows; ++row) { + for (size_t block = 0; block < Mread; ++block) { + const fp8e4m3 scale = scales[row * scale_stride + block]; + const float final_scale = static_cast(scale) * amax * factor_inv; + + for (size_t pair_idx = 0; pair_idx < bytes_per_block; ++pair_idx) { + const size_t byte_idx = + (row * Mread + block) * bytes_per_block + pair_idx; + fp4e2m1x2 fp4_pair; + std::memcpy(&fp4_pair, &packed_data[byte_idx], 1); + const float2 values = cvt_fp4x2_to_float2(fp4_pair); + + const size_t col0 = block * BLOCK_SIZE + pair_idx * 2; + output[row * cols + col0] = + static_cast(values.x * final_scale); + output[row * cols + col0 + 1] = + static_cast(values.y * final_scale); + } + } + } +} + +// Quantize a high-precision input to NVFP4, then dequantize and compare +// against a CPU reference computed from the quantized data. +template +void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { + using namespace test; + DType otype = TypeInfo::dtype; + + Tensor input("input", std::vector{rows, cols}, otype); + fillCase(&input, InputsFillCase::uniform); + + Tensor quantized("quantized", std::vector{rows, cols}, + DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); + quantized.set_amax(0.0f); + setRandomScale(&quantized); + + if (rows > 0 && cols > 0) { + nvte_quantize(input.data(), quantized.data(), 0); + cudaDeviceSynchronize(); + } + + Tensor output("output", std::vector{rows, cols}, otype, true, false); + nvte_dequantize(quantized.data(), output.data(), 0); + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + if (rows > 0 && cols > 0) { + quantized.to_cpu(); + const uint8_t *fp4_data = + reinterpret_cast(quantized.rowwise_cpu_dptr()); + const fp8e4m3 *scales = quantized.rowwise_cpu_scale_inv_ptr(); + const float amax_val = quantized.amax(); + const NVTEShape scale_shape = quantized.rowwise_scale_inv_shape(); + const size_t scale_stride = scale_shape.data[scale_shape.ndim - 1]; + + std::unique_ptr ref_output = + std::make_unique(rows * cols); + compute_ref_dequantize_nvfp4( + fp4_data, scales, amax_val, ref_output.get(), + rows, cols, scale_stride); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_nvfp4", output, ref_output.get(), true, atol, rtol); + } +} + +// Dequantize NVFP4 with GEMM-swizzled scales and compare against compact path. +template +void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) { + using namespace test; + DType otype = TypeInfo::dtype; + + Tensor input("input", std::vector{rows, cols}, otype); + fillCase(&input, InputsFillCase::uniform); + + Tensor quantized_compact("quantized_compact", std::vector{rows, cols}, + DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); + quantized_compact.set_amax(0.0f); + setRandomScale(&quantized_compact); + + if (rows > 0 && cols > 0) { + nvte_quantize(input.data(), quantized_compact.data(), 0); + cudaDeviceSynchronize(); + } + + // Dequantize with compact scales → reference output + Tensor output_compact("output_compact", std::vector{rows, cols}, otype, true, false); + nvte_dequantize(quantized_compact.data(), output_compact.data(), 0); + cudaDeviceSynchronize(); + + // Create tensor with same FP4 data but swizzled scales + Tensor quantized_swizzled("quantized_swizzled", std::vector{rows, cols}, + DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); + quantized_swizzled.set_amax(0.0f); + setRandomScale(&quantized_swizzled); + quantized_swizzled.set_with_gemm_swizzled_scales(true); + + // Copy FP4 data + const size_t data_bytes = rows * cols / 2; + if (data_bytes > 0) { + cudaMemcpy(quantized_swizzled.rowwise_dptr(), quantized_compact.rowwise_dptr(), + data_bytes, cudaMemcpyDeviceToDevice); + } + + // Copy amax from compact to swizzled (GPU-to-GPU) + // Read amax from compact tensor, set on swizzled tensor, upload + quantized_compact.to_cpu(); + quantized_swizzled.set_amax(quantized_compact.amax()); + quantized_swizzled.set_scale(quantized_compact.scale()); + quantized_swizzled.from_cpu(); + + // Swizzle scales + if (data_bytes > 0) { + nvte_swizzle_scaling_factors(quantized_compact.data(), quantized_swizzled.data(), 0); + } + + // Dequantize with swizzled scales + Tensor output_swizzled("output_swizzled", std::vector{rows, cols}, otype, true, false); + nvte_dequantize(quantized_swizzled.data(), output_swizzled.data(), 0); + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + // Read compact output as reference + const size_t num_elems = rows * cols; + std::unique_ptr ref_output = std::make_unique(num_elems); + if (num_elems > 0) { + cudaMemcpy(ref_output.get(), output_compact.rowwise_dptr(), + num_elems * sizeof(OutputType), cudaMemcpyDeviceToHost); + } + + auto [atol, rtol] = getTolerances(otype); + if (num_elems > 0) { + compareResults("output_nvfp4_swizzled", output_swizzled, + ref_output.get(), true, atol, rtol); + } +} + +std::vector> nvfp4_tensor_dims = { + {0, 32}, + {32, 32}, + {64, 64}, + {128, 128}, + {256, 256}, + {128, 512}, + {256, 1024}, +}; + +} // namespace + +class DequantizeNVFP4TestSuite : public ::testing::TestWithParam + , + transformer_engine::DType>> {}; + +TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) +{ + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + const auto tensor_size = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, + performTest_dequantize_nvfp4( + tensor_size.first, tensor_size.second); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + DequantizeNVFP4TestSuite, + ::testing::Combine( + ::testing::ValuesIn(nvfp4_tensor_dims), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + [](const testing::TestParamInfo& info) + { + std::string name = std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "X" + + test::typeName(std::get<1>(info.param)); + return name; + } +); + +class DequantizeNVFP4SwizzledTestSuite : public ::testing::TestWithParam + , + transformer_engine::DType>> {}; + +TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) +{ + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + const auto tensor_size = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, + performTest_dequantize_nvfp4_swizzled( + tensor_size.first, tensor_size.second); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + DequantizeNVFP4SwizzledTestSuite, + ::testing::Combine( + ::testing::ValuesIn(nvfp4_tensor_dims), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + [](const testing::TestParamInfo& info) + { + std::string name = std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "X" + + test::typeName(std::get<1>(info.param)) + "X" + + "Swizzled"; + return name; + } +); + +#endif // FP4_TYPE_SUPPORTED diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index b5a7f26d14..3f3ca03139 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -286,6 +286,16 @@ class Tensor { tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); } + void set_amax(float amax_val) { + if (!amax_cpu_data_) { + amax_cpu_data_ = std::make_shared(0); + float* amax_gpu = nullptr; + cudaMalloc((void**)&amax_gpu, sizeof(float)); // NOLINT(*) + tensor_.set_amax(amax_gpu, DType::kFloat32, std::vector{1}); + } + *amax_cpu_data_ = amax_val; + } + void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales){ tensor_.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); } diff --git a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index f8fecaa4e1..18e55d2dce 100644 --- a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -20,6 +20,7 @@ #include "../../util/math.h" #include "../../util/ptx.cuh" #include "../../utils.cuh" +#include "swizzle.cuh" namespace transformer_engine { namespace dispatch { @@ -42,12 +43,13 @@ constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 8 = 128 / 16 static_assert(ITERATIONS >= 1); -template +template __global__ void __launch_bounds__(THREADS_PER_CHUNK) dequantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, const e8m0_t *const scales_ptr, const size_t rows, const size_t cols, - const size_t scales_stride) { + const size_t scales_stride, const size_t num_scale_tiles_X) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; @@ -158,7 +160,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ? (scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE) : (scales_colwise_chunk_offset_X + tid_colwise_X); - const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X; + size_t scale_idx; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + if constexpr (USE_ROWWISE_SCALING) { + scale_idx = swizzle::gemm_swizzled_scale_idx(scale_offset_Y, scale_offset_X, + num_scale_tiles_X); + } else { + scale_idx = swizzle::gemm_swizzled_scale_idx(scale_offset_X, scale_offset_Y, + num_scale_tiles_X); + } + } else { + scale_idx = scale_offset_Y * scales_stride + scale_offset_X; + } const e8m0_t biased_exponent = scales_ptr[scale_idx]; const float block_scale = ptx::exp2f(biased_exponent); @@ -239,10 +252,11 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) NVTE_CHECK(is_fp8_dtype(input.columnwise_data.dtype), "Input must have FP8 type."); } - NVTE_CHECK(!input.with_gemm_swizzled_scales, "Input must have scales in compact format."); NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match."); + const bool with_gemm_swizzled_scales = input.with_gemm_swizzled_scales; + // TODO: Make more general const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; @@ -276,6 +290,9 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) const size_t scales_stride = use_rowwise_scaling ? scales_X_rowwise : scales_X_colwise; + const size_t num_scale_tiles_X = use_rowwise_scaling ? DIVUP(cols, static_cast(128)) + : DIVUP(rows, static_cast(128)); + const SimpleTensor &input_data = use_rowwise_scaling ? input.data : input.columnwise_data; const dim3 block(THREADS_PER_CHUNK); @@ -289,21 +306,26 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) input.dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_output{}; - - create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype())); - create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols, 0, typeToNumBits(output->dtype())); - - dequantize_mxfp8_kernel - <<>>(tensor_map_input, tensor_map_output, scales_ptr, - rows, cols, scales_stride);); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_SWITCH_CONDITION( + with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + + create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype())); + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols, 0, typeToNumBits(output->dtype())); + + dequantize_mxfp8_kernel + <<>>(tensor_map_input, tensor_map_output, scales_ptr, + rows, cols, scales_stride, + num_scale_tiles_X);); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace mxfp8 diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index ccdc4c93e3..63a7755556 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -20,6 +20,7 @@ #include "../../util/math.h" #include "../../util/ptx.cuh" #include "../../utils.cuh" +#include "../mxfp8/swizzle.cuh" #if FP4_TYPE_SUPPORTED #include @@ -30,11 +31,11 @@ namespace dispatch { namespace nvfp4 { namespace dequantize_kernel { #if FP4_TYPE_SUPPORTED -template +template __global__ void __launch_bounds__(512) dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, const float *const tensor_amax, const size_t N, const size_t M, - const size_t scale_stride) { + const size_t scale_stride, const size_t num_scale_tiles_X) { const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t x = thread_idx % M; const size_t y = thread_idx / M; @@ -52,7 +53,12 @@ __global__ void __launch_bounds__(512) OVec *output_vec = reinterpret_cast(output); const size_t my_index = x + y * M; - const size_t my_scale_index = x + y * scale_stride; + size_t my_scale_index; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + my_scale_index = mxfp8::swizzle::gemm_swizzled_scale_idx(y, x, num_scale_tiles_X); + } else { + my_scale_index = x + y * scale_stride; + } const size_t my_output_index = (x + y * M) * 4; fp4vec value; value.vec = input_vectorized[my_index]; @@ -80,10 +86,11 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) CheckInputTensor(input, "input"); CheckOutputTensor(*output, "output"); NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type."); - NVTE_CHECK(!input.with_gemm_swizzled_scales, "Input must have scales in compact format."); NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + const bool with_gemm_swizzled_scales = input.with_gemm_swizzled_scales; + constexpr int FP4_BLOCK_SIZE = 16; const size_t N = input.flat_first_dim(); const size_t M = input.flat_last_dim(); @@ -95,15 +102,19 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) const size_t total = N * Mread; const size_t threads = 512; const size_t blocks = DIVUP(total, threads); + const size_t num_scale_tiles_X = DIVUP(Mread, static_cast(4)); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( output->data.dtype, OType, - - dequantize_fp4_kernel<<>>( - input.data.dptr, reinterpret_cast(output->data.dptr), - reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), N, Mread, - input.scale_inv.shape.back());); // NOLINT(*) + TRANSFORMER_ENGINE_SWITCH_CONDITION( + with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, + + dequantize_fp4_kernel<<>>( + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), N, Mread, + input.scale_inv.shape.back(), num_scale_tiles_X);); // NOLINT(*) + ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); From 39c0fb1bde5f162f2f37ab2a867c5749a2632b2d Mon Sep 17 00:00:00 2001 From: YigongQin Date: Fri, 10 Apr 2026 14:40:24 -0700 Subject: [PATCH 3/4] pass nvfp4 dequant tests Signed-off-by: YigongQin --- tests/cpp/operator/test_dequantize_nvfp4.cu | 16 ++++++++-------- tests/cpp/test_common.cu | 16 ++++++++++------ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 535f744d99..eb857da6e7 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -151,20 +151,20 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) setRandomScale(&quantized_swizzled); quantized_swizzled.set_with_gemm_swizzled_scales(true); - // Copy FP4 data + // Copy amax and scale from compact to swizzled before FP4 data, + // since from_cpu() uploads all CPU buffers (including zero-init data). + quantized_compact.to_cpu(); + quantized_swizzled.set_amax(quantized_compact.amax()); + quantized_swizzled.set_scale(quantized_compact.scale()); + quantized_swizzled.from_cpu(); + + // Copy FP4 data after from_cpu() to avoid being overwritten const size_t data_bytes = rows * cols / 2; if (data_bytes > 0) { cudaMemcpy(quantized_swizzled.rowwise_dptr(), quantized_compact.rowwise_dptr(), data_bytes, cudaMemcpyDeviceToDevice); } - // Copy amax from compact to swizzled (GPU-to-GPU) - // Read amax from compact tensor, set on swizzled tensor, upload - quantized_compact.to_cpu(); - quantized_swizzled.set_amax(quantized_compact.amax()); - quantized_swizzled.set_scale(quantized_compact.scale()); - quantized_swizzled.from_cpu(); - // Swizzle scales if (data_bytes > 0) { nvte_swizzle_scaling_factors(quantized_compact.data(), quantized_swizzled.data(), 0); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 5180a81612..0f43b3c100 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -441,17 +441,20 @@ void Tensor::to_cpu() const { cudaMemcpyDeviceToHost); } if (isFp8Type(dtype()) || isFp4Type(dtype())) { - if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) { + if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) + || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) { if (tensor_.amax() != nullptr){ cudaMemcpy(amax_cpu_data_.get(), tensor_.amax(), sizeof(float), cudaMemcpyDeviceToHost); } - cudaMemcpy(scale_cpu_data_.get(), - tensor_.scale(), - sizeof(float), - cudaMemcpyDeviceToHost); + if (tensor_.scale() != nullptr) { + cudaMemcpy(scale_cpu_data_.get(), + tensor_.scale(), + sizeof(float), + cudaMemcpyDeviceToHost); + } } auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); if (rowwise_) { @@ -509,7 +512,8 @@ void Tensor::from_cpu() const { void Tensor::set_scale(float scale) { if (isFp8Type(dtype()) || isFp4Type(dtype())) { NVTE_CHECK(scale_cpu_data_); - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING + || tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { *scale_cpu_data_ = scale; from_cpu(); } From 305a3b448e502750c52827f3ec04367f127f088c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Apr 2026 22:07:01 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh | 8 ++++---- transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index 18e55d2dce..6441a567a6 100644 --- a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -163,11 +163,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) size_t scale_idx; if constexpr (WITH_GEMM_SWIZZLED_SCALES) { if constexpr (USE_ROWWISE_SCALING) { - scale_idx = swizzle::gemm_swizzled_scale_idx(scale_offset_Y, scale_offset_X, - num_scale_tiles_X); + scale_idx = + swizzle::gemm_swizzled_scale_idx(scale_offset_Y, scale_offset_X, num_scale_tiles_X); } else { - scale_idx = swizzle::gemm_swizzled_scale_idx(scale_offset_X, scale_offset_Y, - num_scale_tiles_X); + scale_idx = + swizzle::gemm_swizzled_scale_idx(scale_offset_X, scale_offset_Y, num_scale_tiles_X); } } else { scale_idx = scale_offset_Y * scales_stride + scale_offset_X; diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 63a7755556..4143208153 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -112,9 +112,9 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) dequantize_fp4_kernel<<>>( input.data.dptr, reinterpret_cast(output->data.dptr), reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), N, Mread, - input.scale_inv.shape.back(), num_scale_tiles_X);); // NOLINT(*) - ); // NOLINT(*) + reinterpret_cast(input.amax.dptr), N, Mread, input.scale_inv.shape.back(), + num_scale_tiles_X);); // NOLINT(*) + ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!");