Skip to content

[Bug] FusedRMSNorm leaks 2 CUDA tensors per forward call under torch.no_grad() #1999

@dustnehowl

Description

@dustnehowl

Describe the Bug

apex.normalization.FusedRMSNorm leaks 2 CUDA tensors on every forward call, even under torch.no_grad(). The accumulation is linear in the number of calls and is not freed by gc.collect() or torch.cuda.empty_cache(), which causes OOM in long-running inference workloads.
This is particularly impactful because Hugging Face transformers automatically uses FusedRMSNorm for T5 layer norms when apex is importable. T5-XXL has 49 LayerNorm calls per forward pass, so 49 × 2 = 98 CUDA tensors per T5EncoderModel.forward(...). In a real-world FLUX inference workload on a single B200 (180GB VRAM), this caused OOM after ~90 batches at batch_size=4 (~8 GB of GPU memory accumulated per batch).
The leak disappears completely when apex is blocked from being imported (forcing transformers to fall back to native PyTorch RMSNorm), which strongly points to FusedRMSNorm as the source.

Minimal Steps/Code to Reproduce the Bug

This reproduces the leak using only apex.normalization.FusedRMSNorm (no transformers / FLUX dependencies):

import torch
import gc
from apex.normalization import FusedRMSNorm

device = 'cuda:0'
ln = FusedRMSNorm(4096).to(device, dtype=torch.bfloat16)
x = torch.randn(1, 256, 4096, dtype=torch.bfloat16, device=device)

# warm-up
with torch.no_grad():
    _ = ln(x)

def count_cuda_tensors():
    return sum(1 for o in gc.get_objects()
               if torch.is_tensor(o) and o.is_cuda)

prev = count_cuda_tensors()
for i in range(5):
    with torch.no_grad():
        _ = ln(x)
    gc.collect()
    torch.cuda.empty_cache()
    cur = count_cuda_tensors()
    print(f'iter {i}: {cur} CUDA tensors (delta=+{cur - prev})')
    prev = cur

Output:

iter 0: 6 CUDA tensors (delta=+2)
iter 1: 8 CUDA tensors (delta=+2)
iter 2: 10 CUDA tensors (delta=+2)
iter 3: 12 CUDA tensors (delta=+2)
iter 4: 14 CUDA tensors (delta=+2)

gc.collect() and torch.cuda.empty_cache() between calls do not free these tensors. The two leaked tensors per call are:

  • (1, 256, 4096) bfloat16 — input/output activation
  • (256,) float32 — sequence-length-shaped auxiliary tensor

gc.get_referrers(...) shows that no Python object directly holds a reference to these tensors after the call returns, suggesting they are retained at the C++ / torch._library.custom_ops autograd metadata level.
This is unusual under torch.no_grad().

Real-world impact via Hugging Face Transformers / T5

We isolated the leak by running each stage of FluxPipeline independently:

Component CUDA tensors leaked per call
pipe.encode_prompt(...) +98 (= 49 × 2, matches T5-XXL's 49 LayerNorms)
pipe.scheduler.set_timesteps(...) 0
pipe.vae.decode(...) 0
pipe.transformer(...) 0

Only the T5-encoding path (which uses FusedRMSNorm via apex) leaks.

Workaround

Blocking apex's import before transformers is loaded forces T5 to fall back to its native PyTorch RMSNorm:

import sys
sys.modules['apex'] = None
sys.modules['apex.normalization'] = None
sys.modules['apex.normalization.fused_layer_norm'] = None
# now import transformers / diffusers / ...

After this, the same loop above shows delta=+0 for every iteration, and GPU memory stays flat across thousands of calls.

Expected Behavior

Forward calls under torch.no_grad() should not retain GPU tensors after returning. Repeated forward calls should not cause linear GPU memory growth.

Environment

  • OS: Ubuntu 24.04.3 LTS (x86_64)
  • Python: 3.12.3
  • PyTorch: 2.11.0+cu130
  • CUDA: 13.0 (PyTorch built); 13.1.80 (runtime)
  • GPU: NVIDIA B200 (driver 580.95.05)
  • apex: 0.1 (pip-installed; specific commit unknown)
  • transformers / diffusers: 4.57.6 / 0.37.1 (only relevant for the T5/FLUX context above; not needed for the minimal repro)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions