diff --git a/YoloOperationServiceImpl_fixed.java b/YoloOperationServiceImpl_fixed.java new file mode 100644 index 0000000..fc26617 --- /dev/null +++ b/YoloOperationServiceImpl_fixed.java @@ -0,0 +1,217 @@ +package cn.iocoder.yudao.module.annotation.service.yolo; + +import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; +import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO; +import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig; +import cn.iocoder.yudao.module.annotation.service.python.PythonVirtualEnvService; +import cn.iocoder.yudao.module.annotation.service.python.vo.PythonVirtualEnvInfo; +import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService; +import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService; +import cn.iocoder.yudao.module.system.util.RedisUtil; +import jakarta.annotation.Resource; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.io.BufferedReader; +import java.io.File; +import java.io.InputStreamReader; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +/** + * YOLO操作服务实现类 + * + * @author 管理员 + */ +@Slf4j +@Service +public class YoloOperationServiceImpl implements YoloOperationService { + + @Resource + private PythonVirtualEnvService pythonVirtualEnvService; + + @Resource + private TrainInfoService trainInfoService; + + @Resource + private TrainResultService trainResultService; + + // 异步执行器 + private final Executor asyncExecutor = Executors.newCachedThreadPool(); + + @Override + public CompletableFuture trainAsync(String pythonPath, String envName, YoloConfig config) { + return CompletableFuture.supplyAsync(() -> { + config.setStatus("running"); + config.setProgress(0); + config.setTaskId(UUID.randomUUID().toString()); + + // 生成训练ID + Integer trainId = config.getTrainId(); + + // 记录上一轮的epoch,用于检测epoch完成 + int[] lastSavedEpoch = {0}; + + try { + // 检测虚拟环境 + PythonVirtualEnvInfo envInfo = pythonVirtualEnvService.activateVirtualEnv(pythonPath, envName); + if (!envInfo.getVirtualEnvExists() || !envInfo.getYoloInstalled()) { + config.setStatus("failed"); + config.setErrorMessage("虚拟环境不存在或未安装YOLO"); + return config; + } + + // 构建训练命令 + String command = buildTrainCommand(envInfo, config); + + log.info("开始YOLO训练,命令: {}", command); + + // 执行训练 + ProcessBuilder processBuilder = new ProcessBuilder(); + // 在Windows上使用cmd,在Linux/Mac上使用bash + String os = System.getProperty("os.name").toLowerCase(); + if (os.contains("win")) { + processBuilder.command("cmd", "/c", command); + // 设置环境变量确保正确处理中文路径 + processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); + processBuilder.environment().put("LANG", "zh_CN.UTF-8"); + } else { + processBuilder.command("/bin/bash", "-c", command); + processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); + } + Process process = processBuilder.start(); + BufferedReader stdoutReader = new BufferedReader(new InputStreamReader(process.getInputStream())); + BufferedReader stderrReader = new BufferedReader(new InputStreamReader(process.getErrorStream())); +// BufferedReader stderrReader2 = new BufferedReader(new InputStreamReader(process.getOutputStream())); + + StringBuilder fullOutput = new StringBuilder(); + + // 修复:使用单独的线程读取stdout和stderr,避免死锁 + CompletableFuture stdoutFuture = CompletableFuture.runAsync(() -> { + try { + String line; + while ((line = stdoutReader.readLine()) != null) { + log.info("YOLO训练日志: {}", line); + synchronized (fullOutput) { + fullOutput.append(line).append("\n"); + } + + // 详细解析YOLO训练进度和结果 + manageTrainingLogInRedis(trainId, line); + config.setLogMessage(line); + + // 检测 "Results saved to" 并解析路径和识别率 + detectAndParseResultsSaved(line, config, trainId); + } + } catch (Exception e) { + log.error("读取stdout时发生错误", e); + } + }, asyncExecutor); + + CompletableFuture stderrFuture = CompletableFuture.runAsync(() -> { + try { + String line; + while ((line = stderrReader.readLine()) != null) { + log.info("YOLO训练日志: {}", line); + synchronized (fullOutput) { + fullOutput.append(line).append("\n"); + } + + manageTrainingLogInRedis(trainId, line); + + config.setLogMessage(line); + + // 检测 "Results saved to" 并解析路径和识别率 + detectAndParseResultsSaved(line, config, trainId); + } + } catch (Exception e) { + log.error("读取stderr时发生错误", e); + } + }, asyncExecutor); + + // 等待两个线程完成,但设置超时避免无限等待 + try { + CompletableFuture.allOf(stdoutFuture, stderrFuture).get(30, TimeUnit.MINUTES); + } catch (java.util.concurrent.TimeoutException e) { + log.error("读取YOLO输出超时,强制终止进程"); + process.destroyForcibly(); + config.setStatus("failed"); + config.setErrorMessage("训练超时"); + return config; + } + + int exitCode = process.waitFor(); + stdoutReader.close(); + stderrReader.close(); + + if (exitCode == 0) { + config.setStatus("completed"); + config.setProgress(100); + config.setLogMessage("训练完成"); + // 解析最终训练结果摘要 + parseFinalResults(fullOutput.toString(), config); + // 保存最终训练结果 + saveFinalTrainResult(trainId, config, fullOutput.toString()); + } else { + config.setStatus("failed"); + config.setErrorMessage("训练失败,退出码: " + exitCode); + } + + } catch (Exception e) { + log.error("YOLO训练异常", e); + config.setStatus("failed"); + config.setErrorMessage("训练异常: " + e.getMessage()); + } + + return config; + }, asyncExecutor); + } + + // 这里需要实现其他方法和缺失的方法,由于文件太长,我只修改了关键部分 + // 实际使用时需要保留原有的其他方法实现 + + // 以下方法占位符,实际实现从原文件复制 + private String buildTrainCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) { return ""; } + private void manageTrainingLogInRedis(Integer trainId, String line) {} + private void detectAndParseResultsSaved(String line, YoloConfig config, Integer trainId) {} + private void parseFinalResults(String output, YoloConfig config) {} + private void saveFinalTrainResult(Integer trainId, YoloConfig config, String output) {} + + @Override + public YoloConfig validate(String pythonPath, String envName, YoloConfig config) { + // 实现从原文件复制 + return config; + } + + @Override + public YoloConfig predict(String pythonPath, String envName, YoloConfig config) { + // 实现从原文件复制 + return config; + } + + @Override + public YoloConfig classify(String pythonPath, String envName, YoloConfig config) { + // 实现从原文件复制 + return config; + } + + @Override + public YoloConfig export(String pythonPath, String envName, YoloConfig config) { + // 实现从原文件复制 + return config; + } + + @Override + public YoloConfig track(String pythonPath, String envName, YoloConfig config) { + // 实现从原文件复制 + return config; + } + + @Override + public void generateTrainPythonScript(TrainDO train, String yoloDatasetPath) { + // 实现从原文件复制 + } +} \ No newline at end of file diff --git a/application-yolo-service.yml b/application-yolo-service.yml new file mode 100644 index 0000000..800896f --- /dev/null +++ b/application-yolo-service.yml @@ -0,0 +1,7 @@ +# YOLO Socket服务配置 +yolo: + service: + enabled: true # 是否启用socket服务,false则使用传统ProcessBuilder方式 + host: localhost + port: 9999 + timeout: 30000 # 连接超时时间(毫秒) \ No newline at end of file diff --git a/null/train/args.yaml b/null/train/args.yaml new file mode 100644 index 0000000..b39c3ae --- /dev/null +++ b/null/train/args.yaml @@ -0,0 +1,106 @@ +task: detect +mode: train +model: yolov8n.pt +data: 'null' +epochs: 50 +time: null +patience: 100 +batch: 4 +imgsz: 640 +save: true +save_period: -1 +cache: false +device: cpu +workers: 8 +project: 'null' +name: train +exist_ok: false +pretrained: true +optimizer: auto +verbose: true +seed: 0 +deterministic: true +single_cls: false +rect: false +cos_lr: false +close_mosaic: 10 +resume: false +amp: true +fraction: 1.0 +profile: false +freeze: null +multi_scale: false +compile: false +overlap_mask: true +mask_ratio: 4 +dropout: 0.0 +val: true +split: val +save_json: false +conf: null +iou: 0.7 +max_det: 300 +half: false +dnn: false +plots: true +source: null +vid_stride: 1 +stream_buffer: false +visualize: false +augment: false +agnostic_nms: false +classes: null +retina_masks: false +embed: null +show: false +save_frames: false +save_txt: false +save_conf: false +save_crop: false +show_labels: true +show_conf: true +show_boxes: true +line_width: null +format: torchscript +keras: false +optimize: false +int8: false +dynamic: false +simplify: true +opset: null +workspace: null +nms: false +lr0: 0.01 +lrf: 0.01 +momentum: 0.937 +weight_decay: 0.0005 +warmup_epochs: 3.0 +warmup_momentum: 0.8 +warmup_bias_lr: 0.1 +box: 7.5 +cls: 0.5 +dfl: 1.5 +pose: 12.0 +kobj: 1.0 +nbs: 64 +hsv_h: 0.015 +hsv_s: 0.7 +hsv_v: 0.4 +degrees: 0.0 +translate: 0.1 +scale: 0.5 +shear: 0.0 +perspective: 0.0 +flipud: 0.0 +fliplr: 0.5 +bgr: 0.0 +mosaic: 1.0 +mixup: 0.0 +cutmix: 0.0 +copy_paste: 0.0 +copy_paste_mode: flip +auto_augment: randaugment +erasing: 0.4 +cfg: null +tracker: botsort.yaml +save_dir: D:\git\lpNewgit\null\train diff --git a/null/train2/args.yaml b/null/train2/args.yaml new file mode 100644 index 0000000..56107b0 --- /dev/null +++ b/null/train2/args.yaml @@ -0,0 +1,106 @@ +task: detect +mode: train +model: yolov8n.pt +data: 'null' +epochs: 50 +time: null +patience: 100 +batch: 4 +imgsz: 640 +save: true +save_period: -1 +cache: false +device: cpu +workers: 8 +project: 'null' +name: train2 +exist_ok: false +pretrained: true +optimizer: auto +verbose: true +seed: 0 +deterministic: true +single_cls: false +rect: false +cos_lr: false +close_mosaic: 10 +resume: false +amp: true +fraction: 1.0 +profile: false +freeze: null +multi_scale: false +compile: false +overlap_mask: true +mask_ratio: 4 +dropout: 0.0 +val: true +split: val +save_json: false +conf: null +iou: 0.7 +max_det: 300 +half: false +dnn: false +plots: true +source: null +vid_stride: 1 +stream_buffer: false +visualize: false +augment: false +agnostic_nms: false +classes: null +retina_masks: false +embed: null +show: false +save_frames: false +save_txt: false +save_conf: false +save_crop: false +show_labels: true +show_conf: true +show_boxes: true +line_width: null +format: torchscript +keras: false +optimize: false +int8: false +dynamic: false +simplify: true +opset: null +workspace: null +nms: false +lr0: 0.01 +lrf: 0.01 +momentum: 0.937 +weight_decay: 0.0005 +warmup_epochs: 3.0 +warmup_momentum: 0.8 +warmup_bias_lr: 0.1 +box: 7.5 +cls: 0.5 +dfl: 1.5 +pose: 12.0 +kobj: 1.0 +nbs: 64 +hsv_h: 0.015 +hsv_s: 0.7 +hsv_v: 0.4 +degrees: 0.0 +translate: 0.1 +scale: 0.5 +shear: 0.0 +perspective: 0.0 +flipud: 0.0 +fliplr: 0.5 +bgr: 0.0 +mosaic: 1.0 +mixup: 0.0 +cutmix: 0.0 +copy_paste: 0.0 +copy_paste_mode: flip +auto_augment: randaugment +erasing: 0.4 +cfg: null +tracker: botsort.yaml +save_dir: D:\git\lpNewgit\null\train2 diff --git a/start_yolo_service.bat b/start_yolo_service.bat new file mode 100644 index 0000000..8484954 --- /dev/null +++ b/start_yolo_service.bat @@ -0,0 +1,28 @@ +@echo off +echo Starting YOLO Service... +cd /d %~dp0 + +REM 检查Python环境 +python --version +if %errorlevel% neq 0 ( + echo Python not found, please install Python first + pause + exit /b 1 +) + +REM 安装依赖(如果需要) +echo Installing dependencies... +pip install ultralytics + +REM 预加载模型路径(可选) +set MODEL_PATH=%~dp0yolo11n.pt + +REM 启动YOLO服务 +echo Starting YOLO service on localhost:9999 +if exist "%MODEL_PATH%" ( + python yolo_service.py --host localhost --port 9999 --model "%MODEL_PATH%" +) else ( + python yolo_service.py --host localhost --port 9999 +) + +pause \ No newline at end of file diff --git a/yolo-service-manager.py b/yolo-service-manager.py new file mode 100644 index 0000000..f2264ca --- /dev/null +++ b/yolo-service-manager.py @@ -0,0 +1,162 @@ +#!/usr/bin/env python3 +""" +YOLO服务管理器 +用于启动、停止、重启YOLO socket服务 +""" + +import subprocess +import sys +import os +import time +import signal +import argparse +import psutil +from pathlib import Path + +class YoloServiceManager: + def __init__(self, host='localhost', port=9999, model_path=None): + self.host = host + self.port = port + self.model_path = model_path + self.pid_file = Path('yolo_service.pid') + + def start(self): + """启动YOLO服务""" + if self.is_running(): + print(f"YOLO服务已在运行 (PID: {self.get_pid()})") + return + + print(f"启动YOLO服务 {self.host}:{self.port}") + + cmd = [ + sys.executable, 'yolo_service.py', + '--host', self.host, + '--port', str(self.port) + ] + + if self.model_path and Path(self.model_path).exists(): + cmd.extend(['--model', self.model_path]) + + try: + # 启动进程 + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=os.getcwd() + ) + + # 保存PID + with open(self.pid_file, 'w') as f: + f.write(str(process.pid)) + + print(f"YOLO服务已启动 (PID: {process.pid})") + + # 等待服务启动 + time.sleep(2) + if self.is_service_alive(): + print("服务启动成功") + else: + print("警告: 服务可能未正常启动,请检查日志") + + except Exception as e: + print(f"启动服务失败: {e}") + if self.pid_file.exists(): + self.pid_file.unlink() + + def stop(self): + """停止YOLO服务""" + if not self.is_running(): + print("YOLO服务未运行") + return + + pid = self.get_pid() + try: + # 尝试优雅停止 + os.kill(pid, signal.SIGTERM) + time.sleep(2) + + # 如果还在运行,强制停止 + if psutil.pid_exists(pid): + os.kill(pid, signal.SIGKILL) + time.sleep(1) + + print(f"YOLO服务已停止 (PID: {pid})") + + except ProcessLookupError: + print(f"进程 {pid} 不存在") + except Exception as e: + print(f"停止服务失败: {e}") + finally: + if self.pid_file.exists(): + self.pid_file.unlink() + + def restart(self): + """重启YOLO服务""" + self.stop() + time.sleep(1) + self.start() + + def status(self): + """查看服务状态""" + if self.is_running(): + pid = self.get_pid() + if self.is_service_alive(): + print(f"YOLO服务运行中 (PID: {pid})") + return True + else: + print(f"YOLO服务进程存在但不可用 (PID: {pid})") + return False + else: + print("YOLO服务未运行") + return False + + def is_running(self): + """检查服务是否在运行""" + return self.pid_file.exists() and self.get_pid() > 0 + + def get_pid(self): + """获取服务PID""" + if self.pid_file.exists(): + try: + with open(self.pid_file, 'r') as f: + return int(f.read().strip()) + except: + return 0 + return 0 + + def is_service_alive(self): + """检查服务是否可用""" + try: + import socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(3) + result = sock.connect_ex((self.host, self.port)) + sock.close() + return result == 0 + except: + return False + +def main(): + parser = argparse.ArgumentParser(description='YOLO服务管理器') + parser.add_argument('command', choices=['start', 'stop', 'restart', 'status'], + help='命令: start, stop, restart, status') + parser.add_argument('--host', default='localhost', help='服务主机地址') + parser.add_argument('--port', type=int, default=9999, help='服务端口') + parser.add_argument('--model', help='预加载模型路径') + + args = parser.parse_args() + + manager = YoloServiceManager(args.host, args.port, args.model) + + if args.command == 'start': + manager.start() + elif args.command == 'stop': + manager.stop() + elif args.command == 'restart': + manager.restart() + elif args.command == 'status': + manager.status() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/yolo11n.pt b/yolo11n.pt new file mode 100644 index 0000000..45b273b Binary files /dev/null and b/yolo11n.pt differ diff --git a/yolo_service.py b/yolo_service.py new file mode 100644 index 0000000..f1d0064 --- /dev/null +++ b/yolo_service.py @@ -0,0 +1,153 @@ +import json +import socket +import threading +import time +from ultralytics import YOLO +import argparse +import logging +from pathlib import Path + +# 配置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +class YOLOService: + def __init__(self): + self.models = {} # 缓存已加载的模型 + self.default_model = None + + def load_model(self, model_path): + """加载并缓存模型""" + if model_path not in self.models: + logger.info(f"Loading model: {model_path}") + self.models[model_path] = YOLO(model_path) + + return self.models[model_path] + + def predict(self, model_path, input_path, output_path=None, **kwargs): + """执行预测""" + try: + model = self.load_model(model_path) + results = model.predict(input_path, **kwargs) + + if output_path: + # 处理保存结果 + for i, result in enumerate(results): + save_path = f"{output_path}/{i}.jpg" if output_path.endswith('/') else f"{output_path}/{i}.jpg" + result.save(save_path) + + return { + 'status': 'success', + 'message': f'Prediction completed. Results saved to {output_path}' if output_path else 'Prediction completed', + 'output_path': output_path + } + except Exception as e: + return { + 'status': 'error', + 'message': str(e) + } + + def train(self, model_path, data_path, epochs=50, **kwargs): + """执行训练""" + try: + model = self.load_model(model_path) + results = model.train(data=data_path, epochs=epochs, **kwargs) + + return { + 'status': 'success', + 'message': 'Training completed', + 'results_path': results.save_dir + } + except Exception as e: + return { + 'status': 'error', + 'message': str(e) + } + + def validate(self, model_path, data_path=None, **kwargs): + """执行验证""" + try: + model = self.load_model(model_path) + results = model.validate(data=data_path, **kwargs) + + return { + 'status': 'success', + 'message': 'Validation completed', + 'metrics': str(results) + } + except Exception as e: + return { + 'status': 'error', + 'message': str(e) + } + +def handle_client(client_socket, yolo_service): + """处理客户端请求""" + try: + data = client_socket.recv(4096).decode('utf-8') + if not data: + return + + request = json.loads(data) + command = request.get('command') + + if command == 'predict': + response = yolo_service.predict(**request.get('params', {})) + elif command == 'train': + response = yolo_service.train(**request.get('params', {})) + elif command == 'validate': + response = yolo_service.validate(**request.get('params', {})) + else: + response = {'status': 'error', 'message': f'Unknown command: {command}'} + + client_socket.send(json.dumps(response).encode('utf-8')) + + except Exception as e: + error_response = {'status': 'error', 'message': str(e)} + client_socket.send(json.dumps(error_response).encode('utf-8')) + finally: + client_socket.close() + +def main(): + parser = argparse.ArgumentParser(description='YOLO Service') + parser.add_argument('--host', default='localhost', help='Host to bind to') + parser.add_argument('--port', type=int, default=9999, help='Port to bind to') + parser.add_argument('--model', help='Default model to preload') + + args = parser.parse_args() + + yolo_service = YOLOService() + + # 预加载默认模型 + if args.model and Path(args.model).exists(): + yolo_service.load_model(args.model) + logger.info(f"Preloaded model: {args.model}") + + # 启动socket服务 + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server.bind((args.host, args.port)) + server.listen(5) + + logger.info(f"YOLO Service started on {args.host}:{args.port}") + + try: + while True: + client_socket, addr = server.accept() + logger.info(f"Connection from {addr}") + + # 为每个客户端创建新线程 + client_thread = threading.Thread( + target=handle_client, + args=(client_socket, yolo_service) + ) + client_thread.daemon = True + client_thread.start() + + except KeyboardInterrupt: + logger.info("Shutting down server...") + finally: + server.close() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/yolov8n.pt b/yolov8n.pt new file mode 100644 index 0000000..0db4ca4 Binary files /dev/null and b/yolov8n.pt differ diff --git a/yudao-module-annotation/ENHANCED_YOLO_PARSING.md b/yudao-module-annotation/ENHANCED_YOLO_PARSING.md new file mode 100644 index 0000000..fbcc94f --- /dev/null +++ b/yudao-module-annotation/ENHANCED_YOLO_PARSING.md @@ -0,0 +1,112 @@ +# 增强的YOLO结果解析功能 + +## 问题解决 + +用户提到分类任务中 results.csv 文件格式不同,可能没有 mAP50-95 指标。现已增强解析功能以支持多种任务类型。 + +## 主要改进 + +### 1. 增强的 `parseResultsCsv` 方法 + +现在支持不同任务类型的CSV格式解析: + +#### 分类任务 (classify) +- **主要指标**:accuracy, top1_acc, top5_acc +- **CSV格式示例**: + ``` + epoch,train/loss,val/loss,metrics/accuracy,metrics/top1_acc,metrics/top5_acc + 0,1.234,0.987,0.8521,0.8521,0.9543 + ``` +- **解析逻辑**: + - 优先解析 `accuracy` 作为主要准确率 + - 同时记录 Top-1 和 Top-5 准确率 + - 保存格式:`accuracy:0.8521` + +#### 检测任务 (detect) +- **主要指标**:precision, recall, mAP50, mAP50-95 +- **CSV格式示例**: + ``` + epoch,train/box_loss,train/cls_loss,train/dfl_loss,metrics/precision,metrics/recall,metrics/mAP50,metrics/mAP50-95 + 0,0.123,0.456,0.789,0.8521,0.7432,0.8123,0.7654 + ``` +- **保存格式**:`precision:0.8521,recall:0.7432,map:0.7654` + +#### 分割任务 (segment) +- **主要指标**:同检测任务 +- **处理方式**:与检测任务相同 + +### 2. 新增方法 + +#### `parseGenericResults` +- 用于处理未知任务类型 +- 通用解析逻辑,尝试识别常见的指标名称 + +#### `saveClassificationAccuracyToDatabase` +- 专门用于保存分类任务的准确率 +- 区别于检测/分割任务的保存方法 + +### 3. 智能任务类型识别 + +```java +String taskType = config.getTaskType(); +if (taskType == null) { + taskType = "detect"; // 默认为检测任务 +} +``` + +根据 `YoloConfig.taskType` 自动选择合适的解析策略。 + +## 集成状态 + +### ✅ 已完成的集成 + +1. **训练任务 (trainAsync)** - 完全集成 +2. **检测 "Results saved to"** - 所有任务类型都支持 +3. **多任务类型解析** - 支持分类、检测、分割 + +### ⚠️ 待完成的集成 + +虽然代码逻辑已经增强,但以下方法中的 "Results saved to" 检测还需要添加: + +1. **验证方法 (validate)** - 需要在日志读取循环中添加检测 +2. **预测方法 (predict)** - 需要在日志读取循环中添加检测 + +## 使用示例 + +### 分类任务输出 +``` +检测到YOLO结果保存路径: D:\data\train\classify\runs\classify\train1 +分类任务解析到准确率(Accuracy): 0.8521 +分类任务解析到Top-1准确率: 0.8521 +分类任务解析到Top-5准确率: 0.9543 +已保存分类任务准确率到数据库 - 训练ID: 12345, 准确率: accuracy:0.8521 +``` + +### 检测任务输出 +``` +检测到YOLO结果保存路径: D:\data\train\detect\runs\detect\train1 +检测/分割任务解析到识别率 - Precision: 0.8521, Recall: 0.7432, mAP: 0.7654 +已保存识别率到数据库 - 训练ID: 12346, 识别率: precision:0.8521,recall:0.7432,map:0.7654 +``` + +## 数据库字段格式 + +### 分类任务 +```sql +rate: "accuracy:0.8521" +``` + +### 检测/分割任务 +```sql +rate: "precision:0.8521,recall:0.7432,map:0.7654" +``` + +## 错误处理 + +- 如果CSV解析失败,不会影响任务完成状态 +- 支持部分指标解析,即使某些字段缺失 +- 详细的日志记录便于调试 + +## 下一步 + +需要为 `validate` 和 `predict` 方法添加相同的 "Results saved to" 检测逻辑,以实现完整的功能覆盖。 \ No newline at end of file diff --git a/yudao-module-annotation/REDIS_LOG_MANAGEMENT.md b/yudao-module-annotation/REDIS_LOG_MANAGEMENT.md new file mode 100644 index 0000000..907764a --- /dev/null +++ b/yudao-module-annotation/REDIS_LOG_MANAGEMENT.md @@ -0,0 +1,143 @@ +# Redis 日志管理功能优化 + +## 问题描述 + +原始代码的日志管理逻辑存在问题: +- 只处理包含 `%` 的日志行,导致其他重要日志被忽略 +- 对于进度日志的处理逻辑不清晰 +- 没有正确实现"记录所有日志,但进度日志只保留最新"的需求 + +## 新的实现逻辑 + +### 核心需求 +1. **记录所有日志**:不只是包含 `%` 的日志,所有训练日志都应该记录 +2. **进度日志去重**:当遇到包含 `%` 的进度日志时,删除之前的进度日志,只保留最新的 +3. **容量控制**:最多保存 50 条日志,超过时删除最旧的 +4. **过期时间**:所有日志保存 5 天 + +### 修改后的方法 + +#### `manageTrainingLogInRedis` + +```java +private void manageTrainingLogInRedis(Integer trainId, String logLine) { + try { + if (logLine == null || logLine.trim().isEmpty()) { + return; + } + + String redisKey = "yolo:training_log:" + trainId; + boolean isProgressLog = logLine.contains("%"); + + if (isProgressLog) { + // 如果是进度日志(包含%),先删除之前的进度日志 + removeProgressLogs(redisKey); + } + + // 添加新的日志行到列表末尾 + redisUtil.lSet(redisKey, logLine, 5 * 24 * 60 * 60); // 5天过期时间(秒) + + // 检查列表长度,如果超过50条,删除第一条 + long newSize = redisUtil.lGetListSize(redisKey); + if (newSize > 50) { + redisUtil.lRemove(redisKey, 1, redisUtil.lGetIndex(redisKey, 0)); + } + + log.debug("训练日志已存储到Redis - 训练ID: {}, 日志类型: {}, 当前日志数量: {}", + trainId, isProgressLog ? "进度日志" : "普通日志", redisUtil.lGetListSize(redisKey)); + + } catch (Exception e) { + log.error("管理训练日志Redis操作失败 - 训练ID: {}", trainId, e); + } +} +``` + +#### `removeProgressLogs` + +新增的辅助方法,专门用于删除所有进度日志: + +```java +private void removeProgressLogs(String redisKey) { + try { + // 获取所有日志 + java.util.List allLogs = redisUtil.lGet(redisKey, 0, -1); + + // 找出并删除所有包含%的日志 + for (int i = allLogs.size() - 1; i >= 0; i--) { + Object logObj = allLogs.get(i); + if (logObj != null) { + String logStr = logObj.toString(); + if (logStr.contains("%")) { + redisUtil.lRemove(redisKey, 1, logStr); + log.debug("删除旧的进度日志: {}", logStr); + } + } + } + + } catch (Exception e) { + log.error("删除进度日志时出错 - Redis键: {}", redisKey, e); + } +} +``` + +## 工作流程示例 + +### 场景1:普通日志 +``` +输入: "YOLO训练开始" +处理: 直接添加到Redis列表 +结果: Redis列表: ["YOLO训练开始"] +``` + +### 场景2:进度日志(第一次) +``` +输入: "Epoch 1/100: 50%|█████ | 100/200 [00:30<00:30]" +处理: 添加到Redis列表 +结果: Redis列表: ["YOLO训练开始", "Epoch 1/100: 50%|█████ | 100/200 [00:30<00:30]"] +``` + +### 场景3:进度日志(更新) +``` +输入: "Epoch 1/100: 75%|███████▌ | 150/200 [00:45<00:15]" +处理: 1. 删除之前的进度日志 2. 添加新的进度日志 +结果: Redis列表: ["YOLO训练开始", "Epoch 1/100: 75%|███████▌ | 150/200 [00:45<00:15]"] +``` + +### 场景4:混合日志 +``` +输入序列: +1. "模型加载完成" -> 普通日志,直接添加 +2. "Epoch 1/100: 25%|██▌ | 50/200 [00:15<00:45]" -> 进度日志,添加 +3. "数据预处理完成" -> 普通日志,直接添加 +4. "Epoch 1/100: 50%|█████ | 100/200 [00:30<00:30]" -> 进度日志,删除之前的进度日志并添加 + +最终Redis列表: +["模型加载完成", "数据预处理完成", "Epoch 1/100: 50%|█████ | 100/200 [00:30<00:30]"] +``` + +## 优势 + +1. **完整性**:所有重要日志都被保存,不会遗漏 +2. **效率性**:进度日志只保留最新状态,避免重复 +3. **可控性**:最多50条日志,防止Redis占用过多内存 +4. **可维护性**:清晰的日志类型区分和调试信息 + +## 日志类型 + +- **普通日志**:训练开始、模型加载、错误信息等 +- **进度日志**:包含 `%` 的进度条信息,如训练进度 + +## Redis 键格式 + +``` +yolo:training_log:{trainId} +``` + +例如: +``` +yolo:training_log:12345 +``` + +## 过期时间 + +所有日志条目设置 5 天过期时间(432,000 秒),自动清理旧数据。 \ No newline at end of file diff --git a/yudao-module-annotation/YOLO_RESULT_PARSING.md b/yudao-module-annotation/YOLO_RESULT_PARSING.md new file mode 100644 index 0000000..deafd27 --- /dev/null +++ b/yudao-module-annotation/YOLO_RESULT_PARSING.md @@ -0,0 +1,85 @@ +# YOLO 训练结果解析功能说明 + +## 功能概述 + +在 `YoloOperationServiceImpl.java` 中新增了自动检测和解析 YOLO 训练结果的功能,当训练日志中出现 "Results saved to" 时,系统会: + +1. **自动提取保存路径**:从日志中解析出类似 `D:\data\train\1231_5\runs\train\train18` 的路径 +2. **解析识别率信息**:从保存路径中的 `results.csv` 文件或其他日志文件中提取精度、召回率、mAP 等指标 +3. **保存到数据库**:将路径和识别率信息保存到 `TrainResultDO` 表中 + +## 新增方法 + +### 1. `detectAndParseResultsSaved(String logLine, YoloConfig config, Integer trainId)` +- 检测日志中是否包含 "Results saved to" 或 "Saved to" 关键字 +- 调用后续方法进行路径提取和识别率解析 + +### 2. `extractSavedPath(String logLine)` +- 使用正则表达式从日志行中提取保存路径 +- 支持多种格式,如带引号、带统计信息等 + +### 3. `parseAccuracyFromSavedPath(String savedPath, YoloConfig config, Integer trainId)` +- 查找保存目录中的 `results.csv` 文件 +- 解析 CSV 文件中的最后一行(最新结果) +- 提取 precision、recall、mAP50、mAP50-95 等指标 + +### 4. `parseResultsCsv(File csvFile, YoloConfig config, Integer trainId)` +- 解析 YOLO 生成的 `results.csv` 文件 +- CSV 格式通常为:`epoch, train/box_loss, train/cls_loss, train/dfl_loss, metrics/precision, metrics/recall, metrics/mAP50, metrics/mAP50-95` + +### 5. `saveAccuracyToDatabase(Integer trainId, double precision, double recall, double map)` +- 将解析到的识别率保存到数据库 +- 格式:`precision:0.8521,recall:0.7432,map:0.8123` + +## 集成位置 + +在 `trainAsync` 方法中,已经集成到日志读取循环中: + +```java +if (stdoutLine != null) { + // ... 现有代码 ... + + // 检测 "Results saved to" 并解析路径和识别率 + detectAndParseResultsSaved(stdoutLine, config, trainId); +} + +if (stderrLine != null) { + // ... 现有代码 ... + + // 检测 "Results saved to" 并解析路径和识别率 + detectAndParseResultsSaved(stderrLine, config, trainId); +} +``` + +## 预期输出格式 + +当 YOLO 训练完成时,日志中会出现类似内容: +``` +Results saved to D:\data\train\1231_5\runs\train\train18 +``` + +系统会自动: +1. 提取路径:`D:\data\train\1231_5\runs\train\train18` +2. 查找并解析该目录下的 `results.csv` 文件 +3. 保存路径和识别率到数据库 + +## 注意事项 + +1. **文件权限**:确保 Java 进程有权限访问 YOLO 保存的文件 +2. **CSV 格式**:依赖 YOLO 标准的 `results.csv` 输出格式 +3. **异步处理**:识别率解析不会阻塞训练过程 +4. **错误处理**:解析失败不会影响训练完成状态 + +## 数据库字段 + +保存的识别率格式: +```sql +-- TrainResultDO 表的 rate 字段 +precision:0.8521,recall:0.7432,map:0.8123 +``` + +路径字段: +```sql +-- TrainResultDO 表的 path 字段 +D:\data\train\1231_5\runs\train\train18 +``` \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-api/src/main/java/cn/iocoder/yudao/module/annotation/enums/ErrorCodeConstants.java b/yudao-module-annotation/yudao-module-annotation-api/src/main/java/cn/iocoder/yudao/module/annotation/enums/ErrorCodeConstants.java new file mode 100644 index 0000000..988b5fb --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-api/src/main/java/cn/iocoder/yudao/module/annotation/enums/ErrorCodeConstants.java @@ -0,0 +1,18 @@ +package cn.iocoder.yudao.module.annotation.enums; + +import cn.iocoder.yudao.framework.common.exception.ErrorCode; + +/** + * annotation 错误码枚举类 + * + * annotation 系统,使用 1-008-000-000 段 + */ +public interface ErrorCodeConstants { + + // ========== 训练信息模块 1-008-001-000 ========== + ErrorCode TRAIN_INFO_NOT_EXISTS = new ErrorCode(1_008_001_000, "训练信息不存在"); + + // ========== 训练结果模块 1-008-002-000 ========== + ErrorCode TRAIN_RESULT_NOT_EXISTS = new ErrorCode(1_008_002_000, "训练结果不存在"); + +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/pom.xml b/yudao-module-annotation/yudao-module-annotation-biz/pom.xml index b337793..c63635e 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/pom.xml +++ b/yudao-module-annotation/yudao-module-annotation-biz/pom.xml @@ -39,7 +39,10 @@ jna-platform 5.13.0 - + + com.fasterxml.jackson.core + jackson-databind + diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/datas/DatasController.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/datas/DatasController.java index fe83678..c28e759 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/datas/DatasController.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/datas/DatasController.java @@ -6,11 +6,13 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; import cn.iocoder.yudao.framework.excel.core.util.ExcelUtils; +import cn.iocoder.yudao.module.annotation.controller.admin.datas.vo.DatasEnhance; import cn.iocoder.yudao.module.annotation.controller.admin.datas.vo.DatasPageReqVO; import cn.iocoder.yudao.module.annotation.controller.admin.datas.vo.DatasRespVO; import cn.iocoder.yudao.module.annotation.controller.admin.datas.vo.DatasSaveReqVO; import cn.iocoder.yudao.module.annotation.dal.dataobject.datas.DatasDO; import cn.iocoder.yudao.module.annotation.service.datas.DatasService; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; import io.swagger.v3.oas.annotations.tags.Tag; @@ -70,6 +72,25 @@ public class DatasController { return success(BeanUtils.toBean(datas, DatasRespVO.class)); } + // 刷新数据集 + + /** + * 刷新数据集 + * 步骤1.根据path,将 + * @param id + * @return + */ + @GetMapping("/refreshDatas") + @Operation(summary = "获得数据集管理") + @Parameter(name = "id", description = "编号", required = true, example = "1024") + public CommonResult refreshDatas(@RequestParam("id") Integer id) { + List list = datasService.list(new LambdaQueryWrapper().eq(DatasDO::getId, id)); + for (DatasDO datasDO : list){ + datasService.refreshDatas(datasDO); + } + return success(true); + } + @PostMapping("/list") @Operation(summary = "获得数据集管理列表") @PreAuthorize("@ss.hasPermission('annotation:datas:query')") @@ -78,6 +99,21 @@ public class DatasController { return success(result); } + @PostMapping("/listStatus") + @Operation(summary = "获得数据集管理列表") + @PreAuthorize("@ss.hasPermission('annotation:datas:query')") + public CommonResult> listStatus(@RequestBody DatasPageReqVO pageReqVO) { + List result = datasService.list(new LambdaQueryWrapper().eq(DatasDO::getStatus, pageReqVO.getStatus())); + return success(result); + } +///默认每张图片都生成一个新图片,循环遍历选择增强的类型 + @PostMapping("/enhance") + @Operation(summary = "图片增强") +// @PreAuthorize("@ss.hasPermission('annotation:datas:query')") + public CommonResult enhance( @RequestBody DatasEnhance datasEnhance) { + datasService.enhance(datasEnhance); + return success(true); + } @GetMapping("/page") @Operation(summary = "获得数据集管理分页") @PreAuthorize("@ss.hasPermission('annotation:datas:query')") diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/datas/vo/DatasEnhance.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/datas/vo/DatasEnhance.java new file mode 100644 index 0000000..d664b2d --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/datas/vo/DatasEnhance.java @@ -0,0 +1,9 @@ +package cn.iocoder.yudao.module.annotation.controller.admin.datas.vo; + +import lombok.Data; + +@Data +public class DatasEnhance { + private Long id; + private String[] enhancements; +} diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/datas/vo/DatasPageReqVO.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/datas/vo/DatasPageReqVO.java index 2e03a45..3764c64 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/datas/vo/DatasPageReqVO.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/datas/vo/DatasPageReqVO.java @@ -1,10 +1,12 @@ package cn.iocoder.yudao.module.annotation.controller.admin.datas.vo; -import lombok.*; -import java.util.*; -import io.swagger.v3.oas.annotations.media.Schema; import cn.iocoder.yudao.framework.common.pojo.PageParam; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.ToString; import org.springframework.format.annotation.DateTimeFormat; + import java.time.LocalDateTime; import static cn.iocoder.yudao.framework.common.util.date.DateUtils.FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND; @@ -21,7 +23,7 @@ public class DatasPageReqVO extends PageParam { @Schema(description = "描述", example = "你猜") private String description; - @Schema(description = "类型,还未标注,正在标注,正在训练,训练完成", example = "1") + @Schema(description = "类型,还未标注,标注完成,正在训练,训练完成", example = "1") private String status; @Schema(description = "路径") diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/mark/MarkController.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/mark/MarkController.java index 096141a..d5f4116 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/mark/MarkController.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/mark/MarkController.java @@ -11,6 +11,8 @@ import cn.iocoder.yudao.module.annotation.controller.admin.mark.vo.MarkRespVO; import cn.iocoder.yudao.module.annotation.controller.admin.mark.vo.MarkSaveReqVO; import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkDO; import cn.iocoder.yudao.module.annotation.service.mark.MarkService; +import cn.iocoder.yudao.module.system.dal.dataobject.dict.DictDataDO; +import cn.iocoder.yudao.module.system.service.dict.DictDataService; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; @@ -51,12 +53,20 @@ public class MarkController { markService.updateMark(updateReqVO); return success(true); } + @Resource + private DictDataService dictDataService; @PostMapping("/list") @Operation(summary = "获得图片列表") @PreAuthorize("@ss.hasPermission('annotation:datas:query')") public CommonResult> list(@Valid @RequestBody MarkSaveReqVO createReqVO) { + DictDataDO basePath = dictDataService.parseDictData("visual_annotation_conf","base_url"); List result = markService.list(new QueryWrapper().eq("data_id",createReqVO.getDataId())); + for (MarkDO markDO : result){ + if (!markDO.getPath().startsWith("http")){ + markDO.setPath(basePath.getValue() + markDO.getPath()); + } + } return success(result); } @PutMapping("/update") diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/mark/MarkInfoController.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/mark/MarkInfoController.java index 4bb33c3..b0e3fc4 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/mark/MarkInfoController.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/mark/MarkInfoController.java @@ -1,8 +1,12 @@ package cn.iocoder.yudao.module.annotation.controller.admin.mark; import cn.iocoder.yudao.framework.common.pojo.CommonResult; +import cn.iocoder.yudao.module.annotation.controller.admin.mark.vo.MarkSaveReqVO; import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkInfoDO; +import cn.iocoder.yudao.module.annotation.service.MarkInfo.AnnotationData; import cn.iocoder.yudao.module.annotation.service.MarkInfo.MarkInfoService; +import cn.iocoder.yudao.module.annotation.service.mark.MarkService; +import cn.iocoder.yudao.module.system.service.dict.DictDataService; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; @@ -24,16 +28,27 @@ import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; public class MarkInfoController { @Resource - private MarkInfoService markService; + private MarkInfoService markInfoService; + @Resource + private MarkService markService; + @Resource + private DictDataService dictDataService; @PostMapping("/create") @Operation(summary = "创建标注详情") @PreAuthorize("@ss.hasPermission('annotation:mark:create')") public CommonResult createMark(@Valid @RequestBody List createReqVO) { + for (MarkInfoDO markInfoDO : createReqVO){ + markInfoDO.setCreator(null); + markInfoDO.setUpdater(null); markInfoDO.setAnnotationDataString(markInfoDO.getAnnotationData().toString()); - markService.saveOrUpdate(markInfoDO); + markInfoService.saveOrUpdate(markInfoDO); } + markInfoService.remove(new QueryWrapper().eq("mark_id",createReqVO.get(0).getMarkId()).notIn("id",createReqVO.stream().map(MarkInfoDO::getId).toList())); + markService.updateMark(new MarkSaveReqVO() + .setId(createReqVO.get(0).getMarkId()) + .setStatus(1)); return success(1); } @@ -41,8 +56,14 @@ public class MarkInfoController { @Operation(summary = "获得标注详情列表") @PreAuthorize("@ss.hasPermission('annotation:datas:query')") public CommonResult> list(@RequestParam("markId") Integer markId) { - List result = markService.list(new QueryWrapper() + + + List result = markInfoService.list(new QueryWrapper() .eq("mark_id",markId)); + for (MarkInfoDO markInfoDO : result){ + markInfoDO.setAnnotationData(AnnotationData.fromString(markInfoDO.getAnnotationDataString())); + + } return success(result); } @@ -52,7 +73,7 @@ public class MarkInfoController { @Parameter(name = "id", description = "编号", required = true) @PreAuthorize("@ss.hasPermission('annotation:mark:delete')") public CommonResult deleteMark(@RequestParam("id") Integer id) { - markService.deleteMarkInfo(id); + markInfoService.deleteMarkInfo(id); return success(true); } diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/TrainController.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/TrainController.java index d0dcef1..401e2d5 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/TrainController.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/TrainController.java @@ -1,33 +1,30 @@ package cn.iocoder.yudao.module.annotation.controller.admin.train; -import org.springframework.web.bind.annotation.*; -import jakarta.annotation.Resource; -import org.springframework.validation.annotation.Validated; -import org.springframework.security.access.prepost.PreAuthorize; -import io.swagger.v3.oas.annotations.tags.Tag; -import io.swagger.v3.oas.annotations.Parameter; -import io.swagger.v3.oas.annotations.Operation; - -import jakarta.validation.constraints.*; -import jakarta.validation.*; -import jakarta.servlet.http.*; -import java.util.*; -import java.io.IOException; - +import cn.iocoder.yudao.framework.apilog.core.annotation.ApiAccessLog; +import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageResult; -import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; -import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; - import cn.iocoder.yudao.framework.excel.core.util.ExcelUtils; - -import cn.iocoder.yudao.framework.apilog.core.annotation.ApiAccessLog; -import static cn.iocoder.yudao.framework.apilog.core.enums.OperateTypeEnum.*; - import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.*; import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; import cn.iocoder.yudao.module.annotation.service.train.TrainService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.annotation.Resource; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.validation.Valid; +import org.springframework.security.access.prepost.PreAuthorize; +import org.springframework.validation.annotation.Validated; +import org.springframework.web.bind.annotation.*; +import org.springframework.web.multipart.MultipartFile; + +import java.io.IOException; +import java.util.List; + +import static cn.iocoder.yudao.framework.apilog.core.enums.OperateTypeEnum.EXPORT; +import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; @Tag(name = "管理后台 - 训练") @RestController @@ -53,6 +50,14 @@ public class TrainController { return success(true); } + @PutMapping("/init") + @Operation(summary = "初始化训练文件夹") + @PreAuthorize("@ss.hasPermission('annotation:train:update')") + public CommonResult init(@Valid @RequestBody TrainSaveReqVO updateReqVO) { + trainService.init(updateReqVO); + return success(true); + } + @DeleteMapping("/delete") @Operation(summary = "删除训练") @Parameter(name = "id", description = "编号", required = true) @@ -62,6 +67,43 @@ public class TrainController { return success(true); } + + @GetMapping("/test-images") + @Operation(summary = "获得测试图片") +// @Parameter(name = "id", description = "编号", required = true, example = "1024") + @PreAuthorize("@ss.hasPermission('annotation:train:query')") + public CommonResult> testImages() { + List train = trainService.testImages(); + return success( train); + } + + @PostMapping("/test-recognition") + @Operation(summary = "进行训练") +// @Parameter(name = "id", description = "编号", required = true, example = "1024") + @PreAuthorize("@ss.hasPermission('annotation:train:query')") + public CommonResult testRecognition(@RequestBody RecognitionResult recognitionReqVO) { + + String train = trainService.testRecognition(recognitionReqVO.getImageUrl(),recognitionReqVO.getTrainId()); + return success( train); + } + + @PostMapping("/test-images-install") + @Operation(summary = "上传测试图片并保存到指定文件夹") + @PreAuthorize("@ss.hasPermission('annotation:train:query')") + public CommonResult> testImagesInstall(@RequestParam("files") MultipartFile[] files) { + List train = trainService.testImagesInstall(files); + return success( train); + } + + + @DeleteMapping("/test-images-clean") + @Operation(summary = "清空后端测试图片") + @PreAuthorize("@ss.hasPermission('annotation:train:query')") + public CommonResult testImagesClean() { + trainService.testImagesClean(); + return success(true); + } + @GetMapping("/get") @Operation(summary = "获得训练") @Parameter(name = "id", description = "编号", required = true, example = "1024") @@ -92,4 +134,11 @@ public class TrainController { BeanUtils.toBean(list, TrainRespVO.class)); } + @GetMapping("/system-info") + @Operation(summary = "获取系统信息(CPU、GPU、Python环境)") + @PreAuthorize("@ss.hasPermission('annotation:train:query')") + public CommonResult getSystemInfo() { + return success(trainService.getSystemInfo()); + } + } \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/RecognitionResult.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/RecognitionResult.java new file mode 100644 index 0000000..ad204cd --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/RecognitionResult.java @@ -0,0 +1,9 @@ +package cn.iocoder.yudao.module.annotation.controller.admin.train.vo; + +import lombok.Data; + +@Data +public class RecognitionResult { + private String imageUrl; + private String trainId; +} diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/SystemInfoRespVO.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/SystemInfoRespVO.java new file mode 100644 index 0000000..fddf4f8 --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/SystemInfoRespVO.java @@ -0,0 +1,108 @@ +package cn.iocoder.yudao.module.annotation.controller.admin.train.vo; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +import java.util.List; + +@Schema(description = "系统信息 Response VO") +@Data +public class SystemInfoRespVO { + + @Schema(description = "CPU信息") + private CpuInfo cpu; + + @Schema(description = "GPU信息列表") + private List gpus; + + @Schema(description = "Python信息") + private PythonInfo python; + + @Data + @Schema(description = "CPU信息") + public static class CpuInfo { + @Schema(description = "CPU型号") + private String model; + + @Schema(description = "CPU核心数") + private Integer cores; + + @Schema(description = "CPU使用率") + private Double usage; + } + + @Data + @Schema(description = "GPU信息") + public static class GpuInfo { + @Schema(description = "GPU ID") + private Integer id; + + @Schema(description = "GPU名称") + private String name; + + @Schema(description = "GPU内存总量(MB)") + private Long memoryTotal; + + @Schema(description = "GPU已用内存(MB)") + private Long memoryUsed; + + @Schema(description = "GPU可用内存(MB)") + private Long memoryFree; + + @Schema(description = "GPU使用率") + private Double usage; + + @Schema(description = "GPU温度") + private Double temperature; + } + + @Data + @Schema(description = "Python信息") + public static class PythonInfo { + @Schema(description = "是否安装Python") + private Boolean installed; + + @Schema(description = "Python版本") + private String version; + + @Schema(description = "Python路径") + private String path; + + @Schema(description = "是否支持venv虚拟环境") + private Boolean venvSupported; + + @Schema(description = "是否安装virtualenv") + private Boolean virtualenvInstalled; + + @Schema(description = "是否安装conda") + private Boolean condaInstalled; + + @Schema(description = "yolo虚拟环境是否存在") + private Boolean yoloEnvExists; + + @Schema(description = "yolo虚拟环境中是否安装YOLO") + private Boolean yoloInstalled; + + @Schema(description = "YOLO版本") + private String yoloVersion; + + @Schema(description = "Python虚拟环境列表") + private List virtualEnvs; + } + + @Data + @Schema(description = "虚拟环境信息") + public static class VirtualEnvInfo { + @Schema(description = "虚拟环境名称") + private String name; + + @Schema(description = "虚拟环境路径") + private String path; + + @Schema(description = "虚拟环境类型 (venv/virtualenv/conda)") + private String type; + + @Schema(description = "Python版本") + private String pythonVersion; + } +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/TrainPageReqVO.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/TrainPageReqVO.java index 09efe8a..24ad063 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/TrainPageReqVO.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/TrainPageReqVO.java @@ -1,10 +1,12 @@ package cn.iocoder.yudao.module.annotation.controller.admin.train.vo; -import lombok.*; -import java.util.*; -import io.swagger.v3.oas.annotations.media.Schema; import cn.iocoder.yudao.framework.common.pojo.PageParam; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.ToString; import org.springframework.format.annotation.DateTimeFormat; + import java.time.LocalDateTime; import static cn.iocoder.yudao.framework.common.util.date.DateUtils.FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND; @@ -15,7 +17,10 @@ import static cn.iocoder.yudao.framework.common.util.date.DateUtils.FORMAT_YEAR_ @ToString(callSuper = true) public class TrainPageReqVO extends PageParam { - @Schema(description = "项目id", example = "14776") + @Schema(description = "项目名", example = "14776") + private String name; + + @Schema(description = "数据集id", example = "14776") private Integer dataId; @Schema(description = "训练集比例") diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/TrainRespVO.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/TrainRespVO.java index 8dddbda..47cd8c9 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/TrainRespVO.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/TrainRespVO.java @@ -1,12 +1,11 @@ package cn.iocoder.yudao.module.annotation.controller.admin.train.vo; +import com.alibaba.excel.annotation.ExcelIgnoreUnannotated; +import com.alibaba.excel.annotation.ExcelProperty; import io.swagger.v3.oas.annotations.media.Schema; -import lombok.*; -import java.util.*; -import java.util.*; -import org.springframework.format.annotation.DateTimeFormat; +import lombok.Data; + import java.time.LocalDateTime; -import com.alibaba.excel.annotation.*; @Schema(description = "管理后台 - 训练 Response VO") @Data @@ -17,7 +16,11 @@ public class TrainRespVO { @ExcelProperty("id") private Integer id; - @Schema(description = "项目id", example = "14776") + @Schema(description = "name", requiredMode = Schema.RequiredMode.REQUIRED, example = "32206") + @ExcelProperty("name") + private String name; + + @Schema(description = "数据集id", example = "14776") @ExcelProperty("项目id") private Integer dataId; diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/TrainSaveReqVO.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/TrainSaveReqVO.java index a42b4f4..ab0e52b 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/TrainSaveReqVO.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/vo/TrainSaveReqVO.java @@ -1,9 +1,7 @@ package cn.iocoder.yudao.module.annotation.controller.admin.train.vo; import io.swagger.v3.oas.annotations.media.Schema; -import lombok.*; -import java.util.*; -import jakarta.validation.constraints.*; +import lombok.Data; @Schema(description = "管理后台 - 训练新增/修改 Request VO") @Data @@ -12,7 +10,9 @@ public class TrainSaveReqVO { @Schema(description = "id", requiredMode = Schema.RequiredMode.REQUIRED, example = "32206") private Integer id; - @Schema(description = "项目id", example = "14776") + @Schema(description = "name", requiredMode = Schema.RequiredMode.REQUIRED, example = "32206") + private String name; + @Schema(description = "数据集id", example = "14776") private Integer dataId; @Schema(description = "训练集比例") diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/TrainInfoController.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/TrainInfoController.java new file mode 100644 index 0000000..c3b06e6 --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/TrainInfoController.java @@ -0,0 +1,105 @@ +package cn.iocoder.yudao.module.annotation.controller.admin.traininfo; + +import cn.iocoder.yudao.framework.apilog.core.annotation.ApiAccessLog; +import cn.iocoder.yudao.framework.common.pojo.CommonResult; +import cn.iocoder.yudao.framework.common.pojo.PageParam; +import cn.iocoder.yudao.framework.common.pojo.PageResult; +import cn.iocoder.yudao.framework.common.util.object.BeanUtils; +import cn.iocoder.yudao.framework.excel.core.util.ExcelUtils; +import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoPageReqVO; +import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoRespVO; +import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoSaveReqVO; +import cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo.TrainInfoDO; +import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService; +import cn.iocoder.yudao.module.system.util.RedisUtil; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.annotation.Resource; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.validation.Valid; +import org.springframework.security.access.prepost.PreAuthorize; +import org.springframework.validation.annotation.Validated; +import org.springframework.web.bind.annotation.*; + +import java.io.IOException; +import java.util.List; + +import static cn.iocoder.yudao.framework.apilog.core.enums.OperateTypeEnum.EXPORT; +import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; + +@Tag(name = "管理后台 - 识别结果") +@RestController +@RequestMapping("/annotation/train-Info") +@Validated +public class TrainInfoController { + + @Resource + private TrainInfoService trainInfoService; + + @PostMapping("/create") + @Operation(summary = "创建识别结果") + @PreAuthorize("@ss.hasPermission('annotation:train-Info:create')") + public CommonResult createTrainInfo(@Valid @RequestBody TrainInfoSaveReqVO createReqVO) { + return success(trainInfoService.createTrainInfo(createReqVO)); + } + + @PutMapping("/update") + @Operation(summary = "更新识别结果") + @PreAuthorize("@ss.hasPermission('annotation:train-Info:update')") + public CommonResult updateTrainInfo(@Valid @RequestBody TrainInfoSaveReqVO updateReqVO) { + trainInfoService.updateTrainInfo(updateReqVO); + return success(true); + } + + @DeleteMapping("/delete") + @Operation(summary = "删除识别结果") + @Parameter(name = "id", description = "编号", required = true) + @PreAuthorize("@ss.hasPermission('annotation:train-Info:delete')") + public CommonResult deleteTrainInfo(@RequestParam("id") Integer id) { + trainInfoService.deleteTrainInfo(id); + return success(true); + } + + @GetMapping("/get") + @Operation(summary = "获得识别结果") + @Parameter(name = "id", description = "编号", required = true, example = "1024") + @PreAuthorize("@ss.hasPermission('annotation:train-Info:query')") + public CommonResult getTrainInfo(@RequestParam("id") Integer id) { + TrainInfoDO trainInfo = trainInfoService.getTrainInfo(id); + return success(BeanUtils.toBean(trainInfo, TrainInfoRespVO.class)); + } + @Resource + RedisUtil redisUtil; + + @GetMapping("/get-list") + @Operation(summary = "获得识别结果") + @Parameter(name = "id", description = "编号", required = true, example = "1024") + @PreAuthorize("@ss.hasPermission('annotation:train-Info:query')") + public CommonResult> getTrainInfoList(@RequestParam("trainId") Integer id) { + List result = redisUtil.lGet("yolo:training_log:" + id, 0, -1); + List result1 = result.stream().map(Object::toString).toList(); + return success(result1); + } + @GetMapping("/page") + @Operation(summary = "获得识别结果分页") + @PreAuthorize("@ss.hasPermission('annotation:train-Info:query')") + public CommonResult> getTrainInfoPage(@Valid TrainInfoPageReqVO pageReqVO) { + PageResult pageInfo = trainInfoService.getTrainInfoPage(pageReqVO); + return success(BeanUtils.toBean(pageInfo, TrainInfoRespVO.class)); + } + + @GetMapping("/export-excel") + @Operation(summary = "导出识别结果 Excel") + @PreAuthorize("@ss.hasPermission('annotation:train-Info:export')") + @ApiAccessLog(operateType = EXPORT) + public void exportTrainInfoExcel(@Valid TrainInfoPageReqVO pageReqVO, + HttpServletResponse response) throws IOException { + pageReqVO.setPageSize(PageParam.PAGE_SIZE_NONE); + List list = trainInfoService.getTrainInfoPage(pageReqVO).getList(); + // 导出 Excel + ExcelUtils.write(response, "识别结果.xls", "数据", TrainInfoRespVO.class, + BeanUtils.toBean(list, TrainInfoRespVO.class)); + } + +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/vo/TrainInfoPageReqVO.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/vo/TrainInfoPageReqVO.java new file mode 100644 index 0000000..c90e171 --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/vo/TrainInfoPageReqVO.java @@ -0,0 +1,34 @@ +package cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo; + +import lombok.*; +import io.swagger.v3.oas.annotations.media.Schema; +import cn.iocoder.yudao.framework.common.pojo.PageParam; +import org.springframework.format.annotation.DateTimeFormat; +import java.time.LocalDateTime; + +import static cn.iocoder.yudao.framework.common.util.date.DateUtils.FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND; + +@Schema(description = "管理后台 - 训练信息分页 Request VO") +@Data +public class TrainInfoPageReqVO extends PageParam { + + @Schema(description = "训练id", example = "27530") + private Integer trainId; + + @Schema(description = "识别批次", example = "1") + private Integer round; + + @Schema(description = "识别的总批次", example = "100") + private Integer roundTotal; + + @Schema(description = "训练信息") + private String info; + + @Schema(description = "识别结果rate") + private String rate; + + @Schema(description = "创建时间") + @DateTimeFormat(pattern = FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND) + private LocalDateTime[] createTime; + +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/vo/TrainInfoRespVO.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/vo/TrainInfoRespVO.java new file mode 100644 index 0000000..074063f --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/vo/TrainInfoRespVO.java @@ -0,0 +1,44 @@ +package cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.*; +import java.time.LocalDateTime; +import com.alibaba.excel.annotation.*; + +@Schema(description = "管理后台 - 训练信息 Response VO") +@Data +@ExcelIgnoreUnannotated +public class TrainInfoRespVO { + + @Schema(description = "id", requiredMode = Schema.RequiredMode.REQUIRED, example = "15312") + @ExcelProperty("id") + private Integer id; + + @Schema(description = "训练id", example = "27530") + @ExcelProperty("训练id") + private Integer trainId; + + @Schema(description = "识别批次", example = "1") + @ExcelProperty("识别批次") + private Integer round; + + @Schema(description = "识别的总批次", example = "100") + @ExcelProperty("识别的总批次") + private Integer roundTotal; + + @Schema(description = "训练信息") + @ExcelProperty("训练信息") + private String info; + + @Schema(description = "识别结果rate") + @ExcelProperty("识别结果rate") + private String rate; + + @Schema(description = "创建时间") + @ExcelProperty("创建时间") + private LocalDateTime createTime; + + @Schema(description = "修改时间") + @ExcelProperty("修改时间") + private LocalDateTime updateTime; +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/vo/TrainInfoSaveReqVO.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/vo/TrainInfoSaveReqVO.java new file mode 100644 index 0000000..d666c47 --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/vo/TrainInfoSaveReqVO.java @@ -0,0 +1,24 @@ +package cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.*; + +@Schema(description = "管理后台 - 训练信息新增/修改 Request VO") +@Data +public class TrainInfoSaveReqVO { + + @Schema(description = "训练id", requiredMode = Schema.RequiredMode.REQUIRED, example = "27530") + private Integer trainId; + + @Schema(description = "识别批次", requiredMode = Schema.RequiredMode.REQUIRED, example = "1") + private Integer round; + + @Schema(description = "识别的总批次", requiredMode = Schema.RequiredMode.REQUIRED, example = "100") + private Integer roundTotal; + + @Schema(description = "训练信息") + private String info; + + @Schema(description = "识别结果rate") + private String rate; +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/trainresult/TrainResultController.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/trainresult/TrainResultController.java index 13e9efe..4d450f3 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/trainresult/TrainResultController.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/trainresult/TrainResultController.java @@ -1,33 +1,31 @@ package cn.iocoder.yudao.module.annotation.controller.admin.trainresult; -import org.springframework.web.bind.annotation.*; -import jakarta.annotation.Resource; -import org.springframework.validation.annotation.Validated; -import org.springframework.security.access.prepost.PreAuthorize; -import io.swagger.v3.oas.annotations.tags.Tag; -import io.swagger.v3.oas.annotations.Parameter; -import io.swagger.v3.oas.annotations.Operation; - -import jakarta.validation.constraints.*; -import jakarta.validation.*; -import jakarta.servlet.http.*; -import java.util.*; -import java.io.IOException; - +import cn.iocoder.yudao.framework.apilog.core.annotation.ApiAccessLog; +import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.pojo.PageResult; -import cn.iocoder.yudao.framework.common.pojo.CommonResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; -import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; - import cn.iocoder.yudao.framework.excel.core.util.ExcelUtils; - -import cn.iocoder.yudao.framework.apilog.core.annotation.ApiAccessLog; -import static cn.iocoder.yudao.framework.apilog.core.enums.OperateTypeEnum.*; - -import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.*; +import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultPageReqVO; +import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultRespVO; +import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultSaveReqVO; import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO; import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.annotation.Resource; +import jakarta.servlet.http.HttpServletResponse; +import jakarta.validation.Valid; +import org.springframework.security.access.prepost.PreAuthorize; +import org.springframework.validation.annotation.Validated; +import org.springframework.web.bind.annotation.*; + +import java.io.IOException; +import java.util.List; + +import static cn.iocoder.yudao.framework.apilog.core.enums.OperateTypeEnum.EXPORT; +import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; @Tag(name = "管理后台 - 识别结果") @RestController @@ -71,6 +69,16 @@ public class TrainResultController { return success(BeanUtils.toBean(trainResult, TrainResultRespVO.class)); } + + @GetMapping("/status") + @Operation(summary = "获得识别结果") + @Parameter(name = "id", description = "编号", required = true, example = "1024") + @PreAuthorize("@ss.hasPermission('annotation:train-result:query')") + public CommonResult status(@RequestParam("trainId") Integer trainId) { + TrainResultDO trainResult = trainResultService.getStatus(trainId); + return success(BeanUtils.toBean(trainResult, TrainResultRespVO.class)); + } + @GetMapping("/page") @Operation(summary = "获得识别结果分页") @PreAuthorize("@ss.hasPermission('annotation:train-result:query')") diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/trainresult/vo/TrainResultSaveReqVO.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/trainresult/vo/TrainResultSaveReqVO.java index c62fdaf..c046244 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/trainresult/vo/TrainResultSaveReqVO.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/trainresult/vo/TrainResultSaveReqVO.java @@ -2,11 +2,22 @@ package cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo; import io.swagger.v3.oas.annotations.media.Schema; import lombok.*; -import java.util.*; import jakarta.validation.constraints.*; @Schema(description = "管理后台 - 识别结果新增/修改 Request VO") @Data public class TrainResultSaveReqVO { + @Schema(description = "训练id", requiredMode = Schema.RequiredMode.REQUIRED, example = "27530") + @NotNull(message = "训练id不能为空") + private Integer trainId; + + @Schema(description = "输出路径") + private String path; + + @Schema(description = "识别结果rate") + private String rate; + + @Schema(description = "数据集", example = "196") + private Integer dataId; } \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/yolo/YoloOperationController.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/yolo/YoloOperationController.java new file mode 100644 index 0000000..29a57f6 --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/yolo/YoloOperationController.java @@ -0,0 +1,184 @@ +package cn.iocoder.yudao.module.annotation.controller.admin.yolo; + +import cn.iocoder.yudao.framework.common.pojo.CommonResult; +import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; +import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig; +import cn.iocoder.yudao.module.annotation.service.python.PythonVirtualEnvService; +import cn.iocoder.yudao.module.annotation.service.python.vo.PythonVirtualEnvInfo; +import cn.iocoder.yudao.module.annotation.service.train.TrainService; +import cn.iocoder.yudao.module.annotation.service.yolo.YoloOperationService; +import cn.iocoder.yudao.module.system.dal.dataobject.dict.DictDataDO; +import cn.iocoder.yudao.module.system.service.dict.DictDataService; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.Parameter; +import io.swagger.v3.oas.annotations.tags.Tag; +import jakarta.annotation.Resource; +import jakarta.annotation.security.PermitAll; +import lombok.extern.slf4j.Slf4j; +import org.springframework.validation.annotation.Validated; +import org.springframework.web.bind.annotation.*; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success; + +@Tag(name = "管理后台 - YOLO操作") +@RestController +@RequestMapping("/annotation/yolo") +@Validated +@Slf4j +public class YoloOperationController { + + @Resource + private PythonVirtualEnvService pythonVirtualEnvService; + + @Resource + private YoloOperationService yoloOperationService; + + @Resource + private DictDataService dictDataService; + + @PostMapping("/detect-env") + @Operation(summary = "检测Python和虚拟环境") + @PermitAll + public CommonResult detectPythonAndVirtualEnv() { + + Map dictDataMap = dictDataService.getDictDataList("visual_annotation_conf"); + PythonVirtualEnvInfo info = pythonVirtualEnvService.detectPythonAndVirtualEnv( + dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue()); + + // 如果虚拟环境不存在,尝试创建虚拟环境 + if (!info.getVirtualEnvExists() && info.getPythonExists()) { + log.info("虚拟环境不存在,尝试创建虚拟环境: {}", dictDataMap.get("python_venv").getValue()); + info = pythonVirtualEnvService.createVirtualEnvironment( + dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue()); + } + + // 如果YOLO不存在,尝试安装YOLO + if (info.getVirtualEnvExists() && !info.getYoloInstalled()) { + log.info("YOLO未安装,尝试安装YOLO"); + info = pythonVirtualEnvService.installYolo( + dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue()); + } + + log.info("最终环境信息: {}", info); + return success(info); + } + + + @PostMapping("/activate-env") + @Operation(summary = "进入虚拟环境") + @PermitAll + public CommonResult activateVirtualEnv() { + Map dictDataMap = dictDataService.getDictDataList("visual_annotation_conf"); + PythonVirtualEnvInfo info = pythonVirtualEnvService.activateVirtualEnv( + dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue()); + return success(info); + } + @Resource + private TrainService trainService; + + @PostMapping("/train") + @Operation(summary = "YOLO训练(异步)") + @PermitAll + public CommonResult> trainAsync(@RequestParam("trainId") Integer id) { + Map dictDataMap = dictDataService.getDictDataList("visual_annotation_conf"); + TrainDO trainDO = trainService.getTrain(id); +// trainDO.setPath(dictDataMap.get("training_address").getValue()); + YoloConfig config = trainDO.getYolofig(dictDataMap); + config.setTrainId(id); + config.setTrainId(id); + + + CompletableFuture result = yoloOperationService.trainAsync( + dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); + return success(result); + } + + @PostMapping("/validate") + @Operation(summary = "YOLO验证") + @PermitAll + public CommonResult validate(@RequestParam("trainId") Integer id) { + Map dictDataMap = dictDataService.getDictDataList("visual_annotation_conf"); + + TrainDO trainDO = trainService.getTrain(id); +// trainDO.setPath(dictDataMap.get("training_address").getValue()); + YoloConfig config = trainDO.getYolofig(dictDataMap); + config.setTrainId(id); + YoloConfig result = yoloOperationService.validate( + dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); + return success(result); + } + + @PostMapping("/predict") + @Operation(summary = "YOLO预测") + @PermitAll + public CommonResult predict(@RequestParam("trainId") Integer id) { + Map dictDataMap = dictDataService.getDictDataList("visual_annotation_conf"); + + TrainDO trainDO = trainService.getTrain(id); +// trainDO.setPath(dictDataMap.get("training_address").getValue()); + YoloConfig config = trainDO.getYolofig(dictDataMap); + config.setTrainId(id); + YoloConfig result = yoloOperationService.predict( + dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); + return success(result); + } + + @PostMapping("/classify") + @Operation(summary = "YOLO分类") + @PermitAll + public CommonResult classify(@RequestParam("trainId") Integer id) { + Map dictDataMap = dictDataService.getDictDataList("visual_annotation_conf"); + + TrainDO trainDO = trainService.getTrain(id); +// trainDO.setPath(dictDataMap.get("training_address").getValue()); + YoloConfig config = trainDO.getYolofig(dictDataMap); + config.setTrainId(id); + YoloConfig result = yoloOperationService.classify( + dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); + return success(result); + } + + @PostMapping("/export") + @Operation(summary = "导出YOLO模型") + @PermitAll + public CommonResult export(@RequestParam("trainId") Integer id) { + Map dictDataMap = dictDataService.getDictDataList("visual_annotation_conf"); + + TrainDO trainDO = trainService.getTrain(id); +// trainDO.setPath(dictDataMap.get("training_address").getValue()); + YoloConfig config = trainDO.getYolofig(dictDataMap); + config.setTrainId(id); + YoloConfig result = yoloOperationService.export( + dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); + return success(result); + } + + @PostMapping("/track") + @Operation(summary = "YOLO追踪") + @PermitAll + public CommonResult track(@RequestParam("trainId") Integer id) { + Map dictDataMap = dictDataService.getDictDataList("visual_annotation_conf"); + + TrainDO trainDO = trainService.getTrain(id); +// trainDO.setPath(dictDataMap.get("training_address").getValue()); + YoloConfig config = trainDO.getYolofig(dictDataMap); + config.setTrainId(id); + YoloConfig result = yoloOperationService.track( + dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); + return success(result); + } + + @GetMapping("/task-status") + @Operation(summary = "获取异步任务状态") + @PermitAll + public CommonResult getTaskStatus( + @RequestParam @Parameter(description = "任务ID", required = true) String taskId) { + // 这里可以添加任务状态查询逻辑 + // 实际项目中可能需要使用缓存或数据库来存储任务状态 + return success("Task status query not implemented yet"); + } + +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/mark/MarkDO.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/mark/MarkDO.java index 9d15685..aca1e10 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/mark/MarkDO.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/mark/MarkDO.java @@ -2,13 +2,10 @@ package cn.iocoder.yudao.module.annotation.dal.dataobject.mark; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; import com.baomidou.mybatisplus.annotation.KeySequence; -import com.baomidou.mybatisplus.annotation.TableField; import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; import lombok.*; -import java.util.List; - /** * 标注 DO * @@ -45,8 +42,8 @@ public class MarkDO extends BaseDO { /** * 标注,类型是[{对象}],class_id,center_x,center_y,width,height,polygon_points,angle */ - @TableField(typeHandler = com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler.class) - private List annotation; +// @TableField(typeHandler = com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler.class) +// private List annotation; /** * 1.标注完成,0,未标注 */ diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/mark/MarkInfoDO.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/mark/MarkInfoDO.java index 4ebfdfc..382d213 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/mark/MarkInfoDO.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/mark/MarkInfoDO.java @@ -24,7 +24,7 @@ public class MarkInfoDO extends BaseDO { private Long id; private Integer classId; @TableField(exist = false) - private Integer className; + private String className; private Integer markId; private Double centerX; private Double centerY; diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/train/TrainDO.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/train/TrainDO.java index 1837d82..010197d 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/train/TrainDO.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/train/TrainDO.java @@ -1,11 +1,14 @@ package cn.iocoder.yudao.module.annotation.dal.dataobject.train; -import lombok.*; -import java.util.*; -import java.time.LocalDateTime; -import java.time.LocalDateTime; -import com.baomidou.mybatisplus.annotation.*; import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; +import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig; +import cn.iocoder.yudao.module.system.dal.dataobject.dict.DictDataDO; +import com.baomidou.mybatisplus.annotation.KeySequence; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.*; + +import java.util.Map; /** * 训练 DO @@ -27,6 +30,8 @@ public class TrainDO extends BaseDO { */ @TableId private Integer id; + + private String name; /** * 项目id */ @@ -68,4 +73,56 @@ public class TrainDO extends BaseDO { */ private Integer trainType; + private String type; + + private String outPath; + + public YoloConfig getYolofig(Map dictDataMap) { + YoloConfig config = new YoloConfig(); + // 设置数据相关配置 + config.setDatasetPath(this.path); + + // 设置训练参数 + if (this.size != null) { + config.setEpochs(this.size); + } + if (this.round != null) { + config.setBatchSize(this.round); + } + if (this.imageSize != null) { + config.setImageSize(this.imageSize); + } + + // 设置数据集比例(YOLO中需要转换) + // 注意:YOLO通常通过数据集配置文件来设置train/val/test比例 + // 这里我们可以将比例信息保存到配置中,供后续处理使用 + + // 设置预训练模型路径 + if (this.type.equals("1")) { + config.setModelPath(dictDataMap.get("detect_path").getValue()); + config.setPretrained(true); + }else { + config.setModelPath(dictDataMap.get("classify_path").getValue()); + config.setPretrained(false); + } + + // 设置GPU/CPU设备 + + + // 设置输出路径(默认在数据集路径下的runs/train目录) + if (this.path != null) { + config.setOutputPath(this.path + "/runs/train"); + } + + // 设置其他默认参数 + config.setTaskType(this.type.equals("1") ? "detect" : "classify"); + config.setModelName(this.modelPath); + config.setModelPath(this.modelPath); + config.setLearningRate(0.01); + config.setConfThresh(0.25); + config.setIouThresh(0.45); + config.setSave(true); + config.setSaveTxt(true); + return config; + } } \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/traininfo/TrainInfoDO.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/traininfo/TrainInfoDO.java new file mode 100644 index 0000000..6934585 --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/traininfo/TrainInfoDO.java @@ -0,0 +1,47 @@ +package cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo; + +import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO; +import com.baomidou.mybatisplus.annotation.KeySequence; +import com.baomidou.mybatisplus.annotation.TableId; +import com.baomidou.mybatisplus.annotation.TableName; +import lombok.*; + +/** + * 识别结果 DO + * + * @author 管理员 + */ +@TableName("annotation_train_info") +@KeySequence("annotation_train_info_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。 +@Data +@EqualsAndHashCode(callSuper = true) +@ToString(callSuper = true) +@Builder +@NoArgsConstructor +@AllArgsConstructor +public class TrainInfoDO extends BaseDO { + + /** + * id + */ + @TableId + private Integer id; + /** + * 训练id + */ + private Integer trainId; +// 识别批次 + private Integer round; + private String info; + + /** + * 识别的总批次 + */ + private Integer roundTotal; + /** + * 识别结果rate + */ + private String rate; + + +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/types/TypesDO.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/types/TypesDO.java index 37262f6..1885c1c 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/types/TypesDO.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/types/TypesDO.java @@ -33,11 +33,7 @@ public class TypesDO extends BaseDO { * 项目id */ private Integer dataId; - /** - * index - */ - @TableField(value = "`index`") - private Integer index; + /** * 颜色 */ diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/mysql/mark/MarkMapper.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/mysql/mark/MarkMapper.java index 3e0e753..da01f1f 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/mysql/mark/MarkMapper.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/mysql/mark/MarkMapper.java @@ -1,13 +1,11 @@ package cn.iocoder.yudao.module.annotation.dal.mysql.mark; -import java.util.*; - import cn.iocoder.yudao.framework.common.pojo.PageResult; -import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; +import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; +import cn.iocoder.yudao.module.annotation.controller.admin.mark.vo.MarkPageReqVO; import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkDO; import org.apache.ibatis.annotations.Mapper; -import cn.iocoder.yudao.module.annotation.controller.admin.mark.vo.*; /** * 标注 Mapper @@ -23,4 +21,5 @@ public interface MarkMapper extends BaseMapperX { .orderByDesc(MarkDO::getId)); } + void updateByIdSetStatus(MarkDO updateObj); } \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/mysql/traininfo/TrainInfoMapper.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/mysql/traininfo/TrainInfoMapper.java new file mode 100644 index 0000000..9498151 --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/mysql/traininfo/TrainInfoMapper.java @@ -0,0 +1,31 @@ +package cn.iocoder.yudao.module.annotation.dal.mysql.traininfo; + +import java.util.*; + +import cn.iocoder.yudao.framework.common.pojo.PageResult; +import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; +import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; +import cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo.TrainInfoDO; +import org.apache.ibatis.annotations.Mapper; +import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.*; + +/** + * 训练信息 Mapper + * + * @author 管理员 + */ +@Mapper +public interface TrainInfoMapper extends BaseMapperX { + + default PageResult selectPage(TrainInfoPageReqVO reqVO) { + return selectPage(reqVO, new LambdaQueryWrapperX() + .eqIfPresent(TrainInfoDO::getTrainId, reqVO.getTrainId()) + .eqIfPresent(TrainInfoDO::getRound, reqVO.getRound()) + .eqIfPresent(TrainInfoDO::getRoundTotal, reqVO.getRoundTotal()) + .likeIfPresent(TrainInfoDO::getInfo, reqVO.getInfo()) + .eqIfPresent(TrainInfoDO::getRate, reqVO.getRate()) + .betweenIfPresent(TrainInfoDO::getCreateTime, reqVO.getCreateTime()) + .orderByDesc(TrainInfoDO::getId)); + } + +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/mysql/types/TypesMapper.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/mysql/types/TypesMapper.java index 2664080..8743882 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/mysql/types/TypesMapper.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/mysql/types/TypesMapper.java @@ -1,13 +1,11 @@ package cn.iocoder.yudao.module.annotation.dal.mysql.types; -import java.util.*; - import cn.iocoder.yudao.framework.common.pojo.PageResult; -import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX; +import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; +import cn.iocoder.yudao.module.annotation.controller.admin.types.vo.TypesPageReqVO; import cn.iocoder.yudao.module.annotation.dal.dataobject.types.TypesDO; import org.apache.ibatis.annotations.Mapper; -import cn.iocoder.yudao.module.annotation.controller.admin.types.vo.*; /** * 类别 Mapper @@ -21,7 +19,7 @@ public interface TypesMapper extends BaseMapperX { return selectPage(reqVO, new LambdaQueryWrapperX() .likeIfPresent(TypesDO::getName, reqVO.getName()) .eqIfPresent(TypesDO::getDataId, reqVO.getDataId()) - .eqIfPresent(TypesDO::getIndex, reqVO.getIndex()) +// .eqIfPresent(TypesDO::getIndex, reqVO.getIndex()) .betweenIfPresent(TypesDO::getCreateTime, reqVO.getCreateTime()) .eqIfPresent(TypesDO::getColor, reqVO.getColor()) .orderByDesc(TypesDO::getId)); diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/yolo/YoloConfig.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/yolo/YoloConfig.java index fad86a3..2fe0b18 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/yolo/YoloConfig.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/yolo/YoloConfig.java @@ -2,16 +2,262 @@ package cn.iocoder.yudao.module.annotation.dal.yolo; import lombok.Data; -// YoloConfig.java +/** + * YOLO配置类 + * + * @author 管理员 + */ @Data public class YoloConfig { + Integer trainId; + + // ========== Python环境配置 ========== + + /** + * Python可执行文件路径 + */ + private String pythonPath; + + /** + * 虚拟环境名称 + */ + private String virtualEnvName; + + // ========== 基础配置 ========== + + /** + * 模型路径(用于验证、预测、导出) + */ private String modelPath; + + /** + * 数据集路径(用于训练、验证) + */ private String datasetPath; + + /** + * 输出路径(训练结果、验证结果、预测结果等) + */ private String outputPath; - private int epochs = 100; - private int batchSize = 16; + + /** + * 模型名称(如 yolov8n, yolov8s, yolov8m, yolov8l, yolov8x) + */ + private String modelName = "yolov8n"; + + /** + * 任务类型(detect, classify, segment) + */ + private String taskType = "detect"; + + // ========== 训练配置 ========== + + /** + * 训练轮数 + */ + private int epochs = 50; + + /** + * 批次大小 + */ + private int batchSize = 4; + + /** + * 学习率 + */ private double learningRate = 0.01; - private String device = "cpu"; // or "cuda" + + /** + * 设备("0"表示GPU 0,"1"表示GPU 1,"cpu"表示CPU) + */ + private String device = "0"; + + /** + * 是否使用预训练模型 + */ private boolean pretrained = true; + + /** + * 图像大小(像素) + */ + private int imageSize = 640; + + /** + * 工作线程数 + */ + private int workers = 8; + + /** + * 保存间隔(多少轮保存一次模型) + */ + private int savePeriod = -1; + + /** + * 优化器(SGD, Adam, AdamW等) + */ + private String optimizer = "SGD"; + + // ========== 验证配置 ========== + + /** + * 验证数据集路径(可选,如果不提供则使用datasetPath中的验证集) + */ + private String valDatasetPath; + + /** + * 置信度阈值 + */ + private double confThresh = 0.25; + + /** + * IOU阈值 + */ + private double iouThresh = 0.45; + + /** + * 最大检测数 + */ + private int maxDet = 1000; + + // ========== 预测配置 ========== + + /** + * 输入图片/视频路径(用于预测) + */ + private String inputPath; + + /** + * 保存预测结果 + */ + private boolean save = true; + + /** + * 是否显示预测结果 + */ + private boolean show = false; + + /** + * 保存文本结果 + */ + private boolean saveTxt = false; + + /** + * 保存裁剪结果 + */ + private boolean saveCrops = false; + + // ========== 导出配置 ========== + + /** + * 导出格式(onnx, torchscript, coreml等) + */ + private String format = "onnx"; + + /** + * 是否优化导出模型 + */ + private boolean optimize = true; + + /** + * 导出图像大小 + */ + private int imgsz = 640; + + // ========== 追踪配置 ========== + + /** + * 追踪器类型(bytetrack, botsort等) + */ + private String tracker = "bytetrack"; + + /** + * 追踪置信度阈值 + */ + private double trackConf = 0.5; + + /** + * 追踪IOU阈值 + */ + private double trackIou = 0.5; + + // ========== 运行结果 ========== + + /** + * 任务状态(pending, running, completed, failed) + */ + private String status = "pending"; + + /** + * 进度(0-100) + */ + private int progress = 0; + + /** + * 日志信息 + */ + private String logMessage; + + /** + * 错误信息 + */ + private String errorMessage; + + /** + * 训练结果指标(如mAP, precision, recall等) + */ + private String metrics; + + /** + * 任务ID(用于异步任务追踪) + */ + private String taskId; + + // ========== 详细训练进度信息 ========== + + /** + * 当前训练轮次 + */ + private int currentEpoch = 0; + + /** + * 当前批次进度(0-100) + */ + private int batchProgress = 0; + + /** + * 当前批次索引 + */ + private int currentBatch = 0; + + /** + * 总批次数 + */ + private int totalBatches = 0; + + /** + * 损失值 + */ + private double loss = 0.0; + + /** + * 训练精度 + */ + private double precision = 0.0; + + /** + * 训练召回率 + */ + private double recall = 0.0; + + /** + * mAP值 + */ + private double map = 0.0; + + /** + * 最终训练结果摘要 + */ + private String trainingSummary; + } diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/ImageService.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/ImageService.java new file mode 100644 index 0000000..daf83bf --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/ImageService.java @@ -0,0 +1,287 @@ +package cn.iocoder.yudao.module.annotation.service; + +import org.springframework.stereotype.Service; + +import javax.imageio.ImageIO; +import java.awt.*; +import java.awt.image.BufferedImage; +import java.io.File; +import java.util.Random; + +@Service +public class ImageService { + + /** + * 亮度调整 + * @param image 原始图像 + * @param factor 亮度因子 (-1.0到1.0之间,0为原图) + * @return 调整后的图像 + */ + public BufferedImage adjustBrightness(BufferedImage image, float factor) { + BufferedImage result = new BufferedImage( + image.getWidth(), image.getHeight(), image.getType()); + + // 限制factor范围 + factor = Math.max(-1.0f, Math.min(1.0f, factor)); + int offset = (int)(factor * 100); // 转换为亮度偏移值 + + for (int x = 0; x < image.getWidth(); x++) { + for (int y = 0; y < image.getHeight(); y++) { + int rgb = image.getRGB(x, y); + int r = Math.min(255, Math.max(0, ((rgb >> 16) & 0xFF) + offset)); + int g = Math.min(255, Math.max(0, ((rgb >> 8) & 0xFF) + offset)); + int b = Math.min(255, Math.max(0, (rgb & 0xFF) + offset)); + int newRgb = (r << 16) | (g << 8) | b; + result.setRGB(x, y, newRgb); + } + } + return result; + } + + /** + * 对比度调整 + * @param image 原始图像 + * @param factor 对比度因子 (0.0到2.0之间,1.0为原图) + * @return 调整后的图像 + */ + public BufferedImage adjustContrast(BufferedImage image, float factor) { + BufferedImage result = new BufferedImage( + image.getWidth(), image.getHeight(), image.getType()); + + factor = Math.max(0.0f, Math.min(2.0f, factor)); + float contrast = (100.0f * (factor - 1.0f)) / 100.0f + 1.0f; // 转换对比度值 + float intercept = 128 * (1 - contrast); + + for (int x = 0; x < image.getWidth(); x++) { + for (int y = 0; y < image.getHeight(); y++) { + int rgb = image.getRGB(x, y); + int r = Math.min(255, Math.max(0, (int)((((rgb >> 16) & 0xFF) * contrast) + intercept))); + int g = Math.min(255, Math.max(0, (int)((((rgb >> 8) & 0xFF) * contrast) + intercept))); + int b = Math.min(255, Math.max(0, (int)((((rgb & 0xFF) * contrast) + intercept)))); + int newRgb = (r << 16) | (g << 8) | b; + result.setRGB(x, y, newRgb); + } + } + return result; + } + + /** + * 饱和度调整 + * @param image 原始图像 + * @param factor 饱和度因子 (0.0到2.0之间,1.0为原图) + * @return 调整后的图像 + */ + public BufferedImage adjustSaturation(BufferedImage image, float factor) { + BufferedImage result = new BufferedImage( + image.getWidth(), image.getHeight(), image.getType()); + + factor = Math.max(0.0f, Math.min(2.0f, factor)); + + for (int x = 0; x < image.getWidth(); x++) { + for (int y = 0; y < image.getHeight(); y++) { + int rgb = image.getRGB(x, y); + int r = (rgb >> 16) & 0xFF; + int g = (rgb >> 8) & 0xFF; + int b = rgb & 0xFF; + + // 计算灰度值 + int gray = (int)(0.299 * r + 0.587 * g + 0.114 * b); + + // 调整饱和度 + int newR = (int)(gray + factor * (r - gray)); + int newG = (int)(gray + factor * (g - gray)); + int newB = (int)(gray + factor * (b - gray)); + + newR = Math.min(255, Math.max(0, newR)); + newG = Math.min(255, Math.max(0, newG)); + newB = Math.min(255, Math.max(0, newB)); + + int newRgb = (newR << 16) | (newG << 8) | newB; + result.setRGB(x, y, newRgb); + } + } + return result; + } + + /** + * 色相调整 + * @param image 原始图像 + * @param degrees 色相调整角度 (-180到180度) + * @return 调整后的图像 + */ + public BufferedImage adjustHue(BufferedImage image, float degrees) { + BufferedImage result = new BufferedImage( + image.getWidth(), image.getHeight(), image.getType()); + + degrees = degrees % 360; + float hueShift = degrees / 360.0f; + + for (int x = 0; x < image.getWidth(); x++) { + for (int y = 0; y < image.getHeight(); y++) { + int rgb = image.getRGB(x, y); + int r = (rgb >> 16) & 0xFF; + int g = (rgb >> 8) & 0xFF; + int b = rgb & 0xFF; + + float[] hsb = Color.RGBtoHSB(r, g, b, null); + hsb[0] = (hsb[0] + hueShift) % 1.0f; // 调整色相 + if (hsb[0] < 0) hsb[0] += 1.0f; + + int newRgb = Color.HSBtoRGB(hsb[0], hsb[1], hsb[2]); + result.setRGB(x, y, newRgb); + } + } + return result; + } + + /** + * 灰度化处理 + * @param image 原始图像 + * @return 灰度图像 + */ + public BufferedImage convertToGrayscale(BufferedImage image) { + BufferedImage result = new BufferedImage( + image.getWidth(), image.getHeight(), BufferedImage.TYPE_BYTE_GRAY); + + Graphics g = result.getGraphics(); + g.drawImage(image, 0, 0, null); + g.dispose(); + + return result; + } + /** + * 部分区域灰度化处理 + * @param image 原始图像 + * @param ratio 灰度化区域比例 (0.0到1.0之间) + * @return 部分灰度化图像 + */ + public BufferedImage partialGrayscale(BufferedImage image, float ratio) { + // 限制ratio范围在0.3到1.0之间 + ratio = Math.max(0.3f, Math.min(1.0f, ratio)); + + BufferedImage result = new BufferedImage( + image.getWidth(), image.getHeight(), image.getType()); + + // 复制原图到结果图像 + Graphics2D g = result.createGraphics(); + g.drawImage(image, 0, 0, null); + + // 计算需要灰度化的区域宽度 + int grayscaleWidth = (int)(image.getWidth() * ratio); + + // 创建灰度化区域 + BufferedImage grayscaleRegion = convertToGrayscale( + image.getSubimage(0, 0, grayscaleWidth, image.getHeight())); + + // 将灰度化区域绘制到结果图像上 + g.drawImage(grayscaleRegion, 0, 0, null); + g.dispose(); + + return result; + } + + private Random random = new Random(); + + /** + * 随机亮度调整 + * @param image 原始图像 + * @return 调整后的图像 + */ + public BufferedImage adjustBrightness(BufferedImage image) { + float factor = (random.nextFloat() - 0.5f) ; // -1.0 到 1.0之间的随机值 + return adjustBrightness(image, factor); + } + + /** + * 随机对比度调整 + * @param image 原始图像 + * @return 调整后的图像 + */ + public BufferedImage adjustContrast(BufferedImage image) { + float factor = 0.5f + random.nextFloat() * 1.5f; // 0.5 到 2.0之间的随机值 + return adjustContrast(image, factor); + } + + /** + * 随机饱和度调整 + * @param image 原始图像 + * @return 调整后的图像 + */ + public BufferedImage adjustSaturation(BufferedImage image) { + float factor = random.nextFloat() * 2.0f; // 0.0 到 2.0之间的随机值 + return adjustSaturation(image, factor); + } + + /** + * 随机色相调整 + * @param image 原始图像 + * @return 调整后的图像 + */ + public BufferedImage adjustHue(BufferedImage image) { + float degrees = (random.nextFloat() - 0.5f) * 360.0f; // -180 到 180之间的随机值 + return adjustHue(image, degrees); + } + + /** + * 随机灰度化处理(有一定概率应用灰度化) + * @param image 原始图像 + * @return 可能处理后的图像 + */ + public BufferedImage randomGrayscale(BufferedImage image) { + + float degrees = random.nextFloat(); // -180 到 180之间的随机值 + if (random.nextBoolean()) { // 50%概率应用灰度化 + return partialGrayscale(image,degrees); + } + return image; + } + /** + * 图像增强处理方法 + * @param inputPath 输入图片路径 + * @param enhancementType 增强处理类型 + * @return 处理后的图片路径 + */ + public String enhanceImage(String inputPath,String basePath, String enhancementType) { + try { + // 构建输出路径: 原路径-enhance-处理类型.jpg + String outputPath = inputPath.replaceAll("\\.(?=[^.]*$)", "-enhance-" + enhancementType + "."); + if (!outputPath.toLowerCase().endsWith(".jpg")) { + outputPath = outputPath.substring(0, outputPath.lastIndexOf(".")) + ".jpg"; + } + + // 读取原图 + BufferedImage originalImage = ImageIO.read(new File(basePath +inputPath)); + BufferedImage enhancedImage = originalImage; + + // 根据处理类型应用不同的增强效果 + switch (enhancementType) { + case "brightness": + enhancedImage = adjustBrightness(originalImage); + break; + case "contrast": + enhancedImage = adjustContrast(originalImage); + break; + case "saturation": + enhancedImage = adjustSaturation(originalImage); + break; + case "hue": + enhancedImage = adjustHue(originalImage); + break; + case "grayscale": + enhancedImage = randomGrayscale(originalImage); + break; + default: + // 未识别的处理类型,返回原图 + break; + } + + // 保存处理后的图像 + ImageIO.write(enhancedImage, "jpg", new File(basePath + outputPath)); + + return outputPath; + } catch (Exception e) { + throw new RuntimeException("图像增强处理失败: " + e.getMessage(), e); + } + } + +} diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/MarkInfo/AnnotationData.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/MarkInfo/AnnotationData.java index 864fc70..f2f8ac2 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/MarkInfo/AnnotationData.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/MarkInfo/AnnotationData.java @@ -1,6 +1,7 @@ // AnnotationData.java package cn.iocoder.yudao.module.annotation.service.MarkInfo; +import cn.hutool.json.JSONUtil; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import lombok.AllArgsConstructor; import lombok.Data; @@ -59,4 +60,15 @@ public class AnnotationData { } } + public String toString() { + return JSONUtil.toJsonStr(this); + } + + static public AnnotationData fromString(String json) { + if (json == null || json.isEmpty()) { + return null; + } + return JSONUtil.toBean(json, AnnotationData.class); + } + } diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/datas/DatasService.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/datas/DatasService.java index 2671f8f..fc5294e 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/datas/DatasService.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/datas/DatasService.java @@ -1,6 +1,7 @@ package cn.iocoder.yudao.module.annotation.service.datas; import cn.iocoder.yudao.framework.common.pojo.PageResult; +import cn.iocoder.yudao.module.annotation.controller.admin.datas.vo.DatasEnhance; import cn.iocoder.yudao.module.annotation.controller.admin.datas.vo.DatasPageReqVO; import cn.iocoder.yudao.module.annotation.controller.admin.datas.vo.DatasSaveReqVO; import cn.iocoder.yudao.module.annotation.dal.dataobject.datas.DatasDO; @@ -52,4 +53,7 @@ public interface DatasService extends IService { */ PageResult getDatasPage(DatasPageReqVO pageReqVO); + void refreshDatas(DatasDO datasDO); + + void enhance(DatasEnhance datasEnhance); } \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/datas/DatasServiceImpl.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/datas/DatasServiceImpl.java index df8ffed..01e7804 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/datas/DatasServiceImpl.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/datas/DatasServiceImpl.java @@ -2,12 +2,19 @@ package cn.iocoder.yudao.module.annotation.service.datas; import cn.iocoder.yudao.framework.common.pojo.PageResult; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; +import cn.iocoder.yudao.module.annotation.controller.admin.datas.vo.DatasEnhance; import cn.iocoder.yudao.module.annotation.controller.admin.datas.vo.DatasPageReqVO; import cn.iocoder.yudao.module.annotation.controller.admin.datas.vo.DatasSaveReqVO; import cn.iocoder.yudao.module.annotation.dal.dataobject.datas.DatasDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkDO; +import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkInfoDO; import cn.iocoder.yudao.module.annotation.dal.mysql.datas.DatasMapper; -import cn.iocoder.yudao.module.annotation.dal.mysql.mark.MarkMapper; +import cn.iocoder.yudao.module.annotation.dal.mysql.mark.MarkInfoMapper; +import cn.iocoder.yudao.module.annotation.service.ImageService; +import cn.iocoder.yudao.module.annotation.service.mark.MarkService; +import cn.iocoder.yudao.module.system.dal.dataobject.dict.DictDataDO; +import cn.iocoder.yudao.module.system.service.dict.DictDataService; +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import jakarta.annotation.Resource; import org.springframework.stereotype.Service; @@ -16,6 +23,8 @@ import org.springframework.validation.annotation.Validated; import java.io.File; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; //import static cn.iocoder.yudao.module.annotation.enums.ErrorCodeConstants.*; /** @@ -31,14 +40,19 @@ public class DatasServiceImpl extends ServiceImpl impleme private DatasMapper datasMapper; @Resource - private MarkMapper markMapper; + private MarkService markService; + @Resource + private DictDataService dictDataService; /** - * 检查路径是否存在,并获取其中所有图片文件名称 + * 递归检查路径是否存在,并获取其中所有图片文件名称(包括子目录中的图片) * @param path 要检查的路径字符串 * @return 图片文件名称列表 */ - public List getImageNamesFromPath(String path) { + public List getImageNamesFromPath(String superiorPath) { + DictDataDO basePath = dictDataService.parseDictData("visual_annotation_conf","base_path"); + String path = basePath.getValue() + superiorPath; + List imageNames = new ArrayList<>(); // 检查路径是否为空 @@ -47,23 +61,67 @@ public class DatasServiceImpl extends ServiceImpl impleme } File directory = new File(path); + File rootDirectory = directory; // 保存根目录引用 // 检查路径是否存在且为目录 if (!directory.exists() || !directory.isDirectory()) { return imageNames; } - // 获取目录中的所有文件 + // 递归遍历目录,传递根目录用于计算相对路径 + traverseDirectory(directory, rootDirectory, imageNames); + imageNames = imageNames.stream() + .map(v->{return superiorPath+"/"+v;}) + .collect(Collectors.toList()); + + return imageNames; + } + + /** + * 递归遍历目录查找图片文件 + * @param directory 要遍历的目录 + * @param rootDirectory 根目录 + * @param imageNames 图片文件名称收集列表 + */ + private void traverseDirectory(File directory, File rootDirectory, List imageNames) { File[] files = directory.listFiles(); if (files != null) { for (File file : files) { - if (isImageFile(file)) { - imageNames.add(file.getName()); + if (file.isDirectory()) { + // 递归处理子目录,保持根目录不变 + traverseDirectory(file, rootDirectory, imageNames); + } else if (isImageFile(file)) { + // 添加图片文件(包含相对于根目录的路径) + String relativePath = getRelativePath(file, rootDirectory); + imageNames.add(relativePath); } } } + } - return imageNames; + + /** + * 获取文件相对于根目录的路径 + * @param file 文件 + * @param rootDirectory 根目录 + * @return 相对路径 + */ + private String getRelativePath(File file, File rootDirectory) { + // 使用 Path API 来正确计算相对路径 + try { + return rootDirectory.toPath().relativize(file.toPath()).toString(); + } catch (IllegalArgumentException e) { + // 如果无法计算相对路径,则返回文件名 + return file.getName(); + } + } + + + + public static void main(String[] args) { + DatasServiceImpl datasService = new DatasServiceImpl(); + List imageNames = datasService.getImageNamesFromPath("D:\\PycharmProjects\\yolo\\runs\\detect\\11"); + System.out.println(imageNames); } /** @@ -95,12 +153,12 @@ public class DatasServiceImpl extends ServiceImpl impleme datas.setCount((long) imageNames.size()); datas.setProgress(0.0); for (String imageName : imageNames){ - markMapper.insert(new MarkDO() + markService.save(new MarkDO() .setDataId(datas.getId()) - .setPath(imageName) - .setAnnotation( new ArrayList<>())); + .setPath(imageName)); } datasMapper.insert(datas); + // 返回 return datas.getId(); } @@ -137,5 +195,120 @@ public class DatasServiceImpl extends ServiceImpl impleme public PageResult getDatasPage(DatasPageReqVO pageReqVO) { return datasMapper.selectPage(pageReqVO); } + @Override + public void refreshDatas(DatasDO datasDO) { + // 1. 获取路径中的所有图片文件 + List currentImageNames = getImageNamesFromPath(datasDO.getPath()); + + // 2. 获取数据库中已有的标记记录 + List existingMarks = markService.list(new QueryWrapper().eq("data_id", datasDO.getId())); + + // 3. 找出需要删除的图片(数据库中有但文件系统中没有) + List existingImageNames = existingMarks.stream() + .map(MarkDO::getPath) + .collect(Collectors.toList()); + + List imagesToDelete = existingImageNames.stream() + .filter(imageName -> !currentImageNames.contains(imageName)) + .collect(Collectors.toList()); + + // 删除不存在的标记记录 + if (!imagesToDelete.isEmpty()) { + markService.remove(new QueryWrapper() + .eq("data_id", datasDO.getId()) + .in("path", imagesToDelete)); + } + + // 4. 找出需要新增的图片(文件系统中有但数据库中没有) + List imagesToAdd = currentImageNames.stream() + .filter(imageName -> !existingImageNames.contains(imageName)) + .collect(Collectors.toList()); + + // 新增图片标记记录 + for (String imageName : imagesToAdd) { + markService.save(new MarkDO() + .setDataId(datasDO.getId()) + .setPath(imageName) + .setStatus(0)); // 默认未标注状态 + } + + // 5. 更新数据集统计信息 + List updatedMarks = markService.list(new QueryWrapper().eq("data_id", datasDO.getId())); + + // 计算进度 + long totalImages = updatedMarks.size(); + long completedImages = updatedMarks.stream() + .filter(mark -> mark.getStatus() != null && mark.getStatus() == 1) + .count(); + + double progress = totalImages > 0 ? (completedImages * 100.0 / totalImages) : 0.0; + + // 判断是否全部完成 + String status = (completedImages == totalImages && totalImages > 0) ? "2" : "1"; // 2表示完成,1表示进行中 + + // 更新数据集 + datasDO.setCount(totalImages); + datasDO.setProgress(progress); + datasDO.setStatus(status); + datasDO.setCreator( null); + datasDO.setUpdater(null); + datasMapper.updateById(datasDO); + } + @Resource + ImageService imageService; + @Resource + MarkInfoMapper markInfoMapper; + + @Override + public void enhance(DatasEnhance datasEnhance) { + + DictDataDO basePath = dictDataService.parseDictData("visual_annotation_conf","base_path"); +// 遍历图片,根据选择的类型,新增图片,并且新增他自身的标记记录 + List markDOList = markService.list(new QueryWrapper() + .eq("data_id", datasEnhance.getId())); + + List> futures = new ArrayList<>(); + + for (int i = 0; i < markDOList.size(); i++) { + final int index = i; + final MarkDO markDO = markDOList.get(i); + + CompletableFuture future = CompletableFuture.runAsync(() -> { + try { + String path = imageService.enhanceImage( + markDO.getPath(), + basePath.getValue(), + datasEnhance.getEnhancements()[index % datasEnhance.getEnhancements().length] + ); + + List markInfoDOList = markInfoMapper.selectList( + new QueryWrapper().eq("mark_id", markDO.getId()) + ); + + MarkDO newMarkDO = BeanUtils.toBean(markDO, MarkDO.class); + + newMarkDO.setId(null); + newMarkDO.setPath(path); + markService.save(newMarkDO); + + for (MarkInfoDO markInfoDO : markInfoDOList) { + MarkInfoDO newMarkInfoDO = BeanUtils.toBean(markInfoDO, MarkInfoDO.class); + newMarkInfoDO.setId(null); + newMarkInfoDO.setMarkId(newMarkDO.getId()); + markInfoMapper.insert(newMarkInfoDO); + } + } catch (Exception e) { + log.error("图像增强处理失败,图片路径: "+markDO.getPath(),e); + } + }); + + futures.add(future); + } + + // 等待所有异步任务完成 + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + + + } } \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/mark/MarkServiceImpl.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/mark/MarkServiceImpl.java index dfac620..3005fc5 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/mark/MarkServiceImpl.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/mark/MarkServiceImpl.java @@ -29,6 +29,7 @@ public class MarkServiceImpl extends ServiceImpl implements // 插入 MarkDO mark = BeanUtils.toBean(createReqVO, MarkDO.class); markMapper.insert(mark); + // 返回 return mark.getId(); } @@ -39,7 +40,7 @@ public class MarkServiceImpl extends ServiceImpl implements // validateMarkExists(updateReqVO.getId()); // 更新 MarkDO updateObj = BeanUtils.toBean(updateReqVO, MarkDO.class); - markMapper.updateById(updateObj); + markMapper.updateByIdSetStatus(updateObj); } @Override diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/python/PythonVirtualEnvService.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/python/PythonVirtualEnvService.java new file mode 100644 index 0000000..0e6458e --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/python/PythonVirtualEnvService.java @@ -0,0 +1,48 @@ +package cn.iocoder.yudao.module.annotation.service.python; + +import cn.iocoder.yudao.module.annotation.service.python.vo.PythonVirtualEnvInfo; + +/** + * Python虚拟环境服务接口 + * + * @author 管理员 + */ +public interface PythonVirtualEnvService { + + /** + * 进入指定的Python虚拟环境 + * + * @param pythonPath Python可执行文件路径 + * @param envName 虚拟环境名称 + * @return 虚拟环境信息 + */ + PythonVirtualEnvInfo activateVirtualEnv(String pythonPath, String envName); + + /** + * 检测Python和虚拟环境 + * + * @param pythonPath Python可执行文件路径 + * @param envName 虚拟环境名称(可为null,表示检测Python本身) + * @return 检测结果 + */ + PythonVirtualEnvInfo detectPythonAndVirtualEnv(String pythonPath, String envName); + + /** + * 创建Python虚拟环境 + * + * @param pythonPath Python可执行文件路径 + * @param envName 虚拟环境名称 + * @return 创建结果 + */ + PythonVirtualEnvInfo createVirtualEnvironment(String pythonPath, String envName); + + /** + * 在虚拟环境中安装YOLO + * + * @param pythonPath Python可执行文件路径 + * @param envName 虚拟环境名称 + * @return 安装结果 + */ + PythonVirtualEnvInfo installYolo(String pythonPath, String envName); + +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/python/PythonVirtualEnvServiceImpl.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/python/PythonVirtualEnvServiceImpl.java new file mode 100644 index 0000000..7c32119 --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/python/PythonVirtualEnvServiceImpl.java @@ -0,0 +1,524 @@ +package cn.iocoder.yudao.module.annotation.service.python; + +import cn.iocoder.yudao.module.annotation.service.python.vo.PythonVirtualEnvInfo; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; + +import java.io.BufferedReader; +import java.io.File; +import java.io.InputStreamReader; + +/** + * Python虚拟环境服务实现类 + * + * @author 管理员 + */ +@Slf4j +@Service +public class PythonVirtualEnvServiceImpl implements PythonVirtualEnvService { + + @Override + public PythonVirtualEnvInfo activateVirtualEnv(String pythonPath, String envName) { + PythonVirtualEnvInfo info = new PythonVirtualEnvInfo(); + + try { + // 获取实际的Python可执行文件路径 + String actualPythonPath = getActualPythonExecutable(pythonPath); + if (actualPythonPath == null) { + info.setPythonExists(false); + info.setErrorMessage("Python文件不存在或无法找到可执行文件: " + pythonPath); + return info; + } + + // 检测Python版本 + String version = getPythonVersion(pythonPath); + if (version == null) { + info.setPythonExists(false); + info.setErrorMessage("无法获取Python版本"); + return info; + } + + info.setPythonExists(true); + info.setPythonVersion(version); + info.setPythonPath(actualPythonPath); + + // 检测虚拟环境 + detectVirtualEnvironment(info, envName); + + } catch (Exception e) { + log.error("激活虚拟环境失败", e); + info.setErrorMessage("激活虚拟环境异常: " + e.getMessage()); + } + + return info; + } + + @Override + public PythonVirtualEnvInfo createVirtualEnvironment(String pythonPath, String envName) { + PythonVirtualEnvInfo info = new PythonVirtualEnvInfo(); + + try { + // 获取实际的Python可执行文件路径 + String actualPythonPath = getActualPythonExecutable(pythonPath); + if (actualPythonPath == null) { + info.setPythonExists(false); + info.setErrorMessage("Python文件不存在或无法找到可执行文件: " + pythonPath); + return info; + } + + // 获取Python目录 + File pythonFile = new File(actualPythonPath); + String pythonDir = pythonFile.getParent(); + + // 构建虚拟环境路径 + String venvPath = pythonDir + File.separator + envName; + File venvDir = new File(venvPath); + + // 检查虚拟环境是否已存在 + if (venvDir.exists()) { + log.info("虚拟环境已存在: {}", venvPath); + return detectPythonAndVirtualEnv(pythonPath, envName); + } + + // 创建虚拟环境 + String os = System.getProperty("os.name").toLowerCase(); + String command; + + if (os.contains("win")) { + command = actualPythonPath + " -m venv " + venvPath; + } else { + command = actualPythonPath + " -m venv " + venvPath; + } + + ProcessBuilder processBuilder = new ProcessBuilder(); + if (os.contains("win")) { + processBuilder.command("cmd", "/c", command); + } else { + processBuilder.command("/bin/bash", "-c", command); + } + Process process = processBuilder.start(); + int exitCode = process.waitFor(); + + if (exitCode == 0) { + log.info("虚拟环境创建成功: {}", venvPath); + // 重新检测环境 + return detectPythonAndVirtualEnv(pythonPath, envName); + } else { + info.setErrorMessage("创建虚拟环境失败,退出码: " + exitCode); + } + + } catch (Exception e) { + log.error("创建虚拟环境失败", e); + info.setErrorMessage("创建虚拟环境异常: " + e.getMessage()); + } + + return info; + } + + @Override + public PythonVirtualEnvInfo installYolo(String pythonPath, String envName) { + PythonVirtualEnvInfo info = detectPythonAndVirtualEnv(pythonPath, envName); + + if (!info.getVirtualEnvExists()) { + info.setErrorMessage("虚拟环境不存在,无法安装YOLO"); + return info; + } + + try { + // 在虚拟环境中安装YOLO,使用清华源 + String os = System.getProperty("os.name").toLowerCase(); + String command; + + if (os.contains("win")) { + // Windows: 使用cmd激活虚拟环境后安装 + command = info.getActivateCommand() + " && pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -U ultralytics"; + } else { + // Linux/Mac: 使用bash激活虚拟环境后安装 + command = info.getActivateCommand() + " pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -U ultralytics"; + } + + log.info("在虚拟环境中安装YOLO命令: {}", command); + ProcessBuilder processBuilder = new ProcessBuilder(); + if (os.contains("win")) { + processBuilder.command("cmd", "/c", command); + } else { + processBuilder.command("/bin/bash", "-c", command); + } + Process process = processBuilder.start(); + + // 读取输出 + BufferedReader inputReader = new BufferedReader(new InputStreamReader(process.getInputStream())); + BufferedReader errorReader = new BufferedReader(new InputStreamReader(process.getErrorStream())); + + StringBuilder output = new StringBuilder(); + StringBuilder error = new StringBuilder(); + + String line; + while ((line = inputReader.readLine()) != null) { + output.append(line).append("\n"); + } + while ((line = errorReader.readLine()) != null) { + error.append(line).append("\n"); + } + + int exitCode = process.waitFor(); + inputReader.close(); + errorReader.close(); + + if (exitCode == 0) { + log.info("YOLO安装成功: {}", output.toString()); + // 重新检测环境,确认YOLO已安装 + return detectPythonAndVirtualEnv(pythonPath, envName); + } else { + info.setErrorMessage("YOLO安装失败,退出码: " + exitCode + ", 错误信息: " + error); + } + + } catch (Exception e) { + log.error("安装YOLO失败", e); + info.setErrorMessage("安装YOLO异常: " + e.getMessage()); + } + + return info; + } + @Override + public PythonVirtualEnvInfo detectPythonAndVirtualEnv(String pythonPath, String envName) { + PythonVirtualEnvInfo info = new PythonVirtualEnvInfo(); + + try { + // 获取实际的Python可执行文件路径 + String actualPythonPath = getActualPythonExecutable(pythonPath); + if (actualPythonPath == null) { + info.setPythonExists(false); + info.setErrorMessage("Python文件不存在或无法找到可执行文件: " + pythonPath); + return info; + } + + String version = getPythonVersion(pythonPath); + info.setPythonExists(true); + info.setPythonVersion(version); + info.setPythonPath(actualPythonPath); + + // 如果指定了虚拟环境名称,则检测虚拟环境 + if (envName != null && !envName.trim().isEmpty()) { + detectVirtualEnvironment(info, envName); + } + + } catch (Exception e) { + log.error("检测Python和虚拟环境失败", e); + info.setErrorMessage("检测异常: " + e.getMessage()); + } + + return info; + } + + private String getPythonVersion(String pythonPath) { + try { + // 检查并修正Python路径 + String actualPythonPath = getActualPythonExecutable(pythonPath); + if (actualPythonPath == null) { + return null; + } + + ProcessBuilder processBuilder = new ProcessBuilder(); + String os = System.getProperty("os.name").toLowerCase(); + if (os.contains("win")) { + processBuilder.command("cmd", "/c", actualPythonPath + " --version"); + } else { + processBuilder.command("/bin/bash", "-c", actualPythonPath + " --version"); + } + Process process = processBuilder.start(); + + // Python版本信息可能在标准输出或错误输出中,都检查一下 + BufferedReader stdoutReader = new BufferedReader(new InputStreamReader(process.getInputStream())); + BufferedReader stderrReader = new BufferedReader(new InputStreamReader(process.getErrorStream())); + + String stdoutLine = stdoutReader.readLine(); + String stderrLine = stderrReader.readLine(); + + stdoutReader.close(); + stderrReader.close(); + + // 检查两个输出流 + String versionLine = null; + if (stdoutLine != null && (stdoutLine.toLowerCase().contains("python") || stdoutLine.matches("\\d+\\.\\d+\\.\\d+.*"))) { + versionLine = stdoutLine; + } else if (stderrLine != null && (stderrLine.toLowerCase().contains("python") || stderrLine.matches("\\d+\\.\\d+\\.\\d+.*"))) { + versionLine = stderrLine; + } + + if (versionLine != null) { + return versionLine.trim(); + } + } catch (Exception e) { + log.warn("获取Python版本失败: " + pythonPath, e); + } + return null; + } + + /** + * 获取实际的Python可执行文件路径 + */ + private String getActualPythonExecutable(String pythonPath) { + File file = new File(pythonPath); + + // 如果是目录,尝试找到Python可执行文件 + if (file.isDirectory()) { + String os = System.getProperty("os.name").toLowerCase(); + String pythonExeName = os.contains("win") ? "python.exe" : "python3"; + File pythonExe = new File(file, pythonExeName); + + if (pythonExe.exists()) { + return pythonExe.getAbsolutePath(); + } + + // 在Windows上还尝试python3.exe + if (os.contains("win")) { + File python3Exe = new File(file, "python3.exe"); + if (python3Exe.exists()) { + return python3Exe.getAbsolutePath(); + } + } + + return null; + } + + // 如果是文件,直接返回 + return file.exists() ? pythonPath : null; + } + + private void detectVirtualEnvironment(PythonVirtualEnvInfo info, String envName) { + String os = System.getProperty("os.name").toLowerCase(); + + try { + // 1. 检测conda环境 +// if (isCondaInstalled()) { +// String condaEnvPath = getCondaEnvPath(envName); +// if (condaEnvPath != null) { +// info.setVirtualEnvExists(true); +// info.setVirtualEnvType("conda"); +// info.setVirtualEnvPath(condaEnvPath); +// info.setVirtualEnvPythonPath(getCondaEnvPythonPath(condaEnvPath)); +// info.setActivateCommand("conda run -n " + envName); +// +// // 检测YOLO +// detectYoloInEnv(info); +// return; +// } +// } + + // 2. 检测venv环境 - 专门检测Python目录下的虚拟环境 + String pythonDir = info.getPythonPath(); + if (pythonDir != null) { + // 如果是具体的python.exe路径,获取其所在目录 + File pythonFile = new File(pythonDir); + if (pythonFile.isFile()) { + pythonDir = pythonFile.getParent(); + } + + // 构建虚拟环境路径:Python目录/envName + String venvPath = pythonDir + File.separator +"venv"+File.separator + envName; + File venvDir = new File(venvPath); + + if (venvDir.exists() && venvDir.isDirectory()) { + // 检查是否是有效的虚拟环境 + String venvPythonPath = getVenvEnvPythonPath(venvPath, os); + if (venvPythonPath != null && new File(venvPythonPath).exists()) { + info.setVirtualEnvExists(true); + info.setVirtualEnvType("venv"); + info.setVirtualEnvPath(venvPath); + info.setVirtualEnvPythonPath(venvPythonPath); + info.setActivateCommand(os.contains("win") ? + venvPath + "\\Scripts\\activate && " : + "source " + venvPath + "/bin/activate && "); + + // 检测YOLO + detectYoloInEnv(info); + return; + } + } + } + + // 3. 检测virtualenv环境 +// String virtualenvPath = getVirtualenvEnvPath(envName); +// if (virtualenvPath != null) { +// info.setVirtualEnvExists(true); +// info.setVirtualEnvType("virtualenv"); +// info.setVirtualEnvPath(virtualenvPath); +// info.setVirtualEnvPythonPath(getVenvEnvPythonPath(virtualenvPath, os)); +// info.setActivateCommand(os.contains("win") ? +// virtualenvPath + "\\Scripts\\activate" : +// "source " + virtualenvPath + "/bin/activate"); +// +// // 检测YOLO +// detectYoloInEnv(info); +// return; +// } + + info.setVirtualEnvExists(false); + info.setErrorMessage("未找到虚拟环境: " + envName); + + } catch (Exception e) { + log.error("检测虚拟环境失败: " + envName, e); + info.setVirtualEnvExists(false); + info.setErrorMessage("检测虚拟环境异常: " + e.getMessage()); + } + } + + private boolean isCondaInstalled() { + try { + String os = System.getProperty("os.name").toLowerCase(); + String command = "conda --version"; + + ProcessBuilder processBuilder = new ProcessBuilder(); + if (os.contains("win")) { + processBuilder.command("cmd", "/c", command); + } else { + processBuilder.command("/bin/bash", "-c", command); + } + Process process = processBuilder.start(); + BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream())); + String line = reader.readLine(); + reader.close(); + return line != null && line.toLowerCase().contains("conda"); + } catch (Exception e) { + return false; + } + } + + private String getCondaEnvPath(String envName) { + try { + String os = System.getProperty("os.name").toLowerCase(); + String command = "conda env list"; + + ProcessBuilder processBuilder = new ProcessBuilder(); + if (os.contains("win")) { + processBuilder.command("cmd", "/c", command); + } else { + processBuilder.command("/bin/bash", "-c", command); + } + Process process = processBuilder.start(); + BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream())); + String line; + while ((line = reader.readLine()) != null) { + if (!line.trim().isEmpty() && !line.startsWith("#") && line.contains(envName)) { + String[] parts = line.split("\\s+"); + if (parts.length >= 2) { + reader.close(); + return parts[1].trim(); + } + } + } + reader.close(); + } catch (Exception e) { + log.warn("获取conda环境路径失败: " + envName, e); + } + return null; + } + + private String getCondaEnvPythonPath(String condaEnvPath) { + String os = System.getProperty("os.name").toLowerCase(); + return os.contains("win") ? + condaEnvPath + "\\python.exe" : + condaEnvPath + "/bin/python"; + } + + private String getVenvEnvPath(String pythonPath, String envName) { + // 检查 pythonPath\venv 目录下是否存在 envName 文件夹 + File venvBaseDir = new File(pythonPath, "venv"); + if (venvBaseDir.exists() && venvBaseDir.isDirectory()) { + File targetEnvDir = new File(venvBaseDir, envName); + if (targetEnvDir.exists() && targetEnvDir.isDirectory()) { + return targetEnvDir.getAbsolutePath(); + } + } + + // 检查当前目录下的常见虚拟环境目录 + String[] commonDirs = {envName, "venv", ".venv", "env", ".env"}; + for (String dir : commonDirs) { + File venvDir = new File(dir); + if (venvDir.exists() && venvDir.isDirectory()) { + return venvDir.getAbsolutePath(); + } + } + // 检查当前目录下的常见虚拟环境目录 + + + // 检查用户主目录下的虚拟环境 + String userHome = System.getProperty("user.home"); + File virtualenvsDir = new File(userHome, ".virtualenvs"); + if (virtualenvsDir.exists()) { + File envDir = new File(virtualenvsDir, envName); + if (envDir.exists()) { + return envDir.getAbsolutePath(); + } + } + + return null; + } + + private String getVirtualenvEnvPath(String envName) { + // 主要检查.virtualenvs目录 + String userHome = System.getProperty("user.home"); + File virtualenvsDir = new File(userHome, ".virtualenvs"); + if (virtualenvsDir.exists()) { + File envDir = new File(virtualenvsDir, envName); + if (envDir.exists()) { + return envDir.getAbsolutePath(); + } + } + return null; + } + + private String getVenvEnvPythonPath(String venvPath, String os) { + return os.contains("win") ? + venvPath + "\\Scripts\\python.exe" : + venvPath + "/bin/python"; + } + + private void detectYoloInEnv(PythonVirtualEnvInfo info) { + if (info.getVirtualEnvPythonPath() == null) { + return; + } + + try { + // 在虚拟环境中检测ultralytics包 + String os = System.getProperty("os.name").toLowerCase(); + String command; + + if ("conda".equals(info.getVirtualEnvType())) { + command = "conda run -n " + new File(info.getVirtualEnvPath()).getName() + " python -c \"import ultralytics; print(ultralytics.__version__)\""; + } else if (os.contains("win")) { + command = info.getActivateCommand() + " python -c \"import ultralytics; print(ultralytics.__version__)\""; + } else { + command = info.getActivateCommand() + " python -c \"import ultralytics; print(ultralytics.__version__)\""; + } + + ProcessBuilder processBuilder = new ProcessBuilder(); + if (os.contains("win")) { + processBuilder.command("cmd", "/c", command); + } else { + processBuilder.command("/bin/bash", "-c", command); + } + Process process = processBuilder.start(); + BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream())); + String line = reader.readLine(); + reader.close(); + + if (line != null && !line.trim().isEmpty()) { + info.setYoloInstalled(true); + info.setYoloVersion(line.trim()); + } else { + info.setYoloInstalled(false); + } + + } catch (Exception e) { + log.warn("检测YOLO安装状态失败", e); + info.setYoloInstalled(false); + } + } + + public static void main(String[] args) { + + } + +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/python/vo/PythonVirtualEnvInfo.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/python/vo/PythonVirtualEnvInfo.java new file mode 100644 index 0000000..7bf9064 --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/python/vo/PythonVirtualEnvInfo.java @@ -0,0 +1,68 @@ +package cn.iocoder.yudao.module.annotation.service.python.vo; + +import lombok.Data; + +/** + * Python虚拟环境信息 + * + * @author 管理员 + */ +@Data +public class PythonVirtualEnvInfo { + + /** + * Python是否存在 + */ + private Boolean pythonExists; + + /** + * Python版本 + */ + private String pythonVersion; + + /** + * Python路径 + */ + private String pythonPath; + + /** + * 虚拟环境是否存在 + */ + private Boolean virtualEnvExists; + + /** + * 虚拟环境类型 (venv/virtualenv/conda) + */ + private String virtualEnvType; + + /** + * 虚拟环境路径 + */ + private String virtualEnvPath; + + /** + * 虚拟环境中Python路径 + */ + private String virtualEnvPythonPath; + + /** + * 是否安装了YOLO + */ + private Boolean yoloInstalled; + + /** + * YOLO版本 + */ + private String yoloVersion; + + /** + * 激活命令(用于后续操作) + */ + private String activateCommand; + + /** + * 错误信息 + */ + private String errorMessage; + +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainService.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainService.java index e548cd1..2203928 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainService.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainService.java @@ -1,18 +1,22 @@ package cn.iocoder.yudao.module.annotation.service.train; -import java.util.*; -import jakarta.validation.*; -import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.*; -import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; import cn.iocoder.yudao.framework.common.pojo.PageResult; -import cn.iocoder.yudao.framework.common.pojo.PageParam; +import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.SystemInfoRespVO; +import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.TrainPageReqVO; +import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.TrainSaveReqVO; +import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; +import com.baomidou.mybatisplus.extension.service.IService; +import jakarta.validation.Valid; +import org.springframework.web.multipart.MultipartFile; + +import java.util.List; /** * 训练 Service 接口 * * @author 管理员 */ -public interface TrainService { +public interface TrainService extends IService { /** * 创建训练 @@ -52,4 +56,26 @@ public interface TrainService { */ PageResult getTrainPage(TrainPageReqVO pageReqVO); + /** + * 获取系统信息(CPU、GPU、Python环境) + * + * @return 系统信息 + */ + SystemInfoRespVO getSystemInfo(); + + public boolean init(TrainSaveReqVO createReqVO); + + List testImages(); + + /** + * 上传测试图片并保存到指定文件夹 + * + * @param files 上传的图片文件数组 + * @return 图片路径列表 + */ + List testImagesInstall(MultipartFile[] files); + + void testImagesClean(); + + String testRecognition(String image,String trainId); } \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainServiceImpl.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainServiceImpl.java index f5af9d2..7759ce0 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainServiceImpl.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainServiceImpl.java @@ -1,42 +1,634 @@ package cn.iocoder.yudao.module.annotation.service.train; -import org.springframework.stereotype.Service; -import jakarta.annotation.Resource; -import org.springframework.validation.annotation.Validated; -import org.springframework.transaction.annotation.Transactional; - -import java.util.*; -import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.*; -import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; import cn.iocoder.yudao.framework.common.pojo.PageResult; -import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; - +import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.SystemInfoRespVO; +import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.TrainPageReqVO; +import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.TrainSaveReqVO; +import cn.iocoder.yudao.module.annotation.dal.dataobject.datas.DatasDO; +import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkDO; +import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkInfoDO; +import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; +import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO; +import cn.iocoder.yudao.module.annotation.dal.dataobject.types.TypesDO; import cn.iocoder.yudao.module.annotation.dal.mysql.train.TrainMapper; +import cn.iocoder.yudao.module.annotation.dal.mysql.trainresult.TrainResultMapper; +import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig; +import cn.iocoder.yudao.module.annotation.service.MarkInfo.AnnotationData; +import cn.iocoder.yudao.module.annotation.service.MarkInfo.MarkInfoService; +import cn.iocoder.yudao.module.annotation.service.datas.DatasService; +import cn.iocoder.yudao.module.annotation.service.mark.MarkService; +import cn.iocoder.yudao.module.annotation.service.types.TypesService; +import cn.iocoder.yudao.module.annotation.service.yolo.YoloOperationService; +import cn.iocoder.yudao.module.system.dal.dataobject.dict.DictDataDO; +import cn.iocoder.yudao.module.system.service.dict.DictDataService; +import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; +import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper; +import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; +import jakarta.annotation.Resource; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Service; +import org.springframework.validation.annotation.Validated; +import org.springframework.web.multipart.MultipartFile; -import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; -//import static cn.iocoder.yudao.module.annotation.enums.ErrorCodeConstants.*; +import javax.imageio.ImageIO; +import java.awt.image.BufferedImage; +import java.io.*; +import java.nio.file.Files; +import java.nio.file.StandardCopyOption; +import java.util.*; +import java.util.stream.Collectors; /** * 训练 Service 实现类 * * @author 管理员 */ +@Slf4j @Service @Validated -public class TrainServiceImpl implements TrainService { +public class TrainServiceImpl extends ServiceImpl implements TrainService { @Resource private TrainMapper trainMapper; + @Resource + private MarkService markService; + + @Resource + private MarkInfoService markInfoService; + + @Resource + private TypesService typesService; + + @Resource + private DictDataService dictDataService; + + @Resource + private DatasService datasService; + + @Override public Integer createTrain(TrainSaveReqVO createReqVO) { + Map dictDataMap = dictDataService.getDictDataList("visual_annotation_conf"); // 插入 TrainDO train = BeanUtils.toBean(createReqVO, TrainDO.class); + DatasDO datas = datasService.getDatas(createReqVO.getDataId()); + train.setType(datas.getType()); + if(train.getType().equals("1")){ + train.setModelPath(dictDataMap.get("detect_path").getValue()); + }else if(train.getType().equals("2")){ + train.setModelPath(dictDataMap.get("classify_path").getValue()); + } + trainMapper.insert(train); + + String yoloDatasetPath = createYoloDatasetPath(train, dictDataMap.get("training_address").getValue()); + train.setPath(yoloDatasetPath); + trainMapper.updateById(train); + // 返回 return train.getId(); } + @Resource + YoloOperationService yoloOperationService; + public boolean init(TrainSaveReqVO createReqVO) { + try { + TrainDO train = trainMapper.selectById(createReqVO.getId()); + DatasDO datas = datasService.getDatas(createReqVO.getDataId()); + // 获取配置和基础数据 + Map dictDataMap = dictDataService.getDictDataList("visual_annotation_conf"); + List markList = markService.list(new QueryWrapper() + .eq("data_id", createReqVO.getDataId())); + List typesList = typesService.list(new QueryWrapper() + .eq("data_id", createReqVO.getDataId())); + + if (markList.isEmpty() || typesList.isEmpty()) { + log.warn("项目 {} 没有标注数据或类别数据", createReqVO.getDataId()); + return false; + } + + // 确定任务类型(检测/分类) + String taskType = getTaskType( datas.getType()); + + // 创建YOLO数据集目录结构 + String yoloDatasetPath = train.getPath(); + + // 生成类别映射文件 + generateClassesFile(typesList, yoloDatasetPath); + + // 根据任务类型生成数据 + if ("detect".equals(taskType)) { + generateDetectionDataset(dictDataMap,markList, typesList, createReqVO, yoloDatasetPath); + } else if ("classify".equals(taskType)) { + generateClassificationDataset(dictDataMap,markList, typesList, createReqVO, yoloDatasetPath); + } + + // 生成YOLO配置文件 + generateYamlConfig(typesList, taskType, yoloDatasetPath); + update( new UpdateWrapper().eq("id", createReqVO.getId()).set("train_type", 6)); + yoloOperationService.generateTrainPythonScript(train, yoloDatasetPath); + log.info("YOLO数据集生成完成: {}", yoloDatasetPath); + return true; + + } catch (Exception e) { + log.error("生成YOLO数据集失败", e); + return false; + } + } + + @Override + public List testImages() { + Map dictDataMap = dictDataService.getDictDataList("visual_annotation_conf"); +// 遍历置顶文件夹下的所有文件,获取文件名称 + File dir = new File(dictDataMap.get("base_path").getValue() + "test/"); + File[] files = dir.listFiles(); + List fileNames = new ArrayList<>(); + if (files != null) { + for (File file : files) { +// 判断是否是图片 + if (file.getName().endsWith(".jpg") || file.getName().endsWith(".png")||file.getName().endsWith(".jpeg")||file.getName().endsWith(".bmp")) { + fileNames.add(dictDataMap.get("base_url").getValue() + "test/" + file.getName()); + } + } + } + return fileNames; + } + + @Override + public List testImagesInstall(MultipartFile[] files) { + if (files == null || files.length == 0) { + log.warn("没有上传任何文件"); + return new ArrayList<>(); + } + + Map dictDataMap = dictDataService.getDictDataList("visual_annotation_conf"); + String basePath = dictDataMap.get("base_path").getValue(); + String baseUrl = dictDataMap.get("base_url").getValue(); + + // 目标文件夹(install_test_images) + File targetDir = new File(basePath + "test/"); + if (!targetDir.exists()) { + targetDir.mkdirs(); + log.info("创建目标文件夹: {}", targetDir.getAbsolutePath()); + } + + List savedImagePaths = new ArrayList<>(); + + for (MultipartFile file : files) { + try { + String originalFilename = file.getOriginalFilename(); + if (originalFilename == null || originalFilename.trim().isEmpty()) { + log.warn("文件名为空,跳过"); + continue; + } + + // 检查文件类型是否为图片 + if (!isImageFile(originalFilename)) { + log.warn("非图片文件,跳过: {}", originalFilename); + continue; + } + + // 生成唯一文件名(避免重复) + String fileExtension = getFileExtension(originalFilename); + String uniqueFileName = generateUniqueFileName(originalFilename, fileExtension); + File targetFile = new File(targetDir, uniqueFileName); + + // 保存文件 + file.transferTo(targetFile); + log.info("保存图片: {} -> {}", originalFilename, targetFile.getAbsolutePath()); + + // 添加返回的URL路径 + savedImagePaths.add(baseUrl + "test/" + uniqueFileName); + + } catch (IOException e) { + log.error("保存图片失败: {}", file.getOriginalFilename(), e); + } + } + + log.info("testImagesInstall 完成,共保存 {} 张图片到 install_test_images 文件夹", savedImagePaths.size()); + return savedImagePaths; + } + public void testImagesClean(){ + + + Map dictDataMap = dictDataService.getDictDataList("visual_annotation_conf"); + String basePath = dictDataMap.get("base_path").getValue(); +// 清空install_test_images文件夹下的所有文件 + File dir = new File(basePath + "test/"); + File[] files = dir.listFiles(); + if (files != null) { + for (File file : files) { + file.delete(); + } + } + } + @Resource + private TrainResultMapper trainResultMapper; + + public String testRecognition(String image,String trainId){ + TrainDO train = trainMapper.selectById(trainId); + TrainResultDO trainResultDO = trainResultMapper.selectOne(new QueryWrapper() + .eq("train_id", trainId) + .orderByDesc("create_time") + .last("limit 1")); + + Map dictDataMap = dictDataService.getDictDataList("visual_annotation_conf"); +// 删除前端信息 + image = image.replace(dictDataMap.get("base_url").getValue()+"test/",""); +// yoloOperationService.classify() + YoloConfig yoloConfig = new YoloConfig(); + yoloConfig.setModelPath(trainResultDO.getPath()+"/weights/best.pt"); + yoloConfig.setInputPath(dictDataMap.get("base_path").getValue() + "test/" + image); + yoloConfig.setOutputPath(dictDataMap.get("base_path").getValue() + "test/run/"); + YoloConfig yoloConfigResult = yoloOperationService.predict(dictDataMap.get("python_path").getValue(),dictDataMap.get("python_venv").getValue(),yoloConfig); + String url = replacePathPrefix(yoloConfigResult.getOutputPath(), dictDataMap.get("base_path").getValue(), dictDataMap.get("base_url").getValue()) + "/" + image; + return url; + } + /** + * 智能路径替换方法 + * 当outputPath的前缀与basePath相同时,将前缀替换为baseUrl + * 正确处理斜杠和反斜杠的差异 + */ + private String replacePathPrefix(String outputPath, String basePath, String baseUrl) { + if (outputPath == null || basePath == null || baseUrl == null) { + return outputPath; + } + + // 标准化路径分隔符,统一使用正斜杠 + String normalizedOutputPath = outputPath.replace('\\', '/'); + String normalizedBasePath = basePath.replace('\\', '/'); + + // 确保basePath以正斜杠结尾,以便精确匹配 + if (!normalizedBasePath.endsWith("/")) { + normalizedBasePath += "/"; + } + + // 检查outputPath是否以basePath开头 + if (normalizedOutputPath.startsWith(normalizedBasePath)) { + // 提取相对路径部分 + String relativePath = normalizedOutputPath.substring(normalizedBasePath.length()); + + // 确保baseUrl以正斜杠结尾 + if (!baseUrl.endsWith("/")) { + baseUrl += "/"; + } + + return baseUrl + relativePath; + } + + // 如果前缀不匹配,返回原始路径(但标准化斜杠) + return normalizedOutputPath; + } + + /** + * 检查是否为图片文件 + */ + private boolean isImageFile(String filename) { + String lowerCaseFilename = filename.toLowerCase(); + return lowerCaseFilename.endsWith(".jpg") || + lowerCaseFilename.endsWith(".jpeg") || + lowerCaseFilename.endsWith(".png") || + lowerCaseFilename.endsWith(".bmp") || + lowerCaseFilename.endsWith(".gif"); + } + + /** + * 获取文件扩展名 + */ + private String getFileExtension(String filename) { + int lastDotIndex = filename.lastIndexOf('.'); + return lastDotIndex != -1 ? filename.substring(lastDotIndex + 1).toLowerCase() : ""; + } + + /** + * 生成唯一文件名 + */ + private String generateUniqueFileName(String originalFilename, String extension) { + String nameWithoutExtension = originalFilename.substring(0, originalFilename.lastIndexOf('.')); + String timestamp = String.valueOf(System.currentTimeMillis()); + String randomSuffix = String.valueOf((int)(Math.random() * 1000)); + + if (!extension.isEmpty()) { + return nameWithoutExtension + "_" + timestamp + "_" + randomSuffix + "." + extension; + } else { + return nameWithoutExtension + "_" + timestamp + "_" + randomSuffix; + } + } + + + /** + * 获取任务类型 + */ + private String getTaskType( String type) { + // 优先使用传入的type参数 + if (type != null && "1".equals(type)) { + return "detect"; + }else if(type != null && "2".equals(type)){ + return "classify"; + } + + // 默认为检测任务 + return "detect"; + } + + /** + * 创建YOLO数据集目录结构 + */ + private String createYoloDatasetPath(TrainDO train, String customPath) { + String basePath = customPath != null ? customPath : System.getProperty("user.dir") + "/yolo_datasets"; + String datasetPath = basePath + "/datas_" + train.getId(); + + File datasetDir = new File(datasetPath); + if (!datasetDir.exists()) { + datasetDir.mkdirs(); + } + + // 创建YOLO标准目录结构 + new File(datasetPath, "/images/train").mkdirs(); + new File(datasetPath, "/images/val").mkdirs(); + new File(datasetPath, "/images/test").mkdirs(); + new File(datasetPath, "/labels/train").mkdirs(); + new File(datasetPath, "/labels/val").mkdirs(); + new File(datasetPath, "/labels/test").mkdirs(); + + // 分类任务需要额外的类别目录 + new File(datasetPath, "/images/train/classify").mkdirs(); + new File(datasetPath, "/images/val/classify").mkdirs(); + new File(datasetPath, "/images/test/classify").mkdirs(); + + return datasetPath; + } + + /** + * 生成classes.txt文件 + */ + private void generateClassesFile(List typesList, String datasetPath) { + try (PrintWriter writer = new PrintWriter(new FileWriter(datasetPath + "/classes.txt"))) { + // 按id排序(如果没有index字段,使用id排序) + typesList.sort(Comparator.comparing(TypesDO::getId)); + for (TypesDO type : typesList) { + writer.println(type.getName()); + } + } catch (IOException e) { + log.error("生成classes.txt失败", e); + } + } + + /** + * 生成检测任务数据集 + */ + private void generateDetectionDataset(Map dictDataMap,List markList, List typesList, + TrainSaveReqVO createReqVO, String datasetPath) { + // 按id排序建立类别ID映射(如果没有index字段,使用id排序) + typesList.sort(Comparator.comparing(TypesDO::getId)); + Map classIdMap = new HashMap<>(); + for (int i = 0; i < typesList.size(); i++) { + classIdMap.put(typesList.get(i).getId(), i); + } + + // 划分训练集、验证集、测试集 + List trainMarks = new ArrayList<>(); + List valMarks = new ArrayList<>(); + List testMarks = new ArrayList<>(); + + splitDataset(markList, createReqVO, trainMarks, valMarks, testMarks); + + // 处理训练集 + processDetectionDataset(dictDataMap,trainMarks, classIdMap, datasetPath + "/images/train", + datasetPath + "/labels/train"); + + // 处理验证集 + processDetectionDataset(dictDataMap,valMarks, classIdMap, datasetPath + "/images/val", + datasetPath + "/labels/val"); + + // 处理测试集 + processDetectionDataset(dictDataMap,testMarks, classIdMap, datasetPath + "/images/test", + datasetPath + "/labels/test"); + } + + /** + * 处理检测数据集 + */ + private void processDetectionDataset(Map dictDataMap,List marks, Map classIdMap, + String imageDir, String labelDir) { + for (MarkDO mark : marks) { + try { + // 复制图片 + String imageFileName = new File(dictDataMap.get("base_path").getValue()+mark.getPath()).getName(); + File sourceImage = new File(dictDataMap.get("base_path").getValue()+mark.getPath()); + File targetImage = new File(imageDir, imageFileName); + + if (sourceImage.exists()) { + Files.copy(sourceImage.toPath(), targetImage.toPath(), StandardCopyOption.REPLACE_EXISTING); + } + + // 获取图片尺寸 + BufferedImage image = ImageIO.read(sourceImage); + int imageWidth = image.getWidth(); + int imageHeight = image.getHeight(); + + // 生成YOLO格式的标注文件 + String labelFileName = imageFileName.replaceAll("\\.[^.]+$", ".txt"); + File labelFile = new File(labelDir, labelFileName); + + List markInfoList = markInfoService.list(new QueryWrapper() + .eq("mark_id", mark.getId())); + + try (PrintWriter writer = new PrintWriter(new FileWriter(labelFile))) { + for (MarkInfoDO markInfo : markInfoList) { + Integer classIndex = classIdMap.get(markInfo.getClassId()); + if (classIndex != null) { + // 解析标注数据 + AnnotationData annotationData = parseAnnotationData(markInfo); + if (annotationData != null) { + String yoloLine = convertToYOLOFormat(annotationData, classIndex, + imageWidth, imageHeight); + if (yoloLine != null) { + writer.println(yoloLine); + } + } + } + } + } + + } catch (Exception e) { + log.error("处理检测数据失败: {}", mark.getPath(), e); + } + } + } + + /** + * 生成分类任务数据集 + */ + private void generateClassificationDataset(Map dictDataMap,List markList, List typesList, + TrainSaveReqVO createReqVO, String datasetPath) { + // 建立类别映射 + Map classNameMap = new HashMap<>(); + for (TypesDO type : typesList) { + classNameMap.put(type.getId(), type.getName()); + } + + // 划分数据集 + List trainMarks = new ArrayList<>(); + List valMarks = new ArrayList<>(); + List testMarks = new ArrayList<>(); + + splitDataset(markList, createReqVO, trainMarks, valMarks, testMarks); + + // 处理各数据集 + processClassificationDataset(dictDataMap,trainMarks, classNameMap, + datasetPath + "/images/train/classify"); + processClassificationDataset(dictDataMap,valMarks, classNameMap, + datasetPath + "/images/val/classify"); + processClassificationDataset(dictDataMap,testMarks, classNameMap, + datasetPath + "/images/test/classify"); + } + + /** + * 处理分类数据集 + */ + private void processClassificationDataset(Map dictDataMap,List marks, Map classNameMap, + String targetDir) { + for (MarkDO mark : marks) { + try { + // 获取图片的主要分类(假设一张图片只有一个主要类别) + List markInfoList = markInfoService.list(new QueryWrapper() + .eq("data_id", mark.getDataId()).eq("mark_id", mark.getId())); + + if (!markInfoList.isEmpty()) { + String className = classNameMap.get(markInfoList.get(0).getClassId()); + if (className != null) { + // 创建类别子目录 + File classDir = new File(targetDir, className); + classDir.mkdirs(); + + // 复制图片到对应类别目录 + String imageFileName = new File(mark.getPath()).getName(); + File sourceImage = new File(mark.getPath()); + File targetImage = new File(classDir, imageFileName); + + if (sourceImage.exists()) { + Files.copy(sourceImage.toPath(), targetImage.toPath(), + StandardCopyOption.REPLACE_EXISTING); + } + } + } + } catch (Exception e) { + log.error("处理分类数据失败: {}", mark.getPath(), e); + } + } + } + + /** + * 数据集划分 + */ + private void splitDataset(List markList, TrainSaveReqVO createReqVO, + List trainMarks, List valMarks, List testMarks) { + int total = markList.size(); + int trainCount = (int) (total * (createReqVO.getTrain() != null ? createReqVO.getTrain() / 100.0 : 0.8)); + int valCount = (int) (total * (createReqVO.getVal() != null ? createReqVO.getVal() / 100.0 : 0.1)); + + // 随机打乱 + Collections.shuffle(markList); + + trainMarks.addAll(markList.subList(0, trainCount)); + valMarks.addAll(markList.subList(trainCount, Math.min(trainCount + valCount, total))); + testMarks.addAll(markList.subList(trainCount + valCount, total)); + } + + /** + * 解析标注数据 + */ + private AnnotationData parseAnnotationData(MarkInfoDO markInfo) { + if (markInfo.getAnnotationData() != null) { + return markInfo.getAnnotationData(); + } + if (markInfo.getAnnotationDataString() != null && !markInfo.getAnnotationDataString().isEmpty()) { + return AnnotationData.fromString(markInfo.getAnnotationDataString()); + } + + // 如果没有结构化数据,使用基本字段 + AnnotationData data = new AnnotationData(); + AnnotationData.Target target = new AnnotationData.Target(); + AnnotationData.Target.Selector selector = new AnnotationData.Target.Selector(); + AnnotationData.Target.Selector.Geometry geometry = new AnnotationData.Target.Selector.Geometry(); + + selector.setGeometry(geometry); + target.setSelector(selector); + data.setTarget(target); + + return data; + } + + /** + * 转换为YOLO格式 + */ + private String convertToYOLOFormat(AnnotationData annotationData, Integer classIndex, + int imageWidth, int imageHeight) { + try { + AnnotationData.Target.Selector.Geometry geometry = annotationData.getTarget().getSelector().getGeometry(); + + // 获取边界框信息 + double x = geometry.getX() != null ? geometry.getX() : 0; + double y = geometry.getY() != null ? geometry.getY() : 0; + double w = geometry.getW() != null ? geometry.getW() : 0; + double h = geometry.getH() != null ? geometry.getH() : 0; + + // 如果使用基本字段 + if (x == 0 && y == 0 && w == 0 && h == 0) { + // 这里可以添加从MarkInfoDO基本字段获取坐标的逻辑 + // 暂时返回null,表示无法转换 + return null; + } + + // 转换为YOLO格式(相对坐标和宽高) + double centerX = (x + w / 2) / imageWidth; + double centerY = (y + h / 2) / imageHeight; + double width = w / imageWidth; + double height = h / imageHeight; + + // 确保坐标在0-1范围内 + centerX = Math.max(0, Math.min(1, centerX)); + centerY = Math.max(0, Math.min(1, centerY)); + width = Math.max(0, Math.min(1, width)); + height = Math.max(0, Math.min(1, height)); + + return String.format("%d %.6f %.6f %.6f %.6f", classIndex, centerX, centerY, width, height); + + } catch (Exception e) { + log.error("转换YOLO格式失败", e); + return null; + } + } + + /** + * 生成YAML配置文件 + */ + private void generateYamlConfig(List typesList, String taskType, String datasetPath) { + try (PrintWriter writer = new PrintWriter(new FileWriter(datasetPath + "/dataset.yaml"))) { + writer.println("# YOLO Dataset Configuration"); + writer.println("path: " + datasetPath); + writer.println("train: images/train"); + writer.println("val: images/val"); + writer.println("test: images/test"); + + writer.println(); + writer.println("# Classes"); + writer.println("nc: " + typesList.size()); + writer.println("names:"); + + // 按id排序输出类别名称(如果没有index字段,使用id) + typesList.sort(Comparator.comparing(TypesDO::getId)); + for (int i = 0; i < typesList.size(); i++) { + TypesDO type = typesList.get(i); + writer.println(" " + i + ": " + type.getName()); + } + + } catch (IOException e) { + log.error("生成YAML配置文件失败", e); + } + } @Override public void updateTrain(TrainSaveReqVO updateReqVO) { @@ -71,4 +663,260 @@ public class TrainServiceImpl implements TrainService { return trainMapper.selectPage(pageReqVO); } + @Override + public SystemInfoRespVO getSystemInfo() { + SystemInfoRespVO systemInfo = new SystemInfoRespVO(); + + // 获取CPU信息 + systemInfo.setCpu(getCpuInfo()); + + // 获取GPU信息 + systemInfo.setGpus(getGpuInfo()); + + // 获取Python信息 + systemInfo.setPython(getPythonInfo()); + + return systemInfo; + } + + private SystemInfoRespVO.CpuInfo getCpuInfo() { + SystemInfoRespVO.CpuInfo cpuInfo = new SystemInfoRespVO.CpuInfo(); + + try { + // 获取CPU型号 + String os = System.getProperty("os.name").toLowerCase(); + if (os.contains("win")) { + // Windows系统 + Process process = Runtime.getRuntime().exec("wmic cpu get name"); + BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream())); + String line; + while ((line = reader.readLine()) != null) { + if (!line.trim().isEmpty() && !line.contains("Name")) { + cpuInfo.setModel(line.trim()); + break; + } + } + reader.close(); + + // 获取CPU核心数 + process = Runtime.getRuntime().exec("wmic cpu get numberofcores"); + reader = new BufferedReader(new InputStreamReader(process.getInputStream())); + while ((line = reader.readLine()) != null) { + if (!line.trim().isEmpty() && !line.contains("NumberOfCores")) { + cpuInfo.setCores(Integer.parseInt(line.trim())); + break; + } + } + reader.close(); + } else { + // Linux系统 + Process process = Runtime.getRuntime().exec("cat /proc/cpuinfo | grep 'model name' | head -1"); + BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream())); + String line = reader.readLine(); + if (line != null) { + cpuInfo.setModel(line.split(":")[1].trim()); + } + reader.close(); + + // 获取CPU核心数 + process = Runtime.getRuntime().exec("nproc"); + reader = new BufferedReader(new InputStreamReader(process.getInputStream())); + line = reader.readLine(); + if (line != null) { + cpuInfo.setCores(Integer.parseInt(line.trim())); + } + reader.close(); + } + + // 获取CPU使用率(简化实现) + cpuInfo.setUsage(0.0); + + } catch (Exception e) { + log.error("获取CPU信息失败", e); + cpuInfo.setModel("Unknown"); + cpuInfo.setCores(0); + cpuInfo.setUsage(0.0); + } + + return cpuInfo; + } + + private List getGpuInfo() { + List gpuList = new ArrayList<>(); + + try { + // 使用nvidia-smi获取GPU信息 + Process process = Runtime.getRuntime().exec("nvidia-smi --query-gpu=index,name,memory.total,memory.used,memory.free,utilization.gpu,temperature.gpu --format=csv,noheader,nounits"); + BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream())); + String line; + + while ((line = reader.readLine()) != null) { + if (!line.trim().isEmpty()) { + String[] parts = line.split(",\\s*"); + if (parts.length >= 7) { + SystemInfoRespVO.GpuInfo gpuInfo = new SystemInfoRespVO.GpuInfo(); + gpuInfo.setId(Integer.parseInt(parts[0].trim())); + gpuInfo.setName(parts[1].trim()); + gpuInfo.setMemoryTotal(Long.parseLong(parts[2].trim())); + gpuInfo.setMemoryUsed(Long.parseLong(parts[3].trim())); + gpuInfo.setMemoryFree(Long.parseLong(parts[4].trim())); + gpuInfo.setUsage(Double.parseDouble(parts[5].trim())); + gpuInfo.setTemperature(Double.parseDouble(parts[6].trim())); + gpuList.add(gpuInfo); + } + } + } + reader.close(); + + } catch (Exception e) { + log.error("获取GPU信息失败", e); + // 如果nvidia-smi不可用,返回空列表 + } + + return gpuList; + } + + public static void main(String[] args) { + File dir = new File("D:/data"); + File[] files = dir.listFiles(); + List fileNames = new ArrayList<>(); + if (files != null) { + for (File file : files) { + fileNames.add(file.getName()); + } + } + System.out.println(fileNames); + } + private SystemInfoRespVO.PythonInfo getPythonInfo() { + SystemInfoRespVO.PythonInfo pythonInfo = new SystemInfoRespVO.PythonInfo(); + + try { + // 尝试多种Python命令 + String[] pythonCommands = {"python", "python3", "py"}; + String pythonCmd = null; + String version = null; + + for (String cmd : pythonCommands) { + try { + Process process = Runtime.getRuntime().exec(cmd + " --version"); + BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream())); + String line = reader.readLine(); + if (line != null && line.toLowerCase().contains("python")) { + pythonCmd = cmd; + version = line.trim(); + reader.close(); + break; + } + reader.close(); + } catch (Exception e) { + // 继续尝试下一个命令 + } + } + + if (pythonCmd != null) { + pythonInfo.setInstalled(true); + pythonInfo.setVersion(version); + + // 获取Python路径 + try { + Process process = Runtime.getRuntime().exec("where " + pythonCmd); + BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream())); + String line = reader.readLine(); + if (line != null) { + pythonInfo.setPath(line.trim()); + } + reader.close(); + } catch (Exception e) { + log.warn("获取Python路径失败", e); + } + + // 检查各种虚拟环境支持 + checkVirtualEnvironmentSupport(pythonCmd, pythonInfo); + + // 获取虚拟环境列表 +// pythonInfo.setVirtualEnvs(getVirtualEnvironments(pythonCmd)); + + // 检查yolo虚拟环境 + try { + Process process = Runtime.getRuntime().exec("conda env list"); + BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream())); + boolean yoloExists = false; + String line; + while ((line = reader.readLine()) != null) { + if (line.contains("yolo")) { + yoloExists = true; + break; + } + } + pythonInfo.setYoloEnvExists(yoloExists); + reader.close(); + + if (yoloExists) { + // 使用conda直接在yolo环境中执行Python命令 + process = Runtime.getRuntime().exec("conda run -n yolo python -c \"import ultralytics; print(ultralytics.__version__)\""); + reader = new BufferedReader(new InputStreamReader(process.getInputStream())); + line = reader.readLine(); + if (line != null) { + pythonInfo.setYoloInstalled(true); + pythonInfo.setYoloVersion(line.trim()); + } else { + pythonInfo.setYoloInstalled(false); + } + reader.close(); + } + } catch (Exception e) { + pythonInfo.setYoloEnvExists(false); + pythonInfo.setYoloInstalled(false); + log.warn("检查YOLO环境失败", e); + } + } else { + pythonInfo.setInstalled(false); + } + + } catch (Exception e) { + log.error("获取Python信息失败", e); + pythonInfo.setInstalled(false); +// pythonInfo.setVirtualEnvInstalled(false); + pythonInfo.setYoloEnvExists(false); + pythonInfo.setYoloInstalled(false); + } + + return pythonInfo; + } + + private void checkVirtualEnvironmentSupport(String pythonCmd, SystemInfoRespVO.PythonInfo pythonInfo) { + // 检查venv支持(Python 3.3+内置) + try { + Process process = Runtime.getRuntime().exec(pythonCmd + " -m venv --help"); + BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream())); + String output = reader.lines().collect(Collectors.joining()); + pythonInfo.setVenvSupported(!output.isEmpty()); + reader.close(); + } catch (Exception e) { + pythonInfo.setVenvSupported(false); + } + + // 检查virtualenv + try { + Process process = Runtime.getRuntime().exec(pythonCmd + " -m virtualenv --version"); + BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream())); + String line = reader.readLine(); + pythonInfo.setVirtualenvInstalled(line != null && !line.isEmpty()); + reader.close(); + } catch (Exception e) { + pythonInfo.setVirtualenvInstalled(false); + } + + // 检查conda + try { + Process process = Runtime.getRuntime().exec("conda --version"); + BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream())); + String line = reader.readLine(); + pythonInfo.setCondaInstalled(line != null && line.toLowerCase().contains("conda")); + reader.close(); + } catch (Exception e) { + pythonInfo.setCondaInstalled(false); + } + } + } \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/traininfo/TrainInfoService.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/traininfo/TrainInfoService.java new file mode 100644 index 0000000..f97bacb --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/traininfo/TrainInfoService.java @@ -0,0 +1,70 @@ +package cn.iocoder.yudao.module.annotation.service.traininfo; + +import cn.iocoder.yudao.framework.common.pojo.PageResult; +import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoPageReqVO; +import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoRespVO; +import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoSaveReqVO; +import cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo.TrainInfoDO; +import jakarta.validation.Valid; + +import java.util.List; + +/** + * 训练信息 Service 接口 + * + * @author 管理员 + */ +public interface TrainInfoService { + + /** + * 创建训练信息 + * + * @param createReqVO 创建信息 + * @return 编号 + */ + Integer createTrainInfo(@Valid TrainInfoSaveReqVO createReqVO); + + /** + * 更新训练信息 + * + * @param updateReqVO 更新信息 + */ + void updateTrainInfo(@Valid TrainInfoSaveReqVO updateReqVO); + + /** + * 删除训练信息 + * + * @param id 编号 + */ + void deleteTrainInfo(Integer id); + + /** + * 获得训练信息 + * + * @param id 编号 + * @return 训练信息 + */ + TrainInfoDO getTrainInfo(Integer id); + + /** + * 获得训练信息分页 + * + * @param pageReqVO 分页查询 + * @return 训练信息分页 + */ + PageResult getTrainInfoPage(TrainInfoPageReqVO pageReqVO); + + /** + * 保存训练轮次信息 + * + * @param trainId 训练ID + * @param round 轮次 + * @param roundTotal 总轮次 + * @param info 轮次信息 + * @param rate 识别率 + * @return 编号 + */ + Integer saveTrainRoundInfo(Integer trainId, Integer round, Integer roundTotal, String info, String rate); + + List getTrainInfoList(Integer id); +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/traininfo/TrainInfoServiceImpl.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/traininfo/TrainInfoServiceImpl.java new file mode 100644 index 0000000..9bafc8b --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/traininfo/TrainInfoServiceImpl.java @@ -0,0 +1,92 @@ +package cn.iocoder.yudao.module.annotation.service.traininfo; + +import cn.iocoder.yudao.framework.common.pojo.PageResult; +import cn.iocoder.yudao.framework.common.util.object.BeanUtils; +import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; +import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoPageReqVO; +import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoRespVO; +import cn.iocoder.yudao.module.annotation.controller.admin.traininfo.vo.TrainInfoSaveReqVO; +import cn.iocoder.yudao.module.annotation.dal.dataobject.traininfo.TrainInfoDO; +import cn.iocoder.yudao.module.annotation.dal.mysql.traininfo.TrainInfoMapper; +import jakarta.annotation.Resource; +import org.springframework.stereotype.Service; +import org.springframework.validation.annotation.Validated; + +import java.util.List; + +/** + * 训练信息 Service 实现类 + * + * @author 管理员 + */ +@Service +@Validated +public class TrainInfoServiceImpl implements TrainInfoService { + + @Resource + private TrainInfoMapper trainInfoMapper; + + @Override + public Integer createTrainInfo(TrainInfoSaveReqVO createReqVO) { + // 插入 + TrainInfoDO data = BeanUtils.toBean(createReqVO, TrainInfoDO.class); + trainInfoMapper.insert(data); + // 返回 + return data.getId(); + } + + @Override + public void updateTrainInfo(TrainInfoSaveReqVO updateReqVO) { + // 校验存在 +// validateTrainInfoExists(updateReqVO.getId()); + // 更新 + TrainInfoDO updateObj = BeanUtils.toBean(updateReqVO, TrainInfoDO.class); + trainInfoMapper.updateById(updateObj); + } + + @Override + public void deleteTrainInfo(Integer id) { + // 校验存在 + validateTrainInfoExists(id); + // 删除 + trainInfoMapper.deleteById(id); + } + + private void validateTrainInfoExists(Integer id) { + if (trainInfoMapper.selectById(id) == null) { +// throw exception("训练信息不存在"); + } + } + + @Override + public TrainInfoDO getTrainInfo(Integer id) { + return trainInfoMapper.selectById(id); + } + + @Override + public PageResult getTrainInfoPage(TrainInfoPageReqVO pageReqVO) { + return trainInfoMapper.selectPage(pageReqVO); + } + + @Override + public Integer saveTrainRoundInfo(Integer trainId, Integer round, Integer roundTotal, String info, String rate) { + TrainInfoDO trainInfo = TrainInfoDO.builder() + .trainId(trainId) + .round(round) + .roundTotal(roundTotal) + .info(info) + .rate(rate) + .build(); + trainInfoMapper.insert(trainInfo); + return trainInfo.getId(); + } + @Override + public List getTrainInfoList(Integer trainId) { + List trainInfoList = trainInfoMapper.selectList(new LambdaQueryWrapperX() + .eq(TrainInfoDO::getTrainId, trainId) + .orderByAsc(TrainInfoDO::getRound) + .last("limit 15")); + return BeanUtils.toBean(trainInfoList, TrainInfoRespVO.class); + } + +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/trainresult/TrainResultService.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/trainresult/TrainResultService.java index aea2d27..b68dd11 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/trainresult/TrainResultService.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/trainresult/TrainResultService.java @@ -1,11 +1,10 @@ package cn.iocoder.yudao.module.annotation.service.trainresult; -import java.util.*; -import jakarta.validation.*; -import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.*; -import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO; import cn.iocoder.yudao.framework.common.pojo.PageResult; -import cn.iocoder.yudao.framework.common.pojo.PageParam; +import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultPageReqVO; +import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultSaveReqVO; +import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO; +import jakarta.validation.Valid; /** * 识别结果 Service 接口 @@ -52,4 +51,5 @@ public interface TrainResultService { */ PageResult getTrainResultPage(TrainResultPageReqVO pageReqVO); + TrainResultDO getStatus(Integer trainId); } \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/trainresult/TrainResultServiceImpl.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/trainresult/TrainResultServiceImpl.java index 51609a1..d87bd84 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/trainresult/TrainResultServiceImpl.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/trainresult/TrainResultServiceImpl.java @@ -1,20 +1,15 @@ package cn.iocoder.yudao.module.annotation.service.trainresult; -import org.springframework.stereotype.Service; -import jakarta.annotation.Resource; -import org.springframework.validation.annotation.Validated; -import org.springframework.transaction.annotation.Transactional; - -import java.util.*; -import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.*; -import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO; import cn.iocoder.yudao.framework.common.pojo.PageResult; -import cn.iocoder.yudao.framework.common.pojo.PageParam; import cn.iocoder.yudao.framework.common.util.object.BeanUtils; - +import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX; +import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultPageReqVO; +import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultSaveReqVO; +import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO; import cn.iocoder.yudao.module.annotation.dal.mysql.trainresult.TrainResultMapper; - -import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception; +import jakarta.annotation.Resource; +import org.springframework.stereotype.Service; +import org.springframework.validation.annotation.Validated; //import static cn.iocoder.yudao.module.annotation.enums.ErrorCodeConstants.*; /** @@ -71,4 +66,10 @@ public class TrainResultServiceImpl implements TrainResultService { return trainResultMapper.selectPage(pageReqVO); } + + public TrainResultDO getStatus(Integer trainId){ + return trainResultMapper.selectOne(new LambdaQueryWrapperX() + .eq(TrainResultDO::getTrainId, trainId) + .orderByDesc(TrainResultDO::getId).last("limit 1")); + } } \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloOperationService.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloOperationService.java new file mode 100644 index 0000000..7496681 --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloOperationService.java @@ -0,0 +1,76 @@ +package cn.iocoder.yudao.module.annotation.service.yolo; + +import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; +import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig; + +import java.util.concurrent.CompletableFuture; + +/** + * YOLO操作服务接口 + * + * @author 管理员 + */ +public interface YoloOperationService { + + /** + * YOLO训练(异步) + * + * @param pythonPath Python路径 + * @param envName 虚拟环境名称 + * @param config 训练配置 + * @return 异步训练结果 + */ + CompletableFuture trainAsync(String pythonPath, String envName, YoloConfig config); + + /** + * YOLO验证 + * + * @param pythonPath Python路径 + * @param envName 虚拟环境名称 + * @param config 验证配置 + * @return 验证结果配置 + */ + YoloConfig validate(String pythonPath, String envName, YoloConfig config); + + /** + * YOLO预测 + * + * @param pythonPath Python路径 + * @param envName 虚拟环境名称 + * @param config 预测配置 + * @return 预测结果配置 + */ + YoloConfig predict(String pythonPath, String envName, YoloConfig config); + + /** + * YOLO分类 + * + * @param pythonPath Python路径 + * @param envName 虚拟环境名称 + * @param config 分类配置 + * @return 分类结果配置 + */ + YoloConfig classify(String pythonPath, String envName, YoloConfig config); + + /** + * 导出YOLO模型 + * + * @param pythonPath Python路径 + * @param envName 虚拟环境名称 + * @param config 导出配置 + * @return 导出结果配置 + */ + YoloConfig export(String pythonPath, String envName, YoloConfig config); + + /** + * 追踪任务 + * + * @param pythonPath Python路径 + * @param envName 虚拟环境名称 + * @param config 追踪配置 + * @return 追踪结果配置 + */ + YoloConfig track(String pythonPath, String envName, YoloConfig config); + + void generateTrainPythonScript(TrainDO train, String yoloDatasetPath); +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloOperationServiceImpl.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloOperationServiceImpl.java new file mode 100644 index 0000000..aabb9df --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloOperationServiceImpl.java @@ -0,0 +1,1598 @@ +package cn.iocoder.yudao.module.annotation.service.yolo; + +import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultSaveReqVO; +import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; +import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO; +import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig; +import cn.iocoder.yudao.module.annotation.service.python.PythonVirtualEnvService; +import cn.iocoder.yudao.module.annotation.service.python.vo.PythonVirtualEnvInfo; +import cn.iocoder.yudao.module.annotation.service.train.TrainService; +import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService; +import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService; +import cn.iocoder.yudao.module.system.util.RedisUtil; +import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper; +import jakarta.annotation.Resource; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Lazy; +import org.springframework.stereotype.Service; + +import java.io.BufferedReader; +import java.io.File; +import java.io.InputStreamReader; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + +/** + * YOLO操作服务实现类 + * + * @author 管理员 + */ +@Slf4j +@Service +public class YoloOperationServiceImpl implements YoloOperationService { + + @Resource + private PythonVirtualEnvService pythonVirtualEnvService; + + @Resource + private TrainInfoService trainInfoService; + + @Resource + private TrainResultService trainResultService; + + @Resource + private YoloServiceClient yoloServiceClient; + + @Value("${yolo.service.enabled:false}") + private boolean useSocketService; + + // 异步执行器 + private final Executor asyncExecutor = Executors.newCachedThreadPool(); + + @Override + public CompletableFuture trainAsync(String pythonPath, String envName, YoloConfig config) { + return CompletableFuture.supplyAsync(() -> { + config.setStatus("running"); + config.setProgress(0); + config.setTaskId(UUID.randomUUID().toString()); + + // 生成训练ID + Integer trainId = config.getTrainId(); + + // 记录上一轮的epoch,用于检测epoch完成 + int[] lastSavedEpoch = {0}; + + try { + // 检测虚拟环境 + PythonVirtualEnvInfo envInfo = pythonVirtualEnvService.activateVirtualEnv(pythonPath, envName); + if (!envInfo.getVirtualEnvExists() || !envInfo.getYoloInstalled()) { + config.setStatus("failed"); + config.setErrorMessage("虚拟环境不存在或未安装YOLO"); + return config; + } + + // 构建训练命令 + String command = buildTrainCommand(envInfo, config); + + log.info("开始YOLO训练,命令: {}", command); + + // 执行训练 + ProcessBuilder processBuilder = new ProcessBuilder(); + // 在Windows上使用cmd,在Linux/Mac上使用bash + String os = System.getProperty("os.name").toLowerCase(); + if (os.contains("win")) { + processBuilder.command("cmd", "/c", command); + // 设置环境变量确保正确处理中文路径 + processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); + processBuilder.environment().put("LANG", "zh_CN.UTF-8"); + } else { + processBuilder.command("/bin/bash", "-c", command); + processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); + } + Process process = processBuilder.start(); + BufferedReader stdoutReader = new BufferedReader(new InputStreamReader(process.getInputStream())); + BufferedReader stderrReader = new BufferedReader(new InputStreamReader(process.getErrorStream())); +// BufferedReader stderrReader2 = new BufferedReader(new InputStreamReader(process.getOutputStream())); + + StringBuilder fullOutput = new StringBuilder(); + + // 使用单独的线程读取stdout和stderr,避免死锁 + String stdoutLine, stderrLine = null; + while ((stdoutLine = stdoutReader.readLine()) != null || (stderrLine = stderrReader.readLine()) != null ) { + if (stdoutLine != null) { + log.info("YOLO训练日志: {}", stdoutLine); + fullOutput.append(stdoutLine).append("\n"); + + // 详细解析YOLO训练进度和结果 + manageTrainingLogInRedis(trainId, stdoutLine); + config.setLogMessage(stdoutLine); + + // 检测 "Results saved to" 并解析路径和识别率 + detectAndParseResultsSaved(stdoutLine, config, trainId); + } + if (stderrLine != null) { + log.info("YOLO训练日志: {}", stderrLine); + fullOutput.append(stderrLine).append("\n"); + + manageTrainingLogInRedis(trainId, stderrLine); + + config.setLogMessage(stderrLine); + + // 检测 "Results saved to" 并解析路径和识别率 + detectAndParseResultsSaved(stderrLine, config, trainId); + } + } + + int exitCode = process.waitFor(); + stdoutReader.close(); + stderrReader.close(); + + if (exitCode == 0) { + config.setStatus("completed"); + config.setProgress(100); + config.setLogMessage("训练完成"); + // 解析最终训练结果摘要 + parseFinalResults(fullOutput.toString(), config); + // 保存最终训练结果 + saveFinalTrainResult(trainId, config, fullOutput.toString()); + } else { + config.setStatus("failed"); + config.setErrorMessage("训练失败,退出码: " + exitCode); + } + + } catch (Exception e) { + log.error("YOLO训练异常", e); + config.setStatus("failed"); + config.setErrorMessage("训练异常: " + e.getMessage()); + } + return config; + }, asyncExecutor); + } + + @Override + public YoloConfig validate(String pythonPath, String envName, YoloConfig config) { + // 优先尝试使用socket服务(如果启用且可用) + if (useSocketService && yoloServiceClient != null && yoloServiceClient.isServiceAvailable()) { + log.info("使用socket服务执行YOLO验证"); + return yoloServiceClient.validateViaSocket( + config.getModelPath(), + config.getValDatasetPath(), + config + ); + } + + // 回退到原有的ProcessBuilder方式 + log.info("使用传统方式执行YOLO验证"); + try { + PythonVirtualEnvInfo envInfo = pythonVirtualEnvService.activateVirtualEnv(pythonPath, envName); + if (!envInfo.getVirtualEnvExists() || !envInfo.getYoloInstalled()) { + config.setStatus("failed"); + config.setErrorMessage("虚拟环境不存在或未安装YOLO"); + return config; + } + + String command = buildValidateCommand(envInfo, config); + log.info("开始YOLO验证,命令: {}", command); + + ProcessBuilder processBuilder = new ProcessBuilder(); + String os = System.getProperty("os.name").toLowerCase(); + if (os.contains("win")) { + processBuilder.command("cmd", "/c", command); + // 设置环境变量确保正确处理中文路径 + processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); + processBuilder.environment().put("LANG", "zh_CN.UTF-8"); + } else { + processBuilder.command("/bin/bash", "-c", command); + processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); + } + Process process = processBuilder.start(); + BufferedReader stdoutReader = new BufferedReader(new InputStreamReader(process.getInputStream())); + BufferedReader stderrReader = new BufferedReader(new InputStreamReader(process.getErrorStream())); + + StringBuilder output = new StringBuilder(); + String line; + + // 同时读取标准输出和错误输出 + String stdoutLine, stderrLine = null; + while ((stdoutLine = stdoutReader.readLine()) != null || (stderrLine = stderrReader.readLine()) != null) { + if (stdoutLine != null) { + output.append(stdoutLine).append("\n"); + log.info("YOLO验证日志: {}", stdoutLine); + } + if (stderrLine != null) { + output.append(stderrLine).append("\n"); + log.info("YOLO验证日志: {}", stderrLine); + } + } + + int exitCode = process.waitFor(); + stdoutReader.close(); + stderrReader.close(); + + if (exitCode == 0) { + config.setStatus("completed"); + config.setLogMessage("验证完成"); + config.setMetrics(output.toString()); + } else { + config.setStatus("failed"); + config.setErrorMessage("验证失败,退出码: " + exitCode); + config.setLogMessage(output.toString()); + } + + } catch (Exception e) { + log.error("YOLO验证异常", e); + config.setStatus("failed"); + config.setErrorMessage("验证异常: " + e.getMessage()); + } + + return config; + } + + @Override + public YoloConfig predict(String pythonPath, String envName, YoloConfig config) { + // 优先尝试使用socket服务(如果启用且可用) + if (useSocketService && yoloServiceClient != null && yoloServiceClient.isServiceAvailable()) { + log.info("使用socket服务执行YOLO预测"); + return yoloServiceClient.predictViaSocket( + config.getModelPath(), + config.getInputPath(), + config.getOutputPath(), + config + ); + } + + // 回退到原有的ProcessBuilder方式 + log.info("使用传统方式执行YOLO预测"); + try { + PythonVirtualEnvInfo envInfo = pythonVirtualEnvService.activateVirtualEnv(pythonPath, envName); + if (!envInfo.getVirtualEnvExists() || !envInfo.getYoloInstalled()) { + config.setStatus("failed"); + config.setErrorMessage("虚拟环境不存在或未安装YOLO"); + return config; + } + + String command = buildPredictCommand(envInfo, config); + log.info("开始YOLO预测,命令: {}", command); + + ProcessBuilder processBuilder = new ProcessBuilder(); + String os = System.getProperty("os.name").toLowerCase(); + if (os.contains("win")) { + processBuilder.command("cmd", "/c", command); + // 设置环境变量确保正确处理中文路径 + processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); + processBuilder.environment().put("LANG", "zh_CN.UTF-8"); + } else { + processBuilder.command("/bin/bash", "-c", command); + processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); + } + Process process = processBuilder.start(); + BufferedReader stdoutReader = new BufferedReader(new InputStreamReader(process.getInputStream())); + BufferedReader stderrReader = new BufferedReader(new InputStreamReader(process.getErrorStream())); + + StringBuilder output = new StringBuilder(); + String line; + + // 同时读取标准输出和错误输出 + String stdoutLine, stderrLine = null; + while ((stdoutLine = stdoutReader.readLine()) != null || (stderrLine = stderrReader.readLine()) != null) { + if (stdoutLine != null) { + output.append(stdoutLine).append("\n"); + log.info("YOLO预测日志: {}", stdoutLine); + if (stdoutLine.contains("Results saved to")) { + + stdoutLine = stdoutLine.replace("\u001B","").replace("[1m","").replace("[0m",""); + config.setOutputPath(stdoutLine.split(" ")[3]); + } + } + if (stderrLine != null) { + output.append(stderrLine).append("\n"); + log.info("YOLO预测日志: {}", stderrLine); + + if (stderrLine.contains("Results saved to")) { + stderrLine = stderrLine.replace("\u001B","").replace("[1m","").replace("[0m",""); + config.setOutputPath(stderrLine.split(" ")[3]); + } + } + } + + int exitCode = process.waitFor(); + stdoutReader.close(); + stderrReader.close(); + + if (exitCode == 0) { + config.setStatus("completed"); + config.setLogMessage("预测完成"); + config.setMetrics(output.toString()); + } else { + config.setStatus("failed"); + config.setErrorMessage("预测失败,退出码: " + exitCode); + config.setLogMessage(output.toString()); + } + + } catch (Exception e) { + log.error("YOLO预测异常", e); + config.setStatus("failed"); + config.setErrorMessage("预测异常: " + e.getMessage()); + } + + return config; + } + + @Override + public YoloConfig classify(String pythonPath, String envName, YoloConfig config) { + config.setTaskType("classify"); + return predict(pythonPath, envName, config); // 分类和预测使用相同的逻辑,只是任务类型不同 + } + + @Override + public YoloConfig export(String pythonPath, String envName, YoloConfig config) { + try { + PythonVirtualEnvInfo envInfo = pythonVirtualEnvService.activateVirtualEnv(pythonPath, envName); + if (!envInfo.getVirtualEnvExists() || !envInfo.getYoloInstalled()) { + config.setStatus("failed"); + config.setErrorMessage("虚拟环境不存在或未安装YOLO"); + return config; + } + + String command = buildExportCommand(envInfo, config); + log.info("开始YOLO导出,命令: {}", command); + + ProcessBuilder processBuilder = new ProcessBuilder(); + String os = System.getProperty("os.name").toLowerCase(); + if (os.contains("win")) { + processBuilder.command("cmd", "/c", command); + // 设置环境变量确保正确处理中文路径 + processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); + processBuilder.environment().put("LANG", "zh_CN.UTF-8"); + } else { + processBuilder.command("/bin/bash", "-c", command); + processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); + } + Process process = processBuilder.start(); + BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream())); + + StringBuilder output = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + output.append(line).append("\n"); + log.info("YOLO导出日志: {}", line); + } + + int exitCode = process.waitFor(); + reader.close(); + + if (exitCode == 0) { + config.setStatus("completed"); + config.setLogMessage("导出完成"); + config.setMetrics(output.toString()); + } else { + config.setStatus("failed"); + config.setErrorMessage("导出失败,退出码: " + exitCode); + config.setLogMessage(output.toString()); + } + + } catch (Exception e) { + log.error("YOLO导出异常", e); + config.setStatus("failed"); + config.setErrorMessage("导出异常: " + e.getMessage()); + } + + return config; + } + + @Override + public YoloConfig track(String pythonPath, String envName, YoloConfig config) { + try { + PythonVirtualEnvInfo envInfo = pythonVirtualEnvService.activateVirtualEnv(pythonPath, envName); + if (!envInfo.getVirtualEnvExists() || !envInfo.getYoloInstalled()) { + config.setStatus("failed"); + config.setErrorMessage("虚拟环境不存在或未安装YOLO"); + return config; + } + + String command = buildTrackCommand(envInfo, config); + log.info("开始YOLO追踪,命令: {}", command); + + ProcessBuilder processBuilder = new ProcessBuilder(); + String os = System.getProperty("os.name").toLowerCase(); + if (os.contains("win")) { + processBuilder.command("cmd", "/c", command); + // 设置环境变量确保正确处理中文路径 + processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); + processBuilder.environment().put("LANG", "zh_CN.UTF-8"); + } else { + processBuilder.command("/bin/bash", "-c", command); + processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); + } + Process process = processBuilder.start(); + BufferedReader reader = new BufferedReader(new InputStreamReader(process.getErrorStream())); + + StringBuilder output = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + output.append(line).append("\n"); + log.info("YOLO追踪日志: {}", line); + } + + int exitCode = process.waitFor(); + reader.close(); + + if (exitCode == 0) { + config.setStatus("completed"); + config.setLogMessage("追踪完成"); + config.setMetrics(output.toString()); + } else { + config.setStatus("failed"); + config.setErrorMessage("追踪失败,退出码: " + exitCode); + config.setLogMessage(output.toString()); + } + + } catch (Exception e) { + log.error("YOLO追踪异常", e); + config.setStatus("failed"); + config.setErrorMessage("追踪异常: " + e.getMessage()); + } + + return config; + } + + private String buildTrainCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) { + StringBuilder command = new StringBuilder(); + + if ("conda".equals(envInfo.getVirtualEnvType())) { + command.append("conda run -n ").append(new File(envInfo.getVirtualEnvPath()).getName()).append(" "); + } else { + command.append(normalizePath(envInfo.getVirtualEnvPythonPath())).append(" "); + } + + // 规范化路径,处理双斜杠和中文路径问题 + String datasetPath = normalizePath(config.getDatasetPath()); + String outputPath = normalizePath(config.getOutputPath()); + String modelPath = config.getModelPath() != null ? normalizePath(config.getModelPath()) : null; + + // 调试日志:检查模型路径 + log.info("模型路径配置 - modelPath: {}, modelName: {}", config.getModelPath(), config.getModelName()); + + + +// command.append("-c \""); + command.append(datasetPath).append("/train.py "); +// command.append("from ultralytics import YOLO; "); +// +// if (modelPath != null && !modelPath.trim().isEmpty()) { +// command.append("model = YOLO('").append(escapePythonString(modelPath)).append("'); "); +// } else { +// command.append("model = YOLO('").append(config.getModelName()).append("'); "); +// } +// +// // 使用raw string来处理中文路径 +// command.append("model.train(data=r'").append(escapePythonString(datasetPath+"\\dataset.yaml")).append("', "); +// command.append("epochs=").append(config.getEpochs()).append(", "); +// command.append("batch=").append(config.getBatchSize()).append(", "); +// command.append("imgsz=").append(config.getImageSize()).append(", "); +// +// // 取消注释设备和学习率配置 +//// command.append("device=").append(config.getDevice()).append(", "); +//// command.append("lr0=").append(config.getLearningRate()).append(", "); +//// command.append("workers=").append(config.getWorkers()).append(", "); +//// +// if (config.getSavePeriod() > 0) { +// command.append("save_period=").append(config.getSavePeriod()).append(", "); +// } +// +// command.append("project=r'").append(escapePythonString(outputPath)).append("')\""); +// + return command.toString(); + } + + private String buildValidateCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) { + StringBuilder command = new StringBuilder(); + + if ("conda".equals(envInfo.getVirtualEnvType())) { + command.append("conda run -n ").append(new File(envInfo.getVirtualEnvPath()).getName()).append(" "); + } else { + command.append(envInfo.getVirtualEnvPythonPath()).append(" "); + } + + + + String datasetPath = normalizePath(config.getValDatasetPath() != null ? + config.getValDatasetPath() : config.getDatasetPath()); + String modelPath = normalizePath(config.getModelPath()); + + command.append("-c \""); + command.append("import sys; sys.path.insert(0, '.'); "); + command.append("from ultralytics import YOLO; "); + command.append("model = YOLO(r'").append(escapePythonString(modelPath)).append("'); "); + command.append("results = model.val(data=r'").append(escapePythonString(datasetPath)).append("', "); + command.append("conf=").append(config.getConfThresh()).append(", "); + command.append("iou=").append(config.getIouThresh()).append(", "); + command.append("imgsz=").append(config.getImageSize()).append(", "); + command.append("device='").append(config.getDevice()).append("')\""); + + return command.toString(); + } + private String buildPredictCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) { + StringBuilder command = new StringBuilder(); + + if ("conda".equals(envInfo.getVirtualEnvType())) { + command.append("conda run -n ").append(new File(envInfo.getVirtualEnvPath()).getName()).append(" "); + } else { + command.append(envInfo.getVirtualEnvPythonPath()).append(" "); + } + + String inputPath = normalizePath(config.getInputPath()); + String outputPath = normalizePath(config.getOutputPath()); + String modelPath = normalizePath(config.getModelPath()); + + command.append("-c \""); + command.append("import sys; sys.path.insert(0, '.'); "); + command.append("from ultralytics import YOLO; "); + command.append("model = YOLO(r'").append(escapePythonString(modelPath)).append("'); "); + command.append("results = model.predict(source=r'").append(escapePythonString(inputPath)).append("', "); + // 在这里我们将 save=true 改为 save=True + command.append("save=").append(config.isSave() ? "True" : "False").append(", "); + command.append("project='").append(escapePythonString(outputPath)).append("')\""); + + return command.toString(); + } + + + private String buildExportCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) { + StringBuilder command = new StringBuilder(); + + if ("conda".equals(envInfo.getVirtualEnvType())) { + command.append("conda run -n ").append(new File(envInfo.getVirtualEnvPath()).getName()).append(" "); + } else { + command.append(envInfo.getVirtualEnvPythonPath()).append(" "); + } + + command.append("-c \""); + command.append("from ultralytics import YOLO; "); + command.append("model = YOLO('").append(config.getModelPath()).append("'); "); + command.append("model.export(format='").append(config.getFormat()).append("', "); + command.append("imgsz=").append(config.getImgsz()).append(", "); + command.append("optimize=").append(config.isOptimize()).append(")\""); + + return command.toString(); + } + + private String buildTrackCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) { + StringBuilder command = new StringBuilder(); + + if ("conda".equals(envInfo.getVirtualEnvType())) { + command.append("conda run -n ").append(new File(envInfo.getVirtualEnvPath()).getName()).append(" "); + } else { + command.append(envInfo.getVirtualEnvPythonPath()).append(" "); + } + + command.append("-c \""); + command.append("from ultralytics import YOLO; "); + command.append("model = YOLO('").append(config.getModelPath()).append("'); "); + command.append("results = model.track(source='").append(config.getInputPath()).append("', "); + command.append("conf=").append(config.getTrackConf()).append(", "); + command.append("iou=").append(config.getTrackIou()).append(", "); + command.append("tracker='").append(config.getTracker()).append("', "); + command.append("device='").append(config.getDevice()).append("', "); + command.append("save=").append(config.isSave()).append(")\""); + + return command.toString(); + } + + + /** + * 解析最终训练结果 + */ + private void parseFinalResults(String fullOutput, YoloConfig config) { + try { + String[] lines = fullOutput.split("\n"); + StringBuilder summary = new StringBuilder(); + summary.append("训练完成!最终结果:\n"); + + boolean foundResults = false; + + for (String line : lines) { + String trimmedLine = line.trim(); + + // 查找最终的验证结果 + if (trimmedLine.contains("Class") || trimmedLine.contains("Images") || + trimmedLine.contains("Instances") || trimmedLine.contains("Box")) { + // 这些是表格标题,跳过 + continue; + } + + // 查找最终的mAP结果 + if (trimmedLine.contains("mAP50") || trimmedLine.contains("mAP50-95")) { + summary.append("• ").append(trimmedLine).append("\n"); + foundResults = true; + } + + // 查找精度和召回率 + if (trimmedLine.contains("precision:") || trimmedLine.contains("recall:")) { + summary.append("• ").append(trimmedLine).append("\n"); + foundResults = true; + } + + // 查找模型保存路径 + if (trimmedLine.contains("Results saved to") || trimmedLine.contains("Saved to")) { + summary.append("• ").append(trimmedLine).append("\n"); + foundResults = true; + } + + // 查找训练时间 + if (trimmedLine.contains("Training completed in") || trimmedLine.contains("Time")) { + summary.append("• ").append(trimmedLine).append("\n"); + foundResults = true; + } + } + + if (!foundResults) { + summary.append("详细结果请查看日志输出和训练输出目录。\n"); + } + + // 添加训练参数摘要 + summary.append("\n训练参数摘要:\n"); + summary.append(String.format("• 模型: %s\n", config.getModelName())); + summary.append(String.format("• 训练轮数: %d\n", config.getEpochs())); + summary.append(String.format("• 批次大小: %d\n", config.getBatchSize())); + summary.append(String.format("• 学习率: %.4f\n", config.getLearningRate())); + summary.append(String.format("• 图像尺寸: %d\n", config.getImageSize())); + summary.append(String.format("• 设备: %s\n", config.getDevice())); + + // 如果解析到了指标,添加最终指标摘要 + if (config.getPrecision() > 0 || config.getRecall() > 0 || config.getMap() > 0) { + summary.append("\n最终性能指标:\n"); + if (config.getPrecision() > 0) { + summary.append(String.format("• 精度(Precision): %.4f\n", config.getPrecision())); + } + if (config.getRecall() > 0) { + summary.append(String.format("• 召回率(Recall): %.4f\n", config.getRecall())); + } + if (config.getMap() > 0) { + summary.append(String.format("• mAP: %.4f\n", config.getMap())); + } + } + + config.setTrainingSummary(summary.toString()); + config.setMetrics(summary.toString()); + + } catch (Exception e) { + log.error("解析最终训练结果时出错", e); + config.setTrainingSummary("训练完成,但解析详细结果时出错。请检查日志和输出文件。"); + } + } + + /** + * 生成训练ID + */ + private Integer generateTrainId(YoloConfig config) { + // 这里可以根据业务需求生成训练ID + // 简单实现:使用当前时间戳的后8位 + return (int) (System.currentTimeMillis() % 100000000); + } + + /** + * 保存训练轮次信息 + */ + private void saveTrainEpochInfo(Integer trainId, Integer currentEpoch, Integer totalEpochs, YoloConfig config) { + try { + // 构建训练信息 + StringBuilder epochInfo = new StringBuilder(); + epochInfo.append("Epoch ").append(currentEpoch).append("/").append(totalEpochs).append(" 完成\n"); + + if (config.getLoss() > 0) { + epochInfo.append("损失值: ").append(String.format("%.4f", config.getLoss())).append("\n"); + } + if (config.getPrecision() > 0) { + epochInfo.append("精度: ").append(String.format("%.4f", config.getPrecision())).append("\n"); + } + if (config.getRecall() > 0) { + epochInfo.append("召回率: ").append(String.format("%.4f", config.getRecall())).append("\n"); + } + if (config.getMap() > 0) { + epochInfo.append("mAP: ").append(String.format("%.4f", config.getMap())).append("\n"); + } + + // 构建识别率字符串 + StringBuilder rateStr = new StringBuilder(); + if (config.getPrecision() > 0) { + rateStr.append("precision:").append(String.format("%.4f", config.getPrecision())); + } + if (config.getRecall() > 0) { + if (rateStr.length() > 0) rateStr.append(","); + rateStr.append("recall:").append(String.format("%.4f", config.getRecall())); + } + if (config.getMap() > 0) { + if (rateStr.length() > 0) rateStr.append(","); + rateStr.append("map:").append(String.format("%.4f", config.getMap())); + } + + // 保存到数据库 + trainInfoService.saveTrainRoundInfo( + trainId, + currentEpoch, + totalEpochs, + epochInfo.toString(), + rateStr.toString() + ); + + log.info("已保存训练轮次信息: trainId={}, epoch={}", trainId, currentEpoch); + + } catch (Exception e) { + log.error("保存训练轮次信息失败: trainId={}, epoch={}", trainId, currentEpoch, e); + // 不影响训练继续进行 + } + } + + /** + * 保存最终训练结果 + */ + private void saveFinalTrainResult(Integer trainId, YoloConfig config, String fullOutput) { + try { + // 构建结果路径(通常是训练输出目录) + String resultPath = config.getOutputPath(); + if (resultPath == null || resultPath.trim().isEmpty()) { + resultPath = "runs/detect/train"; // YOLO默认输出路径 + } + + // 构建最终识别率 + StringBuilder finalRate = new StringBuilder(); + if (config.getPrecision() > 0) { + finalRate.append("precision:").append(String.format("%.4f", config.getPrecision())); + } + if (config.getRecall() > 0) { + if (finalRate.length() > 0) finalRate.append(","); + finalRate.append("recall:").append(String.format("%.4f", config.getRecall())); + } + if (config.getMap() > 0) { + if (finalRate.length() > 0) finalRate.append(","); + finalRate.append("map:").append(String.format("%.4f", config.getMap())); + } + + // 创建TrainResultDO并保存 + TrainResultDO trainResult = TrainResultDO.builder() + .trainId(trainId) + .path(resultPath) + .rate(finalRate.toString()) + .dataId(null) // 可以根据需要设置数据集ID + .build(); + + // 由于TrainResultService可能没有直接的DO保存方法,我们可能需要使用VO + // 这里假设需要转换为SaveReqVO,如果没有这个方法,需要添加 + saveTrainResultDO(trainResult); + + log.info("已保存最终训练结果: trainId={}, path={}", trainId, resultPath); + + } catch (Exception e) { + log.error("保存最终训练结果失败: trainId={}", trainId, e); + // 不影响训练完成状态 + } + } + @Resource + @Lazy + TrainService trainService; + + /** + * 保存TrainResultDO到数据库 + */ + private void saveTrainResultDO(TrainResultDO trainResult) { + try { + // 创建SaveReqVO + TrainResultSaveReqVO saveReqVO = + new TrainResultSaveReqVO(); + + // 设置值 + saveReqVO.setTrainId(trainResult.getTrainId()); + saveReqVO.setPath(trainResult.getPath()); + saveReqVO.setRate(trainResult.getRate()); + saveReqVO.setDataId(trainResult.getDataId()); + + // 保存 + trainResultService.createTrainResult(saveReqVO); + trainService.update(new UpdateWrapper().eq("id", trainResult.getTrainId()) + .set("train_type", 4)); + + } catch (Exception e) { + log.error("保存训练结果失败: {}", trainResult, e); + // 不影响训练完成状态 + } + } + + /** + * 规范化路径,处理Windows/Linux路径分隔符、双斜杠等问题 + * + * @param path 原始路径 + * @return 规范化后的路径 + */ + private String normalizePath(String path) { + if (path == null || path.trim().isEmpty()) { + return path; + } + + try { + String normalized = path.trim(); + + // 在Windows系统中,特别处理中文路径和特殊字符 + String os = System.getProperty("os.name").toLowerCase(); + if (os.contains("win")) { + // 对于Windows系统,需要确保路径能够正确处理中文 + // 将反斜杠转换为正斜杠,但保持路径编码正确 + normalized = normalized.replace("\\", "/"); + + // 处理双斜杠问题 + while (normalized.contains("//")) { + normalized = normalized.replace("//", "/"); + } + + // 确保路径以正确的格式存在,避免额外的斜杠 + if (normalized.startsWith("/") && normalized.length() > 1 && normalized.charAt(1) != '/') { + // 这是网络路径或绝对路径,保持原样 + } + + } else { + // 对于Linux/Mac系统,使用File类进行标准化 + File file = new File(normalized); + normalized = file.getCanonicalPath(); + } + + log.debug("路径规范化: {} -> {}", path, normalized); + return normalized; + + } catch (Exception e) { + log.warn("路径规范化失败,使用原始路径: {}", path, e); + return path; + } + } + + /** + * 转义Python字符串中的特殊字符 + * + * @param input 原始字符串 + * @return 转义后的字符串 + */ + private String escapePythonString(String input) { + if (input == null) { + return ""; + } + + // 首先规范化路径 + String normalized = normalizePath(input); + + // 转义Python字符串中的特殊字符 + StringBuilder escaped = new StringBuilder(); + + for (int i = 0; i < normalized.length(); i++) { + char c = normalized.charAt(i); + + switch (c) { + case '\\': + // 反斜杠需要转义 + escaped.append("\\\\"); + break; + case '"': + // 双引号需要转义(因为我们用双引号包围Python字符串) + escaped.append("\\\""); + break; + case '\'': + // 单引号也需要转义 + escaped.append("\\\'"); + break; + case '\n': + // 换行符 + escaped.append("\\n"); + break; + case '\r': + // 回车符 + escaped.append("\\r"); + break; + case '\t': + // 制表符 + escaped.append("\\t"); + break; + case '\b': + // 退格符 + escaped.append("\\b"); + break; + case '\f': + // 换页符 + escaped.append("\\f"); + break; + default: + // 对于其他控制字符(ASCII < 32),进行转义 + if (c < 32) { + escaped.append(String.format("\\x%02x", (int) c)); + } else { + escaped.append(c); + } + break; + } + } + + log.debug("Python字符串转义: {} -> {}", input, escaped.toString()); + return escaped.toString(); + } + + @Resource + private RedisUtil redisUtil; + /** + * 检测并解析 "Results saved to" 日志,提取保存路径和识别率 + * + * @param logLine 日志行 + * @param config YOLO配置对象 + * @param trainId 训练ID + */ + private void detectAndParseResultsSaved(String logLine, YoloConfig config, Integer trainId) { + try { + if (logLine == null || logLine.trim().isEmpty()) { + return; + } + + String trimmedLine = logLine.trim(); + + // 检测 "Results saved to" 或 "Saved to" 关键字 + if (trimmedLine.contains("Results saved to ") || trimmedLine.contains("Saved to ")) { + // 提取保存路径 + String savedPath = extractSavedPath(trimmedLine); + savedPath = savedPath.replace("\u001B","").replace("[1m","").replace("[0m",""); + if (savedPath != null && !savedPath.trim().isEmpty()) { + log.info("检测到YOLO结果保存路径: {}", savedPath); + + // 保存路径到配置中 + config.setOutputPath(savedPath); + + // 尝试从保存路径中获取识别率信息 + parseAccuracyFromSavedPath(savedPath, config, trainId); + + // 保存结果路径到数据库 + saveResultsPath(trainId, savedPath, config); + } + } + + } catch (Exception e) { + log.error("检测和解析Results saved to时出错", e); + } + } + + /** + * 从日志行中提取保存路径 + * + * @param logLine 包含 "Results saved to" 的日志行 + * @return 提取的路径,如果提取失败返回null + */ + private String extractSavedPath(String logLine) { + try { + // 处理不同格式的日志输出 + String[] patterns = { + "Results saved to (.*?)\\s*$", + "Saved to (.*?)\\s*$", + "Results saved to (.*?)(?:\\s+\\(\\d+.*?\\))?$", // 带统计信息的格式 + "Saved to (.*?)(?:\\s+\\(\\d+.*?\\))?$" + }; + + for (String pattern : patterns) { + java.util.regex.Pattern p = java.util.regex.Pattern.compile(pattern); + java.util.regex.Matcher m = p.matcher(logLine); + if (m.find()) { + String path = m.group(1).trim(); + // 移除可能的引号 + if (path.startsWith("\"") && path.endsWith("\"")) { + path = path.substring(1, path.length() - 1); + } + if (path.startsWith("'") && path.endsWith("'")) { + path = path.substring(1, path.length() - 1); + } + return path; + } + } + + // 如果正则匹配失败,尝试简单的字符串分割 + String[] parts = logLine.split("Results saved to|Saved to"); + if (parts.length > 1) { + String path = parts[1].trim(); + // 移除可能的引号 + if (path.startsWith("\"") && path.endsWith("\"")) { + path = path.substring(1, path.length() - 1); + } + if (path.startsWith("'") && path.endsWith("'")) { + path = path.substring(1, path.length() - 1); + } + return path; + } + + } catch (Exception e) { + log.warn("提取保存路径失败: {}", logLine, e); + } + return null; + } + + /** + * 从保存路径中解析识别率信息 + * YOLO训练完成后,通常会在保存路径中生成包含性能指标的文件 + * + * @param savedPath YOLO结果保存路径 + * @param config YOLO配置对象 + * @param trainId 训练ID + */ + private void parseAccuracyFromSavedPath(String savedPath, YoloConfig config, Integer trainId) { + try { + log.info("开始从保存路径解析识别率: {}", savedPath); + + java.io.File saveDir = new java.io.File(savedPath); + if (!saveDir.exists() || !saveDir.isDirectory()) { + log.warn("保存路径不存在或不是目录: {}", savedPath); + return; + } + + // 查找可能包含识别率的文件 + java.io.File[] files = saveDir.listFiles(); + if (files == null) { + log.warn("无法读取保存目录内容: {}", savedPath); + return; + } + + // 优先查找results.csv文件(YOLO通常会在训练完成后生成) + for (java.io.File file : files) { + if (file.getName().equals("results.csv")) { + parseResultsCsv(file, config, trainId); + return; + } + } + + // 如果没有results.csv,查找其他可能的文件 + for (java.io.File file : files) { + String fileName = file.getName().toLowerCase(); + + // 查找train_batch或val_batch的图片文件,文件名中可能包含性能指标 + if (fileName.contains("val_batch") && fileName.endsWith(".jpg")) { + log.info("找到验证批次图片,但无法直接从中提取数值指标"); + } + + // 查找可能的日志文件 + if (fileName.endsWith(".log") || fileName.endsWith(".txt")) { + parseLogFileForAccuracy(file, config, trainId); + } + } + + // 如果没有找到结果文件,尝试解析路径中的训练信息 + parseTrainingInfoFromPath(savedPath, config); + + } catch (Exception e) { + log.error("从保存路径解析识别率时出错: {}", savedPath, e); + } + } + + /** + * 解析results.csv文件获取识别率 + * 支持不同任务类型:检测(detect)、分类(classify)、分割(segment) + * + * @param csvFile results.csv文件 + * @param config YOLO配置对象 + * @param trainId 训练ID + */ + private void parseResultsCsv(java.io.File csvFile, YoloConfig config, Integer trainId) { + try { + log.info("解析results.csv文件: {} (任务类型: {})", csvFile.getAbsolutePath(), config.getTaskType()); + + java.util.List lines = java.nio.file.Files.readAllLines(csvFile.toPath()); + if (lines.isEmpty()) { + return; + } + + // 读取标题行和最后一行(最新的结果) + String headerLine = lines.get(0); + String lastLine = lines.get(lines.size() - 1); + String[] headers = headerLine.split(","); + String[] values = lastLine.split(","); + + log.debug("CSV标题行: {}", headerLine); + log.debug("CSV数据行: {}", lastLine); + + // 根据任务类型和标题解析不同的指标 + String taskType = config.getTaskType(); + if (taskType == null) { + taskType = "detect"; // 默认为检测任务 + } + + double precision = 0.0, recall = 0.0, map = 0.0, accuracy = 0.0, f1Score = 0.0; + + try { + if ("classify".equals(taskType) || "classification".equals(taskType)) { + // 分类任务:解析 accuracy, top1_acc, top5_acc, loss 等 + for (int i = 0; i < headers.length && i < values.length; i++) { + String header = headers[i].trim().toLowerCase(); + String value = values[i].trim(); + + try { + if (header.contains("accuracy") || header.contains("acc")) { + accuracy = Double.parseDouble(value); + config.setPrecision(accuracy); // 分类任务用accuracy作为precision + log.info("分类任务解析到准确率(Accuracy): {}", accuracy); + } else if (header.contains("top1") || header.contains("top_1")) { + double top1Acc = Double.parseDouble(value); + log.info("分类任务解析到Top-1准确率: {}", top1Acc); + } else if (header.contains("top5") || header.contains("top_5")) { + double top5Acc = Double.parseDouble(value); + log.info("分类任务解析到Top-5准确率: {}", top5Acc); + } else if (header.contains("loss")) { + // 记录损失值 + log.debug("分类任务解析到损失值: {}", value); + } + } catch (NumberFormatException e) { + log.debug("跳过无法解析的数值: {}", value); + } + } + + // 保存分类任务的准确率到数据库 + if (accuracy > 0) { + saveClassificationAccuracyToDatabase(trainId, accuracy); + } + + } else if ("detect".equals(taskType) || "segment".equals(taskType)) { + // 检测或分割任务:解析 precision, recall, mAP50, mAP50-95 等 + for (int i = 0; i < headers.length && i < values.length; i++) { + String header = headers[i].trim().toLowerCase(); + String value = values[i].trim(); + + try { + if (header.contains("precision")) { + precision = Double.parseDouble(value); + config.setPrecision(precision); + } else if (header.contains("recall")) { + recall = Double.parseDouble(value); + config.setRecall(recall); + } else if (header.contains("map50-95") || header.contains("map50_95")) { + map = Double.parseDouble(value); + config.setMap(map); + } else if (header.contains("map50")) { + double map50 = Double.parseDouble(value); + log.debug("解析到mAP50: {}", map50); + // 如果没有mAP50-95,使用mAP50作为map值 + if (map == 0.0) { + map = map50; + config.setMap(map); + } + } + } catch (NumberFormatException e) { + log.debug("跳过无法解析的数值: {}", value); + } + } + + log.info("检测/分割任务解析到识别率 - Precision: {}, Recall: {}, mAP: {}", + precision, recall, map); + + // 保存检测/分割任务的识别率到数据库 + if (precision > 0 || recall > 0 || map > 0) { + saveAccuracyToDatabase(trainId, precision, recall, map); + } + + } else { + // 其他未知任务类型,尝试通用解析 + log.warn("未知任务类型: {}, 尝试通用解析", taskType); + parseGenericResults(headers, values, config); + } + + } catch (Exception e) { + log.warn("解析results.csv数值时出错: {}", lastLine, e); + } + + } catch (Exception e) { + log.error("读取results.csv文件时出错", e); + } + } + + /** + * 通用解析results.csv,适用于未知任务类型 + * + * @param headers CSV标题 + * @param values CSV数值 + * @param config YOLO配置对象 + */ + private void parseGenericResults(String[] headers, String[] values, YoloConfig config) { + try { + double precision = 0.0, recall = 0.0, map = 0.0, accuracy = 0.0; + + for (int i = 0; i < headers.length && i < values.length; i++) { + String header = headers[i].trim().toLowerCase(); + String value = values[i].trim(); + + try { + if (header.contains("precision")) { + precision = Double.parseDouble(value); + config.setPrecision(precision); + } else if (header.contains("recall")) { + recall = Double.parseDouble(value); + config.setRecall(recall); + } else if (header.contains("map")) { + map = Double.parseDouble(value); + config.setMap(map); + } else if (header.contains("accuracy") || header.contains("acc")) { + accuracy = Double.parseDouble(value); + config.setPrecision(accuracy); // 用accuracy替代precision + } + } catch (NumberFormatException e) { + log.debug("通用解析跳过无法解析的数值: {}", value); + } + } + + log.info("通用解析结果 - Precision/Accuracy: {}, Recall: {}, mAP: {}", + Math.max(precision, accuracy), recall, map); + + } catch (Exception e) { + log.error("通用解析失败", e); + } + } + + /** + * 保存分类任务的准确率到数据库 + * + * @param trainId 训练ID + * @param accuracy 准确率 + */ + private void saveClassificationAccuracyToDatabase(Integer trainId, double accuracy) { + try { + // 构建分类任务的准确率字符串 + String rateStr = String.format("accuracy:%.4f", accuracy); + + // 更新数据库中的准确率 + cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO trainResult = + cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO.builder() + .trainId(trainId) + .rate(rateStr) + .build(); + + saveTrainResultDO(trainResult); + + log.info("已保存分类任务准确率到数据库 - 训练ID: {}, 准确率: {}", trainId, rateStr); + + } catch (Exception e) { + log.error("保存分类任务准确率到数据库失败", e); + } + } + + /** + * 从日志文件中解析识别率 + * + * @param logFile 日志文件 + * @param config YOLO配置对象 + * @param trainId 训练ID + */ + private void parseLogFileForAccuracy(java.io.File logFile, YoloConfig config, Integer trainId) { + try { + log.info("尝试从日志文件解析识别率: {}", logFile.getAbsolutePath()); + + java.util.List lines = java.nio.file.Files.readAllLines(logFile.toPath()); + + // 反向遍历,查找最后的验证结果 + for (int i = lines.size() - 1; i >= 0; i--) { + String line = lines.get(i); + + // 查找包含精度、召回率、mAP等指标的行 + if (line.contains("mAP50") || line.contains("mAP50-95") || + line.contains("precision") || line.contains("recall")) { + + parseMetricsLine(line, config, trainId); + break; + } + } + + } catch (Exception e) { + log.error("读取日志文件时出错", e); + } + } + + /** + * 解析包含指标的行 + * + * @param line 包含指标的行 + * @param config YOLO配置对象 + * @param trainId 训练ID + */ + private void parseMetricsLine(String line, YoloConfig config, Integer trainId) { + try { + log.info("解析指标行: {}", line); + + // 使用正则表达式提取数值 + java.util.regex.Pattern precisionPattern = java.util.regex.Pattern.compile("precision[:\\s=]+([0-9.]+)"); + java.util.regex.Pattern recallPattern = java.util.regex.Pattern.compile("recall[:\\s=]+([0-9.]+)"); + java.util.regex.Pattern mapPattern = java.util.regex.Pattern.compile("mAP50-95[:\\s=]+([0-9.]+)"); + java.util.regex.Pattern map50Pattern = java.util.regex.Pattern.compile("mAP50[:\\s=]+([0-9.]+)"); + + java.util.regex.Matcher m; + + if ((m = precisionPattern.matcher(line)).find()) { + config.setPrecision(Double.parseDouble(m.group(1))); + } + if ((m = recallPattern.matcher(line)).find()) { + config.setRecall(Double.parseDouble(m.group(1))); + } + if ((m = mapPattern.matcher(line)).find()) { + config.setMap(Double.parseDouble(m.group(1))); + } + if ((m = map50Pattern.matcher(line)).find()) { + // 如果没有mAP50-95,使用mAP50作为mAP值 + if (config.getMap() == 0.0) { + config.setMap(Double.parseDouble(m.group(1))); + } + } + + // 保存识别率到数据库 + if (config.getPrecision() > 0 || config.getRecall() > 0 || config.getMap() > 0) { + saveAccuracyToDatabase(trainId, config.getPrecision(), config.getRecall(), config.getMap()); + log.info("成功解析识别率 - Precision: {}, Recall: {}, mAP: {}", + config.getPrecision(), config.getRecall(), config.getMap()); + } + + } catch (Exception e) { + log.error("解析指标行时出错", e); + } + } + + /** + * 从路径信息中解析训练信息(备用方法) + * + * @param savedPath 保存路径 + * @param config YOLO配置对象 + */ + private void parseTrainingInfoFromPath(String savedPath, YoloConfig config) { + try { + // 从路径中可能包含的训练信息,如train18等 + java.io.File pathFile = new java.io.File(savedPath); + String pathName = pathFile.getName(); + + if (pathName.startsWith("train") && pathName.length() > 5) { + try { + int trainNumber = Integer.parseInt(pathName.substring(5)); + log.info("从路径解析到训练编号: {}", trainNumber); + } catch (NumberFormatException e) { + log.debug("无法从路径解析训练编号: {}", pathName); + } + } + + } catch (Exception e) { + log.warn("从路径解析训练信息时出错", e); + } + } + + /** + * 保存结果路径到数据库 + * + * @param trainId 训练ID + * @param savedPath 保存路径 + * @param config YOLO配置对象 + */ + private void saveResultsPath(Integer trainId, String savedPath, YoloConfig config) { + try { + // 更新TrainResultDO + TrainResultDO trainResult = TrainResultDO.builder() + .trainId(trainId) + .path(savedPath) + .build(); + + saveTrainResultDO(trainResult); + TrainDO train = trainService.getTrain(trainId); + train.setOutPath(savedPath); + trainService.updateById(train); + log.info("已保存结果路径到数据库 - 训练ID: {}, 路径: {}", trainId, savedPath); + + } catch (Exception e) { + log.error("保存结果路径到数据库失败", e); + } + } + + /** + * 保存识别率到数据库 + * + * @param trainId 训练ID + * @param precision 精度 + * @param recall 召回率 + * @param map mAP值 + */ + private void saveAccuracyToDatabase(Integer trainId, double precision, double recall, double map) { + try { + // 构建识别率字符串 + StringBuilder rateStr = new StringBuilder(); + if (precision > 0) { + rateStr.append("precision:").append(String.format("%.4f", precision)); + } + if (recall > 0) { + if (rateStr.length() > 0) rateStr.append(","); + rateStr.append("recall:").append(String.format("%.4f", recall)); + } + if (map > 0) { + if (rateStr.length() > 0) rateStr.append(","); + rateStr.append("map:").append(String.format("%.4f", map)); + } + + // 更新数据库中的识别率 + TrainResultDO trainResult = + cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO.builder() + .trainId(trainId) + .rate(rateStr.toString()) + .build(); + + saveTrainResultDO(trainResult); + + log.info("已保存识别率到数据库 - 训练ID: {}, 识别率: {}", trainId, rateStr.toString()); + + } catch (Exception e) { + log.error("保存识别率到数据库失败", e); + } + } + + /** + * 管理训练日志在Redis中的存储 + * 当日志包含%时,进行特殊处理:删除最后一条,插入新的,最多保存50条,保存5天 + * + * @param trainId 训练ID + * @param logLine 日志行 + */ + private void manageTrainingLogInRedis(Integer trainId, String logLine) { + try { + if (logLine == null || logLine.trim().isEmpty()) { + return; + } + + String redisKey = "yolo:training_log:" + trainId; + boolean isProgressLog = logLine.contains("%"); + + if (isProgressLog) { + // 如果是进度日志(包含%),先删除之前的进度日志 + removeProgressLogs(redisKey); + } + + // 添加新的日志行到列表末尾 + redisUtil.lSet(redisKey, logLine, 5 * 24 * 60 * 60); // 5天过期时间(秒) + + // 检查列表长度,如果超过50条,删除第一条 + long newSize = redisUtil.lGetListSize(redisKey); + if (newSize > 50) { + redisUtil.lRemove(redisKey, 1, redisUtil.lGetIndex(redisKey, 0)); + } + + log.debug("训练日志已存储到Redis - 训练ID: {}, 日志类型: {}, 当前日志数量: {}", + trainId, isProgressLog ? "进度日志" : "普通日志", redisUtil.lGetListSize(redisKey)); + + } catch (Exception e) { + log.error("管理训练日志Redis操作失败 - 训练ID: {}", trainId, e); + } + } + + /** + * 删除Redis中所有包含%的进度日志 + * + * @param redisKey Redis键 + */ + private void removeProgressLogs(String redisKey) { + try { + long listSize = redisUtil.lGetListSize(redisKey); + if (listSize == 0) { + return; + } + + // 获取所有日志 + java.util.List allLogs = redisUtil.lGet(redisKey, 0, -1); + if (allLogs == null || allLogs.isEmpty()) { + return; + } + + // 找出并删除最后一条包含%的日志(保留100%的) + for (int i = allLogs.size() - 1; i >= 0; i--) { + Object logObj = allLogs.get(i); + if (logObj != null) { + String logStr = logObj.toString(); + if (logStr.contains("%")) { + // 检查是否是100%的日志,如果是则跳过 + if (logStr.contains("100%")) { + break; + } + redisUtil.lRemove(redisKey, 1, logStr); + log.debug("删除旧的进度日志: {}", logStr); + break; // 只删除最后一条非100%的进度日志 + } + } + } + + } catch (Exception e) { + log.error("删除进度日志时出错 - Redis键: {}", redisKey, e); + } + } + + /** + * 生成仅训练的Python脚本文件 + * + * @return 是否生成成功 + */ + public void generateTrainPythonScript(TrainDO train, String yoloDatasetPath) { + try { + StringBuilder script = new StringBuilder(); + + // 文件头部和导入 + script.append("# -*- coding: utf-8 -*-"); + script.append("\n"); + script.append("import sys"); + script.append("\n"); + script.append("import matplotlib.pyplot as plt"); + script.append("\n"); + script.append("from matplotlib import rcParams"); + script.append("\n"); + script.append("import platform"); + script.append("\n"); + + // 字体设置函数 + script.append("# 自动选择字体"); + script.append("\n"); + script.append("def set_font():"); + script.append("\n"); + script.append(" system_platform = platform.system()"); + script.append("\n"); + script.append(" if system_platform == \"Windows\":"); + script.append("\n"); + script.append(" rcParams['font.sans-serif'] = ['Microsoft YaHei'] # Windows 下使用微软雅黑"); + script.append("\n"); + script.append(" else:"); + script.append("\n"); + script.append(" rcParams['font.sans-serif'] = ['DejaVu Sans'] # Linux 下使用 DejaVu Sans"); + script.append("\n"); + script.append(" rcParams['axes.unicode_minus'] = False # 解决负号显示问题"); + script.append("\n"); + + // 训练函数 + script.append("# 训练函数"); + script.append("\n"); + script.append("def main():"); + script.append("\n"); + script.append(" set_font()"); + script.append("\n"); + script.append(" sys.path.insert(0, '.')"); + script.append("\n"); + script.append(" from ultralytics import YOLO"); + script.append("\n"); + + if (train.getModelPath() != null && !train.getModelPath().trim().isEmpty()) { + script.append(" model = YOLO('").append(escapePythonString(train.getModelPath())).append("')"); + } else { + script.append(" model = YOLO('yolo11n.pt')"); + } + script.append("\n"); + script.append(" "); + script.append("\n"); + script.append(" # 训练模型"); + script.append("\n"); + script.append(" model.train("); + script.append("\n"); + script.append(" data=r'").append(escapePythonString(train.getPath() + "/dataset.yaml")).append("',"); + script.append("\n"); + script.append(" epochs=").append(train.getRound()).append(","); + script.append("\n"); + script.append(" batch=").append(train.getSize()).append(","); + script.append("\n"); + script.append(" imgsz=").append(train.getImageSize()).append(","); + script.append("\n"); + script.append(" project=r'").append(escapePythonString(train.getPath())).append("'"); + script.append("\n"); + script.append(" )"); + script.append("\n"); + + // 程序入口 + script.append("if __name__ == '__main__':"); + script.append("\n"); + script.append(" main()"); + script.append("\n"); + + // 写入文件 + try (java.io.FileWriter writer = new java.io.FileWriter(train.getPath() + "/train.py", java.nio.charset.StandardCharsets.UTF_8)) { + writer.write(script.toString()); + } + + log.info("训练Python脚本生成成功: {}", train.getPath()); + + } catch (Exception e) { + log.error("生成训练Python脚本失败", e); + } + } + + +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloServiceClient.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloServiceClient.java new file mode 100644 index 0000000..ef4590d --- /dev/null +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloServiceClient.java @@ -0,0 +1,219 @@ +package cn.iocoder.yudao.module.annotation.service.yolo; + +import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; + +import java.io.*; +import java.net.Socket; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +/** + * YOLO服务客户端 - 通过socket与Python长连接服务通信 + * + * @author 管理员 + */ +@Slf4j +@Component +public class YoloServiceClient { + + private static final String DEFAULT_HOST = "localhost"; + private static final int DEFAULT_PORT = 9999; + private static final int CONNECTION_TIMEOUT = 30000; // 30秒超时 + + private final ObjectMapper objectMapper = new ObjectMapper(); + + /** + * 通过socket服务执行预测 + */ + public YoloConfig predictViaSocket(String modelPath, String inputPath, String outputPath, YoloConfig config) { + try { + Map params = new HashMap<>(); + params.put("model_path", modelPath); + params.put("input_path", inputPath); + if (outputPath != null) { + params.put("output_path", outputPath); + } + + // 添加其他YOLO配置参数 + addYoloConfigParams(params, config); + + Map request = new HashMap<>(); + request.put("command", "predict"); + request.put("params", params); + + String response = sendRequest(request); + JsonNode responseJson = objectMapper.readTree(response); + + if ("success".equals(responseJson.get("status").asText())) { + config.setStatus("completed"); + config.setLogMessage(responseJson.get("message").asText()); + if (responseJson.has("output_path")) { + config.setOutputPath(responseJson.get("output_path").asText()); + } + return config; + } else { + config.setStatus("failed"); + config.setErrorMessage(responseJson.get("message").asText()); + return config; + } + + } catch (Exception e) { + log.error("Socket预测失败", e); + config.setStatus("failed"); + config.setErrorMessage("Socket预测失败: " + e.getMessage()); + return config; + } + } + + /** + * 通过socket服务执行训练 + */ + public YoloConfig trainViaSocket(String modelPath, String dataPath, YoloConfig config) { + try { + Map params = new HashMap<>(); + params.put("model_path", modelPath); + params.put("data_path", dataPath); + params.put("epochs", config.getEpochs()); + + // 添加其他训练配置参数 + addYoloConfigParams(params, config); + + Map request = new HashMap<>(); + request.put("command", "train"); + request.put("params", params); + + String response = sendRequest(request); + JsonNode responseJson = objectMapper.readTree(response); + + if ("success".equals(responseJson.get("status").asText())) { + config.setStatus("completed"); + config.setLogMessage(responseJson.get("message").asText()); + if (responseJson.has("results_path")) { + config.setOutputPath(responseJson.get("results_path").asText()); + } + return config; + } else { + config.setStatus("failed"); + config.setErrorMessage(responseJson.get("message").asText()); + return config; + } + + } catch (Exception e) { + log.error("Socket训练失败", e); + config.setStatus("failed"); + config.setErrorMessage("Socket训练失败: " + e.getMessage()); + return config; + } + } + + /** + * 通过socket服务执行验证 + */ + public YoloConfig validateViaSocket(String modelPath, String dataPath, YoloConfig config) { + try { + Map params = new HashMap<>(); + params.put("model_path", modelPath); + if (dataPath != null) { + params.put("data_path", dataPath); + } + + // 添加其他验证配置参数 + addYoloConfigParams(params, config); + + Map request = new HashMap<>(); + request.put("command", "validate"); + request.put("params", params); + + String response = sendRequest(request); + JsonNode responseJson = objectMapper.readTree(response); + + if ("success".equals(responseJson.get("status").asText())) { + config.setStatus("completed"); + config.setLogMessage(responseJson.get("message").asText()); + if (responseJson.has("metrics")) { + config.setMetrics(responseJson.get("metrics").asText()); + } + return config; + } else { + config.setStatus("failed"); + config.setErrorMessage(responseJson.get("message").asText()); + return config; + } + + } catch (Exception e) { + log.error("Socket验证失败", e); + config.setStatus("failed"); + config.setErrorMessage("Socket验证失败: " + e.getMessage()); + return config; + } + } + + /** + * 添加YOLO配置参数到请求中 + */ + private void addYoloConfigParams(Map params, YoloConfig config) { + if (config.getImgsz() > 0) { + params.put("imgsz", config.getImgsz()); + } + if (config.getConfThresh() > 0) { + params.put("conf", config.getConfThresh()); + } + if (config.getIouThresh() > 0) { + params.put("iou", config.getIouThresh()); + } + if (config.getDevice() != null) { + params.put("device", config.getDevice()); + } + if (config.getBatchSize() > 0) { + params.put("batch", config.getBatchSize()); + } + params.put("save", config.isSave()); + params.put("save_txt", config.isSaveTxt()); + params.put("save_crops", config.isSaveCrops()); + } + + /** + * 发送请求到socket服务 + */ + private String sendRequest(Map request) throws IOException { + try (Socket socket = new Socket(DEFAULT_HOST, DEFAULT_PORT); + PrintWriter out = new PrintWriter(socket.getOutputStream(), true); + BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()))) { + + // 设置超时 + socket.setSoTimeout(CONNECTION_TIMEOUT); + + // 发送请求 + String requestJson = objectMapper.writeValueAsString(request); + out.println(requestJson); + + // 读取响应 + StringBuilder response = new StringBuilder(); + String line; + while ((line = in.readLine()) != null) { + response.append(line); + } + + return response.toString(); + } + } + + /** + * 检查服务是否可用 + */ + public boolean isServiceAvailable() { + try { + try (Socket socket = new Socket()) { + socket.connect(new java.net.InetSocketAddress(DEFAULT_HOST, DEFAULT_PORT), 3000); + return true; + } + } catch (IOException e) { + return false; + } + } +} \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/resources/mapper/mark/MarkMapper.xml b/yudao-module-annotation/yudao-module-annotation-biz/src/main/resources/mapper/mark/MarkMapper.xml index 2050639..e9c6920 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/resources/mapper/mark/MarkMapper.xml +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/resources/mapper/mark/MarkMapper.xml @@ -9,4 +9,9 @@ 文档可见:https://www.iocoder.cn/MyBatis/x-plugins/ --> + + UPDATE annotation_mark + SET status = #{status} + WHERE id = #{id} + \ No newline at end of file diff --git a/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/util/RedisUtil.java b/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/util/RedisUtil.java index d11c26b..f3361fa 100644 --- a/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/util/RedisUtil.java +++ b/yudao-module-system/yudao-module-system-biz/src/main/java/cn/iocoder/yudao/module/system/util/RedisUtil.java @@ -693,6 +693,17 @@ public class RedisUtil { return false; } } + public boolean lPush(String key, Object value, long time) { + try { + redisTemplate.opsForList().rightPush(key, value); + if (time > 0) + expire(key, time); + return true; + } catch (Exception e) { + log.error(key, e); + return false; + } + } /** * 将list放入缓存 diff --git a/yudao-server/src/main/resources/application-dev.yaml b/yudao-server/src/main/resources/application-dev.yaml index d31f72a..d47d4b5 100644 --- a/yudao-server/src/main/resources/application-dev.yaml +++ b/yudao-server/src/main/resources/application-dev.yaml @@ -47,10 +47,9 @@ spring: datasource: master: name: sy - url: jdbc:mysql://121.43.244.209:3306/${spring.datasource.dynamic.datasource.master.name}?useSSL=false&serverTimezone=Asia/Shanghai&allowPublicKeyRetrieval=true&nullCatalogMeansCurrent=true # MySQL Connector/J 8.X 连接的示例 - + url: jdbc:mysql://192.168.1.21:3306/${spring.datasource.dynamic.datasource.master.name}?useSSL=false&serverTimezone=Asia/Shanghai&allowPublicKeyRetrieval=true&nullCatalogMeansCurrent=true # MySQL Connector/J 8.X 连接的示例 username: root - password: Sy+1234567 + password: upright # Redis 配置。Redisson 默认的配置足够使用,一般不需要进行调优 data: redis: @@ -62,7 +61,7 @@ spring: max-wait: 5000ms # 增加获取连接的最大等待时间 shutdown-timeout: 1000ms # 增加关闭超时时间 timeout: 5000ms # 增加连接超时时间 - host: 121.43.244.209 # 地址 + host: 192.168.1.21 # 地址 port: 6379 # 端口 database: 0 # 数据库索引 # password: dev # 密码,建议生产环境开启