Skip to content

sweep_creator

This script creates and launches wandb sweeps for different models, data tasks, and trainers. It generates bash & slurm scripts for running the sweeps on multiple GPUs or in a slurm environment. It uses the wandb library to create and manage the sweeps.

HyperArgs

Bases: Tap

Usage:
1. check that 'search_space_by_model' has the correct hyperparameter search space

for the model you wish to sweep. 2. run 'python src/run/multi_run/sweep_creator.py --models --data_task_names --trainer_names ' * To run multiple models/data_tasks/trainers, separate them with spaces 3. the script will create executable bash scripts for each sweep (fold_idx), which will launch the wandb sweeps. 4. run the bash script ./.sh 5. If you want to run on multiple GPUs, use the --gpu_count flag. Not tested for >1 GPUs.

Source code in src/run/multi_run/sweep_creator.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class HyperArgs(Tap):
    """
        Usage:
        1. check that 'search_space_by_model' has the correct hyperparameter search space
    for the model you wish to sweep.
        2. run 'python src/run/multi_run/sweep_creator.py --models <models>
    --data_task_names <data_task_names> --trainer_names <trainer_names>'
           * To run multiple models/data_tasks/trainers, separate them with spaces
        3. the script will create executable bash scripts for each sweep (fold_idx),
    which will launch the wandb sweeps.
        4. run the bash script ./<bash_script>.sh
        5. If you want to run on multiple GPUs, use the --gpu_count flag. Not tested for >1 GPUs.
    """

    run_cap: int = (
        250  # Maximum number of runs to execute. Relevant for non-grid search.
    )
    wandb_project: str = 'debug'  # Name of the wandb project to log to.
    wandb_entity: str = 'EyeRead'  # Name of the wandb entity to log to.
    folds: list[int] = [0]  # List of fold indices to run.
    gpu_count: int = 1  # Number of GPUs to use. >1  not tested.
    search_algorithm: Literal['bayes', 'grid', 'random'] = (
        'grid'  # Search algorithm to use.
    )

    # Slurm settings.
    slurm_cpus: int = 10  # Number of CPUs to use. Ideally number of workers + 2.
    slurm_mem: str = '75G'  # Amount of memory to use.
    slurm_mailto: str = 'shubi@campus.technion.ac.il'  # Email to send notifications to.
    num_duplicates_per_gpu: int = 1  # Number of duplicates to run on each GPU.

    # Model, data, and trainer settings as lists to support multiple values
    models: list[str] = []  # List of models to sweep
    base_models: list[str] = []  # List of base models to sweep
    data_tasks: list[str] = []  # List of data tasks to sweep
    trainers: list[str] = [
        'TrainerDL'
    ]  # List of trainers to sweep (default is 'default')

    # Filled in by the script
    trainer: str | None = None
    data_task: str | None = None
    model: str | None = None
    base_model: str | None = None

create_bash_scripts(hyper_args, sweep_ids, mode)

Create bash scripts for the given sweep ids. Args: hyper_args (HyperArgs): Hyperparameters for the sweep. sweep_ids (list[str]): List of sweep ids.

Source code in src/run/multi_run/sweep_creator.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def create_bash_scripts(
    hyper_args: HyperArgs, sweep_ids: list[str], mode: Literal['lacc', 'david']
) -> None:
    """
    Create bash scripts for the given sweep ids.
    Args:
        hyper_args (HyperArgs): Hyperparameters for the sweep.
        sweep_ids (list[str]): List of sweep ids.
    """
    assert hyper_args.model is not None
    base_path = (
        Path('sweeps') / hyper_args.wandb_project / 'bash' / mode / hyper_args.model
    )
    base_path.mkdir(parents=True, exist_ok=True)
    filename = base_path / (
        f'{hyper_args.model}_{hyper_args.data_task}_folds_'
        + '_'.join(map(str, hyper_args.folds))
        + '.sh'
    )
    main_command = '; '.join(
        ['conda activate eyebench']
        + [
            f'CUDA_VISIBLE_DEVICES=${{GPU_NUM}} wandb agent '
            f'{hyper_args.wandb_entity}/{hyper_args.wandb_project}/{sweep_id}'
            for sweep_id in sweep_ids
        ]
    )
    write_bash_script(
        filename=filename, main_command=main_command, sweep_ids=sweep_ids, mode=mode
    )

create_slurm_scripts(hyper_args, sweep_ids, slurm_qos)

Create slurm scripts for the given sweep ids. Args: hyper_args (HyperArgs): Hyperparameters for the sweep. sweep_ids (list[str]): List of sweep ids. slurm_qos (str): 'normal' or 'basic' for DGX.

