PracHub
QuestionsCoachesLearningGuidesInterview Prep
|Home/Machine Learning/Datadog

Implement Grouped-Query Attention (GQA)

Last updated: Jul 1, 2026

Quick Overview

This question evaluates a candidate's understanding of transformer attention mechanisms, specifically how grouped-query attention balances the memory efficiency of multi-query attention with the quality of full multi-head attention. It tests practical implementation skill with tensor reshaping, head grouping, and causal masking, a common way to probe machine learning engineering depth in system-level model design interviews.

  • hard
  • Datadog
  • Machine Learning
  • Machine Learning Engineer

Implement Grouped-Query Attention (GQA)

Company: Datadog

Role: Machine Learning Engineer

Category: Machine Learning

Difficulty: hard

Interview Round: Technical Screen

# Implement Grouped-Query Attention (GQA) Modern decoder-only transformers spend most of their autoregressive-decoding time and memory moving the **key/value (KV) cache** in and out of GPU memory. Multi-Head Attention (MHA) keeps one K and one V per query head, which is expensive; Multi-Query Attention (MQA) shares a single K/V across *all* heads, which is cheap but can hurt quality. **Grouped-Query Attention (GQA)** is the middle ground: the query heads are partitioned into $G$ groups, and all query heads within a group share one K/V head. Implement the forward pass of a GQA module from scratch using only basic tensor ops (linear layers + matmul + softmax; no high-level attention helper). Given: - input hidden states $x$ of shape $(B, T, d_{\text{model}})$, - `num_query_heads` $H$ and `num_kv_heads` $G$, with $H$ divisible by $G$, - per-head dimension $d_{\text{head}} = d_{\text{model}} / H$, your module should: 1. Project $x$ to queries $Q$ ($H$ heads) and keys/values $K, V$ ($G$ heads each). 2. Reshape into heads and **share each KV head across $H/G$ query heads**. 3. Compute **causal** scaled dot-product attention. 4. Concatenate heads and apply the output projection back to $d_{\text{model}}$. ```hint The crux — sharing KV heads After projecting, $Q$ has shape $(B, H, T, d_{\text{head}})$ while $K, V$ only have $(B, G, T, d_{\text{head}})$ — every query head needs to read from one of the $G$ KV heads. Before you can run ordinary multi-head attention, you have to bring $K, V$ up to $H$ heads in a way that preserves which group each query head belongs to; getting the head ordering wrong silently scrambles which queries attend to which keys without raising any error. ``` ```hint Projection sizes are asymmetric The Q projection is $d_{\text{model}} \to H \cdot d_{\text{head}}\ (= d_{\text{model}})$, but the K and V projections are $d_{\text{model}} \to G \cdot d_{\text{head}}$, which is **smaller**. That asymmetry is the whole point — fewer KV parameters and a smaller KV cache at decode time. ``` ```hint Attention core Once $Q$, $K$, $V$ all have $H$ heads, this is ordinary causal scaled dot-product attention. Make sure the causal mask is applied **before** the softmax, not after, and double-check how a $T \times T$ mask broadcasts across the batch and head axes. ``` ### Constraints & Assumptions - Shapes: $x$ is $(B, T, d_{\text{model}})$; the output is $(B, T, d_{\text{model}})$. - $H \bmod G = 0$ (e.g. $H=8,\ G=2$ → each KV head shared by 4 query heads). $G = H$ is MHA; $G = 1$ is MQA. - $d_{\text{model}} \bmod H = 0$; $d_{\text{head}} = d_{\text{model}} // H$. - Scale scores by $1/\sqrt{d_{\text{head}}}$. - Decoder **self-attention**: apply a causal mask so position $t$ attends only to positions $\le t$. - Single-precision floats; the mask fill should use a large negative value (or $-\infty$) compatible with softmax. - Dropout and rotary embeddings (RoPE) may be ignored in the core implementation (mention them as extensions). ### Clarifying Questions to Ask - Is this causal self-attention in a decoder, or cross-attention? I will assume causal self-attention. - Should I implement only the full-sequence forward (training / prefill), or also the **incremental single-token decode** with a KV cache? - Are positional encodings (e.g. RoPE) applied to $Q/K$ inside the module, or handled outside? - Besides the causal mask, is there a padding mask for variable-length sequences in the batch? - Do the Q/K/V/O linear projections include bias terms? ### What a Strong Answer Covers ```premium-lock What a Strong Answer Covers ``` ### Follow-up Questions - During autoregressive decoding with a KV cache, how much memory does GQA save versus MHA, and why is memory **bandwidth** (not FLOPs) the decode-time bottleneck? - How would you "uptrain" an existing MHA checkpoint into GQA — how do you initialize the $G$ KV heads from the original $H$? - Where do rotary position embeddings (RoPE) get applied, and does GQA change that? - Implement the **incremental decode step**: given cached $K, V$ of length $t$ and one new token, produce the next output and update the cache. Which shapes change? - How does FlashAttention interact with GQA, and what changes about the memory-access pattern?

