2021SC@SDUSC
mkiters函数也是dataset类中的一个重要的类函数,我的队友已经在她的博客中详细分析过这个函数,此处不再赘述。
def mktestset(self, args):
path = args.path.replace("train",'test')
fields=self.fields
ds = data.TabularDataset(path=path, format='tsv',fields=fields)
ds.fields["rawent"] = data.RawField()
for x in ds:
x.rawent = x.ent.split(" ; ")
x.ent = self.vec_ents(x.ent,self.ENT)
x.rel = self.mkGraphs(x.rel,len(x.ent[1]))
if args.sparse:
x.rel = (self.adjToSparse(x.rel[0]),x.rel[1])
x.tgt = x.out
x.out = [y.split("_")[0]+">" if "_" in y else y for y in x.out]
x.sordertgt = torch.LongTensor([int(y)+3 for y in x.sorder.split(" ")])
x.sorder = [[int(z) for z in y.strip().split(" ")] for y in x.sorder.split("-1")[:-1]]
ds.fields["tgt"] = self.TGT
ds.fields["rawent"] = data.RawField()
ds.fields["sordertgt"] = data.RawField()
dat_iter = data.Iterator(ds,1,device=args.device,sort_key=lambda x:len(x.src), train=False, sort=False)
return dat_iter
mktestset函数是dataset类中一个用来形成测试集的函数,对数据集进行遍历之后,返回一个迭代器。其余的函数都是对数据集进行一些修饰工作,不再一一展开详细分析。
让我们回到最最开始的地方, 继续分析train.py程序。之前我们详细分析了dataset类,而pargs.py由我的队友来着重分析,那么我们就继续看:
m = model(args)
我们开始分析model类。
class model(nn.Module):
首先我们看model类的init函数。
def __init__(self,args):
super().__init__()
self.args = args
cattimes = 3 if args.title else 2
self.emb = nn.Embedding(args.ntoks,args.hsz)
self.lstm = nn.LSTMCell(args.hsz*cattimes,args.hsz)
self.out = nn.Linear(args.hsz*cattimes,args.tgttoks)
self.le = list_encode(args)
self.entout = nn.Linear(args.hsz,1)
self.switch = nn.Linear(args.hsz*cattimes,1)
这个model类继承了torch.nn,其中的参数都是调用了torch.nn中的函数。cattimes是分类的次数,如果有标题,就设置为3,如果没有标题,就为2。emb为用args.ntoks和args.hsz组成的矩阵(args.ntoks是输出的vocab长度,会在pargs.py的代码分析中详细介绍)。lstm是用hsz和分类次数乘积作为构建LSTM中的一个Cell的输入特征维度,hsz作为构建LSTM中的一个Cell的隐状态的维度,torch.nn中的LSTM和LSTMCell的操作如下图:
out、entout、switch都是调用了nn.Linear()函数,其中的参数都是指维度,对二维变量进行线性变换,如图所示。
self.attn = MultiHeadAttention(args.hsz,args.hsz,args.hsz,h=4,dropout_p=args.drop)
self.mattn = MatrixAttn(args.hsz*cattimes,args.hsz)
self.graph = (args.model in ['graph','gat','gtrans'])
print(args.model)
MultiHeadAttention()是attention.py中的类,继承Module,这里的操作是返回一个连接后的4*4维度的attn。MatrixAttn()也是attention.py中的类,继承Module,这里是对hsz和分类次数的乘积和hsz作线性变换(如上图所示)。graph则是模型生成的图,然后在终端打印出来。