Source code in src/run/multi_run/sweep_creator.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def create_slurm_scripts(
    hyper_args: HyperArgs, sweep_ids: list[str], slurm_qos: Literal['normal', 'basic']
) -> None:
    """
    Create slurm scripts for the given sweep ids.
    Args:
        hyper_args (HyperArgs): Hyperparameters for the sweep.
        sweep_ids (list[str]): List of sweep ids.
        slurm_qos (str): 'normal' or 'basic' for DGX.
    """
    assert hyper_args.model is not None
    base_path = Path('sweeps') / hyper_args.wandb_project / 'slurm' / hyper_args.model
    base_path.mkdir(parents=True, exist_ok=True)
    filename = base_path / (
        f'{hyper_args.model}_{hyper_args.data_task}_folds_'
        + '_'.join(map(str, hyper_args.folds))
        + f'{slurm_qos}.job'
    )
    write_slurm_script(
        filename=filename,
        hyper_args=hyper_args,
        sweep_ids=sweep_ids,
        slurm_qos=slurm_qos,
    )

create_sweep_configs(args)

Create sweep configurations for the given hyperparameters.

Parameters:

Name Type Description Default
args HyperArgs

Hyperparameters for the sweep.

required

Returns:

Type Description
list[dict]

list[dict]: List of sweep configurations.

Source code in src/run/multi_run/sweep_creator.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def create_sweep_configs(args: HyperArgs) -> list[dict]:
    """
    Create sweep configurations for the given hyperparameters.

    Args:
        args (HyperArgs): Hyperparameters for the sweep.

    Returns:
        list[dict]: List of sweep configurations.
    """
    search_space = search_space_by_model[args.base_model]
    logger.info(f'Creating sweep configs for {args.base_model}')
    _, total_count = count_hyperparameter_configs(search_space)
    logger.info(args)
    if total_count > args.run_cap:
        logger.warning(
            f'Warning: The number of hyperparameter configurations ({total_count}) is less than the run cap ({args.run_cap}).'
        )

    sweep_configs = [
        {
            'program': 'src/run/single_run/train.py',
            'method': args.search_algorithm,
            'metric': {
                'goal': 'minimize',
                'name': 'loss/val_all',
            },
            'entity': args.wandb_entity,
            'project': args.wandb_project,
            'name': f'{args.model}_{args.data_task}_fold_{fold_idx}',
            'parameters': search_space,
            'run_cap': args.run_cap,
            'command': [
                '${env}',
                '${interpreter}',
                '${program}',
                '${args_no_hypens}',
                f'+model={args.model}',
                f'+data={args.data_task}',
                f'+trainer={args.trainer}',
                f'data.fold_index={fold_idx}',
                f'trainer.devices={args.gpu_count}',
                f'trainer.wandb_job_type={args.model}_{args.data_task}',
            ],
        }
        for fold_idx in args.folds
    ]

    return sweep_configs

launch_sweeps(entity, project, sweep_configs)

Launch wandb sweeps for the given configurations.

Parameters:

Name Type Description Default
entity str

Name of the wandb entity.

required
project str

Name of the wandb project.

required
sweep_configs List[Dict]

List of sweep configurations.

required

Returns:

Type Description
list[str]

List[str]: List of sweep ids.

Source code in src/run/multi_run/sweep_creator.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def launch_sweeps(entity: str, project: str, sweep_configs: list[dict]) -> list[str]:
    """
    Launch wandb sweeps for the given configurations.

    Args:
        entity (str): Name of the wandb entity.
        project (str): Name of the wandb project.
        sweep_configs (List[Dict]): List of sweep configurations.

    Returns:
        List[str]: List of sweep ids.
    """
    sweep_ids = [
        wandb.sweep(cfg, entity=entity, project=project) for cfg in sweep_configs
    ]
    return sweep_ids

write_bash_script(filename, main_command, sweep_ids, mode)

Write a bash script to launch wandb agents in tmux sessions.

Parameters:

Name Type Description Default
filename str

Name of the bash script file.

required
main_command str

Main command to run in the tmux session.

required
sweep_ids list[str]

List of sweep ids.

