Description I encountered an Out Of Memory (OOM) error while training the model. Upon investigating model.py, I found a memory leak where the computation graph is being retained across batches, causing VRAM usage to grow linearly until the GPU crashes.
Root Cause: In model.py, inside training_step and validation_step, the model predictions (cell_preds) and targets (cell_targets) are appended to the class attributes self.train_cell_inside_preds, self.val_cell_inside_preds, etc., without being detached from the computation graph.
Because these lists persist for the entire epoch, PyTorch is forced to keep the gradients and computation graph for every single batch in VRAM until the epoch ends.
Location model.py: (and similar blocks in validation_step):
# Current Code
if inside_mask.any():
self.train_cell_inside_preds.append(cell_preds[inside_mask])
self.train_cell_inside_targets.append(cell_targets[inside_mask])
Suggested Fix You must detach the tensors and move them to the CPU before appending them to the list. This preserves the values for metrics calculation at the end of the epoch but frees the GPU memory immediately.
# Proposed Fix
if inside_mask.any():
self.train_cell_inside_preds.append(cell_preds[inside_mask].detach().cpu())
self.train_cell_inside_targets.append(cell_targets[inside_mask].detach().cpu())
Applying this fix to all list appends in both training_step and validation_step resolves the OOM error.
Description I encountered an Out Of Memory (OOM) error while training the model. Upon investigating model.py, I found a memory leak where the computation graph is being retained across batches, causing VRAM usage to grow linearly until the GPU crashes.
Root Cause: In model.py, inside training_step and validation_step, the model predictions (cell_preds) and targets (cell_targets) are appended to the class attributes self.train_cell_inside_preds, self.val_cell_inside_preds, etc., without being detached from the computation graph.
Because these lists persist for the entire epoch, PyTorch is forced to keep the gradients and computation graph for every single batch in VRAM until the epoch ends.
Location model.py: (and similar blocks in validation_step):
Suggested Fix You must detach the tensors and move them to the CPU before appending them to the list. This preserves the values for metrics calculation at the end of the epoch but frees the GPU memory immediately.
Applying this fix to all list appends in both training_step and validation_step resolves the OOM error.