DOTA数据集分割,并将txt转为xml

文章目录

DOTA数据集简介

DOTA数据集包含2806张航空图像,尺寸大约从800x800到4000x4000不等,包含15个类别共计188282个实例。其标注方式为四点确定的任意形状和方向的四边形(区别于传统的对边平行bbox)。类别分别为:plane, ship, storage tank, baseball dia- mond, tennis court, swimming pool, ground track field, har- bor, bridge, large vehicle, small vehicle, helicopter, round- about, soccer ball field , basketball court。

可以看出DOTA数据集里的照片,有的尺寸非常大,而且照片大小也统一,因此我们在进行目标检测网络训练时需要对照片进行分割,下面为分割代码:

照片分割

import cv2
import os


def tianchong_you(img):
    size = img.shape
    # if size[0]>=608 and size[1]<608:
    # 这里的大小可以自己设定,但是尽量是32的倍数,并且与txt分割保持一致
    constant = cv2.copyMakeBorder(img, 0, 0, 0, 608 - size[1], cv2.BORDER_CONSTANT,
                                  value = (107, 113, 115))  # 填充值为数据集均值
    # else:
    #    print('图像不符合要求')
    #   return 0
    return constant


def tianchong_xia(img):
    size = img.shape
    # if size[0]<608 and size[1]>=608:
    constant = cv2.copyMakeBorder(img, 0, 608 - size[0], 0, 0, cv2.BORDER_CONSTANT, value = (107, 113, 115))
    # else:
    #    print('图像不符合要求')
    #   return 0
    return constant


def tianchong_xy(img):
    size = img.shape
    # if size[0]<608 and size[1]<608:
    constant = cv2.copyMakeBorder(img, 0, 608 - size[0], 0, 608 - size[1], cv2.BORDER_CONSTANT,
                                  value = (107, 113, 115))
    # else:
    #      print('图像不符合要求')
    #      return 0
    return constant


def caijian(path, path_out, size_w = 608, size_h = 608, step = 576):  # 重叠度为32
    ims_list = os.listdir(path)
    # print(ims_list)
    count = 0
    for im_list in ims_list:
        number = 0
        name = im_list.split('.')[0]  # 去处“.tiff后缀”
        img = cv2.imread(ims_path + im_list)
        size = img.shape
        if size[0] >= 608 and size[1] >= 608:
            count = count + 1
            for h in range(0, size[0] - 1, step):
                star_h = h
                for w in range(0, size[1] - 1, step):
                    star_w = w
                    end_h = star_h + size_h
                    if end_h > size[0]:
                        star_h = size[0] - size_h
                        end_h = star_h + size_h
                    end_w = star_w + size_w
                    if end_w > size[1]:
                        star_w = size[1] - size_w
                    end_w = star_w + size_w
                    cropped = img[star_h:end_h, star_w:end_w]
                    name_img = name + '_' + str(star_h) + '_' + str(star_w)  # 用起始坐标来命名切割得到的图像,为的是方便后续标签数据抓取
                    cv2.imwrite('{}/{}.jpg'.format(path_out, name_img), cropped)
                    number = number + 1
        if size[0] >= 608 and size[1] < 608:
            print('图片{}需要在右面补齐'.format(name))
            count = count + 1
            img0 = tianchong_you(img)
            for h in range(0, size[0] - 1, step):
                star_h = h
                star_w = 0
                end_h = star_h + size_h
                if end_h > size[0]:
                    star_h = size[0] - size_h
                    end_h = star_h + size_h
                end_w = star_w + size_w
                cropped = img0[star_h:end_h, star_w:end_w]
                name_img = name + '_' + str(star_h) + '_' + str(star_w)
                cv2.imwrite('{}/{}.jpg'.format(path_out, name_img), cropped)
                number = number + 1
        if size[0] < 608 and size[1] >= 608:
            count = count + 1
            print('图片{}需要在下面补齐'.format(name))
            img0 = tianchong_xia(img)
            for w in range(0, size[1] - 1, step):
                star_h = 0
                star_w = w
                end_w = star_w + size_w
                if end_w > size[1]:
                    star_w = size[1] - size_w
                    end_w = star_w + size_w
                end_h = star_h + size_h
                cropped = img0[star_h:end_h, star_w:end_w]
                name_img = name + '_' + str(star_h) + '_' + str(star_w)
                cv2.imwrite('{}/{}.jpg'.format(path_out, name_img), cropped)
                number = number + 1
        if size[0] < 608 and size[1] < 608:
            count = count + 1
            print('图片{}需要在下面和右面补齐'.format(name))
            img0 = tianchong_xy(img)
            cropped = img0[0:608, 0:608]
            name_img = name + '_' + '0' + '_' + '0'
            cv2.imwrite('{}/{}.jpg'.format(path_out, name_img), cropped)	# 注意修改照片格式
            number = number + 1
        print('图片{}切割成{}张'.format(name, number))
        print('共完成{}张图片'.format(count))


