Implement PyTorch training loop
Company: Amazon
Role: Machine Learning Engineer
Category: Coding & Algorithms
Difficulty: medium
Interview Round: Onsite
Quick Answer: This question evaluates practical implementation skills in PyTorch, focusing on model and device management, batch-wise training mechanics, gradient handling, and optimizer interaction.
Constraints
- 0 <= num_epochs <= 100
- 1 <= len(model['weights']) <= 20
- Each sample contains exactly len(model['weights']) features
- 0 <= total number of samples across all batches <= 10^4
- Each batch is non-empty, except that `train_loader` itself may be empty
- loss_fn is always 'mse'
Examples
Input: ({'weights': [0.0], 'bias': 0.0}, [([[1.0], [2.0]], [2.0, 4.0])], {'lr': 0.1}, 'mse', 'cpu', 1)
Expected Output: {'weights': [1.0], 'bias': 0.6, 'losses': [10.0], 'device': 'cpu'}
Explanation: Starting from zero, the batch predictions are [0, 0], so the batch MSE is 10.0. One gradient descent step updates the weight to 1.0 and the bias to 0.6.
Input: ({'weights': [0.0], 'bias': 0.0}, [([[1.0]], [1.0]), ([[2.0]], [2.0])], {'lr': 0.1}, 'mse', 'cuda', 2)
Expected Output: {'weights': [0.7696], 'bias': 0.4608, 'losses': [1.48, 0.039168], 'device': 'cuda'}
Explanation: This case has two epochs and two batches, so the loop must correctly repeat zero-grad, forward, backward, and step for every batch. The returned losses are the average batch losses for each epoch.
Input: ({'weights': [1.0, -1.0], 'bias': 0.5}, [], {'lr': 0.01}, 'mse', 'cpu', 2)
Expected Output: {'weights': [1.0, -1.0], 'bias': 0.5, 'losses': [0.0, 0.0], 'device': 'cpu'}
Explanation: With no batches, no parameter updates occur. By definition in this problem, each epoch's average loss is 0.0.
Input: ({'weights': [0.0, 0.0], 'bias': 0.0}, [([[1.0, 2.0], [3.0, 4.0]], [5.0, 11.0])], {'lr': 0.01}, 'mse', 'cpu', 1)
Expected Output: {'weights': [0.38, 0.54], 'bias': 0.16, 'losses': [73.0], 'device': 'cpu'}
Explanation: This verifies that gradients are computed separately for each weight in a multi-feature linear model.
Solution
def solution(model, train_loader, optimizer, loss_fn, device, num_epochs):
if loss_fn != 'mse':
raise ValueError("Only 'mse' is supported")
weights = [float(w) for w in model.get('weights', [])]
bias = float(model.get('bias', 0.0))
lr = float(optimizer.get('lr', 0.0))
epoch_losses = []
def clean(x):
x = round(float(x), 6)
if x == -0.0:
x = 0.0
return x
# Simulate moving the model to the requested device.
moved_device = device
for _ in range(num_epochs):
total_loss = 0.0
batch_count = 0
for inputs, targets in train_loader:
batch_size = len(inputs)
if batch_size == 0:
continue
# 1) Zero gradients
grad_w = [0.0] * len(weights)
grad_b = 0.0
# 2) Forward pass
preds = []
for sample in inputs:
pred = bias
for j, value in enumerate(sample):
pred += weights[j] * value
preds.append(pred)
# 3) Loss computation (mean squared error)
loss = 0.0
for pred, target in zip(preds, targets):
diff = pred - target
loss += diff * diff
loss /= batch_size
# 4) Backward pass (analytic gradients)
for sample, pred, target in zip(inputs, preds, targets):
coeff = 2.0 * (pred - target) / batch_size
for j, value in enumerate(sample):
grad_w[j] += coeff * value
grad_b += coeff
# 5) Optimizer step
for j in range(len(weights)):
weights[j] -= lr * grad_w[j]
bias -= lr * grad_b
total_loss += loss
batch_count += 1
avg_loss = total_loss / batch_count if batch_count else 0.0
epoch_losses.append(clean(avg_loss))
return {
'weights': [clean(w) for w in weights],
'bias': clean(bias),
'losses': epoch_losses,
'device': moved_device
}Time complexity: O(num_epochs * total_samples * num_features). Space complexity: O(num_features + num_epochs).
Hints
- Keep the training loop order strict: zero gradients, forward pass, loss computation, backward pass, then optimizer step.
- For MSE on a batch of size n, the derivative with respect to each prediction is `2 * (pred - target) / n`.