Skip to content

base

DatasetProcessor

Base class for dataset processors

Source code in src/data/preprocessing/dataset_preprocessing/base.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
class DatasetProcessor:
    """Base class for dataset processors"""

    def __init__(self, data_args: DataArgs):
        self.data_args = data_args

    def process(self) -> dict[str, pd.DataFrame]:
        """Process the dataset"""
        # Load raw data
        raw_data = {}
        if self.data_args.raw_ia_path:
            raw_data[DataType.IA] = self.load_raw_data(self.data_args.raw_ia_path)
        else:
            raw_data[DataType.IA] = None

        if self.data_args.raw_fixations_path:
            raw_data[DataType.FIXATIONS] = self.load_raw_data(
                self.data_args.raw_fixations_path
            )
        else:
            raw_data[DataType.FIXATIONS] = None

        # Standardize column names
        processed_data = {}
        for data_type, df in raw_data.items():
            if df is not None:
                processed_data[data_type] = self.standardize_column_names(
                    df, data_type=data_type
                )

        # Dataset-specific processing
        processed_data = self.dataset_specific_processing(processed_data)

        # Filter data
        for data_type in processed_data:
            # Trial level are computed, so we want to keep them as is
            if data_type != DataType.TRIAL_LEVEL:
                processed_data[data_type] = self.filter_data(processed_data[data_type])

        return processed_data

    def load_raw_data(self, data_path: Path) -> pd.DataFrame:
        return load_raw_data(data_path)

    def standardize_column_names(
        self, df: pd.DataFrame, data_type: DataType
    ) -> pd.DataFrame:
        """Standardize column names for the dataset"""
        # Map of original column names to standardized names
        column_map = self.get_column_map(data_type)

        if column_map:
            # Create a dictionary of only the columns that exist in the dataframe
            valid_columns = {k: v for k, v in column_map.items() if k in df.columns}
            logger.info(
                f'Standardizing column names for {self.data_args.dataset_name} {data_type}: {valid_columns}'
            )
            return df.rename(columns=valid_columns)

        logger.info(
            f'{self.data_args.dataset_name} not found in column maps. No changes made.'
        )
        return df

    def filter_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """Filter data based on dataset-specific criteria"""
        columns_to_keep = self.get_columns_to_keep()

        if columns_to_keep:
            missing_columns = [col for col in columns_to_keep if col not in df.columns]
            if missing_columns:
                logger.warning(
                    f'Missing columns not found in the dataframe: {missing_columns}'
                )
                logger.warning(f'All possible columns: {list(df.columns)}')
            existing_columns = [col for col in columns_to_keep if col in df.columns]
            df = df[existing_columns]
            logger.info(
                f'Filtering columns for {self.data_args.dataset_name}: {existing_columns}'
            )

        return df

    def save_processed_data(self, processed_data: dict[str, pd.DataFrame]) -> None:
        """
        Save processed data to processed data folder.

        Args:
            processed_data: Dictionary containing processed dataframes
        """
        self.data_args.processed_data_path.mkdir(parents=True, exist_ok=True)

        # Save each dataframe to the processed folder
        for data_type, df in processed_data.items():
            if df is not None:
                output_path = (
                    self.data_args.processed_data_path / f'{data_type}.feather'
                )
                df.to_feather(output_path)
                logger.info(f'Saved {data_type} to {output_path}')

    @abstractmethod
    def get_column_map(self, data_type: DataType) -> dict:
        """Get column mapping for the dataset"""
        return {}

    @abstractmethod
    def get_columns_to_keep(self) -> list:
        """Get list of columns to keep after filtering"""
        return []

    @abstractmethod
    def dataset_specific_processing(
        self, data_dict: dict[str, pd.DataFrame]
    ) -> dict[str, pd.DataFrame]:
        """Dataset-specific processing steps"""
        # Can use the following for surprisal and other metric calculations
        # from text_metrics.merge_metrics_with_eye_movements import (
        #     add_metrics_to_word_level_eye_tracking_report,
        # )
        # from text_metrics.surprisal_extractors import extractor_switch
        return data_dict

dataset_specific_processing(data_dict) abstractmethod

Dataset-specific processing steps

Source code in src/data/preprocessing/dataset_preprocessing/base.py
129
130
131
132
133
134
135
136
137
138
139
@abstractmethod
def dataset_specific_processing(
    self, data_dict: dict[str, pd.DataFrame]
) -> dict[str, pd.DataFrame]:
    """Dataset-specific processing steps"""
    # Can use the following for surprisal and other metric calculations
    # from text_metrics.merge_metrics_with_eye_movements import (
    #     add_metrics_to_word_level_eye_tracking_report,
    # )
    # from text_metrics.surprisal_extractors import extractor_switch
    return data_dict

