You are given access to an auto-regressive language model that can score the next token based on a prefix.
Assume:
-
next_logprobs(prefix_tokens) -> Map[token, logp]
returns log-probabilities (natural log) for the next token given the current prefix.
-
EOS
is a special end-of-sequence token.
-
You are also given
max_len
and an initial prefix
prompt_tokens
.
Implement two decoding methods:
-
Greedy decoding (greedy sampling)
-
Repeatedly select the single token with the highest log-probability at each step.
-
Stop when you generate
EOS
or when the total length reaches
max_len
.
-
Return the generated token sequence (including the prompt, or clearly specify whether you return only newly generated tokens).
-
Beam search (keep top-k paths)
-
Given an integer
k
(the number of beams to keep), implement beam search that maintains the
top-k candidate sequences
by total sequence score.
-
Use
cumulative log-probability
(sum of token log-probs) as the sequence score (no need for length normalization unless you choose to add it and explain).
-
At each decoding step, expand each current beam by one token using
next_logprobs
, then keep only the top-k resulting sequences overall.
-
Stop when all beams have generated
EOS
or when reaching
max_len
.
-
Return the best sequence (and optionally the top-k sequences if requested).
Clarify any edge cases you handle (e.g., early-finished beams, ties, empty vocabulary, whether finished beams remain in the beam set without further expansion).