# Conflicts:
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/TrainController.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/TrainInfoController.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/yolo/YoloOperationController.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/train/TrainDO.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/yolo/YoloConfig.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/python/PythonVirtualEnvServiceImpl.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainService.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainServiceImpl.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloOperationService.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloOperationServiceImpl.java
#	yudao-server/src/main/resources/application-dev.yaml
master
LAPTOP-S9HJSOEB\昊天 3 months ago
commit 4363ea079a

@ -81,7 +81,7 @@ public class TrainController {
@Operation(summary = "进行训练")
// @Parameter(name = "id", description = "编号", required = true, example = "1024")
@PreAuthorize("@ss.hasPermission('annotation:train:query')")
public CommonResult<String> testRecognition(@RequestBody RecognitionResult recognitionReqVO) {
public CommonResult<String> testRecognition(@RequestBody RecognitionResult recognitionReqVO) {
String train = trainService.testRecognition(recognitionReqVO.getImageUrl(),recognitionReqVO.getTrainId());
return success( train);

@ -69,6 +69,7 @@ public class TrainInfoController {
TrainInfoDO trainInfo = trainInfoService.getTrainInfo(id);
return success(BeanUtils.toBean(trainInfo, TrainInfoRespVO.class));
}
@Resource
RedisUtil redisUtil;
@ -76,6 +77,7 @@ public class TrainInfoController {
@Operation(summary = "获得识别结果")
@Parameter(name = "id", description = "编号", required = true, example = "1024")
@PreAuthorize("@ss.hasPermission('annotation:train-Info:query')")
public CommonResult<List<String>> getTrainInfoList(@RequestParam("trainId") Integer id) {
List<Object> result = redisUtil.lGet("yolo:training_log:" + id, 0, -1);
List<String> result1 = result.stream().map(Object::toString).toList();

@ -87,6 +87,7 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id);
config.setTrainId(id);
@ -105,6 +106,7 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id);
YoloConfig result = yoloOperationService.validate(
dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config);
@ -120,6 +122,7 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id);
YoloConfig result = yoloOperationService.predict(
dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config);
@ -135,6 +138,7 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id);
YoloConfig result = yoloOperationService.classify(
dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config);
@ -150,6 +154,7 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id);
YoloConfig result = yoloOperationService.export(
dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config);
@ -165,6 +170,7 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id);
YoloConfig result = yoloOperationService.track(
dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config);

@ -75,54 +75,63 @@ public class TrainDO extends BaseDO {
private String type;
private String outPath;
public YoloConfig getYolofig(Map<String, DictDataDO> dictDataMap) {
YoloConfig config = new YoloConfig();
// 设置数据相关配置
config.setDatasetPath(this.path);
// 设置训练参数
if (this.size != null) {
config.setEpochs(this.size);
}
if (this.round != null) {
config.setBatchSize(this.round);
}
if (this.imageSize != null) {
config.setImageSize(this.imageSize);
}
// 设置数据集比例YOLO中需要转换
// 注意YOLO通常通过数据集配置文件来设置train/val/test比例
// 这里我们可以将比例信息保存到配置中,供后续处理使用
// 设置预训练模型路径
if (this.type.equals("1")) {
config.setModelPath(dictDataMap.get("detect_path").getValue());
config.setPretrained(true);
}else {
config.setModelPath(dictDataMap.get("classify_path").getValue());
config.setPretrained(false);
}
// 设置GPU/CPU设备
if (this.round != null) {
config.setEpochs(this.round);
}
if (this.size != null) {
config.setBatchSize(this.size);
}
if (this.imageSize != null) {
config.setImageSize(this.imageSize);
}
// 设置数据集比例YOLO中需要转换
// 注意YOLO通常通过数据集配置文件来设置train/val/test比例
// 这里我们可以将比例信息保存到配置中,供后续处理使用
// 设置预训练模型路径
if (this.type.equals("1")) {
config.setModelPath(dictDataMap.get("detect_path").getValue());
config.setPretrained(true);
} else {
config.setModelPath(dictDataMap.get("classify_path").getValue());
config.setPretrained(false);
}
// 设置GPU/CPU设备
// 设置输出路径默认在数据集路径下的runs/train目录
if (this.path != null) {
config.setOutputPath(this.path + "/runs/train");
}
// 设置输出路径默认在数据集路径下的runs/train目录
if (this.path != null) {
config.setOutputPath(this.path + "/runs/train");
// 设置其他默认参数
config.setTaskType(this.type.equals("1") ? "detect" : "classify");
config.setModelName(this.modelPath);
config.setModelPath(this.modelPath);
config.setLearningRate(0.01);
config.setConfThresh(0.25);
config.setIouThresh(0.45);
config.setSave(true);
config.setSaveTxt(true);
return config;
}
// 设置其他默认参数
config.setTaskType(this.type.equals("1") ? "detect" : "classify");
config.setModelName(this.modelPath);
config.setModelPath(this.modelPath);
config.setLearningRate(0.01);
config.setConfThresh(0.25);
config.setIouThresh(0.45);
config.setSave(true);
config.setSaveTxt(true);
return config;
}
}

