package cn.iocoder.yudao.module.annotation.service.yolo; import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO; import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig; import cn.iocoder.yudao.module.annotation.service.python.PythonVirtualEnvService; import cn.iocoder.yudao.module.annotation.service.python.vo.PythonVirtualEnvInfo; import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService; import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService; import cn.iocoder.yudao.module.system.util.RedisUtil; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; import java.io.BufferedReader; import java.io.File; import java.io.InputStreamReader; import java.util.UUID; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; /** * YOLO操作服务实现类 * * @author 管理员 */ @Slf4j @Service public class YoloOperationServiceImpl implements YoloOperationService { @Resource private PythonVirtualEnvService pythonVirtualEnvService; @Resource private TrainInfoService trainInfoService; @Resource private TrainResultService trainResultService; // 异步执行器 private final Executor asyncExecutor = Executors.newCachedThreadPool(); @Override public CompletableFuture trainAsync(String pythonPath, String envName, YoloConfig config) { return CompletableFuture.supplyAsync(() -> { config.setStatus("running"); config.setProgress(0); config.setTaskId(UUID.randomUUID().toString()); // 生成训练ID Integer trainId = config.getTrainId(); // 记录上一轮的epoch,用于检测epoch完成 int[] lastSavedEpoch = {0}; try { // 检测虚拟环境 PythonVirtualEnvInfo envInfo = pythonVirtualEnvService.activateVirtualEnv(pythonPath, envName); if (!envInfo.getVirtualEnvExists() || !envInfo.getYoloInstalled()) { config.setStatus("failed"); config.setErrorMessage("虚拟环境不存在或未安装YOLO"); return config; } // 构建训练命令 String command = buildTrainCommand(envInfo, config); log.info("开始YOLO训练,命令: {}", command); // 执行训练 ProcessBuilder processBuilder = new ProcessBuilder(); // 在Windows上使用cmd,在Linux/Mac上使用bash String os = System.getProperty("os.name").toLowerCase(); if (os.contains("win")) { processBuilder.command("cmd", "/c", command); // 设置环境变量确保正确处理中文路径 processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); processBuilder.environment().put("LANG", "zh_CN.UTF-8"); } else { processBuilder.command("/bin/bash", "-c", command); processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); } Process process = processBuilder.start(); BufferedReader stdoutReader = new BufferedReader(new InputStreamReader(process.getInputStream())); BufferedReader stderrReader = new BufferedReader(new InputStreamReader(process.getErrorStream())); // BufferedReader stderrReader2 = new BufferedReader(new InputStreamReader(process.getOutputStream())); StringBuilder fullOutput = new StringBuilder(); // 修复:使用单独的线程读取stdout和stderr,避免死锁 CompletableFuture stdoutFuture = CompletableFuture.runAsync(() -> { try { String line; while ((line = stdoutReader.readLine()) != null) { log.info("YOLO训练日志: {}", line); synchronized (fullOutput) { fullOutput.append(line).append("\n"); } // 详细解析YOLO训练进度和结果 manageTrainingLogInRedis(trainId, line); config.setLogMessage(line); // 检测 "Results saved to" 并解析路径和识别率 detectAndParseResultsSaved(line, config, trainId); } } catch (Exception e) { log.error("读取stdout时发生错误", e); } }, asyncExecutor); CompletableFuture stderrFuture = CompletableFuture.runAsync(() -> { try { String line; while ((line = stderrReader.readLine()) != null) { log.info("YOLO训练日志: {}", line); synchronized (fullOutput) { fullOutput.append(line).append("\n"); } manageTrainingLogInRedis(trainId, line); config.setLogMessage(line); // 检测 "Results saved to" 并解析路径和识别率 detectAndParseResultsSaved(line, config, trainId); } } catch (Exception e) { log.error("读取stderr时发生错误", e); } }, asyncExecutor); // 等待两个线程完成,但设置超时避免无限等待 try { CompletableFuture.allOf(stdoutFuture, stderrFuture).get(30, TimeUnit.MINUTES); } catch (java.util.concurrent.TimeoutException e) { log.error("读取YOLO输出超时,强制终止进程"); process.destroyForcibly(); config.setStatus("failed"); config.setErrorMessage("训练超时"); return config; } int exitCode = process.waitFor(); stdoutReader.close(); stderrReader.close(); if (exitCode == 0) { config.setStatus("completed"); config.setProgress(100); config.setLogMessage("训练完成"); // 解析最终训练结果摘要 parseFinalResults(fullOutput.toString(), config); // 保存最终训练结果 saveFinalTrainResult(trainId, config, fullOutput.toString()); } else { config.setStatus("failed"); config.setErrorMessage("训练失败,退出码: " + exitCode); } } catch (Exception e) { log.error("YOLO训练异常", e); config.setStatus("failed"); config.setErrorMessage("训练异常: " + e.getMessage()); } return config; }, asyncExecutor); } // 这里需要实现其他方法和缺失的方法,由于文件太长,我只修改了关键部分 // 实际使用时需要保留原有的其他方法实现 // 以下方法占位符,实际实现从原文件复制 private String buildTrainCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) { return ""; } private void manageTrainingLogInRedis(Integer trainId, String line) {} private void detectAndParseResultsSaved(String line, YoloConfig config, Integer trainId) {} private void parseFinalResults(String output, YoloConfig config) {} private void saveFinalTrainResult(Integer trainId, YoloConfig config, String output) {} @Override public YoloConfig validate(String pythonPath, String envName, YoloConfig config) { // 实现从原文件复制 return config; } @Override public YoloConfig predict(String pythonPath, String envName, YoloConfig config) { // 实现从原文件复制 return config; } @Override public YoloConfig classify(String pythonPath, String envName, YoloConfig config) { // 实现从原文件复制 return config; } @Override public YoloConfig export(String pythonPath, String envName, YoloConfig config) { // 实现从原文件复制 return config; } @Override public YoloConfig track(String pythonPath, String envName, YoloConfig config) { // 实现从原文件复制 return config; } @Override public void generateTrainPythonScript(TrainDO train, String yoloDatasetPath) { // 实现从原文件复制 } }