ValueError: Expected target size (64, 31), got torch.Size([64, 63])

Pred-Esti介绍

Preictor-Estimator是一个两阶段的神经质量评估模型,它包括两个神经模型:

  • a predictor:词预测器,使用额外的大规模平行语料进行训练
  • an estimator:质量评估器,使用质量标注了的平行语料(QE data)训练

问题

在训练estimator模型时出现了这样的问题:
ValueError: Expected target size (64, 31), got torch.Size([64, 63])
这是因为pred_tagstags的维度不同的。
ValueError: Expected target size (64, 31), got torch.Size([64, 63])
如图所示:使用的wmt19数据中给出的标注的tags文件中包括MT tagsGap tags,而预测的tags中只有MT tags

pred_tags = [1, 0, 0, 0, 0, 0, 1, 0, 0]		#	OK:0;BAD:1
tags = [BAD, BAD, OK, BAD, BAD, OK, OK, OK, OK, OK, OK, OK, OK, BAD, OK, BAD, OK, OK, OK]

解决方法

# path_tags:MT tags + Gap tags
# path_target_tags:生成的MT tags
path_tags = "openkiwi/data/WMT19/wordsent_level/dev.tags"
path_target_tags = "openkiwi/data/WMT19/wordsent_level/dev.target_tags"

def target_tags(path1, path2):
    with open(path1, "r") as file:
        for line in file:
            array = line.strip().split(" ")[1::2]
            string = ' '.join(array)
            with open(path2, "a") as f:
                f.write(string + '\n')

target_tags(path_tags, path_target_tags)
pred_tags = [1, 0, 0, 0, 0, 0, 1, 0, 0]		#	OK:0;BAD:1
target_tags = [BAD, BAD, OK, OK, OK, OK, BAD, BAD, OK]
ValueError: Expected target size (64, 31), got torch.Size([64, 63])ValueError: Expected target size (64, 31), got torch.Size([64, 63]) weixin_39103096 发布了2 篇原创文章 · 获赞 0 · 访问量 25 私信 关注
上一篇:kafka监控实战(jmxtrans+InfluxDb+Grafana)


下一篇:MongoDB学习笔记5——Python和MongoDB