required
Source code in src/run/multi_run/sweep_creator.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
def write_bash_script(
    filename: Path,
    main_command: str,
    sweep_ids: list[str],
    mode: Literal['lacc', 'david'],
) -> None:
    """
    Write a bash script to launch wandb agents in tmux sessions.

    Args:
        filename (str): Name of the bash script file.
        main_command (str): Main command to run in the tmux session.
        sweep_ids (list[str]): List of sweep ids.
    """
    if mode == 'lacc':
        conda_path = 'source $HOME/miniforge3/etc/profile.d/conda.sh'
        cd_path = 'cd $HOME/eyebench_private'
    elif mode == 'david':
        conda_path = 'source ~/.conda/envs/eyebench/etc/profile.d/mamba.sh'
        cd_path = 'cd /mnt/mlshare/reich3/eyebench_private'
    else:
        raise ValueError(f'Invalid mode: {mode}')

    full_command = f"""#!/bin/bash

if ! command -v tmux &>/dev/null; then
    echo "tmux could not be found, please install tmux first."
    exit 1  
fi

{conda_path}
{cd_path}

GPU_NUM=$1
RUNS_ON_GPU=${{2:-1}}
for ((i=1; i<=RUNS_ON_GPU; i++)); do
    session_name="wandb-gpu${{GPU_NUM}}-dup${{i}}-unified-{sweep_ids[0]}-{len(sweep_ids)}"
    tmux new-session -d -s "${{session_name}}" "{main_command}"; tmux set-option -t "${{session_name}}" remain-on-exit off
    echo "Launched W&B agent for GPU ${{GPU_NUM}}, Dup ${{i}} in tmux session ${{session_name}}"
done
"""

    with open(filename, 'w', encoding='utf-8') as f:
        f.write(full_command)
    os.chmod(filename, os.stat(filename).st_mode | stat.S_IEXEC)
    logger.info(f'Created bash script: {filename}')

write_slurm_script(filename, hyper_args, sweep_ids, slurm_qos)

Write a slurm script to launch wandb agents in slurm jobs.

Parameters:

Name Type Description Default
filename str

Name of the slurm script file.

required
hyper_args HyperArgs

Hyperparameters for the sweep.

required
sweep_ids list[str]

List of sweep ids.

required
slurm_qos str

Slurm quality of service (normal or basic).

required
Source code in src/run/multi_run/sweep_creator.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def write_slurm_script(
    filename: Path,
    hyper_args: HyperArgs,
    sweep_ids: list[str],
    slurm_qos: Literal['normal', 'basic'],
) -> None:
    """
    Write a slurm script to launch wandb agents in slurm jobs.

    Args:
        filename (str): Name of the slurm script file.
        hyper_args (HyperArgs): Hyperparameters for the sweep.
        sweep_ids (list[str]): List of sweep ids.
        slurm_qos (str): Slurm quality of service (normal or basic).
    """
    base_srun_command = f"""
srun --overlap --ntasks=1 --nodes=1 --cpus-per-task=$SLURM_CPUS_PER_TASK -p work,mig \\
    --container-image=/rg/berzak_prj/shubi/prj/rev05_pytorchlightning+pytorch_lightning.sqsh \\
    --container-mounts="/rg/berzak_prj/shubi:/home/shubi" \\
    --container-workdir=/home/shubi/eyebench_private \\
    bash -c "
echo 'Starting job on $(date)'
source /home/shubi/prj/nvidia_pytorch_25_03_py3_mamba_wrapper.sh
conda activate eyebench
wandb agent {hyper_args.wandb_entity}/{hyper_args.wandb_project}/$SWEEP_ID"
    """
    srun_command = f"""{base_srun_command}"""
    if hyper_args.num_duplicates_per_gpu > 1:
        srun_command = f"""{srun_command}\nsleep 600\nwait"""
        for _ in range(hyper_args.num_duplicates_per_gpu - 1):
            srun_command += f"""{base_srun_command}\nsleep 10\n"""
        srun_command += 'wait'

    with open(filename, 'w', encoding='utf-8') as f:
        f.write(
            f"""#!/bin/bash

#SBATCH --job-name={hyper_args.model}_{hyper_args.data_task}-array
#SBATCH --output=logs/{hyper_args.model}_{hyper_args.data_task}-%A_%a.out
#SBATCH --error=logs/{hyper_args.model}_{hyper_args.data_task}-%A_%a.err
#SBATCH --partition=work,mig
#SBATCH --ntasks={hyper_args.num_duplicates_per_gpu}
#SBATCH --nodes=1
#SBATCH --gpus={hyper_args.gpu_count}
#SBATCH --qos={slurm_qos}
#SBATCH --cpus-per-task={hyper_args.slurm_cpus}
#SBATCH --mem={hyper_args.slurm_mem}
#SBATCH --array=0-{len(sweep_ids) - 1}
#SBATCH --mail-type=ALL
#SBATCH --mail-user={hyper_args.slurm_mailto}

sweep_ids=({' '.join(sweep_ids)})
SWEEP_ID=${{sweep_ids[$SLURM_ARRAY_TASK_ID]}}

{srun_command}
"""
        )
    os.chmod(filename, os.stat(filename).st_mode | stat.S_IEXEC)
    logger.info(f'Created Slurm array job script: {filename}')