if __name__ == '__main__':
    ims_path = 'G:/cj/JPEGImages1/'  # 图像数据集的路径
    # txt_path = 'G:/cj/Annotations/'
    path = 'G:/cj/JPEGImages/'  # 切割得到的数据集存放路径
    caijian(ims_path, path, size_w = 608, size_h = 608, step = 576)

txt分割

import cv2
import os

category_set = ['plane', 'baseball-diamond', 'bridge', 'ground-track-field',
                'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
                'basketball-court', 'storage-tank', 'soccer-ball-field',
                'roundabout', 'harbor', 'swimming-pool', 'helicopter']


def tqtxt(path, path_txt, path_out, size_h = 608, size_w = 608):
    ims_list = os.listdir(path)
    for im_list in ims_list:
        name_list = []
        name = im_list.split('.')[0]
        name_list = name.split('_')
        if len(name_list) < 2:
            continue
        h = int(name_list[1])
        w = int(name_list[2])
        txtpath = path_txt + name_list[0] + '.txt'
        txt_outpath = path_out + name + '.txt'
        f = open(txt_outpath, 'a')
        with open(txtpath, 'r') as f_in:  # 打开txt文件
            i = 0
            lines = f_in.readlines()
            # print(len(lines))
            # splitlines = [x.strip().split(' ') for x in lines]  #根据空格分割
            for line in lines:
                if i in [0, 1]:
                    f.write(line)  # txt前两行直接复制过去
                    i = i + 1
                    continue
                splitline = line.split(' ')
                label = splitline[8]
                kunnan = splitline[9]
                if label not in category_set:  # 只书写指定的类别
                    continue
                x1 = int(float(splitline[0]))
                y1 = int(float(splitline[1]))
                x2 = int(float(splitline[2]))
                y2 = int(float(splitline[3]))
                x3 = int(float(splitline[4]))
                y3 = int(float(splitline[5]))
                x4 = int(float(splitline[6]))
                y4 = int(float(splitline[7]))
                if w <= x1 <= w + size_w and w <= x2 <= w + size_w and w <= x3 <= w + size_w and w <= x4 <= w + size_w and h <= y1 <= h + size_h and h <= y2 <= h + size_h and h <= y3 <= h + size_h and h <= y4 <= h + size_h:
                    f.write('{} {} {} {} {} {} {} {} {} {}'.format(float(x1 - w), float(y1 - h), float(x2 - w),
                                                                   float(y2 - h), float(x3 - w), float(y3 - h),
                                                                   float(x4 - w), float(y4 - h), label, kunnan))
                print('转换成功')
        f.close()


if __name__ == '__main__':
    ims_path = 'G:/cj/JPEGImages/'  # 图像数据集的路径
    txt_path = 'G:/cj/Annotations1/'  # 原数据集标签文件
    path = 'G:/cj/Annotations2/'  # 切割后数据集的标签文件存放路径
    tqtxt(ims_path, txt_path, path, size_h = 608, size_w = 608)

txt转xml

import os
from xml.dom.minidom import Document
from xml.dom.minidom import parse
import xml.dom.minidom
import numpy as np
import csv
import cv2
import string


