import os import datetime from ultralytics import YOLO import argparse from model_utils import get_project_root, find_latest_model def export_to_onnx(model_path=None, imgsz=640, simplify=True, opset=12): """将 PyTorch 模型导出为 ONNX 格式 Args: model_path: 模型路径(如果为 None 则自动查找) imgsz: 输入图片尺寸 simplify: 是否简化模型 opset: ONNX opset 版本 Returns: 导出的 ONNX 模型路径 """ project_root = get_project_root() # 如果没提供模型路径,自动查找最新模型 if model_path is None or not os.path.exists(model_path): found_model = find_latest_model() if found_model: model_path = found_model print(f"找到最新模型: {model_path}") else: print(f"\n错误: 找不到模型文件!") print(f"\n可能的原因:") print(f" 1. 还没有运行训练,先运行: python src/train.py") print(f" 2. 使用 --model 参数指定正确路径") return None else: print(f"模型路径: {model_path}") try: # 加载模型 model = YOLO(model_path) # 导出为 ONNX print(f"\n开始导出 ONNX 模型...") print(f"输入尺寸: {imgsz}x{imgsz}") print(f"简化模型: {simplify}") print(f"ONNX opset: {opset}") # 导出 export_result = model.export( format='onnx', imgsz=imgsz, simplify=simplify, opset=opset ) if export_result: # 重命名为 yyyy_MM_dd_HHmmss.onnx 格式 export_dir = os.path.dirname(export_result) timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H%M%S") new_filename = f"{timestamp}.onnx" new_path = os.path.join(export_dir, new_filename) # 如果文件已存在,先删除 if os.path.exists(new_path): os.remove(new_path) # 重命名文件 os.rename(export_result, new_path) print(f"\n✅ 导出成功!") print(f"ONNX 模型路径: {new_path}") return new_path else: print(f"\n❌ 导出失败") return None except Exception as e: print(f"\n导出失败: {e}") return None if __name__ == '__main__': parser = argparse.ArgumentParser(description='将 YOLO 模型导出为 ONNX 格式') parser.add_argument('--model', default=None, help='模型路径(默认:自动查找最新模型)') parser.add_argument('--imgsz', type=int, default=640, help='输入图片尺寸(默认: 640)') parser.add_argument('--simplify', action='store_true', default=True, help='简化模型(默认: True)') parser.add_argument('--no-simplify', action='store_false', dest='simplify', help='不简化模型') parser.add_argument('--opset', type=int, default=12, help='ONNX opset 版本(默认: 12)') args = parser.parse_args() export_to_onnx( model_path=args.model, imgsz=args.imgsz, simplify=args.simplify, opset=args.opset )