parent
acdec70da5
commit
fa0c137f0d
@ -0,0 +1,106 @@
|
||||
task: detect
|
||||
mode: train
|
||||
model: yolov8n.pt
|
||||
data: 'null'
|
||||
epochs: 50
|
||||
time: null
|
||||
patience: 100
|
||||
batch: 4
|
||||
imgsz: 640
|
||||
save: true
|
||||
save_period: -1
|
||||
cache: false
|
||||
device: cpu
|
||||
workers: 8
|
||||
project: 'null'
|
||||
name: train
|
||||
exist_ok: false
|
||||
pretrained: true
|
||||
optimizer: auto
|
||||
verbose: true
|
||||
seed: 0
|
||||
deterministic: true
|
||||
single_cls: false
|
||||
rect: false
|
||||
cos_lr: false
|
||||
close_mosaic: 10
|
||||
resume: false
|
||||
amp: true
|
||||
fraction: 1.0
|
||||
profile: false
|
||||
freeze: null
|
||||
multi_scale: false
|
||||
compile: false
|
||||
overlap_mask: true
|
||||
mask_ratio: 4
|
||||
dropout: 0.0
|
||||
val: true
|
||||
split: val
|
||||
save_json: false
|
||||
conf: null
|
||||
iou: 0.7
|
||||
max_det: 300
|
||||
half: false
|
||||
dnn: false
|
||||
plots: true
|
||||
source: null
|
||||
vid_stride: 1
|
||||
stream_buffer: false
|
||||
visualize: false
|
||||
augment: false
|
||||
agnostic_nms: false
|
||||
classes: null
|
||||
retina_masks: false
|
||||
embed: null
|
||||
show: false
|
||||
save_frames: false
|
||||
save_txt: false
|
||||
save_conf: false
|
||||
save_crop: false
|
||||
show_labels: true
|
||||
show_conf: true
|
||||
show_boxes: true
|
||||
line_width: null
|
||||
format: torchscript
|
||||
keras: false
|
||||
optimize: false
|
||||
int8: false
|
||||
dynamic: false
|
||||
simplify: true
|
||||
opset: null
|
||||
workspace: null
|
||||
nms: false
|
||||
lr0: 0.01
|
||||
lrf: 0.01
|
||||
momentum: 0.937
|
||||
weight_decay: 0.0005
|
||||
warmup_epochs: 3.0
|
||||
warmup_momentum: 0.8
|
||||
warmup_bias_lr: 0.1
|
||||
box: 7.5
|
||||
cls: 0.5
|
||||
dfl: 1.5
|
||||
pose: 12.0
|
||||
kobj: 1.0
|
||||
nbs: 64
|
||||
hsv_h: 0.015
|
||||
hsv_s: 0.7
|
||||
hsv_v: 0.4
|
||||
degrees: 0.0
|
||||
translate: 0.1
|
||||
scale: 0.5
|
||||
shear: 0.0
|
||||
perspective: 0.0
|
||||
flipud: 0.0
|
||||
fliplr: 0.5
|
||||
bgr: 0.0
|
||||
mosaic: 1.0
|
||||
mixup: 0.0
|
||||
cutmix: 0.0
|
||||
copy_paste: 0.0
|
||||
copy_paste_mode: flip
|
||||
auto_augment: randaugment
|
||||
erasing: 0.4
|
||||
cfg: null
|
||||
tracker: botsort.yaml
|
||||
save_dir: D:\git\lpNewgit\null\train
|
||||
@ -0,0 +1,106 @@
|
||||
task: detect
|
||||
mode: train
|
||||
model: yolov8n.pt
|
||||
data: 'null'
|
||||
epochs: 50
|
||||
time: null
|
||||
patience: 100
|
||||
batch: 4
|
||||
imgsz: 640
|
||||
save: true
|
||||
save_period: -1
|
||||
cache: false
|
||||
device: cpu
|
||||
workers: 8
|
||||
project: 'null'
|
||||
name: train2
|
||||
exist_ok: false
|
||||
pretrained: true
|
||||
optimizer: auto
|
||||
verbose: true
|
||||
seed: 0
|
||||
deterministic: true
|
||||
single_cls: false
|
||||
rect: false
|
||||
cos_lr: false
|
||||
close_mosaic: 10
|
||||
resume: false
|
||||
amp: true
|
||||
fraction: 1.0
|
||||
profile: false
|
||||
freeze: null
|
||||
multi_scale: false
|
||||
compile: false
|
||||
overlap_mask: true
|
||||
mask_ratio: 4
|
||||
dropout: 0.0
|
||||
val: true
|
||||
split: val
|
||||
save_json: false
|
||||
conf: null
|
||||
iou: 0.7
|
||||
max_det: 300
|
||||
half: false
|
||||
dnn: false
|
||||
plots: true
|
||||
source: null
|
||||
vid_stride: 1
|
||||
stream_buffer: false
|
||||
visualize: false
|
||||
augment: false
|
||||
agnostic_nms: false
|
||||
classes: null
|
||||
retina_masks: false
|
||||
embed: null
|
||||
show: false
|
||||
save_frames: false
|
||||
save_txt: false
|
||||
save_conf: false
|
||||
save_crop: false
|
||||
show_labels: true
|
||||
show_conf: true
|
||||
show_boxes: true
|
||||
line_width: null
|
||||
format: torchscript
|
||||
keras: false
|
||||
optimize: false
|
||||
int8: false
|
||||
dynamic: false
|
||||
simplify: true
|
||||
opset: null
|
||||
workspace: null
|
||||
nms: false
|
||||
lr0: 0.01
|
||||
lrf: 0.01
|
||||
momentum: 0.937
|
||||
weight_decay: 0.0005
|
||||
warmup_epochs: 3.0
|
||||
warmup_momentum: 0.8
|
||||
warmup_bias_lr: 0.1
|
||||
box: 7.5
|
||||
cls: 0.5
|
||||
dfl: 1.5
|
||||
pose: 12.0
|
||||
kobj: 1.0
|
||||
nbs: 64
|
||||
hsv_h: 0.015
|
||||
hsv_s: 0.7
|
||||
hsv_v: 0.4
|
||||
degrees: 0.0
|
||||
translate: 0.1
|
||||
scale: 0.5
|
||||
shear: 0.0
|
||||
perspective: 0.0
|
||||
flipud: 0.0
|
||||
fliplr: 0.5
|
||||
bgr: 0.0
|
||||
mosaic: 1.0
|
||||
mixup: 0.0
|
||||
cutmix: 0.0
|
||||
copy_paste: 0.0
|
||||
copy_paste_mode: flip
|
||||
auto_augment: randaugment
|
||||
erasing: 0.4
|
||||
cfg: null
|
||||
tracker: botsort.yaml
|
||||
save_dir: D:\git\lpNewgit\null\train2
|
||||
@ -0,0 +1,28 @@
|
||||
@echo off
|
||||
echo Starting YOLO Service...
|
||||
cd /d %~dp0
|
||||
|
||||
REM 检查Python环境
|
||||
python --version
|
||||
if %errorlevel% neq 0 (
|
||||
echo Python not found, please install Python first
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
REM 安装依赖(如果需要)
|
||||
echo Installing dependencies...
|
||||
pip install ultralytics
|
||||
|
||||
REM 预加载模型路径(可选)
|
||||
set MODEL_PATH=%~dp0yolo11n.pt
|
||||
|
||||
REM 启动YOLO服务
|
||||
echo Starting YOLO service on localhost:9999
|
||||
if exist "%MODEL_PATH%" (
|
||||
python yolo_service.py --host localhost --port 9999 --model "%MODEL_PATH%"
|
||||
) else (
|
||||
python yolo_service.py --host localhost --port 9999
|
||||
)
|
||||
|
||||
pause
|
||||
@ -0,0 +1,162 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
YOLO服务管理器
|
||||
用于启动、停止、重启YOLO socket服务
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import signal
|
||||
import argparse
|
||||
import psutil
|
||||
from pathlib import Path
|
||||
|
||||
class YoloServiceManager:
|
||||
def __init__(self, host='localhost', port=9999, model_path=None):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.model_path = model_path
|
||||
self.pid_file = Path('yolo_service.pid')
|
||||
|
||||
def start(self):
|
||||
"""启动YOLO服务"""
|
||||
if self.is_running():
|
||||
print(f"YOLO服务已在运行 (PID: {self.get_pid()})")
|
||||
return
|
||||
|
||||
print(f"启动YOLO服务 {self.host}:{self.port}")
|
||||
|
||||
cmd = [
|
||||
sys.executable, 'yolo_service.py',
|
||||
'--host', self.host,
|
||||
'--port', str(self.port)
|
||||
]
|
||||
|
||||
if self.model_path and Path(self.model_path).exists():
|
||||
cmd.extend(['--model', self.model_path])
|
||||
|
||||
try:
|
||||
# 启动进程
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=os.getcwd()
|
||||
)
|
||||
|
||||
# 保存PID
|
||||
with open(self.pid_file, 'w') as f:
|
||||
f.write(str(process.pid))
|
||||
|
||||
print(f"YOLO服务已启动 (PID: {process.pid})")
|
||||
|
||||
# 等待服务启动
|
||||
time.sleep(2)
|
||||
if self.is_service_alive():
|
||||
print("服务启动成功")
|
||||
else:
|
||||
print("警告: 服务可能未正常启动,请检查日志")
|
||||
|
||||
except Exception as e:
|
||||
print(f"启动服务失败: {e}")
|
||||
if self.pid_file.exists():
|
||||
self.pid_file.unlink()
|
||||
|
||||
def stop(self):
|
||||
"""停止YOLO服务"""
|
||||
if not self.is_running():
|
||||
print("YOLO服务未运行")
|
||||
return
|
||||
|
||||
pid = self.get_pid()
|
||||
try:
|
||||
# 尝试优雅停止
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
time.sleep(2)
|
||||
|
||||
# 如果还在运行,强制停止
|
||||
if psutil.pid_exists(pid):
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
time.sleep(1)
|
||||
|
||||
print(f"YOLO服务已停止 (PID: {pid})")
|
||||
|
||||
except ProcessLookupError:
|
||||
print(f"进程 {pid} 不存在")
|
||||
except Exception as e:
|
||||
print(f"停止服务失败: {e}")
|
||||
finally:
|
||||
if self.pid_file.exists():
|
||||
self.pid_file.unlink()
|
||||
|
||||
def restart(self):
|
||||
"""重启YOLO服务"""
|
||||
self.stop()
|
||||
time.sleep(1)
|
||||
self.start()
|
||||
|
||||
def status(self):
|
||||
"""查看服务状态"""
|
||||
if self.is_running():
|
||||
pid = self.get_pid()
|
||||
if self.is_service_alive():
|
||||
print(f"YOLO服务运行中 (PID: {pid})")
|
||||
return True
|
||||
else:
|
||||
print(f"YOLO服务进程存在但不可用 (PID: {pid})")
|
||||
return False
|
||||
else:
|
||||
print("YOLO服务未运行")
|
||||
return False
|
||||
|
||||
def is_running(self):
|
||||
"""检查服务是否在运行"""
|
||||
return self.pid_file.exists() and self.get_pid() > 0
|
||||
|
||||
def get_pid(self):
|
||||
"""获取服务PID"""
|
||||
if self.pid_file.exists():
|
||||
try:
|
||||
with open(self.pid_file, 'r') as f:
|
||||
return int(f.read().strip())
|
||||
except:
|
||||
return 0
|
||||
return 0
|
||||
|
||||
def is_service_alive(self):
|
||||
"""检查服务是否可用"""
|
||||
try:
|
||||
import socket
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(3)
|
||||
result = sock.connect_ex((self.host, self.port))
|
||||
sock.close()
|
||||
return result == 0
|
||||
except:
|
||||
return False
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='YOLO服务管理器')
|
||||
parser.add_argument('command', choices=['start', 'stop', 'restart', 'status'],
|
||||
help='命令: start, stop, restart, status')
|
||||
parser.add_argument('--host', default='localhost', help='服务主机地址')
|
||||
parser.add_argument('--port', type=int, default=9999, help='服务端口')
|
||||
parser.add_argument('--model', help='预加载模型路径')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
manager = YoloServiceManager(args.host, args.port, args.model)
|
||||
|
||||
if args.command == 'start':
|
||||
manager.start()
|
||||
elif args.command == 'stop':
|
||||
manager.stop()
|
||||
elif args.command == 'restart':
|
||||
manager.restart()
|
||||
elif args.command == 'status':
|
||||
manager.status()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Binary file not shown.
@ -0,0 +1,153 @@
|
||||
import json
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from ultralytics import YOLO
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class YOLOService:
|
||||
def __init__(self):
|
||||
self.models = {} # 缓存已加载的模型
|
||||
self.default_model = None
|
||||
|
||||
def load_model(self, model_path):
|
||||
"""加载并缓存模型"""
|
||||
if model_path not in self.models:
|
||||
logger.info(f"Loading model: {model_path}")
|
||||
self.models[model_path] = YOLO(model_path)
|
||||
|
||||
return self.models[model_path]
|
||||
|
||||
def predict(self, model_path, input_path, output_path=None, **kwargs):
|
||||
"""执行预测"""
|
||||
try:
|
||||
model = self.load_model(model_path)
|
||||
results = model.predict(input_path, **kwargs)
|
||||
|
||||
if output_path:
|
||||
# 处理保存结果
|
||||
for i, result in enumerate(results):
|
||||
save_path = f"{output_path}/{i}.jpg" if output_path.endswith('/') else f"{output_path}/{i}.jpg"
|
||||
result.save(save_path)
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'message': f'Prediction completed. Results saved to {output_path}' if output_path else 'Prediction completed',
|
||||
'output_path': output_path
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
'status': 'error',
|
||||
'message': str(e)
|
||||
}
|
||||
|
||||
def train(self, model_path, data_path, epochs=50, **kwargs):
|
||||
"""执行训练"""
|
||||
try:
|
||||
model = self.load_model(model_path)
|
||||
results = model.train(data=data_path, epochs=epochs, **kwargs)
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'message': 'Training completed',
|
||||
'results_path': results.save_dir
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
'status': 'error',
|
||||
'message': str(e)
|
||||
}
|
||||
|
||||
def validate(self, model_path, data_path=None, **kwargs):
|
||||
"""执行验证"""
|
||||
try:
|
||||
model = self.load_model(model_path)
|
||||
results = model.validate(data=data_path, **kwargs)
|
||||
|
||||
return {
|
||||
'status': 'success',
|
||||
'message': 'Validation completed',
|
||||
'metrics': str(results)
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
'status': 'error',
|
||||
'message': str(e)
|
||||
}
|
||||
|
||||
def handle_client(client_socket, yolo_service):
|
||||
"""处理客户端请求"""
|
||||
try:
|
||||
data = client_socket.recv(4096).decode('utf-8')
|
||||
if not data:
|
||||
return
|
||||
|
||||
request = json.loads(data)
|
||||
command = request.get('command')
|
||||
|
||||
if command == 'predict':
|
||||
response = yolo_service.predict(**request.get('params', {}))
|
||||
elif command == 'train':
|
||||
response = yolo_service.train(**request.get('params', {}))
|
||||
elif command == 'validate':
|
||||
response = yolo_service.validate(**request.get('params', {}))
|
||||
else:
|
||||
response = {'status': 'error', 'message': f'Unknown command: {command}'}
|
||||
|
||||
client_socket.send(json.dumps(response).encode('utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
error_response = {'status': 'error', 'message': str(e)}
|
||||
client_socket.send(json.dumps(error_response).encode('utf-8'))
|
||||
finally:
|
||||
client_socket.close()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='YOLO Service')
|
||||
parser.add_argument('--host', default='localhost', help='Host to bind to')
|
||||
parser.add_argument('--port', type=int, default=9999, help='Port to bind to')
|
||||
parser.add_argument('--model', help='Default model to preload')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
yolo_service = YOLOService()
|
||||
|
||||
# 预加载默认模型
|
||||
if args.model and Path(args.model).exists():
|
||||
yolo_service.load_model(args.model)
|
||||
logger.info(f"Preloaded model: {args.model}")
|
||||
|
||||
# 启动socket服务
|
||||
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
server.bind((args.host, args.port))
|
||||
server.listen(5)
|
||||
|
||||
logger.info(f"YOLO Service started on {args.host}:{args.port}")
|
||||
|
||||
try:
|
||||
while True:
|
||||
client_socket, addr = server.accept()
|
||||
logger.info(f"Connection from {addr}")
|
||||
|
||||
# 为每个客户端创建新线程
|
||||
client_thread = threading.Thread(
|
||||
target=handle_client,
|
||||
args=(client_socket, yolo_service)
|
||||
)
|
||||
client_thread.daemon = True
|
||||
client_thread.start()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Shutting down server...")
|
||||
finally:
|
||||
server.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Binary file not shown.
@ -0,0 +1,85 @@
|
||||
# YOLO 训练结果解析功能说明
|
||||
|
||||
## 功能概述
|
||||
|
||||
在 `YoloOperationServiceImpl.java` 中新增了自动检测和解析 YOLO 训练结果的功能,当训练日志中出现 "Results saved to" 时,系统会:
|
||||
|
||||
1. **自动提取保存路径**:从日志中解析出类似 `D:\data\train\1231_5\runs\train\train18` 的路径
|
||||
2. **解析识别率信息**:从保存路径中的 `results.csv` 文件或其他日志文件中提取精度、召回率、mAP 等指标
|
||||
3. **保存到数据库**:将路径和识别率信息保存到 `TrainResultDO` 表中
|
||||
|
||||
## 新增方法
|
||||
|
||||
### 1. `detectAndParseResultsSaved(String logLine, YoloConfig config, Integer trainId)`
|
||||
- 检测日志中是否包含 "Results saved to" 或 "Saved to" 关键字
|
||||
- 调用后续方法进行路径提取和识别率解析
|
||||
|
||||
### 2. `extractSavedPath(String logLine)`
|
||||
- 使用正则表达式从日志行中提取保存路径
|
||||
- 支持多种格式,如带引号、带统计信息等
|
||||
|
||||
### 3. `parseAccuracyFromSavedPath(String savedPath, YoloConfig config, Integer trainId)`
|
||||
- 查找保存目录中的 `results.csv` 文件
|
||||
- 解析 CSV 文件中的最后一行(最新结果)
|
||||
- 提取 precision、recall、mAP50、mAP50-95 等指标
|
||||
|
||||
### 4. `parseResultsCsv(File csvFile, YoloConfig config, Integer trainId)`
|
||||
- 解析 YOLO 生成的 `results.csv` 文件
|
||||
- CSV 格式通常为:`epoch, train/box_loss, train/cls_loss, train/dfl_loss, metrics/precision, metrics/recall, metrics/mAP50, metrics/mAP50-95`
|
||||
|
||||
### 5. `saveAccuracyToDatabase(Integer trainId, double precision, double recall, double map)`
|
||||
- 将解析到的识别率保存到数据库
|
||||
- 格式:`precision:0.8521,recall:0.7432,map:0.8123`
|
||||
|
||||
## 集成位置
|
||||
|
||||
在 `trainAsync` 方法中,已经集成到日志读取循环中:
|
||||
|
||||
```java
|
||||
if (stdoutLine != null) {
|
||||
// ... 现有代码 ...
|
||||
|
||||
// 检测 "Results saved to" 并解析路径和识别率
|
||||
detectAndParseResultsSaved(stdoutLine, config, trainId);
|
||||
}
|
||||
|
||||
if (stderrLine != null) {
|
||||
// ... 现有代码 ...
|
||||
|
||||
// 检测 "Results saved to" 并解析路径和识别率
|
||||
detectAndParseResultsSaved(stderrLine, config, trainId);
|
||||
}
|
||||
```
|
||||
|
||||
## 预期输出格式
|
||||
|
||||
当 YOLO 训练完成时,日志中会出现类似内容:
|
||||
```
|
||||
Results saved to D:\data\train\1231_5\runs\train\train18
|
||||
```
|
||||
|
||||
系统会自动:
|
||||
1. 提取路径:`D:\data\train\1231_5\runs\train\train18`
|
||||
2. 查找并解析该目录下的 `results.csv` 文件
|
||||
3. 保存路径和识别率到数据库
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **文件权限**:确保 Java 进程有权限访问 YOLO 保存的文件
|
||||
2. **CSV 格式**:依赖 YOLO 标准的 `results.csv` 输出格式
|
||||
3. **异步处理**:识别率解析不会阻塞训练过程
|
||||
4. **错误处理**:解析失败不会影响训练完成状态
|
||||
|
||||
## 数据库字段
|
||||
|
||||
保存的识别率格式:
|
||||
```sql
|
||||
-- TrainResultDO 表的 rate 字段
|
||||
precision:0.8521,recall:0.7432,map:0.8123
|
||||
```
|
||||
|
||||
路径字段:
|
||||
```sql
|
||||
-- TrainResultDO 表的 path 字段
|
||||
D:\data\train\1231_5\runs\train\train18
|
||||
```
|
||||
@ -0,0 +1,9 @@
|
||||
package cn.iocoder.yudao.module.annotation.controller.admin.datas.vo;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class DatasEnhance {
|
||||
private Long id;
|
||||
private String[] enhancements;
|
||||
}
|
||||
@ -0,0 +1,9 @@
|
||||
package cn.iocoder.yudao.module.annotation.controller.admin.train.vo;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class RecognitionResult {
|
||||
private String imageUrl;
|
||||
private String trainId;
|
||||
}
|
||||
@ -0,0 +1,108 @@
|
||||
package cn.iocoder.yudao.module.annotation.controller.admin.train.vo;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Schema(description = "系统信息 Response VO")
|
||||
@Data
|
||||
public class SystemInfoRespVO {
|
||||
|
||||
@Schema(description = "CPU信息")
|
||||
private CpuInfo cpu;
|
||||
|
||||
@Schema(description = "GPU信息列表")
|
||||
private List<GpuInfo> gpus;
|
||||
|
||||
@Schema(description = "Python信息")
|
||||
private PythonInfo python;
|
||||
|
||||
@Data
|
||||
@Schema(description = "CPU信息")
|
||||
public static class CpuInfo {
|
||||
@Schema(description = "CPU型号")
|
||||
private String model;
|
||||
|
||||
@Schema(description = "CPU核心数")
|
||||
private Integer cores;
|
||||
|
||||
@Schema(description = "CPU使用率")
|
||||
private Double usage;
|
||||
}
|
||||
|
||||
@Data
|
||||
@Schema(description = "GPU信息")
|
||||
public static class GpuInfo {
|
||||
@Schema(description = "GPU ID")
|
||||
private Integer id;
|
||||
|
||||
@Schema(description = "GPU名称")
|
||||
private String name;
|
||||
|
||||
@Schema(description = "GPU内存总量(MB)")
|
||||
private Long memoryTotal;
|
||||
|
||||
@Schema(description = "GPU已用内存(MB)")
|
||||
private Long memoryUsed;
|
||||
|
||||
@Schema(description = "GPU可用内存(MB)")
|
||||
private Long memoryFree;
|
||||
|
||||
@Schema(description = "GPU使用率")
|
||||
private Double usage;
|
||||
|
||||
@Schema(description = "GPU温度")
|
||||
private Double temperature;
|
||||
}
|
||||
|
||||
@Data
|
||||
@Schema(description = "Python信息")
|
||||
public static class PythonInfo {
|
||||
@Schema(description = "是否安装Python")
|
||||
private Boolean installed;
|
||||
|
||||
@Schema(description = "Python版本")
|
||||
private String version;
|
||||
|
||||
@Schema(description = "Python路径")
|
||||
private String path;
|
||||
|
||||
@Schema(description = "是否支持venv虚拟环境")
|
||||
private Boolean venvSupported;
|
||||
|
||||
@Schema(description = "是否安装virtualenv")
|
||||
private Boolean virtualenvInstalled;
|
||||
|
||||
@Schema(description = "是否安装conda")
|
||||
private Boolean condaInstalled;
|
||||
|
||||
@Schema(description = "yolo虚拟环境是否存在")
|
||||
private Boolean yoloEnvExists;
|
||||
|
||||
@Schema(description = "yolo虚拟环境中是否安装YOLO")
|
||||
private Boolean yoloInstalled;
|
||||
|
||||
@Schema(description = "YOLO版本")
|
||||
private String yoloVersion;
|
||||
|
||||
@Schema(description = "Python虚拟环境列表")
|
||||
private List<VirtualEnvInfo> virtualEnvs;
|
||||
}
|
||||
|
||||
@Data
|
||||
@Schema(description = "虚拟环境信息")
|
||||
public static class VirtualEnvInfo {
|
||||
@Schema(description = "虚拟环境名称")
|
||||
private String name;
|
||||
|
||||
@Schema(description = "虚拟环境路径")
|
||||
private String path;
|
||||
|
||||
@Schema(description = "虚拟环境类型 (venv/virtualenv/conda)")
|
||||
private String type;
|
||||
|
||||
@Schema(description = "Python版本")
|
||||
private String pythonVersion;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,105 @@
|
||||
package cn.iocoder.yudao.module.annotation.controller.admin.traininfo;
|
||||
|
||||
import cn.iocoder.yudao.framework.apilog.core.annotation.ApiAccessLog;
|
||||
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.framework.excel.core.util.ExcelUtils;
|
||||
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoPageReqVO;
|
||||
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoRespVO;
|
||||
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoSaveReqVO;
|
||||
import cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo.TrainInfoDO;
|
||||
import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService;
|
||||
import cn.iocoder.yudao.module.system.util.RedisUtil;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import jakarta.annotation.Resource;
|
||||
import jakarta.servlet.http.HttpServletResponse;
|
||||
import jakarta.validation.Valid;
|
||||
import org.springframework.security.access.prepost.PreAuthorize;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import static cn.iocoder.yudao.framework.apilog.core.enums.OperateTypeEnum.EXPORT;
|
||||
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
|
||||
|
||||
@Tag(name = "管理后台 - 识别结果")
|
||||
@RestController
|
||||
@RequestMapping("/annotation/train-Info")
|
||||
@Validated
|
||||
public class TrainInfoController {
|
||||
|
||||
@Resource
|
||||
private TrainInfoService trainInfoService;
|
||||
|
||||
@PostMapping("/create")
|
||||
@Operation(summary = "创建识别结果")
|
||||
@PreAuthorize("@ss.hasPermission('annotation:train-Info:create')")
|
||||
public CommonResult<Integer> createTrainInfo(@Valid @RequestBody TrainInfoSaveReqVO createReqVO) {
|
||||
return success(trainInfoService.createTrainInfo(createReqVO));
|
||||
}
|
||||
|
||||
@PutMapping("/update")
|
||||
@Operation(summary = "更新识别结果")
|
||||
@PreAuthorize("@ss.hasPermission('annotation:train-Info:update')")
|
||||
public CommonResult<Boolean> updateTrainInfo(@Valid @RequestBody TrainInfoSaveReqVO updateReqVO) {
|
||||
trainInfoService.updateTrainInfo(updateReqVO);
|
||||
return success(true);
|
||||
}
|
||||
|
||||
@DeleteMapping("/delete")
|
||||
@Operation(summary = "删除识别结果")
|
||||
@Parameter(name = "id", description = "编号", required = true)
|
||||
@PreAuthorize("@ss.hasPermission('annotation:train-Info:delete')")
|
||||
public CommonResult<Boolean> deleteTrainInfo(@RequestParam("id") Integer id) {
|
||||
trainInfoService.deleteTrainInfo(id);
|
||||
return success(true);
|
||||
}
|
||||
|
||||
@GetMapping("/get")
|
||||
@Operation(summary = "获得识别结果")
|
||||
@Parameter(name = "id", description = "编号", required = true, example = "1024")
|
||||
@PreAuthorize("@ss.hasPermission('annotation:train-Info:query')")
|
||||
public CommonResult<TrainInfoRespVO> getTrainInfo(@RequestParam("id") Integer id) {
|
||||
TrainInfoDO trainInfo = trainInfoService.getTrainInfo(id);
|
||||
return success(BeanUtils.toBean(trainInfo, TrainInfoRespVO.class));
|
||||
}
|
||||
@Resource
|
||||
RedisUtil redisUtil;
|
||||
|
||||
@GetMapping("/get-list")
|
||||
@Operation(summary = "获得识别结果")
|
||||
@Parameter(name = "id", description = "编号", required = true, example = "1024")
|
||||
@PreAuthorize("@ss.hasPermission('annotation:train-Info:query')")
|
||||
public CommonResult<List<String>> getTrainInfoList(@RequestParam("trainId") Integer id) {
|
||||
List<Object> result = redisUtil.lGet("yolo:training_log:" + id, 0, -1);
|
||||
List<String> result1 = result.stream().map(Object::toString).toList();
|
||||
return success(result1);
|
||||
}
|
||||
@GetMapping("/page")
|
||||
@Operation(summary = "获得识别结果分页")
|
||||
@PreAuthorize("@ss.hasPermission('annotation:train-Info:query')")
|
||||
public CommonResult<PageResult<TrainInfoRespVO>> getTrainInfoPage(@Valid TrainInfoPageReqVO pageReqVO) {
|
||||
PageResult<TrainInfoDO> pageInfo = trainInfoService.getTrainInfoPage(pageReqVO);
|
||||
return success(BeanUtils.toBean(pageInfo, TrainInfoRespVO.class));
|
||||
}
|
||||
|
||||
@GetMapping("/export-excel")
|
||||
@Operation(summary = "导出识别结果 Excel")
|
||||
@PreAuthorize("@ss.hasPermission('annotation:train-Info:export')")
|
||||
@ApiAccessLog(operateType = EXPORT)
|
||||
public void exportTrainInfoExcel(@Valid TrainInfoPageReqVO pageReqVO,
|
||||
HttpServletResponse response) throws IOException {
|
||||
pageReqVO.setPageSize(PageParam.PAGE_SIZE_NONE);
|
||||
List<TrainInfoDO> list = trainInfoService.getTrainInfoPage(pageReqVO).getList();
|
||||
// 导出 Excel
|
||||
ExcelUtils.write(response, "识别结果.xls", "数据", TrainInfoRespVO.class,
|
||||
BeanUtils.toBean(list, TrainInfoRespVO.class));
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,34 @@
|
||||
package cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo;
|
||||
|
||||
import lombok.*;
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
||||
import org.springframework.format.annotation.DateTimeFormat;
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
import static cn.iocoder.yudao.framework.common.util.date.DateUtils.FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND;
|
||||
|
||||
@Schema(description = "管理后台 - 训练信息分页 Request VO")
|
||||
@Data
|
||||
public class TrainInfoPageReqVO extends PageParam {
|
||||
|
||||
@Schema(description = "训练id", example = "27530")
|
||||
private Integer trainId;
|
||||
|
||||
@Schema(description = "识别批次", example = "1")
|
||||
private Integer round;
|
||||
|
||||
@Schema(description = "识别的总批次", example = "100")
|
||||
private Integer roundTotal;
|
||||
|
||||
@Schema(description = "训练信息")
|
||||
private String info;
|
||||
|
||||
@Schema(description = "识别结果rate")
|
||||
private String rate;
|
||||
|
||||
@Schema(description = "创建时间")
|
||||
@DateTimeFormat(pattern = FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND)
|
||||
private LocalDateTime[] createTime;
|
||||
|
||||
}
|
||||
@ -0,0 +1,44 @@
|
||||
package cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.*;
|
||||
import java.time.LocalDateTime;
|
||||
import com.alibaba.excel.annotation.*;
|
||||
|
||||
@Schema(description = "管理后台 - 训练信息 Response VO")
|
||||
@Data
|
||||
@ExcelIgnoreUnannotated
|
||||
public class TrainInfoRespVO {
|
||||
|
||||
@Schema(description = "id", requiredMode = Schema.RequiredMode.REQUIRED, example = "15312")
|
||||
@ExcelProperty("id")
|
||||
private Integer id;
|
||||
|
||||
@Schema(description = "训练id", example = "27530")
|
||||
@ExcelProperty("训练id")
|
||||
private Integer trainId;
|
||||
|
||||
@Schema(description = "识别批次", example = "1")
|
||||
@ExcelProperty("识别批次")
|
||||
private Integer round;
|
||||
|
||||
@Schema(description = "识别的总批次", example = "100")
|
||||
@ExcelProperty("识别的总批次")
|
||||
private Integer roundTotal;
|
||||
|
||||
@Schema(description = "训练信息")
|
||||
@ExcelProperty("训练信息")
|
||||
private String info;
|
||||
|
||||
@Schema(description = "识别结果rate")
|
||||
@ExcelProperty("识别结果rate")
|
||||
private String rate;
|
||||
|
||||
@Schema(description = "创建时间")
|
||||
@ExcelProperty("创建时间")
|
||||
private LocalDateTime createTime;
|
||||
|
||||
@Schema(description = "修改时间")
|
||||
@ExcelProperty("修改时间")
|
||||
private LocalDateTime updateTime;
|
||||
}
|
||||
@ -0,0 +1,24 @@
|
||||
package cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.*;
|
||||
|
||||
@Schema(description = "管理后台 - 训练信息新增/修改 Request VO")
|
||||
@Data
|
||||
public class TrainInfoSaveReqVO {
|
||||
|
||||
@Schema(description = "训练id", requiredMode = Schema.RequiredMode.REQUIRED, example = "27530")
|
||||
private Integer trainId;
|
||||
|
||||
@Schema(description = "识别批次", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
||||
private Integer round;
|
||||
|
||||
@Schema(description = "识别的总批次", requiredMode = Schema.RequiredMode.REQUIRED, example = "100")
|
||||
private Integer roundTotal;
|
||||
|
||||
@Schema(description = "训练信息")
|
||||
private String info;
|
||||
|
||||
@Schema(description = "识别结果rate")
|
||||
private String rate;
|
||||
}
|
||||
@ -0,0 +1,47 @@
|
||||
package cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo;
|
||||
|
||||
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.*;
|
||||
|
||||
/**
|
||||
* 识别结果 DO
|
||||
*
|
||||
* @author 管理员
|
||||
*/
|
||||
@TableName("annotation_train_info")
|
||||
@KeySequence("annotation_train_info_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
@ToString(callSuper = true)
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class TrainInfoDO extends BaseDO {
|
||||
|
||||
/**
|
||||
* id
|
||||
*/
|
||||
@TableId
|
||||
private Integer id;
|
||||
/**
|
||||
* 训练id
|
||||
*/
|
||||
private Integer trainId;
|
||||
// 识别批次
|
||||
private Integer round;
|
||||
private String info;
|
||||
|
||||
/**
|
||||
* 识别的总批次
|
||||
*/
|
||||
private Integer roundTotal;
|
||||
/**
|
||||
* 识别结果rate
|
||||
*/
|
||||
private String rate;
|
||||
|
||||
|
||||
}
|
||||
@ -0,0 +1,31 @@
|
||||
package cn.iocoder.yudao.module.annotation.dal.mysql.traininfo;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
|
||||
import cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo.TrainInfoDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.*;
|
||||
|
||||
/**
|
||||
* 训练信息 Mapper
|
||||
*
|
||||
* @author 管理员
|
||||
*/
|
||||
@Mapper
|
||||
public interface TrainInfoMapper extends BaseMapperX<TrainInfoDO> {
|
||||
|
||||
default PageResult<TrainInfoDO> selectPage(TrainInfoPageReqVO reqVO) {
|
||||
return selectPage(reqVO, new LambdaQueryWrapperX<TrainInfoDO>()
|
||||
.eqIfPresent(TrainInfoDO::getTrainId, reqVO.getTrainId())
|
||||
.eqIfPresent(TrainInfoDO::getRound, reqVO.getRound())
|
||||
.eqIfPresent(TrainInfoDO::getRoundTotal, reqVO.getRoundTotal())
|
||||
.likeIfPresent(TrainInfoDO::getInfo, reqVO.getInfo())
|
||||
.eqIfPresent(TrainInfoDO::getRate, reqVO.getRate())
|
||||
.betweenIfPresent(TrainInfoDO::getCreateTime, reqVO.getCreateTime())
|
||||
.orderByDesc(TrainInfoDO::getId));
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,70 @@
|
||||
package cn.iocoder.yudao.module.annotation.service.traininfo;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoPageReqVO;
|
||||
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoRespVO;
|
||||
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoSaveReqVO;
|
||||
import cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo.TrainInfoDO;
|
||||
import jakarta.validation.Valid;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 训练信息 Service 接口
|
||||
*
|
||||
* @author 管理员
|
||||
*/
|
||||
public interface TrainInfoService {
|
||||
|
||||
/**
|
||||
* 创建训练信息
|
||||
*
|
||||
* @param createReqVO 创建信息
|
||||
* @return 编号
|
||||
*/
|
||||
Integer createTrainInfo(@Valid TrainInfoSaveReqVO createReqVO);
|
||||
|
||||
/**
|
||||
* 更新训练信息
|
||||
*
|
||||
* @param updateReqVO 更新信息
|
||||
*/
|
||||
void updateTrainInfo(@Valid TrainInfoSaveReqVO updateReqVO);
|
||||
|
||||
/**
|
||||
* 删除训练信息
|
||||
*
|
||||
* @param id 编号
|
||||
*/
|
||||
void deleteTrainInfo(Integer id);
|
||||
|
||||
/**
|
||||
* 获得训练信息
|
||||
*
|
||||
* @param id 编号
|
||||
* @return 训练信息
|
||||
*/
|
||||
TrainInfoDO getTrainInfo(Integer id);
|
||||
|
||||
/**
|
||||
* 获得训练信息分页
|
||||
*
|
||||
* @param pageReqVO 分页查询
|
||||
* @return 训练信息分页
|
||||
*/
|
||||
PageResult<TrainInfoDO> getTrainInfoPage(TrainInfoPageReqVO pageReqVO);
|
||||
|
||||
/**
|
||||
* 保存训练轮次信息
|
||||
*
|
||||
* @param trainId 训练ID
|
||||
* @param round 轮次
|
||||
* @param roundTotal 总轮次
|
||||
* @param info 轮次信息
|
||||
* @param rate 识别率
|
||||
* @return 编号
|
||||
*/
|
||||
Integer saveTrainRoundInfo(Integer trainId, Integer round, Integer roundTotal, String info, String rate);
|
||||
|
||||
List<TrainInfoRespVO> getTrainInfoList(Integer id);
|
||||
}
|
||||
@ -0,0 +1,92 @@
|
||||
package cn.iocoder.yudao.module.annotation.service.traininfo;
|
||||
|
||||
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
||||
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoPageReqVO;
|
||||
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoRespVO;
|
||||
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoSaveReqVO;
|
||||
import cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo.TrainInfoDO;
|
||||
import cn.iocoder.yudao.module.annotation.dal.mysql.traininfo.TrainInfoMapper;
|
||||
import jakarta.annotation.Resource;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.validation.annotation.Validated;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 训练信息 Service 实现类
|
||||
*
|
||||
* @author 管理员
|
||||
*/
|
||||
@Service
|
||||
@Validated
|
||||
public class TrainInfoServiceImpl implements TrainInfoService {
|
||||
|
||||
@Resource
|
||||
private TrainInfoMapper trainInfoMapper;
|
||||
|
||||
@Override
|
||||
public Integer createTrainInfo(TrainInfoSaveReqVO createReqVO) {
|
||||
// 插入
|
||||
TrainInfoDO data = BeanUtils.toBean(createReqVO, TrainInfoDO.class);
|
||||
trainInfoMapper.insert(data);
|
||||
// 返回
|
||||
return data.getId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateTrainInfo(TrainInfoSaveReqVO updateReqVO) {
|
||||
// 校验存在
|
||||
// validateTrainInfoExists(updateReqVO.getId());
|
||||
// 更新
|
||||
TrainInfoDO updateObj = BeanUtils.toBean(updateReqVO, TrainInfoDO.class);
|
||||
trainInfoMapper.updateById(updateObj);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteTrainInfo(Integer id) {
|
||||
// 校验存在
|
||||
validateTrainInfoExists(id);
|
||||
// 删除
|
||||
trainInfoMapper.deleteById(id);
|
||||
}
|
||||
|
||||
private void validateTrainInfoExists(Integer id) {
|
||||
if (trainInfoMapper.selectById(id) == null) {
|
||||
// throw exception("训练信息不存在");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public TrainInfoDO getTrainInfo(Integer id) {
|
||||
return trainInfoMapper.selectById(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageResult<TrainInfoDO> getTrainInfoPage(TrainInfoPageReqVO pageReqVO) {
|
||||
return trainInfoMapper.selectPage(pageReqVO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Integer saveTrainRoundInfo(Integer trainId, Integer round, Integer roundTotal, String info, String rate) {
|
||||
TrainInfoDO trainInfo = TrainInfoDO.builder()
|
||||
.trainId(trainId)
|
||||
.round(round)
|
||||
.roundTotal(roundTotal)
|
||||
.info(info)
|
||||
.rate(rate)
|
||||
.build();
|
||||
trainInfoMapper.insert(trainInfo);
|
||||
return trainInfo.getId();
|
||||
}
|
||||
@Override
|
||||
public List<TrainInfoRespVO> getTrainInfoList(Integer trainId) {
|
||||
List<TrainInfoDO> trainInfoList = trainInfoMapper.selectList(new LambdaQueryWrapperX<TrainInfoDO>()
|
||||
.eq(TrainInfoDO::getTrainId, trainId)
|
||||
.orderByAsc(TrainInfoDO::getRound)
|
||||
.last("limit 15"));
|
||||
return BeanUtils.toBean(trainInfoList, TrainInfoRespVO.class);
|
||||
}
|
||||
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,219 @@
|
||||
package cn.iocoder.yudao.module.annotation.service.yolo;
|
||||
|
||||
import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.io.*;
|
||||
import java.net.Socket;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
/**
|
||||
* YOLO服务客户端 - 通过socket与Python长连接服务通信
|
||||
*
|
||||
* @author 管理员
|
||||
*/
|
||||
@Slf4j
|
||||
@Component
|
||||
public class YoloServiceClient {
|
||||
|
||||
private static final String DEFAULT_HOST = "localhost";
|
||||
private static final int DEFAULT_PORT = 9999;
|
||||
private static final int CONNECTION_TIMEOUT = 30000; // 30秒超时
|
||||
|
||||
private final ObjectMapper objectMapper = new ObjectMapper();
|
||||
|
||||
/**
|
||||
* 通过socket服务执行预测
|
||||
*/
|
||||
public YoloConfig predictViaSocket(String modelPath, String inputPath, String outputPath, YoloConfig config) {
|
||||
try {
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("model_path", modelPath);
|
||||
params.put("input_path", inputPath);
|
||||
if (outputPath != null) {
|
||||
params.put("output_path", outputPath);
|
||||
}
|
||||
|
||||
// 添加其他YOLO配置参数
|
||||
addYoloConfigParams(params, config);
|
||||
|
||||
Map<String, Object> request = new HashMap<>();
|
||||
request.put("command", "predict");
|
||||
request.put("params", params);
|
||||
|
||||
String response = sendRequest(request);
|
||||
JsonNode responseJson = objectMapper.readTree(response);
|
||||
|
||||
if ("success".equals(responseJson.get("status").asText())) {
|
||||
config.setStatus("completed");
|
||||
config.setLogMessage(responseJson.get("message").asText());
|
||||
if (responseJson.has("output_path")) {
|
||||
config.setOutputPath(responseJson.get("output_path").asText());
|
||||
}
|
||||
return config;
|
||||
} else {
|
||||
config.setStatus("failed");
|
||||
config.setErrorMessage(responseJson.get("message").asText());
|
||||
return config;
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("Socket预测失败", e);
|
||||
config.setStatus("failed");
|
||||
config.setErrorMessage("Socket预测失败: " + e.getMessage());
|
||||
return config;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过socket服务执行训练
|
||||
*/
|
||||
public YoloConfig trainViaSocket(String modelPath, String dataPath, YoloConfig config) {
|
||||
try {
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("model_path", modelPath);
|
||||
params.put("data_path", dataPath);
|
||||
params.put("epochs", config.getEpochs());
|
||||
|
||||
// 添加其他训练配置参数
|
||||
addYoloConfigParams(params, config);
|
||||
|
||||
Map<String, Object> request = new HashMap<>();
|
||||
request.put("command", "train");
|
||||
request.put("params", params);
|
||||
|
||||
String response = sendRequest(request);
|
||||
JsonNode responseJson = objectMapper.readTree(response);
|
||||
|
||||
if ("success".equals(responseJson.get("status").asText())) {
|
||||
config.setStatus("completed");
|
||||
config.setLogMessage(responseJson.get("message").asText());
|
||||
if (responseJson.has("results_path")) {
|
||||
config.setOutputPath(responseJson.get("results_path").asText());
|
||||
}
|
||||
return config;
|
||||
} else {
|
||||
config.setStatus("failed");
|
||||
config.setErrorMessage(responseJson.get("message").asText());
|
||||
return config;
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("Socket训练失败", e);
|
||||
config.setStatus("failed");
|
||||
config.setErrorMessage("Socket训练失败: " + e.getMessage());
|
||||
return config;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过socket服务执行验证
|
||||
*/
|
||||
public YoloConfig validateViaSocket(String modelPath, String dataPath, YoloConfig config) {
|
||||
try {
|
||||
Map<String, Object> params = new HashMap<>();
|
||||
params.put("model_path", modelPath);
|
||||
if (dataPath != null) {
|
||||
params.put("data_path", dataPath);
|
||||
}
|
||||
|
||||
// 添加其他验证配置参数
|
||||
addYoloConfigParams(params, config);
|
||||
|
||||
Map<String, Object> request = new HashMap<>();
|
||||
request.put("command", "validate");
|
||||
request.put("params", params);
|
||||
|
||||
String response = sendRequest(request);
|
||||
JsonNode responseJson = objectMapper.readTree(response);
|
||||
|
||||
if ("success".equals(responseJson.get("status").asText())) {
|
||||
config.setStatus("completed");
|
||||
config.setLogMessage(responseJson.get("message").asText());
|
||||
if (responseJson.has("metrics")) {
|
||||
config.setMetrics(responseJson.get("metrics").asText());
|
||||
}
|
||||
return config;
|
||||
} else {
|
||||
config.setStatus("failed");
|
||||
config.setErrorMessage(responseJson.get("message").asText());
|
||||
return config;
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("Socket验证失败", e);
|
||||
config.setStatus("failed");
|
||||
config.setErrorMessage("Socket验证失败: " + e.getMessage());
|
||||
return config;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加YOLO配置参数到请求中
|
||||
*/
|
||||
private void addYoloConfigParams(Map<String, Object> params, YoloConfig config) {
|
||||
if (config.getImgsz() > 0) {
|
||||
params.put("imgsz", config.getImgsz());
|
||||
}
|
||||
if (config.getConfThresh() > 0) {
|
||||
params.put("conf", config.getConfThresh());
|
||||
}
|
||||
if (config.getIouThresh() > 0) {
|
||||
params.put("iou", config.getIouThresh());
|
||||
}
|
||||
if (config.getDevice() != null) {
|
||||
params.put("device", config.getDevice());
|
||||
}
|
||||
if (config.getBatchSize() > 0) {
|
||||
params.put("batch", config.getBatchSize());
|
||||
}
|
||||
params.put("save", config.isSave());
|
||||
params.put("save_txt", config.isSaveTxt());
|
||||
params.put("save_crops", config.isSaveCrops());
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送请求到socket服务
|
||||
*/
|
||||
private String sendRequest(Map<String, Object> request) throws IOException {
|
||||
try (Socket socket = new Socket(DEFAULT_HOST, DEFAULT_PORT);
|
||||
PrintWriter out = new PrintWriter(socket.getOutputStream(), true);
|
||||
BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()))) {
|
||||
|
||||
// 设置超时
|
||||
socket.setSoTimeout(CONNECTION_TIMEOUT);
|
||||
|
||||
// 发送请求
|
||||
String requestJson = objectMapper.writeValueAsString(request);
|
||||
out.println(requestJson);
|
||||
|
||||
// 读取响应
|
||||
StringBuilder response = new StringBuilder();
|
||||
String line;
|
||||
while ((line = in.readLine()) != null) {
|
||||
response.append(line);
|
||||
}
|
||||
|
||||
return response.toString();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查服务是否可用
|
||||
*/
|
||||
public boolean isServiceAvailable() {
|
||||
try {
|
||||
try (Socket socket = new Socket()) {
|
||||
socket.connect(new java.net.InetSocketAddress(DEFAULT_HOST, DEFAULT_PORT), 3000);
|
||||
return true;
|
||||
}
|
||||
} catch (IOException e) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue