流程走通。(效果不行:数据集和数据标注都比较少)
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
/dataset/tags/LabelImg/20260505_01/images
|
||||
/runs
|
||||
dataset/coco/val2017.zip
|
||||
@@ -0,0 +1,10 @@
|
||||
|
||||
path: ./dataset/tags/visiofirm/20260505_01
|
||||
train: train/images
|
||||
val: val/images
|
||||
test: test/images
|
||||
|
||||
nc: 2
|
||||
names:
|
||||
- person
|
||||
- bicycle
|
||||
@@ -0,0 +1,10 @@
|
||||
|
||||
path: ./dataset/tags/visiofirm/20260505_01
|
||||
train: train/images
|
||||
val: val/images
|
||||
test: test/images
|
||||
|
||||
nc: 2
|
||||
names:
|
||||
- person
|
||||
- bicycle
|
||||
Binary file not shown.
@@ -0,0 +1,14 @@
|
||||
VisioFirm:
|
||||
Software: VisioFirm
|
||||
contributor: ''
|
||||
date_created: '2026-05-05'
|
||||
project_description: ''
|
||||
project_name: '1'
|
||||
year: 2026
|
||||
names:
|
||||
- person
|
||||
- bicycle
|
||||
nc: 2
|
||||
test: test/images
|
||||
train: train/images
|
||||
val: val/images
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 130 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 310 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 200 KiB |
Binary file not shown.
Binary file not shown.
|
After Width: | Height: | Size: 158 KiB |
Binary file not shown.
+148
-404
@@ -1,3 +1,4 @@
|
||||
|
||||
# YOLO 模型测试流程指南
|
||||
|
||||
本文档基于知乎文章整理,帮助你从零开始完成一个 YOLO 目标检测项目的完整流程。
|
||||
@@ -9,10 +10,9 @@
|
||||
1. [环境准备](#1-环境准备)
|
||||
2. [数据收集与准备](#2-数据收集与准备)
|
||||
3. [数据标注](#3-数据标注)
|
||||
4. [数据格式转换](#4-数据格式转换)
|
||||
5. [数据集划分](#5-数据集划分)
|
||||
6. [模型训练](#6-模型训练)
|
||||
7. [模型评估与测试](#7-模型评估与测试)
|
||||
4. [配置数据集](#4-配置数据集)
|
||||
5. [模型训练](#5-模型训练)
|
||||
6. [模型评估与测试](#6-模型评估与测试)
|
||||
|
||||
***
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
|
||||
### 1.1 安装 Python 环境
|
||||
|
||||
确保 Python 版本 >= 3.8
|
||||
确保 Python 版本 >= 3.8
|
||||
|
||||
```bash
|
||||
python --version
|
||||
@@ -30,11 +30,10 @@ python --version
|
||||
|
||||
为了避免影响全局 Python 环境,建议为项目创建独立的虚拟环境。
|
||||
|
||||
**方式一:使用 conda**
|
||||
**方式一:使用 conda
|
||||
|
||||
**首先安装 Anaconda 或 Miniconda:**
|
||||
|
||||
1. 下载 Miniconda(推荐,更轻量):<https://docs.conda.io/en/latest/miniconda.html>
|
||||
1. 下载 Miniconda(推荐,更轻量):<https://docs.conda.io/en/latest/miniconda.html>
|
||||
2. 运行安装程序,勾选 "Add Miniconda to PATH"(或安装后手动配置)
|
||||
3. 验证安装:
|
||||
|
||||
@@ -50,7 +49,7 @@ conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/ma
|
||||
conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r
|
||||
conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/msys2
|
||||
|
||||
//切换用清华的源
|
||||
# 切换用清华的源
|
||||
conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
|
||||
|
||||
# 创建虚拟环境
|
||||
@@ -69,7 +68,6 @@ conda install -v pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c
|
||||
# pip 来安装 PyTorch 更稳定
|
||||
conda activate yolo_demo ; pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
|
||||
# 验证安装是否成功
|
||||
python -c "import torch; print(f'PyTorch版本: {torch.__version__}'); print(f'CUDA可用: {torch.cuda.is_available()}')"
|
||||
```
|
||||
@@ -81,15 +79,13 @@ conda env list
|
||||
```
|
||||
|
||||
**常见问题:**
|
||||
|
||||
如果遇到 `CondaToSNonInteractiveError` 错误,需要先运行上面的 `conda tos accept` 命令接受服务条款。
|
||||
|
||||
**验证虚拟环境已激活:**
|
||||
|
||||
激活后,终端提示符前会显示虚拟环境名称:
|
||||
|
||||
```bash
|
||||
(yolo_demo) D:\Codes\AI\Yolo\YoloDemo>
|
||||
(yolo_demo) D:\Codes\AI\Yolo\YoloDemo>
|
||||
```
|
||||
|
||||
**退出虚拟环境:**
|
||||
@@ -102,7 +98,7 @@ conda deactivate
|
||||
### 1.3 安装 YOLOv8(推荐)
|
||||
|
||||
```bash
|
||||
conda activate yolo_demo ;pip install ultralytics
|
||||
conda activate yolo_demo ; pip install ultralytics
|
||||
```
|
||||
|
||||
### 1.4 验证安装
|
||||
@@ -114,7 +110,7 @@ python -c "from ultralytics import YOLO; print('YOLOv8 安装成功')"
|
||||
### 1.5 检查 GPU 支持(可选但推荐)
|
||||
|
||||
```bash
|
||||
python -c "import torch; print(f'CUDA 可用: {torch.cuda.is_available()}')"
|
||||
python -c "import torch; print(f'CUDA可用: {torch.cuda.is_available()}')"
|
||||
```
|
||||
|
||||
***
|
||||
@@ -124,41 +120,28 @@ python -c "import torch; print(f'CUDA 可用: {torch.cuda.is_available()}')"
|
||||
### 2.1 确定检测目标
|
||||
|
||||
首先明确你要检测的类别,例如:
|
||||
|
||||
- car(汽车)
|
||||
- person(行人)
|
||||
- bicycle(自行车)
|
||||
|
||||
### 2.2 数据量建议
|
||||
|
||||
| 项目类型 | 每类最少图片 | 推荐图片数 |
|
||||
| ---- | --------- | ------ |
|
||||
| 快速原型 | 100-200 张 | 500 张 |
|
||||
| 项目类型 | 每类最少图片 | 推荐图片数 |
|
||||
| ---- | ------- | ------ |
|
||||
| 快速原型 | 100-200 张 | 500 张 |
|
||||
| 生产应用 | 1000 张 | 3000 张 |
|
||||
|
||||
### 2.3 数据来源
|
||||
|
||||
**方式一:使用公开数据集**
|
||||
|
||||
- COCO 数据集:<https://cocodataset.org>
|
||||
- <br />
|
||||
- Open Images:<https://storage.googleapis.com/openimages>
|
||||
- COCO 数据集:<https://cocodataset.org>
|
||||
- Open Images:<https://storage.googleapis.com/openimages>
|
||||
|
||||
**方式二:自己拍摄/收集图片**
|
||||
|
||||
- 确保图片清晰,目标可见
|
||||
- 覆盖不同场景、光照、角度
|
||||
- 统一格式为 JPG 或 PNG
|
||||
|
||||
### 2.4 创建数据目录结构
|
||||
|
||||
```bash
|
||||
mkdir -p dataset/images
|
||||
mkdir -p dataset/labels
|
||||
```
|
||||
|
||||
将收集的图片放入 `dataset/images` 目录。
|
||||
|
||||
***
|
||||
|
||||
## 3. 数据标注
|
||||
@@ -167,36 +150,21 @@ mkdir -p dataset/labels
|
||||
|
||||
推荐使用以下工具之一:
|
||||
|
||||
- **TjMakeBot**(在线工具,支持 AI 辅助标注):<https://www.tjmakebot.com>
|
||||
- **VisioFirm**(推荐标注工具)⭐
|
||||
- **TjMakeBot**(在线工具,支持 AI 辅助标注):<https://www.tjmakebot.com>
|
||||
- **LabelImg**(本地工具)
|
||||
- **Roboflow**(在线工具)
|
||||
|
||||
### 3.2 使用 LabelImg 标注(本地方式)
|
||||
**快速启动 VisioFirm:**
|
||||
|
||||
**安装 LabelImg:**
|
||||
|
||||
```bash
|
||||
pip install labelImg
|
||||
```powershell
|
||||
# 在项目根目录运行
|
||||
.\start_visiofirm.ps1
|
||||
```
|
||||
|
||||
**启动 LabelImg:**
|
||||
详细的 VisioFirm 使用说明请查看:[VisioFirm 标注工具使用指南](./2.VisioFirm标注工具使用指南.md)
|
||||
|
||||
```bash
|
||||
labelImg
|
||||
```
|
||||
|
||||
**标注步骤:**
|
||||
|
||||
1. 打开 LabelImg
|
||||
2. 点击 "Open Dir" 选择图片目录
|
||||
3. 点击 "Change Save Dir" 设置标注保存目录
|
||||
4. **重要**:点击 "PascalVOC" 切换为 "YOLO" 格式
|
||||
5. 使用快捷键 `W` 绘制边界框
|
||||
6. 选择类别名称
|
||||
7. 点击 "Save" 保存标注
|
||||
8. 使用 `D` 键切换到下一张图片
|
||||
|
||||
### 3.3 YOLO 标注格式说明
|
||||
### 3.2 YOLO 标注格式说明
|
||||
|
||||
每张图片对应一个 `.txt` 文件,格式如下:
|
||||
|
||||
@@ -212,430 +180,206 @@ class_id center_x center_y width height
|
||||
```
|
||||
|
||||
**说明:**
|
||||
|
||||
- `class_id`:类别 ID(从 0 开始)
|
||||
- `center_x, center_y`:边界框中心点坐标(归一化 0-1)
|
||||
- `width, height`:边界框宽高(归一化 0-1)
|
||||
|
||||
***
|
||||
|
||||
## 4. 数据格式转换
|
||||
## 4. 配置数据集
|
||||
|
||||
### 4.1 验证标注文件
|
||||
### 4.1 数据集目录结构
|
||||
|
||||
创建验证脚本 `validate_dataset.py`:
|
||||
支持多种数据集组织方式:
|
||||
|
||||
```python
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
def validate_yolo_dataset(dataset_dir):
|
||||
images_dir = os.path.join(dataset_dir, 'images')
|
||||
labels_dir = os.path.join(dataset_dir, 'labels')
|
||||
|
||||
errors = []
|
||||
image_files = [f for f in os.listdir(images_dir) if f.endswith(('.jpg', '.png'))]
|
||||
|
||||
for img_file in image_files:
|
||||
img_path = os.path.join(images_dir, img_file)
|
||||
label_file = os.path.splitext(img_file)[0] + '.txt'
|
||||
label_path = os.path.join(labels_dir, label_file)
|
||||
|
||||
if not os.path.exists(label_path):
|
||||
errors.append(f"缺失标注文件: {label_file}")
|
||||
continue
|
||||
|
||||
try:
|
||||
img = Image.open(img_path)
|
||||
img_width, img_height = img.size
|
||||
except Exception as e:
|
||||
errors.append(f"无法打开图片: {img_file}")
|
||||
continue
|
||||
|
||||
with open(label_path, 'r') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
parts = line.split()
|
||||
if len(parts) != 5:
|
||||
errors.append(f"{label_file}: 格式错误")
|
||||
continue
|
||||
|
||||
try:
|
||||
class_id = int(parts[0])
|
||||
center_x = float(parts[1])
|
||||
center_y = float(parts[2])
|
||||
width = float(parts[3])
|
||||
height = float(parts[4])
|
||||
|
||||
if not (0 <= center_x <= 1 and 0 <= center_y <= 1):
|
||||
errors.append(f"{label_file}: 坐标超出范围")
|
||||
except ValueError:
|
||||
errors.append(f"{label_file}: 数字解析错误")
|
||||
|
||||
if errors:
|
||||
print("发现错误:")
|
||||
for error in errors[:10]:
|
||||
print(f" - {error}")
|
||||
else:
|
||||
print("验证通过!")
|
||||
|
||||
if __name__ == '__main__':
|
||||
validate_yolo_dataset('./dataset')
|
||||
**方式一:VisioFirm 导出(推荐)**
|
||||
```
|
||||
dataset/tags/visiofirm/20260505_01/
|
||||
├── data.yaml
|
||||
├── train/images/
|
||||
├── train/labels/
|
||||
├── val/images/
|
||||
├── val/labels/
|
||||
└── test/images/
|
||||
└── test/labels/
|
||||
```
|
||||
|
||||
运行验证:
|
||||
|
||||
```bash
|
||||
python validate_dataset.py
|
||||
**方式二:LabelImg 标注**
|
||||
```
|
||||
dataset/tags/LabelImg/20260505_01/
|
||||
├── data.yaml
|
||||
├── train/images/
|
||||
├── train/labels/
|
||||
├── val/images/
|
||||
├── val/labels/
|
||||
└── test/images/
|
||||
└── test/labels/
|
||||
```
|
||||
|
||||
### 4.2 创建数据集配置文件
|
||||
### 4.2 配置 dataset.yaml
|
||||
|
||||
创建 `dataset.yaml` 文件:
|
||||
在项目根目录创建 `dataset.yaml` 文件,配置示例见:[configs/dataset-traditional.yaml](../configs/dataset-traditional.yaml)
|
||||
|
||||
**关键配置项:**
|
||||
- `path`: 数据集根目录(支持相对路径)
|
||||
- `train`: 训练集图片目录(相对于 path)
|
||||
- `val`: 验证集图片目录(相对于 path)
|
||||
- `test`: 测试集图片目录(相对于 path)
|
||||
- `nc`: 类别数量
|
||||
- `names`: 类别名称列表
|
||||
|
||||
**示例配置:**
|
||||
```yaml
|
||||
path: ./dataset
|
||||
train: images/train
|
||||
val: images/val
|
||||
test: images/test
|
||||
|
||||
nc: 3
|
||||
path: ./dataset/tags/visiofirm/20260505_01
|
||||
train: train/images
|
||||
val: val/images
|
||||
test: test/images
|
||||
|
||||
nc: 2
|
||||
names:
|
||||
0: car
|
||||
1: person
|
||||
2: bicycle
|
||||
- person
|
||||
- bicycle
|
||||
```
|
||||
|
||||
**注意:** 根据你的实际类别修改 `nc` 和 `names`。
|
||||
### 4.3 验证数据集
|
||||
|
||||
***
|
||||
|
||||
## 5. 数据集划分
|
||||
|
||||
### 5.1 创建划分脚本
|
||||
|
||||
创建 `split_dataset.py`:
|
||||
|
||||
```python
|
||||
import os
|
||||
import shutil
|
||||
import random
|
||||
|
||||
def split_dataset(source_dir, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42):
|
||||
random.seed(seed)
|
||||
|
||||
images_dir = os.path.join(source_dir, 'images')
|
||||
labels_dir = os.path.join(source_dir, 'labels')
|
||||
|
||||
images = [f for f in os.listdir(images_dir) if f.endswith(('.jpg', '.png'))]
|
||||
random.shuffle(images)
|
||||
|
||||
total = len(images)
|
||||
train_end = int(total * train_ratio)
|
||||
val_end = train_end + int(total * val_ratio)
|
||||
|
||||
train_images = images[:train_end]
|
||||
val_images = images[train_end:val_end]
|
||||
test_images = images[val_end:]
|
||||
|
||||
print(f"总图片数: {total}")
|
||||
print(f"训练集: {len(train_images)}")
|
||||
print(f"验证集: {len(val_images)}")
|
||||
print(f"测试集: {len(test_images)}")
|
||||
|
||||
for split, img_list in [('train', train_images), ('val', val_images), ('test', test_images)]:
|
||||
split_images_dir = os.path.join(source_dir, 'images', split)
|
||||
split_labels_dir = os.path.join(source_dir, 'labels', split)
|
||||
|
||||
os.makedirs(split_images_dir, exist_ok=True)
|
||||
os.makedirs(split_labels_dir, exist_ok=True)
|
||||
|
||||
for img in img_list:
|
||||
shutil.copy(os.path.join(images_dir, img), os.path.join(split_images_dir, img))
|
||||
|
||||
label_name = os.path.splitext(img)[0] + '.txt'
|
||||
src_label = os.path.join(labels_dir, label_name)
|
||||
dst_label = os.path.join(split_labels_dir, label_name)
|
||||
|
||||
if os.path.exists(src_label):
|
||||
shutil.copy(src_label, dst_label)
|
||||
|
||||
print("数据集划分完成!")
|
||||
|
||||
if __name__ == '__main__':
|
||||
split_dataset('./dataset')
|
||||
```
|
||||
|
||||
### 5.2 执行划分
|
||||
使用 [src/validate_dataset.py](../src/validate_dataset.py) 脚本验证数据集:
|
||||
|
||||
```bash
|
||||
python split_dataset.py
|
||||
```
|
||||
|
||||
### 5.3 验证划分结果
|
||||
|
||||
```bash
|
||||
ls dataset/images/train
|
||||
ls dataset/images/val
|
||||
ls dataset/images/test
|
||||
python src/validate_dataset.py
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
## 6. 模型训练
|
||||
## 5. 模型训练
|
||||
|
||||
### 6.1 创建训练脚本
|
||||
### 5.1 开始训练
|
||||
|
||||
创建 `train.py`:
|
||||
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
import torch
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
model = YOLO('yolov8n.pt')
|
||||
|
||||
results = model.train(
|
||||
data='dataset.yaml',
|
||||
epochs=100,
|
||||
imgsz=640,
|
||||
batch=16,
|
||||
device=device,
|
||||
lr0=0.01,
|
||||
patience=50,
|
||||
save=True,
|
||||
plots=True,
|
||||
project='runs/detect',
|
||||
name='my_model',
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
print("训练完成!")
|
||||
print(f"最佳模型保存在: runs/detect/my_model/weights/best.pt")
|
||||
```
|
||||
|
||||
### 6.2 开始训练
|
||||
使用 [src/train.py](../src/train.py) 脚本训练模型:
|
||||
|
||||
```bash
|
||||
python train.py
|
||||
python src/train.py
|
||||
```
|
||||
|
||||
### 6.3 训练参数说明
|
||||
常用参数:
|
||||
- `--config`: dataset.yaml 配置文件路径(默认使用项目根目录)
|
||||
- `--model`: 预训练模型(默认: yolov8n.pt)
|
||||
- `--epochs`: 训练轮数(默认: 100)
|
||||
- `--batch`: 批次大小(默认: 16)
|
||||
- `--imgsz`: 输入图片尺寸(默认: 640)
|
||||
- `--lr`: 初始学习率(默认: 0.01)
|
||||
- `--device`: 设备(cuda/cpu,默认自动检测)
|
||||
|
||||
| 参数 | 说明 | 建议值 |
|
||||
| -------- | ------ | ----------- |
|
||||
| epochs | 训练轮数 | 100-300 |
|
||||
| imgsz | 输入图片尺寸 | 640 |
|
||||
| batch | 批次大小 | 根据 GPU 内存调整 |
|
||||
| lr0 | 初始学习率 | 0.01 |
|
||||
| patience | 早停耐心值 | 50 |
|
||||
### 5.2 模型选择建议
|
||||
|
||||
### 6.4 模型选择建议
|
||||
|
||||
| 模型 | 参数量 | 速度 | 精度 | 适用场景 |
|
||||
| ------- | ----- | -- | -- | ---- |
|
||||
| yolov8n | 3.2M | 最快 | 较低 | 实时检测 |
|
||||
| yolov8s | 11.2M | 快 | 中等 | 平衡 |
|
||||
| 模型 | 参数量 | 速度 | 精度 | 适用场景 |
|
||||
| ---- | ----- | -- | -- | ---- |
|
||||
| yolov8n | 3.2M | 最快 | 较低 | 实时检测 |
|
||||
| yolov8s | 11.2M | 快 | 中等 | 平衡 |
|
||||
| yolov8m | 25.9M | 中等 | 较高 | 生产环境 |
|
||||
| yolov8l | 43.7M | 较慢 | 高 | 高精度 |
|
||||
| yolov8l | 43.7M | 较慢 | 高 | 高精度 |
|
||||
|
||||
***
|
||||
|
||||
## 7. 模型评估与测试
|
||||
## 6. 模型评估与测试
|
||||
|
||||
### 7.1 评估模型
|
||||
### 6.1 评估模型
|
||||
|
||||
创建 `evaluate.py`:
|
||||
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO('runs/detect/my_model/weights/best.pt')
|
||||
|
||||
metrics = model.val(data='dataset.yaml', split='val')
|
||||
|
||||
print("=" * 50)
|
||||
print("模型评估结果")
|
||||
print("=" * 50)
|
||||
print(f"mAP50: {metrics.box.map50:.4f}")
|
||||
print(f"mAP50-95: {metrics.box.map:.4f}")
|
||||
print(f"Precision: {metrics.box.mp:.4f}")
|
||||
print(f"Recall: {metrics.box.mr:.4f}")
|
||||
print("=" * 50)
|
||||
```
|
||||
|
||||
运行评估:
|
||||
使用 [src/evaluate.py](../src/evaluate.py) 脚本评估模型:
|
||||
|
||||
```bash
|
||||
python evaluate.py
|
||||
python src/evaluate.py
|
||||
```
|
||||
|
||||
### 7.2 测试单张图片
|
||||
**特点:**
|
||||
- 自动查找最新训练的模型(无需手动指定路径)
|
||||
- 支持从多个可能的位置查找已保存模型
|
||||
|
||||
创建 `predict.py`:
|
||||
参数:
|
||||
- `--model`: 模型路径(默认:自动查找最新模型)
|
||||
- `--config`: dataset.yaml 配置文件路径(默认使用项目根目录)
|
||||
- `--split`: 评估集(train/val/test,默认: val)
|
||||
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
### 6.2 预测
|
||||
|
||||
model = YOLO('runs/detect/my_model/weights/best.pt')
|
||||
|
||||
results = model('dataset/images/test/test_image.jpg', save=True, conf=0.25)
|
||||
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
for i in range(len(boxes)):
|
||||
class_name = model.names[int(boxes.cls[i])]
|
||||
conf = boxes.conf[i]
|
||||
print(f"检测到: {class_name}, 置信度: {conf:.2f}")
|
||||
```
|
||||
|
||||
运行测试:
|
||||
使用 [src/predict.py](../src/predict.py) 脚本进行预测:
|
||||
|
||||
```bash
|
||||
python predict.py
|
||||
python src/predict.py
|
||||
```
|
||||
|
||||
### 7.3 批量测试
|
||||
**特点:**
|
||||
- 自动查找最新训练的模型(无需手动指定路径)
|
||||
- 自动从 dataset.yaml 读取测试集路径
|
||||
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
参数:
|
||||
- `--model`: 模型路径(默认:自动查找最新模型)
|
||||
- `--source`: 预测源(图片/目录,默认:读取 dataset.yaml)
|
||||
- `--conf`: 置信度阈值(默认: 0.25)
|
||||
- `--nosave`: 不保存结果(默认保存)
|
||||
|
||||
model = YOLO('runs/detect/my_model/weights/best.pt')
|
||||
### 6.3 模型导出为 ONNX 格式
|
||||
|
||||
results = model('dataset/images/test', save=True, conf=0.25)
|
||||
使用 [src/export_onnx.py](../src/export_onnx.py) 脚本将 PyTorch (.pt) 模型导出为 ONNX 格式:
|
||||
|
||||
```bash
|
||||
python src/export_onnx.py
|
||||
```
|
||||
|
||||
### 7.4 性能基准
|
||||
**特点:**
|
||||
- 自动查找最新训练的模型(无需手动指定路径)
|
||||
- 支持模型简化,减小文件体积
|
||||
- 可自定义输入尺寸和 ONNX opset 版本
|
||||
|
||||
| 应用场景 | mAP50 目标 | mAP50-95 目标 |
|
||||
| ----- | -------- | ----------- |
|
||||
| 快速原型 | > 0.5 | > 0.3 |
|
||||
| 生产环境 | > 0.7 | > 0.5 |
|
||||
| 高精度应用 | > 0.9 | > 0.7 |
|
||||
参数:
|
||||
- `--model`: 模型路径(默认:自动查找最新模型)
|
||||
- `--imgsz`: 输入图片尺寸(默认: 640)
|
||||
- `--simplify`: 简化模型(默认: True)
|
||||
- `--no-simplify`: 不简化模型
|
||||
- `--opset`: ONNX opset 版本(默认: 12)
|
||||
|
||||
***
|
||||
**使用示例:**
|
||||
```bash
|
||||
# 导出最新训练的模型(自动查找)
|
||||
python src/export_onnx.py
|
||||
|
||||
## 常见问题排查
|
||||
# 导出指定模型
|
||||
python src/export_onnx.py --model runs/detect/my_model/weights/best.pt
|
||||
|
||||
### 问题 1:Loss 不下降
|
||||
# 自定义输入尺寸
|
||||
python src/export_onnx.py --imgsz 1280
|
||||
|
||||
**解决方案:**
|
||||
# 不简化模型
|
||||
python src/export_onnx.py --no-simplify
|
||||
|
||||
- 调整学习率(尝试 0.001-0.01)
|
||||
- 检查数据质量
|
||||
- 尝试更大的模型
|
||||
|
||||
### 问题 2:过拟合
|
||||
|
||||
**解决方案:**
|
||||
|
||||
- 增加数据量
|
||||
- 使用更小的模型
|
||||
- 启用更多数据增强
|
||||
|
||||
### 问题 3:训练很慢
|
||||
|
||||
**解决方案:**
|
||||
|
||||
- 使用 GPU 训练
|
||||
- 增大批次大小
|
||||
- 减小图片尺寸
|
||||
|
||||
***
|
||||
|
||||
## 目录结构总结
|
||||
|
||||
完成后的目录结构:
|
||||
|
||||
```
|
||||
project/
|
||||
├── venv/ # 虚拟环境目录(不应提交到版本控制)
|
||||
├── dataset/
|
||||
│ ├── images/
|
||||
│ │ ├── train/
|
||||
│ │ ├── val/
|
||||
│ │ └── test/
|
||||
│ └── labels/
|
||||
│ ├── train/
|
||||
│ ├── val/
|
||||
│ └── test/
|
||||
├── dataset.yaml
|
||||
├── train.py
|
||||
├── evaluate.py
|
||||
├── predict.py
|
||||
├── split_dataset.py
|
||||
└── validate_dataset.py
|
||||
# 自定义 opset 版本
|
||||
python src/export_onnx.py --opset 17
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
## 快速开始命令汇总
|
||||
|
||||
```bash
|
||||
# 创建并激活虚拟环境
|
||||
python -m venv venv
|
||||
venv\Scripts\activate
|
||||
|
||||
# 安装依赖
|
||||
pip install ultralytics
|
||||
|
||||
# 后续步骤
|
||||
python validate_dataset.py
|
||||
python split_dataset.py
|
||||
python train.py
|
||||
python evaluate.py
|
||||
python predict.py
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
## 一键运行环境准备
|
||||
|
||||
项目已提供 `run_demo.py` 脚本,可一键完成环境搭建:
|
||||
**注意:请在项目根目录(YoloDemo)下运行以下命令**
|
||||
|
||||
```bash
|
||||
python run_demo.py
|
||||
```
|
||||
# 1. 配置 dataset.yaml(首次使用)
|
||||
# 复制 configs/dataset-traditional.yaml 到项目根目录,修改 path 指向你的数据集
|
||||
|
||||
该脚本会自动完成:
|
||||
# 2. 验证数据集
|
||||
python src/validate_dataset.py
|
||||
|
||||
1. 检查 Python 版本
|
||||
2. 创建虚拟环境
|
||||
3. 安装 ultralytics 依赖
|
||||
4. 验证安装
|
||||
5. 检查 GPU 支持
|
||||
# 3. 训练模型
|
||||
python src/train.py
|
||||
|
||||
## 在 VS Code 中直接运行代码
|
||||
# 4. 评估模型(自动查找最新训练好的模型)
|
||||
python src/evaluate.py
|
||||
|
||||
### 方法一:使用 Code Runner 插件
|
||||
# 5. 预测(自动查找最新模型)
|
||||
python src/predict.py
|
||||
|
||||
1. 打开 VS Code,搜索并安装 `Code Runner` 插件
|
||||
2. 将光标放在代码块内,点击右上角的 ▶️ 按钮运行
|
||||
|
||||
### 方法二:右键运行
|
||||
|
||||
在代码块内右键,选择 "Run Code" 或使用快捷键 `Ctrl+Alt+N`
|
||||
|
||||
### 方法三:终端运行
|
||||
|
||||
```bash
|
||||
# 进入项目目录
|
||||
cd .
|
||||
|
||||
# 激活虚拟环境
|
||||
venv\Scripts\activate
|
||||
|
||||
# 运行脚本
|
||||
python run_demo.py
|
||||
python validate_dataset.py
|
||||
python split_dataset.py
|
||||
python train.py
|
||||
# 6. 导出为 ONNX 格式(自动查找最新模型)
|
||||
python src/export_onnx.py
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
@@ -0,0 +1,352 @@
|
||||
|
||||
# VisioFirm 标注工具使用指南
|
||||
|
||||
本指南将帮助你使用 VisioFirm 进行 YOLO 目标检测数据标注。
|
||||
|
||||
***
|
||||
|
||||
## 目录
|
||||
|
||||
1. [VisioFirm 简介](#1-visiofirm-简介)
|
||||
2. [安装与启动](#2-安装与启动)
|
||||
3. [项目设置](#3-项目设置)
|
||||
4. [标注操作](#4-标注操作)
|
||||
5. [导出 YOLO 格式](#5-导出-yolo-格式)
|
||||
6. [使用导出数据](#6-使用导出数据)
|
||||
7. [快捷键](#7-快捷键)
|
||||
|
||||
***
|
||||
|
||||
## 1. VisioFirm 简介
|
||||
|
||||
VisioFirm 是一款开源、跨平台的 AI 辅助图像标注工具,专为计算机视觉任务设计。
|
||||
|
||||
**官方链接:**
|
||||
|
||||
- GitHub 仓库:<https://github.com/OschAI/VisioFirm>
|
||||
- PyPI 项目:<https://pypi.org/project/visiofirm/>
|
||||
- 论文地址:<https://arxiv.org/abs/2509.04180>
|
||||
|
||||
**主要特性:**
|
||||
|
||||
- AI 驱动的预标注(YOLO、SAM2、Grounding DINO),节省高达 80% 人工标注
|
||||
- 支持多种标注格式:YOLO、COCO、PascalVOC 等
|
||||
- 自动划分训练/验证/测试集
|
||||
- 支持视频标注和标签传播
|
||||
- 直观的用户界面
|
||||
- 批量处理功能
|
||||
- 标注验证和预览
|
||||
- 支持分类、目标检测、分割任务
|
||||
|
||||
***
|
||||
|
||||
## 2. 安装与启动
|
||||
|
||||
### 2.1 系统要求
|
||||
|
||||
- Python 3.10 或更高版本
|
||||
- 推荐使用 CUDA 11.8(如需 GPU 加速)
|
||||
- 16GB 或更多内存
|
||||
|
||||
### 2.2 安装(推荐方式:pip)
|
||||
|
||||
**注意:** VisioFirm 直接安装在项目的 `yolo_demo` 环境中,无需创建新环境。
|
||||
|
||||
1. 激活项目虚拟环境:
|
||||
```bash
|
||||
conda activate yolo_demo
|
||||
```
|
||||
2. 设置 UTF-8 环境变量(Windows 中文环境必须):
|
||||
```powershell
|
||||
# PowerShell
|
||||
$env:PYTHONUTF8=1
|
||||
```
|
||||
3. 使用 pip 安装 VisioFirm:
|
||||
```bash
|
||||
pip install visiofirm
|
||||
```
|
||||
|
||||
**或者使用一键启动脚本:**
|
||||
|
||||
```powershell
|
||||
.\start_visiofirm.ps1
|
||||
```
|
||||
|
||||
### 2.3 启动 VisioFirm
|
||||
|
||||
**方法一:使用一键启动脚本(推荐)**
|
||||
|
||||
直接运行项目根目录下的启动脚本:
|
||||
|
||||
```powershell
|
||||
.\start_visiofirm.ps1
|
||||
```
|
||||
|
||||
**方法二:手动启动**
|
||||
|
||||
1. 激活环境并设置 UTF-8:
|
||||
```powershell
|
||||
conda activate yolo_demo
|
||||
$env:PYTHONUTF8=1
|
||||
```
|
||||
2. 启动 VisioFirm:
|
||||
```bash
|
||||
visiofirm
|
||||
```
|
||||
|
||||
启动后,浏览器会自动打开 VisioFirm 界面(通常是 [http://localhost:8000)。>](http://localhost:8000)。>)
|
||||
|
||||
***
|
||||
|
||||
## 3. 项目设置
|
||||
|
||||
### 3.1 创建新项目
|
||||
|
||||
1. 打开 VisioFirm
|
||||
2. 点击 **"New Project"** 或 **"创建新项目"**
|
||||
3. 输入项目名称(如:YOLO_Demo)
|
||||
4. 选择项目保存位置
|
||||
|
||||
### 3.2 配置项目
|
||||
|
||||
**导入图片:**
|
||||
|
||||
1. 点击 **"Import Images"** 或 **"导入图片"**
|
||||
2. 选择你的图片文件夹
|
||||
3. 选择要标注的图片文件(可全选)
|
||||
|
||||
**设置类别:**
|
||||
|
||||
1. 点击 **"Classes"** 或 **"类别"** 标签
|
||||
2. 点击 **"Add Class"** 或 **"添加类别"**
|
||||
3. 输入类别名称(如:car, person, bicycle)
|
||||
4. 为每个类别选择颜色(便于区分)
|
||||
5. 保存类别配置
|
||||
|
||||
示例类别配置:
|
||||
|
||||
```
|
||||
0: person (行人)
|
||||
1: bicycle (自行车)
|
||||
```
|
||||
|
||||
**数据集划分设置(可选):**
|
||||
|
||||
VisioFirm 在导出时会自动将数据集划分为 train/val/test。默认划分比例:
|
||||
- 训练集(train):70%
|
||||
- 验证集(val):15%
|
||||
- 测试集(test):15%
|
||||
|
||||
***
|
||||
|
||||
## 4. 标注操作
|
||||
|
||||
### 4.1 基本标注步骤
|
||||
|
||||
1. 从左侧图片列表中选择要标注的图片
|
||||
2. 选择标注工具(通常是矩形框工具)
|
||||
3. 在图片上拖动鼠标绘制边界框,框选目标
|
||||
4. 从类别列表中选择对应的类别
|
||||
5. 确认标注无误后,点击保存或自动保存
|
||||
|
||||
### 4.2 标注技巧
|
||||
|
||||
**精确框选:**
|
||||
|
||||
- 尽量紧贴目标边缘绘制边界框
|
||||
- 避免框选过多背景
|
||||
- 对于部分遮挡的目标,仍然标注可见部分
|
||||
|
||||
**类别一致性:**
|
||||
|
||||
- 保持同类目标标注标准一致
|
||||
- 对于模糊不清的目标,可以选择跳过或标记为不确定
|
||||
|
||||
**标注顺序:**
|
||||
|
||||
- 按类别顺序标注,减少类别切换
|
||||
- 使用快捷键提高效率
|
||||
|
||||
**AI 预标注(推荐):**
|
||||
|
||||
VisioFirm 支持 AI 预标注功能,可以大幅提高标注效率:
|
||||
1. 点击 AI 预标注按钮
|
||||
2. 选择合适的模型(如 YOLO)
|
||||
3. 自动生成标注框
|
||||
4. 人工检查和修正
|
||||
|
||||
***
|
||||
|
||||
## 5. 导出 YOLO 格式
|
||||
|
||||
### 5.1 导出设置
|
||||
|
||||
1. 完成标注后,点击 **"Export"** 或 **"导出"**
|
||||
2. 选择 **"YOLO"** 格式
|
||||
3. 设置导出选项:
|
||||
- 导出路径:选择项目内的 `dataset/tags/visiofirm/` 目录
|
||||
- 勾选 "自动划分 train/val/test"
|
||||
- 类别映射:确认类别名称正确
|
||||
- 坐标格式:选择归一化坐标(0-1)
|
||||
4. 可以选择导出为 ZIP 压缩包或文件夹
|
||||
|
||||
### 5.2 执行导出
|
||||
|
||||
1. 点击 **"Export"** 或 **"开始导出"**
|
||||
2. 等待导出完成
|
||||
3. 导出完成后,会在 `dataset/tags/visiofirm/` 目录下生成新文件夹(如 `20260505_01/`)
|
||||
|
||||
### 5.3 导出结果说明
|
||||
|
||||
导出后,目录结构如下:
|
||||
|
||||
```
|
||||
dataset/tags/visiofirm/
|
||||
├── 20260505_01/ # 每次导出会生成带日期的文件夹
|
||||
│ ├── data.yaml # 数据集配置文件(VisioFirm 格式)
|
||||
│ ├── train/
|
||||
│ │ ├── images/ # 训练集图片
|
||||
│ │ └── labels/ # 训练集标注
|
||||
│ ├── val/
|
||||
│ │ ├── images/ # 验证集图片
|
||||
│ │ └── labels/ # 验证集标注
|
||||
│ └── test/
|
||||
│ ├── images/ # 测试集图片
|
||||
│ └── labels/ # 测试集标注
|
||||
└── 1_YOLO.zip # ZIP 压缩包(可选)
|
||||
```
|
||||
|
||||
**标注文件格式:**
|
||||
|
||||
```
|
||||
class_id center_x center_y width height
|
||||
```
|
||||
|
||||
**示例:**
|
||||
|
||||
```
|
||||
0 0.5 0.5 0.3 0.4
|
||||
1 0.2 0.3 0.1 0.2
|
||||
```
|
||||
|
||||
**说明:**
|
||||
|
||||
- `class_id`:类别 ID(从 0 开始)
|
||||
- `center_x, center_y`:边界框中心点坐标(归一化 0-1)
|
||||
- `width, height`:边界框宽高(归一化 0-1)
|
||||
|
||||
***
|
||||
|
||||
## 6. 使用导出数据
|
||||
|
||||
### 6.1 配置项目
|
||||
|
||||
1. 复制 `configs/dataset-traditional.yaml` 到项目根目录,重命名为 `dataset.yaml`
|
||||
2. 编辑 `dataset.yaml`,修改 `path` 指向刚导出的数据文件夹:
|
||||
|
||||
```yaml
|
||||
path: ./dataset/tags/visiofirm/20260505_01
|
||||
train: train/images
|
||||
val: val/images
|
||||
test: test/images
|
||||
|
||||
nc: 2
|
||||
names:
|
||||
- person
|
||||
- bicycle
|
||||
```
|
||||
|
||||
**注意:** `nc` 和 `names` 需要根据你实际标注的类别进行修改。
|
||||
|
||||
### 6.2 验证数据集
|
||||
|
||||
使用项目验证脚本检查导出的标注:
|
||||
|
||||
```bash
|
||||
python src/validate_dataset.py
|
||||
```
|
||||
|
||||
验证通过后,即可开始训练模型!
|
||||
|
||||
***
|
||||
|
||||
## 7. 快捷键
|
||||
|
||||
| 快捷键 | 功能 |
|
||||
| --------- | ------ |
|
||||
| `W` 或 `R` | 矩形框工具 |
|
||||
| `A` | 上一张图片 |
|
||||
| `D` | 下一张图片 |
|
||||
| `Ctrl+S` | 保存标注 |
|
||||
| `Ctrl+Z` | 撤销 |
|
||||
| `Ctrl+Y` | 重做 |
|
||||
| `Del` | 删除选中的框 |
|
||||
| `Esc` | 取消当前操作 |
|
||||
| `Space` | 平移画布 |
|
||||
| `鼠标滚轮` | 缩放 |
|
||||
|
||||
**注意:** 具体快捷键可能因 VisioFirm 版本不同而略有差异,请以实际软件为准。
|
||||
|
||||
***
|
||||
|
||||
## 8. 标注质量检查
|
||||
|
||||
完成标注后,建议进行以下检查:
|
||||
|
||||
1. **完整性检查**
|
||||
- 确认所有图片都已标注
|
||||
- 确认所有目标都已框选
|
||||
|
||||
2. **准确性检查**
|
||||
- 检查边界框是否准确
|
||||
- 检查类别是否正确
|
||||
|
||||
3. **格式验证**
|
||||
- 使用项目中的验证脚本检查标注格式:
|
||||
```bash
|
||||
python src/validate_dataset.py
|
||||
```
|
||||
|
||||
***
|
||||
|
||||
## 9. 常见问题
|
||||
|
||||
### Q: VisioFirm 如何获取?
|
||||
|
||||
A: 推荐使用 pip 安装:`pip install visiofirm`,详细步骤见"安装与启动"章节。
|
||||
也可以访问 GitHub 仓库:<https://github.com/OschAI/VisioFirm> 获取源码。
|
||||
|
||||
### Q: VisioFirm 是免费的吗?
|
||||
|
||||
A: 是的,VisioFirm 是开源软件,采用 Apache-2.0 许可证,完全免费使用。
|
||||
|
||||
### Q: 导出后如何使用这些数据?
|
||||
|
||||
A: 导出的数据已经自动划分好 train/val/test。
|
||||
只需修改项目根目录的 `dataset.yaml` 中的 `path` 指向导出文件夹即可开始训练。
|
||||
|
||||
### Q: 如何批量处理图片?
|
||||
|
||||
A: VisioFirm 支持批量导入图片,可以一次选择多张或整个文件夹进行标注。
|
||||
导出时也支持批量导出所有标注。
|
||||
|
||||
### Q: 标注过程中如何暂停和继续?
|
||||
|
||||
A: VisioFirm 会自动保存进度,关闭浏览器标签页后重新打开 <http://localhost:8000> 即可继续。
|
||||
|
||||
### Q: 如何修改已有标注?
|
||||
|
||||
A: 选择要修改的框,可以拖动调整位置或大小,也可以删除后重新标注。
|
||||
对于 AI 预标注的结果,可以直接进行修正。
|
||||
|
||||
### Q: 遇到技术问题怎么办?
|
||||
|
||||
A: 可以在 GitHub Issues 页面提交问题:<https://github.com/OschAI/VisioFirm/issues>
|
||||
|
||||
***
|
||||
|
||||
## 下一步
|
||||
|
||||
标注完成后,请继续参考 [YOLO 模型测试流程指南](./1.YOLO模型测试流程指南.md) 进行模型训练。
|
||||
|
||||
祝你标注顺利!
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,24 @@
|
||||
|
||||
import os
|
||||
import yaml
|
||||
|
||||
|
||||
def load_dataset_config(config_path=None):
|
||||
"""加载数据集配置
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径,默认使用项目根目录的 dataset.yaml
|
||||
|
||||
Returns:
|
||||
dict: 数据集配置
|
||||
"""
|
||||
if config_path is None:
|
||||
# 默认使用项目根目录的 dataset.yaml
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
config_path = os.path.join(project_root, 'dataset.yaml')
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
raise FileNotFoundError(f"配置文件不存在: {config_path}")
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
return yaml.safe_load(f)
|
||||
@@ -0,0 +1,84 @@
|
||||
|
||||
import os
|
||||
from ultralytics import YOLO
|
||||
import argparse
|
||||
from model_utils import get_project_root, find_latest_model
|
||||
|
||||
|
||||
def evaluate_model(model_path=None, config_path=None, split='val'):
|
||||
"""评估模型
|
||||
|
||||
Args:
|
||||
model_path: 模型路径(如果为 None 则自动查找)
|
||||
config_path: dataset.yaml 配置文件路径
|
||||
split: 评估集(train/val/test)
|
||||
"""
|
||||
project_root = get_project_root()
|
||||
|
||||
# 如果没提供 config_path,使用项目根目录的 dataset.yaml
|
||||
if config_path is None:
|
||||
config_path = os.path.join(project_root, 'dataset.yaml')
|
||||
|
||||
print(f"配置文件: {config_path}")
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(config_path):
|
||||
print(f"\n错误: 配置文件不存在!")
|
||||
print(f"路径: {config_path}")
|
||||
print(f"\n请先配置 dataset.yaml:")
|
||||
print(f" 1. 复制 configs/dataset-traditional.yaml 到项目根目录")
|
||||
print(f" 2. 重命名为 dataset.yaml")
|
||||
print(f" 3. 修改 path 指向你的数据集目录")
|
||||
return None
|
||||
|
||||
# 如果没提供模型路径,自动查找最新模型
|
||||
if model_path is None or not os.path.exists(model_path):
|
||||
found_model = find_latest_model()
|
||||
if found_model:
|
||||
model_path = found_model
|
||||
print(f"找到最新模型: {model_path}")
|
||||
else:
|
||||
print(f"\n错误: 找不到模型文件!")
|
||||
print(f"\n可能的原因:")
|
||||
print(f" 1. 还没有运行训练,先运行: python src/train.py")
|
||||
print(f" 2. 使用 --model 参数指定正确路径")
|
||||
return None
|
||||
else:
|
||||
print(f"模型路径: {model_path}")
|
||||
|
||||
try:
|
||||
# 加载模型
|
||||
model = YOLO(model_path)
|
||||
|
||||
# 评估
|
||||
print(f"\n开始评估 {split} 集...")
|
||||
metrics = model.val(data=config_path, split=split)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("模型评估结果")
|
||||
print("=" * 50)
|
||||
print(f"mAP50: {metrics.box.map50:.4f}")
|
||||
print(f"mAP50-95: {metrics.box.map:.4f}")
|
||||
print(f"Precision: {metrics.box.mp:.4f}")
|
||||
print(f"Recall: {metrics.box.mr:.4f}")
|
||||
print("=" * 50)
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n评估失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='评估 YOLO 模型')
|
||||
parser.add_argument('--model', default=None,
|
||||
help='模型路径(默认:自动查找最新模型)')
|
||||
parser.add_argument('--config', default=None,
|
||||
help='dataset.yaml 配置文件路径(默认使用项目根目录)')
|
||||
parser.add_argument('--split', default='val',
|
||||
help='评估集(train/val/test,默认: val)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
evaluate_model(args.model, args.config, args.split)
|
||||
@@ -0,0 +1,103 @@
|
||||
|
||||
import os
|
||||
import datetime
|
||||
from ultralytics import YOLO
|
||||
import argparse
|
||||
from model_utils import get_project_root, find_latest_model
|
||||
|
||||
|
||||
def export_to_onnx(model_path=None, imgsz=640, simplify=True, opset=12):
|
||||
"""将 PyTorch 模型导出为 ONNX 格式
|
||||
|
||||
Args:
|
||||
model_path: 模型路径(如果为 None 则自动查找)
|
||||
imgsz: 输入图片尺寸
|
||||
simplify: 是否简化模型
|
||||
opset: ONNX opset 版本
|
||||
|
||||
Returns:
|
||||
导出的 ONNX 模型路径
|
||||
"""
|
||||
project_root = get_project_root()
|
||||
|
||||
# 如果没提供模型路径,自动查找最新模型
|
||||
if model_path is None or not os.path.exists(model_path):
|
||||
found_model = find_latest_model()
|
||||
if found_model:
|
||||
model_path = found_model
|
||||
print(f"找到最新模型: {model_path}")
|
||||
else:
|
||||
print(f"\n错误: 找不到模型文件!")
|
||||
print(f"\n可能的原因:")
|
||||
print(f" 1. 还没有运行训练,先运行: python src/train.py")
|
||||
print(f" 2. 使用 --model 参数指定正确路径")
|
||||
return None
|
||||
else:
|
||||
print(f"模型路径: {model_path}")
|
||||
|
||||
try:
|
||||
# 加载模型
|
||||
model = YOLO(model_path)
|
||||
|
||||
# 导出为 ONNX
|
||||
print(f"\n开始导出 ONNX 模型...")
|
||||
print(f"输入尺寸: {imgsz}x{imgsz}")
|
||||
print(f"简化模型: {simplify}")
|
||||
print(f"ONNX opset: {opset}")
|
||||
|
||||
# 导出
|
||||
export_result = model.export(
|
||||
format='onnx',
|
||||
imgsz=imgsz,
|
||||
simplify=simplify,
|
||||
opset=opset
|
||||
)
|
||||
|
||||
if export_result:
|
||||
# 重命名为 yyyy_MM_dd_HHmmss.onnx 格式
|
||||
export_dir = os.path.dirname(export_result)
|
||||
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H%M%S")
|
||||
new_filename = f"{timestamp}.onnx"
|
||||
new_path = os.path.join(export_dir, new_filename)
|
||||
|
||||
# 如果文件已存在,先删除
|
||||
if os.path.exists(new_path):
|
||||
os.remove(new_path)
|
||||
|
||||
# 重命名文件
|
||||
os.rename(export_result, new_path)
|
||||
|
||||
print(f"\n✅ 导出成功!")
|
||||
print(f"ONNX 模型路径: {new_path}")
|
||||
return new_path
|
||||
else:
|
||||
print(f"\n❌ 导出失败")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n导出失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='将 YOLO 模型导出为 ONNX 格式')
|
||||
parser.add_argument('--model', default=None,
|
||||
help='模型路径(默认:自动查找最新模型)')
|
||||
parser.add_argument('--imgsz', type=int, default=640,
|
||||
help='输入图片尺寸(默认: 640)')
|
||||
parser.add_argument('--simplify', action='store_true', default=True,
|
||||
help='简化模型(默认: True)')
|
||||
parser.add_argument('--no-simplify', action='store_false', dest='simplify',
|
||||
help='不简化模型')
|
||||
parser.add_argument('--opset', type=int, default=12,
|
||||
help='ONNX opset 版本(默认: 12)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
export_to_onnx(
|
||||
model_path=args.model,
|
||||
imgsz=args.imgsz,
|
||||
simplify=args.simplify,
|
||||
opset=args.opset
|
||||
)
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
|
||||
import os
|
||||
import glob
|
||||
|
||||
|
||||
def get_project_root():
|
||||
"""获取项目根目录
|
||||
|
||||
Returns:
|
||||
项目根目录的绝对路径
|
||||
"""
|
||||
current_file = os.path.abspath(__file__)
|
||||
src_dir = os.path.dirname(current_file)
|
||||
project_root = os.path.dirname(src_dir)
|
||||
return project_root
|
||||
|
||||
|
||||
def find_latest_model(project_dir='runs/detect', name='my_model'):
|
||||
"""查找最新的最佳模型
|
||||
|
||||
Args:
|
||||
project_dir: 项目目录
|
||||
name: 实验名称
|
||||
|
||||
Returns:
|
||||
最新的 best.pt 路径,如果找不到则返回 None
|
||||
"""
|
||||
project_root = get_project_root()
|
||||
|
||||
# 尝试在多个可能的位置查找
|
||||
possible_paths = [
|
||||
# 1. 默认位置
|
||||
os.path.join(project_root, project_dir, name, 'weights', 'best.pt'),
|
||||
# 2. 查找所有可能的 best.pt
|
||||
*glob.glob(os.path.join(project_root, 'runs', '**', 'best.pt'), recursive=True)
|
||||
]
|
||||
|
||||
# 过滤存在的文件
|
||||
existing_models = [p for p in possible_paths if os.path.exists(p)]
|
||||
|
||||
if not existing_models:
|
||||
return None
|
||||
|
||||
# 去重,保持优先级(同一文件可能出现在多个路径中)
|
||||
unique_models = list(dict.fromkeys(existing_models))
|
||||
|
||||
# 返回最新修改的那个
|
||||
unique_models.sort(key=lambda x: os.path.getmtime(x), reverse=True)
|
||||
return unique_models[0]
|
||||
@@ -0,0 +1,86 @@
|
||||
|
||||
import os
|
||||
from ultralytics import YOLO
|
||||
import argparse
|
||||
from model_utils import get_project_root, find_latest_model
|
||||
|
||||
|
||||
def predict_image(model_path=None, source=None, conf=0.25, save=True):
|
||||
"""预测单张图片或目录
|
||||
|
||||
Args:
|
||||
model_path: 模型路径(如果为 None 则自动查找)
|
||||
source: 预测源(图片/目录)
|
||||
conf: 置信度阈值
|
||||
save: 是否保存结果
|
||||
"""
|
||||
project_root = get_project_root()
|
||||
|
||||
# 如果没提供模型路径,自动查找最新模型
|
||||
if model_path is None or not os.path.exists(model_path):
|
||||
found_model = find_latest_model()
|
||||
if found_model:
|
||||
model_path = found_model
|
||||
print(f"找到最新模型: {model_path}")
|
||||
else:
|
||||
print(f"\n错误: 找不到模型文件!")
|
||||
print(f"\n可能的原因:")
|
||||
print(f" 1. 还没有运行训练,先运行: python src/train.py")
|
||||
print(f" 2. 使用 --model 参数指定正确路径")
|
||||
return None
|
||||
else:
|
||||
print(f"模型路径: {model_path}")
|
||||
|
||||
# 如果没提供 source,使用配置文件中的 test 目录
|
||||
if source is None:
|
||||
config_path = os.path.join(project_root, 'dataset.yaml')
|
||||
if os.path.exists(config_path):
|
||||
import yaml
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
if 'path' in config and 'test' in config:
|
||||
source = os.path.join(config['path'], config['test'])
|
||||
|
||||
if source is None or not os.path.exists(source):
|
||||
print(f"\n警告: 找不到预测源,请使用 --source 参数指定图片路径")
|
||||
|
||||
try:
|
||||
# 加载模型
|
||||
model = YOLO(model_path)
|
||||
|
||||
# 预测
|
||||
results = model.predict(source, save=save, conf=conf)
|
||||
|
||||
# 打印结果
|
||||
for result in results:
|
||||
boxes = result.boxes
|
||||
print(f"\n图片: {result.path}")
|
||||
if len(boxes) == 0:
|
||||
print(f" 未检测到目标")
|
||||
else:
|
||||
for i in range(len(boxes)):
|
||||
class_name = model.names[int(boxes.cls[i])]
|
||||
confidence = float(boxes.conf[i])
|
||||
print(f" 检测到: {class_name}, 置信度: {confidence:.2f}")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n预测失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='YOLO 预测')
|
||||
parser.add_argument('--model', default=None,
|
||||
help='模型路径(默认:自动查找最新模型)')
|
||||
parser.add_argument('--source', default=None,
|
||||
help='预测源(图片/目录,默认:读取 dataset.yaml)')
|
||||
parser.add_argument('--conf', type=float, default=0.25,
|
||||
help='置信度阈值(默认: 0.25)')
|
||||
parser.add_argument('--nosave', action='store_true',
|
||||
help='不保存结果(默认保存)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
predict_image(args.model, args.source, conf=args.conf, save=not args.nosave)
|
||||
+112
@@ -0,0 +1,112 @@
|
||||
|
||||
import os
|
||||
from ultralytics import YOLO
|
||||
import torch
|
||||
import argparse
|
||||
from model_utils import get_project_root
|
||||
|
||||
|
||||
def train_yolo(config_path=None, model_name='yolov8n.pt', epochs=100, imgsz=640,
|
||||
batch=16, device=None, lr0=0.01, patience=50,
|
||||
project='runs/detect', name='my_model', exist_ok=True):
|
||||
"""训练 YOLO 模型
|
||||
|
||||
Args:
|
||||
config_path: dataset.yaml 配置文件路径
|
||||
"""
|
||||
project_root = get_project_root()
|
||||
|
||||
# 如果没提供 config_path,使用项目根目录的 dataset.yaml
|
||||
if config_path is None:
|
||||
config_path = os.path.join(project_root, 'dataset.yaml')
|
||||
|
||||
print(f"配置文件: {config_path}")
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(config_path):
|
||||
print(f"\n错误: 配置文件不存在!")
|
||||
print(f"路径: {config_path}")
|
||||
print(f"\n请先配置 dataset.yaml:")
|
||||
print(f" 1. 复制 configs/dataset-traditional.yaml 到项目根目录")
|
||||
print(f" 2. 重命名为 dataset.yaml")
|
||||
print(f" 3. 修改 path 指向你的数据集目录")
|
||||
return None
|
||||
|
||||
# 设置设备
|
||||
if device is None:
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
try:
|
||||
# 加载模型
|
||||
print(f"\n加载模型: {model_name}")
|
||||
model = YOLO(model_name)
|
||||
|
||||
# 训练 - 确保 project 路径是相对于项目根目录的
|
||||
print(f"开始训练...")
|
||||
results = model.train(
|
||||
data=config_path,
|
||||
epochs=epochs,
|
||||
imgsz=imgsz,
|
||||
batch=batch,
|
||||
device=device,
|
||||
lr0=lr0,
|
||||
patience=patience,
|
||||
save=True,
|
||||
plots=True,
|
||||
project=os.path.join(project_root, project),
|
||||
name=name,
|
||||
exist_ok=exist_ok,
|
||||
)
|
||||
|
||||
best_model_path = results.save_dir / 'weights' / 'best.pt'
|
||||
print(f"\n训练完成!")
|
||||
print(f"最佳模型保存在: {best_model_path}")
|
||||
print(f"\n现在可以运行:")
|
||||
print(f" 评估模型: python src/evaluate.py --model {best_model_path}")
|
||||
print(f" 进行预测: python src/predict.py --model {best_model_path}")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n训练失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='训练 YOLO 模型')
|
||||
parser.add_argument('--config', default=None,
|
||||
help='dataset.yaml 配置文件路径(默认使用项目根目录)')
|
||||
parser.add_argument('--model', default='yolov8n.pt',
|
||||
help='预训练模型 (默认: yolov8n.pt)')
|
||||
parser.add_argument('--epochs', type=int, default=100,
|
||||
help='训练轮数 (默认: 100)')
|
||||
parser.add_argument('--imgsz', type=int, default=640,
|
||||
help='输入图片尺寸 (默认: 640)')
|
||||
parser.add_argument('--batch', type=int, default=16,
|
||||
help='批次大小 (默认: 16)')
|
||||
parser.add_argument('--device', default=None,
|
||||
help='设备 (cuda/cpu, 默认: 自动检测)')
|
||||
parser.add_argument('--lr', type=float, default=0.01,
|
||||
help='初始学习率 (默认: 0.01)')
|
||||
parser.add_argument('--patience', type=int, default=50,
|
||||
help='早停耐心值 (默认: 50)')
|
||||
parser.add_argument('--project', default='runs/detect',
|
||||
help='项目目录 (默认: runs/detect)')
|
||||
parser.add_argument('--name', default='my_model',
|
||||
help='实验名称 (默认: my_model)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
train_yolo(
|
||||
config_path=args.config,
|
||||
model_name=args.model,
|
||||
epochs=args.epochs,
|
||||
imgsz=args.imgsz,
|
||||
batch=args.batch,
|
||||
device=args.device,
|
||||
lr0=args.lr,
|
||||
patience=args.patience,
|
||||
project=args.project,
|
||||
name=args.name
|
||||
)
|
||||
@@ -0,0 +1,135 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from PIL import Image
|
||||
|
||||
# 添加 src 目录到路径
|
||||
src_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
if src_dir not in sys.path:
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
from dataset_util import load_dataset_config
|
||||
|
||||
|
||||
def validate_images_labels(images_dir, labels_dir):
|
||||
"""验证 images 和 labels 目录
|
||||
|
||||
Returns:
|
||||
list: 错误列表
|
||||
"""
|
||||
errors = []
|
||||
|
||||
if not os.path.exists(images_dir):
|
||||
errors.append(f"images 目录不存在: {images_dir}")
|
||||
return errors
|
||||
|
||||
image_files = [f for f in os.listdir(images_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
|
||||
|
||||
if not image_files:
|
||||
errors.append(f"images 目录中没有图片: {images_dir}")
|
||||
return errors
|
||||
|
||||
for img_file in image_files:
|
||||
img_path = os.path.join(images_dir, img_file)
|
||||
label_file = os.path.splitext(img_file)[0] + '.txt'
|
||||
label_path = os.path.join(labels_dir, label_file)
|
||||
|
||||
if not os.path.exists(label_path):
|
||||
errors.append(f"缺失标注文件: {label_file}")
|
||||
continue
|
||||
|
||||
try:
|
||||
img = Image.open(img_path)
|
||||
img_width, img_height = img.size
|
||||
except Exception as e:
|
||||
errors.append(f"无法打开图片: {img_file}")
|
||||
continue
|
||||
|
||||
with open(label_path, 'r') as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
parts = line.split()
|
||||
if len(parts) != 5:
|
||||
errors.append(f"{label_file}: 格式错误,行 {line_num}")
|
||||
continue
|
||||
|
||||
try:
|
||||
class_id = int(parts[0])
|
||||
center_x = float(parts[1])
|
||||
center_y = float(parts[2])
|
||||
width = float(parts[3])
|
||||
height = float(parts[4])
|
||||
|
||||
if not (0 <= center_x <= 1 and 0 <= center_y <= 1):
|
||||
errors.append(f"{label_file}: 坐标超出范围,行 {line_num}")
|
||||
except ValueError:
|
||||
errors.append(f"{label_file}: 数字解析错误,行 {line_num}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_dataset_from_config(config_path=None):
|
||||
"""从配置文件验证数据集
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
|
||||
Returns:
|
||||
list: 错误列表
|
||||
"""
|
||||
config = load_dataset_config(config_path)
|
||||
|
||||
data_dir = config.get('path', './dataset')
|
||||
if not os.path.isabs(data_dir) and config_path:
|
||||
data_dir = os.path.join(os.path.dirname(config_path), data_dir)
|
||||
|
||||
errors = []
|
||||
|
||||
# 验证 train/val/test
|
||||
for split in ['train', 'val', 'test']:
|
||||
split_images_dir = os.path.join(data_dir, config.get(split, f'{split}/images'))
|
||||
split_labels_dir = os.path.join(data_dir, config.get(split, f'{split}/images').replace('images', 'labels'))
|
||||
|
||||
if os.path.exists(split_images_dir):
|
||||
print(f"检查 {split} 集...")
|
||||
split_errors = validate_images_labels(split_images_dir, split_labels_dir)
|
||||
errors.extend(split_errors)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_yolo_dataset(config_path=None):
|
||||
"""验证 YOLO 数据集
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
"""
|
||||
print("验证数据集...")
|
||||
|
||||
errors = validate_dataset_from_config(config_path)
|
||||
|
||||
if errors:
|
||||
print(f"\n发现 {len(errors)} 个错误:")
|
||||
for error in errors[:20]:
|
||||
print(f" - {error}")
|
||||
if len(errors) > 20:
|
||||
print(f" ... 还有 {len(errors) - 20} 个错误")
|
||||
else:
|
||||
print("\n验证通过!")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='验证 YOLO 数据集')
|
||||
parser.add_argument('config_path', nargs='?', default=None,
|
||||
help='dataset.yaml 配置文件路径(默认使用项目根目录的 dataset.yaml)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
validate_yolo_dataset(args.config_path)
|
||||
@@ -0,0 +1,49 @@
|
||||
# VisioFirm 一键启动脚本
|
||||
Write-Host "========================================" -ForegroundColor Cyan
|
||||
Write-Host " 正在启动 VisioFirm..." -ForegroundColor Cyan
|
||||
Write-Host "========================================" -ForegroundColor Cyan
|
||||
Write-Host ""
|
||||
|
||||
# 设置 UTF-8 环境变量(解决 Windows 中文环境编码问题)
|
||||
$env:PYTHONUTF8 = 1
|
||||
Write-Host "[1/4] 已设置 PYTHONUTF8=1" -ForegroundColor Green
|
||||
|
||||
# 检查 Conda 是否可用
|
||||
try {
|
||||
$condaVersion = conda --version 2>&1
|
||||
Write-Host "[2/4] Conda 已检测到: $condaVersion" -ForegroundColor Green
|
||||
} catch {
|
||||
Write-Host "[错误] 未检测到 Conda,请先安装 Anaconda 或 Miniconda" -ForegroundColor Red
|
||||
Read-Host "按任意键退出"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# 激活 Conda 环境
|
||||
Write-Host "[3/4] 正在激活 Conda 环境 (yolo_demo)..." -ForegroundColor Green
|
||||
conda activate yolo_demo
|
||||
if ($LASTEXITCODE -ne 0) {
|
||||
Write-Host "[警告] 激活环境时可能存在问题,但将继续尝试..." -ForegroundColor Yellow
|
||||
}
|
||||
|
||||
# 检查 VisioFirm 是否已安装
|
||||
try {
|
||||
$visiofirmInstalled = python -c "import visiofirm; print('OK')" 2>&1
|
||||
if ($visiofirmInstalled -eq 'OK') {
|
||||
Write-Host "[4/4] VisioFirm 已安装" -ForegroundColor Green
|
||||
}
|
||||
} catch {
|
||||
Write-Host "[提示] 正在尝试安装 VisioFirm..." -ForegroundColor Yellow
|
||||
pip install visiofirm
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
Write-Host "🎉 VisioFirm 即将启动!" -ForegroundColor Yellow
|
||||
Write-Host "📱 请在浏览器中打开: http://localhost:8000" -ForegroundColor Yellow
|
||||
Write-Host ""
|
||||
Write-Host "按 Ctrl+C 停止服务器" -ForegroundColor Gray
|
||||
Write-Host ""
|
||||
Write-Host "========================================" -ForegroundColor Cyan
|
||||
Write-Host ""
|
||||
|
||||
# 启动 VisioFirm
|
||||
visiofirm
|
||||
BIN
Binary file not shown.
Reference in New Issue
Block a user