Skip to content

test_ml

summary

Given a list of experiments, where each experiment is a dictionary that maps fold_idx to a completed w&b sweep, this script ought to 1. take the best hyperparameters from each fold (sweep) or from a designated run_id if requested, 2. fit the model on that fold, 3. save test predictions to file.

Experiment dataclass

Class representing an experiment.

Source code in src/run/single_run/test_ml.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
@dataclass
class Experiment:
    """
    Class representing an experiment.
    """

    dataset_name: str = field(init=False)
    model_name: str = field(init=False)
    model_args: dataclass
    data_args: dataclass
    save_folder_name: Path = field(default_factory=Path)
    sweeps: list[Sweep] = field(default_factory=list)
    wandb_project: str = 'ml_debug'

    def load_sweep_ids_from_yaml(self, yaml_path: str) -> list[str]:
        """
        Load sweep IDs from a YAML file.
        """
        with open(yaml_path, 'r', encoding='utf-8') as f:
            sweep_cfg = yaml.safe_load(f)
        return sweep_cfg.get('sweep_ids', [])

    def __post_init__(self):
        self.dataset_name = self.data_args.__name__
        self.model_name = self.model_args.__name__
        # Load sweep IDs from the YAML file
        sweep_ids = self.load_sweep_ids_from_yaml(
            f'sweeps/{self.wandb_project}/configs/{self.model_name}_{self.dataset_name}.yaml'
        )
        # Create Sweep objects for each sweep ID
        self.sweeps = [Sweep(sweep_id=sweep_id) for sweep_id in sweep_ids]

        self.save_folder_name = Path(
            f'results/raw/+data={self.dataset_name},'
            f'+model={self.model_name},+trainer=TrainerML,trainer.wandb_job_type='
            f'{self.model_name}_{self.dataset_name}',
        )

        self.save_folder_name.mkdir(parents=True, exist_ok=True)

load_sweep_ids_from_yaml(yaml_path)

Load sweep IDs from a YAML file.

Source code in src/run/single_run/test_ml.py
197
198
199
200
201
202
203
def load_sweep_ids_from_yaml(self, yaml_path: str) -> list[str]:
    """
    Load sweep IDs from a YAML file.
    """
    with open(yaml_path, 'r', encoding='utf-8') as f:
        sweep_cfg = yaml.safe_load(f)
    return sweep_cfg.get('sweep_ids', [])

HyperArgs

Bases: Tap

Command line arguments for the script.

Source code in src/run/single_run/test_ml.py
156
157
158
159
160
161
162
163
164
class HyperArgs(Tap):
    """
    Command line arguments for the script.
    """

    wandb_entity: str = 'EyeRead'  # Name of the wandb entity to log to.
    wandb_run_id: str | None = None  # Provide if you want a single run.
    data_task: str = 'CopCo_TYP'  # Name of the data task (e.g., CopCo_TYP).
    wandb_project: str = 'CopCo_TYP_20250714'  # Name of the wandb project.

Sweep dataclass

Class representing a sweep in wandb.

Attributes:

Name Type Description
sweep_id str

The ID of the sweep in wandb.

cfg_of_best dict

The configuration of the best run in the sweep.

fold_index int | None

The index of the fold, if applicable.

Source code in src/run/single_run/test_ml.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@dataclass
class Sweep:
    """
    Class representing a sweep in wandb.

    Attributes:
        sweep_id (str): The ID of the sweep in wandb.
        cfg_of_best (dict): The configuration of the best run in the sweep.
        fold_index (int | None): The index of the fold, if applicable.
    """

    sweep_id: str
    cfg_of_best: dict = field(default_factory=dict)
    fold_index: int | None = None

checks(experiments_list)

Basic consistency check: ensure that each sweep_id is unique across experiments.

Source code in src/run/single_run/test_ml.py
247
248
249
250
251
252
def checks(experiments_list: list[Experiment]) -> None:
    """
    Basic consistency check: ensure that each sweep_id is unique across experiments.
    """
    sweep_ids = [sweep.sweep_id for exp in experiments_list for sweep in exp.sweeps]
    assert len(sweep_ids) == len(set(sweep_ids)), 'Duplicate sweep IDs found!'

get_config_from_run(api, entity, project, run_id)

Fetches the config of a single run by run_id.

Source code in src/run/single_run/test_ml.py
238
239
240
241
242
243
244
def get_config_from_run(
    api: wandb.Api, entity: str, project: str, run_id: str
) -> dict[str, Any]:
    """
    Fetches the config of a single run by run_id.
    """
    return api.run(path=f'{entity}/{project}/{run_id}').config

get_config_from_sweep(api, entity, project, sweep_id)

Fetches the config of the best run (by the sweep's objective) from a given sweep_id.

Source code in src/run/single_run/test_ml.py
224
225
226
227
228
229
230
231
232
233
234
235
def get_config_from_sweep(
    api: wandb.Api,
    entity: str,
    project: str,
    sweep_id: str,
) -> dict[str, Any]:
    """
    Fetches the config of the *best run* (by the sweep's objective) from a given sweep_id.
    """
    sweep_obj = api.sweep(path=f'{entity}/{project}/{sweep_id}')
    best_run = sweep_obj.best_run()
    return best_run.config

predict_on_val_and_test(model, val_datasets, test_datasets)

Predict on all val and test datasets, returning one list of results.

Source code in src/run/single_run/test_ml.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def predict_on_val_and_test(
    model: torch.nn.Module,
    val_datasets: list[np.ndarray],
    test_datasets: list[np.ndarray],
) -> list[np.ndarray]:
    """Predict on all val and test datasets, returning one list of results."""
    results = []
    # Predict on validation datasets
    for val_dataset in val_datasets:
        results.append(model.predict(val_dataset))

    # Predict on test datasets
    for test_dataset in test_datasets:
        results.append(model.predict(test_dataset))

    return results

process_results(results, dm, cfg, fold_index)

Given all results from val and test datasets, build a unified DataFrame with all relevant columns.

TODO almost duplicate code with test_dl.py

Source code in src/run/single_run/test_ml.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def process_results(
    results: list[tuple[torch.Tensor, ...]],
    dm: base_datamodule.ETDataModuleFast,
    cfg: Args,
    fold_index: int,
) -> pd.DataFrame:
    """
    Given all results from val and test datasets, build
    a unified DataFrame with all relevant columns.
    # TODO almost duplicate code with test_dl.py
    """
    group_level_metrics = []

    for index, eval_type_results in enumerate(results):
        # based on predict_dataloader (first 3 are val, last three test)
        eval_type = 'val' if index in [0, 1, 2] else 'test'
        if eval_type == 'val':
            dataset = dm.val_datasets[index]
        else:
            dataset = dm.test_datasets[index % 3]

        # Decide whether we have grouped trial keys

        trial_info = extract_trial_info(
            dataset, cols_to_keep=cfg.data.groupby_columns
        ).reset_index(drop=True)

        # Unpack model outputs
        preds, probs, y_true = eval_type_results
        if probs is None:
            probs = preds
        df = pd.DataFrame(
            {
                'label': y_true.numpy(),
                'prediction_prob': probs.numpy().tolist(),
                'eval_regime': REGIMES[index % 3],
                'eval_type': eval_type,
                'fold_index': fold_index,
            },
        )

        group_level_metrics.append(pd.concat([df, trial_info], axis=1))

    res = pd.concat(group_level_metrics)

    return res