parent
acdec70da5
commit
fa0c137f0d
@ -0,0 +1,106 @@
|
|||||||
|
task: detect
|
||||||
|
mode: train
|
||||||
|
model: yolov8n.pt
|
||||||
|
data: 'null'
|
||||||
|
epochs: 50
|
||||||
|
time: null
|
||||||
|
patience: 100
|
||||||
|
batch: 4
|
||||||
|
imgsz: 640
|
||||||
|
save: true
|
||||||
|
save_period: -1
|
||||||
|
cache: false
|
||||||
|
device: cpu
|
||||||
|
workers: 8
|
||||||
|
project: 'null'
|
||||||
|
name: train
|
||||||
|
exist_ok: false
|
||||||
|
pretrained: true
|
||||||
|
optimizer: auto
|
||||||
|
verbose: true
|
||||||
|
seed: 0
|
||||||
|
deterministic: true
|
||||||
|
single_cls: false
|
||||||
|
rect: false
|
||||||
|
cos_lr: false
|
||||||
|
close_mosaic: 10
|
||||||
|
resume: false
|
||||||
|
amp: true
|
||||||
|
fraction: 1.0
|
||||||
|
profile: false
|
||||||
|
freeze: null
|
||||||
|
multi_scale: false
|
||||||
|
compile: false
|
||||||
|
overlap_mask: true
|
||||||
|
mask_ratio: 4
|
||||||
|
dropout: 0.0
|
||||||
|
val: true
|
||||||
|
split: val
|
||||||
|
save_json: false
|
||||||
|
conf: null
|
||||||
|
iou: 0.7
|
||||||
|
max_det: 300
|
||||||
|
half: false
|
||||||
|
dnn: false
|
||||||
|
plots: true
|
||||||
|
source: null
|
||||||
|
vid_stride: 1
|
||||||
|
stream_buffer: false
|
||||||
|
visualize: false
|
||||||
|
augment: false
|
||||||
|
agnostic_nms: false
|
||||||
|
classes: null
|
||||||
|
retina_masks: false
|
||||||
|
embed: null
|
||||||
|
show: false
|
||||||
|
save_frames: false
|
||||||
|
save_txt: false
|
||||||
|
save_conf: false
|
||||||
|
save_crop: false
|
||||||
|
show_labels: true
|
||||||
|
show_conf: true
|
||||||
|
show_boxes: true
|
||||||
|
line_width: null
|
||||||
|
format: torchscript
|
||||||
|
keras: false
|
||||||
|
optimize: false
|
||||||
|
int8: false
|
||||||
|
dynamic: false
|
||||||
|
simplify: true
|
||||||
|
opset: null
|
||||||
|
workspace: null
|
||||||
|
nms: false
|
||||||
|
lr0: 0.01
|
||||||
|
lrf: 0.01
|
||||||
|
momentum: 0.937
|
||||||
|
weight_decay: 0.0005
|
||||||
|
warmup_epochs: 3.0
|
||||||
|
warmup_momentum: 0.8
|
||||||
|
warmup_bias_lr: 0.1
|
||||||
|
box: 7.5
|
||||||
|
cls: 0.5
|
||||||
|
dfl: 1.5
|
||||||
|
pose: 12.0
|
||||||
|
kobj: 1.0
|
||||||
|
nbs: 64
|
||||||
|
hsv_h: 0.015
|
||||||
|
hsv_s: 0.7
|
||||||
|
hsv_v: 0.4
|
||||||
|
degrees: 0.0
|
||||||
|
translate: 0.1
|
||||||
|
scale: 0.5
|
||||||
|
shear: 0.0
|
||||||
|
perspective: 0.0
|
||||||
|
flipud: 0.0
|
||||||
|
fliplr: 0.5
|
||||||
|
bgr: 0.0
|
||||||
|
mosaic: 1.0
|
||||||
|
mixup: 0.0
|
||||||
|
cutmix: 0.0
|
||||||
|
copy_paste: 0.0
|
||||||
|
copy_paste_mode: flip
|
||||||
|
auto_augment: randaugment
|
||||||
|
erasing: 0.4
|
||||||
|
cfg: null
|
||||||
|
tracker: botsort.yaml
|
||||||
|
save_dir: D:\git\lpNewgit\null\train
|
||||||
@ -0,0 +1,106 @@
|
|||||||
|
task: detect
|
||||||
|
mode: train
|
||||||
|
model: yolov8n.pt
|
||||||
|
data: 'null'
|
||||||
|
epochs: 50
|
||||||
|
time: null
|
||||||
|
patience: 100
|
||||||
|
batch: 4
|
||||||
|
imgsz: 640
|
||||||
|
save: true
|
||||||
|
save_period: -1
|
||||||
|
cache: false
|
||||||
|
device: cpu
|
||||||
|
workers: 8
|
||||||
|
project: 'null'
|
||||||
|
name: train2
|
||||||
|
exist_ok: false
|
||||||
|
pretrained: true
|
||||||
|
optimizer: auto
|
||||||
|
verbose: true
|
||||||
|
seed: 0
|
||||||
|
deterministic: true
|
||||||
|
single_cls: false
|
||||||
|
rect: false
|
||||||
|
cos_lr: false
|
||||||
|
close_mosaic: 10
|
||||||
|
resume: false
|
||||||
|
amp: true
|
||||||
|
fraction: 1.0
|
||||||
|
profile: false
|
||||||
|
freeze: null
|
||||||
|
multi_scale: false
|
||||||
|
compile: false
|
||||||
|
overlap_mask: true
|
||||||
|
mask_ratio: 4
|
||||||
|
dropout: 0.0
|
||||||
|
val: true
|
||||||
|
split: val
|
||||||
|
save_json: false
|
||||||
|
conf: null
|
||||||
|
iou: 0.7
|
||||||
|
max_det: 300
|
||||||
|
half: false
|
||||||
|
dnn: false
|
||||||
|
plots: true
|
||||||
|
source: null
|
||||||
|
vid_stride: 1
|
||||||
|
stream_buffer: false
|
||||||
|
visualize: false
|
||||||
|
augment: false
|
||||||
|
agnostic_nms: false
|
||||||
|
classes: null
|
||||||
|
retina_masks: false
|
||||||
|
embed: null
|
||||||
|
show: false
|
||||||
|
save_frames: false
|
||||||
|
save_txt: false
|
||||||
|
save_conf: false
|
||||||
|
save_crop: false
|
||||||
|
show_labels: true
|
||||||
|
show_conf: true
|
||||||
|
show_boxes: true
|
||||||
|
line_width: null
|
||||||
|
format: torchscript
|
||||||
|
keras: false
|
||||||
|
optimize: false
|
||||||
|
int8: false
|
||||||
|
dynamic: false
|
||||||
|
simplify: true
|
||||||
|
opset: null
|
||||||
|
workspace: null
|
||||||
|
nms: false
|
||||||
|
lr0: 0.01
|
||||||
|
lrf: 0.01
|
||||||
|
momentum: 0.937
|
||||||
|
weight_decay: 0.0005
|
||||||
|
warmup_epochs: 3.0
|
||||||
|
warmup_momentum: 0.8
|
||||||
|
warmup_bias_lr: 0.1
|
||||||
|
box: 7.5
|
||||||
|
cls: 0.5
|
||||||
|
dfl: 1.5
|
||||||
|
pose: 12.0
|
||||||
|
kobj: 1.0
|
||||||
|
nbs: 64
|
||||||
|
hsv_h: 0.015
|
||||||
|
hsv_s: 0.7
|
||||||
|
hsv_v: 0.4
|
||||||
|
degrees: 0.0
|
||||||
|
translate: 0.1
|
||||||
|
scale: 0.5
|
||||||
|
shear: 0.0
|
||||||
|
perspective: 0.0
|
||||||
|
flipud: 0.0
|
||||||
|
fliplr: 0.5
|
||||||
|
bgr: 0.0
|
||||||
|
mosaic: 1.0
|
||||||
|
mixup: 0.0
|
||||||
|
cutmix: 0.0
|
||||||
|
copy_paste: 0.0
|
||||||
|
copy_paste_mode: flip
|
||||||
|
auto_augment: randaugment
|
||||||
|
erasing: 0.4
|
||||||
|
cfg: null
|
||||||
|
tracker: botsort.yaml
|
||||||
|
save_dir: D:\git\lpNewgit\null\train2
|
||||||
@ -0,0 +1,28 @@
|
|||||||
|
@echo off
|
||||||
|
echo Starting YOLO Service...
|
||||||
|
cd /d %~dp0
|
||||||
|
|
||||||
|
REM 检查Python环境
|
||||||
|
python --version
|
||||||
|
if %errorlevel% neq 0 (
|
||||||
|
echo Python not found, please install Python first
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
REM 安装依赖(如果需要)
|
||||||
|
echo Installing dependencies...
|
||||||
|
pip install ultralytics
|
||||||
|
|
||||||
|
REM 预加载模型路径(可选)
|
||||||
|
set MODEL_PATH=%~dp0yolo11n.pt
|
||||||
|
|
||||||
|
REM 启动YOLO服务
|
||||||
|
echo Starting YOLO service on localhost:9999
|
||||||
|
if exist "%MODEL_PATH%" (
|
||||||
|
python yolo_service.py --host localhost --port 9999 --model "%MODEL_PATH%"
|
||||||
|
) else (
|
||||||
|
python yolo_service.py --host localhost --port 9999
|
||||||
|
)
|
||||||
|
|
||||||
|
pause
|
||||||
@ -0,0 +1,162 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
YOLO服务管理器
|
||||||
|
用于启动、停止、重启YOLO socket服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import signal
|
||||||
|
import argparse
|
||||||
|
import psutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
class YoloServiceManager:
|
||||||
|
def __init__(self, host='localhost', port=9999, model_path=None):
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self.model_path = model_path
|
||||||
|
self.pid_file = Path('yolo_service.pid')
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""启动YOLO服务"""
|
||||||
|
if self.is_running():
|
||||||
|
print(f"YOLO服务已在运行 (PID: {self.get_pid()})")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"启动YOLO服务 {self.host}:{self.port}")
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
sys.executable, 'yolo_service.py',
|
||||||
|
'--host', self.host,
|
||||||
|
'--port', str(self.port)
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.model_path and Path(self.model_path).exists():
|
||||||
|
cmd.extend(['--model', self.model_path])
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 启动进程
|
||||||
|
process = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
cwd=os.getcwd()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存PID
|
||||||
|
with open(self.pid_file, 'w') as f:
|
||||||
|
f.write(str(process.pid))
|
||||||
|
|
||||||
|
print(f"YOLO服务已启动 (PID: {process.pid})")
|
||||||
|
|
||||||
|
# 等待服务启动
|
||||||
|
time.sleep(2)
|
||||||
|
if self.is_service_alive():
|
||||||
|
print("服务启动成功")
|
||||||
|
else:
|
||||||
|
print("警告: 服务可能未正常启动,请检查日志")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"启动服务失败: {e}")
|
||||||
|
if self.pid_file.exists():
|
||||||
|
self.pid_file.unlink()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""停止YOLO服务"""
|
||||||
|
if not self.is_running():
|
||||||
|
print("YOLO服务未运行")
|
||||||
|
return
|
||||||
|
|
||||||
|
pid = self.get_pid()
|
||||||
|
try:
|
||||||
|
# 尝试优雅停止
|
||||||
|
os.kill(pid, signal.SIGTERM)
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
# 如果还在运行,强制停止
|
||||||
|
if psutil.pid_exists(pid):
|
||||||
|
os.kill(pid, signal.SIGKILL)
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
print(f"YOLO服务已停止 (PID: {pid})")
|
||||||
|
|
||||||
|
except ProcessLookupError:
|
||||||
|
print(f"进程 {pid} 不存在")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"停止服务失败: {e}")
|
||||||
|
finally:
|
||||||
|
if self.pid_file.exists():
|
||||||
|
self.pid_file.unlink()
|
||||||
|
|
||||||
|
def restart(self):
|
||||||
|
"""重启YOLO服务"""
|
||||||
|
self.stop()
|
||||||
|
time.sleep(1)
|
||||||
|
self.start()
|
||||||
|
|
||||||
|
def status(self):
|
||||||
|
"""查看服务状态"""
|
||||||
|
if self.is_running():
|
||||||
|
pid = self.get_pid()
|
||||||
|
if self.is_service_alive():
|
||||||
|
print(f"YOLO服务运行中 (PID: {pid})")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print(f"YOLO服务进程存在但不可用 (PID: {pid})")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
print("YOLO服务未运行")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_running(self):
|
||||||
|
"""检查服务是否在运行"""
|
||||||
|
return self.pid_file.exists() and self.get_pid() > 0
|
||||||
|
|
||||||
|
def get_pid(self):
|
||||||
|
"""获取服务PID"""
|
||||||
|
if self.pid_file.exists():
|
||||||
|
try:
|
||||||
|
with open(self.pid_file, 'r') as f:
|
||||||
|
return int(f.read().strip())
|
||||||
|
except:
|
||||||
|
return 0
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def is_service_alive(self):
|
||||||
|
"""检查服务是否可用"""
|
||||||
|
try:
|
||||||
|
import socket
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
sock.settimeout(3)
|
||||||
|
result = sock.connect_ex((self.host, self.port))
|
||||||
|
sock.close()
|
||||||
|
return result == 0
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='YOLO服务管理器')
|
||||||
|
parser.add_argument('command', choices=['start', 'stop', 'restart', 'status'],
|
||||||
|
help='命令: start, stop, restart, status')
|
||||||
|
parser.add_argument('--host', default='localhost', help='服务主机地址')
|
||||||
|
parser.add_argument('--port', type=int, default=9999, help='服务端口')
|
||||||
|
parser.add_argument('--model', help='预加载模型路径')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
manager = YoloServiceManager(args.host, args.port, args.model)
|
||||||
|
|
||||||
|
if args.command == 'start':
|
||||||
|
manager.start()
|
||||||
|
elif args.command == 'stop':
|
||||||
|
manager.stop()
|
||||||
|
elif args.command == 'restart':
|
||||||
|
manager.restart()
|
||||||
|
elif args.command == 'status':
|
||||||
|
manager.status()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Binary file not shown.
@ -0,0 +1,153 @@
|
|||||||
|
import json
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from ultralytics import YOLO
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class YOLOService:
|
||||||
|
def __init__(self):
|
||||||
|
self.models = {} # 缓存已加载的模型
|
||||||
|
self.default_model = None
|
||||||
|
|
||||||
|
def load_model(self, model_path):
|
||||||
|
"""加载并缓存模型"""
|
||||||
|
if model_path not in self.models:
|
||||||
|
logger.info(f"Loading model: {model_path}")
|
||||||
|
self.models[model_path] = YOLO(model_path)
|
||||||
|
|
||||||
|
return self.models[model_path]
|
||||||
|
|
||||||
|
def predict(self, model_path, input_path, output_path=None, **kwargs):
|
||||||
|
"""执行预测"""
|
||||||
|
try:
|
||||||
|
model = self.load_model(model_path)
|
||||||
|
results = model.predict(input_path, **kwargs)
|
||||||
|
|
||||||
|
if output_path:
|
||||||
|
# 处理保存结果
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
save_path = f"{output_path}/{i}.jpg" if output_path.endswith('/') else f"{output_path}/{i}.jpg"
|
||||||
|
result.save(save_path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'status': 'success',
|
||||||
|
'message': f'Prediction completed. Results saved to {output_path}' if output_path else 'Prediction completed',
|
||||||
|
'output_path': output_path
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
'status': 'error',
|
||||||
|
'message': str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
def train(self, model_path, data_path, epochs=50, **kwargs):
|
||||||
|
"""执行训练"""
|
||||||
|
try:
|
||||||
|
model = self.load_model(model_path)
|
||||||
|
results = model.train(data=data_path, epochs=epochs, **kwargs)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'status': 'success',
|
||||||
|
'message': 'Training completed',
|
||||||
|
'results_path': results.save_dir
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
'status': 'error',
|
||||||
|
'message': str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
def validate(self, model_path, data_path=None, **kwargs):
|
||||||
|
"""执行验证"""
|
||||||
|
try:
|
||||||
|
model = self.load_model(model_path)
|
||||||
|
results = model.validate(data=data_path, **kwargs)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'status': 'success',
|
||||||
|
'message': 'Validation completed',
|
||||||
|
'metrics': str(results)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
'status': 'error',
|
||||||
|
'message': str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
def handle_client(client_socket, yolo_service):
|
||||||
|
"""处理客户端请求"""
|
||||||
|
try:
|
||||||
|
data = client_socket.recv(4096).decode('utf-8')
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
|
||||||
|
request = json.loads(data)
|
||||||
|
command = request.get('command')
|
||||||
|
|
||||||
|
if command == 'predict':
|
||||||
|
response = yolo_service.predict(**request.get('params', {}))
|
||||||
|
elif command == 'train':
|
||||||
|
response = yolo_service.train(**request.get('params', {}))
|
||||||
|
elif command == 'validate':
|
||||||
|
response = yolo_service.validate(**request.get('params', {}))
|
||||||
|
else:
|
||||||
|
response = {'status': 'error', 'message': f'Unknown command: {command}'}
|
||||||
|
|
||||||
|
client_socket.send(json.dumps(response).encode('utf-8'))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_response = {'status': 'error', 'message': str(e)}
|
||||||
|
client_socket.send(json.dumps(error_response).encode('utf-8'))
|
||||||
|
finally:
|
||||||
|
client_socket.close()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='YOLO Service')
|
||||||
|
parser.add_argument('--host', default='localhost', help='Host to bind to')
|
||||||
|
parser.add_argument('--port', type=int, default=9999, help='Port to bind to')
|
||||||
|
parser.add_argument('--model', help='Default model to preload')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
yolo_service = YOLOService()
|
||||||
|
|
||||||
|
# 预加载默认模型
|
||||||
|
if args.model and Path(args.model).exists():
|
||||||
|
yolo_service.load_model(args.model)
|
||||||
|
logger.info(f"Preloaded model: {args.model}")
|
||||||
|
|
||||||
|
# 启动socket服务
|
||||||
|
server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
server.bind((args.host, args.port))
|
||||||
|
server.listen(5)
|
||||||
|
|
||||||
|
logger.info(f"YOLO Service started on {args.host}:{args.port}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
client_socket, addr = server.accept()
|
||||||
|
logger.info(f"Connection from {addr}")
|
||||||
|
|
||||||
|
# 为每个客户端创建新线程
|
||||||
|
client_thread = threading.Thread(
|
||||||
|
target=handle_client,
|
||||||
|
args=(client_socket, yolo_service)
|
||||||
|
)
|
||||||
|
client_thread.daemon = True
|
||||||
|
client_thread.start()
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Shutting down server...")
|
||||||
|
finally:
|
||||||
|
server.close()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Binary file not shown.
@ -0,0 +1,85 @@
|
|||||||
|
# YOLO 训练结果解析功能说明
|
||||||
|
|
||||||
|
## 功能概述
|
||||||
|
|
||||||
|
在 `YoloOperationServiceImpl.java` 中新增了自动检测和解析 YOLO 训练结果的功能,当训练日志中出现 "Results saved to" 时,系统会:
|
||||||
|
|
||||||
|
1. **自动提取保存路径**:从日志中解析出类似 `D:\data\train\1231_5\runs\train\train18` 的路径
|
||||||
|
2. **解析识别率信息**:从保存路径中的 `results.csv` 文件或其他日志文件中提取精度、召回率、mAP 等指标
|
||||||
|
3. **保存到数据库**:将路径和识别率信息保存到 `TrainResultDO` 表中
|
||||||
|
|
||||||
|
## 新增方法
|
||||||
|
|
||||||
|
### 1. `detectAndParseResultsSaved(String logLine, YoloConfig config, Integer trainId)`
|
||||||
|
- 检测日志中是否包含 "Results saved to" 或 "Saved to" 关键字
|
||||||
|
- 调用后续方法进行路径提取和识别率解析
|
||||||
|
|
||||||
|
### 2. `extractSavedPath(String logLine)`
|
||||||
|
- 使用正则表达式从日志行中提取保存路径
|
||||||
|
- 支持多种格式,如带引号、带统计信息等
|
||||||
|
|
||||||
|
### 3. `parseAccuracyFromSavedPath(String savedPath, YoloConfig config, Integer trainId)`
|
||||||
|
- 查找保存目录中的 `results.csv` 文件
|
||||||
|
- 解析 CSV 文件中的最后一行(最新结果)
|
||||||
|
- 提取 precision、recall、mAP50、mAP50-95 等指标
|
||||||
|
|
||||||
|
### 4. `parseResultsCsv(File csvFile, YoloConfig config, Integer trainId)`
|
||||||
|
- 解析 YOLO 生成的 `results.csv` 文件
|
||||||
|
- CSV 格式通常为:`epoch, train/box_loss, train/cls_loss, train/dfl_loss, metrics/precision, metrics/recall, metrics/mAP50, metrics/mAP50-95`
|
||||||
|
|
||||||
|
### 5. `saveAccuracyToDatabase(Integer trainId, double precision, double recall, double map)`
|
||||||
|
- 将解析到的识别率保存到数据库
|
||||||
|
- 格式:`precision:0.8521,recall:0.7432,map:0.8123`
|
||||||
|
|
||||||
|
## 集成位置
|
||||||
|
|
||||||
|
在 `trainAsync` 方法中,已经集成到日志读取循环中:
|
||||||
|
|
||||||
|
```java
|
||||||
|
if (stdoutLine != null) {
|
||||||
|
// ... 现有代码 ...
|
||||||
|
|
||||||
|
// 检测 "Results saved to" 并解析路径和识别率
|
||||||
|
detectAndParseResultsSaved(stdoutLine, config, trainId);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (stderrLine != null) {
|
||||||
|
// ... 现有代码 ...
|
||||||
|
|
||||||
|
// 检测 "Results saved to" 并解析路径和识别率
|
||||||
|
detectAndParseResultsSaved(stderrLine, config, trainId);
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 预期输出格式
|
||||||
|
|
||||||
|
当 YOLO 训练完成时,日志中会出现类似内容:
|
||||||
|
```
|
||||||
|
Results saved to D:\data\train\1231_5\runs\train\train18
|
||||||
|
```
|
||||||
|
|
||||||
|
系统会自动:
|
||||||
|
1. 提取路径:`D:\data\train\1231_5\runs\train\train18`
|
||||||
|
2. 查找并解析该目录下的 `results.csv` 文件
|
||||||
|
3. 保存路径和识别率到数据库
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. **文件权限**:确保 Java 进程有权限访问 YOLO 保存的文件
|
||||||
|
2. **CSV 格式**:依赖 YOLO 标准的 `results.csv` 输出格式
|
||||||
|
3. **异步处理**:识别率解析不会阻塞训练过程
|
||||||
|
4. **错误处理**:解析失败不会影响训练完成状态
|
||||||
|
|
||||||
|
## 数据库字段
|
||||||
|
|
||||||
|
保存的识别率格式:
|
||||||
|
```sql
|
||||||
|
-- TrainResultDO 表的 rate 字段
|
||||||
|
precision:0.8521,recall:0.7432,map:0.8123
|
||||||
|
```
|
||||||
|
|
||||||
|
路径字段:
|
||||||
|
```sql
|
||||||
|
-- TrainResultDO 表的 path 字段
|
||||||
|
D:\data\train\1231_5\runs\train\train18
|
||||||
|
```
|
||||||
@ -0,0 +1,9 @@
|
|||||||
|
package cn.iocoder.yudao.module.annotation.controller.admin.datas.vo;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class DatasEnhance {
|
||||||
|
private Long id;
|
||||||
|
private String[] enhancements;
|
||||||
|
}
|
||||||
@ -0,0 +1,9 @@
|
|||||||
|
package cn.iocoder.yudao.module.annotation.controller.admin.train.vo;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class RecognitionResult {
|
||||||
|
private String imageUrl;
|
||||||
|
private String trainId;
|
||||||
|
}
|
||||||
@ -0,0 +1,108 @@
|
|||||||
|
package cn.iocoder.yudao.module.annotation.controller.admin.train.vo;
|
||||||
|
|
||||||
|
import io.swagger.v3.oas.annotations.media.Schema;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@Schema(description = "系统信息 Response VO")
|
||||||
|
@Data
|
||||||
|
public class SystemInfoRespVO {
|
||||||
|
|
||||||
|
@Schema(description = "CPU信息")
|
||||||
|
private CpuInfo cpu;
|
||||||
|
|
||||||
|
@Schema(description = "GPU信息列表")
|
||||||
|
private List<GpuInfo> gpus;
|
||||||
|
|
||||||
|
@Schema(description = "Python信息")
|
||||||
|
private PythonInfo python;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Schema(description = "CPU信息")
|
||||||
|
public static class CpuInfo {
|
||||||
|
@Schema(description = "CPU型号")
|
||||||
|
private String model;
|
||||||
|
|
||||||
|
@Schema(description = "CPU核心数")
|
||||||
|
private Integer cores;
|
||||||
|
|
||||||
|
@Schema(description = "CPU使用率")
|
||||||
|
private Double usage;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Schema(description = "GPU信息")
|
||||||
|
public static class GpuInfo {
|
||||||
|
@Schema(description = "GPU ID")
|
||||||
|
private Integer id;
|
||||||
|
|
||||||
|
@Schema(description = "GPU名称")
|
||||||
|
private String name;
|
||||||
|
|
||||||
|
@Schema(description = "GPU内存总量(MB)")
|
||||||
|
private Long memoryTotal;
|
||||||
|
|
||||||
|
@Schema(description = "GPU已用内存(MB)")
|
||||||
|
private Long memoryUsed;
|
||||||
|
|
||||||
|
@Schema(description = "GPU可用内存(MB)")
|
||||||
|
private Long memoryFree;
|
||||||
|
|
||||||
|
@Schema(description = "GPU使用率")
|
||||||
|
private Double usage;
|
||||||
|
|
||||||
|
@Schema(description = "GPU温度")
|
||||||
|
private Double temperature;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Schema(description = "Python信息")
|
||||||
|
public static class PythonInfo {
|
||||||
|
@Schema(description = "是否安装Python")
|
||||||
|
private Boolean installed;
|
||||||
|
|
||||||
|
@Schema(description = "Python版本")
|
||||||
|
private String version;
|
||||||
|
|
||||||
|
@Schema(description = "Python路径")
|
||||||
|
private String path;
|
||||||
|
|
||||||
|
@Schema(description = "是否支持venv虚拟环境")
|
||||||
|
private Boolean venvSupported;
|
||||||
|
|
||||||
|
@Schema(description = "是否安装virtualenv")
|
||||||
|
private Boolean virtualenvInstalled;
|
||||||
|
|
||||||
|
@Schema(description = "是否安装conda")
|
||||||
|
private Boolean condaInstalled;
|
||||||
|
|
||||||
|
@Schema(description = "yolo虚拟环境是否存在")
|
||||||
|
private Boolean yoloEnvExists;
|
||||||
|
|
||||||
|
@Schema(description = "yolo虚拟环境中是否安装YOLO")
|
||||||
|
private Boolean yoloInstalled;
|
||||||
|
|
||||||
|
@Schema(description = "YOLO版本")
|
||||||
|
private String yoloVersion;
|
||||||
|
|
||||||
|
@Schema(description = "Python虚拟环境列表")
|
||||||
|
private List<VirtualEnvInfo> virtualEnvs;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Schema(description = "虚拟环境信息")
|
||||||
|
public static class VirtualEnvInfo {
|
||||||
|
@Schema(description = "虚拟环境名称")
|
||||||
|
private String name;
|
||||||
|
|
||||||
|
@Schema(description = "虚拟环境路径")
|
||||||
|
private String path;
|
||||||
|
|
||||||
|
@Schema(description = "虚拟环境类型 (venv/virtualenv/conda)")
|
||||||
|
private String type;
|
||||||
|
|
||||||
|
@Schema(description = "Python版本")
|
||||||
|
private String pythonVersion;
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -0,0 +1,105 @@
|
|||||||
|
package cn.iocoder.yudao.module.annotation.controller.admin.traininfo;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.apilog.core.annotation.ApiAccessLog;
|
||||||
|
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
|
||||||
|
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
||||||
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||||
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||||
|
import cn.iocoder.yudao.framework.excel.core.util.ExcelUtils;
|
||||||
|
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoPageReqVO;
|
||||||
|
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoRespVO;
|
||||||
|
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoSaveReqVO;
|
||||||
|
import cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo.TrainInfoDO;
|
||||||
|
import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService;
|
||||||
|
import cn.iocoder.yudao.module.system.util.RedisUtil;
|
||||||
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
|
import io.swagger.v3.oas.annotations.Parameter;
|
||||||
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||||
|
import jakarta.annotation.Resource;
|
||||||
|
import jakarta.servlet.http.HttpServletResponse;
|
||||||
|
import jakarta.validation.Valid;
|
||||||
|
import org.springframework.security.access.prepost.PreAuthorize;
|
||||||
|
import org.springframework.validation.annotation.Validated;
|
||||||
|
import org.springframework.web.bind.annotation.*;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static cn.iocoder.yudao.framework.apilog.core.enums.OperateTypeEnum.EXPORT;
|
||||||
|
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
|
||||||
|
|
||||||
|
@Tag(name = "管理后台 - 识别结果")
|
||||||
|
@RestController
|
||||||
|
@RequestMapping("/annotation/train-Info")
|
||||||
|
@Validated
|
||||||
|
public class TrainInfoController {
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private TrainInfoService trainInfoService;
|
||||||
|
|
||||||
|
@PostMapping("/create")
|
||||||
|
@Operation(summary = "创建识别结果")
|
||||||
|
@PreAuthorize("@ss.hasPermission('annotation:train-Info:create')")
|
||||||
|
public CommonResult<Integer> createTrainInfo(@Valid @RequestBody TrainInfoSaveReqVO createReqVO) {
|
||||||
|
return success(trainInfoService.createTrainInfo(createReqVO));
|
||||||
|
}
|
||||||
|
|
||||||
|
@PutMapping("/update")
|
||||||
|
@Operation(summary = "更新识别结果")
|
||||||
|
@PreAuthorize("@ss.hasPermission('annotation:train-Info:update')")
|
||||||
|
public CommonResult<Boolean> updateTrainInfo(@Valid @RequestBody TrainInfoSaveReqVO updateReqVO) {
|
||||||
|
trainInfoService.updateTrainInfo(updateReqVO);
|
||||||
|
return success(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@DeleteMapping("/delete")
|
||||||
|
@Operation(summary = "删除识别结果")
|
||||||
|
@Parameter(name = "id", description = "编号", required = true)
|
||||||
|
@PreAuthorize("@ss.hasPermission('annotation:train-Info:delete')")
|
||||||
|
public CommonResult<Boolean> deleteTrainInfo(@RequestParam("id") Integer id) {
|
||||||
|
trainInfoService.deleteTrainInfo(id);
|
||||||
|
return success(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@GetMapping("/get")
|
||||||
|
@Operation(summary = "获得识别结果")
|
||||||
|
@Parameter(name = "id", description = "编号", required = true, example = "1024")
|
||||||
|
@PreAuthorize("@ss.hasPermission('annotation:train-Info:query')")
|
||||||
|
public CommonResult<TrainInfoRespVO> getTrainInfo(@RequestParam("id") Integer id) {
|
||||||
|
TrainInfoDO trainInfo = trainInfoService.getTrainInfo(id);
|
||||||
|
return success(BeanUtils.toBean(trainInfo, TrainInfoRespVO.class));
|
||||||
|
}
|
||||||
|
@Resource
|
||||||
|
RedisUtil redisUtil;
|
||||||
|
|
||||||
|
@GetMapping("/get-list")
|
||||||
|
@Operation(summary = "获得识别结果")
|
||||||
|
@Parameter(name = "id", description = "编号", required = true, example = "1024")
|
||||||
|
@PreAuthorize("@ss.hasPermission('annotation:train-Info:query')")
|
||||||
|
public CommonResult<List<String>> getTrainInfoList(@RequestParam("trainId") Integer id) {
|
||||||
|
List<Object> result = redisUtil.lGet("yolo:training_log:" + id, 0, -1);
|
||||||
|
List<String> result1 = result.stream().map(Object::toString).toList();
|
||||||
|
return success(result1);
|
||||||
|
}
|
||||||
|
@GetMapping("/page")
|
||||||
|
@Operation(summary = "获得识别结果分页")
|
||||||
|
@PreAuthorize("@ss.hasPermission('annotation:train-Info:query')")
|
||||||
|
public CommonResult<PageResult<TrainInfoRespVO>> getTrainInfoPage(@Valid TrainInfoPageReqVO pageReqVO) {
|
||||||
|
PageResult<TrainInfoDO> pageInfo = trainInfoService.getTrainInfoPage(pageReqVO);
|
||||||
|
return success(BeanUtils.toBean(pageInfo, TrainInfoRespVO.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
@GetMapping("/export-excel")
|
||||||
|
@Operation(summary = "导出识别结果 Excel")
|
||||||
|
@PreAuthorize("@ss.hasPermission('annotation:train-Info:export')")
|
||||||
|
@ApiAccessLog(operateType = EXPORT)
|
||||||
|
public void exportTrainInfoExcel(@Valid TrainInfoPageReqVO pageReqVO,
|
||||||
|
HttpServletResponse response) throws IOException {
|
||||||
|
pageReqVO.setPageSize(PageParam.PAGE_SIZE_NONE);
|
||||||
|
List<TrainInfoDO> list = trainInfoService.getTrainInfoPage(pageReqVO).getList();
|
||||||
|
// 导出 Excel
|
||||||
|
ExcelUtils.write(response, "识别结果.xls", "数据", TrainInfoRespVO.class,
|
||||||
|
BeanUtils.toBean(list, TrainInfoRespVO.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -0,0 +1,34 @@
|
|||||||
|
package cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo;
|
||||||
|
|
||||||
|
import lombok.*;
|
||||||
|
import io.swagger.v3.oas.annotations.media.Schema;
|
||||||
|
import cn.iocoder.yudao.framework.common.pojo.PageParam;
|
||||||
|
import org.springframework.format.annotation.DateTimeFormat;
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
|
||||||
|
import static cn.iocoder.yudao.framework.common.util.date.DateUtils.FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND;
|
||||||
|
|
||||||
|
@Schema(description = "管理后台 - 训练信息分页 Request VO")
|
||||||
|
@Data
|
||||||
|
public class TrainInfoPageReqVO extends PageParam {
|
||||||
|
|
||||||
|
@Schema(description = "训练id", example = "27530")
|
||||||
|
private Integer trainId;
|
||||||
|
|
||||||
|
@Schema(description = "识别批次", example = "1")
|
||||||
|
private Integer round;
|
||||||
|
|
||||||
|
@Schema(description = "识别的总批次", example = "100")
|
||||||
|
private Integer roundTotal;
|
||||||
|
|
||||||
|
@Schema(description = "训练信息")
|
||||||
|
private String info;
|
||||||
|
|
||||||
|
@Schema(description = "识别结果rate")
|
||||||
|
private String rate;
|
||||||
|
|
||||||
|
@Schema(description = "创建时间")
|
||||||
|
@DateTimeFormat(pattern = FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND)
|
||||||
|
private LocalDateTime[] createTime;
|
||||||
|
|
||||||
|
}
|
||||||
@ -0,0 +1,44 @@
|
|||||||
|
package cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo;
|
||||||
|
|
||||||
|
import io.swagger.v3.oas.annotations.media.Schema;
|
||||||
|
import lombok.*;
|
||||||
|
import java.time.LocalDateTime;
|
||||||
|
import com.alibaba.excel.annotation.*;
|
||||||
|
|
||||||
|
@Schema(description = "管理后台 - 训练信息 Response VO")
|
||||||
|
@Data
|
||||||
|
@ExcelIgnoreUnannotated
|
||||||
|
public class TrainInfoRespVO {
|
||||||
|
|
||||||
|
@Schema(description = "id", requiredMode = Schema.RequiredMode.REQUIRED, example = "15312")
|
||||||
|
@ExcelProperty("id")
|
||||||
|
private Integer id;
|
||||||
|
|
||||||
|
@Schema(description = "训练id", example = "27530")
|
||||||
|
@ExcelProperty("训练id")
|
||||||
|
private Integer trainId;
|
||||||
|
|
||||||
|
@Schema(description = "识别批次", example = "1")
|
||||||
|
@ExcelProperty("识别批次")
|
||||||
|
private Integer round;
|
||||||
|
|
||||||
|
@Schema(description = "识别的总批次", example = "100")
|
||||||
|
@ExcelProperty("识别的总批次")
|
||||||
|
private Integer roundTotal;
|
||||||
|
|
||||||
|
@Schema(description = "训练信息")
|
||||||
|
@ExcelProperty("训练信息")
|
||||||
|
private String info;
|
||||||
|
|
||||||
|
@Schema(description = "识别结果rate")
|
||||||
|
@ExcelProperty("识别结果rate")
|
||||||
|
private String rate;
|
||||||
|
|
||||||
|
@Schema(description = "创建时间")
|
||||||
|
@ExcelProperty("创建时间")
|
||||||
|
private LocalDateTime createTime;
|
||||||
|
|
||||||
|
@Schema(description = "修改时间")
|
||||||
|
@ExcelProperty("修改时间")
|
||||||
|
private LocalDateTime updateTime;
|
||||||
|
}
|
||||||
@ -0,0 +1,24 @@
|
|||||||
|
package cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo;
|
||||||
|
|
||||||
|
import io.swagger.v3.oas.annotations.media.Schema;
|
||||||
|
import lombok.*;
|
||||||
|
|
||||||
|
@Schema(description = "管理后台 - 训练信息新增/修改 Request VO")
|
||||||
|
@Data
|
||||||
|
public class TrainInfoSaveReqVO {
|
||||||
|
|
||||||
|
@Schema(description = "训练id", requiredMode = Schema.RequiredMode.REQUIRED, example = "27530")
|
||||||
|
private Integer trainId;
|
||||||
|
|
||||||
|
@Schema(description = "识别批次", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
|
||||||
|
private Integer round;
|
||||||
|
|
||||||
|
@Schema(description = "识别的总批次", requiredMode = Schema.RequiredMode.REQUIRED, example = "100")
|
||||||
|
private Integer roundTotal;
|
||||||
|
|
||||||
|
@Schema(description = "训练信息")
|
||||||
|
private String info;
|
||||||
|
|
||||||
|
@Schema(description = "识别结果rate")
|
||||||
|
private String rate;
|
||||||
|
}
|
||||||
@ -0,0 +1,47 @@
|
|||||||
|
package cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
|
||||||
|
import com.baomidou.mybatisplus.annotation.KeySequence;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableId;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableName;
|
||||||
|
import lombok.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 识别结果 DO
|
||||||
|
*
|
||||||
|
* @author 管理员
|
||||||
|
*/
|
||||||
|
@TableName("annotation_train_info")
|
||||||
|
@KeySequence("annotation_train_info_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
|
||||||
|
@Data
|
||||||
|
@EqualsAndHashCode(callSuper = true)
|
||||||
|
@ToString(callSuper = true)
|
||||||
|
@Builder
|
||||||
|
@NoArgsConstructor
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class TrainInfoDO extends BaseDO {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* id
|
||||||
|
*/
|
||||||
|
@TableId
|
||||||
|
private Integer id;
|
||||||
|
/**
|
||||||
|
* 训练id
|
||||||
|
*/
|
||||||
|
private Integer trainId;
|
||||||
|
// 识别批次
|
||||||
|
private Integer round;
|
||||||
|
private String info;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 识别的总批次
|
||||||
|
*/
|
||||||
|
private Integer roundTotal;
|
||||||
|
/**
|
||||||
|
* 识别结果rate
|
||||||
|
*/
|
||||||
|
private String rate;
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
@ -0,0 +1,31 @@
|
|||||||
|
package cn.iocoder.yudao.module.annotation.dal.mysql.traininfo;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||||
|
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
||||||
|
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
|
||||||
|
import cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo.TrainInfoDO;
|
||||||
|
import org.apache.ibatis.annotations.Mapper;
|
||||||
|
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 训练信息 Mapper
|
||||||
|
*
|
||||||
|
* @author 管理员
|
||||||
|
*/
|
||||||
|
@Mapper
|
||||||
|
public interface TrainInfoMapper extends BaseMapperX<TrainInfoDO> {
|
||||||
|
|
||||||
|
default PageResult<TrainInfoDO> selectPage(TrainInfoPageReqVO reqVO) {
|
||||||
|
return selectPage(reqVO, new LambdaQueryWrapperX<TrainInfoDO>()
|
||||||
|
.eqIfPresent(TrainInfoDO::getTrainId, reqVO.getTrainId())
|
||||||
|
.eqIfPresent(TrainInfoDO::getRound, reqVO.getRound())
|
||||||
|
.eqIfPresent(TrainInfoDO::getRoundTotal, reqVO.getRoundTotal())
|
||||||
|
.likeIfPresent(TrainInfoDO::getInfo, reqVO.getInfo())
|
||||||
|
.eqIfPresent(TrainInfoDO::getRate, reqVO.getRate())
|
||||||
|
.betweenIfPresent(TrainInfoDO::getCreateTime, reqVO.getCreateTime())
|
||||||
|
.orderByDesc(TrainInfoDO::getId));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@ -0,0 +1,70 @@
|
|||||||
|
package cn.iocoder.yudao.module.annotation.service.traininfo;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||||
|
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoPageReqVO;
|
||||||
|
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoRespVO;
|
||||||
|
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoSaveReqVO;
|
||||||
|
import cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo.TrainInfoDO;
|
||||||
|
import jakarta.validation.Valid;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 训练信息 Service 接口
|
||||||
|
*
|
||||||
|
* @author 管理员
|
||||||
|
*/
|
||||||
|
public interface TrainInfoService {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建训练信息
|
||||||
|
*
|
||||||
|
* @param createReqVO 创建信息
|
||||||
|
* @return 编号
|
||||||
|
*/
|
||||||
|
Integer createTrainInfo(@Valid TrainInfoSaveReqVO createReqVO);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 更新训练信息
|
||||||
|
*
|
||||||
|
* @param updateReqVO 更新信息
|
||||||
|
*/
|
||||||
|
void updateTrainInfo(@Valid TrainInfoSaveReqVO updateReqVO);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 删除训练信息
|
||||||
|
*
|
||||||
|
* @param id 编号
|
||||||
|
*/
|
||||||
|
void deleteTrainInfo(Integer id);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获得训练信息
|
||||||
|
*
|
||||||
|
* @param id 编号
|
||||||
|
* @return 训练信息
|
||||||
|
*/
|
||||||
|
TrainInfoDO getTrainInfo(Integer id);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获得训练信息分页
|
||||||
|
*
|
||||||
|
* @param pageReqVO 分页查询
|
||||||
|
* @return 训练信息分页
|
||||||
|
*/
|
||||||
|
PageResult<TrainInfoDO> getTrainInfoPage(TrainInfoPageReqVO pageReqVO);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 保存训练轮次信息
|
||||||
|
*
|
||||||
|
* @param trainId 训练ID
|
||||||
|
* @param round 轮次
|
||||||
|
* @param roundTotal 总轮次
|
||||||
|
* @param info 轮次信息
|
||||||
|
* @param rate 识别率
|
||||||
|
* @return 编号
|
||||||
|
*/
|
||||||
|
Integer saveTrainRoundInfo(Integer trainId, Integer round, Integer roundTotal, String info, String rate);
|
||||||
|
|
||||||
|
List<TrainInfoRespVO> getTrainInfoList(Integer id);
|
||||||
|
}
|
||||||
@ -0,0 +1,92 @@
|
|||||||
|
package cn.iocoder.yudao.module.annotation.service.traininfo;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.framework.common.pojo.PageResult;
|
||||||
|
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
|
||||||
|
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
|
||||||
|
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoPageReqVO;
|
||||||
|
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoRespVO;
|
||||||
|
import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoSaveReqVO;
|
||||||
|
import cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo.TrainInfoDO;
|
||||||
|
import cn.iocoder.yudao.module.annotation.dal.mysql.traininfo.TrainInfoMapper;
|
||||||
|
import jakarta.annotation.Resource;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.validation.annotation.Validated;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 训练信息 Service 实现类
|
||||||
|
*
|
||||||
|
* @author 管理员
|
||||||
|
*/
|
||||||
|
@Service
|
||||||
|
@Validated
|
||||||
|
public class TrainInfoServiceImpl implements TrainInfoService {
|
||||||
|
|
||||||
|
@Resource
|
||||||
|
private TrainInfoMapper trainInfoMapper;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer createTrainInfo(TrainInfoSaveReqVO createReqVO) {
|
||||||
|
// 插入
|
||||||
|
TrainInfoDO data = BeanUtils.toBean(createReqVO, TrainInfoDO.class);
|
||||||
|
trainInfoMapper.insert(data);
|
||||||
|
// 返回
|
||||||
|
return data.getId();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void updateTrainInfo(TrainInfoSaveReqVO updateReqVO) {
|
||||||
|
// 校验存在
|
||||||
|
// validateTrainInfoExists(updateReqVO.getId());
|
||||||
|
// 更新
|
||||||
|
TrainInfoDO updateObj = BeanUtils.toBean(updateReqVO, TrainInfoDO.class);
|
||||||
|
trainInfoMapper.updateById(updateObj);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deleteTrainInfo(Integer id) {
|
||||||
|
// 校验存在
|
||||||
|
validateTrainInfoExists(id);
|
||||||
|
// 删除
|
||||||
|
trainInfoMapper.deleteById(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void validateTrainInfoExists(Integer id) {
|
||||||
|
if (trainInfoMapper.selectById(id) == null) {
|
||||||
|
// throw exception("训练信息不存在");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TrainInfoDO getTrainInfo(Integer id) {
|
||||||
|
return trainInfoMapper.selectById(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public PageResult<TrainInfoDO> getTrainInfoPage(TrainInfoPageReqVO pageReqVO) {
|
||||||
|
return trainInfoMapper.selectPage(pageReqVO);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Integer saveTrainRoundInfo(Integer trainId, Integer round, Integer roundTotal, String info, String rate) {
|
||||||
|
TrainInfoDO trainInfo = TrainInfoDO.builder()
|
||||||
|
.trainId(trainId)
|
||||||
|
.round(round)
|
||||||
|
.roundTotal(roundTotal)
|
||||||
|
.info(info)
|
||||||
|
.rate(rate)
|
||||||
|
.build();
|
||||||
|
trainInfoMapper.insert(trainInfo);
|
||||||
|
return trainInfo.getId();
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public List<TrainInfoRespVO> getTrainInfoList(Integer trainId) {
|
||||||
|
List<TrainInfoDO> trainInfoList = trainInfoMapper.selectList(new LambdaQueryWrapperX<TrainInfoDO>()
|
||||||
|
.eq(TrainInfoDO::getTrainId, trainId)
|
||||||
|
.orderByAsc(TrainInfoDO::getRound)
|
||||||
|
.last("limit 15"));
|
||||||
|
return BeanUtils.toBean(trainInfoList, TrainInfoRespVO.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,219 @@
|
|||||||
|
package cn.iocoder.yudao.module.annotation.service.yolo;
|
||||||
|
|
||||||
|
import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig;
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import java.io.*;
|
||||||
|
import java.net.Socket;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* YOLO服务客户端 - 通过socket与Python长连接服务通信
|
||||||
|
*
|
||||||
|
* @author 管理员
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
@Component
|
||||||
|
public class YoloServiceClient {
|
||||||
|
|
||||||
|
private static final String DEFAULT_HOST = "localhost";
|
||||||
|
private static final int DEFAULT_PORT = 9999;
|
||||||
|
private static final int CONNECTION_TIMEOUT = 30000; // 30秒超时
|
||||||
|
|
||||||
|
private final ObjectMapper objectMapper = new ObjectMapper();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 通过socket服务执行预测
|
||||||
|
*/
|
||||||
|
public YoloConfig predictViaSocket(String modelPath, String inputPath, String outputPath, YoloConfig config) {
|
||||||
|
try {
|
||||||
|
Map<String, Object> params = new HashMap<>();
|
||||||
|
params.put("model_path", modelPath);
|
||||||
|
params.put("input_path", inputPath);
|
||||||
|
if (outputPath != null) {
|
||||||
|
params.put("output_path", outputPath);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加其他YOLO配置参数
|
||||||
|
addYoloConfigParams(params, config);
|
||||||
|
|
||||||
|
Map<String, Object> request = new HashMap<>();
|
||||||
|
request.put("command", "predict");
|
||||||
|
request.put("params", params);
|
||||||
|
|
||||||
|
String response = sendRequest(request);
|
||||||
|
JsonNode responseJson = objectMapper.readTree(response);
|
||||||
|
|
||||||
|
if ("success".equals(responseJson.get("status").asText())) {
|
||||||
|
config.setStatus("completed");
|
||||||
|
config.setLogMessage(responseJson.get("message").asText());
|
||||||
|
if (responseJson.has("output_path")) {
|
||||||
|
config.setOutputPath(responseJson.get("output_path").asText());
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
} else {
|
||||||
|
config.setStatus("failed");
|
||||||
|
config.setErrorMessage(responseJson.get("message").asText());
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Socket预测失败", e);
|
||||||
|
config.setStatus("failed");
|
||||||
|
config.setErrorMessage("Socket预测失败: " + e.getMessage());
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 通过socket服务执行训练
|
||||||
|
*/
|
||||||
|
public YoloConfig trainViaSocket(String modelPath, String dataPath, YoloConfig config) {
|
||||||
|
try {
|
||||||
|
Map<String, Object> params = new HashMap<>();
|
||||||
|
params.put("model_path", modelPath);
|
||||||
|
params.put("data_path", dataPath);
|
||||||
|
params.put("epochs", config.getEpochs());
|
||||||
|
|
||||||
|
// 添加其他训练配置参数
|
||||||
|
addYoloConfigParams(params, config);
|
||||||
|
|
||||||
|
Map<String, Object> request = new HashMap<>();
|
||||||
|
request.put("command", "train");
|
||||||
|
request.put("params", params);
|
||||||
|
|
||||||
|
String response = sendRequest(request);
|
||||||
|
JsonNode responseJson = objectMapper.readTree(response);
|
||||||
|
|
||||||
|
if ("success".equals(responseJson.get("status").asText())) {
|
||||||
|
config.setStatus("completed");
|
||||||
|
config.setLogMessage(responseJson.get("message").asText());
|
||||||
|
if (responseJson.has("results_path")) {
|
||||||
|
config.setOutputPath(responseJson.get("results_path").asText());
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
} else {
|
||||||
|
config.setStatus("failed");
|
||||||
|
config.setErrorMessage(responseJson.get("message").asText());
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Socket训练失败", e);
|
||||||
|
config.setStatus("failed");
|
||||||
|
config.setErrorMessage("Socket训练失败: " + e.getMessage());
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 通过socket服务执行验证
|
||||||
|
*/
|
||||||
|
public YoloConfig validateViaSocket(String modelPath, String dataPath, YoloConfig config) {
|
||||||
|
try {
|
||||||
|
Map<String, Object> params = new HashMap<>();
|
||||||
|
params.put("model_path", modelPath);
|
||||||
|
if (dataPath != null) {
|
||||||
|
params.put("data_path", dataPath);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加其他验证配置参数
|
||||||
|
addYoloConfigParams(params, config);
|
||||||
|
|
||||||
|
Map<String, Object> request = new HashMap<>();
|
||||||
|
request.put("command", "validate");
|
||||||
|
request.put("params", params);
|
||||||
|
|
||||||
|
String response = sendRequest(request);
|
||||||
|
JsonNode responseJson = objectMapper.readTree(response);
|
||||||
|
|
||||||
|
if ("success".equals(responseJson.get("status").asText())) {
|
||||||
|
config.setStatus("completed");
|
||||||
|
config.setLogMessage(responseJson.get("message").asText());
|
||||||
|
if (responseJson.has("metrics")) {
|
||||||
|
config.setMetrics(responseJson.get("metrics").asText());
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
} else {
|
||||||
|
config.setStatus("failed");
|
||||||
|
config.setErrorMessage(responseJson.get("message").asText());
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("Socket验证失败", e);
|
||||||
|
config.setStatus("failed");
|
||||||
|
config.setErrorMessage("Socket验证失败: " + e.getMessage());
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 添加YOLO配置参数到请求中
|
||||||
|
*/
|
||||||
|
private void addYoloConfigParams(Map<String, Object> params, YoloConfig config) {
|
||||||
|
if (config.getImgsz() > 0) {
|
||||||
|
params.put("imgsz", config.getImgsz());
|
||||||
|
}
|
||||||
|
if (config.getConfThresh() > 0) {
|
||||||
|
params.put("conf", config.getConfThresh());
|
||||||
|
}
|
||||||
|
if (config.getIouThresh() > 0) {
|
||||||
|
params.put("iou", config.getIouThresh());
|
||||||
|
}
|
||||||
|
if (config.getDevice() != null) {
|
||||||
|
params.put("device", config.getDevice());
|
||||||
|
}
|
||||||
|
if (config.getBatchSize() > 0) {
|
||||||
|
params.put("batch", config.getBatchSize());
|
||||||
|
}
|
||||||
|
params.put("save", config.isSave());
|
||||||
|
params.put("save_txt", config.isSaveTxt());
|
||||||
|
params.put("save_crops", config.isSaveCrops());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 发送请求到socket服务
|
||||||
|
*/
|
||||||
|
private String sendRequest(Map<String, Object> request) throws IOException {
|
||||||
|
try (Socket socket = new Socket(DEFAULT_HOST, DEFAULT_PORT);
|
||||||
|
PrintWriter out = new PrintWriter(socket.getOutputStream(), true);
|
||||||
|
BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()))) {
|
||||||
|
|
||||||
|
// 设置超时
|
||||||
|
socket.setSoTimeout(CONNECTION_TIMEOUT);
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
String requestJson = objectMapper.writeValueAsString(request);
|
||||||
|
out.println(requestJson);
|
||||||
|
|
||||||
|
// 读取响应
|
||||||
|
StringBuilder response = new StringBuilder();
|
||||||
|
String line;
|
||||||
|
while ((line = in.readLine()) != null) {
|
||||||
|
response.append(line);
|
||||||
|
}
|
||||||
|
|
||||||
|
return response.toString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查服务是否可用
|
||||||
|
*/
|
||||||
|
public boolean isServiceAvailable() {
|
||||||
|
try {
|
||||||
|
try (Socket socket = new Socket()) {
|
||||||
|
socket.connect(new java.net.InetSocketAddress(DEFAULT_HOST, DEFAULT_PORT), 3000);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} catch (IOException e) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue