Files

104 lines
3.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import datetime
from ultralytics import YOLO
import argparse
from model_utils import get_project_root, find_latest_model
def export_to_onnx(model_path=None, imgsz=640, simplify=True, opset=12):
"""将 PyTorch 模型导出为 ONNX 格式
Args:
model_path: 模型路径(如果为 None 则自动查找)
imgsz: 输入图片尺寸
simplify: 是否简化模型
opset: ONNX opset 版本
Returns:
导出的 ONNX 模型路径
"""
project_root = get_project_root()
# 如果没提供模型路径,自动查找最新模型
if model_path is None or not os.path.exists(model_path):
found_model = find_latest_model()
if found_model:
model_path = found_model
print(f"找到最新模型: {model_path}")
else:
print(f"\n错误: 找不到模型文件!")
print(f"\n可能的原因:")
print(f" 1. 还没有运行训练,先运行: python src/train.py")
print(f" 2. 使用 --model 参数指定正确路径")
return None
else:
print(f"模型路径: {model_path}")
try:
# 加载模型
model = YOLO(model_path)
# 导出为 ONNX
print(f"\n开始导出 ONNX 模型...")
print(f"输入尺寸: {imgsz}x{imgsz}")
print(f"简化模型: {simplify}")
print(f"ONNX opset: {opset}")
# 导出
export_result = model.export(
format='onnx',
imgsz=imgsz,
simplify=simplify,
opset=opset
)
if export_result:
# 重命名为 yyyy_MM_dd_HHmmss.onnx 格式
export_dir = os.path.dirname(export_result)
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H%M%S")
new_filename = f"{timestamp}.onnx"
new_path = os.path.join(export_dir, new_filename)
# 如果文件已存在,先删除
if os.path.exists(new_path):
os.remove(new_path)
# 重命名文件
os.rename(export_result, new_path)
print(f"\n✅ 导出成功!")
print(f"ONNX 模型路径: {new_path}")
return new_path
else:
print(f"\n❌ 导出失败")
return None
except Exception as e:
print(f"\n导出失败: {e}")
return None
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='将 YOLO 模型导出为 ONNX 格式')
parser.add_argument('--model', default=None,
help='模型路径(默认:自动查找最新模型)')
parser.add_argument('--imgsz', type=int, default=640,
help='输入图片尺寸(默认: 640')
parser.add_argument('--simplify', action='store_true', default=True,
help='简化模型(默认: True')
parser.add_argument('--no-simplify', action='store_false', dest='simplify',
help='不简化模型')
parser.add_argument('--opset', type=int, default=12,
help='ONNX opset 版本(默认: 12')
args = parser.parse_args()
export_to_onnx(
model_path=args.model,
imgsz=args.imgsz,
simplify=args.simplify,
opset=args.opset
)