diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index e2f86d02dbec..d152292635fe 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -701,13 +701,15 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( __syncthreads(); + // disable rtz conversion due to its impact on accuracy. constexpr bool LOGITS_RTZ_CONVERSION = false; // write logits to shared mem for (int token_depth = 0; token_depth < TLOOP; token_depth++) { dout[token_depth] *= inv_sum_scale; if constexpr (LOGITS_RTZ_CONVERSION) { - // use rtz conversion for performance, with no visible impact on accuracy + // use rtz conversion for better performance, with negligible impact on + // accuracy. shared_logits[warpid][token_depth][lane16id][rowid] = from_floatx4_rtz(dout[token_depth]); } else { diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index b826559cd8db..10d984351f67 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -7,7 +7,7 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils import get_max_shared_memory_bytes +from vllm.utils import get_max_shared_memory_bytes, is_navi from .allclose_default import get_default_atol, get_default_rtol @@ -33,7 +33,7 @@ # This should be sync with get_supported_head_sizes() in # vllm.attention.ops.paged_attn.PagedAttention -HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256] +HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256] BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] @@ -116,7 +116,8 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize( - "version", ["v1", "v2"] if not current_platform.is_rocm() else ["rocm"]) + "version", + ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]) @pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -181,7 +182,11 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - k_scale = v_scale = torch.tensor(0.3, dtype=torch.float) + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32) + + # additional argument for v1/v2 pa kernel + num_threads = 1024 if current_platform.is_rocm() \ + and not is_navi() else 128 # Call the paged attention kernel. output = torch.empty_like(query) @@ -203,12 +208,12 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v1, - (output, query, key_cache, value_cache, num_kv_heads, scale, - block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v1, + (output, query, key_cache, value_cache, num_kv_heads, scale, + block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, num_threads), + cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) elif version in ("v2", "rocm"): if current_platform.is_rocm(): @@ -247,13 +252,14 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v2, - (output, exp_sums, max_logits, tmp_output, query, - key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v2, + (output, exp_sums, max_logits, tmp_output, query, key_cache, + value_cache, num_kv_heads, scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, kv_cache_dtype, + k_scale, v_scale, 0, 0, 0, 64, 0, num_threads), + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) else: ops.paged_attention_rocm( @@ -299,14 +305,14 @@ def test_paged_attention( dtype=dtype, device=device) ops.convert_fp8(dequantized_key_cache, key_cache) - key_cache = k_scale * dequantized_key_cache + key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, device=device) ops.convert_fp8(dequantized_value_cache, value_cache) - value_cache = v_scale * dequantized_value_cache + value_cache = dequantized_value_cache ref_output = torch.empty_like(query) ref_single_query_cached_kv_attention( @@ -434,4 +440,4 @@ def test_multi_query_kv_attention( ) atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) \ No newline at end of file