流程走通。(效果不行:数据集和数据标注都比较少)

This commit is contained in:
ShaoHua
2026-05-05 03:02:34 +08:00
parent 41e3bb113a
commit 3fa64a95f8
28 changed files with 1179 additions and 404 deletions
+135
View File
@@ -0,0 +1,135 @@
import os
import sys
from PIL import Image
# 添加 src 目录到路径
src_dir = os.path.dirname(os.path.abspath(__file__))
if src_dir not in sys.path:
sys.path.insert(0, src_dir)
from dataset_util import load_dataset_config
def validate_images_labels(images_dir, labels_dir):
"""验证 images 和 labels 目录
Returns:
list: 错误列表
"""
errors = []
if not os.path.exists(images_dir):
errors.append(f"images 目录不存在: {images_dir}")
return errors
image_files = [f for f in os.listdir(images_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
if not image_files:
errors.append(f"images 目录中没有图片: {images_dir}")
return errors
for img_file in image_files:
img_path = os.path.join(images_dir, img_file)
label_file = os.path.splitext(img_file)[0] + '.txt'
label_path = os.path.join(labels_dir, label_file)
if not os.path.exists(label_path):
errors.append(f"缺失标注文件: {label_file}")
continue
try:
img = Image.open(img_path)
img_width, img_height = img.size
except Exception as e:
errors.append(f"无法打开图片: {img_file}")
continue
with open(label_path, 'r') as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
parts = line.split()
if len(parts) != 5:
errors.append(f"{label_file}: 格式错误,行 {line_num}")
continue
try:
class_id = int(parts[0])
center_x = float(parts[1])
center_y = float(parts[2])
width = float(parts[3])
height = float(parts[4])
if not (0 <= center_x <= 1 and 0 <= center_y <= 1):
errors.append(f"{label_file}: 坐标超出范围,行 {line_num}")
except ValueError:
errors.append(f"{label_file}: 数字解析错误,行 {line_num}")
return errors
def validate_dataset_from_config(config_path=None):
"""从配置文件验证数据集
Args:
config_path: 配置文件路径
Returns:
list: 错误列表
"""
config = load_dataset_config(config_path)
data_dir = config.get('path', './dataset')
if not os.path.isabs(data_dir) and config_path:
data_dir = os.path.join(os.path.dirname(config_path), data_dir)
errors = []
# 验证 train/val/test
for split in ['train', 'val', 'test']:
split_images_dir = os.path.join(data_dir, config.get(split, f'{split}/images'))
split_labels_dir = os.path.join(data_dir, config.get(split, f'{split}/images').replace('images', 'labels'))
if os.path.exists(split_images_dir):
print(f"检查 {split} 集...")
split_errors = validate_images_labels(split_images_dir, split_labels_dir)
errors.extend(split_errors)
return errors
def validate_yolo_dataset(config_path=None):
"""验证 YOLO 数据集
Args:
config_path: 配置文件路径
"""
print("验证数据集...")
errors = validate_dataset_from_config(config_path)
if errors:
print(f"\n发现 {len(errors)} 个错误:")
for error in errors[:20]:
print(f" - {error}")
if len(errors) > 20:
print(f" ... 还有 {len(errors) - 20} 个错误")
else:
print("\n验证通过!")
return errors
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='验证 YOLO 数据集')
parser.add_argument('config_path', nargs='?', default=None,
help='dataset.yaml 配置文件路径(默认使用项目根目录的 dataset.yaml')
args = parser.parse_args()
validate_yolo_dataset(args.config_path)