Skip to content

utils

convert_string_to_list(s)

Converts a Pandas Series containing stringified lists into actual lists.

Parameters:

Name Type Description Default
s Series

Series with stringified lists.

required

Returns:

Type Description
list[list[float]]

list[list[float]]: List of lists with float values.

Source code in src/run/single_run/utils.py
358
359
360
361
362
363
364
365
366
367
368
def convert_string_to_list(s: pd.Series) -> list[list[float]]:
    """
    Converts a Pandas Series containing stringified lists into actual lists.

    Args:
        s (pd.Series): Series with stringified lists.

    Returns:
        list[list[float]]: List of lists with float values.
    """
    return s.apply(ast.literal_eval).tolist()

get_config(config_path)

Load the config for testing.

Source code in src/run/single_run/utils.py
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
def get_config(config_path: Path) -> Args:
    """
    Load the config for testing.
    """
    output_dir = to_absolute_path(str(config_path))
    overrides = OmegaConf.load(join(output_dir, '.hydra/overrides.yaml'))
    hydra_config = OmegaConf.load(join(output_dir, '.hydra/hydra.yaml'))

    # getting the config name from the previous job.
    config_name = hydra_config.hydra.job.config_name

    # compose a new config from scratch
    cfg = compose(config_name, overrides=overrides)
    updated_cfg = instantiate(cfg, _convert_='object')

    return updated_cfg

instantiate_config(cfg)

Instantiate the config object with the appropriate datamodule and model.

Parameters:

Name Type Description Default
cfg dict

The configuration object.

required

Returns:

Name Type Description
Args Args

The instantiated configuration object.

Source code in src/run/single_run/utils.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def instantiate_config(cfg: DictConfig) -> Args:
    """
    Instantiate the config object with the appropriate datamodule and model.

    Args:
        cfg (dict): The configuration object.

    Returns:
        Args: The instantiated configuration object.
    """
    args: Args = instantiate(config=cfg, _convert_='object')
    args.data.full_dataset_name = args.data.dataset_name
    args.model.full_model_name = args.model.model_name
    args.model.max_time_limit = args.model.max_time
    args.model.is_ml = args.model.base_model_name in MLModelNames
    args.model.use_class_weighted_loss = (
        args.model.use_class_weighted_loss
        if len(list(args.data.class_names)) > 1
        else False
    )
    return _configure_model_backbone(args)

update_cfg_with_wandb(cfg)

Update the configuration object with the wandb config. This function will overwrite the config with the wandb config if the wandb config is not empty. Args:

cfg (Args): The configuration object.

Returns: Args: The updated configuration object.

Source code in src/run/single_run/utils.py
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def update_cfg_with_wandb(cfg: Args) -> Args:
    """
    Update the configuration object with the wandb config.
    This function will overwrite the config with the wandb config if the
    wandb config is not empty.
    Args:

        cfg (Args): The configuration object.
    Returns:
        Args: The updated configuration object.
    """

    logger.info('Overwriting args with wandb config')

    def validate_and_setattr(
        obj: object, attr_name: str, value: str | int | float
    ) -> None:
        if not hasattr(obj, attr_name):
            raise AttributeError(
                f"Attribute '{attr_name}' does not exist in {obj.__class__.__name__}",
            )
        setattr(obj, attr_name, value)

    for key, value in wandb.config.items():
        if isinstance(value, dict):
            if not hasattr(cfg, key):
                raise AttributeError(
                    f"Attribute '{key}' does not exist in {cfg.__class__.__name__}",
                )
            sub_cfg = getattr(cfg, key)

            for sub_key, sub_value in value.items():
                if isinstance(sub_value, dict):
                    if not hasattr(sub_cfg, sub_key):
                        raise AttributeError(
                            f"Attribute '{sub_key}' does not exist in {key}",
                        )
                    sub_sub_cfg = getattr(sub_cfg, sub_key)

                    for sub_sub_key, sub_sub_value in sub_value.items():
                        logger.info(
                            f'Setting cfg.{key}.{sub_key}.{sub_sub_key} to {sub_sub_value}',
                        )
                        validate_and_setattr(sub_sub_cfg, sub_sub_key, sub_sub_value)
                else:
                    logger.info(f'Setting cfg.{key}.{sub_key} to {sub_value}')
                    validate_and_setattr(sub_cfg, sub_key, sub_value)
        else:
            logger.info(f'Setting cfg.{key} to {value}')
            validate_and_setattr(cfg, key, value)

    return _configure_model_backbone(cfg)