Skip to content

utils

count_hyperparameter_configs(config, log_specific_values=True, n_hours=24)

Traverse an arbitrarily nested dict, find all dicts with a 'values' key whose value is a list, and: 1. Print the full key path, the number of values, and the values themselves 2. Compute the product of all those lengths 3. Print the time each run should get if runs are evenly distributed over 24 hours

Source code in src/run/multi_run/utils.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def count_hyperparameter_configs(
    config: dict, log_specific_values: bool = True, n_hours: int = 24
) -> tuple:
    """
    Traverse an arbitrarily nested dict, find all dicts with a 'values' key
    whose value is a list, and:
        1. Print the full key path, the number of values, and the values themselves
        2. Compute the product of all those lengths
        3. Print the time each run should get if runs are evenly distributed over 24 hours
    """

    counts = {}
    values_map = {}

    def recurse(d: dict, path: str = ''):
        for key, val in d.items():
            current_path = f'{path}.{key}' if path else key
            if isinstance(val, dict):
                if 'values' in val and isinstance(val['values'], list):
                    counts[current_path] = len(val['values'])
                    values_map[current_path] = val['values']
                else:
                    recurse(val, current_path)

    recurse(config)

    if log_specific_values:
        # Print per-parameter counts and their values
        for param, n in counts.items():
            if n > 1:
                logger.info(
                    f'{param.replace(".parameters", "")}: {n} possible values → {values_map[param]}'
                )

    # Compute total combinations
    total = 1
    for n in counts.values():
        total *= n

    # Compute time per run over a 24-hour period
    seconds = n_hours * 3600
    time_per_run = seconds / total if total else 0
    # Format as DD:HH:MM:SS
    td = datetime.timedelta(seconds=round(time_per_run))
    days = td.days
    hours, remainder = divmod(td.seconds, 3600)
    minutes, seconds = divmod(remainder, 60)
    formatted = f'{days:02}:{hours:02}:{minutes:02}:{seconds:02}'
    logger.info(f'Total number of runs: {total}. Time per run: {formatted}')
    return formatted, int(total)