流程走通。(效果不行:数据集和数据标注都比较少)
This commit is contained in:
@@ -0,0 +1,135 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from PIL import Image
|
||||
|
||||
# 添加 src 目录到路径
|
||||
src_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
if src_dir not in sys.path:
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
from dataset_util import load_dataset_config
|
||||
|
||||
|
||||
def validate_images_labels(images_dir, labels_dir):
|
||||
"""验证 images 和 labels 目录
|
||||
|
||||
Returns:
|
||||
list: 错误列表
|
||||
"""
|
||||
errors = []
|
||||
|
||||
if not os.path.exists(images_dir):
|
||||
errors.append(f"images 目录不存在: {images_dir}")
|
||||
return errors
|
||||
|
||||
image_files = [f for f in os.listdir(images_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
|
||||
|
||||
if not image_files:
|
||||
errors.append(f"images 目录中没有图片: {images_dir}")
|
||||
return errors
|
||||
|
||||
for img_file in image_files:
|
||||
img_path = os.path.join(images_dir, img_file)
|
||||
label_file = os.path.splitext(img_file)[0] + '.txt'
|
||||
label_path = os.path.join(labels_dir, label_file)
|
||||
|
||||
if not os.path.exists(label_path):
|
||||
errors.append(f"缺失标注文件: {label_file}")
|
||||
continue
|
||||
|
||||
try:
|
||||
img = Image.open(img_path)
|
||||
img_width, img_height = img.size
|
||||
except Exception as e:
|
||||
errors.append(f"无法打开图片: {img_file}")
|
||||
continue
|
||||
|
||||
with open(label_path, 'r') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
parts = line.split()
|
||||
if len(parts) != 5:
|
||||
errors.append(f"{label_file}: 格式错误,行 {line_num}")
|
||||
continue
|
||||
|
||||
try:
|
||||
class_id = int(parts[0])
|
||||
center_x = float(parts[1])
|
||||
center_y = float(parts[2])
|
||||
width = float(parts[3])
|
||||
height = float(parts[4])
|
||||
|
||||
if not (0 <= center_x <= 1 and 0 <= center_y <= 1):
|
||||
errors.append(f"{label_file}: 坐标超出范围,行 {line_num}")
|
||||
except ValueError:
|
||||
errors.append(f"{label_file}: 数字解析错误,行 {line_num}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_dataset_from_config(config_path=None):
|
||||
"""从配置文件验证数据集
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
|
||||
Returns:
|
||||
list: 错误列表
|
||||
"""
|
||||
config = load_dataset_config(config_path)
|
||||
|
||||
data_dir = config.get('path', './dataset')
|
||||
if not os.path.isabs(data_dir) and config_path:
|
||||
data_dir = os.path.join(os.path.dirname(config_path), data_dir)
|
||||
|
||||
errors = []
|
||||
|
||||
# 验证 train/val/test
|
||||
for split in ['train', 'val', 'test']:
|
||||
split_images_dir = os.path.join(data_dir, config.get(split, f'{split}/images'))
|
||||
split_labels_dir = os.path.join(data_dir, config.get(split, f'{split}/images').replace('images', 'labels'))
|
||||
|
||||
if os.path.exists(split_images_dir):
|
||||
print(f"检查 {split} 集...")
|
||||
split_errors = validate_images_labels(split_images_dir, split_labels_dir)
|
||||
errors.extend(split_errors)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_yolo_dataset(config_path=None):
|
||||
"""验证 YOLO 数据集
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
"""
|
||||
print("验证数据集...")
|
||||
|
||||
errors = validate_dataset_from_config(config_path)
|
||||
|
||||
if errors:
|
||||
print(f"\n发现 {len(errors)} 个错误:")
|
||||
for error in errors[:20]:
|
||||
print(f" - {error}")
|
||||
if len(errors) > 20:
|
||||
print(f" ... 还有 {len(errors) - 20} 个错误")
|
||||
else:
|
||||
print("\n验证通过!")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='验证 YOLO 数据集')
|
||||
parser.add_argument('config_path', nargs='?', default=None,
|
||||
help='dataset.yaml 配置文件路径(默认使用项目根目录的 dataset.yaml)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
validate_yolo_dataset(args.config_path)
|
||||
Reference in New Issue
Block a user