Describe the Bug
apex.normalization.FusedRMSNorm leaks 2 CUDA tensors on every forward call, even under torch.no_grad(). The accumulation is linear in the number of calls and is not freed by gc.collect() or torch.cuda.empty_cache(), which causes OOM in long-running inference workloads.
This is particularly impactful because Hugging Face transformers automatically uses FusedRMSNorm for T5 layer norms when apex is importable. T5-XXL has 49 LayerNorm calls per forward pass, so 49 × 2 = 98 CUDA tensors per T5EncoderModel.forward(...). In a real-world FLUX inference workload on a single B200 (180GB VRAM), this caused OOM after ~90 batches at batch_size=4 (~8 GB of GPU memory accumulated per batch).
The leak disappears completely when apex is blocked from being imported (forcing transformers to fall back to native PyTorch RMSNorm), which strongly points to FusedRMSNorm as the source.
Minimal Steps/Code to Reproduce the Bug
This reproduces the leak using only apex.normalization.FusedRMSNorm (no transformers / FLUX dependencies):
import torch
import gc
from apex.normalization import FusedRMSNorm
device = 'cuda:0'
ln = FusedRMSNorm(4096).to(device, dtype=torch.bfloat16)
x = torch.randn(1, 256, 4096, dtype=torch.bfloat16, device=device)
# warm-up
with torch.no_grad():
_ = ln(x)
def count_cuda_tensors():
return sum(1 for o in gc.get_objects()
if torch.is_tensor(o) and o.is_cuda)
prev = count_cuda_tensors()
for i in range(5):
with torch.no_grad():
_ = ln(x)
gc.collect()
torch.cuda.empty_cache()
cur = count_cuda_tensors()
print(f'iter {i}: {cur} CUDA tensors (delta=+{cur - prev})')
prev = cur
Output:
iter 0: 6 CUDA tensors (delta=+2)
iter 1: 8 CUDA tensors (delta=+2)
iter 2: 10 CUDA tensors (delta=+2)
iter 3: 12 CUDA tensors (delta=+2)
iter 4: 14 CUDA tensors (delta=+2)
gc.collect() and torch.cuda.empty_cache() between calls do not free these tensors. The two leaked tensors per call are:
(1, 256, 4096) bfloat16 — input/output activation
(256,) float32 — sequence-length-shaped auxiliary tensor
gc.get_referrers(...) shows that no Python object directly holds a reference to these tensors after the call returns, suggesting they are retained at the C++ / torch._library.custom_ops autograd metadata level.
This is unusual under torch.no_grad().
Real-world impact via Hugging Face Transformers / T5
We isolated the leak by running each stage of FluxPipeline independently:
| Component |
CUDA tensors leaked per call |
pipe.encode_prompt(...) |
+98 (= 49 × 2, matches T5-XXL's 49 LayerNorms) |
pipe.scheduler.set_timesteps(...) |
0 |
pipe.vae.decode(...) |
0 |
pipe.transformer(...) |
0 |
Only the T5-encoding path (which uses FusedRMSNorm via apex) leaks.
Workaround
Blocking apex's import before transformers is loaded forces T5 to fall back to its native PyTorch RMSNorm:
import sys
sys.modules['apex'] = None
sys.modules['apex.normalization'] = None
sys.modules['apex.normalization.fused_layer_norm'] = None
# now import transformers / diffusers / ...
After this, the same loop above shows delta=+0 for every iteration, and GPU memory stays flat across thousands of calls.
Expected Behavior
Forward calls under torch.no_grad() should not retain GPU tensors after returning. Repeated forward calls should not cause linear GPU memory growth.
Environment
- OS: Ubuntu 24.04.3 LTS (x86_64)
- Python: 3.12.3
- PyTorch: 2.11.0+cu130
- CUDA: 13.0 (PyTorch built); 13.1.80 (runtime)
- GPU: NVIDIA B200 (driver 580.95.05)
- apex: 0.1 (pip-installed; specific commit unknown)
- transformers / diffusers: 4.57.6 / 0.37.1 (only relevant for the T5/FLUX context above; not needed for the minimal repro)
Describe the Bug
apex.normalization.FusedRMSNormleaks 2 CUDA tensors on every forward call, even undertorch.no_grad(). The accumulation is linear in the number of calls and is not freed bygc.collect()ortorch.cuda.empty_cache(), which causes OOM in long-running inference workloads.This is particularly impactful because Hugging Face
transformersautomatically usesFusedRMSNormfor T5 layer norms when apex is importable. T5-XXL has 49 LayerNorm calls per forward pass, so 49 × 2 = 98 CUDA tensors perT5EncoderModel.forward(...). In a real-world FLUX inference workload on a single B200 (180GB VRAM), this caused OOM after ~90 batches atbatch_size=4(~8 GB of GPU memory accumulated per batch).The leak disappears completely when apex is blocked from being imported (forcing
transformersto fall back to native PyTorch RMSNorm), which strongly points toFusedRMSNormas the source.Minimal Steps/Code to Reproduce the Bug
This reproduces the leak using only
apex.normalization.FusedRMSNorm(notransformers/ FLUX dependencies):Output:
gc.collect()andtorch.cuda.empty_cache()between calls do not free these tensors. The two leaked tensors per call are:(1, 256, 4096) bfloat16— input/output activation(256,) float32— sequence-length-shaped auxiliary tensorgc.get_referrers(...)shows that no Python object directly holds a reference to these tensors after the call returns, suggesting they are retained at the C++ /torch._library.custom_opsautograd metadata level.This is unusual under
torch.no_grad().Real-world impact via Hugging Face Transformers / T5
We isolated the leak by running each stage of
FluxPipelineindependently:pipe.encode_prompt(...)pipe.scheduler.set_timesteps(...)pipe.vae.decode(...)pipe.transformer(...)Only the T5-encoding path (which uses
FusedRMSNormvia apex) leaks.Workaround
Blocking apex's import before
transformersis loaded forces T5 to fall back to its native PyTorch RMSNorm:After this, the same loop above shows
delta=+0for every iteration, and GPU memory stays flat across thousands of calls.Expected Behavior
Forward calls under
torch.no_grad()should not retain GPU tensors after returning. Repeated forward calls should not cause linear GPU memory growth.Environment