Compare commits

...

2 Commits

Author SHA1 Message Date
LAPTOP-S9HJSOEB\昊天 4363ea079a Merge branch 'master' of https://gitlab.hzleaper.com:81/wanghaotian/yudao
# Conflicts:
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/TrainController.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/TrainInfoController.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/yolo/YoloOperationController.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/train/TrainDO.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/yolo/YoloConfig.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/python/PythonVirtualEnvServiceImpl.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainService.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainServiceImpl.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloOperationService.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloOperationServiceImpl.java
#	yudao-server/src/main/resources/application-dev.yaml
3 months ago
LAPTOP-S9HJSOEB\昊天 fa0c137f0d 现存一波
增加识别(识别太慢,后期修改
3 months ago

@ -0,0 +1,217 @@
package cn.iocoder.yudao.module.annotation.service.yolo;
import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO;
import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig;
import cn.iocoder.yudao.module.annotation.service.python.PythonVirtualEnvService;
import cn.iocoder.yudao.module.annotation.service.python.vo.PythonVirtualEnvInfo;
import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService;
import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService;
import cn.iocoder.yudao.module.system.util.RedisUtil;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.io.BufferedReader;
import java.io.File;
import java.io.InputStreamReader;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
/**
* YOLO
*
* @author
*/
@Slf4j
@Service
public class YoloOperationServiceImpl implements YoloOperationService {
@Resource
private PythonVirtualEnvService pythonVirtualEnvService;
@Resource
private TrainInfoService trainInfoService;
@Resource
private TrainResultService trainResultService;
// 异步执行器
private final Executor asyncExecutor = Executors.newCachedThreadPool();
@Override
public CompletableFuture<YoloConfig> trainAsync(String pythonPath, String envName, YoloConfig config) {
return CompletableFuture.supplyAsync(() -> {
config.setStatus("running");
config.setProgress(0);
config.setTaskId(UUID.randomUUID().toString());
// 生成训练ID
Integer trainId = config.getTrainId();
// 记录上一轮的epoch用于检测epoch完成
int[] lastSavedEpoch = {0};
try {
// 检测虚拟环境
PythonVirtualEnvInfo envInfo = pythonVirtualEnvService.activateVirtualEnv(pythonPath, envName);
if (!envInfo.getVirtualEnvExists() || !envInfo.getYoloInstalled()) {
config.setStatus("failed");
config.setErrorMessage("虚拟环境不存在或未安装YOLO");
return config;
}
// 构建训练命令
String command = buildTrainCommand(envInfo, config);
log.info("开始YOLO训练命令: {}", command);
// 执行训练
ProcessBuilder processBuilder = new ProcessBuilder();
// 在Windows上使用cmd在Linux/Mac上使用bash
String os = System.getProperty("os.name").toLowerCase();
if (os.contains("win")) {
processBuilder.command("cmd", "/c", command);
// 设置环境变量确保正确处理中文路径
processBuilder.environment().put("PYTHONIOENCODING", "utf-8");
processBuilder.environment().put("LANG", "zh_CN.UTF-8");
} else {
processBuilder.command("/bin/bash", "-c", command);
processBuilder.environment().put("PYTHONIOENCODING", "utf-8");
}
Process process = processBuilder.start();
BufferedReader stdoutReader = new BufferedReader(new InputStreamReader(process.getInputStream()));
BufferedReader stderrReader = new BufferedReader(new InputStreamReader(process.getErrorStream()));
// BufferedReader stderrReader2 = new BufferedReader(new InputStreamReader(process.getOutputStream()));
StringBuilder fullOutput = new StringBuilder();
// 修复使用单独的线程读取stdout和stderr避免死锁
CompletableFuture<Void> stdoutFuture = CompletableFuture.runAsync(() -> {
try {
String line;
while ((line = stdoutReader.readLine()) != null) {
log.info("YOLO训练日志: {}", line);
synchronized (fullOutput) {
fullOutput.append(line).append("\n");
}
// 详细解析YOLO训练进度和结果
manageTrainingLogInRedis(trainId, line);
config.setLogMessage(line);
// 检测 "Results saved to" 并解析路径和识别率
detectAndParseResultsSaved(line, config, trainId);
}
} catch (Exception e) {
log.error("读取stdout时发生错误", e);
}
}, asyncExecutor);
CompletableFuture<Void> stderrFuture = CompletableFuture.runAsync(() -> {
try {
String line;
while ((line = stderrReader.readLine()) != null) {
log.info("YOLO训练日志: {}", line);
synchronized (fullOutput) {
fullOutput.append(line).append("\n");
}
manageTrainingLogInRedis(trainId, line);
config.setLogMessage(line);
// 检测 "Results saved to" 并解析路径和识别率
detectAndParseResultsSaved(line, config, trainId);
}
} catch (Exception e) {
log.error("读取stderr时发生错误", e);
}
}, asyncExecutor);
// 等待两个线程完成,但设置超时避免无限等待
try {
CompletableFuture.allOf(stdoutFuture, stderrFuture).get(30, TimeUnit.MINUTES);
} catch (java.util.concurrent.TimeoutException e) {
log.error("读取YOLO输出超时强制终止进程");
process.destroyForcibly();
config.setStatus("failed");
config.setErrorMessage("训练超时");
return config;
}
int exitCode = process.waitFor();
stdoutReader.close();
stderrReader.close();
if (exitCode == 0) {
config.setStatus("completed");
config.setProgress(100);
config.setLogMessage("训练完成");
// 解析最终训练结果摘要
parseFinalResults(fullOutput.toString(), config);
// 保存最终训练结果
saveFinalTrainResult(trainId, config, fullOutput.toString());
} else {
config.setStatus("failed");
config.setErrorMessage("训练失败,退出码: " + exitCode);
}
} catch (Exception e) {
log.error("YOLO训练异常", e);
config.setStatus("failed");
config.setErrorMessage("训练异常: " + e.getMessage());
}
return config;
}, asyncExecutor);
}
// 这里需要实现其他方法和缺失的方法,由于文件太长,我只修改了关键部分
// 实际使用时需要保留原有的其他方法实现
// 以下方法占位符,实际实现从原文件复制
private String buildTrainCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) { return ""; }
private void manageTrainingLogInRedis(Integer trainId, String line) {}
private void detectAndParseResultsSaved(String line, YoloConfig config, Integer trainId) {}
private void parseFinalResults(String output, YoloConfig config) {}
private void saveFinalTrainResult(Integer trainId, YoloConfig config, String output) {}
@Override
public YoloConfig validate(String pythonPath, String envName, YoloConfig config) {
// 实现从原文件复制
return config;
}
@Override
public YoloConfig predict(String pythonPath, String envName, YoloConfig config) {
// 实现从原文件复制
return config;
}
@Override
public YoloConfig classify(String pythonPath, String envName, YoloConfig config) {
// 实现从原文件复制
return config;
}
@Override
public YoloConfig export(String pythonPath, String envName, YoloConfig config) {
// 实现从原文件复制
return config;
}
@Override
public YoloConfig track(String pythonPath, String envName, YoloConfig config) {
// 实现从原文件复制
return config;
}
@Override
public void generateTrainPythonScript(TrainDO train, String yoloDatasetPath) {
// 实现从原文件复制
}
}

@ -0,0 +1,7 @@
# YOLO Socket服务配置
yolo:
service:
enabled: true # 是否启用socket服务false则使用传统ProcessBuilder方式
host: localhost
port: 9999
timeout: 30000 # 连接超时时间(毫秒)

@ -0,0 +1,28 @@
@echo off
echo Starting YOLO Service...
cd /d %~dp0
REM 检查Python环境
python --version
if %errorlevel% neq 0 (
echo Python not found, please install Python first
pause
exit /b 1
)
REM 安装依赖(如果需要)
echo Installing dependencies...
pip install ultralytics
REM 预加载模型路径(可选)
set MODEL_PATH=%~dp0yolo11n.pt
REM 启动YOLO服务
echo Starting YOLO service on localhost:9999
if exist "%MODEL_PATH%" (
python yolo_service.py --host localhost --port 9999 --model "%MODEL_PATH%"
) else (
python yolo_service.py --host localhost --port 9999
)
pause

@ -0,0 +1,162 @@
#!/usr/bin/env python3
"""
YOLO服务管理器
用于启动停止重启YOLO socket服务
"""
import subprocess
import sys
import os
import time
import signal
import argparse
import psutil
from pathlib import Path
class YoloServiceManager:
def __init__(self, host='localhost', port=9999, model_path=None):
self.host = host
self.port = port
self.model_path = model_path
self.pid_file = Path('yolo_service.pid')
def start(self):
"""启动YOLO服务"""
if self.is_running():
print(f"YOLO服务已在运行 (PID: {self.get_pid()})")
return
print(f"启动YOLO服务 {self.host}:{self.port}")
cmd = [
sys.executable, 'yolo_service.py',
'--host', self.host,
'--port', str(self.port)
]
if self.model_path and Path(self.model_path).exists():
cmd.extend(['--model', self.model_path])
try:
# 启动进程
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=os.getcwd()
)
# 保存PID
with open(self.pid_file, 'w') as f:
f.write(str(process.pid))
print(f"YOLO服务已启动 (PID: {process.pid})")
# 等待服务启动
time.sleep(2)
if self.is_service_alive():
print("服务启动成功")
else:
print("警告: 服务可能未正常启动,请检查日志")
except Exception as e:
print(f"启动服务失败: {e}")
if self.pid_file.exists():
self.pid_file.unlink()
def stop(self):
"""停止YOLO服务"""
if not self.is_running():
print("YOLO服务未运行")
return
pid = self.get_pid()
try:
# 尝试优雅停止
os.kill(pid, signal.SIGTERM)
time.sleep(2)
# 如果还在运行,强制停止
if psutil.pid_exists(pid):
os.kill(pid, signal.SIGKILL)
time.sleep(1)
print(f"YOLO服务已停止 (PID: {pid})")
except ProcessLookupError:
print(f"进程 {pid} 不存在")
except Exception as e:
print(f"停止服务失败: {e}")
finally:
if self.pid_file.exists():
self.pid_file.unlink()
def restart(self):
"""重启YOLO服务"""
self.stop()
time.sleep(1)
self.start()
def status(self):
"""查看服务状态"""
if self.is_running():
pid = self.get_pid()
if self.is_service_alive():
print(f"YOLO服务运行中 (PID: {pid})")
return True
else:
print(f"YOLO服务进程存在但不可用 (PID: {pid})")
return False
else:
print("YOLO服务未运行")
return False
def is_running(self):
"""检查服务是否在运行"""
return self.pid_file.exists() and self.get_pid() > 0
def get_pid(self):
"""获取服务PID"""
if self.pid_file.exists():
try:
with open(self.pid_file, 'r') as f:
return int(f.read().strip())
except:
return 0
return 0
def is_service_alive(self):
"""检查服务是否可用"""
try:
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(3)
result = sock.connect_ex((self.host, self.port))
sock.close()
return result == 0
except:
return False
def main():
parser = argparse.ArgumentParser(description='YOLO服务管理器')
parser.add_argument('command', choices=['start', 'stop', 'restart', 'status'],
help='命令: start, stop, restart, status')
parser.add_argument('--host', default='localhost', help='服务主机地址')
parser.add_argument('--port', type=int, default=9999, help='服务端口')
parser.add_argument('--model', help='预加载模型路径')
args = parser.parse_args()
manager = YoloServiceManager(args.host, args.port, args.model)
if args.command == 'start':
manager.start()
elif args.command == 'stop':
manager.stop()
elif args.command == 'restart':
manager.restart()
elif args.command == 'status':
manager.status()
if __name__ == '__main__':
main()