Quick Answer: This question evaluates a candidate's understanding of transformer attention mechanisms, specifically how grouped-query attention balances the memory efficiency of multi-query attention with the quality of full multi-head attention. It tests practical implementation skill with tensor reshaping, head grouping, and causal masking, a common way to probe machine learning engineering depth in system-level model design interviews.

Related Interview Questions

  • Implement Focal Loss for Class-Imbalanced Classification - Datadog (hard)
|Home/Machine Learning/Datadog

Implement Grouped-Query Attention (GQA)

Datadog logo
Datadog
Jun 25, 2026, 12:00 AM
hardMachine Learning EngineerTechnical ScreenMachine Learning
0
0

Implement Grouped-Query Attention (GQA)

Modern decoder-only transformers spend most of their autoregressive-decoding time and memory moving the key/value (KV) cache in and out of GPU memory. Multi-Head Attention (MHA) keeps one K and one V per query head, which is expensive; Multi-Query Attention (MQA) shares a single K/V across all heads, which is cheap but can hurt quality. Grouped-Query Attention (GQA) is the middle ground: the query heads are partitioned into GGG groups, and all query heads within a group share one K/V head.

Implement the forward pass of a GQA module from scratch using only basic tensor ops (linear layers + matmul + softmax; no high-level attention helper). Given:

  • input hidden states xxx of shape (B,T,dmodel)(B, T, d_{\text{model}})(B,T,dmodel​) ,
  • num_query_heads HHH and num_kv_heads GGG , with HHH divisible by GGG ,
  • per-head dimension dhead=dmodel/Hd_{\text{head}} = d_{\text{model}} / Hdhead​=dmodel​/H ,

your module should:

  1. Project xxx to queries QQQ ( HHH heads) and keys/values K,VK, VK,V ( GGG heads each).
  2. Reshape into heads and share each KV head across H/GH/GH/G query heads .
  3. Compute causal scaled dot-product attention.
  4. Concatenate heads and apply the output projection back to dmodeld_{\text{model}}dmodel​ .

Constraints & Assumptions

  • Shapes: xxx is (B,T,dmodel)(B, T, d_{\text{model}})(B,T,dmodel​) ; the output is (B,T,dmodel)(B, T, d_{\text{model}})(B,T,dmodel​) .
  • H mod G=0H \bmod G = 0HmodG=0 (e.g. H=8, G=2H=8,\ G=2H=8, G=2 → each KV head shared by 4 query heads). G=HG = HG=H is MHA; G=1G = 1G=1 is MQA.
  • dmodel mod H=0d_{\text{model}} \bmod H = 0dmodel​modH=0 ; dhead=dmodel//Hd_{\text{head}} = d_{\text{model}} // Hdhead​=dmodel​//H .
  • Scale scores by 1/dhead1/\sqrt{d_{\text{head}}}1/dhead​​ .
  • Decoder self-attention : apply a causal mask so position ttt attends only to positions ≤t\le t≤t .
  • Single-precision floats; the mask fill should use a large negative value (or −∞-\infty−∞ ) compatible with softmax.
  • Dropout and rotary embeddings (RoPE) may be ignored in the core implementation (mention them as extensions).

Clarifying Questions to Ask

  • Is this causal self-attention in a decoder, or cross-attention? I will assume causal self-attention.
  • Should I implement only the full-sequence forward (training / prefill), or also the incremental single-token decode with a KV cache?
  • Are positional encodings (e.g. RoPE) applied to Q/KQ/KQ/K inside the module, or handled outside?
  • Besides the causal mask, is there a padding mask for variable-length sequences in the batch?
  • Do the Q/K/V/O linear projections include bias terms?

What a Strong Answer Covers Premium

Follow-up Questions

  • During autoregressive decoding with a KV cache, how much memory does GQA save versus MHA, and why is memory bandwidth (not FLOPs) the decode-time bottleneck?
  • How would you "uptrain" an existing MHA checkpoint into GQA — how do you initialize the GGG KV heads from the original HHH ?
  • Where do rotary position embeddings (RoPE) get applied, and does GQA change that?
  • Implement the incremental decode step : given cached K,VK, VK,V of length ttt and one new token, produce the next output and update the cache. Which shapes change?
  • How does FlashAttention interact with GQA, and what changes about the memory-access pattern?
Loading comments...

Browse More Questions

More Machine Learning•More Datadog•More Machine Learning Engineer•Datadog Machine Learning Engineer•Datadog Machine Learning•Machine Learning Engineer Machine Learning

Write your answer

Your first approved answer each day earns 20 XP.

Sign in to write your answer.
PracHub

Master your tech interviews with 8,000+ real questions from top companies.

Product

  • Questions
  • Learning Tracks
  • Interview Guides
  • Resources
  • Premium
  • For Universities
  • Student Access

Browse

  • By Company
  • By Role
  • By Category
  • Topic Hubs
  • SQL Questions
  • AI Coding Questions
  • Compare Platforms
  • Discord Community

Support

  • support@prachub.com
  • (916) 541-4762

Legal

  • Privacy Policy
  • Terms of Service
  • About Us

© 2026 PracHub. All rights reserved.