def WriterXMLFiles(filename, path, box_list, label_list, w, h, d):
    # dict_box[filename]=json_dict[filename]
    doc = xml.dom.minidom.Document()
    root = doc.createElement('annotation')
    doc.appendChild(root)

    foldername = doc.createElement("folder")
    foldername.appendChild(doc.createTextNode("JPEGImages"))
    root.appendChild(foldername)

    nodeFilename = doc.createElement('filename')
    nodeFilename.appendChild(doc.createTextNode(filename))
    root.appendChild(nodeFilename)

    pathname = doc.createElement("path")
    pathname.appendChild(doc.createTextNode("xxxx"))
    root.appendChild(pathname)

    sourcename = doc.createElement("source")

    databasename = doc.createElement("database")
    databasename.appendChild(doc.createTextNode("Unknown"))
    sourcename.appendChild(databasename)

    annotationname = doc.createElement("annotation")
    annotationname.appendChild(doc.createTextNode("xxx"))
    sourcename.appendChild(annotationname)

    imagename = doc.createElement("image")
    imagename.appendChild(doc.createTextNode("xxx"))
    sourcename.appendChild(imagename)

    flickridname = doc.createElement("flickrid")
    flickridname.appendChild(doc.createTextNode("0"))
    sourcename.appendChild(flickridname)

    root.appendChild(sourcename)

    nodesize = doc.createElement('size')
    nodewidth = doc.createElement('width')
    nodewidth.appendChild(doc.createTextNode(str(w)))
    nodesize.appendChild(nodewidth)
    nodeheight = doc.createElement('height')
    nodeheight.appendChild(doc.createTextNode(str(h)))
    nodesize.appendChild(nodeheight)
    nodedepth = doc.createElement('depth')
    nodedepth.appendChild(doc.createTextNode(str(d)))
    nodesize.appendChild(nodedepth)
    root.appendChild(nodesize)

    segname = doc.createElement("segmented")
    segname.appendChild(doc.createTextNode("0"))
    root.appendChild(segname)

    for (box, label) in zip(box_list, label_list):
        nodeobject = doc.createElement('object')
        nodename = doc.createElement('name')
        nodename.appendChild(doc.createTextNode(str(label)))
        nodeobject.appendChild(nodename)
        nodebndbox = doc.createElement('bndbox')
        nodex1 = doc.createElement('x1')
        nodex1.appendChild(doc.createTextNode(str(box[0])))
        nodebndbox.appendChild(nodex1)
        nodey1 = doc.createElement('y1')
        nodey1.appendChild(doc.createTextNode(str(box[1])))
        nodebndbox.appendChild(nodey1)
        nodex2 = doc.createElement('x2')
        nodex2.appendChild(doc.createTextNode(str(box[2])))
        nodebndbox.appendChild(nodex2)
        nodey2 = doc.createElement('y2')
        nodey2.appendChild(doc.createTextNode(str(box[3])))
        nodebndbox.appendChild(nodey2)
        nodex3 = doc.createElement('x3')
        nodex3.appendChild(doc.createTextNode(str(box[4])))
        nodebndbox.appendChild(nodex3)
        nodey3 = doc.createElement('y3')
        nodey3.appendChild(doc.createTextNode(str(box[5])))
        nodebndbox.appendChild(nodey3)
        nodex4 = doc.createElement('x4')
        nodex4.appendChild(doc.createTextNode(str(box[6])))
        nodebndbox.appendChild(nodex4)
        nodey4 = doc.createElement('y4')
        nodey4.appendChild(doc.createTextNode(str(box[7])))
        nodebndbox.appendChild(nodey4)

        # ang = doc.createElement('angle')
        # ang.appendChild(doc.createTextNode(str(angle)))
        # nodebndbox.appendChild(ang)
        nodeobject.appendChild(nodebndbox)
        root.appendChild(nodeobject)
    fp = open(path + filename, 'w')
    doc.writexml(fp, indent = '\n')
    fp.close()


def load_annoataion(p):
    '''
    load annotation from the text file
    :param p:
    :return:
    '''
    text_polys = []
    text_tags = []
    if not os.path.exists(p):
        return np.array(text_polys, dtype = np.float32)
    with open(p, 'r') as f:
        for line in f.readlines()[2:]:
            label = 'text'
            # strip BOM. \ufeff for python3,  \xef\xbb\bf for python2
            # line = [i.strip('\ufeff').strip('\xef\xbb\xbf') for i in line]
            # print(line)

            x1, y1, x2, y2, x3, y3, x4, y4, label = line.split(' ')[0:9]
            # print(label)
            text_polys.append([x1, y1, x2, y2, x3, y3, x4, y4])
            text_tags.append(label)

        return np.array(text_polys, dtype = np.float), np.array(text_tags, dtype = np.str)


txt_path = 'G:/cj/Annotations2/'
xml_path = 'G:/cj/Annotations/'
img_path = 'G:/cj/JPEGImages/'
print(os.path.exists(txt_path))
txts = os.listdir(txt_path)
for count, t in enumerate(txts):
    boxes, labels = load_annoataion(os.path.join(txt_path, t))
    xml_name = t.replace('.txt', '.xml')
    img_name = t.replace('.txt', '.jpg') # 注意修改照片格式
    print(img_name)
    img = cv2.imread(os.path.join(img_path, img_name))
    print(xml_name, xml_path, boxes, labels)
    h, w, d = img.shape
    print(xml_name, xml_path, boxes, labels, w, h, d)
    WriterXMLFiles(xml_name, xml_path, boxes, labels, w, h, d)
    if count % 1000 == 0:
        print(count)
上一篇:前端导出下载文件功能


下一篇:安装Nginx