trainers
This module contains dataclasses for configuring trainers.
The module defines a hierarchy of configuration classes:
- BaseTrainer: The root configuration class with common attributes
- TrainerDL: Configuration for deep learning trainers, inheriting from BaseTrainer
- TrainerML: Configuration for machine learning trainers, inheriting from BaseTrainer
Each configuration class is defined using the @dataclass decorator and specifies
the relevant attributes and their default values.
The @register_config decorator is used to register the configuration classes with a
specific group defined by ConfigName.TRAINER.
BaseTrainer
dataclass
Base configuration class for all trainers.
This class defines common attributes shared by both deep learning and machine learning trainers.
Attributes:
| Name | Type | Description |
|---|---|---|
num_workers |
int
|
Number of worker processes for data loading. Default is 4. |
profiler |
str | None
|
Profiler to use ('simple', 'advanced', or None). Default is None. |
precision |
Precision
|
Numerical precision for training. Default is Precision.THIRTY_TWO_TRUE. |
float32_matmul_precision |
MatmulPrecisionLevel
|
Matrix multiplication precision level. Default is MatmulPrecisionLevel.HIGH. |
seed |
int
|
Random seed for reproducibility. Default is 42. |
devices |
Any
|
Device configuration for training. Default is 1. |
run_mode |
RunModes
|
Mode for running the trainer (e.g., 'train', 'test', 'debug'). Default is RunModes.TRAIN. |
wandb_job_type |
str
|
Type of job for Weights & Biases logging. Default is "MISSING". |
wandb_project |
str
|
Weights & Biases project name. Default is "reading-comprehension-from-eye-movements". |
wandb_entity |
str
|
Weights & Biases entity name. Default is "EyeRead". |
wandb_notes |
str
|
Additional notes for Weights & Biases logging. Default is an empty string. |
overwrite_data |
bool
|
If True, overwrites the relevant TextDataSet and ETDataset. features even if they exist. |
Source code in src/configs/trainers.py
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | |
__post_init__()
Post-initialization hook to adjust attributes based on the run mode.
If the run mode is set to 'debug', the number of workers is set to 0 and the Weights & Biases job type is set to "debug".
Source code in src/configs/trainers.py
71 72 73 74 75 76 77 78 79 80 81 82 83 84 | |
TrainerDL
dataclass
Bases: BaseTrainer
Configuration class for deep learning trainers.
Inherits from BaseTrainer and adds specific attributes for deep learning models.
Attributes:
| Name | Type | Description |
|---|---|---|
learning_rate |
float
|
Optimizer learning rate. Must be specified by derived classes. |
gradient_clip_val |
float | None
|
Gradient clipping value. Default is None. |
accelerator |
Accelerators
|
Accelerator to use (e.g., 'cpu', 'gpu', 'tpu'). Default is Accelerators.AUTO. |
log_gradients |
bool
|
Whether to log gradients. Default is False. |
optimize_for_loss |
bool
|
Whether to optimize for loss instead of metrics. Default is True. |
Source code in src/configs/trainers.py
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | |
TrainerML
dataclass
Bases: BaseTrainer
Configuration class for machine learning trainers. Inherits from BaseTrainer and adds specific attributes for machine learning models.
Source code in src/configs/trainers.py
112 113 114 115 116 117 118 | |