Skip to content

base_datamodule

Data module for creating the data.

ETDataModule

Bases: LightningDataModule

A PyTorch Lightning data module for the eye tracking data.

Attributes:

Name Type Description
cfg Args

The configuration object.

text_dataset_path Path

The path to the text dataset.

train_dataset ETDataset

The training dataset.

val_datasets list[ETDataset]

The validation datasets.

test_datasets list[ETDataset]

The test datasets.

Source code in src/data/datamodules/base_datamodule.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
class ETDataModule(pl.LightningDataModule):
    """
    A PyTorch Lightning data module for the eye tracking data.

    Attributes:
        cfg (Args): The configuration object.
        text_dataset_path (Path): The path to the text dataset.
        train_dataset (ETDataset): The training dataset.
        val_datasets (list[ETDataset]): The validation datasets.
        test_datasets (list[ETDataset]): The test datasets.
    """

    def __init__(self, cfg: Args):
        """
        Initialize the ETDataModule instance.

        Args:
            cfg (Args): The configuration object.
        """
        super().__init__()
        self.cfg = cfg

        self.train_dataset: ETDataset
        self.val_datasets: list[ETDataset]
        self.test_datasets: list[ETDataset]

        self.text_dataset_path = (
            FEATURES_CACHE_FOLDER
            / f'{cfg.data.dataset_name}_{cfg.data.task}_{cfg.model.model_name}'
            / 'TextDataSet.pkl'
        )

        self.save_hyperparameters(asdict(self.cfg))

    def setup(self, stage: str | None = None) -> None:
        """
        Set up the data module for training, validation, or testing.

        Args:
            stage (str | None): The stage of the setup. Can be "fit", "test", or "predict".
        """

        ia_scaler = self.cfg.model.normalization_type.value()
        fixation_scaler = self.cfg.model.normalization_type.value()
        trial_features_scaler = self.cfg.model.normalization_type.value()

        self.train_dataset = self.create_etdataset(
            ia_scaler=ia_scaler,
            fixation_scaler=fixation_scaler,
            trial_features_scaler=trial_features_scaler,
            set_name=SetNames.TRAIN,
            regime_name=SetNames.TRAIN,
        )

        if stage in {'fit', 'predict'}:
            self.val_datasets = [
                self.create_etdataset(
                    ia_scaler=self.train_dataset.ia_scaler,
                    fixation_scaler=self.train_dataset.fixation_scaler,
                    trial_features_scaler=self.train_dataset.trial_features_scaler,
                    regime_name=regime_name,
                    set_name=SetNames.VAL,
                )
                for regime_name in REGIMES
            ]

        if stage in {'test', 'predict'}:
            self.test_datasets = [
                self.create_etdataset(
                    ia_scaler=self.train_dataset.ia_scaler,
                    fixation_scaler=self.train_dataset.fixation_scaler,
                    trial_features_scaler=self.train_dataset.trial_features_scaler,
                    regime_name=regime_name,
                    set_name=SetNames.TEST,
                )
                for regime_name in REGIMES
            ]

    @abstractmethod
    def create_etdataset(
        self,
        ia_scaler: Scaler | None,
        fixation_scaler: Scaler | None,
        trial_features_scaler: Scaler | None,
        set_name: SetNames,
        regime_name: SetNames,
    ) -> ETDataset:
        """
        Abstract method to create an ETDataset instance.

        Args:
            ia_scaler (MinMaxScaler | RobustScaler | StandardScaler): The IA scaler.
            fixation_scaler (MinMaxScaler | RobustScaler | StandardScaler | None): Fixation scaler.
            trial_features_scaler (MinMaxScaler | RobustScaler | StandardScaler | None):
                The trial features scaler.
            regime_name (SetNames): The name of the regime (e.g., unseen_subject_seen_item).
            set_name (SetNames): The name of the set (e.g., train, test, val).

        Returns:
            ETDataset: The created ETDataset instance.
        """
        raise NotImplementedError('Subclasses must implement this method.')

    def create_dataloader(
        self,
        dataset,
        shuffle,
        sample_m_per_class: bool = False,
        drop_last: bool = False,
    ) -> DataLoader:
        """
        Create a DataLoader for the given dataset.

        Args:
            dataset (ETDataset): The dataset to create the DataLoader for.
            shuffle (bool): Whether to shuffle the data.

        Returns:
            DataLoader: The created DataLoader.
        """
        if sample_m_per_class:
            sampler = samplers.MPerClassSampler(
                labels=self.train_dataset.labels,
                m=1,
                length_before_new_iter=self.cfg.trainer.samples_per_epoch,
            )
            shuffle = None
            logger.info(
                f'Using MPerClassSampler with m=1 and {self.cfg.trainer.samples_per_epoch} samples per epoch. Shuffle is set to None.'
            )
        else:
            sampler = None

        return DataLoader(
            dataset,
            batch_size=self.cfg.model.batch_size,
            num_workers=self.cfg.trainer.num_workers,
            shuffle=shuffle,
            pin_memory=True,
            drop_last=drop_last,
            sampler=sampler,
        )

    def train_dataloader(self) -> DataLoader:
        """
        Create the DataLoader for the training dataset.

        Returns:
            DataLoader: The DataLoader for the training dataset.
        """
        return self.create_dataloader(
            self.train_dataset,
            shuffle=True,
            drop_last=False,
            sample_m_per_class=self.cfg.trainer.sample_m_per_class,
        )

    def val_dataloader(self) -> list[DataLoader]:
        """
        Create the DataLoader for the validation datasets.

        Returns:
            list[DataLoader]: A list of DataLoaders for the validation datasets.
        """
        return [
            self.create_dataloader(dataset, shuffle=False, drop_last=False)
            for dataset in self.val_datasets
        ]

    def test_dataloader(self) -> list[DataLoader]:
        """
        Create the DataLoader for the test datasets.

        Returns:
            list[DataLoader]: A list of DataLoaders for the test datasets.
        """
        return [
            self.create_dataloader(dataset, shuffle=False, drop_last=False)
            for dataset in self.test_datasets
        ]

    def predict_dataloader(self) -> list[DataLoader]:
        """
        Create the DataLoader for the prediction datasets.

        Returns:
            list[DataLoader]: A list of DataLoaders for the prediction datasets.
        """
        return self.val_dataloader() + self.test_dataloader()

    def prepare_data(self) -> None:
        """
        Prepare the data for the module.

        """

        self.text_dataset_create_if_needed()

    def text_dataset_create_if_needed(self) -> None:
        """
        If the text dataset does not exist or overwrite_data is True, create and save the text dataset.
        """

        if self.cfg.model.use_eyes_only:
            logger.info('Using eyes only, no text dataset will be created.')
            return

        if self.cfg.trainer.overwrite_data or not self.text_dataset_path.exists():
            self.text_dataset_path.parent.mkdir(parents=True, exist_ok=True)
            logger.info(f'Creating and saving textDataSet to {self.text_dataset_path}')
            # create and save to pkl
            text_data = TextDataSet(cfg=self.cfg)
            with open(self.text_dataset_path, 'wb') as f:
                pickle.dump(text_data, f)
        else:
            logger.info(
                f'TextDataSet already exists at: {self.text_dataset_path} and overwrite is False'
            )

    def load_text_dataset(self) -> TextDataSet:
        """
        Load the text dataset from a pickle file.

        Returns:
            TextDataSet: The loaded text dataset.
        """
        logger.info(f'Loading textDataSet from {self.text_dataset_path}')
        with open(self.text_dataset_path, 'rb') as f:
            text_data = pickle.load(f)
        return text_data

