目前主流的Object Detection模型,如YOLO系列(YOLOv3, YOLOv5, YOLOv7, YOLOv8),Faster R-CNN,Mask R-CNN,和DETR等,在训练和验证阶段,数据集格式有不同。
YOLO系列采用的是YOLO格式数据集,MMDetection实现的Faster R-CNN等模型采用的是COCO格式数据集。
这里记录一些不同数据集格式之间转换的代码。
YOLOv5格式数据集的目录结构为:
COCO格式数据集
转换代码如下:
import os
import json
import cv2
def yolo_to_coco(yolo_file, img_id, ann_id_start, img_width, img_height):
annotations = []
with open(yolo_file, 'r') as f:
lines = f.readlines()
for line in lines:
class_id, center_x, center_y, width, height = map(float, line.strip().split())
# Convert YOLO format (normalized) to COCO format (absolute)
x = (center_x - width / 2) * img_width
y = (center_y - height / 2) * img_height
w = width * img_width
h = height * img_height
bbox = [x, y, w, h]
area = w * h
annotations.append({
"id": ann_id_start,
"image_id": img_id,
"category_id": int(class_id),
"bbox": bbox,
"area": area,
"iscrowd": 0
})
ann_id_start += 1
return annotations, ann_id_start
def create_coco_json(img_folder, yolo_folder, output_file, categories):
images = []
annotations = []
ann_id = 1
img_id = 1
for img_file in os.listdir(img_folder):
if img_file.endswith(".jpg") or img_file.endswith(".png"):
img_path = os.path.join(img_folder, img_file)
img = cv2.imread(img_path)
height, width, _ = img.shape
images.append({
"id": img_id,
"width": width,
"height": height,
"file_name": img_file
})
yolo_file = os.path.join(yolo_folder, os.path.splitext(img_file)[0] + ".txt")
if os.path.exists(yolo_file):
img_annotations, ann_id = yolo_to_coco(yolo_file, img_id, ann_id, width, height)
annotations.extend(img_annotations)
img_id += 1
coco_format = {
"images": images,
"annotations": annotations,
"categories": categories
}
with open(output_file, 'w') as f:
json.dump(coco_format, f, indent=4)
# Define your categories here
categories = [
{"id": 0, "name": "person", "supercategory": "none"},
# Add more categories as needed
]
# Paths to the image folder, YOLO annotations folder and output json file
img_folder = "/home/dataset/images/val"
yolo_folder = "/home/dataset/labels/val"
output_file = "./val_coco_annotations.json"
create_coco_json(img_folder, yolo_folder, output_file, categories)
该python脚本将训练集和验证集分开转换,需要手动修改对应的路径。categories为数据集中的类别信息,根据需要进行添加和删减。
官方YOLOv7库支持的数据集格式和YOLOv5有一点小的区别,可以从二者数据集配置文件的差异对比查看。
YOLOv7的数据集配置文件:
# train and val data as 1) directory: path/images/, 2) file: path/images.txt, or 3) list: [path1/images/, path2/images/]
train: ./coco/train2017.txt # 118287 images
val: ./coco/val2017.txt # 5000 images
# test: ./coco/test-dev2017.txt # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
# number of classes
nc: 80
# class names
names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
'hair drier', 'toothbrush' ]
YOLOv5的数据集配置文件:
path: /home/dataset/
train: images/train
val: images/val
# Classes
names:
0: person
1: bicycle
2: car
3: motorcycle
4: airplane
5: bus
6: train
7: truck
8: boat
9: traffic light
10: fire hydrant
11: stop sign
12: parking meter
13: bench
14: bird
15: cat
16: dog
17: horse
18: sheep
19: cow
20: elephant
21: bear
22: zebra
23: giraffe
24: backpack
25: umbrella
26: handbag
27: tie
28: suitcase
29: frisbee
30: skis
31: snowboard
32: sports ball
33: kite
34: baseball bat
35: baseball glove
36: skateboard
37: surfboard
38: tennis racket
39: bottle
40: wine glass
41: cup
42: fork
43: knife
44: spoon
45: bowl
46: banana
47: apple
48: sandwich
49: orange
50: broccoli
51: carrot
52: hot dog
53: pizza
54: donut
55: cake
56: chair
57: couch
58: potted plant
59: bed
60: dining table
61: toilet
62: tv
63: laptop
64: mouse
65: remote
66: keyboard
67: cell phone
68: microwave
69: oven
70: toaster
71: sink
72: refrigerator
73: book
74: clock
75: vase
76: scissors
77: teddy bear
78: hair drier
79: toothbrush
对比可见,不同于YOLOv5指定数据集中images的路径,YOLOv7改为指定包含有images路径的txt文件。因此需要准备这样一个txt文件,代码如下:
import os
def list_files_in_directory(directory, output_file):
"""
List all files in the given directory and its subdirectories,
and write their paths to the output file.
Args:
directory (str): The root directory to start listing files from.
output_file (str): The file to write the paths to.
"""
with open(output_file, 'w') as file:
for root, _, files in os.walk(directory):
for name in files:
file_path = os.path.join(root, name)
file.write(file_path + '\n')
# Example usage:
directory = "/home/dataset/images/val/"
output_file = './dataset/val_with_patch.txt'
list_files_in_directory(directory, output_file)
虽然经历了这样的改动,但image和label的相对位置仍不变,为YOLOv5格式数据集的目录结构。
转换代码如下:
import json
import os
import argparse
import numpy as np
parser = argparse.ArgumentParser(description='Test yolo data.')
parser.add_argument('-j', help='JSON file', dest='json', required=True)
parser.add_argument('-o', help='path to output folder', dest='out',required=True)
args = parser.parse_args()
json_file = args.json
output = args.out
class COCO2YOLO:
def __init__(self):
self._check_file_and_dir(json_file, output)
self.labels = json.load(open(json_file, 'r', encoding='utf-8'))
self.coco_id_name_map = self._categories()
self.coco_name_list = list(self.coco_id_name_map.values())
print("total images", len(self.labels['images']))
print("total categories", len(self.labels['categories']))
print("total labels", len(self.labels['annotations']))
def _check_file_and_dir(self, file_path, dir_path):
if not os.path.exists(file_path):
raise ValueError("file not found")
if not os.path.exists(dir_path):
os.makedirs(dir_path)
def _categories(self):
categories = {}
for cls in self.labels['categories']:
categories[cls['id']] = cls['name']
return categories
def _load_images_info(self):
images_info = {}
for image in self.labels['images']:
id = image['id']
file_name = image['file_name']
if file_name.find('\\') > -1:
file_name = file_name[file_name.index('\\')+1:]
w = image['width']
h = image['height']
images_info[id] = (file_name, w, h)
return images_info
def _bbox_2_yolo(self, bbox, img_w, img_h):
x, y, w, h = bbox[0], bbox[1], bbox[2], bbox[3]
centerx = bbox[0] + w / 2
centery = bbox[1] + h / 2
dw = 1 / img_w
dh = 1 / img_h
centerx *= dw
w *= dw
centery *= dh
h *= dh
return centerx, centery, w, h
def _convert_anno(self, images_info):
anno_dict = dict()
for anno in self.labels['annotations']:
bbox = anno['bbox']
image_id = anno['image_id']
category_id = anno['category_id']
image_info = images_info.get(image_id)
image_name = image_info[0]
img_w = image_info[1]
img_h = image_info[2]
yolo_box = self._bbox_2_yolo(bbox, img_w, img_h)
anno_info = (image_name, category_id, yolo_box)
anno_infos = anno_dict.get(image_id)
if not anno_infos:
anno_dict[image_id] = [anno_info]
else:
anno_infos.append(anno_info)
anno_dict[image_id] = anno_infos
return anno_dict
def save_classes(self):
sorted_classes = list(map(lambda x: x['name'], sorted(self.labels['categories'], key=lambda x: x['id'])))
print('coco names', sorted_classes)
with open('coco.names', 'w', encoding='utf-8') as f:
for cls in sorted_classes:
f.write(cls + '\n')
f.close()
def coco2yolo(self):
print("loading image info...")
images_info = self._load_images_info()
print("loading done, total images", len(images_info))
print("start converting...")
anno_dict = self._convert_anno(images_info)
print("converting done, total labels", len(anno_dict))
print("saving txt file...")
self._save_txt(anno_dict)
print("saving done")
def _save_txt(self, anno_dict):
for k, v in anno_dict.items():
file_name = v[0][0].split(".")[0] + ".txt"
# -----
for obj in v:
cat_name_tmp = self.coco_id_name_map.get(obj[1])
#print(cat_name_tmp)
if cat_name_tmp != 'person':
continue
if np.array(obj[2]).all() > 0:
pass
else:
continue
if obj[2][-1] > 0.2:
pass
else:
continue
with open(os.path.join(output, file_name), 'a+', encoding='utf-8') as f:
#print(k, v)
print(obj)
cat_name = self.coco_id_name_map.get(obj[1])
category_id = self.coco_name_list.index(cat_name)
box = ['{:.6f}'.format(x) for x in obj[2]]
box = ' '.join(box)
line = str(category_id) + ' ' + box
f.write(line + '\n')
if __name__ == '__main__':
c2y = COCO2YOLO()
c2y.coco2yolo()
该代码支持根据标注文件中object的height进行筛选,例如只保留height大于0.2的objects。