|
|
|
|
@ -1,9 +1,22 @@
|
|
|
|
|
package cn.iocoder.yudao.module.annotation.service.yolo;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultSaveReqVO;
|
|
|
|
|
import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO;
|
|
|
|
|
import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO;
|
|
|
|
|
import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig;
|
|
|
|
|
import cn.iocoder.yudao.module.annotation.service.python.PythonVirtualEnvService;
|
|
|
|
|
import cn.iocoder.yudao.module.annotation.service.python.vo.PythonVirtualEnvInfo;
|
|
|
|
|
|
|
|
|
|
import cn.iocoder.yudao.module.annotation.service.train.TrainService;
|
|
|
|
|
import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService;
|
|
|
|
|
import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService;
|
|
|
|
|
import cn.iocoder.yudao.module.system.util.RedisUtil;
|
|
|
|
|
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
|
|
|
|
|
import jakarta.annotation.Resource;
|
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
|
|
import org.springframework.beans.factory.annotation.Value;
|
|
|
|
|
import org.springframework.context.annotation.Lazy;
|
|
|
|
|
import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService;
|
|
|
|
|
import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService;
|
|
|
|
|
import jakarta.annotation.Resource;
|
|
|
|
|
@ -36,19 +49,25 @@ public class YoloOperationServiceImpl implements YoloOperationService {
|
|
|
|
|
@Resource
|
|
|
|
|
private TrainResultService trainResultService;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Resource
|
|
|
|
|
private YoloServiceClient yoloServiceClient;
|
|
|
|
|
|
|
|
|
|
@Value("${yolo.service.enabled:false}")
|
|
|
|
|
private boolean useSocketService;
|
|
|
|
|
|
|
|
|
|
// 异步执行器
|
|
|
|
|
private final Executor asyncExecutor = Executors.newCachedThreadPool();
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public CompletableFuture<YoloConfig> trainAsync(String pythonPath, String envName, YoloConfig config) {
|
|
|
|
|
return CompletableFuture.supplyAsync(() -> {
|
|
|
|
|
config.setStatus("running");
|
|
|
|
|
config.setProgress(0);
|
|
|
|
|
config.setTaskId(UUID.randomUUID().toString());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 生成训练ID
|
|
|
|
|
Integer trainId = generateTrainId(config);
|
|
|
|
|
|
|
|
|
|
Integer trainId = config.getTrainId();
|
|
|
|
|
|
|
|
|
|
// 记录上一轮的epoch,用于检测epoch完成
|
|
|
|
|
int[] lastSavedEpoch = {0};
|
|
|
|
|
|
|
|
|
|
@ -63,16 +82,13 @@ public class YoloOperationServiceImpl implements YoloOperationService {
|
|
|
|
|
|
|
|
|
|
// 构建训练命令
|
|
|
|
|
String command = buildTrainCommand(envInfo, config);
|
|
|
|
|
|
|
|
|
|
log.info("开始YOLO训练,命令: {}", command);
|
|
|
|
|
|
|
|
|
|
// 执行训练
|
|
|
|
|
ProcessBuilder processBuilder = new ProcessBuilder();
|
|
|
|
|
// 在Windows上使用cmd,在Linux/Mac上使用bash
|
|
|
|
|
String os = System.getProperty("os.name").toLowerCase();
|
|
|
|
|
if (os.contains("win")) {
|
|
|
|
|
processBuilder.command("cmd", "/c", command);
|
|
|
|
|
// 设置环境变量确保正确处理中文路径
|
|
|
|
|
processBuilder.environment().put("PYTHONIOENCODING", "utf-8");
|
|
|
|
|
processBuilder.environment().put("LANG", "zh_CN.UTF-8");
|
|
|
|
|
} else {
|
|
|
|
|
@ -80,32 +96,35 @@ public class YoloOperationServiceImpl implements YoloOperationService {
|
|
|
|
|
processBuilder.environment().put("PYTHONIOENCODING", "utf-8");
|
|
|
|
|
}
|
|
|
|
|
Process process = processBuilder.start();
|
|
|
|
|
|
|
|
|
|
BufferedReader stdoutReader = new BufferedReader(new InputStreamReader(process.getInputStream()));
|
|
|
|
|
BufferedReader stderrReader = new BufferedReader(new InputStreamReader(process.getErrorStream()));
|
|
|
|
|
|
|
|
|
|
StringBuilder fullOutput = new StringBuilder();
|
|
|
|
|
String line;
|
|
|
|
|
|
|
|
|
|
String stdoutLine;
|
|
|
|
|
String stderrLine = null;
|
|
|
|
|
|
|
|
|
|
// 同时读取标准输出和错误输出
|
|
|
|
|
String stdoutLine, stderrLine = null;
|
|
|
|
|
while ((stdoutLine = stdoutReader.readLine()) != null || (stderrLine = stderrReader.readLine()) != null) {
|
|
|
|
|
if (stdoutLine != null) {
|
|
|
|
|
log.info("YOLO训练日志: {}", stdoutLine);
|
|
|
|
|
fullOutput.append(stdoutLine).append("\n");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 详细解析YOLO训练进度和结果
|
|
|
|
|
parseTrainingProgress(stdoutLine, config, trainId, lastSavedEpoch);
|
|
|
|
|
|
|
|
|
|
manageTrainingLogInRedis(trainId, stdoutLine);
|
|
|
|
|
config.setLogMessage(stdoutLine);
|
|
|
|
|
detectAndParseResultsSaved(stdoutLine, config, trainId);
|
|
|
|
|
parseTrainingProgress(stdoutLine, config, trainId, lastSavedEpoch);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (stderrLine != null) {
|
|
|
|
|
log.info("YOLO训练日志: {}", stderrLine);
|
|
|
|
|
fullOutput.append(stderrLine).append("\n");
|
|
|
|
|
|
|
|
|
|
// 详细解析YOLO训练进度和结果
|
|
|
|
|
parseTrainingProgress(stderrLine, config, trainId, lastSavedEpoch);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
manageTrainingLogInRedis(trainId, stderrLine);
|
|
|
|
|
config.setLogMessage(stderrLine);
|
|
|
|
|
detectAndParseResultsSaved(stderrLine, config, trainId);
|
|
|
|
|
parseTrainingProgress(stderrLine, config, trainId, lastSavedEpoch);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -138,6 +157,19 @@ public class YoloOperationServiceImpl implements YoloOperationService {
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public YoloConfig validate(String pythonPath, String envName, YoloConfig config) {
|
|
|
|
|
|
|
|
|
|
// 优先尝试使用socket服务(如果启用且可用)
|
|
|
|
|
if (useSocketService && yoloServiceClient != null && yoloServiceClient.isServiceAvailable()) {
|
|
|
|
|
log.info("使用socket服务执行YOLO验证");
|
|
|
|
|
return yoloServiceClient.validateViaSocket(
|
|
|
|
|
config.getModelPath(),
|
|
|
|
|
config.getValDatasetPath(),
|
|
|
|
|
config
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 回退到原有的ProcessBuilder方式
|
|
|
|
|
log.info("使用传统方式执行YOLO验证");
|
|
|
|
|
try {
|
|
|
|
|
PythonVirtualEnvInfo envInfo = pythonVirtualEnvService.activateVirtualEnv(pythonPath, envName);
|
|
|
|
|
if (!envInfo.getVirtualEnvExists() || !envInfo.getYoloInstalled()) {
|
|
|
|
|
@ -205,6 +237,20 @@ public class YoloOperationServiceImpl implements YoloOperationService {
|
|
|
|
|
|
|
|
|
|
@Override
|
|
|
|
|
public YoloConfig predict(String pythonPath, String envName, YoloConfig config) {
|
|
|
|
|
|
|
|
|
|
// 优先尝试使用socket服务(如果启用且可用)
|
|
|
|
|
if (useSocketService && yoloServiceClient != null && yoloServiceClient.isServiceAvailable()) {
|
|
|
|
|
log.info("使用socket服务执行YOLO预测");
|
|
|
|
|
return yoloServiceClient.predictViaSocket(
|
|
|
|
|
config.getModelPath(),
|
|
|
|
|
config.getInputPath(),
|
|
|
|
|
config.getOutputPath(),
|
|
|
|
|
config
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 回退到原有的ProcessBuilder方式
|
|
|
|
|
log.info("使用传统方式执行YOLO预测");
|
|
|
|
|
try {
|
|
|
|
|
PythonVirtualEnvInfo envInfo = pythonVirtualEnvService.activateVirtualEnv(pythonPath, envName);
|
|
|
|
|
if (!envInfo.getVirtualEnvExists() || !envInfo.getYoloInstalled()) {
|
|
|
|
|
@ -240,10 +286,22 @@ public class YoloOperationServiceImpl implements YoloOperationService {
|
|
|
|
|
if (stdoutLine != null) {
|
|
|
|
|
output.append(stdoutLine).append("\n");
|
|
|
|
|
log.info("YOLO预测日志: {}", stdoutLine);
|
|
|
|
|
|
|
|
|
|
if (stdoutLine.contains("Results saved to")) {
|
|
|
|
|
|
|
|
|
|
stdoutLine = stdoutLine.replace("\u001B","").replace("[1m","").replace("[0m","");
|
|
|
|
|
config.setOutputPath(stdoutLine.split(" ")[3]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (stderrLine != null) {
|
|
|
|
|
output.append(stderrLine).append("\n");
|
|
|
|
|
log.info("YOLO预测日志: {}", stderrLine);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (stderrLine.contains("Results saved to")) {
|
|
|
|
|
stderrLine = stderrLine.replace("\u001B","").replace("[1m","").replace("[0m","");
|
|
|
|
|
config.setOutputPath(stderrLine.split(" ")[3]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -394,6 +452,8 @@ public class YoloOperationServiceImpl implements YoloOperationService {
|
|
|
|
|
if ("conda".equals(envInfo.getVirtualEnvType())) {
|
|
|
|
|
command.append("conda run -n ").append(new File(envInfo.getVirtualEnvPath()).getName()).append(" ");
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
|
|
command.append(normalizePath(envInfo.getVirtualEnvPythonPath())).append(" ");
|
|
|
|
|
command.append(envInfo.getVirtualEnvPythonPath()).append(" ");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -402,6 +462,40 @@ public class YoloOperationServiceImpl implements YoloOperationService {
|
|
|
|
|
String outputPath = normalizePath(config.getOutputPath());
|
|
|
|
|
String modelPath = config.getModelPath() != null ? normalizePath(config.getModelPath()) : null;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 调试日志:检查模型路径
|
|
|
|
|
log.info("模型路径配置 - modelPath: {}, modelName: {}", config.getModelPath(), config.getModelName());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// command.append("-c \"");
|
|
|
|
|
command.append(datasetPath).append("/train.py ");
|
|
|
|
|
// command.append("from ultralytics import YOLO; ");
|
|
|
|
|
//
|
|
|
|
|
// if (modelPath != null && !modelPath.trim().isEmpty()) {
|
|
|
|
|
// command.append("model = YOLO('").append(escapePythonString(modelPath)).append("'); ");
|
|
|
|
|
// } else {
|
|
|
|
|
// command.append("model = YOLO('").append(config.getModelName()).append("'); ");
|
|
|
|
|
// }
|
|
|
|
|
//
|
|
|
|
|
// // 使用raw string来处理中文路径
|
|
|
|
|
// command.append("model.train(data=r'").append(escapePythonString(datasetPath+"\\dataset.yaml")).append("', ");
|
|
|
|
|
// command.append("epochs=").append(config.getEpochs()).append(", ");
|
|
|
|
|
// command.append("batch=").append(config.getBatchSize()).append(", ");
|
|
|
|
|
// command.append("imgsz=").append(config.getImageSize()).append(", ");
|
|
|
|
|
//
|
|
|
|
|
// // 取消注释设备和学习率配置
|
|
|
|
|
//// command.append("device=").append(config.getDevice()).append(", ");
|
|
|
|
|
//// command.append("lr0=").append(config.getLearningRate()).append(", ");
|
|
|
|
|
//// command.append("workers=").append(config.getWorkers()).append(", ");
|
|
|
|
|
////
|
|
|
|
|
// if (config.getSavePeriod() > 0) {
|
|
|
|
|
// command.append("save_period=").append(config.getSavePeriod()).append(", ");
|
|
|
|
|
// }
|
|
|
|
|
//
|
|
|
|
|
// command.append("project=r'").append(escapePythonString(outputPath)).append("')\"");
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
command.append("-c \"");
|
|
|
|
|
@ -472,8 +566,8 @@ public class YoloOperationServiceImpl implements YoloOperationService {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
String inputPath = normalizePath(config.getInputPath());
|
|
|
|
|
String outputPath = normalizePath(config.getOutputPath());
|
|
|
|
|
String modelPath = normalizePath(config.getModelPath());
|
|
|
|
|
|
|
|
|
|
command.append("-c \"");
|
|
|
|
|
@ -481,18 +575,17 @@ public class YoloOperationServiceImpl implements YoloOperationService {
|
|
|
|
|
command.append("from ultralytics import YOLO; ");
|
|
|
|
|
command.append("model = YOLO(r'").append(escapePythonString(modelPath)).append("'); ");
|
|
|
|
|
command.append("results = model.predict(source=r'").append(escapePythonString(inputPath)).append("', ");
|
|
|
|
|
command.append("conf=").append(config.getConfThresh()).append(", ");
|
|
|
|
|
command.append("iou=").append(config.getIouThresh()).append(", ");
|
|
|
|
|
command.append("imgsz=").append(config.getImageSize()).append(", ");
|
|
|
|
|
command.append("device='").append(config.getDevice()).append("', ");
|
|
|
|
|
command.append("save=").append(config.isSave()).append(", ");
|
|
|
|
|
command.append("show=").append(config.isShow()).append(", ");
|
|
|
|
|
command.append("save_txt=").append(config.isSaveTxt()).append(", ");
|
|
|
|
|
command.append("save_crop=").append(config.isSaveCrops()).append(")\"");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 在这里我们将 save=true 改为 save=True
|
|
|
|
|
command.append("save=").append(config.isSave() ? "True" : "False").append(", ");
|
|
|
|
|
command.append("project='").append(escapePythonString(outputPath)).append("')\"");
|
|
|
|
|
|
|
|
|
|
return command.toString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private String buildExportCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) {
|
|
|
|
|
StringBuilder command = new StringBuilder();
|
|
|
|
|
|
|
|
|
|
@ -534,6 +627,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
|
|
|
|
|
return command.toString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 解析训练进度详细信息
|
|
|
|
|
*/
|
|
|
|
|
@ -631,6 +725,8 @@ public class YoloOperationServiceImpl implements YoloOperationService {
|
|
|
|
|
// 忽略解析错误,继续处理下一行
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 解析最终训练结果
|
|
|
|
|
@ -820,6 +916,12 @@ public class YoloOperationServiceImpl implements YoloOperationService {
|
|
|
|
|
// 不影响训练完成状态
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@Resource
|
|
|
|
|
@Lazy
|
|
|
|
|
TrainService trainService;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 保存TrainResultDO到数据库
|
|
|
|
|
@ -827,8 +929,11 @@ public class YoloOperationServiceImpl implements YoloOperationService {
|
|
|
|
|
private void saveTrainResultDO(TrainResultDO trainResult) {
|
|
|
|
|
try {
|
|
|
|
|
// 创建SaveReqVO
|
|
|
|
|
cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultSaveReqVO saveReqVO =
|
|
|
|
|
new cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultSaveReqVO();
|
|
|
|
|
|
|
|
|
|
TrainResultSaveReqVO saveReqVO =
|
|
|
|
|
new TrainResultSaveReqVO();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 设置值
|
|
|
|
|
saveReqVO.setTrainId(trainResult.getTrainId());
|
|
|
|
|
@ -838,6 +943,11 @@ public class YoloOperationServiceImpl implements YoloOperationService {
|
|
|
|
|
|
|
|
|
|
// 保存
|
|
|
|
|
trainResultService.createTrainResult(saveReqVO);
|
|
|
|
|
|
|
|
|
|
trainService.update(new UpdateWrapper<TrainDO>().eq("id", trainResult.getTrainId())
|
|
|
|
|
.set("train_type", 4));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("保存训练结果失败: {}", trainResult, e);
|
|
|
|
|
@ -959,4 +1069,687 @@ public class YoloOperationServiceImpl implements YoloOperationService {
|
|
|
|
|
return escaped.toString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Resource
|
|
|
|
|
private RedisUtil redisUtil;
|
|
|
|
|
/**
|
|
|
|
|
* 检测并解析 "Results saved to" 日志,提取保存路径和识别率
|
|
|
|
|
*
|
|
|
|
|
* @param logLine 日志行
|
|
|
|
|
* @param config YOLO配置对象
|
|
|
|
|
* @param trainId 训练ID
|
|
|
|
|
*/
|
|
|
|
|
private void detectAndParseResultsSaved(String logLine, YoloConfig config, Integer trainId) {
|
|
|
|
|
try {
|
|
|
|
|
if (logLine == null || logLine.trim().isEmpty()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
String trimmedLine = logLine.trim();
|
|
|
|
|
|
|
|
|
|
// 检测 "Results saved to" 或 "Saved to" 关键字
|
|
|
|
|
if (trimmedLine.contains("Results saved to ") || trimmedLine.contains("Saved to ")) {
|
|
|
|
|
// 提取保存路径
|
|
|
|
|
String savedPath = extractSavedPath(trimmedLine);
|
|
|
|
|
savedPath = savedPath.replace("\u001B","").replace("[1m","").replace("[0m","");
|
|
|
|
|
if (savedPath != null && !savedPath.trim().isEmpty()) {
|
|
|
|
|
log.info("检测到YOLO结果保存路径: {}", savedPath);
|
|
|
|
|
|
|
|
|
|
// 保存路径到配置中
|
|
|
|
|
config.setOutputPath(savedPath);
|
|
|
|
|
|
|
|
|
|
// 尝试从保存路径中获取识别率信息
|
|
|
|
|
parseAccuracyFromSavedPath(savedPath, config, trainId);
|
|
|
|
|
|
|
|
|
|
// 保存结果路径到数据库
|
|
|
|
|
saveResultsPath(trainId, savedPath, config);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("检测和解析Results saved to时出错", e);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 从日志行中提取保存路径
|
|
|
|
|
*
|
|
|
|
|
* @param logLine 包含 "Results saved to" 的日志行
|
|
|
|
|
* @return 提取的路径,如果提取失败返回null
|
|
|
|
|
*/
|
|
|
|
|
private String extractSavedPath(String logLine) {
|
|
|
|
|
try {
|
|
|
|
|
// 处理不同格式的日志输出
|
|
|
|
|
String[] patterns = {
|
|
|
|
|
"Results saved to (.*?)\\s*$",
|
|
|
|
|
"Saved to (.*?)\\s*$",
|
|
|
|
|
"Results saved to (.*?)(?:\\s+\\(\\d+.*?\\))?$", // 带统计信息的格式
|
|
|
|
|
"Saved to (.*?)(?:\\s+\\(\\d+.*?\\))?$"
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
for (String pattern : patterns) {
|
|
|
|
|
java.util.regex.Pattern p = java.util.regex.Pattern.compile(pattern);
|
|
|
|
|
java.util.regex.Matcher m = p.matcher(logLine);
|
|
|
|
|
if (m.find()) {
|
|
|
|
|
String path = m.group(1).trim();
|
|
|
|
|
// 移除可能的引号
|
|
|
|
|
if (path.startsWith("\"") && path.endsWith("\"")) {
|
|
|
|
|
path = path.substring(1, path.length() - 1);
|
|
|
|
|
}
|
|
|
|
|
if (path.startsWith("'") && path.endsWith("'")) {
|
|
|
|
|
path = path.substring(1, path.length() - 1);
|
|
|
|
|
}
|
|
|
|
|
return path;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 如果正则匹配失败,尝试简单的字符串分割
|
|
|
|
|
String[] parts = logLine.split("Results saved to|Saved to");
|
|
|
|
|
if (parts.length > 1) {
|
|
|
|
|
String path = parts[1].trim();
|
|
|
|
|
// 移除可能的引号
|
|
|
|
|
if (path.startsWith("\"") && path.endsWith("\"")) {
|
|
|
|
|
path = path.substring(1, path.length() - 1);
|
|
|
|
|
}
|
|
|
|
|
if (path.startsWith("'") && path.endsWith("'")) {
|
|
|
|
|
path = path.substring(1, path.length() - 1);
|
|
|
|
|
}
|
|
|
|
|
return path;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.warn("提取保存路径失败: {}", logLine, e);
|
|
|
|
|
}
|
|
|
|
|
return null;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 从保存路径中解析识别率信息
|
|
|
|
|
* YOLO训练完成后,通常会在保存路径中生成包含性能指标的文件
|
|
|
|
|
*
|
|
|
|
|
* @param savedPath YOLO结果保存路径
|
|
|
|
|
* @param config YOLO配置对象
|
|
|
|
|
* @param trainId 训练ID
|
|
|
|
|
*/
|
|
|
|
|
private void parseAccuracyFromSavedPath(String savedPath, YoloConfig config, Integer trainId) {
|
|
|
|
|
try {
|
|
|
|
|
log.info("开始从保存路径解析识别率: {}", savedPath);
|
|
|
|
|
|
|
|
|
|
java.io.File saveDir = new java.io.File(savedPath);
|
|
|
|
|
if (!saveDir.exists() || !saveDir.isDirectory()) {
|
|
|
|
|
log.warn("保存路径不存在或不是目录: {}", savedPath);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 查找可能包含识别率的文件
|
|
|
|
|
java.io.File[] files = saveDir.listFiles();
|
|
|
|
|
if (files == null) {
|
|
|
|
|
log.warn("无法读取保存目录内容: {}", savedPath);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 优先查找results.csv文件(YOLO通常会在训练完成后生成)
|
|
|
|
|
for (java.io.File file : files) {
|
|
|
|
|
if (file.getName().equals("results.csv")) {
|
|
|
|
|
parseResultsCsv(file, config, trainId);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 如果没有results.csv,查找其他可能的文件
|
|
|
|
|
for (java.io.File file : files) {
|
|
|
|
|
String fileName = file.getName().toLowerCase();
|
|
|
|
|
|
|
|
|
|
// 查找train_batch或val_batch的图片文件,文件名中可能包含性能指标
|
|
|
|
|
if (fileName.contains("val_batch") && fileName.endsWith(".jpg")) {
|
|
|
|
|
log.info("找到验证批次图片,但无法直接从中提取数值指标");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 查找可能的日志文件
|
|
|
|
|
if (fileName.endsWith(".log") || fileName.endsWith(".txt")) {
|
|
|
|
|
parseLogFileForAccuracy(file, config, trainId);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 如果没有找到结果文件,尝试解析路径中的训练信息
|
|
|
|
|
parseTrainingInfoFromPath(savedPath, config);
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("从保存路径解析识别率时出错: {}", savedPath, e);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 解析results.csv文件获取识别率
|
|
|
|
|
* 支持不同任务类型:检测(detect)、分类(classify)、分割(segment)
|
|
|
|
|
*
|
|
|
|
|
* @param csvFile results.csv文件
|
|
|
|
|
* @param config YOLO配置对象
|
|
|
|
|
* @param trainId 训练ID
|
|
|
|
|
*/
|
|
|
|
|
private void parseResultsCsv(java.io.File csvFile, YoloConfig config, Integer trainId) {
|
|
|
|
|
try {
|
|
|
|
|
log.info("解析results.csv文件: {} (任务类型: {})", csvFile.getAbsolutePath(), config.getTaskType());
|
|
|
|
|
|
|
|
|
|
java.util.List<String> lines = java.nio.file.Files.readAllLines(csvFile.toPath());
|
|
|
|
|
if (lines.isEmpty()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 读取标题行和最后一行(最新的结果)
|
|
|
|
|
String headerLine = lines.get(0);
|
|
|
|
|
String lastLine = lines.get(lines.size() - 1);
|
|
|
|
|
String[] headers = headerLine.split(",");
|
|
|
|
|
String[] values = lastLine.split(",");
|
|
|
|
|
|
|
|
|
|
log.debug("CSV标题行: {}", headerLine);
|
|
|
|
|
log.debug("CSV数据行: {}", lastLine);
|
|
|
|
|
|
|
|
|
|
// 根据任务类型和标题解析不同的指标
|
|
|
|
|
String taskType = config.getTaskType();
|
|
|
|
|
if (taskType == null) {
|
|
|
|
|
taskType = "detect"; // 默认为检测任务
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
double precision = 0.0, recall = 0.0, map = 0.0, accuracy = 0.0, f1Score = 0.0;
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
if ("classify".equals(taskType) || "classification".equals(taskType)) {
|
|
|
|
|
// 分类任务:解析 accuracy, top1_acc, top5_acc, loss 等
|
|
|
|
|
for (int i = 0; i < headers.length && i < values.length; i++) {
|
|
|
|
|
String header = headers[i].trim().toLowerCase();
|
|
|
|
|
String value = values[i].trim();
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
if (header.contains("accuracy") || header.contains("acc")) {
|
|
|
|
|
accuracy = Double.parseDouble(value);
|
|
|
|
|
config.setPrecision(accuracy); // 分类任务用accuracy作为precision
|
|
|
|
|
log.info("分类任务解析到准确率(Accuracy): {}", accuracy);
|
|
|
|
|
} else if (header.contains("top1") || header.contains("top_1")) {
|
|
|
|
|
double top1Acc = Double.parseDouble(value);
|
|
|
|
|
log.info("分类任务解析到Top-1准确率: {}", top1Acc);
|
|
|
|
|
} else if (header.contains("top5") || header.contains("top_5")) {
|
|
|
|
|
double top5Acc = Double.parseDouble(value);
|
|
|
|
|
log.info("分类任务解析到Top-5准确率: {}", top5Acc);
|
|
|
|
|
} else if (header.contains("loss")) {
|
|
|
|
|
// 记录损失值
|
|
|
|
|
log.debug("分类任务解析到损失值: {}", value);
|
|
|
|
|
}
|
|
|
|
|
} catch (NumberFormatException e) {
|
|
|
|
|
log.debug("跳过无法解析的数值: {}", value);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 保存分类任务的准确率到数据库
|
|
|
|
|
if (accuracy > 0) {
|
|
|
|
|
saveClassificationAccuracyToDatabase(trainId, accuracy);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} else if ("detect".equals(taskType) || "segment".equals(taskType)) {
|
|
|
|
|
// 检测或分割任务:解析 precision, recall, mAP50, mAP50-95 等
|
|
|
|
|
for (int i = 0; i < headers.length && i < values.length; i++) {
|
|
|
|
|
String header = headers[i].trim().toLowerCase();
|
|
|
|
|
String value = values[i].trim();
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
if (header.contains("precision")) {
|
|
|
|
|
precision = Double.parseDouble(value);
|
|
|
|
|
config.setPrecision(precision);
|
|
|
|
|
} else if (header.contains("recall")) {
|
|
|
|
|
recall = Double.parseDouble(value);
|
|
|
|
|
config.setRecall(recall);
|
|
|
|
|
} else if (header.contains("map50-95") || header.contains("map50_95")) {
|
|
|
|
|
map = Double.parseDouble(value);
|
|
|
|
|
config.setMap(map);
|
|
|
|
|
} else if (header.contains("map50")) {
|
|
|
|
|
double map50 = Double.parseDouble(value);
|
|
|
|
|
log.debug("解析到mAP50: {}", map50);
|
|
|
|
|
// 如果没有mAP50-95,使用mAP50作为map值
|
|
|
|
|
if (map == 0.0) {
|
|
|
|
|
map = map50;
|
|
|
|
|
config.setMap(map);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} catch (NumberFormatException e) {
|
|
|
|
|
log.debug("跳过无法解析的数值: {}", value);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
log.info("检测/分割任务解析到识别率 - Precision: {}, Recall: {}, mAP: {}",
|
|
|
|
|
precision, recall, map);
|
|
|
|
|
|
|
|
|
|
// 保存检测/分割任务的识别率到数据库
|
|
|
|
|
if (precision > 0 || recall > 0 || map > 0) {
|
|
|
|
|
saveAccuracyToDatabase(trainId, precision, recall, map);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
// 其他未知任务类型,尝试通用解析
|
|
|
|
|
log.warn("未知任务类型: {}, 尝试通用解析", taskType);
|
|
|
|
|
parseGenericResults(headers, values, config);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.warn("解析results.csv数值时出错: {}", lastLine, e);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("读取results.csv文件时出错", e);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 通用解析results.csv,适用于未知任务类型
|
|
|
|
|
*
|
|
|
|
|
* @param headers CSV标题
|
|
|
|
|
* @param values CSV数值
|
|
|
|
|
* @param config YOLO配置对象
|
|
|
|
|
*/
|
|
|
|
|
private void parseGenericResults(String[] headers, String[] values, YoloConfig config) {
|
|
|
|
|
try {
|
|
|
|
|
double precision = 0.0, recall = 0.0, map = 0.0, accuracy = 0.0;
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < headers.length && i < values.length; i++) {
|
|
|
|
|
String header = headers[i].trim().toLowerCase();
|
|
|
|
|
String value = values[i].trim();
|
|
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
if (header.contains("precision")) {
|
|
|
|
|
precision = Double.parseDouble(value);
|
|
|
|
|
config.setPrecision(precision);
|
|
|
|
|
} else if (header.contains("recall")) {
|
|
|
|
|
recall = Double.parseDouble(value);
|
|
|
|
|
config.setRecall(recall);
|
|
|
|
|
} else if (header.contains("map")) {
|
|
|
|
|
map = Double.parseDouble(value);
|
|
|
|
|
config.setMap(map);
|
|
|
|
|
} else if (header.contains("accuracy") || header.contains("acc")) {
|
|
|
|
|
accuracy = Double.parseDouble(value);
|
|
|
|
|
config.setPrecision(accuracy); // 用accuracy替代precision
|
|
|
|
|
}
|
|
|
|
|
} catch (NumberFormatException e) {
|
|
|
|
|
log.debug("通用解析跳过无法解析的数值: {}", value);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
log.info("通用解析结果 - Precision/Accuracy: {}, Recall: {}, mAP: {}",
|
|
|
|
|
Math.max(precision, accuracy), recall, map);
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("通用解析失败", e);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 保存分类任务的准确率到数据库
|
|
|
|
|
*
|
|
|
|
|
* @param trainId 训练ID
|
|
|
|
|
* @param accuracy 准确率
|
|
|
|
|
*/
|
|
|
|
|
private void saveClassificationAccuracyToDatabase(Integer trainId, double accuracy) {
|
|
|
|
|
try {
|
|
|
|
|
// 构建分类任务的准确率字符串
|
|
|
|
|
String rateStr = String.format("accuracy:%.4f", accuracy);
|
|
|
|
|
|
|
|
|
|
// 更新数据库中的准确率
|
|
|
|
|
cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO trainResult =
|
|
|
|
|
cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO.builder()
|
|
|
|
|
.trainId(trainId)
|
|
|
|
|
.rate(rateStr)
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
saveTrainResultDO(trainResult);
|
|
|
|
|
|
|
|
|
|
log.info("已保存分类任务准确率到数据库 - 训练ID: {}, 准确率: {}", trainId, rateStr);
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("保存分类任务准确率到数据库失败", e);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 从日志文件中解析识别率
|
|
|
|
|
*
|
|
|
|
|
* @param logFile 日志文件
|
|
|
|
|
* @param config YOLO配置对象
|
|
|
|
|
* @param trainId 训练ID
|
|
|
|
|
*/
|
|
|
|
|
private void parseLogFileForAccuracy(java.io.File logFile, YoloConfig config, Integer trainId) {
|
|
|
|
|
try {
|
|
|
|
|
log.info("尝试从日志文件解析识别率: {}", logFile.getAbsolutePath());
|
|
|
|
|
|
|
|
|
|
java.util.List<String> lines = java.nio.file.Files.readAllLines(logFile.toPath());
|
|
|
|
|
|
|
|
|
|
// 反向遍历,查找最后的验证结果
|
|
|
|
|
for (int i = lines.size() - 1; i >= 0; i--) {
|
|
|
|
|
String line = lines.get(i);
|
|
|
|
|
|
|
|
|
|
// 查找包含精度、召回率、mAP等指标的行
|
|
|
|
|
if (line.contains("mAP50") || line.contains("mAP50-95") ||
|
|
|
|
|
line.contains("precision") || line.contains("recall")) {
|
|
|
|
|
|
|
|
|
|
parseMetricsLine(line, config, trainId);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("读取日志文件时出错", e);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 解析包含指标的行
|
|
|
|
|
*
|
|
|
|
|
* @param line 包含指标的行
|
|
|
|
|
* @param config YOLO配置对象
|
|
|
|
|
* @param trainId 训练ID
|
|
|
|
|
*/
|
|
|
|
|
private void parseMetricsLine(String line, YoloConfig config, Integer trainId) {
|
|
|
|
|
try {
|
|
|
|
|
log.info("解析指标行: {}", line);
|
|
|
|
|
|
|
|
|
|
// 使用正则表达式提取数值
|
|
|
|
|
java.util.regex.Pattern precisionPattern = java.util.regex.Pattern.compile("precision[:\\s=]+([0-9.]+)");
|
|
|
|
|
java.util.regex.Pattern recallPattern = java.util.regex.Pattern.compile("recall[:\\s=]+([0-9.]+)");
|
|
|
|
|
java.util.regex.Pattern mapPattern = java.util.regex.Pattern.compile("mAP50-95[:\\s=]+([0-9.]+)");
|
|
|
|
|
java.util.regex.Pattern map50Pattern = java.util.regex.Pattern.compile("mAP50[:\\s=]+([0-9.]+)");
|
|
|
|
|
|
|
|
|
|
java.util.regex.Matcher m;
|
|
|
|
|
|
|
|
|
|
if ((m = precisionPattern.matcher(line)).find()) {
|
|
|
|
|
config.setPrecision(Double.parseDouble(m.group(1)));
|
|
|
|
|
}
|
|
|
|
|
if ((m = recallPattern.matcher(line)).find()) {
|
|
|
|
|
config.setRecall(Double.parseDouble(m.group(1)));
|
|
|
|
|
}
|
|
|
|
|
if ((m = mapPattern.matcher(line)).find()) {
|
|
|
|
|
config.setMap(Double.parseDouble(m.group(1)));
|
|
|
|
|
}
|
|
|
|
|
if ((m = map50Pattern.matcher(line)).find()) {
|
|
|
|
|
// 如果没有mAP50-95,使用mAP50作为mAP值
|
|
|
|
|
if (config.getMap() == 0.0) {
|
|
|
|
|
config.setMap(Double.parseDouble(m.group(1)));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 保存识别率到数据库
|
|
|
|
|
if (config.getPrecision() > 0 || config.getRecall() > 0 || config.getMap() > 0) {
|
|
|
|
|
saveAccuracyToDatabase(trainId, config.getPrecision(), config.getRecall(), config.getMap());
|
|
|
|
|
log.info("成功解析识别率 - Precision: {}, Recall: {}, mAP: {}",
|
|
|
|
|
config.getPrecision(), config.getRecall(), config.getMap());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("解析指标行时出错", e);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 从路径信息中解析训练信息(备用方法)
|
|
|
|
|
*
|
|
|
|
|
* @param savedPath 保存路径
|
|
|
|
|
* @param config YOLO配置对象
|
|
|
|
|
*/
|
|
|
|
|
private void parseTrainingInfoFromPath(String savedPath, YoloConfig config) {
|
|
|
|
|
try {
|
|
|
|
|
// 从路径中可能包含的训练信息,如train18等
|
|
|
|
|
java.io.File pathFile = new java.io.File(savedPath);
|
|
|
|
|
String pathName = pathFile.getName();
|
|
|
|
|
|
|
|
|
|
if (pathName.startsWith("train") && pathName.length() > 5) {
|
|
|
|
|
try {
|
|
|
|
|
int trainNumber = Integer.parseInt(pathName.substring(5));
|
|
|
|
|
log.info("从路径解析到训练编号: {}", trainNumber);
|
|
|
|
|
} catch (NumberFormatException e) {
|
|
|
|
|
log.debug("无法从路径解析训练编号: {}", pathName);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.warn("从路径解析训练信息时出错", e);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 保存结果路径到数据库
|
|
|
|
|
*
|
|
|
|
|
* @param trainId 训练ID
|
|
|
|
|
* @param savedPath 保存路径
|
|
|
|
|
* @param config YOLO配置对象
|
|
|
|
|
*/
|
|
|
|
|
private void saveResultsPath(Integer trainId, String savedPath, YoloConfig config) {
|
|
|
|
|
try {
|
|
|
|
|
// 更新TrainResultDO
|
|
|
|
|
TrainResultDO trainResult = TrainResultDO.builder()
|
|
|
|
|
.trainId(trainId)
|
|
|
|
|
.path(savedPath)
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
saveTrainResultDO(trainResult);
|
|
|
|
|
TrainDO train = trainService.getTrain(trainId);
|
|
|
|
|
train.setOutPath(savedPath);
|
|
|
|
|
trainService.updateById(train);
|
|
|
|
|
log.info("已保存结果路径到数据库 - 训练ID: {}, 路径: {}", trainId, savedPath);
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("保存结果路径到数据库失败", e);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 保存识别率到数据库
|
|
|
|
|
*
|
|
|
|
|
* @param trainId 训练ID
|
|
|
|
|
* @param precision 精度
|
|
|
|
|
* @param recall 召回率
|
|
|
|
|
* @param map mAP值
|
|
|
|
|
*/
|
|
|
|
|
private void saveAccuracyToDatabase(Integer trainId, double precision, double recall, double map) {
|
|
|
|
|
try {
|
|
|
|
|
// 构建识别率字符串
|
|
|
|
|
StringBuilder rateStr = new StringBuilder();
|
|
|
|
|
if (precision > 0) {
|
|
|
|
|
rateStr.append("precision:").append(String.format("%.4f", precision));
|
|
|
|
|
}
|
|
|
|
|
if (recall > 0) {
|
|
|
|
|
if (rateStr.length() > 0) rateStr.append(",");
|
|
|
|
|
rateStr.append("recall:").append(String.format("%.4f", recall));
|
|
|
|
|
}
|
|
|
|
|
if (map > 0) {
|
|
|
|
|
if (rateStr.length() > 0) rateStr.append(",");
|
|
|
|
|
rateStr.append("map:").append(String.format("%.4f", map));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 更新数据库中的识别率
|
|
|
|
|
TrainResultDO trainResult =
|
|
|
|
|
cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO.builder()
|
|
|
|
|
.trainId(trainId)
|
|
|
|
|
.rate(rateStr.toString())
|
|
|
|
|
.build();
|
|
|
|
|
|
|
|
|
|
saveTrainResultDO(trainResult);
|
|
|
|
|
|
|
|
|
|
log.info("已保存识别率到数据库 - 训练ID: {}, 识别率: {}", trainId, rateStr.toString());
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("保存识别率到数据库失败", e);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 管理训练日志在Redis中的存储
|
|
|
|
|
* 当日志包含%时,进行特殊处理:删除最后一条,插入新的,最多保存50条,保存5天
|
|
|
|
|
*
|
|
|
|
|
* @param trainId 训练ID
|
|
|
|
|
* @param logLine 日志行
|
|
|
|
|
*/
|
|
|
|
|
private void manageTrainingLogInRedis(Integer trainId, String logLine) {
|
|
|
|
|
try {
|
|
|
|
|
if (logLine == null || logLine.trim().isEmpty()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
String redisKey = "yolo:training_log:" + trainId;
|
|
|
|
|
boolean isProgressLog = logLine.contains("%");
|
|
|
|
|
|
|
|
|
|
if (isProgressLog) {
|
|
|
|
|
// 如果是进度日志(包含%),先删除之前的进度日志
|
|
|
|
|
removeProgressLogs(redisKey);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 添加新的日志行到列表末尾
|
|
|
|
|
redisUtil.lSet(redisKey, logLine, 5 * 24 * 60 * 60); // 5天过期时间(秒)
|
|
|
|
|
|
|
|
|
|
// 检查列表长度,如果超过50条,删除第一条
|
|
|
|
|
long newSize = redisUtil.lGetListSize(redisKey);
|
|
|
|
|
if (newSize > 50) {
|
|
|
|
|
redisUtil.lRemove(redisKey, 1, redisUtil.lGetIndex(redisKey, 0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
log.debug("训练日志已存储到Redis - 训练ID: {}, 日志类型: {}, 当前日志数量: {}",
|
|
|
|
|
trainId, isProgressLog ? "进度日志" : "普通日志", redisUtil.lGetListSize(redisKey));
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("管理训练日志Redis操作失败 - 训练ID: {}", trainId, e);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 删除Redis中所有包含%的进度日志
|
|
|
|
|
*
|
|
|
|
|
* @param redisKey Redis键
|
|
|
|
|
*/
|
|
|
|
|
private void removeProgressLogs(String redisKey) {
|
|
|
|
|
try {
|
|
|
|
|
long listSize = redisUtil.lGetListSize(redisKey);
|
|
|
|
|
if (listSize == 0) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 获取所有日志
|
|
|
|
|
java.util.List<Object> allLogs = redisUtil.lGet(redisKey, 0, -1);
|
|
|
|
|
if (allLogs == null || allLogs.isEmpty()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 找出并删除最后一条包含%的日志(保留100%的)
|
|
|
|
|
for (int i = allLogs.size() - 1; i >= 0; i--) {
|
|
|
|
|
Object logObj = allLogs.get(i);
|
|
|
|
|
if (logObj != null) {
|
|
|
|
|
String logStr = logObj.toString();
|
|
|
|
|
if (logStr.contains("%")) {
|
|
|
|
|
// 检查是否是100%的日志,如果是则跳过
|
|
|
|
|
if (logStr.contains("100%")) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
redisUtil.lRemove(redisKey, 1, logStr);
|
|
|
|
|
log.debug("删除旧的进度日志: {}", logStr);
|
|
|
|
|
break; // 只删除最后一条非100%的进度日志
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("删除进度日志时出错 - Redis键: {}", redisKey, e);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* 生成仅训练的Python脚本文件
|
|
|
|
|
*
|
|
|
|
|
* @return 是否生成成功
|
|
|
|
|
*/
|
|
|
|
|
public void generateTrainPythonScript(TrainDO train, String yoloDatasetPath) {
|
|
|
|
|
try {
|
|
|
|
|
StringBuilder script = new StringBuilder();
|
|
|
|
|
|
|
|
|
|
// 文件头部和导入
|
|
|
|
|
script.append("# -*- coding: utf-8 -*-");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append("import sys");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append("import matplotlib.pyplot as plt");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append("from matplotlib import rcParams");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append("import platform");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
|
|
|
|
|
// 字体设置函数
|
|
|
|
|
script.append("# 自动选择字体");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append("def set_font():");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" system_platform = platform.system()");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" if system_platform == \"Windows\":");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" rcParams['font.sans-serif'] = ['Microsoft YaHei'] # Windows 下使用微软雅黑");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" else:");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" rcParams['font.sans-serif'] = ['DejaVu Sans'] # Linux 下使用 DejaVu Sans");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" rcParams['axes.unicode_minus'] = False # 解决负号显示问题");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
|
|
|
|
|
// 训练函数
|
|
|
|
|
script.append("# 训练函数");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append("def main():");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" set_font()");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" sys.path.insert(0, '.')");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" from ultralytics import YOLO");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
|
|
|
|
|
if (train.getModelPath() != null && !train.getModelPath().trim().isEmpty()) {
|
|
|
|
|
script.append(" model = YOLO('").append(escapePythonString(train.getModelPath())).append("')");
|
|
|
|
|
} else {
|
|
|
|
|
script.append(" model = YOLO('yolo11n.pt')");
|
|
|
|
|
}
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" ");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" # 训练模型");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" model.train(");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" data=r'").append(escapePythonString(train.getPath() + "/dataset.yaml")).append("',");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" epochs=").append(train.getRound()).append(",");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" batch=").append(train.getSize()).append(",");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" imgsz=").append(train.getImageSize()).append(",");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" project=r'").append(escapePythonString(train.getPath())).append("'");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" )");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
|
|
|
|
|
// 程序入口
|
|
|
|
|
script.append("if __name__ == '__main__':");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
script.append(" main()");
|
|
|
|
|
script.append("\n");
|
|
|
|
|
|
|
|
|
|
// 写入文件
|
|
|
|
|
try (java.io.FileWriter writer = new java.io.FileWriter(train.getPath() + "/train.py", java.nio.charset.StandardCharsets.UTF_8)) {
|
|
|
|
|
writer.write(script.toString());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
log.info("训练Python脚本生成成功: {}", train.getPath());
|
|
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
|
log.error("生成训练Python脚本失败", e);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
}
|