__init__(cfg)

Initialize the ETDataModule instance.

Parameters:

Name Type Description Default
cfg Args

The configuration object.

required
Source code in src/data/datamodules/base_datamodule.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def __init__(self, cfg: Args):
    """
    Initialize the ETDataModule instance.

    Args:
        cfg (Args): The configuration object.
    """
    super().__init__()
    self.cfg = cfg

    self.train_dataset: ETDataset
    self.val_datasets: list[ETDataset]
    self.test_datasets: list[ETDataset]

    self.text_dataset_path = (
        FEATURES_CACHE_FOLDER
        / f'{cfg.data.dataset_name}_{cfg.data.task}_{cfg.model.model_name}'
        / 'TextDataSet.pkl'
    )

    self.save_hyperparameters(asdict(self.cfg))

create_dataloader(dataset, shuffle, sample_m_per_class=False, drop_last=False)

Create a DataLoader for the given dataset.

Parameters:

Name Type Description Default
dataset ETDataset

The dataset to create the DataLoader for.

required
shuffle bool

Whether to shuffle the data.

required

Returns:

Name Type Description
DataLoader DataLoader

The created DataLoader.

Source code in src/data/datamodules/base_datamodule.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def create_dataloader(
    self,
    dataset,
    shuffle,
    sample_m_per_class: bool = False,
    drop_last: bool = False,
) -> DataLoader:
    """
    Create a DataLoader for the given dataset.

    Args:
        dataset (ETDataset): The dataset to create the DataLoader for.
        shuffle (bool): Whether to shuffle the data.

    Returns:
        DataLoader: The created DataLoader.
    """
    if sample_m_per_class:
        sampler = samplers.MPerClassSampler(
            labels=self.train_dataset.labels,
            m=1,
            length_before_new_iter=self.cfg.trainer.samples_per_epoch,
        )
        shuffle = None
        logger.info(
            f'Using MPerClassSampler with m=1 and {self.cfg.trainer.samples_per_epoch} samples per epoch. Shuffle is set to None.'
        )
    else:
        sampler = None

    return DataLoader(
        dataset,
        batch_size=self.cfg.model.batch_size,
        num_workers=self.cfg.trainer.num_workers,
        shuffle=shuffle,
        pin_memory=True,
        drop_last=drop_last,
        sampler=sampler,
    )

