Skip to content

main_config

Configuration file for the model, trainer, data paths and data.

Args dataclass

Configuration class for the model, trainer, data paths, and data.

Attributes:

Name Type Description
model BaseModelArgs

Configuration for the model.

trainer BaseTrainer

Configuration for the trainer.

data DataArgs

Configuration for the data.

eval_path str | None

Path for evaluation.

hydra Any

Configuration for Hydra.

Source code in src/configs/main_config.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@dataclass
class Args:
    """
    Configuration class for the model, trainer, data paths, and data.

    Attributes:
        model (BaseModelArgs): Configuration for the model.
        trainer (BaseTrainer): Configuration for the trainer.
        data (DataArgs): Configuration for the data.
        eval_path (str | None): Path for evaluation.
        hydra (Any): Configuration for Hydra.
    """

    model: BaseModelArgs = field(default_factory=BaseModelArgs)
    trainer: BaseTrainer = field(default_factory=BaseTrainer)
    data: DataArgs = field(default_factory=DataArgs)
    eval_path: str | None = None
    # https://hydra.cc/docs/1.3/configure_hydra/workdir/
    hydra: Any = field(
        default_factory=lambda: {
            'run': {
                'dir': 'outputs/${hydra:job.override_dirname}/fold_index=${data.fold_index}',
            },
            'sweep': {
                'dir': 'cross_validation_runs',
                # https://github.com/facebookresearch/hydra/issues/1786#issuecomment-1017005470
                'subdir': '${hydra:job.override_dirname}/fold_index=${data.fold_index}',
            },
            'job': {
                'config': {
                    'override_dirname': {
                        # Don't include fold_index and devices in the directory name
                        'exclude_keys': [
                            'data.fold_index',
                            'trainer.devices',
                        ],
                    },
                },
            },
        },
    )

get_model(cfg)

Returns a model based on the model name.

Parameters:

Name Type Description Default
cfg Args

Configuration object containing model parameters.

required

Returns:

Name Type Description
BaseModel BaseModel | BaseMLModel

An instance of the model class.

Source code in src/configs/main_config.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def get_model(cfg: Args) -> BaseModel | BaseMLModel:
    """
    Returns a model based on the model name.

    Args:
        cfg (Args): Configuration object containing model parameters.

    Returns:
        BaseModel: An instance of the model class.
    """

    model_class = ModelFactory.get(cfg.model.base_model_name)
    model = model_class(
        trainer_args=cfg.trainer,
        model_args=cfg.model,
        data_args=cfg.data,
    )

    if getattr(cfg.trainer, 'use_torch_compile', False):
        logger.info('Using torch.compile')
        model = torch.compile(
            model,
            mode='reduce-overhead',
        )

    return model