Skip to content

[Bugfix] Fix OOM caused by cumem allocator inflating memory_reserved()#37111

Open
haosdent wants to merge 1 commit intovllm-project:mainfrom
haosdent:fix-37096
Open

[Bugfix] Fix OOM caused by cumem allocator inflating memory_reserved()#37111
haosdent wants to merge 1 commit intovllm-project:mainfrom
haosdent:fix-37096

Conversation

@haosdent
Copy link
Copy Markdown
Contributor

@haosdent haosdent commented Mar 15, 2026

Purpose

When the cumem allocator's cleanup (use_memory_pool() exit in cumem.py) manually calls unmap_and_release() to free cached blocks, it bypasses PyTorch's allocator tracking. This inflates torch.cuda.memory_reserved(), making non_torch_memory (cuda_memory - memory_reserved()) go deeply negative. The downstream effect is that non_kv_cache_memory is underestimated, causing vLLM to over-allocate KV cache and OOM on large models.

This was exposed after #32947 when a Python syntax bug fix (with A and B:with (A, B):) properly enabled the cumem allocator for model weight loading.

The fix replaces the memory_reserved()-based formula with a mem_get_info()-based measurement (total_consumed = before_create.free_memory - after_profile.free_memory). The CUDA driver's mem_get_info() always returns accurate physical free/total memory regardless of which allocator is used. Additionally, uses transient_peak_headroom (torch_peak - torch_allocated) instead of torch_peak_increase to avoid double-counting persistent torch allocations already included in total_consumed.

Fixes #37096

Test Plan

  • Added test_memory_profiling_persistent_torch to verify persistent torch allocations are not double-counted in non_kv_cache_memory.
  • E2E verification with cumem allocator on NVIDIA GB10:
Metric Without Fix With Fix
non_torch_increase -496.1 MiB (negative!) N/A (not used)
total_consumed (mem_get_info) N/A 527.9 MiB (accurate)
non_kv_cache_memory 15.9 MiB (underestimates ~49x) 783.9 MiB (correct)
Result Over-allocates KV cache → OOM Safe

Test Result

tests/utils_/test_mem_utils.py::test_memory_profiling_persistent_torch PASSED

@mergify mergify Bot added v1 bug Something isn't working labels Mar 15, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an Out-Of-Memory (OOM) error caused by inaccurate memory measurement when using the cumem allocator. The fix correctly changes the memory profiling logic to use mem_get_info() instead of the problematic memory_reserved(), providing a more accurate measure of consumed memory and preventing over-allocation of the KV cache. The changes are consistently propagated through the memory calculation logic in gpu_worker.py, and a test case is updated to validate the new behavior. The implementation appears correct and effectively addresses the described bug. I have no further comments as the changes are sound.

@haosdent haosdent changed the title [WIP] [Bugfix] Fix OOM caused by cumem allocator inflating memory_reserved() [Bugfix] Fix OOM caused by cumem allocator inflating memory_reserved() Mar 16, 2026
@haosdent haosdent marked this pull request as ready for review March 16, 2026 03:32
@haosdent haosdent requested a review from njhill as a code owner March 16, 2026 03:32
Comment thread vllm/utils/mem_utils.py Outdated
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 19, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @haosdent.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 19, 2026
@haosdent haosdent force-pushed the fix-37096 branch 4 times, most recently from 960b2c1 to b27b401 Compare March 19, 2026 05:55
@mergify mergify Bot removed the needs-rebase label Mar 19, 2026
When the cumem allocator's cleanup bypasses PyTorch's allocator tracking
via direct cuMemUnmap, memory_reserved() becomes inflated, making
non_torch_memory negative and underestimating non-KV cache memory usage.
This causes OOM on large models (e.g. gpt-oss-120b on GH200 144GB).

Replace the memory_reserved()-based formula with a mem_get_info()-based
measurement (total_consumed) that is always accurate regardless of which
allocator is used. Use transient_peak_headroom (torch_peak -
torch_allocated) instead of torch_peak_increase to avoid double-counting
persistent torch allocations already included in total_consumed.

Fixes vllm-project#37096

Signed-off-by: haosdent <haosdent@gmail.com>
Copy link
Copy Markdown
Collaborator

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @tjtanaa @jikunshang @JartX will this PR cause issues on ROCm? It's relying on mem_get_info()

@haosdent, see related PR #36720

@jikunshang
Copy link
Copy Markdown
Collaborator

not AMD expert, but I did notice this before #12624
cc @gshtras

@JartX
Copy link
Copy Markdown
Contributor

JartX commented Mar 19, 2026

Hi @MatthewBonanni, I just tested the PR on my ROCM setup and it loaded correctly. Thanks for mentioning me.😊 @AndreasKaratzas

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 20, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @haosdent.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 20, 2026
@mergify mergify Bot removed the needs-rebase label Apr 29, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 29, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @haosdent.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Apr 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working needs-rebase v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: v0.17.0-aarch64 onwards will run out of CUDA memory for gpt-oss-120b on GH200 144GB

4 participants