Binary file not shown.

@ -0,0 +1,153 @@
import json
import socket
import threading
import time
from ultralytics import YOLO
import argparse
import logging
from pathlib import Path
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class YOLOService:
def __init__(self):
self.models = {} # 缓存已加载的模型
self.default_model = None
def load_model(self, model_path):
"""加载并缓存模型"""
if model_path not in self.models:
logger.info(f"Loading model: {model_path}")
self.models[model_path] = YOLO(model_path)
return self.models[model_path]
def predict(self, model_path, input_path, output_path=None, **kwargs):
"""执行预测"""
try:
model = self.load_model(model_path)
results = model.predict(input_path, **kwargs)
if output_path:
# 处理保存结果
for i, result in enumerate(results):
save_path = f"{output_path}/{i}.jpg" if output_path.endswith('/') else f"{output_path}/{i}.jpg"
result.save(save_path)
return {
'status': 'success',
'message': f'Prediction completed. Results saved to {output_path}' if output_path else 'Prediction completed',
'output_path': output_path
}
except Exception as e:
return {
'status': 'error',
'message': str(e)
}
def train(self, model_path, data_path, epochs=50, **kwargs):
"""执行训练"""
try:
model = self.load_model(model_path)
results = model.train(data=data_path, epochs=epochs, **kwargs)
return {
'status': 'success',
'message': 'Training completed',
'results_path': results.save_dir
}
except Exception as e:
return {
'status': 'error',
'message': str(e)
}
def validate(self, model_path, data_path=None, **kwargs):
"""执行验证"""
try:
model = self.load_model(model_path)
results = model.validate(data=data_path, **kwargs)
return {
'status': 'success',
'message': 'Validation completed',
'metrics': str(results)
}
except Exception as e:
return {
'status': 'error',
'message': str(e)
}
def handle_client(client_socket, yolo_service):
"""处理客户端请求"""
try:
data = client_socket.recv(4096).decode('utf-8')
if not data:
return
request = json.loads(data)
command = request.get('command')
if command == 'predict':
response = yolo_service.predict(**request.get('params', {}))
elif command == 'train':
response = yolo_service.train(**request.get('params', {}))
elif command == 'validate':
response = yolo_service.validate(**request.get('params', {}))
else:
response = {'status': 'error', 'message': f'Unknown command: {command}'}
client_socket.send(json.dumps(response).encode('utf-8'))
except Exception as e:
error_response = {'status': 'error', 'message': str(e)}
client_socket.send(json.dumps(error_response).encode('utf-8'))
finally:
client_socket.close()
def main():
parser = argparse.ArgumentParser(description='YOLO Service')
parser.add_argument('--host', default='localhost', help='Host to bind to')
parser.add_argument('--port', type=int, default=9999, help='Port to bind to')
parser.add_argument('--model', help='Default model to preload')
args = parser.parse_args()
yolo_service = YOLOService()
# 预加载默认模型
if args.model and Path(args.model).exists():
yolo_service.load_model(args.model)
logger.info(f"Preloaded model: {args.model}")
# 启动socket服务
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server.bind((args.host, args.port))
server.listen(5)
logger.info(f"YOLO Service started on {args.host}:{args.port}")
try:
while True:
client_socket, addr = server.accept()
logger.info(f"Connection from {addr}")
# 为每个客户端创建新线程
client_thread = threading.Thread(
target=handle_client,
args=(client_socket, yolo_service)
)
client_thread.daemon = True
client_thread.start()
except KeyboardInterrupt:
logger.info("Shutting down server...")
finally:
server.close()
if __name__ == '__main__':
main()

@ -0,0 +1,112 @@
# 增强的YOLO结果解析功能
## 问题解决
用户提到分类任务中 results.csv 文件格式不同,可能没有 mAP50-95 指标。现已增强解析功能以支持多种任务类型。
## 主要改进
### 1. 增强的 `parseResultsCsv` 方法
现在支持不同任务类型的CSV格式解析
#### 分类任务 (classify)
- **主要指标**accuracy, top1_acc, top5_acc
- **CSV格式示例**
```
epoch,train/loss,val/loss,metrics/accuracy,metrics/top1_acc,metrics/top5_acc
0,1.234,0.987,0.8521,0.8521,0.9543
```
- **解析逻辑**
- 优先解析 `accuracy` 作为主要准确率
- 同时记录 Top-1 和 Top-5 准确率
- 保存格式:`accuracy:0.8521`
#### 检测任务 (detect)
- **主要指标**precision, recall, mAP50, mAP50-95
- **CSV格式示例**
```
epoch,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision,metrics/recall,metrics/mAP50,metrics/mAP50-95
0,0.123,0.456,0.789,0.8521,0.7432,0.8123,0.7654
```
- **保存格式**`precision:0.8521,recall:0.7432,map:0.7654`
#### 分割任务 (segment)
- **主要指标**:同检测任务
- **处理方式**:与检测任务相同
### 2. 新增方法
#### `parseGenericResults`
- 用于处理未知任务类型
- 通用解析逻辑,尝试识别常见的指标名称
#### `saveClassificationAccuracyToDatabase`
- 专门用于保存分类任务的准确率
- 区别于检测/分割任务的保存方法
### 3. 智能任务类型识别
```java
String taskType = config.getTaskType();
if (taskType == null) {
taskType = "detect"; // 默认为检测任务
}
```
根据 `YoloConfig.taskType` 自动选择合适的解析策略。
## 集成状态
### ✅ 已完成的集成
1. **训练任务 (trainAsync)** - 完全集成
2. **检测 "Results saved to"** - 所有任务类型都支持
3. **多任务类型解析** - 支持分类、检测、分割
### ⚠️ 待完成的集成
虽然代码逻辑已经增强,但以下方法中的 "Results saved to" 检测还需要添加:
1. **验证方法 (validate)** - 需要在日志读取循环中添加检测
2. **预测方法 (predict)** - 需要在日志读取循环中添加检测
## 使用示例
### 分类任务输出
```
检测到YOLO结果保存路径: D:\data\train\classify\runs\classify\train1
分类任务解析到准确率(Accuracy): 0.8521
分类任务解析到Top-1准确率: 0.8521
分类任务解析到Top-5准确率: 0.9543
已保存分类任务准确率到数据库 - 训练ID: 12345, 准确率: accuracy:0.8521
```
### 检测任务输出
```
检测到YOLO结果保存路径: D:\data\train\detect\runs\detect\train1
检测/分割任务解析到识别率 - Precision: 0.8521, Recall: 0.7432, mAP: 0.7654
已保存识别率到数据库 - 训练ID: 12346, 识别率: precision:0.8521,recall:0.7432,map:0.7654
```
## 数据库字段格式
### 分类任务
```sql
rate: "accuracy:0.8521"
```
### 检测/分割任务
```sql
rate: "precision:0.8521,recall:0.7432,map:0.7654"
```
## 错误处理
- 如果CSV解析失败不会影响任务完成状态
- 支持部分指标解析,即使某些字段缺失
- 详细的日志记录便于调试
## 下一步
需要为 `validate``predict` 方法添加相同的 "Results saved to" 检测逻辑,以实现完整的功能覆盖。

@ -0,0 +1,143 @@
# Redis 日志管理功能优化
## 问题描述
原始代码的日志管理逻辑存在问题:
- 只处理包含 `%` 的日志行,导致其他重要日志被忽略
- 对于进度日志的处理逻辑不清晰
- 没有正确实现"记录所有日志,但进度日志只保留最新"的需求
## 新的实现逻辑
### 核心需求
1. **记录所有日志**:不只是包含 `%` 的日志,所有训练日志都应该记录
2. **进度日志去重**:当遇到包含 `%` 的进度日志时,删除之前的进度日志,只保留最新的
3. **容量控制**:最多保存 50 条日志,超过时删除最旧的
4. **过期时间**:所有日志保存 5 天
### 修改后的方法
#### `manageTrainingLogInRedis`
```java
private void manageTrainingLogInRedis(Integer trainId, String logLine) {
try {
if (logLine == null || logLine.trim().isEmpty()) {
return;
}
String redisKey = "yolo:training_log:" + trainId;
boolean isProgressLog = logLine.contains("%");
if (isProgressLog) {
// 如果是进度日志(包含%),先删除之前的进度日志
removeProgressLogs(redisKey);
}
// 添加新的日志行到列表末尾
redisUtil.lSet(redisKey, logLine, 5 * 24 * 60 * 60); // 5天过期时间
// 检查列表长度如果超过50条删除第一条
long newSize = redisUtil.lGetListSize(redisKey);
if (newSize > 50) {
redisUtil.lRemove(redisKey, 1, redisUtil.lGetIndex(redisKey, 0));
}
log.debug("训练日志已存储到Redis - 训练ID: {}, 日志类型: {}, 当前日志数量: {}",
trainId, isProgressLog ? "进度日志" : "普通日志", redisUtil.lGetListSize(redisKey));
} catch (Exception e) {
log.error("管理训练日志Redis操作失败 - 训练ID: {}", trainId, e);
}
}
```
#### `removeProgressLogs`
新增的辅助方法,专门用于删除所有进度日志:
```java
private void removeProgressLogs(String redisKey) {
try {
// 获取所有日志
java.util.List<Object> allLogs = redisUtil.lGet(redisKey, 0, -1);
// 找出并删除所有包含%的日志
for (int i = allLogs.size() - 1; i >= 0; i--) {
Object logObj = allLogs.get(i);
if (logObj != null) {
String logStr = logObj.toString();
if (logStr.contains("%")) {
redisUtil.lRemove(redisKey, 1, logStr);
log.debug("删除旧的进度日志: {}", logStr);
}
}
}
} catch (Exception e) {
log.error("删除进度日志时出错 - Redis键: {}", redisKey, e);
}
}
```
## 工作流程示例
### 场景1普通日志
```
输入: "YOLO训练开始"
处理: 直接添加到Redis列表
结果: Redis列表: ["YOLO训练开始"]
```
### 场景2进度日志第一次
```
输入: "Epoch 1/100: 50%|█████ | 100/200 [00:30<00:30]"
处理: 添加到Redis列表
结果: Redis列表: ["YOLO训练开始", "Epoch 1/100: 50%|█████ | 100/200 [00:30<00:30]"]
```
### 场景3进度日志更新
```
输入: "Epoch 1/100: 75%|███████▌ | 150/200 [00:45<00:15]"
处理: 1. 删除之前的进度日志 2. 添加新的进度日志
结果: Redis列表: ["YOLO训练开始", "Epoch 1/100: 75%|███████▌ | 150/200 [00:45<00:15]"]
```
### 场景4混合日志
```
输入序列:
1. "模型加载完成" -> 普通日志,直接添加
2. "Epoch 1/100: 25%|██▌ | 50/200 [00:15<00:45]" -> 进度日志,添加
3. "数据预处理完成" -> 普通日志,直接添加
4. "Epoch 1/100: 50%|█████ | 100/200 [00:30<00:30]" -> 进度日志,删除之前的进度日志并添加
最终Redis列表:
["模型加载完成", "数据预处理完成", "Epoch 1/100: 50%|█████ | 100/200 [00:30<00:30]"]
```
## 优势
1. **完整性**:所有重要日志都被保存,不会遗漏
2. **效率性**:进度日志只保留最新状态,避免重复
3. **可控性**最多50条日志防止Redis占用过多内存
4. **可维护性**:清晰的日志类型区分和调试信息
## 日志类型
- **普通日志**:训练开始、模型加载、错误信息等
- **进度日志**:包含 `%` 的进度条信息,如训练进度
## Redis 键格式
```
yolo:training_log:{trainId}
```
例如:
```
yolo:training_log:12345
```
## 过期时间
所有日志条目设置 5 天过期时间432,000 秒),自动清理旧数据。

