基于ALBert及BiLSTM进行中文文本分类的通用过程。
1.语料准备(基于csv文件),语料类
import os
from typing import List
from typing import Tuple
import numpy as np
import pandas as pd
from tensorflow.keras.utils import get_file
from kashgari import macros as K
from kashgari import utils
from kashgari.logger import logger
from kashgari.tokenizers.base_tokenizer import Tokenizer
from kashgari.tokenizers.bert_tokenizer import BertTokenizer
class Corpus:
def __init__(self,
corpus_train_csv_path: str,
sample_count: int = None,
tokenizer: Tokenizer = None) -> None:
self.file_path = corpus_train_csv_path
self.train_ids = []
self.test_ids = []
self.valid_ids = []
self.tokenizer: Tokenizer
if tokenizer is None:
self.tokenizer = BertTokenizer()
else:
self.tokenizer = tokenizer
if sample_count is None:
df = pd.read_csv(self.file_path)
sample_count = len(df)
del df
self.sample_count = sample_count
for i in range(self.sample_count):
prob = np.random.random()
if prob <= 0.7:
self.train_ids.append(i)
elif prob <= 0.85:
self.test_ids.append(i)
else:
self.valid_ids.append(i)
@classmethod
def _extract_label(cls, row: pd.Series) -> List[str]:
return list(labels).index(row['labels'])
def _text_process(self, text: str) -> List[str]:
return self.tokenizer.tokenize(str(text))
def load_data(self,
subset_name: str = 'train',
shuffle: bool = True) -> Tuple[List[List[str]], List[List[str]]]:
df = pd.read_csv(self.file_path)
df = df[:self.sample_count]
df['y'] = df.apply(self._extract_label, axis=1)
df['x'] = df['contents'].apply(self._text_process)
df = df[['x', 'y']]
if subset_name == 'train':
df = df.loc[self.train_ids]
elif subset_name == 'valid':
df = df.loc[self.valid_ids]
else:
df = df.loc[self.test_ids]
xs, ys = list(df['x'].values), list(df['y'].values)
if shuffle:
xs, ys = utils.unison_shuffled_copies(xs, ys)
return xs, ys
corpus = Corpus('./corpus.csv')#columns:contents,labels
2.生成训练集、验证集及测试集
train_x, train_y = corpus.load_data('train')
valid_x, valid_y = corpus.load_data('valid')
test_x, test_y = corpus.load_data('test')
3.训练模型
from kashgari.tasks.classification import BiLSTM_Model
from kashgari.embeddings import BertEmbedding,TransformerEmbedding
from keras.callbacks import EarlyStopping
model_folder='./albert_small_zh_google'
vocab_path = os.path.join(model_folder, 'vocab.txt')
config_path = os.path.join(model_folder, 'albert_config_small_google.json')
checkpoint_path = os.path.join(model_folder, 'albert_model.ckpt')
embedding = TransformerEmbedding(vocab_path,config_path,checkpoint_path,'albert')
model=BiLSTM_Model(embedding)
callbacks_list = [EarlyStopping(monitor='val_accuracy', patience=5)]
model.fit(train_x,train_y,valid_x, valid_y, batch_size=64,epochs=100,callbacks=callbacks_list)
4.测试模型效果
succ=0
fail=0
for x in test_x:
y=test_y[test_x.index(x)]
if y== model.predict([x])[0]:
succ+=1
else:
fail+=1
print('成功率:',succ/(succ+fail))