Summary
The current KDA unit tests are too slow for regular development workflows. In particular, tests/test_kda.py and tests/test_kda_compare_fla.py both exercise large parameter grids, expensive variable-length traces, and full forward/backward comparisons across multiple implementations.
This provides strong coverage, but it also makes local iteration slow and increases CI cost.
We should significantly reduce KDA unit test runtime while still keeping enough coverage to catch correctness regressions. A practical direction is to split the suite into fast and slow modes, where fast covers representative correctness paths for default CI and local iteration, and slow preserves the broader stress coverage for nightly or manual runs.
Problem
Both of the current KDA test files are expensive:
tests/test_kda.py
tests/test_kda_compare_fla.py
Today, each file contains:
- 32 fixed-length parameterized cases
- 36 variable-length parameterized cases
- forward + backward checks per case
- multiple reference paths (
naive_recurrent_kda or fla_chunk_kda)
- large varlen traces up to long packed sequences
- extra parameter dimensions such as
beta_dtype and disable_recompute
In practice, this means:
- slow local development loops
- higher CI runtime and GPU time consumption
- redundant coverage between the two test files
- large stress cases blocking quick correctness checks
Goal
Reduce the default runtime of KDA tests substantially while preserving enough coverage to:
- catch obvious correctness regressions quickly
- keep representative coverage for fixed-length and varlen paths
- keep at least some backward coverage in default runs
- retain broader stress coverage in an opt-in mode
Why The Current Tests Are Slow
1. Large Cartesian products
The current tests multiply:
- model/config cases
beta_dtype
disable_recompute
This expands the suite significantly even though not every large case needs every secondary toggle.
2. Heavy reference implementations
tests/test_kda.py compares against a naive recurrent reference, and tests/test_kda_compare_fla.py compares against the FLA Triton implementation. Both are useful, but running them both across broad parameter grids creates significant overlap.
3. Expensive backward checks on every case
Many cases validate:
- output
n- final state
- gradients for q/k/v/g/beta/h0
- sometimes
A_log and dt_bias
This is much heavier than forward-only validation.
4. Very large varlen traces in default test mode
Some varlen traces are closer to stress/performance validation than quick unit testing.
Proposed Direction
1. Split into fast and slow modes
Introduce two tiers of KDA tests:
Fast mode
Run by default in local development and standard CI.
Properties:
- small number of representative fixed-length cases
- small number of representative varlen cases
- limited use of
disable_recompute=False
- limited
beta_dtype coverage
- only a subset of cases perform full backward validation
Slow mode
Run manually or in nightly CI.
Properties:
- large varlen traces
- broader parameter combinations
- exhaustive backward checks
- compatibility comparisons against FLA across larger grids
This can be implemented with markers such as:
@pytest.mark.kda_fast
@pytest.mark.kda_slow
or with a single slow marker for the expensive cases.
2. Reduce redundant cross-product coverage
Instead of testing all combinations of:
beta_dtype
disable_recompute
- large shapes
we should:
- keep full cross-product only for a few small representative cases
- run large cases only with the default / most common settings
- test rare combinations on small or medium inputs only
3. Keep only representative backward-heavy cases in fast mode
Default fast mode should still include backward validation, but only on a small subset:
- one small fixed-length case
- one medium fixed-length case
- one varlen case
- one
use_gate_in_kernel=True case
The remaining fast-mode cases can focus on forward outputs and final state.
4. Move the largest varlen traces into slow mode
The trace-like varlen cases are valuable, but they should not dominate regular unit test runtime.
These cases should be marked as slow and excluded from default runs.
5. Clarify the roles of the two files
The two test files currently overlap in intent:
tests/test_kda.py: algorithmic correctness against naive recurrence
tests/test_kda_compare_fla.py: implementation compatibility against FLA Triton
Suggested direction:
- keep
tests/test_kda.py as the stronger correctness-oriented suite
- reduce
tests/test_kda_compare_fla.py to a smaller compatibility-focused matrix
This avoids paying the cost of two nearly full broad sweeps.
6. Remove unnecessary graph retention where possible
The tests currently use retain_graph=True in backward calls. Some of these may be unnecessary if the graph is not reused. Removing unneeded graph retention may reduce memory pressure and slightly reduce runtime.
This should be reviewed carefully to avoid changing test behavior.
Possible Implementation Plan
Phase 1: Add markers and separate heavy cases
- mark largest varlen traces as slow
- mark broad FLA comparison cases as slow
- keep a compact default path for quick validation
Phase 2: Reduce matrix size
- trim cross-product expansion for
beta_dtype
- trim cross-product expansion for
disable_recompute
- keep special toggles only on representative cases
Phase 3: Reduce duplicate coverage
- keep naive-reference-heavy coverage in
test_kda.py
- keep smaller FLA compatibility coverage in
test_kda_compare_fla.py
Phase 4: Optional cleanup
- remove unnecessary
retain_graph=True
- refactor common test helpers to make the fast/slow split easier to maintain
Acceptance Criteria
- default KDA unit test runtime is significantly reduced
- default mode still covers:
- fixed-length path
- varlen path
- at least one backward path
- at least one gated path
- slow mode preserves the current broader stress coverage
- test intent becomes clearer: correctness coverage vs compatibility coverage
- developers can choose between quick local checks and broader regression sweeps
Example Outcome
A possible end state could be:
pytest tests/test_kda.py tests/test_kda_compare_fla.py runs a compact fast suite
pytest -m slow tests/test_kda.py tests/test_kda_compare_fla.py runs the broader stress suite
- nightly CI includes slow coverage
- regular PR CI only runs fast coverage
Open Questions
- Should the default PR CI run only
fast, or fast + a very small subset of slow?
- Should
test_kda_compare_fla.py be reduced more aggressively than test_kda.py?
- Should gradient checks be made optional for some larger default cases?
- Do we want an explicit
KDA_TEST_MODE=fast|slow switch in addition to pytest markers?
Impact
This change should improve:
- local developer iteration speed
- GPU efficiency in CI
- clarity of test intent
- maintainability of the KDA test suite
while preserving meaningful correctness coverage.
Summary
The current KDA unit tests are too slow for regular development workflows. In particular,
tests/test_kda.pyandtests/test_kda_compare_fla.pyboth exercise large parameter grids, expensive variable-length traces, and full forward/backward comparisons across multiple implementations.This provides strong coverage, but it also makes local iteration slow and increases CI cost.
We should significantly reduce KDA unit test runtime while still keeping enough coverage to catch correctness regressions. A practical direction is to split the suite into
fastandslowmodes, wherefastcovers representative correctness paths for default CI and local iteration, andslowpreserves the broader stress coverage for nightly or manual runs.Problem
Both of the current KDA test files are expensive:
tests/test_kda.pytests/test_kda_compare_fla.pyToday, each file contains:
naive_recurrent_kdaorfla_chunk_kda)beta_dtypeanddisable_recomputeIn practice, this means:
Goal
Reduce the default runtime of KDA tests substantially while preserving enough coverage to:
Why The Current Tests Are Slow
1. Large Cartesian products
The current tests multiply:
beta_dtypedisable_recomputeThis expands the suite significantly even though not every large case needs every secondary toggle.
2. Heavy reference implementations
tests/test_kda.pycompares against a naive recurrent reference, andtests/test_kda_compare_fla.pycompares against the FLA Triton implementation. Both are useful, but running them both across broad parameter grids creates significant overlap.3. Expensive backward checks on every case
Many cases validate:
n- final state
A_loganddt_biasThis is much heavier than forward-only validation.
4. Very large varlen traces in default test mode
Some varlen traces are closer to stress/performance validation than quick unit testing.
Proposed Direction
1. Split into fast and slow modes
Introduce two tiers of KDA tests:
Fast mode
Run by default in local development and standard CI.
Properties:
disable_recompute=Falsebeta_dtypecoverageSlow mode
Run manually or in nightly CI.
Properties:
This can be implemented with markers such as:
@pytest.mark.kda_fast@pytest.mark.kda_slowor with a single
slowmarker for the expensive cases.2. Reduce redundant cross-product coverage
Instead of testing all combinations of:
beta_dtypedisable_recomputewe should:
3. Keep only representative backward-heavy cases in fast mode
Default fast mode should still include backward validation, but only on a small subset:
use_gate_in_kernel=TruecaseThe remaining fast-mode cases can focus on forward outputs and final state.
4. Move the largest varlen traces into slow mode
The trace-like varlen cases are valuable, but they should not dominate regular unit test runtime.
These cases should be marked as slow and excluded from default runs.
5. Clarify the roles of the two files
The two test files currently overlap in intent:
tests/test_kda.py: algorithmic correctness against naive recurrencetests/test_kda_compare_fla.py: implementation compatibility against FLA TritonSuggested direction:
tests/test_kda.pyas the stronger correctness-oriented suitetests/test_kda_compare_fla.pyto a smaller compatibility-focused matrixThis avoids paying the cost of two nearly full broad sweeps.
6. Remove unnecessary graph retention where possible
The tests currently use
retain_graph=Truein backward calls. Some of these may be unnecessary if the graph is not reused. Removing unneeded graph retention may reduce memory pressure and slightly reduce runtime.This should be reviewed carefully to avoid changing test behavior.
Possible Implementation Plan
Phase 1: Add markers and separate heavy cases
Phase 2: Reduce matrix size
beta_dtypedisable_recomputePhase 3: Reduce duplicate coverage
test_kda.pytest_kda_compare_fla.pyPhase 4: Optional cleanup
retain_graph=TrueAcceptance Criteria
Example Outcome
A possible end state could be:
pytest tests/test_kda.py tests/test_kda_compare_fla.pyruns a compact fast suitepytest -m slow tests/test_kda.py tests/test_kda_compare_fla.pyruns the broader stress suiteOpen Questions
fast, orfast + a very small subset of slow?test_kda_compare_fla.pybe reduced more aggressively thantest_kda.py?KDA_TEST_MODE=fast|slowswitch in addition to pytest markers?Impact
This change should improve:
while preserving meaningful correctness coverage.