136 lines
4.0 KiB
Python
136 lines
4.0 KiB
Python
|
||
import os
|
||
import sys
|
||
from PIL import Image
|
||
|
||
# 添加 src 目录到路径
|
||
src_dir = os.path.dirname(os.path.abspath(__file__))
|
||
if src_dir not in sys.path:
|
||
sys.path.insert(0, src_dir)
|
||
|
||
from dataset_util import load_dataset_config
|
||
|
||
|
||
def validate_images_labels(images_dir, labels_dir):
|
||
"""验证 images 和 labels 目录
|
||
|
||
Returns:
|
||
list: 错误列表
|
||
"""
|
||
errors = []
|
||
|
||
if not os.path.exists(images_dir):
|
||
errors.append(f"images 目录不存在: {images_dir}")
|
||
return errors
|
||
|
||
image_files = [f for f in os.listdir(images_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
|
||
|
||
if not image_files:
|
||
errors.append(f"images 目录中没有图片: {images_dir}")
|
||
return errors
|
||
|
||
for img_file in image_files:
|
||
img_path = os.path.join(images_dir, img_file)
|
||
label_file = os.path.splitext(img_file)[0] + '.txt'
|
||
label_path = os.path.join(labels_dir, label_file)
|
||
|
||
if not os.path.exists(label_path):
|
||
errors.append(f"缺失标注文件: {label_file}")
|
||
continue
|
||
|
||
try:
|
||
img = Image.open(img_path)
|
||
img_width, img_height = img.size
|
||
except Exception as e:
|
||
errors.append(f"无法打开图片: {img_file}")
|
||
continue
|
||
|
||
with open(label_path, 'r') as f:
|
||
for line_num, line in enumerate(f, 1):
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
parts = line.split()
|
||
if len(parts) != 5:
|
||
errors.append(f"{label_file}: 格式错误,行 {line_num}")
|
||
continue
|
||
|
||
try:
|
||
class_id = int(parts[0])
|
||
center_x = float(parts[1])
|
||
center_y = float(parts[2])
|
||
width = float(parts[3])
|
||
height = float(parts[4])
|
||
|
||
if not (0 <= center_x <= 1 and 0 <= center_y <= 1):
|
||
errors.append(f"{label_file}: 坐标超出范围,行 {line_num}")
|
||
except ValueError:
|
||
errors.append(f"{label_file}: 数字解析错误,行 {line_num}")
|
||
|
||
return errors
|
||
|
||
|
||
def validate_dataset_from_config(config_path=None):
|
||
"""从配置文件验证数据集
|
||
|
||
Args:
|
||
config_path: 配置文件路径
|
||
|
||
Returns:
|
||
list: 错误列表
|
||
"""
|
||
config = load_dataset_config(config_path)
|
||
|
||
data_dir = config.get('path', './dataset')
|
||
if not os.path.isabs(data_dir) and config_path:
|
||
data_dir = os.path.join(os.path.dirname(config_path), data_dir)
|
||
|
||
errors = []
|
||
|
||
# 验证 train/val/test
|
||
for split in ['train', 'val', 'test']:
|
||
split_images_dir = os.path.join(data_dir, config.get(split, f'{split}/images'))
|
||
split_labels_dir = os.path.join(data_dir, config.get(split, f'{split}/images').replace('images', 'labels'))
|
||
|
||
if os.path.exists(split_images_dir):
|
||
print(f"检查 {split} 集...")
|
||
split_errors = validate_images_labels(split_images_dir, split_labels_dir)
|
||
errors.extend(split_errors)
|
||
|
||
return errors
|
||
|
||
|
||
def validate_yolo_dataset(config_path=None):
|
||
"""验证 YOLO 数据集
|
||
|
||
Args:
|
||
config_path: 配置文件路径
|
||
"""
|
||
print("验证数据集...")
|
||
|
||
errors = validate_dataset_from_config(config_path)
|
||
|
||
if errors:
|
||
print(f"\n发现 {len(errors)} 个错误:")
|
||
for error in errors[:20]:
|
||
print(f" - {error}")
|
||
if len(errors) > 20:
|
||
print(f" ... 还有 {len(errors) - 20} 个错误")
|
||
else:
|
||
print("\n验证通过!")
|
||
|
||
return errors
|
||
|
||
|
||
if __name__ == '__main__':
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='验证 YOLO 数据集')
|
||
parser.add_argument('config_path', nargs='?', default=None,
|
||
help='dataset.yaml 配置文件路径(默认使用项目根目录的 dataset.yaml)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
validate_yolo_dataset(args.config_path)
|