Implement dynamic batching for token decoding
Company: xAI
Role: Machine Learning Engineer
Category: Coding & Algorithms
Difficulty: medium
Interview Round: Onsite
You are given a black-box “simulated language model” interface that can advance many sequences in a batch.
## Model interface
- Tokens are integers.
- `model_next(batch_prefixes) -> next_tokens`
- `batch_prefixes` is a list of token lists, one per active sequence in the current batch.
- `next_tokens` is a list of integers of the same length, where `next_tokens[i]` is the next generated token for `batch_prefixes[i]`.
## Requests
Each request/sequence has:
- `prompt_tokens`: initial prefix tokens
- `max_tokens`: maximum number of *generated* tokens allowed (not counting the prompt)
- A stopping rule:
- either a `stop_token` (single token), or
- a `stop_sequence` (a list of tokens that, when it appears as a suffix of the generated output, ends generation)
- A callback to return the final generated tokens (or you may collect results to return at the end).
## Batch execution requirement
Implement a decoding/sampling engine with **dynamic batching**:
- There is a fixed batch capacity `B`.
- Requests arrive in a waiting queue (you can assume they are all available initially, or you can model an input queue).
- You repeatedly call `model_next` to advance active sequences.
- Sequences may finish at different times due to:
- reaching `max_tokens`, or
- hitting the stop condition.
- When a sequence finishes, its slot becomes free and should be **refilled** from the waiting queue if possible.
- Near the end, the batch may be partially filled; your code must handle `len(active) < B` correctly.
## Correctness requirement
Maintain a correct mapping between **batch slots** and **requests** so that tokens and final outputs are never mixed up after refilling (e.g., via a `slot_id -> request_id` mapping).
## Task
Write a function (or class) that runs this dynamic-batching decoding loop until all requests are completed, and returns (or callbacks) the generated outputs per request.
Clearly define:
- your data structures (active slots, waiting queue, per-request state),
- the main loop and termination condition,
- how you detect stop conditions (especially `stop_sequence`),
- and how you handle partially filled batches.
Quick Answer: This question evaluates dynamic batching, per-request state management, and sequence-decoding correctness for language-model inference, including handling stop conditions, max-token limits, and maintaining a correct slot-to-request mapping.