# -*- coding: UTF-8 -*-  
import xml.etree.ElementTree as ET
from os import getcwd
import glob
import shutil

# classes = ["mark", "front", "back", "top", "bottle"]
classes = ['car', 'van', 'bus', 'feright_car', 'truck']

xml_path = "./datasets/DroneVehicle/train/trainlabel/"
out_path = "./datasets/DroneVehicle/train/labels/"

def convert_annotation(xml_name):
    # in_file = open('./indata/' + image_name[:-3] + 'xml')  # xml文件路径

    f = open(xml_path + xml_name, 'r', encoding='utf-8')
    xml_text = f.read()
    root = ET.fromstring(xml_text)
    f.close()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)


    for obj in root.iter('object'):
        cls = obj.find('name').text
        if cls == "feright car":
            cls = "feright_car"
        if cls not in classes:
            continue
        if w == 0 or h == 0:
            continue
        out_file = open(out_path + xml_name[:-4] + '.txt', 'a+')  # 转换后的txt文件存放路径
        cls_id = classes.index(cls)
        xmlbox = obj.find('polygon')
        xmlbox_ver = obj.find('bndbox')
        if xmlbox: # 斜矩形框
            obb_label = " ".join(['%.7f' % (float(xmlbox.find('x1').text) / w), '%.7f' % (float(xmlbox.find('y1').text) / h), '%.7f' % (float(xmlbox.find('x2').text) / w), '%.7f' % (float(xmlbox.find('y2').text) / h),
                '%.7f' % (float(xmlbox.find('x3').text) / w), '%.7f' % (float(xmlbox.find('y3').text) / h), '%.7f' % (float(xmlbox.find('x4').text) / w), '%.7f' % (float(xmlbox.find('y4').text) / h)])
        elif xmlbox_ver: # 正矩形框
            xmin = float(xmlbox_ver.find('xmin').text)
            ymin = float(xmlbox_ver.find('ymin').text)
            xmax = float(xmlbox_ver.find('xmax').text)
            ymax = float(xmlbox_ver.find('ymax').text)
            x1, y1 = xmin, ymin
            x2, y2 = xmax, ymin
            x3, y3 = xmin, ymax
            x4, y4 = xmax, ymax
            obb_label = " ".join(['%.7f' % (x1 / w), '%.7f' % (y1 / h), '%.7f' % (x2 / w), '%.7f' % (y2 / h),
                                 '%.7f' % (x3 / w), '%.7f' % (y3 / h), '%.7f' % (x4 / w), '%.7f' % (y4 / h)])
        else:
            print("error box")
            # input()
            continue
        out_file.write(str(cls_id) + " " + obb_label + '\n')
        out_file.close()



if __name__ == '__main__':
    print("Begin->")
    for xml_dir in glob.glob(xml_path + "*.xml"):  # 遍历每一个xml文件
        print(xml_dir)
        xml_name = xml_dir.split('/')[-1]
        print("xml_name--->", xml_name)
        convert_annotation(xml_name)
    print("End->")

Logo

技术共进,成长同行——讯飞AI开发者社区

更多推荐