arghosh / AKT

MIT License
93 stars 31 forks source link

target response issue in AKT model #13

Closed skewondr closed 2 years ago

skewondr commented 2 years ago

Hello, I want to ask your opinion on the AKT model architecture.

image

the image above is the figure of AKT model represented in your paper

    if self.n_pid > 0:
        q_embed_diff_data = self.q_embed_diff(q_data)  # d_ct
        pid_embed_data = self.difficult_param(pid_data)  # uq
        q_embed_data = q_embed_data + pid_embed_data * \
            q_embed_diff_data  # uq *d_ct + c_ct
        qa_embed_diff_data = self.qa_embed_diff(
            qa_data)  # f_(ct,rt) or #h_rt
        if self.separate_qa:
            qa_embed_data = qa_embed_data + pid_embed_data * \
                qa_embed_diff_data  # uq* f_(ct,rt) + e_(ct,rt)
        else:
            qa_embed_data = qa_embed_data + pid_embed_data * \
                (qa_embed_diff_data+q_embed_diff_data)  # + uq *(h_rt+d_ct)
        c_reg_loss = (pid_embed_data ** 2.).sum() * self.l2

and the code above is what you implemented at akt.py.

The point is that I think AKT model has a chance to know the target answers with "f(c_t, r_t) variation vector" (at the paper), which is "qa_embed_diff_data" (at your code). In my opinion, this is related to already-known target issue.

To resolve the issue, I carefully suggest modifying Architecture forward function as the following code:

        else:  # dont peek current response
            pad_zero = torch.zeros(batch_size, 1, x.size(-1)).to(self.device)
            q = x
            k = torch.cat([pad_zero, x[:, :-1, :]], dim=1)
            v = torch.cat([pad_zero, y[:, :-1, :]], dim=1)
            x = block(mask=0, query=q, key=k, values=v, apply_pos=True) 
            flag_first = True

thank you for your attention :)