from torch import nn
class TextClassificationModel(nn.Module):
def __init__(self, num_class):
super(TextClassificationModel, self).__init__()
self.fc = nn.Linear(100, num_class)
def forward(self, text):
text = text.float()
return self.fc(text)
在这组词汇中不匹配的词汇:书
初始化模型
num_class = len(label_name)
vocab_size = 100000
em_size = 12
model = TextClassificationModel(num_class).to(device)
定义训练及评估函数
import time
def train(dataloader):
model.train()
total_acc, train_loss, total_count = 0,0,0
log_interval = 50
start_time = time.time()
for idx, (text,label) in enumerate(dataloader):
predicted_label = model(text)
optimizer.zero_grad()
loss = criterion(predicted_label, label)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
optimizer.step()
total_acc += (predicted_label.argmax(1) == label).sum().item()
train_loss += loss.item()
total_count += label.size(0)
if idx % log_interval == 0 and idx > 0:
elapsed = time.time() - start_time
print('| epoch {:1d} | {:4d}/{:4d} batches'
'| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),
total_acc/total_count, train_loss/total_count))
total_acc, train_loss, total_count = 0,0,0
start_time = time.time()
def evaluate(dataloader):
model.eval()
total_acc,train_loss, total_count = 0,0,0
with torch.no_grad():
for idx, (text,label) in enumerate(dataloader):
predicted_label = model(text)
loss = criterion(predicted_label, label)
total_acc += (predicted_label.argmax(1) == label).sum().item()
train_loss += loss.item()
total_count += label.size(0)
return total_acc/total_count, train_loss/total_count