Skip to content

# Reduce KDA Unit Test Runtime While Preserving Coverage #78

@icavan

Description

@icavan

Summary

The current KDA unit tests are too slow for regular development workflows. In particular, tests/test_kda.py and tests/test_kda_compare_fla.py both exercise large parameter grids, expensive variable-length traces, and full forward/backward comparisons across multiple implementations.

This provides strong coverage, but it also makes local iteration slow and increases CI cost.

We should significantly reduce KDA unit test runtime while still keeping enough coverage to catch correctness regressions. A practical direction is to split the suite into fast and slow modes, where fast covers representative correctness paths for default CI and local iteration, and slow preserves the broader stress coverage for nightly or manual runs.

Problem

Both of the current KDA test files are expensive:

  • tests/test_kda.py
  • tests/test_kda_compare_fla.py

Today, each file contains:

  • 32 fixed-length parameterized cases
  • 36 variable-length parameterized cases
  • forward + backward checks per case
  • multiple reference paths (naive_recurrent_kda or fla_chunk_kda)
  • large varlen traces up to long packed sequences
  • extra parameter dimensions such as beta_dtype and disable_recompute

In practice, this means:

  • slow local development loops
  • higher CI runtime and GPU time consumption
  • redundant coverage between the two test files
  • large stress cases blocking quick correctness checks

Goal

Reduce the default runtime of KDA tests substantially while preserving enough coverage to:

  • catch obvious correctness regressions quickly
  • keep representative coverage for fixed-length and varlen paths
  • keep at least some backward coverage in default runs
  • retain broader stress coverage in an opt-in mode

Why The Current Tests Are Slow

1. Large Cartesian products

The current tests multiply:

  • model/config cases
  • beta_dtype
  • disable_recompute

This expands the suite significantly even though not every large case needs every secondary toggle.

2. Heavy reference implementations

tests/test_kda.py compares against a naive recurrent reference, and tests/test_kda_compare_fla.py compares against the FLA Triton implementation. Both are useful, but running them both across broad parameter grids creates significant overlap.

3. Expensive backward checks on every case

Many cases validate:

  • output
    n- final state
  • gradients for q/k/v/g/beta/h0
  • sometimes A_log and dt_bias

This is much heavier than forward-only validation.

4. Very large varlen traces in default test mode

Some varlen traces are closer to stress/performance validation than quick unit testing.

Proposed Direction

1. Split into fast and slow modes

Introduce two tiers of KDA tests:

Fast mode

Run by default in local development and standard CI.

Properties:

  • small number of representative fixed-length cases
  • small number of representative varlen cases
  • limited use of disable_recompute=False
  • limited beta_dtype coverage
  • only a subset of cases perform full backward validation

Slow mode

Run manually or in nightly CI.

Properties:

  • large varlen traces
  • broader parameter combinations
  • exhaustive backward checks
  • compatibility comparisons against FLA across larger grids

This can be implemented with markers such as:

  • @pytest.mark.kda_fast
  • @pytest.mark.kda_slow

or with a single slow marker for the expensive cases.

2. Reduce redundant cross-product coverage

Instead of testing all combinations of:

  • beta_dtype
  • disable_recompute
  • large shapes

we should:

  • keep full cross-product only for a few small representative cases
  • run large cases only with the default / most common settings
  • test rare combinations on small or medium inputs only

3. Keep only representative backward-heavy cases in fast mode

Default fast mode should still include backward validation, but only on a small subset:

  • one small fixed-length case
  • one medium fixed-length case
  • one varlen case
  • one use_gate_in_kernel=True case

The remaining fast-mode cases can focus on forward outputs and final state.

4. Move the largest varlen traces into slow mode

The trace-like varlen cases are valuable, but they should not dominate regular unit test runtime.

These cases should be marked as slow and excluded from default runs.

5. Clarify the roles of the two files

The two test files currently overlap in intent:

  • tests/test_kda.py: algorithmic correctness against naive recurrence
  • tests/test_kda_compare_fla.py: implementation compatibility against FLA Triton

Suggested direction:

  • keep tests/test_kda.py as the stronger correctness-oriented suite
  • reduce tests/test_kda_compare_fla.py to a smaller compatibility-focused matrix

This avoids paying the cost of two nearly full broad sweeps.

6. Remove unnecessary graph retention where possible

The tests currently use retain_graph=True in backward calls. Some of these may be unnecessary if the graph is not reused. Removing unneeded graph retention may reduce memory pressure and slightly reduce runtime.

This should be reviewed carefully to avoid changing test behavior.

Possible Implementation Plan

Phase 1: Add markers and separate heavy cases

  • mark largest varlen traces as slow
  • mark broad FLA comparison cases as slow
  • keep a compact default path for quick validation

Phase 2: Reduce matrix size

  • trim cross-product expansion for beta_dtype
  • trim cross-product expansion for disable_recompute
  • keep special toggles only on representative cases

Phase 3: Reduce duplicate coverage

  • keep naive-reference-heavy coverage in test_kda.py
  • keep smaller FLA compatibility coverage in test_kda_compare_fla.py

Phase 4: Optional cleanup

  • remove unnecessary retain_graph=True
  • refactor common test helpers to make the fast/slow split easier to maintain

Acceptance Criteria

  • default KDA unit test runtime is significantly reduced
  • default mode still covers:
    • fixed-length path
    • varlen path
    • at least one backward path
    • at least one gated path
  • slow mode preserves the current broader stress coverage
  • test intent becomes clearer: correctness coverage vs compatibility coverage
  • developers can choose between quick local checks and broader regression sweeps

Example Outcome

A possible end state could be:

  • pytest tests/test_kda.py tests/test_kda_compare_fla.py runs a compact fast suite
  • pytest -m slow tests/test_kda.py tests/test_kda_compare_fla.py runs the broader stress suite
  • nightly CI includes slow coverage
  • regular PR CI only runs fast coverage

Open Questions

  • Should the default PR CI run only fast, or fast + a very small subset of slow?
  • Should test_kda_compare_fla.py be reduced more aggressively than test_kda.py?
  • Should gradient checks be made optional for some larger default cases?
  • Do we want an explicit KDA_TEST_MODE=fast|slow switch in addition to pytest markers?

Impact

This change should improve:

  • local developer iteration speed
  • GPU efficiency in CI
  • clarity of test intent
  • maintainability of the KDA test suite

while preserving meaningful correctness coverage.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions