Implement Top-p (Nucleus) Sampling in NumPy
Company: Amazon
Role: Machine Learning Engineer
Category: Coding & Algorithms
Difficulty: medium
Interview Round: Onsite
Implement **top-p (nucleus) sampling** for next-token selection in a language model, using only NumPy.
Given a vector of raw logits over a vocabulary, top-p sampling keeps the smallest set of most-probable tokens whose cumulative probability mass is at least `p`, renormalizes that set into a probability distribution, and draws one token from it. This is the standard decoding strategy used to trade off diversity and coherence in text generation.
## Function signature
Implement:
```python
def top_p_sample(
logits, # np.ndarray of shape (vocab_size,): raw, unnormalized logits
p, # float in (0, 1]: the nucleus cumulative-probability threshold
temperature=1.0, # float > 0: temperature applied to logits before softmax
rng=None, # optional np.random.Generator for reproducible sampling
):
"""
Returns:
int: the index (token id) sampled from the nucleus.
"""
```
## What the computation must do
1. **Temperature scaling.** Divide the logits by `temperature` before converting to probabilities. Lower temperature sharpens the distribution; higher temperature flattens it.
2. **Softmax.** Convert the temperature-scaled logits into a probability distribution using a **numerically stable** softmax (subtract the max logit before exponentiating).
3. **Sort by probability, descending.** Order the tokens from most to least probable.
4. **Find the nucleus.** Walk the sorted probabilities accumulating their sum, and keep tokens up to and **including** the first token at which the running cumulative probability is $\ge p$. The nucleus is this smallest prefix of the sorted list whose cumulative mass reaches `p`. The nucleus is never empty: even if the single most-probable token already has probability $\ge p$, that one token forms the nucleus.
5. **Renormalize.** Zero out all tokens outside the nucleus, then rescale the surviving probabilities so they sum to 1.
6. **Sample.** Draw one token index from the renormalized distribution. If `rng` is provided, use it (`rng.choice` / `rng.random`) so results are reproducible; otherwise fall back to NumPy's default random source.
7. **Return the original vocabulary index** (token id) of the sampled token — not its position in the sorted order.
## Constraints & Requirements
- Use **NumPy only**. Do not call any deep-learning framework or any library's built-in top-p / nucleus-sampling helper.
- The **softmax must be numerically stable**.
- `1 <= vocab_size <= 100000`.
- `0 < p <= 1`. When `p == 1`, the nucleus is the entire vocabulary (subject only to floating-point rounding).
- `temperature > 0`.
- The returned value is a Python `int` in `[0, vocab_size)`.
- Sampling must be **reproducible** when an `rng` (`np.random.Generator`) is passed: the same `logits`, `p`, `temperature`, and `rng` state must yield the same token.
- Ties in probability may be broken in any consistent order (e.g., stable sort), as long as the nucleus definition above is respected.
## Example
For `logits = [2.0, 1.0, 0.1]`, `temperature = 1.0`:
- Softmax probabilities are approximately `[0.659, 0.242, 0.099]`.
- With `p = 0.9`: cumulative mass after the top two tokens is `0.659 + 0.242 = 0.901 >= 0.9`, so the nucleus is the first two tokens `{0, 1}`. Token id 2 is excluded. Sampling then draws from the renormalized two-token distribution `[0.731, 0.269]`.
- With `p = 0.6`: the single most-probable token already reaches `0.659 >= 0.6`, so the nucleus is `{0}` and the function deterministically returns `0`.
Quick Answer: This coding question tests practical implementation of top-p (nucleus) sampling, a core decoding strategy in large language models. It evaluates NumPy proficiency including numerically stable softmax, cumulative probability thresholding, and reproducible stochastic sampling — skills central to machine learning engineering roles.
Implement **top-p (nucleus) sampling** for next-token selection in a language model, using only NumPy.
Given a vector of raw logits over a vocabulary, top-p sampling keeps the smallest set of most-probable tokens whose cumulative probability mass is at least `p`, renormalizes that set into a probability distribution, and draws one token from it.
## Function signature
```python
def solution(logits, p, temperature=1.0, seed=None):
"""Returns: int — the vocabulary index (token id) sampled from the nucleus."""
```
**Note on `seed`:** The interview prompt passes an `np.random.Generator` (`rng`) for reproducibility. In this executable console the random source is given as an integer `seed` instead (or `None` for a fresh draw); internally build `np.random.default_rng(seed)`. With the same `logits`, `p`, `temperature`, and `seed`, the function must return the same token. When the nucleus is a single token, the return is deterministic regardless of `seed`.
## What the computation must do
1. **Temperature scaling.** Divide the logits by `temperature` before converting to probabilities.
2. **Softmax.** Convert the temperature-scaled logits into a probability distribution using a **numerically stable** softmax (subtract the max logit before exponentiating).
3. **Sort by probability, descending.**
4. **Find the nucleus.** Walk the sorted probabilities accumulating their sum, and keep tokens up to and **including** the first token at which the running cumulative probability is `>= p`. The nucleus is never empty.
5. **Renormalize.** Zero out all tokens outside the nucleus, then rescale the survivors to sum to 1.
6. **Sample** one token index from the renormalized distribution using the seeded generator.
7. **Return the original vocabulary index** of the sampled token — not its position in the sorted order.
## Constraints
- Use **NumPy only**; no deep-learning framework or built-in top-p helper.
- Softmax must be numerically stable.
- `1 <= vocab_size <= 100000`.
- `0 < p <= 1`. When `p == 1`, the nucleus is the entire vocabulary.
- `temperature > 0`.
- Returned value is a Python `int` in `[0, vocab_size)`.
- Ties may be broken in any consistent order (e.g., stable sort).
## Example
For `logits = [2.0, 1.0, 0.1]`, `temperature = 1.0`, softmax ≈ `[0.659, 0.242, 0.099]`:
- `p = 0.9`: cumulative after the top two tokens is `0.901 >= 0.9`, so the nucleus is `{0, 1}`; token id 2 is excluded.
- `p = 0.6`: the top token already reaches `0.659 >= 0.6`, so the nucleus is `{0}` and the function deterministically returns `0`.
Constraints
- 1 <= vocab_size <= 100000
- 0 < p <= 1 (p == 1 selects the entire vocabulary, up to float rounding)
- temperature > 0
- Use NumPy only; no DL framework or built-in nucleus-sampling helper
- Softmax must be numerically stable (subtract max before exp)
- Nucleus is the smallest descending prefix with cumulative prob >= p, inclusive of the crossing token, and is never empty
- Return a Python int in [0, vocab_size); the original token id, not the sorted position
- Same logits, p, temperature, and seed must yield the same token
Examples
Input: ([2.0, 1.0, 0.1], 0.6, 1.0, None)
Expected Output: 0
Explanation: Softmax top prob ~0.659 >= 0.6, so the nucleus is the single token {0} and the return is deterministic regardless of seed.
Input: ([2.0, 1.0, 0.1], 0.9, 1.0, 42)
Expected Output: 1
Explanation: Cumulative mass of the top two tokens is ~0.901 >= 0.9, so the nucleus is {0, 1} (token 2 excluded). With seed 42 the reproducible draw from the renormalized two-token distribution lands on token id 1.
Input: ([5.0, 0.0, 0.0, 0.0], 0.5, 1.0, None)
Expected Output: 0
Explanation: The dominant logit gives the top token probability well above 0.5, so the nucleus is {0}; deterministic regardless of seed.
Input: ([1.0, 1.0, 1.0, 1.0], 1.0, 1.0, 7)
Expected Output: 2
Explanation: Uniform logits => p == 1 keeps the entire vocabulary; seed 7 reproducibly selects token id 2.
Input: ([0.0], 1.0, 1.0, None)
Expected Output: 0
Explanation: Edge case: vocab_size == 1. The only token forms the nucleus and index 0 is returned.
Input: ([3.0, 2.5, 2.0, 1.0], 0.95, 2.0, 123)
Expected Output: 2
Explanation: Temperature 2.0 flattens the distribution so the nucleus spans several tokens; with p=0.95 and seed 123 the reproducible draw returns token id 2.
Input: ([10.0, 1.0, 1.0], 0.99, 0.5, 1)
Expected Output: 0
Explanation: Temperature 0.5 sharpens the distribution so the top token's probability already exceeds 0.99; the nucleus collapses to {0} and the return is deterministic.
Hints
- Stabilize the softmax by subtracting max(logits / temperature) before exponentiating — never exponentiate the raw scaled logits.
- Sort probabilities descending and take a cumulative sum; the nucleus size is the index of the first prefix sum that is >= p, plus one (so the crossing token is included). np.searchsorted(csum, p, side='left') + 1 gives this directly; clamp it to [1, vocab_size].
- After selecting the nucleus indices, renormalize ONLY those probabilities (divide by their sum) before sampling.
- Build the generator from the seed (np.random.default_rng(seed)) so the same seed reproduces the same draw; map the sampled position back through the saved sort permutation to recover the original token id.
- When the top token alone already reaches p, the nucleus is a single token and the result is deterministic regardless of the seed.