diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/TrainController.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/TrainController.java index 401e2d5..e6995a6 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/TrainController.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/TrainController.java @@ -81,7 +81,7 @@ public class TrainController { @Operation(summary = "进行训练") // @Parameter(name = "id", description = "编号", required = true, example = "1024") @PreAuthorize("@ss.hasPermission('annotation:train:query')") - public CommonResult testRecognition(@RequestBody RecognitionResult recognitionReqVO) { + public CommonResult testRecognition(@RequestBody RecognitionResult recognitionReqVO) { String train = trainService.testRecognition(recognitionReqVO.getImageUrl(),recognitionReqVO.getTrainId()); return success( train); diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/TrainInfoController.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/TrainInfoController.java index c3b06e6..ce31b47 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/TrainInfoController.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/TrainInfoController.java @@ -69,6 +69,7 @@ public class TrainInfoController { TrainInfoDO trainInfo = trainInfoService.getTrainInfo(id); return success(BeanUtils.toBean(trainInfo, TrainInfoRespVO.class)); } + @Resource RedisUtil redisUtil; @@ -76,6 +77,7 @@ public class TrainInfoController { @Operation(summary = "获得识别结果") @Parameter(name = "id", description = "编号", required = true, example = "1024") @PreAuthorize("@ss.hasPermission('annotation:train-Info:query')") + public CommonResult> getTrainInfoList(@RequestParam("trainId") Integer id) { List result = redisUtil.lGet("yolo:training_log:" + id, 0, -1); List result1 = result.stream().map(Object::toString).toList(); diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/yolo/YoloOperationController.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/yolo/YoloOperationController.java index 29a57f6..e45a888 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/yolo/YoloOperationController.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/yolo/YoloOperationController.java @@ -87,6 +87,7 @@ public class YoloOperationController { TrainDO trainDO = trainService.getTrain(id); // trainDO.setPath(dictDataMap.get("training_address").getValue()); YoloConfig config = trainDO.getYolofig(dictDataMap); + config.setTrainId(id); config.setTrainId(id); @@ -105,6 +106,7 @@ public class YoloOperationController { TrainDO trainDO = trainService.getTrain(id); // trainDO.setPath(dictDataMap.get("training_address").getValue()); YoloConfig config = trainDO.getYolofig(dictDataMap); + config.setTrainId(id); YoloConfig result = yoloOperationService.validate( dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); @@ -120,6 +122,7 @@ public class YoloOperationController { TrainDO trainDO = trainService.getTrain(id); // trainDO.setPath(dictDataMap.get("training_address").getValue()); YoloConfig config = trainDO.getYolofig(dictDataMap); + config.setTrainId(id); YoloConfig result = yoloOperationService.predict( dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); @@ -135,6 +138,7 @@ public class YoloOperationController { TrainDO trainDO = trainService.getTrain(id); // trainDO.setPath(dictDataMap.get("training_address").getValue()); YoloConfig config = trainDO.getYolofig(dictDataMap); + config.setTrainId(id); YoloConfig result = yoloOperationService.classify( dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); @@ -150,6 +154,7 @@ public class YoloOperationController { TrainDO trainDO = trainService.getTrain(id); // trainDO.setPath(dictDataMap.get("training_address").getValue()); YoloConfig config = trainDO.getYolofig(dictDataMap); + config.setTrainId(id); YoloConfig result = yoloOperationService.export( dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); @@ -165,6 +170,7 @@ public class YoloOperationController { TrainDO trainDO = trainService.getTrain(id); // trainDO.setPath(dictDataMap.get("training_address").getValue()); YoloConfig config = trainDO.getYolofig(dictDataMap); + config.setTrainId(id); YoloConfig result = yoloOperationService.track( dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/train/TrainDO.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/train/TrainDO.java index 010197d..29e2a86 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/train/TrainDO.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/train/TrainDO.java @@ -75,54 +75,63 @@ public class TrainDO extends BaseDO { private String type; + private String outPath; public YoloConfig getYolofig(Map dictDataMap) { YoloConfig config = new YoloConfig(); // 设置数据相关配置 config.setDatasetPath(this.path); - + // 设置训练参数 + if (this.size != null) { config.setEpochs(this.size); } if (this.round != null) { config.setBatchSize(this.round); - } - if (this.imageSize != null) { - config.setImageSize(this.imageSize); - } - - // 设置数据集比例(YOLO中需要转换) - // 注意:YOLO通常通过数据集配置文件来设置train/val/test比例 - // 这里我们可以将比例信息保存到配置中,供后续处理使用 - - // 设置预训练模型路径 - if (this.type.equals("1")) { - config.setModelPath(dictDataMap.get("detect_path").getValue()); - config.setPretrained(true); - }else { - config.setModelPath(dictDataMap.get("classify_path").getValue()); - config.setPretrained(false); - } - - // 设置GPU/CPU设备 + if (this.round != null) { + config.setEpochs(this.round); + } + if (this.size != null) { + config.setBatchSize(this.size); + } + if (this.imageSize != null) { + config.setImageSize(this.imageSize); + } + + // 设置数据集比例(YOLO中需要转换) + // 注意:YOLO通常通过数据集配置文件来设置train/val/test比例 + // 这里我们可以将比例信息保存到配置中,供后续处理使用 + + // 设置预训练模型路径 + if (this.type.equals("1")) { + config.setModelPath(dictDataMap.get("detect_path").getValue()); + config.setPretrained(true); + } else { + config.setModelPath(dictDataMap.get("classify_path").getValue()); + config.setPretrained(false); + } + + // 设置GPU/CPU设备 + + + // 设置输出路径(默认在数据集路径下的runs/train目录) + if (this.path != null) { + config.setOutputPath(this.path + "/runs/train"); + } - - // 设置输出路径(默认在数据集路径下的runs/train目录) - if (this.path != null) { - config.setOutputPath(this.path + "/runs/train"); + // 设置其他默认参数 + config.setTaskType(this.type.equals("1") ? "detect" : "classify"); + config.setModelName(this.modelPath); + config.setModelPath(this.modelPath); + config.setLearningRate(0.01); + config.setConfThresh(0.25); + config.setIouThresh(0.45); + config.setSave(true); + config.setSaveTxt(true); + return config; } - - // 设置其他默认参数 - config.setTaskType(this.type.equals("1") ? "detect" : "classify"); - config.setModelName(this.modelPath); - config.setModelPath(this.modelPath); - config.setLearningRate(0.01); - config.setConfThresh(0.25); - config.setIouThresh(0.45); - config.setSave(true); - config.setSaveTxt(true); return config; } } \ No newline at end of file diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/yolo/YoloConfig.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/yolo/YoloConfig.java index 2fe0b18..a26a9b3 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/yolo/YoloConfig.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/yolo/YoloConfig.java @@ -205,8 +205,8 @@ public class YoloConfig { /** * 训练结果指标(如mAP, precision, recall等) */ + private String metrics; - /** * 任务ID(用于异步任务追踪) */ diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/python/PythonVirtualEnvServiceImpl.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/python/PythonVirtualEnvServiceImpl.java index 7c32119..465f49c 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/python/PythonVirtualEnvServiceImpl.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/python/PythonVirtualEnvServiceImpl.java @@ -316,6 +316,7 @@ public class PythonVirtualEnvServiceImpl implements PythonVirtualEnvService { } // 构建虚拟环境路径:Python目录/envName + String venvPath = pythonDir + File.separator +"venv"+File.separator + envName; File venvDir = new File(venvPath); diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainService.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainService.java index 2203928..6bfd7ac 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainService.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainService.java @@ -7,6 +7,7 @@ import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.TrainSaveReq import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; import com.baomidou.mybatisplus.extension.service.IService; import jakarta.validation.Valid; + import org.springframework.web.multipart.MultipartFile; import java.util.List; @@ -65,6 +66,7 @@ public interface TrainService extends IService { public boolean init(TrainSaveReqVO createReqVO); + List testImages(); /** diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainServiceImpl.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainServiceImpl.java index 7759ce0..6be368d 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainServiceImpl.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainServiceImpl.java @@ -9,16 +9,20 @@ import cn.iocoder.yudao.module.annotation.dal.dataobject.datas.DatasDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkInfoDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; + import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.types.TypesDO; import cn.iocoder.yudao.module.annotation.dal.mysql.train.TrainMapper; import cn.iocoder.yudao.module.annotation.dal.mysql.trainresult.TrainResultMapper; import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig; +import cn.iocoder.yudao.module.annotation.dal.dataobject.types.TypesDO; +import cn.iocoder.yudao.module.annotation.dal.mysql.train.TrainMapper; import cn.iocoder.yudao.module.annotation.service.MarkInfo.AnnotationData; import cn.iocoder.yudao.module.annotation.service.MarkInfo.MarkInfoService; import cn.iocoder.yudao.module.annotation.service.datas.DatasService; import cn.iocoder.yudao.module.annotation.service.mark.MarkService; import cn.iocoder.yudao.module.annotation.service.types.TypesService; + import cn.iocoder.yudao.module.annotation.service.yolo.YoloOperationService; import cn.iocoder.yudao.module.system.dal.dataobject.dict.DictDataDO; import cn.iocoder.yudao.module.system.service.dict.DictDataService; @@ -29,6 +33,7 @@ import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import org.springframework.validation.annotation.Validated; + import org.springframework.web.multipart.MultipartFile; import javax.imageio.ImageIO; @@ -90,6 +95,7 @@ public class TrainServiceImpl extends ServiceImpl implemen // 返回 return train.getId(); } + @Resource YoloOperationService yoloOperationService; public boolean init(TrainSaveReqVO createReqVO) { @@ -127,7 +133,9 @@ public class TrainServiceImpl extends ServiceImpl implemen // 生成YOLO配置文件 generateYamlConfig(typesList, taskType, yoloDatasetPath); update( new UpdateWrapper().eq("id", createReqVO.getId()).set("train_type", 6)); + yoloOperationService.generateTrainPythonScript(train, yoloDatasetPath); + log.info("YOLO数据集生成完成: {}", yoloDatasetPath); return true; @@ -137,6 +145,7 @@ public class TrainServiceImpl extends ServiceImpl implemen } } + @Override public List testImages() { Map dictDataMap = dictDataService.getDictDataList("visual_annotation_conf"); @@ -337,6 +346,7 @@ public class TrainServiceImpl extends ServiceImpl implemen */ private String createYoloDatasetPath(TrainDO train, String customPath) { String basePath = customPath != null ? customPath : System.getProperty("user.dir") + "/yolo_datasets"; + String datasetPath = basePath + "/datas_" + train.getId(); File datasetDir = new File(datasetPath); @@ -433,7 +443,8 @@ public class TrainServiceImpl extends ServiceImpl implemen File labelFile = new File(labelDir, labelFileName); List markInfoList = markInfoService.list(new QueryWrapper() - .eq("mark_id", mark.getId())); + + .eq("data_id", mark.getDataId()).eq("mark_id", mark.getId())); try (PrintWriter writer = new PrintWriter(new FileWriter(labelFile))) { for (MarkInfoDO markInfo : markInfoList) { @@ -776,6 +787,7 @@ public class TrainServiceImpl extends ServiceImpl implemen return gpuList; } + public static void main(String[] args) { File dir = new File("D:/data"); File[] files = dir.listFiles(); @@ -786,6 +798,7 @@ public class TrainServiceImpl extends ServiceImpl implemen } } System.out.println(fileNames); + } private SystemInfoRespVO.PythonInfo getPythonInfo() { SystemInfoRespVO.PythonInfo pythonInfo = new SystemInfoRespVO.PythonInfo(); diff --git a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloOperationServiceImpl.java b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloOperationServiceImpl.java index aabb9df..156ed6f 100644 --- a/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloOperationServiceImpl.java +++ b/yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloOperationServiceImpl.java @@ -1,11 +1,13 @@ package cn.iocoder.yudao.module.annotation.service.yolo; + import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultSaveReqVO; import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO; import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig; import cn.iocoder.yudao.module.annotation.service.python.PythonVirtualEnvService; import cn.iocoder.yudao.module.annotation.service.python.vo.PythonVirtualEnvInfo; + import cn.iocoder.yudao.module.annotation.service.train.TrainService; import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService; import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService; @@ -15,6 +17,10 @@ import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Lazy; +import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService; +import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService; +import jakarta.annotation.Resource; +import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import java.io.BufferedReader; @@ -42,6 +48,7 @@ public class YoloOperationServiceImpl implements YoloOperationService { @Resource private TrainResultService trainResultService; + @Resource private YoloServiceClient yoloServiceClient; @@ -51,17 +58,16 @@ public class YoloOperationServiceImpl implements YoloOperationService { // 异步执行器 private final Executor asyncExecutor = Executors.newCachedThreadPool(); - @Override public CompletableFuture trainAsync(String pythonPath, String envName, YoloConfig config) { return CompletableFuture.supplyAsync(() -> { config.setStatus("running"); config.setProgress(0); config.setTaskId(UUID.randomUUID().toString()); - + // 生成训练ID Integer trainId = config.getTrainId(); - + // 记录上一轮的epoch,用于检测epoch完成 int[] lastSavedEpoch = {0}; @@ -76,16 +82,13 @@ public class YoloOperationServiceImpl implements YoloOperationService { // 构建训练命令 String command = buildTrainCommand(envInfo, config); - log.info("开始YOLO训练,命令: {}", command); // 执行训练 ProcessBuilder processBuilder = new ProcessBuilder(); - // 在Windows上使用cmd,在Linux/Mac上使用bash String os = System.getProperty("os.name").toLowerCase(); if (os.contains("win")) { processBuilder.command("cmd", "/c", command); - // 设置环境变量确保正确处理中文路径 processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); processBuilder.environment().put("LANG", "zh_CN.UTF-8"); } else { @@ -93,36 +96,35 @@ public class YoloOperationServiceImpl implements YoloOperationService { processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); } Process process = processBuilder.start(); + BufferedReader stdoutReader = new BufferedReader(new InputStreamReader(process.getInputStream())); BufferedReader stderrReader = new BufferedReader(new InputStreamReader(process.getErrorStream())); -// BufferedReader stderrReader2 = new BufferedReader(new InputStreamReader(process.getOutputStream())); StringBuilder fullOutput = new StringBuilder(); - - // 使用单独的线程读取stdout和stderr,避免死锁 - String stdoutLine, stderrLine = null; - while ((stdoutLine = stdoutReader.readLine()) != null || (stderrLine = stderrReader.readLine()) != null ) { + String stdoutLine; + String stderrLine = null; + + // 同时读取标准输出和错误输出 + while ((stdoutLine = stdoutReader.readLine()) != null || (stderrLine = stderrReader.readLine()) != null) { if (stdoutLine != null) { log.info("YOLO训练日志: {}", stdoutLine); fullOutput.append(stdoutLine).append("\n"); - + // 详细解析YOLO训练进度和结果 manageTrainingLogInRedis(trainId, stdoutLine); config.setLogMessage(stdoutLine); - - // 检测 "Results saved to" 并解析路径和识别率 detectAndParseResultsSaved(stdoutLine, config, trainId); + parseTrainingProgress(stdoutLine, config, trainId, lastSavedEpoch); } + if (stderrLine != null) { log.info("YOLO训练日志: {}", stderrLine); fullOutput.append(stderrLine).append("\n"); manageTrainingLogInRedis(trainId, stderrLine); - config.setLogMessage(stderrLine); - - // 检测 "Results saved to" 并解析路径和识别率 detectAndParseResultsSaved(stderrLine, config, trainId); + parseTrainingProgress(stderrLine, config, trainId, lastSavedEpoch); } } @@ -148,12 +150,14 @@ public class YoloOperationServiceImpl implements YoloOperationService { config.setStatus("failed"); config.setErrorMessage("训练异常: " + e.getMessage()); } + return config; }, asyncExecutor); } @Override public YoloConfig validate(String pythonPath, String envName, YoloConfig config) { + // 优先尝试使用socket服务(如果启用且可用) if (useSocketService && yoloServiceClient != null && yoloServiceClient.isServiceAvailable()) { log.info("使用socket服务执行YOLO验证"); @@ -233,6 +237,7 @@ public class YoloOperationServiceImpl implements YoloOperationService { @Override public YoloConfig predict(String pythonPath, String envName, YoloConfig config) { + // 优先尝试使用socket服务(如果启用且可用) if (useSocketService && yoloServiceClient != null && yoloServiceClient.isServiceAvailable()) { log.info("使用socket服务执行YOLO预测"); @@ -281,6 +286,7 @@ public class YoloOperationServiceImpl implements YoloOperationService { if (stdoutLine != null) { output.append(stdoutLine).append("\n"); log.info("YOLO预测日志: {}", stdoutLine); + if (stdoutLine.contains("Results saved to")) { stdoutLine = stdoutLine.replace("\u001B","").replace("[1m","").replace("[0m",""); @@ -291,6 +297,7 @@ public class YoloOperationServiceImpl implements YoloOperationService { output.append(stderrLine).append("\n"); log.info("YOLO预测日志: {}", stderrLine); + if (stderrLine.contains("Results saved to")) { stderrLine = stderrLine.replace("\u001B","").replace("[1m","").replace("[0m",""); config.setOutputPath(stderrLine.split(" ")[3]); @@ -445,13 +452,16 @@ public class YoloOperationServiceImpl implements YoloOperationService { if ("conda".equals(envInfo.getVirtualEnvType())) { command.append("conda run -n ").append(new File(envInfo.getVirtualEnvPath()).getName()).append(" "); } else { + command.append(normalizePath(envInfo.getVirtualEnvPythonPath())).append(" "); + command.append(envInfo.getVirtualEnvPythonPath()).append(" "); } // 规范化路径,处理双斜杠和中文路径问题 String datasetPath = normalizePath(config.getDatasetPath()); String outputPath = normalizePath(config.getOutputPath()); String modelPath = config.getModelPath() != null ? normalizePath(config.getModelPath()) : null; + // 调试日志:检查模型路径 log.info("模型路径配置 - modelPath: {}, modelName: {}", config.getModelPath(), config.getModelName()); @@ -485,6 +495,36 @@ public class YoloOperationServiceImpl implements YoloOperationService { // // command.append("project=r'").append(escapePythonString(outputPath)).append("')\""); // + + + + command.append("-c \""); + command.append("import sys; sys.path.insert(0, '.'); "); + command.append("from ultralytics import YOLO; "); + + if (modelPath != null && !modelPath.trim().isEmpty()) { + command.append("model = YOLO(r'").append(escapePythonString(modelPath)).append("'); "); + } else { + command.append("model = YOLO('").append(config.getModelName()).append("'); "); + } + + // 使用raw string来处理中文路径 + command.append("model.train(data=r'").append(escapePythonString(datasetPath)).append("', "); + command.append("epochs=").append(config.getEpochs()).append(", "); + command.append("batch=").append(config.getBatchSize()).append(", "); + command.append("imgsz=").append(config.getImageSize()).append(", "); + + // 取消注释设备和学习率配置 +// command.append("device=").append(config.getDevice()).append(", "); +// command.append("lr0=").append(config.getLearningRate()).append(", "); +// command.append("workers=").append(config.getWorkers()).append(", "); +// + if (config.getSavePeriod() > 0) { + command.append("save_period=").append(config.getSavePeriod()).append(", "); + } + + command.append("project=r'").append(escapePythonString(outputPath)).append("')\""); + return command.toString(); } @@ -515,15 +555,17 @@ public class YoloOperationServiceImpl implements YoloOperationService { return command.toString(); } + private String buildPredictCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) { StringBuilder command = new StringBuilder(); - + if ("conda".equals(envInfo.getVirtualEnvType())) { command.append("conda run -n ").append(new File(envInfo.getVirtualEnvPath()).getName()).append(" "); } else { command.append(envInfo.getVirtualEnvPythonPath()).append(" "); } + String inputPath = normalizePath(config.getInputPath()); String outputPath = normalizePath(config.getOutputPath()); String modelPath = normalizePath(config.getModelPath()); @@ -533,6 +575,7 @@ public class YoloOperationServiceImpl implements YoloOperationService { command.append("from ultralytics import YOLO; "); command.append("model = YOLO(r'").append(escapePythonString(modelPath)).append("'); "); command.append("results = model.predict(source=r'").append(escapePythonString(inputPath)).append("', "); + // 在这里我们将 save=true 改为 save=True command.append("save=").append(config.isSave() ? "True" : "False").append(", "); command.append("project='").append(escapePythonString(outputPath)).append("')\""); @@ -541,6 +584,8 @@ public class YoloOperationServiceImpl implements YoloOperationService { } + + private String buildExportCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) { StringBuilder command = new StringBuilder(); @@ -582,6 +627,106 @@ public class YoloOperationServiceImpl implements YoloOperationService { return command.toString(); } + + /** + * 解析训练进度详细信息 + */ + private void parseTrainingProgress(String line, YoloConfig config, Integer trainId, int[] lastSavedEpoch) { + try { + String trimmedLine = line.trim(); + + // 解析Epoch信息: "Epoch 1/100" 或 "1/100" + if (trimmedLine.contains("Epoch") || (trimmedLine.matches("\\d+/\\d+"))) { + String[] parts = trimmedLine.split("\\s+"); + for (String part : parts) { + if (part.contains("/")) { + String[] epochParts = part.split("/"); + if (epochParts.length == 2) { + int currentEpoch = Integer.parseInt(epochParts[0]); + int totalEpochs = Integer.parseInt(epochParts[1]); + config.setCurrentEpoch(currentEpoch); + // 计算总体进度:epoch进度 + batch进度 + int epochProgress = (int) ((double) currentEpoch / totalEpochs * 90); // 90%用于epoch,10%用于最终处理 + int totalProgress = Math.min(epochProgress + (config.getBatchProgress() / 10), 95); + config.setProgress(totalProgress); + + // 当epoch完成时保存训练信息 + if (currentEpoch > lastSavedEpoch[0] && currentEpoch <= config.getEpochs()) { + saveTrainEpochInfo(trainId, currentEpoch, config.getEpochs(), config); + lastSavedEpoch[0] = currentEpoch; + } + } + } + } + } + + // 解析批次信息: "1/100" 或 "batch 1/100" + if (trimmedLine.contains("batch") || trimmedLine.matches("\\d+/\\d+.*loss")) { + if (trimmedLine.contains("batch")) { + String[] parts = trimmedLine.split("\\s+"); + for (int i = 0; i < parts.length; i++) { + if (parts[i].equals("batch") && i + 1 < parts.length && parts[i + 1].contains("/")) { + String[] batchParts = parts[i + 1].split("/"); + if (batchParts.length == 2) { + int currentBatch = Integer.parseInt(batchParts[0]); + int totalBatches = Integer.parseInt(batchParts[1]); + config.setCurrentBatch(currentBatch); + config.setTotalBatches(totalBatches); + config.setBatchProgress((int) ((double) currentBatch / totalBatches * 100)); + } + } + } + } + } + + // 解析损失值: "loss: 0.1234", "box_loss: 0.1", "cls_loss: 0.05", "obj_loss: 0.02" + if (trimmedLine.contains("loss")) { + String[] lossParts = trimmedLine.split(","); + for (String lossPart : lossParts) { + lossPart = lossPart.trim(); + if (lossPart.startsWith("loss:")) { + String value = lossPart.substring(5).trim(); + config.setLoss(Double.parseDouble(value)); + } else if (lossPart.startsWith("box_loss:")) { + // 可以添加更详细的损失记录 + } else if (lossPart.startsWith("cls_loss:")) { + // 可以添加更详细的损失记录 + } else if (lossPart.startsWith("obj_loss:")) { + // 可以添加更详细的损失记录 + } + } + } + + // 解析训练指标: "metrics/precision(B): 0.85", "metrics/recall(B): 0.78", "metrics/mAP50(B): 0.82", "metrics/mAP50-95(B): 0.65" + if (trimmedLine.contains("metrics/")) { + String[] metricParts = trimmedLine.split(","); + for (String metricPart : metricParts) { + metricPart = metricPart.trim(); + if (metricPart.startsWith("metrics/precision")) { + String value = metricPart.split(":")[1].trim(); + config.setPrecision(Double.parseDouble(value)); + } else if (metricPart.startsWith("metrics/recall")) { + String value = metricPart.split(":")[1].trim(); + config.setRecall(Double.parseDouble(value)); + } else if (metricPart.startsWith("metrics/mAP50-95") || metricPart.startsWith("metrics/mAP")) { + String value = metricPart.split(":")[1].trim(); + config.setMap(Double.parseDouble(value)); + } + } + } + + // 解析学习率: "lr: 0.001" + if (trimmedLine.contains("lr:")) { + // 可以记录学习率变化 + } + + } catch (Exception e) { + log.debug("解析训练进度时出错: {}", line, e); + // 忽略解析错误,继续处理下一行 + } + } + + /** * 解析最终训练结果 @@ -771,9 +916,12 @@ public class YoloOperationServiceImpl implements YoloOperationService { // 不影响训练完成状态 } } + @Resource @Lazy TrainService trainService; + + /** * 保存TrainResultDO到数据库 @@ -781,8 +929,11 @@ public class YoloOperationServiceImpl implements YoloOperationService { private void saveTrainResultDO(TrainResultDO trainResult) { try { // 创建SaveReqVO + TrainResultSaveReqVO saveReqVO = new TrainResultSaveReqVO(); + + // 设置值 saveReqVO.setTrainId(trainResult.getTrainId()); @@ -792,8 +943,11 @@ public class YoloOperationServiceImpl implements YoloOperationService { // 保存 trainResultService.createTrainResult(saveReqVO); + trainService.update(new UpdateWrapper().eq("id", trainResult.getTrainId()) .set("train_type", 4)); + + } catch (Exception e) { log.error("保存训练结果失败: {}", trainResult, e); @@ -915,6 +1069,7 @@ public class YoloOperationServiceImpl implements YoloOperationService { return escaped.toString(); } + @Resource private RedisUtil redisUtil; /** @@ -1595,4 +1750,6 @@ public class YoloOperationServiceImpl implements YoloOperationService { } + + } \ No newline at end of file diff --git a/yudao-server/src/main/resources/application-dev.yaml b/yudao-server/src/main/resources/application-dev.yaml index d47d4b5..22f5ec2 100644 --- a/yudao-server/src/main/resources/application-dev.yaml +++ b/yudao-server/src/main/resources/application-dev.yaml @@ -48,6 +48,7 @@ spring: master: name: sy url: jdbc:mysql://192.168.1.21:3306/${spring.datasource.dynamic.datasource.master.name}?useSSL=false&serverTimezone=Asia/Shanghai&allowPublicKeyRetrieval=true&nullCatalogMeansCurrent=true # MySQL Connector/J 8.X 连接的示例 + username: root password: upright # Redis 配置。Redisson 默认的配置足够使用,一般不需要进行调优