2021SC@SDUSC
util包的分析
这篇博客将剩下的util包的内容讲解完成
util_json
JSON (JavaScript Object Notation) 是一种轻量级的数据交换格式。Python3 中可以使用 json 模块来对 JSON 数据进行编解码,它主要提供了四个方法: dumps
、dump
、loads
、load
。
我们在这里为了和前端发送信息,主要用的是json.dumps
方法,这里就简单介绍一下即可
json.dump(obj, fp, *, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, cls=None, indent=None, separators=None, default=None, sort_keys=False, **kw)
obj
: 表示是要序列化的对象。
fp
: 文件描述符,将序列化的str保存到文件中。json模块总是生成str对象,而不是字节对象;因此,fp.write()必须支持str输入。
skipkeys
: 默认为False,如果skipkey=True
,(默认值:False),则将跳过不是基本类型(str,int,float,bool,None)的dict
键,不会引发TypeError。
sort_keys
: 默认值为False,如果sort_keys为True,则字典的输出将按键值排序。
util_plot
这个包主要是用来对画图的相关代码,但是画图的代码我们前面都已经介绍过了,所以这篇博客里我们主要介绍一些其他的配置设定。
首先是颜色的配色方面,通过对相关期刊上比较好的配色图取色,我们再加以改进,最后得到了相关的配色字典dict,如下,可以进行采用。
还有一部分就是画图代码的调用
我们用了一个字典来表示画图代号
image_type = {
'all_need': ['draw_umap', 'draw_shap', 'draw_ROC_PRC_curve', 'draw_negative_density', 'draw_positive_density',
'draw_tra_ROC_PRC_curve', 'draw_result_histogram','epoch_plot'],
'1_in_3': {'prot': 'draw_hist_image', 'DNA': 'draw_dna_hist_image', 'RNA': 'draw_rna_hist_image'},
False: ['draw_dna_rna_prot_length_distribution_image']}
然后再调用的时候只要eval
相关的画图函数就可以完成,这里简答又可以方便管理和使用。
相关的代码和注释都已经补充在下面了,感兴趣的可以逐步分析这看看
def draw_plots(data, config):
# 为了保证函数的优雅性,需要将画图函数的参数统一修改为data1,data2,config三个
# 把每个函数的plt.show()取消掉
# 测试数据中config参数不全&train\test data格式不对
# 传参给此函数时data为字典:{type:[[[data1],[data2]],[[data1],[data2]]...]}
# 问题:柱状图为什么plt.show()了两次
tag = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']
if config['model_number'] % 2 == 0:
col = 2
else:
col = 3
# print(config['model_number'])
# fig1 = plt.figure()
if config['model_number'] == 2 or config['model_number'] == 3:
fig1 = plt.figure(figsize=(10, 3))
# print(2,config['model_number'])
else:
fig1 = plt.figure()
# print('hhhh')
if config['model_number'] == 1:
# pass
eval(image_type['all_need'][0])(data['all_need'][0][0][0], data['all_need'][0][1][0], config)
# l1, = plt.plot([], [], 'o', color='#00beca', label='positive')
# l2, = plt.plot([], [], 'o', color='#f87671', label='negative')
# plt.legend(bbox_to_anchor=(1, 0), loc=3, borderaxespad=0)
plt.savefig('{}/{}.{}'.format(config['savepath'], 'UMAP', 'png'))
plt.figure()
eval(image_type['all_need'][1])(data['all_need'][1][0][0], data['all_need'][1][1][0], config)
plt.savefig('{}/{}.{}'.format(config['savepath'], 'SHAP', 'png'))
# plt.show()
else:
for i in range(config['model_number']):
ax = fig1.add_subplot(int(config['model_number'] / col), col, i + 1)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
eval(image_type['all_need'][0])(data['all_need'][0][0][i], data['all_need'][0][1][i], config)
# if i == 2:
# fig1.set_figheight(10)
# fig1.set_figwidth(3)
# plt.axis('square')
# ax.title.set_text(config['names'][type_list[i]])
ax.set_title(config['names'][i])
trans = ts.ScaledTranslation(-20 / 72, 7 / 72, fig1.dpi_scale_trans)
if config['model_number'] == 2:
ax.text(0.05, 1.0, tag[i], transform=ax.transAxes + trans, fontweight='bold')
else:
ax.text(0.1, 1.0, tag[i], transform=ax.transAxes + trans, fontweight='bold')
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.2, hspace=0.4)
# l1, = plt.plot([], [], 'o', color='#00beca', label='positive')
# l2, = plt.plot([], [], 'o', color='#f87671', label='negative')
# plt.legend(bbox_to_anchor=(1, 0), loc=3, borderaxespad=0)
# plt.tight_layout()
plt.savefig('{}/{}.{}'.format(config['savepath'], 'UMAP', 'png'))
# shap plot
fig2 = plt.figure()
for i in range(config['model_number']):
# ax = fig.add_subplot(1, 3, i + 1)
ax = fig2.add_subplot(int(config['model_number'] / col), col, i + 1)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
eval(image_type['all_need'][1])(data['all_need'][1][0][i], data['all_need'][1][1][i], config)
ax.set_xlabel('SHAP value of ' + config['names'][i])
trans = ts.ScaledTranslation(-20 / 72, 7 / 72, fig2.dpi_scale_trans)
ax.text(-0.1, 1.0, tag[i], transform=ax.transAxes + trans, fontweight='bold',
fontdict={'weight': 'bold', 'size': 14})
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=1.1, hspace=0.3)
plt.savefig('{}/{}.{}'.format(config['savepath'], 'SHAP', 'png'))
# plt.show()
# 其他都需要画的图
for img in range(2, len(image_type['all_need'])):
eval(image_type['all_need'][img])(data['all_need'][img][0], data['all_need'][img][1], config)
# 3选1
# print(data['1_in_3'][0])
eval(image_type['1_in_3'][config['type']])(data['1_in_3'][0], data['1_in_3'][1], config)
# 画motif还是长度分布
if config['if_same']:
print("start plot motif")
motif_title = ['Motif statistics of the positives (Train)', 'Motif statistics of the negatives (Train)',
'Motif statistics of the positives (Test)', 'Motif statistics of the negatives (Test)']
for i in range(4):
motif = "/home/weilab/anaconda3/envs/wy/bin/weblogo --resolution 500 --format PNG -f " + \
config['fasta_list'][i] + " -o " + config['savepath'] + "/motif_" + (str)(
i) + ".png" + " --title " + " ' " + motif_title[i] + " ' "
os.system(motif)
else:
eval(image_type[False][0])(data['1_in_3'][0], data['1_in_3'][1], config)