@ -205,8 +205,8 @@ public class YoloConfig {
/**
* mAP, precision, recall
*/
private String metrics;
/**
* ID
*/

@ -316,6 +316,7 @@ public class PythonVirtualEnvServiceImpl implements PythonVirtualEnvService {
}
// 构建虚拟环境路径Python目录/envName
String venvPath = pythonDir + File.separator +"venv"+File.separator + envName;
File venvDir = new File(venvPath);

@ -7,6 +7,7 @@ import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.TrainSaveReq
import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO;
import com.baomidou.mybatisplus.extension.service.IService;
import jakarta.validation.Valid;
import org.springframework.web.multipart.MultipartFile;
import java.util.List;
@ -65,6 +66,7 @@ public interface TrainService extends IService<TrainDO> {
public boolean init(TrainSaveReqVO createReqVO);
List<String> testImages();
/**

@ -9,16 +9,20 @@ import cn.iocoder.yudao.module.annotation.dal.dataobject.datas.DatasDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkInfoDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.types.TypesDO;
import cn.iocoder.yudao.module.annotation.dal.mysql.train.TrainMapper;
import cn.iocoder.yudao.module.annotation.dal.mysql.trainresult.TrainResultMapper;
import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig;
import cn.iocoder.yudao.module.annotation.dal.dataobject.types.TypesDO;
import cn.iocoder.yudao.module.annotation.dal.mysql.train.TrainMapper;
import cn.iocoder.yudao.module.annotation.service.MarkInfo.AnnotationData;
import cn.iocoder.yudao.module.annotation.service.MarkInfo.MarkInfoService;
import cn.iocoder.yudao.module.annotation.service.datas.DatasService;
import cn.iocoder.yudao.module.annotation.service.mark.MarkService;
import cn.iocoder.yudao.module.annotation.service.types.TypesService;
import cn.iocoder.yudao.module.annotation.service.yolo.YoloOperationService;
import cn.iocoder.yudao.module.system.dal.dataobject.dict.DictDataDO;
import cn.iocoder.yudao.module.system.service.dict.DictDataService;
@ -29,6 +33,7 @@ import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.multipart.MultipartFile;
import javax.imageio.ImageIO;
@ -90,6 +95,7 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
// 返回
return train.getId();
}
@Resource
YoloOperationService yoloOperationService;
public boolean init(TrainSaveReqVO createReqVO) {
@ -127,7 +133,9 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
// 生成YOLO配置文件
generateYamlConfig(typesList, taskType, yoloDatasetPath);
update( new UpdateWrapper<TrainDO>().eq("id", createReqVO.getId()).set("train_type", 6));
yoloOperationService.generateTrainPythonScript(train, yoloDatasetPath);
log.info("YOLO数据集生成完成: {}", yoloDatasetPath);
return true;
@ -137,6 +145,7 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
}
}
@Override
public List<String> testImages() {
Map<String, DictDataDO> dictDataMap = dictDataService.getDictDataList("visual_annotation_conf");
@ -337,6 +346,7 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
*/
private String createYoloDatasetPath(TrainDO train, String customPath) {
String basePath = customPath != null ? customPath : System.getProperty("user.dir") + "/yolo_datasets";
String datasetPath = basePath + "/datas_" + train.getId();
File datasetDir = new File(datasetPath);
@ -433,7 +443,8 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
File labelFile = new File(labelDir, labelFileName);
List<MarkInfoDO> markInfoList = markInfoService.list(new QueryWrapper<MarkInfoDO>()
.eq("mark_id", mark.getId()));
.eq("data_id", mark.getDataId()).eq("mark_id", mark.getId()));
try (PrintWriter writer = new PrintWriter(new FileWriter(labelFile))) {
for (MarkInfoDO markInfo : markInfoList) {
@ -776,6 +787,7 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
return gpuList;
}
public static void main(String[] args) {
File dir = new File("D:/data");
File[] files = dir.listFiles();
@ -786,6 +798,7 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
}
}
System.out.println(fileNames);
}
private SystemInfoRespVO.PythonInfo getPythonInfo() {
SystemInfoRespVO.PythonInfo pythonInfo = new SystemInfoRespVO.PythonInfo();

@ -1,11 +1,13 @@
package cn.iocoder.yudao.module.annotation.service.yolo;
import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultSaveReqVO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO;
import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig;
import cn.iocoder.yudao.module.annotation.service.python.PythonVirtualEnvService;
import cn.iocoder.yudao.module.annotation.service.python.vo.PythonVirtualEnvInfo;
import cn.iocoder.yudao.module.annotation.service.train.TrainService;
import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService;
import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService;
@ -15,6 +17,10 @@ import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Lazy;
import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService;
import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.io.BufferedReader;
@ -42,6 +48,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
@Resource
private TrainResultService trainResultService;
@Resource
private YoloServiceClient yoloServiceClient;
@ -51,17 +58,16 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// 异步执行器
private final Executor asyncExecutor = Executors.newCachedThreadPool();
@Override
public CompletableFuture<YoloConfig> trainAsync(String pythonPath, String envName, YoloConfig config) {
return CompletableFuture.supplyAsync(() -> {
config.setStatus("running");
config.setProgress(0);
config.setTaskId(UUID.randomUUID().toString());
// 生成训练ID
Integer trainId = config.getTrainId();
// 记录上一轮的epoch用于检测epoch完成
int[] lastSavedEpoch = {0};
@ -76,16 +82,13 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// 构建训练命令
String command = buildTrainCommand(envInfo, config);
log.info("开始YOLO训练命令: {}", command);
// 执行训练
ProcessBuilder processBuilder = new ProcessBuilder();
// 在Windows上使用cmd在Linux/Mac上使用bash
String os = System.getProperty("os.name").toLowerCase();
if (os.contains("win")) {
processBuilder.command("cmd", "/c", command);
// 设置环境变量确保正确处理中文路径
processBuilder.environment().put("PYTHONIOENCODING", "utf-8");
processBuilder.environment().put("LANG", "zh_CN.UTF-8");
} else {
@ -93,36 +96,35 @@ public class YoloOperationServiceImpl implements YoloOperationService {
processBuilder.environment().put("PYTHONIOENCODING", "utf-8");
}
Process process = processBuilder.start();
BufferedReader stdoutReader = new BufferedReader(new InputStreamReader(process.getInputStream()));
BufferedReader stderrReader = new BufferedReader(new InputStreamReader(process.getErrorStream()));
// BufferedReader stderrReader2 = new BufferedReader(new InputStreamReader(process.getOutputStream()));
StringBuilder fullOutput = new StringBuilder();
// 使用单独的线程读取stdout和stderr避免死锁
String stdoutLine, stderrLine = null;
while ((stdoutLine = stdoutReader.readLine()) != null || (stderrLine = stderrReader.readLine()) != null ) {
String stdoutLine;
String stderrLine = null;
// 同时读取标准输出和错误输出
while ((stdoutLine = stdoutReader.readLine()) != null || (stderrLine = stderrReader.readLine()) != null) {
if (stdoutLine != null) {
log.info("YOLO训练日志: {}", stdoutLine);
fullOutput.append(stdoutLine).append("\n");
// 详细解析YOLO训练进度和结果
manageTrainingLogInRedis(trainId, stdoutLine);
config.setLogMessage(stdoutLine);
// 检测 "Results saved to" 并解析路径和识别率
detectAndParseResultsSaved(stdoutLine, config, trainId);
parseTrainingProgress(stdoutLine, config, trainId, lastSavedEpoch);
}
if (stderrLine != null) {
log.info("YOLO训练日志: {}", stderrLine);
fullOutput.append(stderrLine).append("\n");
manageTrainingLogInRedis(trainId, stderrLine);
config.setLogMessage(stderrLine);
// 检测 "Results saved to" 并解析路径和识别率
detectAndParseResultsSaved(stderrLine, config, trainId);
parseTrainingProgress(stderrLine, config, trainId, lastSavedEpoch);
}
}
@ -148,12 +150,14 @@ public class YoloOperationServiceImpl implements YoloOperationService {
config.setStatus("failed");
config.setErrorMessage("训练异常: " + e.getMessage());
}
return config;
}, asyncExecutor);
}
@Override
public YoloConfig validate(String pythonPath, String envName, YoloConfig config) {
// 优先尝试使用socket服务如果启用且可用
if (useSocketService && yoloServiceClient != null && yoloServiceClient.isServiceAvailable()) {
log.info("使用socket服务执行YOLO验证");
@ -233,6 +237,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
@Override
public YoloConfig predict(String pythonPath, String envName, YoloConfig config) {
// 优先尝试使用socket服务如果启用且可用
if (useSocketService && yoloServiceClient != null && yoloServiceClient.isServiceAvailable()) {
log.info("使用socket服务执行YOLO预测");
@ -281,6 +286,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
if (stdoutLine != null) {
output.append(stdoutLine).append("\n");
log.info("YOLO预测日志: {}", stdoutLine);
if (stdoutLine.contains("Results saved to")) {
stdoutLine = stdoutLine.replace("\u001B","").replace("[1m","").replace("[0m","");
@ -291,6 +297,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
output.append(stderrLine).append("\n");
log.info("YOLO预测日志: {}", stderrLine);
if (stderrLine.contains("Results saved to")) {
stderrLine = stderrLine.replace("\u001B","").replace("[1m","").replace("[0m","");
config.setOutputPath(stderrLine.split(" ")[3]);
@ -445,13 +452,16 @@ public class YoloOperationServiceImpl implements YoloOperationService {
if ("conda".equals(envInfo.getVirtualEnvType())) {
command.append("conda run -n ").append(new File(envInfo.getVirtualEnvPath()).getName()).append(" ");
} else {
command.append(normalizePath(envInfo.getVirtualEnvPythonPath())).append(" ");
command.append(envInfo.getVirtualEnvPythonPath()).append(" ");
}
// 规范化路径,处理双斜杠和中文路径问题
String datasetPath = normalizePath(config.getDatasetPath());
String outputPath = normalizePath(config.getOutputPath());
String modelPath = config.getModelPath() != null ? normalizePath(config.getModelPath()) : null;
// 调试日志:检查模型路径
log.info("模型路径配置 - modelPath: {}, modelName: {}", config.getModelPath(), config.getModelName());
@ -485,6 +495,36 @@ public class YoloOperationServiceImpl implements YoloOperationService {
//
// command.append("project=r'").append(escapePythonString(outputPath)).append("')\"");
//
command.append("-c \"");
command.append("import sys; sys.path.insert(0, '.'); ");
command.append("from ultralytics import YOLO; ");
if (modelPath != null && !modelPath.trim().isEmpty()) {
command.append("model = YOLO(r'").append(escapePythonString(modelPath)).append("'); ");
} else {
command.append("model = YOLO('").append(config.getModelName()).append("'); ");
}
// 使用raw string来处理中文路径
command.append("model.train(data=r'").append(escapePythonString(datasetPath)).append("', ");
command.append("epochs=").append(config.getEpochs()).append(", ");
command.append("batch=").append(config.getBatchSize()).append(", ");
command.append("imgsz=").append(config.getImageSize()).append(", ");
// 取消注释设备和学习率配置
// command.append("device=").append(config.getDevice()).append(", ");
// command.append("lr0=").append(config.getLearningRate()).append(", ");
// command.append("workers=").append(config.getWorkers()).append(", ");
//
if (config.getSavePeriod() > 0) {
command.append("save_period=").append(config.getSavePeriod()).append(", ");
}
command.append("project=r'").append(escapePythonString(outputPath)).append("')\"");
return command.toString();
}
@ -515,15 +555,17 @@ public class YoloOperationServiceImpl implements YoloOperationService {
return command.toString();
}
private String buildPredictCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) {
StringBuilder command = new StringBuilder();
if ("conda".equals(envInfo.getVirtualEnvType())) {
command.append("conda run -n ").append(new File(envInfo.getVirtualEnvPath()).getName()).append(" ");
} else {
command.append(envInfo.getVirtualEnvPythonPath()).append(" ");
}
String inputPath = normalizePath(config.getInputPath());
String outputPath = normalizePath(config.getOutputPath());
String modelPath = normalizePath(config.getModelPath());
@ -533,6 +575,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
command.append("from ultralytics import YOLO; ");
command.append("model = YOLO(r'").append(escapePythonString(modelPath)).append("'); ");
command.append("results = model.predict(source=r'").append(escapePythonString(inputPath)).append("', ");
// 在这里我们将 save=true 改为 save=True
command.append("save=").append(config.isSave() ? "True" : "False").append(", ");
command.append("project='").append(escapePythonString(outputPath)).append("')\"");
@ -541,6 +584,8 @@ public class YoloOperationServiceImpl implements YoloOperationService {
}
private String buildExportCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) {
StringBuilder command = new StringBuilder();
@ -582,6 +627,106 @@ public class YoloOperationServiceImpl implements YoloOperationService {
return command.toString();
}
/**
*
*/
private void parseTrainingProgress(String line, YoloConfig config, Integer trainId, int[] lastSavedEpoch) {
try {
String trimmedLine = line.trim();
// 解析Epoch信息: "Epoch 1/100" 或 "1/100"
if (trimmedLine.contains("Epoch") || (trimmedLine.matches("\\d+/\\d+"))) {
String[] parts = trimmedLine.split("\\s+");
for (String part : parts) {
if (part.contains("/")) {
String[] epochParts = part.split("/");
if (epochParts.length == 2) {
int currentEpoch = Integer.parseInt(epochParts[0]);
int totalEpochs = Integer.parseInt(epochParts[1]);
config.setCurrentEpoch(currentEpoch);
// 计算总体进度epoch进度 + batch进度
int epochProgress = (int) ((double) currentEpoch / totalEpochs * 90); // 90%用于epoch10%用于最终处理
int totalProgress = Math.min(epochProgress + (config.getBatchProgress() / 10), 95);
config.setProgress(totalProgress);
// 当epoch完成时保存训练信息
if (currentEpoch > lastSavedEpoch[0] && currentEpoch <= config.getEpochs()) {
saveTrainEpochInfo(trainId, currentEpoch, config.getEpochs(), config);
lastSavedEpoch[0] = currentEpoch;
}
}
}
}
}
// 解析批次信息: "1/100" 或 "batch 1/100"
if (trimmedLine.contains("batch") || trimmedLine.matches("\\d+/\\d+.*loss")) {
if (trimmedLine.contains("batch")) {
String[] parts = trimmedLine.split("\\s+");
for (int i = 0; i < parts.length; i++) {
if (parts[i].equals("batch") && i + 1 < parts.length && parts[i + 1].contains("/")) {
String[] batchParts = parts[i + 1].split("/");
if (batchParts.length == 2) {
int currentBatch = Integer.parseInt(batchParts[0]);
int totalBatches = Integer.parseInt(batchParts[1]);
config.setCurrentBatch(currentBatch);
config.setTotalBatches(totalBatches);
config.setBatchProgress((int) ((double) currentBatch / totalBatches * 100));
}
}
}
}
}
// 解析损失值: "loss: 0.1234", "box_loss: 0.1", "cls_loss: 0.05", "obj_loss: 0.02"
if (trimmedLine.contains("loss")) {
String[] lossParts = trimmedLine.split(",");
for (String lossPart : lossParts) {
lossPart = lossPart.trim();
if (lossPart.startsWith("loss:")) {
String value = lossPart.substring(5).trim();
config.setLoss(Double.parseDouble(value));
} else if (lossPart.startsWith("box_loss:")) {
// 可以添加更详细的损失记录
} else if (lossPart.startsWith("cls_loss:")) {
// 可以添加更详细的损失记录
} else if (lossPart.startsWith("obj_loss:")) {
// 可以添加更详细的损失记录
}
}
}
// 解析训练指标: "metrics/precision(B): 0.85", "metrics/recall(B): 0.78", "metrics/mAP50(B): 0.82", "metrics/mAP50-95(B): 0.65"
if (trimmedLine.contains("metrics/")) {
String[] metricParts = trimmedLine.split(",");
for (String metricPart : metricParts) {
metricPart = metricPart.trim();
if (metricPart.startsWith("metrics/precision")) {
String value = metricPart.split(":")[1].trim();
config.setPrecision(Double.parseDouble(value));
} else if (metricPart.startsWith("metrics/recall")) {
String value = metricPart.split(":")[1].trim();
config.setRecall(Double.parseDouble(value));
} else if (metricPart.startsWith("metrics/mAP50-95") || metricPart.startsWith("metrics/mAP")) {
String value = metricPart.split(":")[1].trim();
config.setMap(Double.parseDouble(value));
}
}
}
// 解析学习率: "lr: 0.001"
if (trimmedLine.contains("lr:")) {
// 可以记录学习率变化
}
} catch (Exception e) {
log.debug("解析训练进度时出错: {}", line, e);
// 忽略解析错误,继续处理下一行
}
}
/**
*
@ -771,9 +916,12 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// 不影响训练完成状态
}
}
@Resource
@Lazy
TrainService trainService;
/**
* TrainResultDO
@ -781,8 +929,11 @@ public class YoloOperationServiceImpl implements YoloOperationService {
private void saveTrainResultDO(TrainResultDO trainResult) {
try {
// 创建SaveReqVO
TrainResultSaveReqVO saveReqVO =
new TrainResultSaveReqVO();
// 设置值
saveReqVO.setTrainId(trainResult.getTrainId());
@ -792,8 +943,11 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// 保存
trainResultService.createTrainResult(saveReqVO);
trainService.update(new UpdateWrapper<TrainDO>().eq("id", trainResult.getTrainId())
.set("train_type", 4));
} catch (Exception e) {
log.error("保存训练结果失败: {}", trainResult, e);
@ -915,6 +1069,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
return escaped.toString();
}
@Resource
private RedisUtil redisUtil;
/**
@ -1595,4 +1750,6 @@ public class YoloOperationServiceImpl implements YoloOperationService {
}
}

@ -48,6 +48,7 @@ spring:
master:
name: sy
url: jdbc:mysql://192.168.1.21:3306/${spring.datasource.dynamic.datasource.master.name}?useSSL=false&serverTimezone=Asia/Shanghai&allowPublicKeyRetrieval=true&nullCatalogMeansCurrent=true # MySQL Connector/J 8.X 连接的示例
username: root
password: upright
# Redis 配置。Redisson 默认的配置足够使用,一般不需要进行调优

Loading…
Cancel
Save