流程走通。(效果不行:数据集和数据标注都比较少)
This commit is contained in:
@@ -0,0 +1,103 @@
|
||||
|
||||
import os
|
||||
import datetime
|
||||
from ultralytics import YOLO
|
||||
import argparse
|
||||
from model_utils import get_project_root, find_latest_model
|
||||
|
||||
|
||||
def export_to_onnx(model_path=None, imgsz=640, simplify=True, opset=12):
|
||||
"""将 PyTorch 模型导出为 ONNX 格式
|
||||
|
||||
Args:
|
||||
model_path: 模型路径(如果为 None 则自动查找)
|
||||
imgsz: 输入图片尺寸
|
||||
simplify: 是否简化模型
|
||||
opset: ONNX opset 版本
|
||||
|
||||
Returns:
|
||||
导出的 ONNX 模型路径
|
||||
"""
|
||||
project_root = get_project_root()
|
||||
|
||||
# 如果没提供模型路径,自动查找最新模型
|
||||
if model_path is None or not os.path.exists(model_path):
|
||||
found_model = find_latest_model()
|
||||
if found_model:
|
||||
model_path = found_model
|
||||
print(f"找到最新模型: {model_path}")
|
||||
else:
|
||||
print(f"\n错误: 找不到模型文件!")
|
||||
print(f"\n可能的原因:")
|
||||
print(f" 1. 还没有运行训练,先运行: python src/train.py")
|
||||
print(f" 2. 使用 --model 参数指定正确路径")
|
||||
return None
|
||||
else:
|
||||
print(f"模型路径: {model_path}")
|
||||
|
||||
try:
|
||||
# 加载模型
|
||||
model = YOLO(model_path)
|
||||
|
||||
# 导出为 ONNX
|
||||
print(f"\n开始导出 ONNX 模型...")
|
||||
print(f"输入尺寸: {imgsz}x{imgsz}")
|
||||
print(f"简化模型: {simplify}")
|
||||
print(f"ONNX opset: {opset}")
|
||||
|
||||
# 导出
|
||||
export_result = model.export(
|
||||
format='onnx',
|
||||
imgsz=imgsz,
|
||||
simplify=simplify,
|
||||
opset=opset
|
||||
)
|
||||
|
||||
if export_result:
|
||||
# 重命名为 yyyy_MM_dd_HHmmss.onnx 格式
|
||||
export_dir = os.path.dirname(export_result)
|
||||
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H%M%S")
|
||||
new_filename = f"{timestamp}.onnx"
|
||||
new_path = os.path.join(export_dir, new_filename)
|
||||
|
||||
# 如果文件已存在,先删除
|
||||
if os.path.exists(new_path):
|
||||
os.remove(new_path)
|
||||
|
||||
# 重命名文件
|
||||
os.rename(export_result, new_path)
|
||||
|
||||
print(f"\n✅ 导出成功!")
|
||||
print(f"ONNX 模型路径: {new_path}")
|
||||
return new_path
|
||||
else:
|
||||
print(f"\n❌ 导出失败")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n导出失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='将 YOLO 模型导出为 ONNX 格式')
|
||||
parser.add_argument('--model', default=None,
|
||||
help='模型路径(默认:自动查找最新模型)')
|
||||
parser.add_argument('--imgsz', type=int, default=640,
|
||||
help='输入图片尺寸(默认: 640)')
|
||||
parser.add_argument('--simplify', action='store_true', default=True,
|
||||
help='简化模型(默认: True)')
|
||||
parser.add_argument('--no-simplify', action='store_false', dest='simplify',
|
||||
help='不简化模型')
|
||||
parser.add_argument('--opset', type=int, default=12,
|
||||
help='ONNX opset 版本(默认: 12)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
export_to_onnx(
|
||||
model_path=args.model,
|
||||
imgsz=args.imgsz,
|
||||
simplify=args.simplify,
|
||||
opset=args.opset
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user