create_etdataset(ia_scaler, fixation_scaler, trial_features_scaler, set_name, regime_name) abstractmethod

Abstract method to create an ETDataset instance.

Parameters:

Name Type Description Default
ia_scaler MinMaxScaler | RobustScaler | StandardScaler

The IA scaler.

required
fixation_scaler MinMaxScaler | RobustScaler | StandardScaler | None

Fixation scaler.

required
trial_features_scaler MinMaxScaler | RobustScaler | StandardScaler | None

The trial features scaler.

required
regime_name SetNames

The name of the regime (e.g., unseen_subject_seen_item).

required
set_name SetNames

The name of the set (e.g., train, test, val).

required

Returns:

Name Type Description
ETDataset ETDataset

The created ETDataset instance.

Source code in src/data/datamodules/base_datamodule.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@abstractmethod
def create_etdataset(
    self,
    ia_scaler: Scaler | None,
    fixation_scaler: Scaler | None,
    trial_features_scaler: Scaler | None,
    set_name: SetNames,
    regime_name: SetNames,
) -> ETDataset:
    """
    Abstract method to create an ETDataset instance.

    Args:
        ia_scaler (MinMaxScaler | RobustScaler | StandardScaler): The IA scaler.
        fixation_scaler (MinMaxScaler | RobustScaler | StandardScaler | None): Fixation scaler.
        trial_features_scaler (MinMaxScaler | RobustScaler | StandardScaler | None):
            The trial features scaler.
        regime_name (SetNames): The name of the regime (e.g., unseen_subject_seen_item).
        set_name (SetNames): The name of the set (e.g., train, test, val).

    Returns:
        ETDataset: The created ETDataset instance.
    """
    raise NotImplementedError('Subclasses must implement this method.')

load_text_dataset()

Load the text dataset from a pickle file.

Returns:

Name Type Description
TextDataSet TextDataSet

The loaded text dataset.

Source code in src/data/datamodules/base_datamodule.py
264
265
266
267
268
269
270
271
272
273
274
def load_text_dataset(self) -> TextDataSet:
    """
    Load the text dataset from a pickle file.

    Returns:
        TextDataSet: The loaded text dataset.
    """
    logger.info(f'Loading textDataSet from {self.text_dataset_path}')
    with open(self.text_dataset_path, 'rb') as f:
        text_data = pickle.load(f)
    return text_data

predict_dataloader()

Create the DataLoader for the prediction datasets.

Returns:

Type Description
list[DataLoader]

list[DataLoader]: A list of DataLoaders for the prediction datasets.

Source code in src/data/datamodules/base_datamodule.py
226
227
228
229
230
231
232
233
def predict_dataloader(self) -> list[DataLoader]:
    """
    Create the DataLoader for the prediction datasets.

    Returns:
        list[DataLoader]: A list of DataLoaders for the prediction datasets.
    """
    return self.val_dataloader() + self.test_dataloader()

