Skip to content

Critical Memory Leak: Linear VRAM growth during training due to graph retention #2

@GravityBoi

Description

@GravityBoi

Description I encountered an Out Of Memory (OOM) error while training the model. Upon investigating model.py, I found a memory leak where the computation graph is being retained across batches, causing VRAM usage to grow linearly until the GPU crashes.

Root Cause: In model.py, inside training_step and validation_step, the model predictions (cell_preds) and targets (cell_targets) are appended to the class attributes self.train_cell_inside_preds, self.val_cell_inside_preds, etc., without being detached from the computation graph.

Because these lists persist for the entire epoch, PyTorch is forced to keep the gradients and computation graph for every single batch in VRAM until the epoch ends.

Location model.py: (and similar blocks in validation_step):

# Current Code
if inside_mask.any():
    self.train_cell_inside_preds.append(cell_preds[inside_mask]) 
    self.train_cell_inside_targets.append(cell_targets[inside_mask])

Suggested Fix You must detach the tensors and move them to the CPU before appending them to the list. This preserves the values for metrics calculation at the end of the epoch but frees the GPU memory immediately.

# Proposed Fix
if inside_mask.any():
    self.train_cell_inside_preds.append(cell_preds[inside_mask].detach().cpu())
    self.train_cell_inside_targets.append(cell_targets[inside_mask].detach().cpu())

Applying this fix to all list appends in both training_step and validation_step resolves the OOM error.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions