Skip to content

Feature request: Gemma 4 model type support (gemma4, gemma4_text, gemma4_any_to_any) #2098

@justinchuby

Description

@justinchuby

Summary

Request runtime support for Google Gemma 4 models in ORT GenAI. The ONNX models are built by mobius and the preprocessing kernels are in flight in onnxruntime-extensions#1052. The missing piece is model type registration and runtime handling in ORT GenAI itself.

Model type: gemma4

A single gemma4 model type with the runtime detecting the pipeline variant based on which ONNX files are present:

Variant ONNX files Detection
Text-only model.onnx No vision/speech/embedding files
Vision-language (VLM) model.onnx + vision.onnx + embedding.onnx Has vision + embedding
Multimodal (any-to-any) model.onnx + vision.onnx + speech.onnx + embedding.onnx Has vision + speech + embedding

Reference genai_config.json files: examples/gemma4/ort_genai/

Architecture details requiring special handling

1. KV cache sharing (15 layers, not 35)

Gemma 4 E2B has 35 decoder layers, but the last 20 layers reuse K,V projections from earlier layers (num_kv_shared_layers = 20). Only the first 15 layers have independent KV caches.

  • num_hidden_layers in genai_config is 15 (KV cache depth)
  • past_key_values.{0..14}.key/value are the only KV inputs/outputs
  • The ONNX model internally routes shared K,V via present.N.key/valuepast_key_values.M.key/value connections

Note: Related issue #2097 discusses the interaction between KV sharing and GroupQueryAttention for shared layers.

2. Dual head_dim (sliding vs full attention)

Sliding-window (local) layers use head_dim = 256; full-attention (global) layers use global_head_dim = 512. The genai_config sets head_size = 256 (local value). The runtime must handle the fact that KV cache entries have different shapes depending on layer type:

  • Sliding layers (indices 0-3, 5-8, 10-13): KV shape [B, 1, S, 256]
  • Full layers (indices 4, 9, 14): KV shape [B, 1, S, 512]

3. Sliding window attention (per-layer)

The 5:1 local:global pattern means sliding window applies to specific layers only. The sliding_window.layers field in genai_config lists which KV cache indices use sliding window:

"sliding_window": {
    "window_size": 512,
    "layers": [0, 1, 2, 3, 5, 6, 7, 8, 10, 11, 12, 13]
}

4. Vision encoder inputs

Unlike other VLMs that take raw pixel tensors, Gemma 4 vision takes pre-patchified inputs:

  • pixel_values [batch, num_patches, 3 * 16 * 16] — flattened patch pixels
  • pixel_position_ids [batch, num_patches, 2] — (row, col) patch coordinates

The preprocessing is handled by Gemma4ImageTransform in onnxruntime-extensions#1052.

5. Audio encoder (multimodal variant only)

The speech.onnx model takes input_features [batch, time, 128] (128-dim log-mel spectrogram) and outputs audio_features [batch, time/4, 1536]. Audio token positions in input_ids are marked by audio_token_id = 258881.

The preprocessing is handled by Gemma4LogMel in onnxruntime-extensions#1052.

Key config values (google/gemma-4-E2B-it)

Field Value
vocab_size 262,144
hidden_size 1,536
num_attention_heads 8
num_key_value_heads 1
head_size (local) 256
global_head_dim (full) 512
num_hidden_layers (KV depth) 15
sliding_window 512
bos_token_id 2
eos_token_id [1, 106]
image_token_id 255,999
audio_token_id 258,881

Dependencies

  • onnxruntime-extensions#1052: Image + audio preprocessing kernels (in review)
  • onnxruntime/mobius PR#190: ONNX model export with GroupQueryAttention + sliding window
  • onnxruntime-genai#2097: KV-shared layer + GQA interaction (separate but related)

Implementation notes

The closest existing model types are gemma3 (VLM) and its text-only variant. The main new runtime work is:

  1. Register the gemma4 model type with auto-detection of pipeline variant
  2. Handle dual head_dim in KV cache allocation
  3. Handle per-layer sliding window (already partially supported by sliding_window.layers)
  4. Wire speech.onnx for the multimodal variant
  5. Map pixel_position_ids as a vision input (in addition to pixel_values)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions