import os import sys from PIL import Image # 添加 src 目录到路径 src_dir = os.path.dirname(os.path.abspath(__file__)) if src_dir not in sys.path: sys.path.insert(0, src_dir) from dataset_util import load_dataset_config def validate_images_labels(images_dir, labels_dir): """验证 images 和 labels 目录 Returns: list: 错误列表 """ errors = [] if not os.path.exists(images_dir): errors.append(f"images 目录不存在: {images_dir}") return errors image_files = [f for f in os.listdir(images_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))] if not image_files: errors.append(f"images 目录中没有图片: {images_dir}") return errors for img_file in image_files: img_path = os.path.join(images_dir, img_file) label_file = os.path.splitext(img_file)[0] + '.txt' label_path = os.path.join(labels_dir, label_file) if not os.path.exists(label_path): errors.append(f"缺失标注文件: {label_file}") continue try: img = Image.open(img_path) img_width, img_height = img.size except Exception as e: errors.append(f"无法打开图片: {img_file}") continue with open(label_path, 'r') as f: for line_num, line in enumerate(f, 1): line = line.strip() if not line: continue parts = line.split() if len(parts) != 5: errors.append(f"{label_file}: 格式错误,行 {line_num}") continue try: class_id = int(parts[0]) center_x = float(parts[1]) center_y = float(parts[2]) width = float(parts[3]) height = float(parts[4]) if not (0 <= center_x <= 1 and 0 <= center_y <= 1): errors.append(f"{label_file}: 坐标超出范围,行 {line_num}") except ValueError: errors.append(f"{label_file}: 数字解析错误,行 {line_num}") return errors def validate_dataset_from_config(config_path=None): """从配置文件验证数据集 Args: config_path: 配置文件路径 Returns: list: 错误列表 """ config = load_dataset_config(config_path) data_dir = config.get('path', './dataset') if not os.path.isabs(data_dir) and config_path: data_dir = os.path.join(os.path.dirname(config_path), data_dir) errors = [] # 验证 train/val/test for split in ['train', 'val', 'test']: split_images_dir = os.path.join(data_dir, config.get(split, f'{split}/images')) split_labels_dir = os.path.join(data_dir, config.get(split, f'{split}/images').replace('images', 'labels')) if os.path.exists(split_images_dir): print(f"检查 {split} 集...") split_errors = validate_images_labels(split_images_dir, split_labels_dir) errors.extend(split_errors) return errors def validate_yolo_dataset(config_path=None): """验证 YOLO 数据集 Args: config_path: 配置文件路径 """ print("验证数据集...") errors = validate_dataset_from_config(config_path) if errors: print(f"\n发现 {len(errors)} 个错误:") for error in errors[:20]: print(f" - {error}") if len(errors) > 20: print(f" ... 还有 {len(errors) - 20} 个错误") else: print("\n验证通过!") return errors if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='验证 YOLO 数据集') parser.add_argument('config_path', nargs='?', default=None, help='dataset.yaml 配置文件路径(默认使用项目根目录的 dataset.yaml)') args = parser.parse_args() validate_yolo_dataset(args.config_path)