Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Add SKT model #34

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@

[Jie Ouyang](https://github.com/0russwest0)

[Weizhe Huang](https://github.com/weizhehuang0827)

The starred is the corresponding author
16 changes: 10 additions & 6 deletions EduKTM/GKT/GKT.py
weizhehuang0827 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def __init__(self, ku_num, graph, hidden_num, net_params: dict = None, loss_para
self.loss_params = loss_params if loss_params is not None else {}

def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
loss_function = SLMLoss(**self.loss_params)
loss_function = SLMLoss(**self.loss_params).to(device)
self.gkt_model = self.gkt_model.to(device)
trainer = torch.optim.Adam(self.gkt_model.parameters(), lr)

for e in range(epoch):
Expand All @@ -39,9 +40,11 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
label_mask: torch.Tensor = label_mask.to(device)

# real training
predicted_response, _ = self.gkt_model(question, data, data_mask)
predicted_response, _ = self.gkt_model(
question, data, data_mask)

loss = loss_function(predicted_response, pick_index, label, label_mask)
loss = loss_function(predicted_response,
pick_index, label, label_mask)

# back propagation
trainer.zero_grad()
Expand All @@ -52,8 +55,9 @@ def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.00
print("[Epoch %d] SLMoss: %.6f" % (e, float(np.mean(losses))))

if test_data is not None:
auc, accuracy = self.eval(test_data)
print("[Epoch %d] auc: %.6f, accuracy: %.6f" % (e, auc, accuracy))
auc, accuracy = self.eval(test_data, device=device)
print("[Epoch %d] auc: %.6f, accuracy: %.6f" %
(e, auc, accuracy))

def eval(self, test_data, device="cpu") -> tuple:
self.gkt_model.eval()
Expand All @@ -75,7 +79,7 @@ def eval(self, test_data, device="cpu") -> tuple:
output = pick(output, pick_index.to(output.device))
pred = tensor2list(output)
label = tensor2list(label)
for i, length in enumerate(label_mask.numpy().tolist()):
for i, length in enumerate(label_mask.cpu().tolist()):
length = int(length)
y_true.extend(label[i][:length])
y_pred.extend(pred[i][:length])
Expand Down
54 changes: 35 additions & 19 deletions EduKTM/GKT/GKTNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ class GKTNet(nn.Module):
def __init__(self, ku_num, graph, hidden_num=None, latent_dim=None, dropout=0.0):
super(GKTNet, self).__init__()
self.ku_num = int(ku_num)
self.hidden_num = self.ku_num if hidden_num is None else int(hidden_num)
self.latent_dim = self.ku_num if latent_dim is None else int(latent_dim)
self.hidden_num = self.ku_num if hidden_num is None else int(
hidden_num)
self.latent_dim = self.ku_num if latent_dim is None else int(
latent_dim)
self.neighbor_dim = self.hidden_num + self.latent_dim
self.graph = nx.DiGraph()
self.graph.add_nodes_from(list(range(ku_num)))
Expand All @@ -25,10 +27,12 @@ def __init__(self, ku_num, graph, hidden_num=None, latent_dim=None, dropout=0.0)
self.graph.add_weighted_edges_from(json.load(f))
except ValueError:
with open(graph) as f:
self.graph.add_weighted_edges_from([e + [1.0] for e in json.load(f)])
self.graph.add_weighted_edges_from(
[e + [1.0] for e in json.load(f)])

self.rnn = GRUCell(self.hidden_num)
self.response_embedding = nn.Embedding(2 * self.ku_num, self.latent_dim)
self.response_embedding = nn.Embedding(
2 * self.ku_num, self.latent_dim)
self.concept_embedding = nn.Embedding(self.ku_num, self.latent_dim)
self.f_self = nn.Linear(self.neighbor_dim, self.hidden_num)
self.n_out = nn.Linear(2 * self.neighbor_dim, self.hidden_num)
Expand All @@ -38,7 +42,7 @@ def __init__(self, ku_num, graph, hidden_num=None, latent_dim=None, dropout=0.0)

def in_weight(self, x, ordinal=True, with_weight=True):
if isinstance(x, torch.Tensor):
x = x.numpy().tolist()
x = x.cpu().numpy().tolist()
if isinstance(x, list):
return [self.in_weight(_x) for _x in x]
elif isinstance(x, (int, float)):
Expand All @@ -57,7 +61,7 @@ def in_weight(self, x, ordinal=True, with_weight=True):

def out_weight(self, x, ordinal=True, with_weight=True):
if isinstance(x, torch.Tensor):
x = x.numpy().tolist()
x = x.cpu().numpy().tolist()
if isinstance(x, list):
return [self.out_weight(_x) for _x in x]
elif isinstance(x, (int, float)):
Expand All @@ -76,7 +80,7 @@ def out_weight(self, x, ordinal=True, with_weight=True):

def neighbors(self, x, ordinal=True, with_weight=False):
if isinstance(x, torch.Tensor):
x = x.numpy().tolist()
x = x.cpu().numpy().tolist()
if isinstance(x, list):
return [self.neighbors(_x) for _x in x]
elif isinstance(x, (int, float)):
Expand All @@ -95,10 +99,12 @@ def neighbors(self, x, ordinal=True, with_weight=False):

def forward(self, questions, answers, valid_length=None, compressed_out=True, layout="NTC"):
length = questions.shape[1]
inputs, axis, batch_size = format_sequence(length, questions, layout, False)
device = questions.device
inputs, axis, batch_size = format_sequence(
length, questions, layout, False)
answers, _, _ = format_sequence(length, answers, layout, False)

states = begin_states([(batch_size, self.ku_num, self.hidden_num)])[0]
states = states.to(device)
outputs = []
all_states = []
for i in range(length):
Expand All @@ -107,12 +113,15 @@ def forward(self, questions, answers, valid_length=None, compressed_out=True, la
answer_i = answers[i].reshape([batch_size, ])

_neighbors = self.neighbors(inputs_i)
neighbors_mask = expand_tensor(torch.Tensor(_neighbors), -1, self.hidden_num)
_neighbors_mask = expand_tensor(torch.Tensor(_neighbors), -1, self.hidden_num + self.latent_dim)
neighbors_mask = expand_tensor(torch.tensor(
_neighbors, device=device), -1, self.hidden_num)
_neighbors_mask = expand_tensor(torch.tensor(
_neighbors, device=device), -1, self.hidden_num + self.latent_dim)

# get concept embedding
concept_embeddings = self.concept_embedding.weight.data
concept_embeddings = expand_tensor(concept_embeddings, 0, batch_size)
concept_embeddings = expand_tensor(
concept_embeddings, 0, batch_size)

agg_states = torch.cat((concept_embeddings, states), dim=-1)

Expand All @@ -121,20 +130,25 @@ def forward(self, questions, answers, valid_length=None, compressed_out=True, la

# self - aggregate
_concept_embedding = get_states(inputs_i, states)
_self_hidden_states = torch.cat((_concept_embedding, self.response_embedding(answer_i)), dim=-1)
_self_hidden_states = torch.cat(
(_concept_embedding, self.response_embedding(answer_i)), dim=-1)

_self_mask = F.one_hot(inputs_i, self.ku_num) # p
_self_mask = expand_tensor(_self_mask, -1, self.hidden_num)

self_hidden_states = expand_tensor(_self_hidden_states, 1, self.ku_num)
self_hidden_states = expand_tensor(
_self_hidden_states, 1, self.ku_num)

# aggregate
_hidden_states = torch.cat((_neighbors_states, self_hidden_states), dim=-1)
_hidden_states = torch.cat(
(_neighbors_states, self_hidden_states), dim=-1)

_in_state = self.n_in(_hidden_states)
_out_state = self.n_out(_hidden_states)
in_weight = expand_tensor(torch.Tensor(self.in_weight(inputs_i)), -1, self.hidden_num)
out_weight = expand_tensor(torch.Tensor(self.out_weight(inputs_i)), -1, self.hidden_num)
in_weight = expand_tensor(torch.tensor(self.in_weight(
inputs_i), device=device), -1, self.hidden_num)
out_weight = expand_tensor(torch.tensor(self.out_weight(
inputs_i), device=device), -1, self.hidden_num)

next_neighbors_states = in_weight * _in_state + out_weight * _out_state

Expand All @@ -146,7 +160,8 @@ def forward(self, questions, answers, valid_length=None, compressed_out=True, la
next_states = neighbors_mask * next_neighbors_states + next_self_states

next_states, _ = self.rnn(next_states, [states])
next_states = (_self_mask + neighbors_mask) * next_states + (1 - _self_mask - neighbors_mask) * states
next_states = (_self_mask + neighbors_mask) * \
next_states + (1 - _self_mask - neighbors_mask) * states

states = self.dropout(next_states)
output = torch.sigmoid(self.out(states).squeeze(axis=-1)) # p
Expand All @@ -157,6 +172,7 @@ def forward(self, questions, answers, valid_length=None, compressed_out=True, la
if valid_length is not None:
if compressed_out:
states = None
outputs = mask_sequence_variable_length(torch, outputs, length, valid_length, axis, merge=True)
outputs = mask_sequence_variable_length(
torch, outputs, length, valid_length, axis, merge=True)

return outputs, states
94 changes: 94 additions & 0 deletions EduKTM/SKT/SKT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# coding: utf-8
# 2023/3/17 @ weizhehuang0827

import logging
import numpy as np
import torch
from tqdm import tqdm
from EduKTM import KTM
from .SKTNet import SKTNet
from EduKTM.utils import SLMLoss, tensor2list, pick
from sklearn.metrics import roc_auc_score, accuracy_score


class SKT(KTM):
def __init__(self, ku_num, graph_params, hidden_num, net_params: dict = None, loss_params=None):
super(SKT, self).__init__()
self.skt_model = SKTNet(
ku_num,
graph_params,
hidden_num,
**(net_params if net_params is not None else {})
)
self.loss_params = loss_params if loss_params is not None else {}

def train(self, train_data, test_data=None, *, epoch: int, device="cpu", lr=0.001) -> ...:
loss_function = SLMLoss(**self.loss_params).to(device)
self.skt_model = self.skt_model.to(device)
trainer = torch.optim.Adam(self.skt_model.parameters(), lr)

for e in range(epoch):
losses = []
for (question, data, data_mask, label, pick_index, label_mask) in tqdm(train_data, "Epoch %s" % e):
# convert to device
question: torch.Tensor = question.to(device)
data: torch.Tensor = data.to(device)
data_mask: torch.Tensor = data_mask.to(device)
label: torch.Tensor = label.to(device)
pick_index: torch.Tensor = pick_index.to(device)
label_mask: torch.Tensor = label_mask.to(device)

# real training
predicted_response, _ = self.skt_model(
question, data, data_mask)

loss = loss_function(predicted_response,
pick_index, label, label_mask)

# back propagation
trainer.zero_grad()
loss.backward()
trainer.step()

losses.append(loss.mean().item())
print("[Epoch %d] SLMoss: %.6f" % (e, float(np.mean(losses))))

if test_data is not None:
auc, accuracy = self.eval(test_data, device=device)
print("[Epoch %d] auc: %.6f, accuracy: %.6f" %
(e, auc, accuracy))

def eval(self, test_data, device="cpu") -> tuple:
self.skt_model.eval()
y_true = []
y_pred = []

for (question, data, data_mask, label, pick_index, label_mask) in tqdm(test_data, "evaluating"):
# convert to device
question: torch.Tensor = question.to(device)
data: torch.Tensor = data.to(device)
data_mask: torch.Tensor = data_mask.to(device)
label: torch.Tensor = label.to(device)
pick_index: torch.Tensor = pick_index.to(device)
label_mask: torch.Tensor = label_mask.to(device)

# real evaluating
output, _ = self.skt_model(question, data, data_mask)
output = output[:, :-1]
output = pick(output, pick_index.to(output.device))
pred = tensor2list(output)
label = tensor2list(label)
for i, length in enumerate(label_mask.cpu().tolist()):
length = int(length)
y_true.extend(label[i][:length])
y_pred.extend(pred[i][:length])
self.skt_model.train()
return roc_auc_score(y_true, y_pred), accuracy_score(y_true, np.array(y_pred) >= 0.5)

def save(self, filepath) -> ...:
torch.save(self.skt_model.state_dict(), filepath)
logging.info("save parameters to %s" % filepath)

def load(self, filepath):
self.skt_model.load_state_dict(torch.load(filepath))
logging.info("load parameters from %s" % filepath)
Loading