Implement Focal Loss for Class-Imbalanced Classification
Company: Datadog
Role: Machine Learning Engineer
Category: Machine Learning
Difficulty: hard
Interview Round: Technical Screen
# Implement Focal Loss for Class-Imbalanced Classification
In dense detection and many real-world classification problems the data is highly imbalanced: a huge number of easy, well-classified examples (e.g. background) dominate the standard cross-entropy loss, while the rare, hard, or positive examples contribute almost nothing to the gradient. **Focal Loss** addresses this by *down-weighting* the loss assigned to easy examples so training concentrates on the hard, misclassified ones.
Implement Focal Loss from scratch using only basic tensor operations (no high-level "focal loss" library function). Write a function for **binary** classification that:
- takes a batch of raw **logits** $z$ (pre-sigmoid) and binary targets $y \in \{0, 1\}$ of the same shape,
- applies the focal modulation with a focusing parameter $\gamma \ge 0$ and a class-balancing weight $\alpha \in [0, 1]$,
- is **numerically stable** (must accept large-magnitude logits without overflow or `NaN`),
- supports reduction modes `"none"`, `"mean"`, `"sum"`,
- is differentiable with respect to $z$.
Let $p = \sigma(z)$. Define the probability of the true class $p_t = p$ if $y = 1$ else $1 - p$, and the balancing weight $\alpha_t = \alpha$ if $y = 1$ else $1 - \alpha$. The per-example focal loss is:
$$\mathrm{FL} = -\,\alpha_t\,(1 - p_t)^{\gamma}\,\log(p_t)$$
```hint Where to start
Start from binary cross-entropy expressed **directly from logits** so you never compute `log(0)`. There's a standard "softplus identity" that rewrites $-\log\sigma(z)$ as a sum of a linear term in $z$ and a bounded `log1p(exp(...))` term — look up the stable formula for `binary_cross_entropy_with_logits` if you don't remember it. That stable BCE term *is exactly* $-\log(p_t)$, so focal loss is just $\alpha_t (1-p_t)^\gamma \cdot \text{BCE}$.
```
```hint Numerical stability
Do **not** compute `p = sigmoid(z)` and then `log(p)`: for large negative $z$, $p$ underflows to 0 and $\log(p) \to -\infty$. Compute the cross-entropy term in logit space; only the *modulating factor* $(1-p_t)^\gamma$ needs $p = \sigma(z)$, and that factor lives in $[0,1]$ so it is safe.
```
```hint Vectorize the branches
Replace the `if y==1 else` logic with arithmetic so it runs branch-free on whole tensors. Since $y \in \{0,1\}$, you can use it directly as a selector weight to linearly blend the "positive" and "negative" cases for both $p_t$ and $\alpha_t$ in one expression each — no `torch.where` or indexing needed.
```
### Constraints & Assumptions
- Logits $z$ and targets $y$ are tensors of identical shape; the batch may span arbitrary dimensions. $y$ is float-valued in $\{0,1\}$.
- $\gamma$ (gamma) is a non-negative scalar; $\gamma = 0$ reduces focal loss to (class-weighted) cross-entropy. A typical value is $\gamma = 2$.
- $\alpha$ (alpha) $\in [0,1]$ balances the positive vs. negative class. A typical value is $\alpha = 0.25$.
- Logits can be large, $|z|$ up to roughly $50$–$100$; the implementation must not produce `Inf` or `NaN`.
- No autograd-breaking operations: do not `detach()` anything on the path that needs a gradient w.r.t. $z$.
- `"mean"` averages over all elements unless a different normalization is requested.
### Clarifying Questions to Ask
- Is the task **binary** (sigmoid) focal loss, or **multi-class** (softmax) focal loss? I will implement binary and can extend.
- Should $\alpha$ be a fixed scalar, or a per-class vector?
- For the `"mean"` reduction, do we average over the element count, or normalize by the number of positive examples (the RetinaNet convention)?
- Are targets hard labels $\{0,1\}$ or soft probabilities? Is there an `ignore_index` for masked elements?
- Should the gradient flow through the modulating factor $(1 - p_t)^\gamma$, or do we stop gradient there (as some implementations do)?
### What a Strong Answer Covers
```premium-lock What a Strong Answer Covers
```
### Follow-up Questions
- Extend the implementation to **multi-class** focal loss using a softmax over $C$ classes. How does $p_t$ change, and how do you index the true class efficiently?
- Why does the focusing term help when $\alpha$-balancing alone is insufficient? Concretely, how does $\gamma$ change the gradient contribution of an easy example with $p_t = 0.9$ versus a hard one with $p_t = 0.1$?
- How would you choose $\alpha$ and $\gamma$ in practice, and how are they coupled? What happens as $\gamma \to \infty$?
- Some implementations stop the gradient through $(1 - p_t)^\gamma$. What is the motivation, and what is the trade-off?
Quick Answer: This question evaluates a candidate's ability to implement a specialized loss function for class-imbalanced classification from basic tensor operations rather than a library call. It tests understanding of numerical stability in logit-space computation, differentiability, and vectorized reduction handling, commonly asked in machine learning interviews to assess practical deep learning implementation skill beyond conceptual familiarity.