Implement K-means and solve interval/frequency tasks
Company: Amazon
Role: Machine Learning Engineer
Category: Coding & Algorithms
Difficulty: medium
Interview Round: Onsite
Quick Answer: This multi-part question evaluates understanding of K-means clustering, interval-merging algorithms, and frequency-based top-k retrieval, assessing competencies in unsupervised learning concepts, algorithmic problem solving, data structures, and complexity analysis.
Part 1: Deterministic K-means Clustering
Constraints
- 1 <= n <= 10^3
- 1 <= d <= 20
- 1 <= k <= n
- 1 <= max_iters <= 100
- All rows of X have the same length d
- Each coordinate fits in standard Python int/float ranges
Examples
Input: ([[1, 1], [1, 2], [4, 4], [4, 5]], 2)
Expected Output: {'centroids': [[1.0, 1.5], [4.0, 4.5]], 'labels': [0, 0, 1, 1]}
Explanation: Two obvious groups form, and the centroids converge to the means of those groups.
Input: ([[0, 0], [2, 2], [4, 4]], 1)
Expected Output: {'centroids': [[2.0, 2.0]], 'labels': [0, 0, 0]}
Explanation: With one cluster, the centroid becomes the mean of all points.
Input: ([[0], [0], [10]], 2)
Expected Output: {'centroids': [[10.0], [0.0]], 'labels': [1, 1, 0]}
Explanation: The initial centroids are identical, so one cluster is empty in the first update and must remain unchanged.
Input: ([[5], [9], [13]], 3)
Expected Output: {'centroids': [[5.0], [9.0], [13.0]], 'labels': [0, 1, 2]}
Explanation: When k equals the number of points, each point remains its own cluster.
Solution
def solution(X, k, max_iters=100):
if not X or k <= 0 or k > len(X):
return {'centroids': [], 'labels': []}
d = len(X[0])
centroids = [list(map(float, X[i])) for i in range(k)]
prev_labels = None
labels = [0] * len(X)
for _ in range(max_iters):
labels = []
for point in X:
best_idx = 0
best_dist = None
for idx, centroid in enumerate(centroids):
dist = 0.0
for j in range(d):
diff = point[j] - centroid[j]
dist += diff * diff
if best_dist is None or dist < best_dist:
best_dist = dist
best_idx = idx
labels.append(best_idx)
if labels == prev_labels:
break
counts = [0] * k
sums = [[0.0] * d for _ in range(k)]
for label, point in zip(labels, X):
counts[label] += 1
for j in range(d):
sums[label][j] += point[j]
for idx in range(k):
if counts[idx] > 0:
centroids[idx] = [sums[idx][j] / counts[idx] for j in range(d)]
prev_labels = labels[:]
return {'centroids': centroids, 'labels': labels}Time complexity: O(max_iters * n * k * d). Space complexity: O(k * d + n).
Hints
- Use squared Euclidean distance so you do not need to call sqrt.
- During the update step, keep a sum vector and a count for each cluster; if a count is 0, leave that centroid unchanged.
Part 2: Merge Overlapping Closed Intervals
Constraints
- 0 <= len(intervals) <= 2 * 10^5
- -2^31 <= start <= end <= 2^31 - 1
Examples
Input: ([[1, 3], [2, 6], [8, 10], [15, 18]],)
Expected Output: [[1, 6], [8, 10], [15, 18]]
Explanation: The first two intervals overlap and become [1, 6].
Input: ([[1, 4], [4, 5]],)
Expected Output: [[1, 5]]
Explanation: Closed intervals that touch at an endpoint still overlap.
Input: ([],)
Expected Output: []
Explanation: An empty input has no intervals to merge.
Input: ([[6, 8], [1, 9], [2, 4], [4, 7]],)
Expected Output: [[1, 9]]
Explanation: After sorting, every interval overlaps with [1, 9].
Solution
def solution(intervals):
if not intervals:
return []
intervals = sorted(intervals, key=lambda interval: interval[0])
merged = [intervals[0][:]]
for start, end in intervals[1:]:
last = merged[-1]
if start <= last[1]:
if end > last[1]:
last[1] = end
else:
merged.append([start, end])
return mergedTime complexity: O(n log n). Space complexity: O(n).
Hints
- Sort the intervals by their start value first.
- As you scan left to right, only compare the current interval with the last merged interval.
Part 3: Top-k Most Frequent Elements with Deterministic Tie-Breaking
Constraints
- 1 <= len(nums) <= 2 * 10^5
- Values fit in 32-bit signed integers
- 1 <= k <= number of distinct values in nums
Examples
Input: ([1, 1, 1, 2, 2, 3], 2)
Expected Output: [1, 2]
Explanation: 1 appears 3 times, 2 appears 2 times, and 3 appears once.
Input: ([4, 4, 1, 1, 2, 2], 2)
Expected Output: [1, 2]
Explanation: All three values appear twice, so the smaller values come first.
Input: ([7], 1)
Expected Output: [7]
Explanation: A single-element array returns that element.
Input: ([-1, -1, -2, -2, -3], 2)
Expected Output: [-2, -1]
Explanation: -1 and -2 are tied in frequency, so the smaller number -2 comes first.
Solution
def solution(nums, k):
from collections import Counter
freq = Counter(nums)
ordered = sorted(freq.items(), key=lambda item: (-item[1], item[0]))
return [value for value, _ in ordered[:k]]Time complexity: O(n + m log m), where m is the number of distinct values. Space complexity: O(m).
Hints
- Count how many times each value appears before trying to choose the top k.
- Once you know the frequencies, sort distinct values by frequency descending and value ascending.