Skip to content

ahn_model

Ahn et al. baseline models Based on https://github.com/aeye-lab/etra-reading-comprehension/blob/main/ahn_baseline/evaluate_ahn_baseline.py https://github.com/ahnchive/SB-SAT/blob/master/model/model_training.ipynb

AhnCNNModel

Bases: AhnModel

CNN model for Ahn et al. baseline

Parameters:

Name Type Description Default
model_args AhnCNN

The model arguments.

required
trainer_args TrainerDL

The trainer arguments.

required
data_args DataArgs

The data arguments.

required
Source code in src/models/ahn_model.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
@register_model
class AhnCNNModel(AhnModel):
    """
    CNN model for Ahn et al. baseline

    Args:
        model_args (AhnCNN): The model arguments.
        trainer_args (TrainerDL): The trainer arguments.
        data_args (DataArgs): The data arguments.
    """

    def __init__(
        self,
        model_args: AhnCNN,
        trainer_args: TrainerDL,
        data_args: DataArgs,
    ):
        super().__init__(model_args, trainer_args, data_args=data_args)

        self.input_dim = self.input_dim
        self.model_args = model_args
        hidden_dim = self.model_args.hidden_dim
        kernel_size = self.model_args.conv_kernel_size
        fc_dropout = self.model_args.fc_dropout
        fc_hidden_dim1 = self.model_args.fc_hidden_dim1
        fc_hidden_dim2 = self.model_args.fc_hidden_dim2

        self.conv_model = nn.Sequential(
            # (batch size, number of features, max seq len)
            nn.Conv1d(
                in_channels=self.input_dim,
                out_channels=hidden_dim,
                kernel_size=kernel_size,
            ),  # (batch size, hidden_dim, max seq len - 2)
            nn.ReLU(),
            nn.Conv1d(
                in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=kernel_size
            ),  # (batch size, hidden_dim, max seq len - 4)
            nn.ReLU(),
            nn.Conv1d(
                in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=kernel_size
            ),  # (batch size, hidden_dim, max seq len - 6)
            nn.ReLU(),
            nn.MaxPool1d(
                kernel_size=self.model_args.pooling_kernel_size
            ),  # (batch size, hidden_dim, (max seq len -6) / 2)
            nn.Dropout(fc_dropout),  # (batch size, hidden_dim, (max seq len -6) / 2)
            nn.Flatten(),  # (batch size, hidden_dim * ((max seq len -6) / 2))
        )
        self.fc = nn.Sequential(
            nn.Linear(
                ((self.max_scanpath_length - 6) // 2) * hidden_dim, fc_hidden_dim1
            ),  # (batch size, 50)
            nn.ReLU(),
            nn.Dropout(fc_dropout),  # (batch size, 50)
            nn.Linear(fc_hidden_dim1, fc_hidden_dim2),  # (batch size, 20)
            nn.ReLU(),
            nn.Linear(fc_hidden_dim2, self.num_classes),  # (batch size, 2)
        )

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of the CNN model.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            tuple: A tuple containing the output tensor and hidden representations.
        """
        x = x.transpose(1, 2)  # (batch size, number of features, max seq len)
        hidden_representations = self.conv_model(x)
        x = self.fc(hidden_representations)
        return x, hidden_representations

forward(x)

Forward pass of the CNN model.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required

Returns:

Name Type Description
tuple tuple[Tensor, Tensor]

A tuple containing the output tensor and hidden representations.

Source code in src/models/ahn_model.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Forward pass of the CNN model.

    Args:
        x (torch.Tensor): The input tensor.

    Returns:
        tuple: A tuple containing the output tensor and hidden representations.
    """
    x = x.transpose(1, 2)  # (batch size, number of features, max seq len)
    hidden_representations = self.conv_model(x)
    x = self.fc(hidden_representations)
    return x, hidden_representations

AhnModel

Bases: BaseModel

Base model for Ahn et al.

Parameters:

Name Type Description Default
model_args AhnArgs

The model arguments.

required
trainer_args TrainerDL

The trainer arguments.

required
data_args DataArgs

The data arguments.

required
Source code in src/models/ahn_model.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class AhnModel(BaseModel):
    """
    Base model for Ahn et al.

    Args:
        model_args (AhnArgs): The model arguments.
        trainer_args (TrainerDL): The trainer arguments.
        data_args (DataArgs): The data arguments.
    """

    def __init__(
        self,
        model_args: Ahn,
        trainer_args: TrainerDL,
        data_args: DataArgs,
    ):
        super().__init__(
            model_args=model_args, trainer_args=trainer_args, data_args=data_args
        )
        self.model_args = model_args
        self.input_dim = (
            model_args.fixation_dim
            if model_args.use_fixation_report
            else model_args.eyes_dim
        )
        self.preorder = model_args.preorder
        self.model: nn.Module

        self.train()
        self.save_hyperparameters()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the model.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            torch.Tensor: The output tensor.
        """
        raise NotImplementedError

    def shared_step(
        self, batch: list
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Shared step for training and validation.

        Args:
            batch (list): The input batch.

        Returns:
            tuple: A tuple containing ordered labels, loss, ordered logits, labels, and logits.
        """
        batch_data = self.unpack_batch(batch)
        assert batch_data.fixation_features is not None, 'eyes_tensor not in batch_dict'
        labels = batch_data.labels
        logits, unused_hidden = self(x=batch_data.fixation_features)

        if logits.ndim == 1:
            logits = logits.unsqueeze(0)
        loss = self.loss(logits, labels)

        return labels, loss, logits

forward(x)

Forward pass of the model.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required

Returns:

Type Description
Tensor

torch.Tensor: The output tensor.

Source code in src/models/ahn_model.py
47
48
49
50
51
52
53
54
55
56
57
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass of the model.

    Args:
        x (torch.Tensor): The input tensor.

    Returns:
        torch.Tensor: The output tensor.
    """
    raise NotImplementedError

shared_step(batch)

Shared step for training and validation.

Parameters:

Name Type Description Default
batch list

The input batch.

required

Returns:

Name Type Description
tuple tuple[Tensor, Tensor, Tensor]

A tuple containing ordered labels, loss, ordered logits, labels, and logits.

Source code in src/models/ahn_model.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def shared_step(
    self, batch: list
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Shared step for training and validation.

    Args:
        batch (list): The input batch.

    Returns:
        tuple: A tuple containing ordered labels, loss, ordered logits, labels, and logits.
    """
    batch_data = self.unpack_batch(batch)
    assert batch_data.fixation_features is not None, 'eyes_tensor not in batch_dict'
    labels = batch_data.labels
    logits, unused_hidden = self(x=batch_data.fixation_features)

    if logits.ndim == 1:
        logits = logits.unsqueeze(0)
    loss = self.loss(logits, labels)

    return labels, loss, logits

AhnRNNModel

Bases: AhnModel

RNN model for Ahn et al. baseline

Parameters:

Name Type Description Default
model_args AhnRNN

The model arguments.

required
trainer_args TrainerDL

The trainer arguments.

required
data_args DataArgs

The data arguments.

required
Source code in src/models/ahn_model.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
@register_model
class AhnRNNModel(AhnModel):
    """
    RNN model for Ahn et al. baseline

    Args:
        model_args (AhnRNN): The model arguments.
        trainer_args (TrainerDL): The trainer arguments.
        data_args (DataArgs): The data arguments.
    """

    def __init__(
        self,
        model_args: AhnRNN,
        trainer_args: TrainerDL,
        data_args: DataArgs,
    ):
        super().__init__(model_args, trainer_args, data_args=data_args)
        self.lstm = nn.LSTM(
            input_size=self.input_dim,
            hidden_size=self.model_args.hidden_dim,
            bidirectional=True,
            batch_first=True,
            num_layers=model_args.num_lstm_layers,
        )
        self.fc = nn.Sequential(
            nn.Dropout(self.model_args.fc_dropout),  # (batch_size, hidden_size * 2)
            nn.Linear(
                model_args.hidden_dim * 2, model_args.hidden_dim * 2
            ),  # (batch_size, 50)
            nn.ReLU(),
            nn.Dropout(self.model_args.fc_dropout),
            nn.Linear(
                model_args.hidden_dim * 2,
                model_args.fc_hidden_dim,
            ),  # (batch_size, 2)
            nn.ReLU(),
            nn.Linear(model_args.fc_hidden_dim, self.num_classes),  # (batch_size, 2)
            nn.ReLU(),
        )

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass of the RNN model.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            tuple: A tuple containing the output tensor and hidden representations.
        """
        # take the last hidden state of the lstm
        x, _ = self.lstm(x)  # (batch_size, seq_len, hidden_size * 2)
        x = x[:, -1, :]  # (batch_size, hidden_size * 2)
        hidden_representations = x.clone()
        x = self.fc(x)  # (batch_size, 2)
        return x, hidden_representations

forward(x)

Forward pass of the RNN model.

Parameters:

Name Type Description Default
x Tensor

The input tensor.

required

Returns:

Name Type Description
tuple tuple[Tensor, Tensor]

A tuple containing the output tensor and hidden representations.

Source code in src/models/ahn_model.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Forward pass of the RNN model.

    Args:
        x (torch.Tensor): The input tensor.

    Returns:
        tuple: A tuple containing the output tensor and hidden representations.
    """
    # take the last hidden state of the lstm
    x, _ = self.lstm(x)  # (batch_size, seq_len, hidden_size * 2)
    x = x[:, -1, :]  # (batch_size, hidden_size * 2)
    hidden_representations = x.clone()
    x = self.fc(x)  # (batch_size, 2)
    return x, hidden_representations