以遥感DIOR数据集为例,其标注文件为.xml格式,例
想把某一类从中取出来并生成针对此类的mask,实现方法是将.xml转化为json后读取object中的内容,将boundingbox的值取出生成mask图,需要用到的包如下
import simplejson
import xmltodict
import numpy as np
f.open()取出.xml数据,转换为json并读取字符
xmlparse = xmltodict.parse(xmlstr)
jsonstr = simplejson.dumps(xmlparse,indent=1)
simplejson_list = simplejson.loads(jsonstr, encoding='utf-8', strict=False)
取出object信息
annotation_objs = simplejson_list['annotation']['object']
由于DIOR数据集中一张图上可能有多个目标,故需要循环查找所需类别。在上个语句中,若一张图中只有一个目标,则annotation_objs为dict类型,若有多个目标则为list,所以要针对这两种情况分别讨论,不能直接循环查找。先判断是否为list,再循环list中每个obj的信息,将名字和boundingbox取出,对mask赋值
if isinstance(annotation_objs, list): # 判断是否为List
len_obj = len(annotation_objs)
for i in range(len_obj):
obj_name = annotation_objs[i]['name']
if obj_name == obj_str:
k = k+1
obj_bnd = annotation_objs[i]['bndbox']
xmin = int(obj_bnd['xmin'])
ymin = int(obj_bnd['ymin'])
xmax = int(obj_bnd['xmax'])
ymax = int(obj_bnd['ymax'])
mask[ymin:ymax, xmin:xmax] = 1
else: # 若不为list则只有一个obj,为dict
obj_name = annotation_objs['name']
if obj_name == obj_str:
k = k + 1
obj_bnd = annotation_objs['bndbox']
xmin = int(obj_bnd['xmin'])
ymin = int(obj_bnd['ymin'])
xmax = int(obj_bnd['xmax'])
ymax = int(obj_bnd['ymax'])
mask[ymin:ymax, xmin:xmax] = 1
其中k为是否含所需类别的flag,若k不等于0则说明本张图含有所需目标。
完整代码如下:
# -*- coding: utf-8 -*-
import simplejson
import xmltodict
import numpy as np
import cv2
import os
#定义xml转json的函数
def xmltojson(xmlstr, k, filename):
xmlparse = xmltodict.parse(xmlstr)
jsonstr = simplejson.dumps(xmlparse,indent=1)
simplejson_list = simplejson.loads(jsonstr, encoding='utf-8', strict=False)
mask_dir = './' # mask保存路径
image_dir = './' # 含所需目标的原图保存路径
imgread_dir = './' # 原图读取路径
mask = np.zeros((800, 800))
annotation_objs = simplejson_list['annotation']['object']
obj_str = 'ship' # 所需目标类别
# 先判断图像中是否有多个目标
if isinstance(annotation_objs, list):
len_obj = len(annotation_objs)
for i in range(len_obj):
obj_name = annotation_objs[i]['name']
if obj_name == obj_str:
k = k+1
obj_bnd = annotation_objs[i]['bndbox']
xmin = int(obj_bnd['xmin'])
ymin = int(obj_bnd['ymin'])
xmax = int(obj_bnd['xmax'])
ymax = int(obj_bnd['ymax'])
mask[ymin:ymax, xmin:xmax] = 1
else:
obj_name = annotation_objs['name']
if obj_name == obj_str:
k = k + 1
obj_bnd = annotation_objs['bndbox']
xmin = int(obj_bnd['xmin'])
ymin = int(obj_bnd['ymin'])
xmax = int(obj_bnd['xmax'])
ymax = int(obj_bnd['ymax'])
mask[ymin:ymax, xmin:xmax] = 1
if k != 0:
name = os.path.splitext(filename)[0] + '.jpg'
img_rgb = cv2.imread(imgread_dir +name)
mask_name = mask_dir + name
img_name = image_dir + name
cv2.imwrite(mask_name, mask*255)
cv2.imwrite(img_name, img_rgb)
if __name__=="__main__":
for filename in os.listdir(r"./"): # listdir的参数是.xml文件夹的路径
print(filename)
f = open('./' + filename) # 读取.xml
data = f.read()
k = 0
xmltojson(data, k, filename)