huggingface利用bert-base-chinese实现中文情感分类-利用pytorch模式

先做一些数据预处理工作,本文主要使用的数据集是lansinuote/ChnSentiCorp

from transformers import BertTokenizer
token = BertTokenizer.from_pretrained('bert-base-chinese')

import torch
from datasets import load_dataset

dataset = load_dataset('lansinuote/ChnSentiCorp')
print(type(dataset))
class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        label = self.dataset[idx]['label']
        return text, label
dataset = Dataset(dataset['train'])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]

    #编码
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                   truncation=True,
                                   padding='max_length',
                                   max_length=500,
                                   return_tensors='pt',
                                   return_length=True)

    #input_ids:编码之后的数字
    #attention_mask:是补零的位置是0,其他位置是1
    input_ids = data['input_ids'].to(device)
    attention_mask = data['attention_mask'].to(device)
    token_type_ids = data['token_type_ids'].to(device)
    labels = torch.LongTensor(labels).to(device)

    #print(data['length'], data['length'].max())

    return input_ids, attention_mask, token_type_ids, labels

loader = torch.utils.data.DataLoader(dataset, batch_size=32, collate_fn=collate_fn, shuffle=True, drop_last=True)
len(loader)  # 计算数据集的批次数

引入bert-base-chinese模型

from transformers import BertModel

pretrained = BertModel.from_pretrained('bert-base-chinese').to(device)
sum(i.numel() for i in pretrained.parameters())/1e6  # 计算模型参数总数

for param in pretrained.parameters():
    param.requires_grad = False  # 冻结参数

模型后面添加几个层

class Model(torch.nn.Module):
    def __init__(self, pretrained):
        super(Model, self).__init__()
        self.bert = pretrained
        self.fn1 = torch.nn.Linear(768, 256)
        self.relu = torch.nn.ReLU()
        self.fn2 = torch.nn.Linear(256, 768)
        self.classifier = torch.nn.Linear(768, 2)  # 768是BERT的输出维度,2是分类数

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        #加两个线性层加一个ReLU激活
        output = self.fn1(output.last_hidden_state[:,0])
        output = self.relu(output)
        output = self.fn2(output)
        out = self.classifier(output)
        return out

定义训练器

from transformers import AdamW
from transformers.optimization import get_scheduler

def train():
    optimizer = AdamW(model.parameters(), lr=1e-5)
    criterion = torch.nn.CrossEntropyLoss()
    scheduler = get_scheduler("linear", optimizer=optimizer, 
                              num_training_steps=len(loader)*3,
                              num_warmup_steps=0)
    model.train()
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        if i % 10 == 0:
            out = outputs.argmax(dim=1)
            accuracy = (out == labels).sum().item() / len(labels)
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print(i, loss.item(), accuracy, lr)

开始训练

train()  # 开始训练

测试

def test():
    loader_test = torch.utils.data.DataLoader(
        Dataset(load_dataset('lansinuote/ChnSentiCorp')['test']),
        batch_size=32,
        collate_fn=collate_fn,
        shuffle=True,
        drop_last=True
    )
    model.eval()
    correct = 0
    total = 0
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        if i == 5: break  # 只测试前5个批次
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        out = outputs.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)
    print('Accuracy:', correct / total)
test()  # 开始测试
上一篇:Springboot与easypoi(2):合并单元格、二级表头、动态导出


下一篇:现在在哪还能申请到免费的通配符(泛域名)SSL证书?