@ -0,0 +1,85 @@
# YOLO 训练结果解析功能说明
## 功能概述
`YoloOperationServiceImpl.java` 中新增了自动检测和解析 YOLO 训练结果的功能,当训练日志中出现 "Results saved to" 时,系统会:
1. **自动提取保存路径**:从日志中解析出类似 `D:\data\train\1231_5\runs\train\train18` 的路径
2. **解析识别率信息**:从保存路径中的 `results.csv` 文件或其他日志文件中提取精度、召回率、mAP 等指标
3. **保存到数据库**:将路径和识别率信息保存到 `TrainResultDO` 表中
## 新增方法
### 1. `detectAndParseResultsSaved(String logLine, YoloConfig config, Integer trainId)`
- 检测日志中是否包含 "Results saved to" 或 "Saved to" 关键字
- 调用后续方法进行路径提取和识别率解析
### 2. `extractSavedPath(String logLine)`
- 使用正则表达式从日志行中提取保存路径
- 支持多种格式,如带引号、带统计信息等
### 3. `parseAccuracyFromSavedPath(String savedPath, YoloConfig config, Integer trainId)`
- 查找保存目录中的 `results.csv` 文件
- 解析 CSV 文件中的最后一行(最新结果)
- 提取 precision、recall、mAP50、mAP50-95 等指标
### 4. `parseResultsCsv(File csvFile, YoloConfig config, Integer trainId)`
- 解析 YOLO 生成的 `results.csv` 文件
- CSV 格式通常为:`epoch, train/box_loss, train/cls_loss, train/dfl_loss, metrics/precision, metrics/recall, metrics/mAP50, metrics/mAP50-95`
### 5. `saveAccuracyToDatabase(Integer trainId, double precision, double recall, double map)`
- 将解析到的识别率保存到数据库
- 格式:`precision:0.8521,recall:0.7432,map:0.8123`
## 集成位置
`trainAsync` 方法中,已经集成到日志读取循环中:
```java
if (stdoutLine != null) {
// ... 现有代码 ...
// 检测 "Results saved to" 并解析路径和识别率
detectAndParseResultsSaved(stdoutLine, config, trainId);
}
if (stderrLine != null) {
// ... 现有代码 ...
// 检测 "Results saved to" 并解析路径和识别率
detectAndParseResultsSaved(stderrLine, config, trainId);
}
```
## 预期输出格式
当 YOLO 训练完成时,日志中会出现类似内容:
```
Results saved to D:\data\train\1231_5\runs\train\train18
```
系统会自动:
1. 提取路径:`D:\data\train\1231_5\runs\train\train18`
2. 查找并解析该目录下的 `results.csv` 文件
3. 保存路径和识别率到数据库
## 注意事项
1. **文件权限**:确保 Java 进程有权限访问 YOLO 保存的文件
2. **CSV 格式**:依赖 YOLO 标准的 `results.csv` 输出格式
3. **异步处理**:识别率解析不会阻塞训练过程
4. **错误处理**:解析失败不会影响训练完成状态
## 数据库字段
保存的识别率格式:
```sql
-- TrainResultDO 表的 rate 字段
precision:0.8521,recall:0.7432,map:0.8123
```
路径字段:
```sql
-- TrainResultDO 表的 path 字段
D:\data\train\1231_5\runs\train\train18
```

@ -6,10 +6,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.excel.core.util.ExcelUtils; import cn.iocoder.yudao.framework.excel.core.util.ExcelUtils;
import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.SystemInfoRespVO; import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.*;
import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.TrainPageReqVO;
import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.TrainRespVO;
import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.TrainSaveReqVO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO;
import cn.iocoder.yudao.module.annotation.service.train.TrainService; import cn.iocoder.yudao.module.annotation.service.train.TrainService;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
@ -21,6 +18,7 @@ import jakarta.validation.Valid;
import org.springframework.security.access.prepost.PreAuthorize; import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
@ -53,7 +51,7 @@ public class TrainController {
} }
@PutMapping("/init") @PutMapping("/init")
@Operation(summary = "更新训练") @Operation(summary = "初始化训练文件夹")
@PreAuthorize("@ss.hasPermission('annotation:train:update')") @PreAuthorize("@ss.hasPermission('annotation:train:update')")
public CommonResult<Boolean> init(@Valid @RequestBody TrainSaveReqVO updateReqVO) { public CommonResult<Boolean> init(@Valid @RequestBody TrainSaveReqVO updateReqVO) {
trainService.init(updateReqVO); trainService.init(updateReqVO);
@ -69,6 +67,43 @@ public class TrainController {
return success(true); return success(true);
} }
@GetMapping("/test-images")
@Operation(summary = "获得测试图片")
// @Parameter(name = "id", description = "编号", required = true, example = "1024")
@PreAuthorize("@ss.hasPermission('annotation:train:query')")
public CommonResult<List<String>> testImages() {
List<String> train = trainService.testImages();
return success( train);
}
@PostMapping("/test-recognition")
@Operation(summary = "进行训练")
// @Parameter(name = "id", description = "编号", required = true, example = "1024")
@PreAuthorize("@ss.hasPermission('annotation:train:query')")
public CommonResult<String> testRecognition(@RequestBody RecognitionResult recognitionReqVO) {
String train = trainService.testRecognition(recognitionReqVO.getImageUrl(),recognitionReqVO.getTrainId());
return success( train);
}
@PostMapping("/test-images-install")
@Operation(summary = "上传测试图片并保存到指定文件夹")
@PreAuthorize("@ss.hasPermission('annotation:train:query')")
public CommonResult<List<String>> testImagesInstall(@RequestParam("files") MultipartFile[] files) {
List<String> train = trainService.testImagesInstall(files);
return success( train);
}
@DeleteMapping("/test-images-clean")
@Operation(summary = "清空后端测试图片")
@PreAuthorize("@ss.hasPermission('annotation:train:query')")
public CommonResult<Boolean> testImagesClean() {
trainService.testImagesClean();
return success(true);
}
@GetMapping("/get") @GetMapping("/get")
@Operation(summary = "获得训练") @Operation(summary = "获得训练")
@Parameter(name = "id", description = "编号", required = true, example = "1024") @Parameter(name = "id", description = "编号", required = true, example = "1024")

@ -0,0 +1,9 @@
package cn.iocoder.yudao.module.annotation.controller.admin.train.vo;
import lombok.Data;
@Data
public class RecognitionResult {
private String imageUrl;
private String trainId;
}

@ -11,6 +11,7 @@ import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInf
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoSaveReqVO; import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoSaveReqVO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo.TrainInfoDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo.TrainInfoDO;
import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService; import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService;
import cn.iocoder.yudao.module.system.util.RedisUtil;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
@ -69,13 +70,18 @@ public class TrainInfoController {
return success(BeanUtils.toBean(trainInfo, TrainInfoRespVO.class)); return success(BeanUtils.toBean(trainInfo, TrainInfoRespVO.class));
} }
@Resource
RedisUtil redisUtil;
@GetMapping("/get-list") @GetMapping("/get-list")
@Operation(summary = "获得识别结果") @Operation(summary = "获得识别结果")
@Parameter(name = "id", description = "编号", required = true, example = "1024") @Parameter(name = "id", description = "编号", required = true, example = "1024")
@PreAuthorize("@ss.hasPermission('annotation:train-Info:query')") @PreAuthorize("@ss.hasPermission('annotation:train-Info:query')")
public CommonResult<List<TrainInfoRespVO>> getTrainInfoList(@RequestParam("trainId") Integer id) {
List<TrainInfoRespVO> trainInfo = trainInfoService.getTrainInfoList(id); public CommonResult<List<String>> getTrainInfoList(@RequestParam("trainId") Integer id) {
return success(BeanUtils.toBean(trainInfo, TrainInfoRespVO.class)); List<Object> result = redisUtil.lGet("yolo:training_log:" + id, 0, -1);
List<String> result1 = result.stream().map(Object::toString).toList();
return success(result1);
} }
@GetMapping("/page") @GetMapping("/page")
@Operation(summary = "获得识别结果分页") @Operation(summary = "获得识别结果分页")

@ -88,6 +88,9 @@ public class YoloOperationController {
// trainDO.setPath(dictDataMap.get("training_address").getValue()); // trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap); YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id);
config.setTrainId(id);
CompletableFuture<YoloConfig> result = yoloOperationService.trainAsync( CompletableFuture<YoloConfig> result = yoloOperationService.trainAsync(
dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config);
@ -103,6 +106,8 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id); TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue()); // trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap); YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id);
YoloConfig result = yoloOperationService.validate( YoloConfig result = yoloOperationService.validate(
dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config);
return success(result); return success(result);
@ -117,6 +122,8 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id); TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue()); // trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap); YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id);
YoloConfig result = yoloOperationService.predict( YoloConfig result = yoloOperationService.predict(
dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config);
return success(result); return success(result);
@ -131,6 +138,8 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id); TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue()); // trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap); YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id);
YoloConfig result = yoloOperationService.classify( YoloConfig result = yoloOperationService.classify(
dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config);
return success(result); return success(result);
@ -145,6 +154,8 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id); TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue()); // trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap); YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id);
YoloConfig result = yoloOperationService.export( YoloConfig result = yoloOperationService.export(
dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config);
return success(result); return success(result);
@ -159,6 +170,8 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id); TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue()); // trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap); YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id);
YoloConfig result = yoloOperationService.track( YoloConfig result = yoloOperationService.track(
dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config);
return success(result); return success(result);

@ -75,52 +75,63 @@ public class TrainDO extends BaseDO {
private String type; private String type;
private String outPath;
public YoloConfig getYolofig(Map<String, DictDataDO> dictDataMap) { public YoloConfig getYolofig(Map<String, DictDataDO> dictDataMap) {
YoloConfig config = new YoloConfig(); YoloConfig config = new YoloConfig();
// 设置数据相关配置 // 设置数据相关配置
config.setDatasetPath(this.path); config.setDatasetPath(this.path);
// 设置训练参数 // 设置训练参数
if (this.round != null) {
config.setEpochs(this.round);
}
if (this.size != null) { if (this.size != null) {
config.setBatchSize(this.size); config.setEpochs(this.size);
}
if (this.imageSize != null) {
config.setImageSize(this.imageSize);
} }
if (this.round != null) {
// 设置数据集比例YOLO中需要转换 config.setBatchSize(this.round);
// 注意YOLO通常通过数据集配置文件来设置train/val/test比例 if (this.round != null) {
// 这里我们可以将比例信息保存到配置中,供后续处理使用 config.setEpochs(this.round);
}
// 设置预训练模型路径 if (this.size != null) {
if (this.type.equals("1")) { config.setBatchSize(this.size);
config.setModelPath(dictDataMap.get("detect_path").getValue()); }
config.setPretrained(true); if (this.imageSize != null) {
}else { config.setImageSize(this.imageSize);
config.setModelPath(dictDataMap.get("classify_path").getValue()); }
config.setPretrained(false);
} // 设置数据集比例YOLO中需要转换
// 注意YOLO通常通过数据集配置文件来设置train/val/test比例
// 设置GPU/CPU设备 // 这里我们可以将比例信息保存到配置中,供后续处理使用
// 设置预训练模型路径
if (this.type.equals("1")) {
config.setModelPath(dictDataMap.get("detect_path").getValue());
config.setPretrained(true);
} else {
config.setModelPath(dictDataMap.get("classify_path").getValue());
config.setPretrained(false);
}
// 设置GPU/CPU设备
// 设置输出路径默认在数据集路径下的runs/train目录
if (this.path != null) {
config.setOutputPath(this.path + "/runs/train");
}
// 设置其他默认参数
// 设置输出路径默认在数据集路径下的runs/train目录 config.setTaskType(this.type.equals("1") ? "detect" : "classify");
if (this.path != null) { config.setModelName(this.modelPath);
config.setOutputPath(this.path + "/runs/train"); config.setModelPath(this.modelPath);
config.setLearningRate(0.01);
config.setConfThresh(0.25);
config.setIouThresh(0.45);
config.setSave(true);
config.setSaveTxt(true);
return config;
} }
// 设置其他默认参数
config.setTaskType(this.type.equals("1") ? "detect" : "classify");
config.setModelName(this.modelPath);
config.setModelPath(this.modelPath);
config.setLearningRate(0.01);
config.setConfThresh(0.25);
config.setIouThresh(0.45);
config.setSave(true);
config.setSaveTxt(true);
return config; return config;
} }
} }

