[copilot generated] Gemma4 Per-Layer Embedding: Architecture Analysis & Alternatives
Problem Statement
Gemma4 VLMs use a 3-model split (decoder, vision_encoder, embedding) for ORT GenAI deployment. The standard VLM pipeline is:
embedding(input_ids, image_features) → inputs_embeds → decoder(inputs_embeds) → logits
But Gemma4 E2B and E4B have per-layer input embeddings: each decoder layer has its own embedding table that looks up the original input_ids to compute a per-layer conditioning signal. This means the decoder needs input_ids in addition to inputs_embeds.
Which models are affected?
| Model |
hidden_size_per_layer_input |
Affected? |
| Gemma4 E2B (2B) |
256 |
✅ Yes — 35 layers × 262K vocab × 256 dim |
| Gemma4 E4B (4B) |
256 |
✅ Yes — 42 layers × 262K vocab × 256 dim |
| Gemma4 26B-A4B |
0 |
❌ No |
| Gemma4 31B |
0 |
❌ No |
Per-layer input mechanism (current implementation)
In Gemma4TextModel._compute_per_layer_inputs() (gemma4.py:1424-1481):
- Model projection:
inputs_embeds → Linear → [B, S, L*D] (project to all layers at once)
- Scale: multiply by
hidden_size^(-0.5)
- Reshape + norm:
[B, S, L, D] → per-layer norm
- Per-layer token embedding: For each layer
i, look up input_ids in embed_tokens_per_layer[i] (each table is [262K, 256])
- Combine:
proj_i + token_emb_i, scale by 0.5^0.5
Then in each decoder layer, the per-layer input is consumed via a gated addition:
gated = act(gate(hidden_states)) * per_layer_input
hidden_states += norm(project(gated))
Current mobius approach
The current implementation passes both inputs_embeds AND input_ids to the decoder ONNX model:
inputs_embeds — from the embedding model (with vision features scattered)
input_ids — the raw token IDs for per-layer embedding lookup
This works architecturally but requires GenAI runtime to pass input_ids to the decoder, which is non-standard for VLM decoders (they normally only receive inputs_embeds).
Alternative Analysis
Alternative A: Pre-compute ALL per-layer embeddings in the embedding model
Approach: Move _compute_per_layer_inputs() and all 35-42 embed_tokens_per_layer tables into the embedding model. Output: [num_layers, B, S, per_layer_dim] as an additional output.
Memory cost: For E4B with seq_len=2048:
- Per call:
42 × 1 × 2048 × 256 × 2 bytes(bf16) = 44 MB — transmitted between models
- Embedding model weights: adds 42 × 262144 × 256 × 2 bytes = 5.6 GB to the embedding model (E4B)
- Plus the projection Linear:
2560 → 42*256 = 10752 → 27M params → 54 MB
Pros:
- Decoder becomes standard VLM decoder (inputs_embeds only) — no GenAI runtime changes needed
- Clean separation: embedding model owns all embedding logic
Cons:
- Large inter-model tensor: 44 MB per forward pass for the per-layer embeddings output (
num_layers × B × S × per_layer_dim in bf16). This is transmitted from the embedding model to the decoder every step.
- Model balance shift: Per-layer embedding tables (5.6 GB for E4B, 2.8 GB for E2B) move from decoder to embedding model. The embedding model becomes disproportionately large while the decoder shrinks. This affects memory planning — the embedding model runs once (prefill) while the decoder runs every token.
- Embedding model complexity: The embedding model grows from a simple "lookup + scatter" to also computing per-layer projections and N embedding lookups. Adds
_compute_per_layer_inputs logic plus N+1 output tensors.
- Weight splitting changes: The
preprocess_weights logic for decoder and embedding models must be updated to route per-layer embedding tables to the embedding model instead of the decoder.
Note: Total model weight is unchanged — the per-layer tables are MOVED from decoder to embedding, not duplicated.
Feasibility: MEDIUM — semantically correct (the projection operates on inputs_embeds at the start of forward(), before any transformer layers, so it CAN be computed in the embedding model). But the model balance shift and inter-model tensor cost make it less attractive than Alternative B.
Alternative B: Keep input_ids flowing to decoder (previous recommendation)
Approach: The decoder takes both inputs_embeds and input_ids. The embedding model just handles the standard input_ids → embed_tokens + scatter_vision_features → inputs_embeds pipeline. Per-layer computation stays in the decoder, using L separate [V, D] embedding tables (split from HF's fused [V, L*D] table to work around ORT Gather int32 overflow).
Note: This was the recommended approach before the ORT Gather int32 fix (onnxruntime#28107) was merged. It is now superseded by Alternative D, which uses the fused table directly. Alternative B remains a valid fallback for ORT versions < 1.25.
GenAI runtime impact: The decoder's ONNX graph has input_ids as a graph input. _introspect_inputs() in auto_export.py discovers this automatically and includes it in genai_config.json. The GenAI runtime must:
- Know to feed
input_ids to the decoder alongside inputs_embeds
- Feed the SAME
input_ids that were fed to the embedding model
Pros:
- All per-layer embedding tables stay in the decoder — no model rebalancing needed
- Zero additional inter-model tensor overhead
- Minimal code complexity —
_compute_per_layer_inputs is already implemented
- Works for all Gemma4 variants (per_layer_dim=0 models just ignore input_ids)
Cons:
- Non-standard decoder interface: VLM decoders normally take only
inputs_embeds, not both inputs_embeds AND input_ids. This requires GenAI runtime awareness.
- Runtime coupling: GenAI runtime must propagate
input_ids to the decoder, which it doesn't do today for VLMs (DecoderOnly_Model handles this for text-only, but MultiModal pipeline doesn't)
GenAI runtime change needed: The MultiModal pipeline's decoder session needs to bind input_ids as an additional input. This is a config-driven change — the introspected decoder inputs already include input_ids, so the runtime just needs to respect all declared inputs rather than assuming VLM decoders only take inputs_embeds.
Feasibility: HIGH — this is the current implementation. The only gap is GenAI runtime support for the extra decoder input.
Alternative C: Shared embedding with per-layer projection (no per-layer vocab)
Approach: Instead of N separate embedding tables (one per layer), use a single shared embedding table and N small linear projections to create per-layer variants.
Analysis: This would require changing the model architecture itself — Gemma4 as designed by Google uses per-layer embedding tables. We can't change the architecture and still load HuggingFace weights. Not applicable.
Feasibility: NONE — incompatible with pretrained weights.
Alternative D: Fuse per-layer embeddings into the decoder's embed_tokens
Approach: Pre-compute the per-layer token embeddings as a single fused lookup table, [V, L*D], and index it once per forward pass. This is a computation optimization within the decoder, not an architectural change.
Analysis: This is actually what the original HF implementation does (one big [V, L*D] embedding), but the mobius implementation intentionally splits it into per-layer [V, D] tables to avoid ORT CUDA Gather int32 overflow (onnxruntime#28107). From gemma4.py:1400-1403:
# Use per-layer embedding tables instead of one giant [V, L*D] table.
# Each [V, D] table has only V*D elements (e.g. 262144*256 = 67M),
# well under the ORT CUDA Gather int32 limit (~2.1B).
The fused table for E4B would be 262144 × 42 × 256 = 2.8B elements, which previously exceeded the ORT CUDA Gather int32 limit (~2.1B). However, this has been fixed in onnxruntime#28107, which was merged and is available in ORT ≥ 1.25.
With the ORT fix, the fused [V, L*D] approach works correctly. This is actually simpler than the per-layer split:
- 1 Gather on
[V, L*D] instead of L Gathers on [V, D]
- Matches the HuggingFace implementation directly (one
nn.Embedding(V, L*D))
- Simpler
preprocess_weights — no need to split the fused HF weight into per-layer tables
- Fewer ONNX ops in the graph
GenAI runtime impact: Same as Alternative B — the decoder still needs input_ids as an input. The fused table is an internal optimization within the decoder graph, not a model boundary change.
Feasibility: HIGH — ORT fix merged (onnxruntime#28107). This is the simplest and most efficient approach.
Alternative E: Lazy per-layer embedding via ONNX subgraph
Approach: Instead of computing all per-layer embeddings upfront, compute each layer's embedding inside the decoder layer itself using an ONNX Scan or Loop. The decoder still takes input_ids.
Analysis: This is essentially what the current implementation does — _compute_per_layer_inputs runs in the text model's forward(), which is inside the decoder ONNX graph. Moving it into individual layers wouldn't change the model boundary.
Feasibility: SAME AS B — doesn't change the fundamental issue of needing input_ids in the decoder.
Recommendation
Alternative D is the recommended approach (fused [V, L*D] table).
Note: Alternative B was the previous recommendation when the ORT Gather int32 overflow bug (onnxruntime#28107) was still open. The per-layer split (L separate [V, D] tables) was a workaround for that bug. Now that the fix is merged (ORT ≥ 1.25), Alternative D supersedes B as the cleaner solution.
Rationale:
-
Simpler code: A single nn.Embedding(V, L*D) replaces L separate nn.Embedding(V, D) tables. One Gather op instead of L Gather ops in the ONNX graph. The _compute_per_layer_inputs logic simplifies from a loop of per-layer lookups to a single lookup + reshape.
-
Matches HuggingFace implementation: The original HF Gemma4 uses one fused embed_tokens_per_layer table of shape [V, L*D]. Alternative D matches this directly, making weight loading trivial — no need to split or reshape the HF weight in preprocess_weights.
-
Fewer ONNX ops: For E4B (42 layers), Alternative B requires 42 Gather ops + 42 intermediate tensors. Alternative D uses 1 Gather + 1 Reshape — simpler graph, better optimization potential, fewer kernel launches on CUDA.
-
Same GenAI runtime requirement as B: The decoder still needs input_ids as an input alongside inputs_embeds. The runtime change (~20 lines to propagate input_ids to the VLM decoder) is identical to Alternative B.
-
Zero overhead: No model rebalancing, no extra inter-model tensors, no model balance shift. Per-layer tables stay in the decoder.
Required GenAI runtime work
Same as Alternative B — the GenAI runtime MultiModal pipeline needs to:
- Detect that the decoder declares
input_ids as an input (from genai_config.json decoder inputs)
- Propagate the same
input_ids that was fed to the embedding model to the decoder session
- This is likely a ~20-line change in the pipeline orchestration code
Why D over B?
Both B and D require the same GenAI runtime change. The difference is internal to the decoder ONNX graph:
| Aspect |
Alternative B (per-layer split) |
Alternative D (fused table) |
| Embedding tables |
L × [V, D] (split) |
1 × [V, L*D] (fused) |
| Gather ops |
L (one per layer) |
1 (single lookup) |
| preprocess_weights |
Must split HF's fused [V, L*D] → L tables |
Direct — HF weight used as-is |
| Code complexity |
Loop over per-layer tables |
Single lookup + reshape |
| HF parity |
Transformed from HF layout |
Matches HF exactly |
| ORT requirement |
Any ORT version |
ORT ≥ 1.25 (Gather int32 fix) |
Why not Alternative A?
Alternative A is technically feasible — the embedding model already produces inputs_embeds (with vision features scattered), and could also compute per_layer_model_projection(inputs_embeds) + per-layer token lookups. The total model weight is the same (tables MOVE from decoder to embedding, not duplicated).
However, the practical downsides outweigh the benefit of a "standard" decoder interface:
- Model balance shift: E4B's embedding model grows by 5.6 GB of per-layer tables. The embedding runs once (prefill only), while the decoder runs every token — moving weights from decoder to embedding increases peak memory during prefill without reducing per-token decode cost.
- 44 MB inter-model tensor:
[num_layers × B × S × per_layer_dim] must be passed from embedding to decoder every call.
- Embedding model complexity: Goes from ~5 ops (lookup + scatter) to computing N projections + N embedding lookups + output packing.
- All this to avoid a ~20-line GenAI runtime change that simply propagates
input_ids to the decoder.
Summary
| Alternative |
Feasibility |
Memory Overhead |
Runtime Change |
Complexity |
| A: Pre-compute in embedding |
MEDIUM |
+44 MB/call (weights move, not duplicated) |
None |
HIGH |
| B: input_ids to decoder (per-layer split) |
HIGH |
None |
~20 lines |
LOW |
| C: Shared embedding + projections |
NONE |
N/A |
N/A |
N/A |
| D: Fused [V, L*D] table ⭐ |
HIGH (ORT fix merged) |
None |
~20 lines (same as B) |
LOWEST |
| E: Per-layer lazy compute |
SAME AS B |
None |
Same as B |
SAME |
Recommendation: Alternative D (fused [V, L*D] table). Simplest code, fewest ONNX ops, matches HF implementation directly. Requires ORT ≥ 1.25 and the same ~20-line GenAI runtime change as Alternative B.
[copilot generated] Gemma4 Per-Layer Embedding: Architecture Analysis & Alternatives
Problem Statement
Gemma4 VLMs use a 3-model split (decoder, vision_encoder, embedding) for ORT GenAI deployment. The standard VLM pipeline is:
But Gemma4 E2B and E4B have per-layer input embeddings: each decoder layer has its own embedding table that looks up the original
input_idsto compute a per-layer conditioning signal. This means the decoder needsinput_idsin addition toinputs_embeds.Which models are affected?
hidden_size_per_layer_inputPer-layer input mechanism (current implementation)
In
Gemma4TextModel._compute_per_layer_inputs()(gemma4.py:1424-1481):inputs_embeds → Linear → [B, S, L*D](project to all layers at once)hidden_size^(-0.5)[B, S, L, D]→ per-layer normi, look upinput_idsinembed_tokens_per_layer[i](each table is[262K, 256])proj_i + token_emb_i, scale by0.5^0.5Then in each decoder layer, the per-layer input is consumed via a gated addition:
Current mobius approach
The current implementation passes both
inputs_embedsANDinput_idsto the decoder ONNX model:inputs_embeds— from the embedding model (with vision features scattered)input_ids— the raw token IDs for per-layer embedding lookupThis works architecturally but requires GenAI runtime to pass
input_idsto the decoder, which is non-standard for VLM decoders (they normally only receiveinputs_embeds).Alternative Analysis
Alternative A: Pre-compute ALL per-layer embeddings in the embedding model
Approach: Move
_compute_per_layer_inputs()and all 35-42embed_tokens_per_layertables into the embedding model. Output:[num_layers, B, S, per_layer_dim]as an additional output.Memory cost: For E4B with seq_len=2048:
42 × 1 × 2048 × 256 × 2 bytes(bf16) = 44 MB— transmitted between models2560 → 42*256 = 10752→ 27M params → 54 MBPros:
Cons:
num_layers × B × S × per_layer_dimin bf16). This is transmitted from the embedding model to the decoder every step._compute_per_layer_inputslogic plus N+1 output tensors.preprocess_weightslogic for decoder and embedding models must be updated to route per-layer embedding tables to the embedding model instead of the decoder.Note: Total model weight is unchanged — the per-layer tables are MOVED from decoder to embedding, not duplicated.
Feasibility: MEDIUM — semantically correct (the projection operates on
inputs_embedsat the start offorward(), before any transformer layers, so it CAN be computed in the embedding model). But the model balance shift and inter-model tensor cost make it less attractive than Alternative B.Alternative B: Keep input_ids flowing to decoder (previous recommendation)
Approach: The decoder takes both
inputs_embedsandinput_ids. The embedding model just handles the standardinput_ids → embed_tokens + scatter_vision_features → inputs_embedspipeline. Per-layer computation stays in the decoder, usingLseparate[V, D]embedding tables (split from HF's fused[V, L*D]table to work around ORT Gather int32 overflow).GenAI runtime impact: The decoder's ONNX graph has
input_idsas a graph input._introspect_inputs()in auto_export.py discovers this automatically and includes it in genai_config.json. The GenAI runtime must:input_idsto the decoder alongsideinputs_embedsinput_idsthat were fed to the embedding modelPros:
_compute_per_layer_inputsis already implementedCons:
inputs_embeds, not bothinputs_embedsANDinput_ids. This requires GenAI runtime awareness.input_idsto the decoder, which it doesn't do today for VLMs (DecoderOnly_Model handles this for text-only, but MultiModal pipeline doesn't)GenAI runtime change needed: The MultiModal pipeline's decoder session needs to bind
input_idsas an additional input. This is a config-driven change — the introspected decoder inputs already includeinput_ids, so the runtime just needs to respect all declared inputs rather than assuming VLM decoders only takeinputs_embeds.Feasibility: HIGH — this is the current implementation. The only gap is GenAI runtime support for the extra decoder input.
Alternative C: Shared embedding with per-layer projection (no per-layer vocab)
Approach: Instead of N separate embedding tables (one per layer), use a single shared embedding table and N small linear projections to create per-layer variants.
Analysis: This would require changing the model architecture itself — Gemma4 as designed by Google uses per-layer embedding tables. We can't change the architecture and still load HuggingFace weights. Not applicable.
Feasibility: NONE — incompatible with pretrained weights.
Alternative D: Fuse per-layer embeddings into the decoder's embed_tokens
Approach: Pre-compute the per-layer token embeddings as a single fused lookup table,
[V, L*D], and index it once per forward pass. This is a computation optimization within the decoder, not an architectural change.Analysis: This is actually what the original HF implementation does (one big
[V, L*D]embedding), but the mobius implementation intentionally splits it into per-layer[V, D]tables to avoid ORT CUDA Gather int32 overflow (onnxruntime#28107). From gemma4.py:1400-1403:The fused table for E4B would be
262144 × 42 × 256 = 2.8B elements, which previously exceeded the ORT CUDA Gather int32 limit (~2.1B). However, this has been fixed in onnxruntime#28107, which was merged and is available in ORT ≥ 1.25.With the ORT fix, the fused
[V, L*D]approach works correctly. This is actually simpler than the per-layer split:[V, L*D]instead of L Gathers on[V, D]nn.Embedding(V, L*D))preprocess_weights— no need to split the fused HF weight into per-layer tablesGenAI runtime impact: Same as Alternative B — the decoder still needs
input_idsas an input. The fused table is an internal optimization within the decoder graph, not a model boundary change.Feasibility: HIGH — ORT fix merged (onnxruntime#28107). This is the simplest and most efficient approach.
Alternative E: Lazy per-layer embedding via ONNX subgraph
Approach: Instead of computing all per-layer embeddings upfront, compute each layer's embedding inside the decoder layer itself using an ONNX Scan or Loop. The decoder still takes
input_ids.Analysis: This is essentially what the current implementation does —
_compute_per_layer_inputsruns in the text model'sforward(), which is inside the decoder ONNX graph. Moving it into individual layers wouldn't change the model boundary.Feasibility: SAME AS B — doesn't change the fundamental issue of needing
input_idsin the decoder.Recommendation
Alternative D is the recommended approach (fused
[V, L*D]table).Rationale:
Simpler code: A single
nn.Embedding(V, L*D)replacesLseparatenn.Embedding(V, D)tables. One Gather op instead ofLGather ops in the ONNX graph. The_compute_per_layer_inputslogic simplifies from a loop of per-layer lookups to a single lookup + reshape.Matches HuggingFace implementation: The original HF Gemma4 uses one fused
embed_tokens_per_layertable of shape[V, L*D]. Alternative D matches this directly, making weight loading trivial — no need to split or reshape the HF weight inpreprocess_weights.Fewer ONNX ops: For E4B (42 layers), Alternative B requires 42 Gather ops + 42 intermediate tensors. Alternative D uses 1 Gather + 1 Reshape — simpler graph, better optimization potential, fewer kernel launches on CUDA.
Same GenAI runtime requirement as B: The decoder still needs
input_idsas an input alongsideinputs_embeds. The runtime change (~20 lines to propagateinput_idsto the VLM decoder) is identical to Alternative B.Zero overhead: No model rebalancing, no extra inter-model tensors, no model balance shift. Per-layer tables stay in the decoder.
Required GenAI runtime work
Same as Alternative B — the GenAI runtime MultiModal pipeline needs to:
input_idsas an input (from genai_config.json decoder inputs)input_idsthat was fed to the embedding model to the decoder sessionWhy D over B?
Both B and D require the same GenAI runtime change. The difference is internal to the decoder ONNX graph:
L×[V, D](split)[V, L*D](fused)L(one per layer)[V, L*D]→LtablesWhy not Alternative A?
Alternative A is technically feasible — the embedding model already produces
inputs_embeds(with vision features scattered), and could also computeper_layer_model_projection(inputs_embeds)+ per-layer token lookups. The total model weight is the same (tables MOVE from decoder to embedding, not duplicated).However, the practical downsides outweigh the benefit of a "standard" decoder interface:
[num_layers × B × S × per_layer_dim]must be passed from embedding to decoder every call.input_idsto the decoder.Summary
Recommendation: Alternative D (fused
[V, L*D]table). Simplest code, fewest ONNX ops, matches HF implementation directly. Requires ORT ≥ 1.25 and the same ~20-line GenAI runtime change as Alternative B.