filter_data(df)

Filter data based on dataset-specific criteria

Source code in src/data/preprocessing/dataset_preprocessing/base.py
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def filter_data(self, df: pd.DataFrame) -> pd.DataFrame:
    """Filter data based on dataset-specific criteria"""
    columns_to_keep = self.get_columns_to_keep()

    if columns_to_keep:
        missing_columns = [col for col in columns_to_keep if col not in df.columns]
        if missing_columns:
            logger.warning(
                f'Missing columns not found in the dataframe: {missing_columns}'
            )
            logger.warning(f'All possible columns: {list(df.columns)}')
        existing_columns = [col for col in columns_to_keep if col in df.columns]
        df = df[existing_columns]
        logger.info(
            f'Filtering columns for {self.data_args.dataset_name}: {existing_columns}'
        )

    return df

get_column_map(data_type) abstractmethod

Get column mapping for the dataset

Source code in src/data/preprocessing/dataset_preprocessing/base.py
119
120
121
122
@abstractmethod
def get_column_map(self, data_type: DataType) -> dict:
    """Get column mapping for the dataset"""
    return {}

get_columns_to_keep() abstractmethod

Get list of columns to keep after filtering

Source code in src/data/preprocessing/dataset_preprocessing/base.py
124
125
126
127
@abstractmethod
def get_columns_to_keep(self) -> list:
    """Get list of columns to keep after filtering"""
    return []

process()

Process the dataset

Source code in src/data/preprocessing/dataset_preprocessing/base.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def process(self) -> dict[str, pd.DataFrame]:
    """Process the dataset"""
    # Load raw data
    raw_data = {}
    if self.data_args.raw_ia_path:
        raw_data[DataType.IA] = self.load_raw_data(self.data_args.raw_ia_path)
    else:
        raw_data[DataType.IA] = None

    if self.data_args.raw_fixations_path:
        raw_data[DataType.FIXATIONS] = self.load_raw_data(
            self.data_args.raw_fixations_path
        )
    else:
        raw_data[DataType.FIXATIONS] = None

    # Standardize column names
    processed_data = {}
    for data_type, df in raw_data.items():
        if df is not None:
            processed_data[data_type] = self.standardize_column_names(
                df, data_type=data_type
            )

    # Dataset-specific processing
    processed_data = self.dataset_specific_processing(processed_data)

    # Filter data
    for data_type in processed_data:
        # Trial level are computed, so we want to keep them as is
        if data_type != DataType.TRIAL_LEVEL:
            processed_data[data_type] = self.filter_data(processed_data[data_type])

    return processed_data

save_processed_data(processed_data)

Save processed data to processed data folder.

Parameters:

Name Type Description Default
processed_data dict[str, DataFrame]

Dictionary containing processed dataframes

required
Source code in src/data/preprocessing/dataset_preprocessing/base.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def save_processed_data(self, processed_data: dict[str, pd.DataFrame]) -> None:
    """
    Save processed data to processed data folder.

    Args:
        processed_data: Dictionary containing processed dataframes
    """
    self.data_args.processed_data_path.mkdir(parents=True, exist_ok=True)

    # Save each dataframe to the processed folder
    for data_type, df in processed_data.items():
        if df is not None:
            output_path = (
                self.data_args.processed_data_path / f'{data_type}.feather'
            )
            df.to_feather(output_path)
            logger.info(f'Saved {data_type} to {output_path}')

standardize_column_names(df, data_type)

Standardize column names for the dataset

Source code in src/data/preprocessing/dataset_preprocessing/base.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def standardize_column_names(
    self, df: pd.DataFrame, data_type: DataType
) -> pd.DataFrame:
    """Standardize column names for the dataset"""
    # Map of original column names to standardized names
    column_map = self.get_column_map(data_type)

    if column_map:
        # Create a dictionary of only the columns that exist in the dataframe
        valid_columns = {k: v for k, v in column_map.items() if k in df.columns}
        logger.info(
            f'Standardizing column names for {self.data_args.dataset_name} {data_type}: {valid_columns}'
        )
        return df.rename(columns=valid_columns)

    logger.info(
        f'{self.data_args.dataset_name} not found in column maps. No changes made.'
    )
    return df