@ -205,8 +205,8 @@ public class YoloConfig {
/** /**
* mAP, precision, recall * mAP, precision, recall
*/ */
private String metrics;
private String metrics;
/** /**
* ID * ID
*/ */

@ -316,7 +316,8 @@ public class PythonVirtualEnvServiceImpl implements PythonVirtualEnvService {
} }
// 构建虚拟环境路径Python目录/envName // 构建虚拟环境路径Python目录/envName
String venvPath = pythonDir + File.separator + envName;
String venvPath = pythonDir + File.separator +"venv"+File.separator + envName;
File venvDir = new File(venvPath); File venvDir = new File(venvPath);
if (venvDir.exists() && venvDir.isDirectory()) { if (venvDir.exists() && venvDir.isDirectory()) {

@ -8,6 +8,10 @@ import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO;
import com.baomidou.mybatisplus.extension.service.IService; import com.baomidou.mybatisplus.extension.service.IService;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import org.springframework.web.multipart.MultipartFile;
import java.util.List;
/** /**
* Service * Service
* *
@ -62,4 +66,18 @@ public interface TrainService extends IService<TrainDO> {
public boolean init(TrainSaveReqVO createReqVO); public boolean init(TrainSaveReqVO createReqVO);
List<String> testImages();
/**
*
*
* @param files
* @return
*/
List<String> testImagesInstall(MultipartFile[] files);
void testImagesClean();
String testRecognition(String image,String trainId);
} }

