Implement decoder-only GPT-style transformer
Company: Amazon
Role: Machine Learning Engineer
Category: Coding & Algorithms
Difficulty: medium
Interview Round: Onsite
### Goal
Implement a simplified **decoder-only Transformer language model** (similar in spirit to GPT) for next-token prediction. The implementation should be modular, using four main classes:
1. `MultiHeadAttention`
2. `FeedForward`
3. `DecoderLayer`
4. `GPT` (the full model)
You may assume a deep learning framework (e.g., PyTorch or TensorFlow), but you should clearly specify tensor shapes and operations.
---
### Model details and requirements
Assume:
- Batch size: `B`
- Sequence length: `T`
- Embedding dimension: `d_model`
- Number of attention heads: `num_heads` (assume `d_model` is divisible by `num_heads`)
- Vocabulary size: `V`
#### 1. Multi-head self-attention (`MultiHeadAttention`)
Implement a class `MultiHeadAttention` that performs **masked self-attention**:
- Inputs:
- `x`: Tensor of shape `(B, T, d_model)` (token embeddings or layer outputs).
- Internally learnable parameters:
- Linear projections to compute queries `Q`, keys `K`, and values `V` for all heads.
- Operations:
1. Project `x` into `Q`, `K`, and `V` with shapes `(B, T, d_model)`.
2. Split them into `num_heads` heads, each of dimension `d_head = d_model / num_heads`.
3. Compute scaled dot-product attention for each head:
\[
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_{head}}} + M\right)V,
\]
where `M` is a **causal mask** that prevents a position from attending to future positions (positions `> t`).
4. Concatenate the outputs of all heads back to shape `(B, T, d_model)`.
5. Apply a final linear projection to return to `(B, T, d_model)`.
- Output:
- Tensor of shape `(B, T, d_model)`.
Ensure you correctly implement the **causal (autoregressive) mask** so that position `t` can only attend to positions `0..t`.
#### 2. Position-wise feed-forward network (`FeedForward`)
Implement a class `FeedForward` for the position-wise MLP:
- Inputs:
- `x`: Tensor of shape `(B, T, d_model)`.
- Internal structure (typical choice):
- Linear layer from `d_model` to `d_ff` (e.g., `4 * d_model`).
- Non-linear activation (e.g., GELU or ReLU).
- Linear layer from `d_ff` back to `d_model`.
- All operations are applied **independently** to each position in the sequence.
- Output:
- Tensor of shape `(B, T, d_model)`.
#### 3. Decoder layer (`DecoderLayer`)
Implement a single Transformer **decoder block** that combines attention, feed-forward, residual connections, and layer normalization:
- Inputs:
- `x`: Tensor of shape `(B, T, d_model)`.
- Components:
1. LayerNorm `ln1` before self-attention.
2. `MultiHeadAttention` block (masked self-attention).
3. Residual connection: `x = x + attention_output`.
4. LayerNorm `ln2` before feed-forward.
5. `FeedForward` block.
6. Residual connection: `x = x + ff_output`.
- Output:
- Tensor of shape `(B, T, d_model)`.
Your `DecoderLayer` should be reusable so that multiple layers can be stacked.
#### 4. GPT model (`GPT`)
Implement a `GPT` class representing the full decoder-only Transformer:
- Inputs:
- `input_ids`: Integer tensor of shape `(B, T)` representing token indices.
- Components:
1. **Token embedding** layer mapping token IDs to vectors of size `d_model`.
2. **Positional encodings** (learned or fixed) of shape `(T, d_model)` added to the token embeddings.
3. A stack of `N` identical `DecoderLayer` blocks.
4. A final linear layer mapping from `d_model` to vocabulary size `V`.
- Forward pass:
1. Embed tokens, add positional encodings to obtain `(B, T, d_model)`.
2. Pass through the `N` decoder layers sequentially.
3. Apply the final linear layer to obtain logits of shape `(B, T, V)`.
- Output:
- Logits for next-token prediction: `(B, T, V)`.
#### Additional notes
- You should define the `forward` method for each class with correct tensor transformations.
- Be careful about tensor shapes when splitting/combining heads.
- Ensure the causal mask is correctly broadcast and applied in attention.
- You do **not** need to implement training loops or optimization; focus on correct and clean model implementation.
Quick Answer: This question evaluates implementation-level mastery of transformer architectures, including multi-head masked self-attention, position-wise feed-forward networks, residual connections, and precise tensor-shape reasoning for a decoder-only language model, within the Coding & Algorithms domain of deep learning and neural network engineering.