You are given a large in-memory Python list named records, where each element is a dictionary with at least the following keys:
-
country: str
-
device: str
-
timestamp: str | datetime
Each dictionary may also contain additional fields.
Write a Python function:
def stratified_sample_by_country(records, sample_size, random_state=None):
The function should return exactly sample_size records, sampled without replacement, such that the distribution of country in the sample matches the original dataset as closely as possible. For example, if 10% of the input records have country == "US", then approximately 10% of the sampled records should come from the US.
Your solution should address:
-
How to group records by country.
-
How to compute the target number of samples per country.
-
How to handle non-integer target counts while still returning exactly
sample_size
rows.
-
How to make the sampling random but reproducible when
random_state
is provided.
-
Edge cases such as
sample_size = 0
, missing
country
values, and
sample_size > len(records)
.
You should also briefly discuss the time and space complexity of your approach.