@ -9,6 +9,12 @@ import cn.iocoder.yudao.module.annotation.dal.dataobject.datas.DatasDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkInfoDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkInfoDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.types.TypesDO;
import cn.iocoder.yudao.module.annotation.dal.mysql.train.TrainMapper;
import cn.iocoder.yudao.module.annotation.dal.mysql.trainresult.TrainResultMapper;
import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig;
import cn.iocoder.yudao.module.annotation.dal.dataobject.types.TypesDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.types.TypesDO;
import cn.iocoder.yudao.module.annotation.dal.mysql.train.TrainMapper; import cn.iocoder.yudao.module.annotation.dal.mysql.train.TrainMapper;
import cn.iocoder.yudao.module.annotation.service.MarkInfo.AnnotationData; import cn.iocoder.yudao.module.annotation.service.MarkInfo.AnnotationData;
@ -16,6 +22,8 @@ import cn.iocoder.yudao.module.annotation.service.MarkInfo.MarkInfoService;
import cn.iocoder.yudao.module.annotation.service.datas.DatasService; import cn.iocoder.yudao.module.annotation.service.datas.DatasService;
import cn.iocoder.yudao.module.annotation.service.mark.MarkService; import cn.iocoder.yudao.module.annotation.service.mark.MarkService;
import cn.iocoder.yudao.module.annotation.service.types.TypesService; import cn.iocoder.yudao.module.annotation.service.types.TypesService;
import cn.iocoder.yudao.module.annotation.service.yolo.YoloOperationService;
import cn.iocoder.yudao.module.system.dal.dataobject.dict.DictDataDO; import cn.iocoder.yudao.module.system.dal.dataobject.dict.DictDataDO;
import cn.iocoder.yudao.module.system.service.dict.DictDataService; import cn.iocoder.yudao.module.system.service.dict.DictDataService;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
@ -26,6 +34,8 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
import org.springframework.web.multipart.MultipartFile;
import javax.imageio.ImageIO; import javax.imageio.ImageIO;
import java.awt.image.BufferedImage; import java.awt.image.BufferedImage;
import java.io.*; import java.io.*;
@ -85,6 +95,9 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
// 返回 // 返回
return train.getId(); return train.getId();
} }
@Resource
YoloOperationService yoloOperationService;
public boolean init(TrainSaveReqVO createReqVO) { public boolean init(TrainSaveReqVO createReqVO) {
try { try {
TrainDO train = trainMapper.selectById(createReqVO.getId()); TrainDO train = trainMapper.selectById(createReqVO.getId());
@ -120,7 +133,9 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
// 生成YOLO配置文件 // 生成YOLO配置文件
generateYamlConfig(typesList, taskType, yoloDatasetPath); generateYamlConfig(typesList, taskType, yoloDatasetPath);
update( new UpdateWrapper<TrainDO>().eq("id", createReqVO.getId()).set("train_type", 6)); update( new UpdateWrapper<TrainDO>().eq("id", createReqVO.getId()).set("train_type", 6));
yoloOperationService.generateTrainPythonScript(train, yoloDatasetPath);
log.info("YOLO数据集生成完成: {}", yoloDatasetPath); log.info("YOLO数据集生成完成: {}", yoloDatasetPath);
return true; return true;
@ -130,6 +145,187 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
} }
} }
@Override
public List<String> testImages() {
Map<String, DictDataDO> dictDataMap = dictDataService.getDictDataList("visual_annotation_conf");
// 遍历置顶文件夹下的所有文件,获取文件名称
File dir = new File(dictDataMap.get("base_path").getValue() + "test/");
File[] files = dir.listFiles();
List<String> fileNames = new ArrayList<>();
if (files != null) {
for (File file : files) {
// 判断是否是图片
if (file.getName().endsWith(".jpg") || file.getName().endsWith(".png")||file.getName().endsWith(".jpeg")||file.getName().endsWith(".bmp")) {
fileNames.add(dictDataMap.get("base_url").getValue() + "test/" + file.getName());
}
}
}
return fileNames;
}
@Override
public List<String> testImagesInstall(MultipartFile[] files) {
if (files == null || files.length == 0) {
log.warn("没有上传任何文件");
return new ArrayList<>();
}
Map<String, DictDataDO> dictDataMap = dictDataService.getDictDataList("visual_annotation_conf");
String basePath = dictDataMap.get("base_path").getValue();
String baseUrl = dictDataMap.get("base_url").getValue();
// 目标文件夹install_test_images
File targetDir = new File(basePath + "test/");
if (!targetDir.exists()) {
targetDir.mkdirs();
log.info("创建目标文件夹: {}", targetDir.getAbsolutePath());
}
List<String> savedImagePaths = new ArrayList<>();
for (MultipartFile file : files) {
try {
String originalFilename = file.getOriginalFilename();
if (originalFilename == null || originalFilename.trim().isEmpty()) {
log.warn("文件名为空,跳过");
continue;
}
// 检查文件类型是否为图片
if (!isImageFile(originalFilename)) {
log.warn("非图片文件,跳过: {}", originalFilename);
continue;
}
// 生成唯一文件名(避免重复)
String fileExtension = getFileExtension(originalFilename);
String uniqueFileName = generateUniqueFileName(originalFilename, fileExtension);
File targetFile = new File(targetDir, uniqueFileName);
// 保存文件
file.transferTo(targetFile);
log.info("保存图片: {} -> {}", originalFilename, targetFile.getAbsolutePath());
// 添加返回的URL路径
savedImagePaths.add(baseUrl + "test/" + uniqueFileName);
} catch (IOException e) {
log.error("保存图片失败: {}", file.getOriginalFilename(), e);
}
}
log.info("testImagesInstall 完成,共保存 {} 张图片到 install_test_images 文件夹", savedImagePaths.size());
return savedImagePaths;
}
public void testImagesClean(){
Map<String, DictDataDO> dictDataMap = dictDataService.getDictDataList("visual_annotation_conf");
String basePath = dictDataMap.get("base_path").getValue();
// 清空install_test_images文件夹下的所有文件
File dir = new File(basePath + "test/");
File[] files = dir.listFiles();
if (files != null) {
for (File file : files) {
file.delete();
}
}
}
@Resource
private TrainResultMapper trainResultMapper;
public String testRecognition(String image,String trainId){
TrainDO train = trainMapper.selectById(trainId);
TrainResultDO trainResultDO = trainResultMapper.selectOne(new QueryWrapper<TrainResultDO>()
.eq("train_id", trainId)
.orderByDesc("create_time")
.last("limit 1"));
Map<String, DictDataDO> dictDataMap = dictDataService.getDictDataList("visual_annotation_conf");
// 删除前端信息
image = image.replace(dictDataMap.get("base_url").getValue()+"test/","");
// yoloOperationService.classify()
YoloConfig yoloConfig = new YoloConfig();
yoloConfig.setModelPath(trainResultDO.getPath()+"/weights/best.pt");
yoloConfig.setInputPath(dictDataMap.get("base_path").getValue() + "test/" + image);
yoloConfig.setOutputPath(dictDataMap.get("base_path").getValue() + "test/run/");
YoloConfig yoloConfigResult = yoloOperationService.predict(dictDataMap.get("python_path").getValue(),dictDataMap.get("python_venv").getValue(),yoloConfig);
String url = replacePathPrefix(yoloConfigResult.getOutputPath(), dictDataMap.get("base_path").getValue(), dictDataMap.get("base_url").getValue()) + "/" + image;
return url;
}
/**
*
* outputPathbasePathbaseUrl
*
*/
private String replacePathPrefix(String outputPath, String basePath, String baseUrl) {
if (outputPath == null || basePath == null || baseUrl == null) {
return outputPath;
}
// 标准化路径分隔符,统一使用正斜杠
String normalizedOutputPath = outputPath.replace('\\', '/');
String normalizedBasePath = basePath.replace('\\', '/');
// 确保basePath以正斜杠结尾以便精确匹配
if (!normalizedBasePath.endsWith("/")) {
normalizedBasePath += "/";
}
// 检查outputPath是否以basePath开头
if (normalizedOutputPath.startsWith(normalizedBasePath)) {
// 提取相对路径部分
String relativePath = normalizedOutputPath.substring(normalizedBasePath.length());
// 确保baseUrl以正斜杠结尾
if (!baseUrl.endsWith("/")) {
baseUrl += "/";
}
return baseUrl + relativePath;
}
// 如果前缀不匹配,返回原始路径(但标准化斜杠)
return normalizedOutputPath;
}
/**
*
*/
private boolean isImageFile(String filename) {
String lowerCaseFilename = filename.toLowerCase();
return lowerCaseFilename.endsWith(".jpg") ||
lowerCaseFilename.endsWith(".jpeg") ||
lowerCaseFilename.endsWith(".png") ||
lowerCaseFilename.endsWith(".bmp") ||
lowerCaseFilename.endsWith(".gif");
}
/**
*
*/
private String getFileExtension(String filename) {
int lastDotIndex = filename.lastIndexOf('.');
return lastDotIndex != -1 ? filename.substring(lastDotIndex + 1).toLowerCase() : "";
}
/**
*
*/
private String generateUniqueFileName(String originalFilename, String extension) {
String nameWithoutExtension = originalFilename.substring(0, originalFilename.lastIndexOf('.'));
String timestamp = String.valueOf(System.currentTimeMillis());
String randomSuffix = String.valueOf((int)(Math.random() * 1000));
if (!extension.isEmpty()) {
return nameWithoutExtension + "_" + timestamp + "_" + randomSuffix + "." + extension;
} else {
return nameWithoutExtension + "_" + timestamp + "_" + randomSuffix;
}
}
/** /**
* *
*/ */
@ -150,7 +346,8 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
*/ */
private String createYoloDatasetPath(TrainDO train, String customPath) { private String createYoloDatasetPath(TrainDO train, String customPath) {
String basePath = customPath != null ? customPath : System.getProperty("user.dir") + "/yolo_datasets"; String basePath = customPath != null ? customPath : System.getProperty("user.dir") + "/yolo_datasets";
String datasetPath = basePath + "/"+train.getName()+"_" + train.getId();
String datasetPath = basePath + "/datas_" + train.getId();
File datasetDir = new File(datasetPath); File datasetDir = new File(datasetPath);
if (!datasetDir.exists()) { if (!datasetDir.exists()) {
@ -246,6 +443,7 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
File labelFile = new File(labelDir, labelFileName); File labelFile = new File(labelDir, labelFileName);
List<MarkInfoDO> markInfoList = markInfoService.list(new QueryWrapper<MarkInfoDO>() List<MarkInfoDO> markInfoList = markInfoService.list(new QueryWrapper<MarkInfoDO>()
.eq("data_id", mark.getDataId()).eq("mark_id", mark.getId())); .eq("data_id", mark.getDataId()).eq("mark_id", mark.getId()));
try (PrintWriter writer = new PrintWriter(new FileWriter(labelFile))) { try (PrintWriter writer = new PrintWriter(new FileWriter(labelFile))) {
@ -588,9 +786,19 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
return gpuList; return gpuList;
} }
public static void main(String[] args) { public static void main(String[] args) {
TrainServiceImpl trainService = new TrainServiceImpl(); File dir = new File("D:/data");
System.out.println(trainService.getPythonInfo()); File[] files = dir.listFiles();
List<String> fileNames = new ArrayList<>();
if (files != null) {
for (File file : files) {
fileNames.add(file.getName());
}
}
System.out.println(fileNames);
} }
private SystemInfoRespVO.PythonInfo getPythonInfo() { private SystemInfoRespVO.PythonInfo getPythonInfo() {
SystemInfoRespVO.PythonInfo pythonInfo = new SystemInfoRespVO.PythonInfo(); SystemInfoRespVO.PythonInfo pythonInfo = new SystemInfoRespVO.PythonInfo();

@ -1,5 +1,6 @@
package cn.iocoder.yudao.module.annotation.service.yolo; package cn.iocoder.yudao.module.annotation.service.yolo;
import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO;
import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig; import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
@ -71,4 +72,5 @@ public interface YoloOperationService {
*/ */
YoloConfig track(String pythonPath, String envName, YoloConfig config); YoloConfig track(String pythonPath, String envName, YoloConfig config);
void generateTrainPythonScript(TrainDO train, String yoloDatasetPath);
} }

@ -1,9 +1,22 @@
package cn.iocoder.yudao.module.annotation.service.yolo; package cn.iocoder.yudao.module.annotation.service.yolo;
import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultSaveReqVO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO;
import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig; import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig;
import cn.iocoder.yudao.module.annotation.service.python.PythonVirtualEnvService; import cn.iocoder.yudao.module.annotation.service.python.PythonVirtualEnvService;
import cn.iocoder.yudao.module.annotation.service.python.vo.PythonVirtualEnvInfo; import cn.iocoder.yudao.module.annotation.service.python.vo.PythonVirtualEnvInfo;
import cn.iocoder.yudao.module.annotation.service.train.TrainService;
import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService;
import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService;
import cn.iocoder.yudao.module.system.util.RedisUtil;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Lazy;
import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService; import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService;
import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService; import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService;
import jakarta.annotation.Resource; import jakarta.annotation.Resource;
@ -36,19 +49,25 @@ public class YoloOperationServiceImpl implements YoloOperationService {
@Resource @Resource
private TrainResultService trainResultService; private TrainResultService trainResultService;
@Resource
private YoloServiceClient yoloServiceClient;
@Value("${yolo.service.enabled:false}")
private boolean useSocketService;
// 异步执行器 // 异步执行器
private final Executor asyncExecutor = Executors.newCachedThreadPool(); private final Executor asyncExecutor = Executors.newCachedThreadPool();
@Override @Override
public CompletableFuture<YoloConfig> trainAsync(String pythonPath, String envName, YoloConfig config) { public CompletableFuture<YoloConfig> trainAsync(String pythonPath, String envName, YoloConfig config) {
return CompletableFuture.supplyAsync(() -> { return CompletableFuture.supplyAsync(() -> {
config.setStatus("running"); config.setStatus("running");
config.setProgress(0); config.setProgress(0);
config.setTaskId(UUID.randomUUID().toString()); config.setTaskId(UUID.randomUUID().toString());
// 生成训练ID // 生成训练ID
Integer trainId = generateTrainId(config); Integer trainId = config.getTrainId();
// 记录上一轮的epoch用于检测epoch完成 // 记录上一轮的epoch用于检测epoch完成
int[] lastSavedEpoch = {0}; int[] lastSavedEpoch = {0};
@ -63,16 +82,13 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// 构建训练命令 // 构建训练命令
String command = buildTrainCommand(envInfo, config); String command = buildTrainCommand(envInfo, config);
log.info("开始YOLO训练命令: {}", command); log.info("开始YOLO训练命令: {}", command);
// 执行训练 // 执行训练
ProcessBuilder processBuilder = new ProcessBuilder(); ProcessBuilder processBuilder = new ProcessBuilder();
// 在Windows上使用cmd在Linux/Mac上使用bash
String os = System.getProperty("os.name").toLowerCase(); String os = System.getProperty("os.name").toLowerCase();
if (os.contains("win")) { if (os.contains("win")) {
processBuilder.command("cmd", "/c", command); processBuilder.command("cmd", "/c", command);
// 设置环境变量确保正确处理中文路径
processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); processBuilder.environment().put("PYTHONIOENCODING", "utf-8");
processBuilder.environment().put("LANG", "zh_CN.UTF-8"); processBuilder.environment().put("LANG", "zh_CN.UTF-8");
} else { } else {
@ -80,32 +96,35 @@ public class YoloOperationServiceImpl implements YoloOperationService {
processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); processBuilder.environment().put("PYTHONIOENCODING", "utf-8");
} }
Process process = processBuilder.start(); Process process = processBuilder.start();
BufferedReader stdoutReader = new BufferedReader(new InputStreamReader(process.getInputStream())); BufferedReader stdoutReader = new BufferedReader(new InputStreamReader(process.getInputStream()));
BufferedReader stderrReader = new BufferedReader(new InputStreamReader(process.getErrorStream())); BufferedReader stderrReader = new BufferedReader(new InputStreamReader(process.getErrorStream()));
StringBuilder fullOutput = new StringBuilder(); StringBuilder fullOutput = new StringBuilder();
String line; String stdoutLine;
String stderrLine = null;
// 同时读取标准输出和错误输出 // 同时读取标准输出和错误输出
String stdoutLine, stderrLine = null;
while ((stdoutLine = stdoutReader.readLine()) != null || (stderrLine = stderrReader.readLine()) != null) { while ((stdoutLine = stdoutReader.readLine()) != null || (stderrLine = stderrReader.readLine()) != null) {
if (stdoutLine != null) { if (stdoutLine != null) {
log.info("YOLO训练日志: {}", stdoutLine); log.info("YOLO训练日志: {}", stdoutLine);
fullOutput.append(stdoutLine).append("\n"); fullOutput.append(stdoutLine).append("\n");
// 详细解析YOLO训练进度和结果 // 详细解析YOLO训练进度和结果
parseTrainingProgress(stdoutLine, config, trainId, lastSavedEpoch); manageTrainingLogInRedis(trainId, stdoutLine);
config.setLogMessage(stdoutLine); config.setLogMessage(stdoutLine);
detectAndParseResultsSaved(stdoutLine, config, trainId);
parseTrainingProgress(stdoutLine, config, trainId, lastSavedEpoch);
} }
if (stderrLine != null) { if (stderrLine != null) {
log.info("YOLO训练日志: {}", stderrLine); log.info("YOLO训练日志: {}", stderrLine);
fullOutput.append(stderrLine).append("\n"); fullOutput.append(stderrLine).append("\n");
// 详细解析YOLO训练进度和结果 manageTrainingLogInRedis(trainId, stderrLine);
parseTrainingProgress(stderrLine, config, trainId, lastSavedEpoch);
config.setLogMessage(stderrLine); config.setLogMessage(stderrLine);
detectAndParseResultsSaved(stderrLine, config, trainId);
parseTrainingProgress(stderrLine, config, trainId, lastSavedEpoch);
} }
} }
@ -138,6 +157,19 @@ public class YoloOperationServiceImpl implements YoloOperationService {
@Override @Override
public YoloConfig validate(String pythonPath, String envName, YoloConfig config) { public YoloConfig validate(String pythonPath, String envName, YoloConfig config) {
// 优先尝试使用socket服务如果启用且可用
if (useSocketService && yoloServiceClient != null && yoloServiceClient.isServiceAvailable()) {
log.info("使用socket服务执行YOLO验证");
return yoloServiceClient.validateViaSocket(
config.getModelPath(),
config.getValDatasetPath(),
config
);
}
// 回退到原有的ProcessBuilder方式
log.info("使用传统方式执行YOLO验证");
try { try {
PythonVirtualEnvInfo envInfo = pythonVirtualEnvService.activateVirtualEnv(pythonPath, envName); PythonVirtualEnvInfo envInfo = pythonVirtualEnvService.activateVirtualEnv(pythonPath, envName);
if (!envInfo.getVirtualEnvExists() || !envInfo.getYoloInstalled()) { if (!envInfo.getVirtualEnvExists() || !envInfo.getYoloInstalled()) {
@ -205,6 +237,20 @@ public class YoloOperationServiceImpl implements YoloOperationService {
@Override @Override
public YoloConfig predict(String pythonPath, String envName, YoloConfig config) { public YoloConfig predict(String pythonPath, String envName, YoloConfig config) {
// 优先尝试使用socket服务如果启用且可用
if (useSocketService && yoloServiceClient != null && yoloServiceClient.isServiceAvailable()) {
log.info("使用socket服务执行YOLO预测");
return yoloServiceClient.predictViaSocket(
config.getModelPath(),
config.getInputPath(),
config.getOutputPath(),
config
);
}
// 回退到原有的ProcessBuilder方式
log.info("使用传统方式执行YOLO预测");
try { try {
PythonVirtualEnvInfo envInfo = pythonVirtualEnvService.activateVirtualEnv(pythonPath, envName); PythonVirtualEnvInfo envInfo = pythonVirtualEnvService.activateVirtualEnv(pythonPath, envName);
if (!envInfo.getVirtualEnvExists() || !envInfo.getYoloInstalled()) { if (!envInfo.getVirtualEnvExists() || !envInfo.getYoloInstalled()) {
@ -240,10 +286,22 @@ public class YoloOperationServiceImpl implements YoloOperationService {
if (stdoutLine != null) { if (stdoutLine != null) {
output.append(stdoutLine).append("\n"); output.append(stdoutLine).append("\n");
log.info("YOLO预测日志: {}", stdoutLine); log.info("YOLO预测日志: {}", stdoutLine);
if (stdoutLine.contains("Results saved to")) {
stdoutLine = stdoutLine.replace("\u001B","").replace("[1m","").replace("[0m","");
config.setOutputPath(stdoutLine.split(" ")[3]);
}
} }
if (stderrLine != null) { if (stderrLine != null) {
output.append(stderrLine).append("\n"); output.append(stderrLine).append("\n");
log.info("YOLO预测日志: {}", stderrLine); log.info("YOLO预测日志: {}", stderrLine);
if (stderrLine.contains("Results saved to")) {
stderrLine = stderrLine.replace("\u001B","").replace("[1m","").replace("[0m","");
config.setOutputPath(stderrLine.split(" ")[3]);
}
} }
} }
@ -394,6 +452,8 @@ public class YoloOperationServiceImpl implements YoloOperationService {
if ("conda".equals(envInfo.getVirtualEnvType())) { if ("conda".equals(envInfo.getVirtualEnvType())) {
command.append("conda run -n ").append(new File(envInfo.getVirtualEnvPath()).getName()).append(" "); command.append("conda run -n ").append(new File(envInfo.getVirtualEnvPath()).getName()).append(" ");
} else { } else {
command.append(normalizePath(envInfo.getVirtualEnvPythonPath())).append(" ");
command.append(envInfo.getVirtualEnvPythonPath()).append(" "); command.append(envInfo.getVirtualEnvPythonPath()).append(" ");
} }
@ -402,6 +462,40 @@ public class YoloOperationServiceImpl implements YoloOperationService {
String outputPath = normalizePath(config.getOutputPath()); String outputPath = normalizePath(config.getOutputPath());
String modelPath = config.getModelPath() != null ? normalizePath(config.getModelPath()) : null; String modelPath = config.getModelPath() != null ? normalizePath(config.getModelPath()) : null;
// 调试日志:检查模型路径
log.info("模型路径配置 - modelPath: {}, modelName: {}", config.getModelPath(), config.getModelName());
// command.append("-c \"");
command.append(datasetPath).append("/train.py ");
// command.append("from ultralytics import YOLO; ");
//
// if (modelPath != null && !modelPath.trim().isEmpty()) {
// command.append("model = YOLO('").append(escapePythonString(modelPath)).append("'); ");
// } else {
// command.append("model = YOLO('").append(config.getModelName()).append("'); ");
// }
//
// // 使用raw string来处理中文路径
// command.append("model.train(data=r'").append(escapePythonString(datasetPath+"\\dataset.yaml")).append("', ");
// command.append("epochs=").append(config.getEpochs()).append(", ");
// command.append("batch=").append(config.getBatchSize()).append(", ");
// command.append("imgsz=").append(config.getImageSize()).append(", ");
//
// // 取消注释设备和学习率配置
//// command.append("device=").append(config.getDevice()).append(", ");
//// command.append("lr0=").append(config.getLearningRate()).append(", ");
//// command.append("workers=").append(config.getWorkers()).append(", ");
////
// if (config.getSavePeriod() > 0) {
// command.append("save_period=").append(config.getSavePeriod()).append(", ");
// }
//
// command.append("project=r'").append(escapePythonString(outputPath)).append("')\"");
//
command.append("-c \""); command.append("-c \"");
@ -472,8 +566,8 @@ public class YoloOperationServiceImpl implements YoloOperationService {
} }
String inputPath = normalizePath(config.getInputPath()); String inputPath = normalizePath(config.getInputPath());
String outputPath = normalizePath(config.getOutputPath());
String modelPath = normalizePath(config.getModelPath()); String modelPath = normalizePath(config.getModelPath());
command.append("-c \""); command.append("-c \"");
@ -481,18 +575,17 @@ public class YoloOperationServiceImpl implements YoloOperationService {
command.append("from ultralytics import YOLO; "); command.append("from ultralytics import YOLO; ");
command.append("model = YOLO(r'").append(escapePythonString(modelPath)).append("'); "); command.append("model = YOLO(r'").append(escapePythonString(modelPath)).append("'); ");
command.append("results = model.predict(source=r'").append(escapePythonString(inputPath)).append("', "); command.append("results = model.predict(source=r'").append(escapePythonString(inputPath)).append("', ");
command.append("conf=").append(config.getConfThresh()).append(", ");
command.append("iou=").append(config.getIouThresh()).append(", "); // 在这里我们将 save=true 改为 save=True
command.append("imgsz=").append(config.getImageSize()).append(", "); command.append("save=").append(config.isSave() ? "True" : "False").append(", ");
command.append("device='").append(config.getDevice()).append("', "); command.append("project='").append(escapePythonString(outputPath)).append("')\"");
command.append("save=").append(config.isSave()).append(", ");
command.append("show=").append(config.isShow()).append(", ");
command.append("save_txt=").append(config.isSaveTxt()).append(", ");
command.append("save_crop=").append(config.isSaveCrops()).append(")\"");
return command.toString(); return command.toString();
} }
private String buildExportCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) { private String buildExportCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) {
StringBuilder command = new StringBuilder(); StringBuilder command = new StringBuilder();
@ -534,6 +627,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
return command.toString(); return command.toString();
} }
/** /**
* *
*/ */
@ -631,6 +725,8 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// 忽略解析错误,继续处理下一行 // 忽略解析错误,继续处理下一行
} }
} }
/** /**
* *
@ -820,6 +916,12 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// 不影响训练完成状态 // 不影响训练完成状态
} }
} }
@Resource
@Lazy
TrainService trainService;
/** /**
* TrainResultDO * TrainResultDO
@ -827,8 +929,11 @@ public class YoloOperationServiceImpl implements YoloOperationService {
private void saveTrainResultDO(TrainResultDO trainResult) { private void saveTrainResultDO(TrainResultDO trainResult) {
try { try {
// 创建SaveReqVO // 创建SaveReqVO
cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultSaveReqVO saveReqVO =
new cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultSaveReqVO(); TrainResultSaveReqVO saveReqVO =
new TrainResultSaveReqVO();
// 设置值 // 设置值
saveReqVO.setTrainId(trainResult.getTrainId()); saveReqVO.setTrainId(trainResult.getTrainId());
@ -838,6 +943,11 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// 保存 // 保存
trainResultService.createTrainResult(saveReqVO); trainResultService.createTrainResult(saveReqVO);
trainService.update(new UpdateWrapper<TrainDO>().eq("id", trainResult.getTrainId())
.set("train_type", 4));
} catch (Exception e) { } catch (Exception e) {
log.error("保存训练结果失败: {}", trainResult, e); log.error("保存训练结果失败: {}", trainResult, e);
@ -959,4 +1069,687 @@ public class YoloOperationServiceImpl implements YoloOperationService {
return escaped.toString(); return escaped.toString();
} }
@Resource
private RedisUtil redisUtil;
/**
* "Results saved to"
*
* @param logLine
* @param config YOLO
* @param trainId ID
*/
private void detectAndParseResultsSaved(String logLine, YoloConfig config, Integer trainId) {
try {
if (logLine == null || logLine.trim().isEmpty()) {
return;
}
String trimmedLine = logLine.trim();
// 检测 "Results saved to" 或 "Saved to" 关键字
if (trimmedLine.contains("Results saved to ") || trimmedLine.contains("Saved to ")) {
// 提取保存路径
String savedPath = extractSavedPath(trimmedLine);
savedPath = savedPath.replace("\u001B","").replace("[1m","").replace("[0m","");
if (savedPath != null && !savedPath.trim().isEmpty()) {
log.info("检测到YOLO结果保存路径: {}", savedPath);
// 保存路径到配置中
config.setOutputPath(savedPath);
// 尝试从保存路径中获取识别率信息
parseAccuracyFromSavedPath(savedPath, config, trainId);
// 保存结果路径到数据库
saveResultsPath(trainId, savedPath, config);
}
}
} catch (Exception e) {
log.error("检测和解析Results saved to时出错", e);
}
}
/**
*
*
* @param logLine "Results saved to"
* @return null
*/
private String extractSavedPath(String logLine) {
try {
// 处理不同格式的日志输出
String[] patterns = {
"Results saved to (.*?)\\s*$",
"Saved to (.*?)\\s*$",
"Results saved to (.*?)(?:\\s+\\(\\d+.*?\\))?$", // 带统计信息的格式
"Saved to (.*?)(?:\\s+\\(\\d+.*?\\))?$"
};
for (String pattern : patterns) {
java.util.regex.Pattern p = java.util.regex.Pattern.compile(pattern);
java.util.regex.Matcher m = p.matcher(logLine);
if (m.find()) {
String path = m.group(1).trim();
// 移除可能的引号
if (path.startsWith("\"") && path.endsWith("\"")) {
path = path.substring(1, path.length() - 1);
}
if (path.startsWith("'") && path.endsWith("'")) {
path = path.substring(1, path.length() - 1);
}
return path;
}
}
// 如果正则匹配失败,尝试简单的字符串分割
String[] parts = logLine.split("Results saved to|Saved to");
if (parts.length > 1) {
String path = parts[1].trim();
// 移除可能的引号
if (path.startsWith("\"") && path.endsWith("\"")) {
path = path.substring(1, path.length() - 1);
}
if (path.startsWith("'") && path.endsWith("'")) {
path = path.substring(1, path.length() - 1);
}
return path;
}
} catch (Exception e) {
log.warn("提取保存路径失败: {}", logLine, e);
}
return null;
}
/**
*
* YOLO
*
* @param savedPath YOLO
* @param config YOLO
* @param trainId ID
*/
private void parseAccuracyFromSavedPath(String savedPath, YoloConfig config, Integer trainId) {
try {
log.info("开始从保存路径解析识别率: {}", savedPath);
java.io.File saveDir = new java.io.File(savedPath);
if (!saveDir.exists() || !saveDir.isDirectory()) {
log.warn("保存路径不存在或不是目录: {}", savedPath);
return;
}
// 查找可能包含识别率的文件
java.io.File[] files = saveDir.listFiles();
if (files == null) {
log.warn("无法读取保存目录内容: {}", savedPath);
return;
}
// 优先查找results.csv文件YOLO通常会在训练完成后生成
for (java.io.File file : files) {
if (file.getName().equals("results.csv")) {
parseResultsCsv(file, config, trainId);
return;
}
}
// 如果没有results.csv查找其他可能的文件
for (java.io.File file : files) {
String fileName = file.getName().toLowerCase();
// 查找train_batch或val_batch的图片文件文件名中可能包含性能指标
if (fileName.contains("val_batch") && fileName.endsWith(".jpg")) {
log.info("找到验证批次图片,但无法直接从中提取数值指标");
}
// 查找可能的日志文件
if (fileName.endsWith(".log") || fileName.endsWith(".txt")) {
parseLogFileForAccuracy(file, config, trainId);
}
}
// 如果没有找到结果文件,尝试解析路径中的训练信息
parseTrainingInfoFromPath(savedPath, config);
} catch (Exception e) {
log.error("从保存路径解析识别率时出错: {}", savedPath, e);
}
}
/**
* results.csv
* (detect)(classify)(segment)
*
* @param csvFile results.csv
* @param config YOLO
* @param trainId ID
*/
private void parseResultsCsv(java.io.File csvFile, YoloConfig config, Integer trainId) {
try {
log.info("解析results.csv文件: {} (任务类型: {})", csvFile.getAbsolutePath(), config.getTaskType());
java.util.List<String> lines = java.nio.file.Files.readAllLines(csvFile.toPath());
if (lines.isEmpty()) {
return;
}
// 读取标题行和最后一行(最新的结果)
String headerLine = lines.get(0);
String lastLine = lines.get(lines.size() - 1);
String[] headers = headerLine.split(",");
String[] values = lastLine.split(",");
log.debug("CSV标题行: {}", headerLine);
log.debug("CSV数据行: {}", lastLine);
// 根据任务类型和标题解析不同的指标
String taskType = config.getTaskType();
if (taskType == null) {
taskType = "detect"; // 默认为检测任务
}
double precision = 0.0, recall = 0.0, map = 0.0, accuracy = 0.0, f1Score = 0.0;
try {
if ("classify".equals(taskType) || "classification".equals(taskType)) {
// 分类任务:解析 accuracy, top1_acc, top5_acc, loss 等
for (int i = 0; i < headers.length && i < values.length; i++) {
String header = headers[i].trim().toLowerCase();
String value = values[i].trim();
try {
if (header.contains("accuracy") || header.contains("acc")) {
accuracy = Double.parseDouble(value);
config.setPrecision(accuracy); // 分类任务用accuracy作为precision
log.info("分类任务解析到准确率(Accuracy): {}", accuracy);
} else if (header.contains("top1") || header.contains("top_1")) {
double top1Acc = Double.parseDouble(value);
log.info("分类任务解析到Top-1准确率: {}", top1Acc);
} else if (header.contains("top5") || header.contains("top_5")) {
double top5Acc = Double.parseDouble(value);
log.info("分类任务解析到Top-5准确率: {}", top5Acc);
} else if (header.contains("loss")) {
// 记录损失值
log.debug("分类任务解析到损失值: {}", value);
}
} catch (NumberFormatException e) {
log.debug("跳过无法解析的数值: {}", value);
}
}
// 保存分类任务的准确率到数据库
if (accuracy > 0) {
saveClassificationAccuracyToDatabase(trainId, accuracy);
}
} else if ("detect".equals(taskType) || "segment".equals(taskType)) {
// 检测或分割任务:解析 precision, recall, mAP50, mAP50-95 等
for (int i = 0; i < headers.length && i < values.length; i++) {
String header = headers[i].trim().toLowerCase();
String value = values[i].trim();
try {
if (header.contains("precision")) {
precision = Double.parseDouble(value);
config.setPrecision(precision);
} else if (header.contains("recall")) {
recall = Double.parseDouble(value);
config.setRecall(recall);
} else if (header.contains("map50-95") || header.contains("map50_95")) {
map = Double.parseDouble(value);
config.setMap(map);
} else if (header.contains("map50")) {
double map50 = Double.parseDouble(value);
log.debug("解析到mAP50: {}", map50);
// 如果没有mAP50-95使用mAP50作为map值
if (map == 0.0) {
map = map50;
config.setMap(map);
}
}
} catch (NumberFormatException e) {
log.debug("跳过无法解析的数值: {}", value);
}
}
log.info("检测/分割任务解析到识别率 - Precision: {}, Recall: {}, mAP: {}",
precision, recall, map);
// 保存检测/分割任务的识别率到数据库
if (precision > 0 || recall > 0 || map > 0) {
saveAccuracyToDatabase(trainId, precision, recall, map);
}
} else {
// 其他未知任务类型,尝试通用解析
log.warn("未知任务类型: {}, 尝试通用解析", taskType);
parseGenericResults(headers, values, config);
}
} catch (Exception e) {
log.warn("解析results.csv数值时出错: {}", lastLine, e);
}
} catch (Exception e) {
log.error("读取results.csv文件时出错", e);
}
}
/**
* results.csv
*
* @param headers CSV
* @param values CSV
* @param config YOLO
*/
private void parseGenericResults(String[] headers, String[] values, YoloConfig config) {
try {
double precision = 0.0, recall = 0.0, map = 0.0, accuracy = 0.0;
for (int i = 0; i < headers.length && i < values.length; i++) {
String header = headers[i].trim().toLowerCase();
String value = values[i].trim();
try {
if (header.contains("precision")) {
precision = Double.parseDouble(value);
config.setPrecision(precision);
} else if (header.contains("recall")) {
recall = Double.parseDouble(value);
config.setRecall(recall);
} else if (header.contains("map")) {
map = Double.parseDouble(value);
config.setMap(map);
} else if (header.contains("accuracy") || header.contains("acc")) {
accuracy = Double.parseDouble(value);
config.setPrecision(accuracy); // 用accuracy替代precision
}
} catch (NumberFormatException e) {
log.debug("通用解析跳过无法解析的数值: {}", value);
}
}
log.info("通用解析结果 - Precision/Accuracy: {}, Recall: {}, mAP: {}",
Math.max(precision, accuracy), recall, map);
} catch (Exception e) {
log.error("通用解析失败", e);
}
}
/**
*
*
* @param trainId ID
* @param accuracy
*/
private void saveClassificationAccuracyToDatabase(Integer trainId, double accuracy) {
try {
// 构建分类任务的准确率字符串
String rateStr = String.format("accuracy:%.4f", accuracy);
// 更新数据库中的准确率
cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO trainResult =
cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO.builder()
.trainId(trainId)
.rate(rateStr)
.build();
saveTrainResultDO(trainResult);
log.info("已保存分类任务准确率到数据库 - 训练ID: {}, 准确率: {}", trainId, rateStr);
} catch (Exception e) {
log.error("保存分类任务准确率到数据库失败", e);
}
}
/**
*
*
* @param logFile
* @param config YOLO
* @param trainId ID
*/
private void parseLogFileForAccuracy(java.io.File logFile, YoloConfig config, Integer trainId) {
try {
log.info("尝试从日志文件解析识别率: {}", logFile.getAbsolutePath());
java.util.List<String> lines = java.nio.file.Files.readAllLines(logFile.toPath());
// 反向遍历,查找最后的验证结果
for (int i = lines.size() - 1; i >= 0; i--) {
String line = lines.get(i);
// 查找包含精度、召回率、mAP等指标的行
if (line.contains("mAP50") || line.contains("mAP50-95") ||
line.contains("precision") || line.contains("recall")) {
parseMetricsLine(line, config, trainId);
break;
}
}
} catch (Exception e) {
log.error("读取日志文件时出错", e);
}
}
/**
*
*
* @param line
* @param config YOLO
* @param trainId ID
*/
private void parseMetricsLine(String line, YoloConfig config, Integer trainId) {
try {
log.info("解析指标行: {}", line);
// 使用正则表达式提取数值
java.util.regex.Pattern precisionPattern = java.util.regex.Pattern.compile("precision[:\\s=]+([0-9.]+)");
java.util.regex.Pattern recallPattern = java.util.regex.Pattern.compile("recall[:\\s=]+([0-9.]+)");
java.util.regex.Pattern mapPattern = java.util.regex.Pattern.compile("mAP50-95[:\\s=]+([0-9.]+)");
java.util.regex.Pattern map50Pattern = java.util.regex.Pattern.compile("mAP50[:\\s=]+([0-9.]+)");
java.util.regex.Matcher m;
if ((m = precisionPattern.matcher(line)).find()) {
config.setPrecision(Double.parseDouble(m.group(1)));
}
if ((m = recallPattern.matcher(line)).find()) {
config.setRecall(Double.parseDouble(m.group(1)));
}
if ((m = mapPattern.matcher(line)).find()) {
config.setMap(Double.parseDouble(m.group(1)));
}
if ((m = map50Pattern.matcher(line)).find()) {
// 如果没有mAP50-95使用mAP50作为mAP值
if (config.getMap() == 0.0) {
config.setMap(Double.parseDouble(m.group(1)));
}
}
// 保存识别率到数据库
if (config.getPrecision() > 0 || config.getRecall() > 0 || config.getMap() > 0) {
saveAccuracyToDatabase(trainId, config.getPrecision(), config.getRecall(), config.getMap());
log.info("成功解析识别率 - Precision: {}, Recall: {}, mAP: {}",
config.getPrecision(), config.getRecall(), config.getMap());
}
} catch (Exception e) {
log.error("解析指标行时出错", e);
}
}
/**
*
*
* @param savedPath
* @param config YOLO
*/
private void parseTrainingInfoFromPath(String savedPath, YoloConfig config) {
try {
// 从路径中可能包含的训练信息如train18等
java.io.File pathFile = new java.io.File(savedPath);
String pathName = pathFile.getName();
if (pathName.startsWith("train") && pathName.length() > 5) {
try {
int trainNumber = Integer.parseInt(pathName.substring(5));
log.info("从路径解析到训练编号: {}", trainNumber);
} catch (NumberFormatException e) {
log.debug("无法从路径解析训练编号: {}", pathName);
}
}
} catch (Exception e) {
log.warn("从路径解析训练信息时出错", e);
}
}
/**
*
*
* @param trainId ID
* @param savedPath
* @param config YOLO
*/
private void saveResultsPath(Integer trainId, String savedPath, YoloConfig config) {
try {
// 更新TrainResultDO
TrainResultDO trainResult = TrainResultDO.builder()
.trainId(trainId)
.path(savedPath)
.build();
saveTrainResultDO(trainResult);
TrainDO train = trainService.getTrain(trainId);
train.setOutPath(savedPath);
trainService.updateById(train);
log.info("已保存结果路径到数据库 - 训练ID: {}, 路径: {}", trainId, savedPath);
} catch (Exception e) {
log.error("保存结果路径到数据库失败", e);
}
}
/**
*
*
* @param trainId ID
* @param precision
* @param recall
* @param map mAP
*/
private void saveAccuracyToDatabase(Integer trainId, double precision, double recall, double map) {
try {
// 构建识别率字符串
StringBuilder rateStr = new StringBuilder();
if (precision > 0) {
rateStr.append("precision:").append(String.format("%.4f", precision));
}
if (recall > 0) {
if (rateStr.length() > 0) rateStr.append(",");
rateStr.append("recall:").append(String.format("%.4f", recall));
}
if (map > 0) {
if (rateStr.length() > 0) rateStr.append(",");
rateStr.append("map:").append(String.format("%.4f", map));
}
// 更新数据库中的识别率
TrainResultDO trainResult =
cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO.builder()
.trainId(trainId)
.rate(rateStr.toString())
.build();
saveTrainResultDO(trainResult);
log.info("已保存识别率到数据库 - 训练ID: {}, 识别率: {}", trainId, rateStr.toString());
} catch (Exception e) {
log.error("保存识别率到数据库失败", e);
}
}
/**
* Redis
* %505
*
* @param trainId ID
* @param logLine
*/
private void manageTrainingLogInRedis(Integer trainId, String logLine) {
try {
if (logLine == null || logLine.trim().isEmpty()) {
return;
}
String redisKey = "yolo:training_log:" + trainId;
boolean isProgressLog = logLine.contains("%");
if (isProgressLog) {
// 如果是进度日志(包含%),先删除之前的进度日志
removeProgressLogs(redisKey);
}
// 添加新的日志行到列表末尾
redisUtil.lSet(redisKey, logLine, 5 * 24 * 60 * 60); // 5天过期时间
// 检查列表长度如果超过50条删除第一条
long newSize = redisUtil.lGetListSize(redisKey);
if (newSize > 50) {
redisUtil.lRemove(redisKey, 1, redisUtil.lGetIndex(redisKey, 0));
}
log.debug("训练日志已存储到Redis - 训练ID: {}, 日志类型: {}, 当前日志数量: {}",
trainId, isProgressLog ? "进度日志" : "普通日志", redisUtil.lGetListSize(redisKey));
} catch (Exception e) {
log.error("管理训练日志Redis操作失败 - 训练ID: {}", trainId, e);
}
}
/**
* Redis%
*
* @param redisKey Redis
*/
private void removeProgressLogs(String redisKey) {
try {
long listSize = redisUtil.lGetListSize(redisKey);
if (listSize == 0) {
return;
}
// 获取所有日志
java.util.List<Object> allLogs = redisUtil.lGet(redisKey, 0, -1);
if (allLogs == null || allLogs.isEmpty()) {
return;
}
// 找出并删除最后一条包含%的日志保留100%的)
for (int i = allLogs.size() - 1; i >= 0; i--) {
Object logObj = allLogs.get(i);
if (logObj != null) {
String logStr = logObj.toString();
if (logStr.contains("%")) {
// 检查是否是100%的日志,如果是则跳过
if (logStr.contains("100%")) {
break;
}
redisUtil.lRemove(redisKey, 1, logStr);
log.debug("删除旧的进度日志: {}", logStr);
break; // 只删除最后一条非100%的进度日志
}
}
}
} catch (Exception e) {
log.error("删除进度日志时出错 - Redis键: {}", redisKey, e);
}
}
/**
* Python
*
* @return
*/
public void generateTrainPythonScript(TrainDO train, String yoloDatasetPath) {
try {
StringBuilder script = new StringBuilder();
// 文件头部和导入
script.append("# -*- coding: utf-8 -*-");
script.append("\n");
script.append("import sys");
script.append("\n");
script.append("import matplotlib.pyplot as plt");
script.append("\n");
script.append("from matplotlib import rcParams");
script.append("\n");
script.append("import platform");
script.append("\n");
// 字体设置函数
script.append("# 自动选择字体");
script.append("\n");
script.append("def set_font():");
script.append("\n");
script.append(" system_platform = platform.system()");
script.append("\n");
script.append(" if system_platform == \"Windows\":");
script.append("\n");
script.append(" rcParams['font.sans-serif'] = ['Microsoft YaHei'] # Windows 下使用微软雅黑");
script.append("\n");
script.append(" else:");
script.append("\n");
script.append(" rcParams['font.sans-serif'] = ['DejaVu Sans'] # Linux 下使用 DejaVu Sans");
script.append("\n");
script.append(" rcParams['axes.unicode_minus'] = False # 解决负号显示问题");
script.append("\n");
// 训练函数
script.append("# 训练函数");
script.append("\n");
script.append("def main():");
script.append("\n");
script.append(" set_font()");
script.append("\n");
script.append(" sys.path.insert(0, '.')");
script.append("\n");
script.append(" from ultralytics import YOLO");
script.append("\n");
if (train.getModelPath() != null && !train.getModelPath().trim().isEmpty()) {
script.append(" model = YOLO('").append(escapePythonString(train.getModelPath())).append("')");
} else {
script.append(" model = YOLO('yolo11n.pt')");
}
script.append("\n");
script.append(" ");
script.append("\n");
script.append(" # 训练模型");
script.append("\n");
script.append(" model.train(");
script.append("\n");
script.append(" data=r'").append(escapePythonString(train.getPath() + "/dataset.yaml")).append("',");
script.append("\n");
script.append(" epochs=").append(train.getRound()).append(",");
script.append("\n");
script.append(" batch=").append(train.getSize()).append(",");
script.append("\n");
script.append(" imgsz=").append(train.getImageSize()).append(",");
script.append("\n");
script.append(" project=r'").append(escapePythonString(train.getPath())).append("'");
script.append("\n");
script.append(" )");
script.append("\n");
// 程序入口
script.append("if __name__ == '__main__':");
script.append("\n");
script.append(" main()");
script.append("\n");
// 写入文件
try (java.io.FileWriter writer = new java.io.FileWriter(train.getPath() + "/train.py", java.nio.charset.StandardCharsets.UTF_8)) {
writer.write(script.toString());
}
log.info("训练Python脚本生成成功: {}", train.getPath());
} catch (Exception e) {
log.error("生成训练Python脚本失败", e);
}
}
} }

@ -0,0 +1,219 @@
package cn.iocoder.yudao.module.annotation.service.yolo;
import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.io.*;
import java.net.Socket;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
/**
* YOLO - socketPython
*
* @author
*/
@Slf4j
@Component
public class YoloServiceClient {
private static final String DEFAULT_HOST = "localhost";
private static final int DEFAULT_PORT = 9999;
private static final int CONNECTION_TIMEOUT = 30000; // 30秒超时
private final ObjectMapper objectMapper = new ObjectMapper();
/**
* socket
*/
public YoloConfig predictViaSocket(String modelPath, String inputPath, String outputPath, YoloConfig config) {
try {
Map<String, Object> params = new HashMap<>();
params.put("model_path", modelPath);
params.put("input_path", inputPath);
if (outputPath != null) {
params.put("output_path", outputPath);
}
// 添加其他YOLO配置参数
addYoloConfigParams(params, config);
Map<String, Object> request = new HashMap<>();
request.put("command", "predict");
request.put("params", params);
String response = sendRequest(request);
JsonNode responseJson = objectMapper.readTree(response);
if ("success".equals(responseJson.get("status").asText())) {
config.setStatus("completed");
config.setLogMessage(responseJson.get("message").asText());
if (responseJson.has("output_path")) {
config.setOutputPath(responseJson.get("output_path").asText());
}
return config;
} else {
config.setStatus("failed");
config.setErrorMessage(responseJson.get("message").asText());
return config;
}
} catch (Exception e) {
log.error("Socket预测失败", e);
config.setStatus("failed");
config.setErrorMessage("Socket预测失败: " + e.getMessage());
return config;
}
}
/**
* socket
*/
public YoloConfig trainViaSocket(String modelPath, String dataPath, YoloConfig config) {
try {
Map<String, Object> params = new HashMap<>();
params.put("model_path", modelPath);
params.put("data_path", dataPath);
params.put("epochs", config.getEpochs());
// 添加其他训练配置参数
addYoloConfigParams(params, config);
Map<String, Object> request = new HashMap<>();
request.put("command", "train");
request.put("params", params);
String response = sendRequest(request);
JsonNode responseJson = objectMapper.readTree(response);
if ("success".equals(responseJson.get("status").asText())) {
config.setStatus("completed");
config.setLogMessage(responseJson.get("message").asText());
if (responseJson.has("results_path")) {
config.setOutputPath(responseJson.get("results_path").asText());
}
return config;
} else {
config.setStatus("failed");
config.setErrorMessage(responseJson.get("message").asText());
return config;
}
} catch (Exception e) {
log.error("Socket训练失败", e);
config.setStatus("failed");
config.setErrorMessage("Socket训练失败: " + e.getMessage());
return config;
}
}
/**
* socket
*/
public YoloConfig validateViaSocket(String modelPath, String dataPath, YoloConfig config) {
try {
Map<String, Object> params = new HashMap<>();
params.put("model_path", modelPath);
if (dataPath != null) {
params.put("data_path", dataPath);
}
// 添加其他验证配置参数
addYoloConfigParams(params, config);
Map<String, Object> request = new HashMap<>();
request.put("command", "validate");
request.put("params", params);
String response = sendRequest(request);
JsonNode responseJson = objectMapper.readTree(response);
if ("success".equals(responseJson.get("status").asText())) {
config.setStatus("completed");
config.setLogMessage(responseJson.get("message").asText());
if (responseJson.has("metrics")) {
config.setMetrics(responseJson.get("metrics").asText());
}
return config;
} else {
config.setStatus("failed");
config.setErrorMessage(responseJson.get("message").asText());
return config;
}
} catch (Exception e) {
log.error("Socket验证失败", e);
config.setStatus("failed");
config.setErrorMessage("Socket验证失败: " + e.getMessage());
return config;
}
}
/**
* YOLO
*/
private void addYoloConfigParams(Map<String, Object> params, YoloConfig config) {
if (config.getImgsz() > 0) {
params.put("imgsz", config.getImgsz());
}
if (config.getConfThresh() > 0) {
params.put("conf", config.getConfThresh());
}
if (config.getIouThresh() > 0) {
params.put("iou", config.getIouThresh());
}
if (config.getDevice() != null) {
params.put("device", config.getDevice());
}
if (config.getBatchSize() > 0) {
params.put("batch", config.getBatchSize());
}
params.put("save", config.isSave());
params.put("save_txt", config.isSaveTxt());
params.put("save_crops", config.isSaveCrops());
}
/**
* socket
*/
private String sendRequest(Map<String, Object> request) throws IOException {
try (Socket socket = new Socket(DEFAULT_HOST, DEFAULT_PORT);
PrintWriter out = new PrintWriter(socket.getOutputStream(), true);
BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()))) {
// 设置超时
socket.setSoTimeout(CONNECTION_TIMEOUT);
// 发送请求
String requestJson = objectMapper.writeValueAsString(request);
out.println(requestJson);
// 读取响应
StringBuilder response = new StringBuilder();
String line;
while ((line = in.readLine()) != null) {
response.append(line);
}
return response.toString();
}
}
/**
*
*/
public boolean isServiceAvailable() {
try {
try (Socket socket = new Socket()) {
socket.connect(new java.net.InetSocketAddress(DEFAULT_HOST, DEFAULT_PORT), 3000);
return true;
}
} catch (IOException e) {
return false;
}
}
}

@ -693,6 +693,17 @@ public class RedisUtil {
return false; return false;
} }
} }
public boolean lPush(String key, Object value, long time) {
try {
redisTemplate.opsForList().rightPush(key, value);
if (time > 0)
expire(key, time);
return true;
} catch (Exception e) {
log.error(key, e);
return false;
}
}
/** /**
* list * list

Loading…
Cancel
Save