prepare_data()

Prepare the data for the module.

Source code in src/data/datamodules/base_datamodule.py
235
236
237
238
239
240
241
def prepare_data(self) -> None:
    """
    Prepare the data for the module.

    """

    self.text_dataset_create_if_needed()

setup(stage=None)

Set up the data module for training, validation, or testing.

Parameters:

Name Type Description Default
stage str | None

The stage of the setup. Can be "fit", "test", or "predict".

None
Source code in src/data/datamodules/base_datamodule.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def setup(self, stage: str | None = None) -> None:
    """
    Set up the data module for training, validation, or testing.

    Args:
        stage (str | None): The stage of the setup. Can be "fit", "test", or "predict".
    """

    ia_scaler = self.cfg.model.normalization_type.value()
    fixation_scaler = self.cfg.model.normalization_type.value()
    trial_features_scaler = self.cfg.model.normalization_type.value()

    self.train_dataset = self.create_etdataset(
        ia_scaler=ia_scaler,
        fixation_scaler=fixation_scaler,
        trial_features_scaler=trial_features_scaler,
        set_name=SetNames.TRAIN,
        regime_name=SetNames.TRAIN,
    )

    if stage in {'fit', 'predict'}:
        self.val_datasets = [
            self.create_etdataset(
                ia_scaler=self.train_dataset.ia_scaler,
                fixation_scaler=self.train_dataset.fixation_scaler,
                trial_features_scaler=self.train_dataset.trial_features_scaler,
                regime_name=regime_name,
                set_name=SetNames.VAL,
            )
            for regime_name in REGIMES
        ]

    if stage in {'test', 'predict'}:
        self.test_datasets = [
            self.create_etdataset(
                ia_scaler=self.train_dataset.ia_scaler,
                fixation_scaler=self.train_dataset.fixation_scaler,
                trial_features_scaler=self.train_dataset.trial_features_scaler,
                regime_name=regime_name,
                set_name=SetNames.TEST,
            )
            for regime_name in REGIMES
        ]

test_dataloader()

Create the DataLoader for the test datasets.

Returns:

Type Description
list[DataLoader]

list[DataLoader]: A list of DataLoaders for the test datasets.

Source code in src/data/datamodules/base_datamodule.py
214
215
216
217
218
219
220
221
222
223
224
def test_dataloader(self) -> list[DataLoader]:
    """
    Create the DataLoader for the test datasets.

    Returns:
        list[DataLoader]: A list of DataLoaders for the test datasets.
    """
    return [
        self.create_dataloader(dataset, shuffle=False, drop_last=False)
        for dataset in self.test_datasets
    ]

text_dataset_create_if_needed()

If the text dataset does not exist or overwrite_data is True, create and save the text dataset.

Source code in src/data/datamodules/base_datamodule.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def text_dataset_create_if_needed(self) -> None:
    """
    If the text dataset does not exist or overwrite_data is True, create and save the text dataset.
    """

    if self.cfg.model.use_eyes_only:
        logger.info('Using eyes only, no text dataset will be created.')
        return

    if self.cfg.trainer.overwrite_data or not self.text_dataset_path.exists():
        self.text_dataset_path.parent.mkdir(parents=True, exist_ok=True)
        logger.info(f'Creating and saving textDataSet to {self.text_dataset_path}')
        # create and save to pkl
        text_data = TextDataSet(cfg=self.cfg)
        with open(self.text_dataset_path, 'wb') as f:
            pickle.dump(text_data, f)
    else:
        logger.info(
            f'TextDataSet already exists at: {self.text_dataset_path} and overwrite is False'
        )

train_dataloader()

Create the DataLoader for the training dataset.

Returns:

Name Type Description
DataLoader DataLoader

The DataLoader for the training dataset.

Source code in src/data/datamodules/base_datamodule.py
188
189
190
191
192
193
194
195
196
197
198
199
200
def train_dataloader(self) -> DataLoader:
    """
    Create the DataLoader for the training dataset.

    Returns:
        DataLoader: The DataLoader for the training dataset.
    """
    return self.create_dataloader(
        self.train_dataset,
        shuffle=True,
        drop_last=False,
        sample_m_per_class=self.cfg.trainer.sample_m_per_class,
    )

