Implement dynamic batching for token decoding
Company: xAI
Role: Machine Learning Engineer
Category: Coding & Algorithms
Difficulty: medium
Interview Round: Onsite
Quick Answer: This question evaluates dynamic batching, per-request state management, and sequence-decoding correctness for language-model inference, including handling stop conditions, max-token limits, and maintaining a correct slot-to-request mapping.
Constraints
- 1 <= batch_size <= 1000
- 0 <= len(requests) <= 10000
- 0 <= request['max_tokens'] <= 10000
- Exactly one of request['stop_token'] or request['stop_sequence'] is non-None
- If present, request['stop_sequence'] is non-empty
- All tokens are integers
- The sum of all generated tokens actually produced across all requests is at most 100000
- Every prefix your engine queries exists in next_token_map
Examples
Input: (2, [{'prompt_tokens': [1], 'max_tokens': 4, 'stop_token': 4, 'stop_sequence': None}, {'prompt_tokens': [5], 'max_tokens': 5, 'stop_token': None, 'stop_sequence': [7, 8]}, {'prompt_tokens': [9], 'max_tokens': 2, 'stop_token': 99, 'stop_sequence': None}], {(1,): 2, (1, 2): 3, (1, 2, 3): 4, (5,): 6, (5, 6): 7, (5, 6, 7): 8, (9,): 10, (9, 10): 11})
Expected Output: [[2, 3, 4], [6, 7, 8], [10, 11]]
Explanation: Requests 0 and 1 start first. Both finish after the third model call, freeing slots. Request 2 is then inserted and runs alone. The stop token 4 and stop sequence [7, 8] are included in the outputs.
Input: (3, [{'prompt_tokens': [4], 'max_tokens': 0, 'stop_token': 9, 'stop_sequence': None}, {'prompt_tokens': [], 'max_tokens': 3, 'stop_token': None, 'stop_sequence': [7, 8]}, {'prompt_tokens': [2], 'max_tokens': 1, 'stop_token': 5, 'stop_sequence': None}], {(): 7, (7,): 8, (2,): 3})
Expected Output: [[], [7, 8], [3]]
Explanation: The first request finishes immediately because max_tokens is 0. The batch is partially filled. The second request stops when its generated suffix becomes [7, 8], and the third request stops after one token because of max_tokens.
Input: (2, [{'prompt_tokens': [1], 'max_tokens': 4, 'stop_token': 2, 'stop_sequence': None}, {'prompt_tokens': [1], 'max_tokens': 3, 'stop_token': None, 'stop_sequence': [2, 3]}, {'prompt_tokens': [1], 'max_tokens': 4, 'stop_token': 4, 'stop_sequence': None}, {'prompt_tokens': [5], 'max_tokens': 1, 'stop_token': 9, 'stop_sequence': None}], {(1,): 2, (1, 2): 3, (1, 2, 3): 4, (5,): 6})
Expected Output: [[2], [2, 3], [2, 3, 4], [6]]
Explanation: Multiple requests share the same prompt, so keeping the slot-to-request mapping correct matters. Request 0 stops after token 2, its slot is refilled, and later requests continue from their own correct states.
Input: (2, [], {})
Expected Output: []
Explanation: No requests means there is nothing to decode.
Solution
def solution(batch_size, requests, next_token_map):
def model_next(batch_prefixes):
next_tokens = []
for prefix in batch_prefixes:
key = tuple(prefix)
if key not in next_token_map:
raise KeyError(f'Missing next token for prefix {key}')
next_tokens.append(next_token_map[key])
return next_tokens
n = len(requests)
results = [[] for _ in range(n)]
prefixes = [req['prompt_tokens'][:] for req in requests]
active_slots = [None] * batch_size
waiting = 0
def fill_slot(slot):
nonlocal waiting
while waiting < n:
if requests[waiting]['max_tokens'] == 0:
waiting += 1
continue
active_slots[slot] = waiting
waiting += 1
return
active_slots[slot] = None
for slot in range(batch_size):
fill_slot(slot)
while True:
batch_prefixes = []
slot_order = []
for slot, req_id in enumerate(active_slots):
if req_id is not None:
slot_order.append(slot)
batch_prefixes.append(prefixes[req_id])
if not batch_prefixes:
break
next_tokens = model_next(batch_prefixes)
finished_slots = []
for i, slot in enumerate(slot_order):
req_id = active_slots[slot]
token = next_tokens[i]
prefixes[req_id].append(token)
results[req_id].append(token)
req = requests[req_id]
done = len(results[req_id]) >= req['max_tokens']
if not done and req.get('stop_token') is not None and token == req['stop_token']:
done = True
stop_sequence = req.get('stop_sequence')
if not done and stop_sequence is not None:
m = len(stop_sequence)
if m <= len(results[req_id]) and results[req_id][-m:] == stop_sequence:
done = True
if done:
finished_slots.append(slot)
for slot in finished_slots:
active_slots[slot] = None
for slot in finished_slots:
fill_slot(slot)
return resultsTime complexity: O(T * L + C * B), where T is the total number of generated tokens, L is the maximum stop_sequence length, C is the number of model calls, and B is batch_size. Space complexity: O(P + T + B), where P is the total prompt length, T is the total generated output length, and B is batch_size.
Hints
- Use a fixed-size array for batch slots, where each slot stores the request ID currently occupying that slot.
- For a stop sequence, you only need to compare the end of the generated output after appending the newest token.