val_dataloader()

Create the DataLoader for the validation datasets.

Returns:

Type Description
list[DataLoader]

list[DataLoader]: A list of DataLoaders for the validation datasets.

Source code in src/data/datamodules/base_datamodule.py
202
203
204
205
206
207
208
209
210
211
212
def val_dataloader(self) -> list[DataLoader]:
    """
    Create the DataLoader for the validation datasets.

    Returns:
        list[DataLoader]: A list of DataLoaders for the validation datasets.
    """
    return [
        self.create_dataloader(dataset, shuffle=False, drop_last=False)
        for dataset in self.val_datasets
    ]

ETDataModuleFast

Bases: ETDataModule

A subclass of ETDataModule that includes checks to prevent redundant data preparation and setup. Based on the solution provided in https://github.com/Lightning-AI/pytorch-lightning/issues/16005

Attributes:

Name Type Description
prepare_data_done bool

A flag indicating whether the prepare_data method has been called.

setup_stages_done set

A set storing the stages for which setup method has been called.

Source code in src/data/datamodules/base_datamodule.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
class ETDataModuleFast(ETDataModule):
    """
    A subclass of ETDataModule that includes checks to prevent redundant data preparation and setup.
    Based on the solution provided in https://github.com/Lightning-AI/pytorch-lightning/issues/16005

    Attributes:
        prepare_data_done (bool): A flag indicating whether the prepare_data method has been called.
        setup_stages_done (set): A set storing the stages for which setup method has been called.
    """

    def __init__(self, *args: object, **kwargs: object) -> None:
        """
        Initialize the ETDataModuleFast instance.

        Args:
            *args: Variable length argument list to be passed to the ETDataModule constructor.
            **kwargs: Arbitrary keyword arguments to be passed to the ETDataModule constructor.
        """
        super().__init__(*args, **kwargs)
        self.prepare_data_done = False
        self.setup_stages_done = set()

    def prepare_data(self) -> None:
        """
        Prepare data for the module. If this method has been called before, it does nothing.
        """
        if not self.prepare_data_done:
            super().prepare_data()
            self.prepare_data_done = True

    def setup(self, stage: str) -> None:
        """
        Set up the module for a specific stage.
            If this method has been called before for the same stage, it does nothing.

        Args:
            stage (str): The stage for which to set up the module.
        """
        if stage not in self.setup_stages_done:
            super().setup(stage)
            self.setup_stages_done.add(stage)

__init__(*args, **kwargs)

Initialize the ETDataModuleFast instance.

Parameters:

Name Type Description Default
*args object

Variable length argument list to be passed to the ETDataModule constructor.

()
**kwargs object

Arbitrary keyword arguments to be passed to the ETDataModule constructor.

{}
Source code in src/data/datamodules/base_datamodule.py
287
288
289
290
291
292
293
294
295
296
297
def __init__(self, *args: object, **kwargs: object) -> None:
    """
    Initialize the ETDataModuleFast instance.

    Args:
        *args: Variable length argument list to be passed to the ETDataModule constructor.
        **kwargs: Arbitrary keyword arguments to be passed to the ETDataModule constructor.
    """
    super().__init__(*args, **kwargs)
    self.prepare_data_done = False
    self.setup_stages_done = set()

prepare_data()

Prepare data for the module. If this method has been called before, it does nothing.

Source code in src/data/datamodules/base_datamodule.py
299
300
301
302
303
304
305
def prepare_data(self) -> None:
    """
    Prepare data for the module. If this method has been called before, it does nothing.
    """
    if not self.prepare_data_done:
        super().prepare_data()
        self.prepare_data_done = True

setup(stage)

Set up the module for a specific stage. If this method has been called before for the same stage, it does nothing.

Parameters:

Name Type Description Default
stage str

The stage for which to set up the module.

required
Source code in src/data/datamodules/base_datamodule.py
307
308
309
310
311
312
313
314
315
316
317
def setup(self, stage: str) -> None:
    """
    Set up the module for a specific stage.
        If this method has been called before for the same stage, it does nothing.

    Args:
        stage (str): The stage for which to set up the module.
    """
    if stage not in self.setup_stages_done:
        super().setup(stage)
        self.setup_stages_done.add(stage)