# Conflicts:
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/train/TrainController.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/traininfo/TrainInfoController.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/controller/admin/yolo/YoloOperationController.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/dataobject/train/TrainDO.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/dal/yolo/YoloConfig.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/python/PythonVirtualEnvServiceImpl.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainService.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/train/TrainServiceImpl.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloOperationService.java
#	yudao-module-annotation/yudao-module-annotation-biz/src/main/java/cn/iocoder/yudao/module/annotation/service/yolo/YoloOperationServiceImpl.java
#	yudao-server/src/main/resources/application-dev.yaml
master
LAPTOP-S9HJSOEB\昊天 3 months ago
commit 4363ea079a

@ -81,7 +81,7 @@ public class TrainController {
@Operation(summary = "进行训练") @Operation(summary = "进行训练")
// @Parameter(name = "id", description = "编号", required = true, example = "1024") // @Parameter(name = "id", description = "编号", required = true, example = "1024")
@PreAuthorize("@ss.hasPermission('annotation:train:query')") @PreAuthorize("@ss.hasPermission('annotation:train:query')")
public CommonResult<String> testRecognition(@RequestBody RecognitionResult recognitionReqVO) { public CommonResult<String> testRecognition(@RequestBody RecognitionResult recognitionReqVO) {
String train = trainService.testRecognition(recognitionReqVO.getImageUrl(),recognitionReqVO.getTrainId()); String train = trainService.testRecognition(recognitionReqVO.getImageUrl(),recognitionReqVO.getTrainId());
return success( train); return success( train);

@ -69,6 +69,7 @@ public class TrainInfoController {
TrainInfoDO trainInfo = trainInfoService.getTrainInfo(id); TrainInfoDO trainInfo = trainInfoService.getTrainInfo(id);
return success(BeanUtils.toBean(trainInfo, TrainInfoRespVO.class)); return success(BeanUtils.toBean(trainInfo, TrainInfoRespVO.class));
} }
@Resource @Resource
RedisUtil redisUtil; RedisUtil redisUtil;
@ -76,6 +77,7 @@ public class TrainInfoController {
@Operation(summary = "获得识别结果") @Operation(summary = "获得识别结果")
@Parameter(name = "id", description = "编号", required = true, example = "1024") @Parameter(name = "id", description = "编号", required = true, example = "1024")
@PreAuthorize("@ss.hasPermission('annotation:train-Info:query')") @PreAuthorize("@ss.hasPermission('annotation:train-Info:query')")
public CommonResult<List<String>> getTrainInfoList(@RequestParam("trainId") Integer id) { public CommonResult<List<String>> getTrainInfoList(@RequestParam("trainId") Integer id) {
List<Object> result = redisUtil.lGet("yolo:training_log:" + id, 0, -1); List<Object> result = redisUtil.lGet("yolo:training_log:" + id, 0, -1);
List<String> result1 = result.stream().map(Object::toString).toList(); List<String> result1 = result.stream().map(Object::toString).toList();

@ -87,6 +87,7 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id); TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue()); // trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap); YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id); config.setTrainId(id);
config.setTrainId(id); config.setTrainId(id);
@ -105,6 +106,7 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id); TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue()); // trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap); YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id); config.setTrainId(id);
YoloConfig result = yoloOperationService.validate( YoloConfig result = yoloOperationService.validate(
dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config);
@ -120,6 +122,7 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id); TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue()); // trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap); YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id); config.setTrainId(id);
YoloConfig result = yoloOperationService.predict( YoloConfig result = yoloOperationService.predict(
dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config);
@ -135,6 +138,7 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id); TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue()); // trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap); YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id); config.setTrainId(id);
YoloConfig result = yoloOperationService.classify( YoloConfig result = yoloOperationService.classify(
dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config);
@ -150,6 +154,7 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id); TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue()); // trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap); YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id); config.setTrainId(id);
YoloConfig result = yoloOperationService.export( YoloConfig result = yoloOperationService.export(
dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config);
@ -165,6 +170,7 @@ public class YoloOperationController {
TrainDO trainDO = trainService.getTrain(id); TrainDO trainDO = trainService.getTrain(id);
// trainDO.setPath(dictDataMap.get("training_address").getValue()); // trainDO.setPath(dictDataMap.get("training_address").getValue());
YoloConfig config = trainDO.getYolofig(dictDataMap); YoloConfig config = trainDO.getYolofig(dictDataMap);
config.setTrainId(id); config.setTrainId(id);
YoloConfig result = yoloOperationService.track( YoloConfig result = yoloOperationService.track(
dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config); dictDataMap.get("python_path").getValue(), dictDataMap.get("python_venv").getValue(), config);

@ -75,54 +75,63 @@ public class TrainDO extends BaseDO {
private String type; private String type;
private String outPath; private String outPath;
public YoloConfig getYolofig(Map<String, DictDataDO> dictDataMap) { public YoloConfig getYolofig(Map<String, DictDataDO> dictDataMap) {
YoloConfig config = new YoloConfig(); YoloConfig config = new YoloConfig();
// 设置数据相关配置 // 设置数据相关配置
config.setDatasetPath(this.path); config.setDatasetPath(this.path);
// 设置训练参数 // 设置训练参数
if (this.size != null) { if (this.size != null) {
config.setEpochs(this.size); config.setEpochs(this.size);
} }
if (this.round != null) { if (this.round != null) {
config.setBatchSize(this.round); config.setBatchSize(this.round);
} if (this.round != null) {
if (this.imageSize != null) { config.setEpochs(this.round);
config.setImageSize(this.imageSize); }
} if (this.size != null) {
config.setBatchSize(this.size);
// 设置数据集比例YOLO中需要转换 }
// 注意YOLO通常通过数据集配置文件来设置train/val/test比例 if (this.imageSize != null) {
// 这里我们可以将比例信息保存到配置中,供后续处理使用 config.setImageSize(this.imageSize);
}
// 设置预训练模型路径
if (this.type.equals("1")) { // 设置数据集比例YOLO中需要转换
config.setModelPath(dictDataMap.get("detect_path").getValue()); // 注意YOLO通常通过数据集配置文件来设置train/val/test比例
config.setPretrained(true); // 这里我们可以将比例信息保存到配置中,供后续处理使用
}else {
config.setModelPath(dictDataMap.get("classify_path").getValue()); // 设置预训练模型路径
config.setPretrained(false); if (this.type.equals("1")) {
} config.setModelPath(dictDataMap.get("detect_path").getValue());
config.setPretrained(true);
// 设置GPU/CPU设备 } else {
config.setModelPath(dictDataMap.get("classify_path").getValue());
config.setPretrained(false);
}
// 设置GPU/CPU设备
// 设置输出路径默认在数据集路径下的runs/train目录
if (this.path != null) {
config.setOutputPath(this.path + "/runs/train");
}
// 设置其他默认参数
// 设置输出路径默认在数据集路径下的runs/train目录 config.setTaskType(this.type.equals("1") ? "detect" : "classify");
if (this.path != null) { config.setModelName(this.modelPath);
config.setOutputPath(this.path + "/runs/train"); config.setModelPath(this.modelPath);
config.setLearningRate(0.01);
config.setConfThresh(0.25);
config.setIouThresh(0.45);
config.setSave(true);
config.setSaveTxt(true);
return config;
} }
// 设置其他默认参数
config.setTaskType(this.type.equals("1") ? "detect" : "classify");
config.setModelName(this.modelPath);
config.setModelPath(this.modelPath);
config.setLearningRate(0.01);
config.setConfThresh(0.25);
config.setIouThresh(0.45);
config.setSave(true);
config.setSaveTxt(true);
return config; return config;
} }
} }

@ -205,8 +205,8 @@ public class YoloConfig {
/** /**
* mAP, precision, recall * mAP, precision, recall
*/ */
private String metrics; private String metrics;
/** /**
* ID * ID
*/ */

@ -316,6 +316,7 @@ public class PythonVirtualEnvServiceImpl implements PythonVirtualEnvService {
} }
// 构建虚拟环境路径Python目录/envName // 构建虚拟环境路径Python目录/envName
String venvPath = pythonDir + File.separator +"venv"+File.separator + envName; String venvPath = pythonDir + File.separator +"venv"+File.separator + envName;
File venvDir = new File(venvPath); File venvDir = new File(venvPath);

@ -7,6 +7,7 @@ import cn.iocoder.yudao.module.annotation.controller.admin.train.vo.TrainSaveReq
import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO;
import com.baomidou.mybatisplus.extension.service.IService; import com.baomidou.mybatisplus.extension.service.IService;
import jakarta.validation.Valid; import jakarta.validation.Valid;
import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartFile;
import java.util.List; import java.util.List;
@ -65,6 +66,7 @@ public interface TrainService extends IService<TrainDO> {
public boolean init(TrainSaveReqVO createReqVO); public boolean init(TrainSaveReqVO createReqVO);
List<String> testImages(); List<String> testImages();
/** /**

@ -9,16 +9,20 @@ import cn.iocoder.yudao.module.annotation.dal.dataobject.datas.DatasDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkInfoDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.mark.MarkInfoDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.types.TypesDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.types.TypesDO;
import cn.iocoder.yudao.module.annotation.dal.mysql.train.TrainMapper; import cn.iocoder.yudao.module.annotation.dal.mysql.train.TrainMapper;
import cn.iocoder.yudao.module.annotation.dal.mysql.trainresult.TrainResultMapper; import cn.iocoder.yudao.module.annotation.dal.mysql.trainresult.TrainResultMapper;
import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig; import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig;
import cn.iocoder.yudao.module.annotation.dal.dataobject.types.TypesDO;
import cn.iocoder.yudao.module.annotation.dal.mysql.train.TrainMapper;
import cn.iocoder.yudao.module.annotation.service.MarkInfo.AnnotationData; import cn.iocoder.yudao.module.annotation.service.MarkInfo.AnnotationData;
import cn.iocoder.yudao.module.annotation.service.MarkInfo.MarkInfoService; import cn.iocoder.yudao.module.annotation.service.MarkInfo.MarkInfoService;
import cn.iocoder.yudao.module.annotation.service.datas.DatasService; import cn.iocoder.yudao.module.annotation.service.datas.DatasService;
import cn.iocoder.yudao.module.annotation.service.mark.MarkService; import cn.iocoder.yudao.module.annotation.service.mark.MarkService;
import cn.iocoder.yudao.module.annotation.service.types.TypesService; import cn.iocoder.yudao.module.annotation.service.types.TypesService;
import cn.iocoder.yudao.module.annotation.service.yolo.YoloOperationService; import cn.iocoder.yudao.module.annotation.service.yolo.YoloOperationService;
import cn.iocoder.yudao.module.system.dal.dataobject.dict.DictDataDO; import cn.iocoder.yudao.module.system.dal.dataobject.dict.DictDataDO;
import cn.iocoder.yudao.module.system.service.dict.DictDataService; import cn.iocoder.yudao.module.system.service.dict.DictDataService;
@ -29,6 +33,7 @@ import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated; import org.springframework.validation.annotation.Validated;
import org.springframework.web.multipart.MultipartFile; import org.springframework.web.multipart.MultipartFile;
import javax.imageio.ImageIO; import javax.imageio.ImageIO;
@ -90,6 +95,7 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
// 返回 // 返回
return train.getId(); return train.getId();
} }
@Resource @Resource
YoloOperationService yoloOperationService; YoloOperationService yoloOperationService;
public boolean init(TrainSaveReqVO createReqVO) { public boolean init(TrainSaveReqVO createReqVO) {
@ -127,7 +133,9 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
// 生成YOLO配置文件 // 生成YOLO配置文件
generateYamlConfig(typesList, taskType, yoloDatasetPath); generateYamlConfig(typesList, taskType, yoloDatasetPath);
update( new UpdateWrapper<TrainDO>().eq("id", createReqVO.getId()).set("train_type", 6)); update( new UpdateWrapper<TrainDO>().eq("id", createReqVO.getId()).set("train_type", 6));
yoloOperationService.generateTrainPythonScript(train, yoloDatasetPath); yoloOperationService.generateTrainPythonScript(train, yoloDatasetPath);
log.info("YOLO数据集生成完成: {}", yoloDatasetPath); log.info("YOLO数据集生成完成: {}", yoloDatasetPath);
return true; return true;
@ -137,6 +145,7 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
} }
} }
@Override @Override
public List<String> testImages() { public List<String> testImages() {
Map<String, DictDataDO> dictDataMap = dictDataService.getDictDataList("visual_annotation_conf"); Map<String, DictDataDO> dictDataMap = dictDataService.getDictDataList("visual_annotation_conf");
@ -337,6 +346,7 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
*/ */
private String createYoloDatasetPath(TrainDO train, String customPath) { private String createYoloDatasetPath(TrainDO train, String customPath) {
String basePath = customPath != null ? customPath : System.getProperty("user.dir") + "/yolo_datasets"; String basePath = customPath != null ? customPath : System.getProperty("user.dir") + "/yolo_datasets";
String datasetPath = basePath + "/datas_" + train.getId(); String datasetPath = basePath + "/datas_" + train.getId();
File datasetDir = new File(datasetPath); File datasetDir = new File(datasetPath);
@ -433,7 +443,8 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
File labelFile = new File(labelDir, labelFileName); File labelFile = new File(labelDir, labelFileName);
List<MarkInfoDO> markInfoList = markInfoService.list(new QueryWrapper<MarkInfoDO>() List<MarkInfoDO> markInfoList = markInfoService.list(new QueryWrapper<MarkInfoDO>()
.eq("mark_id", mark.getId()));
.eq("data_id", mark.getDataId()).eq("mark_id", mark.getId()));
try (PrintWriter writer = new PrintWriter(new FileWriter(labelFile))) { try (PrintWriter writer = new PrintWriter(new FileWriter(labelFile))) {
for (MarkInfoDO markInfo : markInfoList) { for (MarkInfoDO markInfo : markInfoList) {
@ -776,6 +787,7 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
return gpuList; return gpuList;
} }
public static void main(String[] args) { public static void main(String[] args) {
File dir = new File("D:/data"); File dir = new File("D:/data");
File[] files = dir.listFiles(); File[] files = dir.listFiles();
@ -786,6 +798,7 @@ public class TrainServiceImpl extends ServiceImpl<TrainMapper, TrainDO> implemen
} }
} }
System.out.println(fileNames); System.out.println(fileNames);
} }
private SystemInfoRespVO.PythonInfo getPythonInfo() { private SystemInfoRespVO.PythonInfo getPythonInfo() {
SystemInfoRespVO.PythonInfo pythonInfo = new SystemInfoRespVO.PythonInfo(); SystemInfoRespVO.PythonInfo pythonInfo = new SystemInfoRespVO.PythonInfo();

@ -1,11 +1,13 @@
package cn.iocoder.yudao.module.annotation.service.yolo; package cn.iocoder.yudao.module.annotation.service.yolo;
import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultSaveReqVO; import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultSaveReqVO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO;
import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO; import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO;
import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig; import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig;
import cn.iocoder.yudao.module.annotation.service.python.PythonVirtualEnvService; import cn.iocoder.yudao.module.annotation.service.python.PythonVirtualEnvService;
import cn.iocoder.yudao.module.annotation.service.python.vo.PythonVirtualEnvInfo; import cn.iocoder.yudao.module.annotation.service.python.vo.PythonVirtualEnvInfo;
import cn.iocoder.yudao.module.annotation.service.train.TrainService; import cn.iocoder.yudao.module.annotation.service.train.TrainService;
import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService; import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService;
import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService; import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService;
@ -15,6 +17,10 @@ import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Lazy; import org.springframework.context.annotation.Lazy;
import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService;
import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.io.BufferedReader; import java.io.BufferedReader;
@ -42,6 +48,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
@Resource @Resource
private TrainResultService trainResultService; private TrainResultService trainResultService;
@Resource @Resource
private YoloServiceClient yoloServiceClient; private YoloServiceClient yoloServiceClient;
@ -51,17 +58,16 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// 异步执行器 // 异步执行器
private final Executor asyncExecutor = Executors.newCachedThreadPool(); private final Executor asyncExecutor = Executors.newCachedThreadPool();
@Override @Override
public CompletableFuture<YoloConfig> trainAsync(String pythonPath, String envName, YoloConfig config) { public CompletableFuture<YoloConfig> trainAsync(String pythonPath, String envName, YoloConfig config) {
return CompletableFuture.supplyAsync(() -> { return CompletableFuture.supplyAsync(() -> {
config.setStatus("running"); config.setStatus("running");
config.setProgress(0); config.setProgress(0);
config.setTaskId(UUID.randomUUID().toString()); config.setTaskId(UUID.randomUUID().toString());
// 生成训练ID // 生成训练ID
Integer trainId = config.getTrainId(); Integer trainId = config.getTrainId();
// 记录上一轮的epoch用于检测epoch完成 // 记录上一轮的epoch用于检测epoch完成
int[] lastSavedEpoch = {0}; int[] lastSavedEpoch = {0};
@ -76,16 +82,13 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// 构建训练命令 // 构建训练命令
String command = buildTrainCommand(envInfo, config); String command = buildTrainCommand(envInfo, config);
log.info("开始YOLO训练命令: {}", command); log.info("开始YOLO训练命令: {}", command);
// 执行训练 // 执行训练
ProcessBuilder processBuilder = new ProcessBuilder(); ProcessBuilder processBuilder = new ProcessBuilder();
// 在Windows上使用cmd在Linux/Mac上使用bash
String os = System.getProperty("os.name").toLowerCase(); String os = System.getProperty("os.name").toLowerCase();
if (os.contains("win")) { if (os.contains("win")) {
processBuilder.command("cmd", "/c", command); processBuilder.command("cmd", "/c", command);
// 设置环境变量确保正确处理中文路径
processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); processBuilder.environment().put("PYTHONIOENCODING", "utf-8");
processBuilder.environment().put("LANG", "zh_CN.UTF-8"); processBuilder.environment().put("LANG", "zh_CN.UTF-8");
} else { } else {
@ -93,36 +96,35 @@ public class YoloOperationServiceImpl implements YoloOperationService {
processBuilder.environment().put("PYTHONIOENCODING", "utf-8"); processBuilder.environment().put("PYTHONIOENCODING", "utf-8");
} }
Process process = processBuilder.start(); Process process = processBuilder.start();
BufferedReader stdoutReader = new BufferedReader(new InputStreamReader(process.getInputStream())); BufferedReader stdoutReader = new BufferedReader(new InputStreamReader(process.getInputStream()));
BufferedReader stderrReader = new BufferedReader(new InputStreamReader(process.getErrorStream())); BufferedReader stderrReader = new BufferedReader(new InputStreamReader(process.getErrorStream()));
// BufferedReader stderrReader2 = new BufferedReader(new InputStreamReader(process.getOutputStream()));
StringBuilder fullOutput = new StringBuilder(); StringBuilder fullOutput = new StringBuilder();
String stdoutLine;
// 使用单独的线程读取stdout和stderr避免死锁 String stderrLine = null;
String stdoutLine, stderrLine = null;
while ((stdoutLine = stdoutReader.readLine()) != null || (stderrLine = stderrReader.readLine()) != null ) { // 同时读取标准输出和错误输出
while ((stdoutLine = stdoutReader.readLine()) != null || (stderrLine = stderrReader.readLine()) != null) {
if (stdoutLine != null) { if (stdoutLine != null) {
log.info("YOLO训练日志: {}", stdoutLine); log.info("YOLO训练日志: {}", stdoutLine);
fullOutput.append(stdoutLine).append("\n"); fullOutput.append(stdoutLine).append("\n");
// 详细解析YOLO训练进度和结果 // 详细解析YOLO训练进度和结果
manageTrainingLogInRedis(trainId, stdoutLine); manageTrainingLogInRedis(trainId, stdoutLine);
config.setLogMessage(stdoutLine); config.setLogMessage(stdoutLine);
// 检测 "Results saved to" 并解析路径和识别率
detectAndParseResultsSaved(stdoutLine, config, trainId); detectAndParseResultsSaved(stdoutLine, config, trainId);
parseTrainingProgress(stdoutLine, config, trainId, lastSavedEpoch);
} }
if (stderrLine != null) { if (stderrLine != null) {
log.info("YOLO训练日志: {}", stderrLine); log.info("YOLO训练日志: {}", stderrLine);
fullOutput.append(stderrLine).append("\n"); fullOutput.append(stderrLine).append("\n");
manageTrainingLogInRedis(trainId, stderrLine); manageTrainingLogInRedis(trainId, stderrLine);
config.setLogMessage(stderrLine); config.setLogMessage(stderrLine);
// 检测 "Results saved to" 并解析路径和识别率
detectAndParseResultsSaved(stderrLine, config, trainId); detectAndParseResultsSaved(stderrLine, config, trainId);
parseTrainingProgress(stderrLine, config, trainId, lastSavedEpoch);
} }
} }
@ -148,12 +150,14 @@ public class YoloOperationServiceImpl implements YoloOperationService {
config.setStatus("failed"); config.setStatus("failed");
config.setErrorMessage("训练异常: " + e.getMessage()); config.setErrorMessage("训练异常: " + e.getMessage());
} }
return config; return config;
}, asyncExecutor); }, asyncExecutor);
} }
@Override @Override
public YoloConfig validate(String pythonPath, String envName, YoloConfig config) { public YoloConfig validate(String pythonPath, String envName, YoloConfig config) {
// 优先尝试使用socket服务如果启用且可用 // 优先尝试使用socket服务如果启用且可用
if (useSocketService && yoloServiceClient != null && yoloServiceClient.isServiceAvailable()) { if (useSocketService && yoloServiceClient != null && yoloServiceClient.isServiceAvailable()) {
log.info("使用socket服务执行YOLO验证"); log.info("使用socket服务执行YOLO验证");
@ -233,6 +237,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
@Override @Override
public YoloConfig predict(String pythonPath, String envName, YoloConfig config) { public YoloConfig predict(String pythonPath, String envName, YoloConfig config) {
// 优先尝试使用socket服务如果启用且可用 // 优先尝试使用socket服务如果启用且可用
if (useSocketService && yoloServiceClient != null && yoloServiceClient.isServiceAvailable()) { if (useSocketService && yoloServiceClient != null && yoloServiceClient.isServiceAvailable()) {
log.info("使用socket服务执行YOLO预测"); log.info("使用socket服务执行YOLO预测");
@ -281,6 +286,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
if (stdoutLine != null) { if (stdoutLine != null) {
output.append(stdoutLine).append("\n"); output.append(stdoutLine).append("\n");
log.info("YOLO预测日志: {}", stdoutLine); log.info("YOLO预测日志: {}", stdoutLine);
if (stdoutLine.contains("Results saved to")) { if (stdoutLine.contains("Results saved to")) {
stdoutLine = stdoutLine.replace("\u001B","").replace("[1m","").replace("[0m",""); stdoutLine = stdoutLine.replace("\u001B","").replace("[1m","").replace("[0m","");
@ -291,6 +297,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
output.append(stderrLine).append("\n"); output.append(stderrLine).append("\n");
log.info("YOLO预测日志: {}", stderrLine); log.info("YOLO预测日志: {}", stderrLine);
if (stderrLine.contains("Results saved to")) { if (stderrLine.contains("Results saved to")) {
stderrLine = stderrLine.replace("\u001B","").replace("[1m","").replace("[0m",""); stderrLine = stderrLine.replace("\u001B","").replace("[1m","").replace("[0m","");
config.setOutputPath(stderrLine.split(" ")[3]); config.setOutputPath(stderrLine.split(" ")[3]);
@ -445,13 +452,16 @@ public class YoloOperationServiceImpl implements YoloOperationService {
if ("conda".equals(envInfo.getVirtualEnvType())) { if ("conda".equals(envInfo.getVirtualEnvType())) {
command.append("conda run -n ").append(new File(envInfo.getVirtualEnvPath()).getName()).append(" "); command.append("conda run -n ").append(new File(envInfo.getVirtualEnvPath()).getName()).append(" ");
} else { } else {
command.append(normalizePath(envInfo.getVirtualEnvPythonPath())).append(" "); command.append(normalizePath(envInfo.getVirtualEnvPythonPath())).append(" ");
command.append(envInfo.getVirtualEnvPythonPath()).append(" ");
} }
// 规范化路径,处理双斜杠和中文路径问题 // 规范化路径,处理双斜杠和中文路径问题
String datasetPath = normalizePath(config.getDatasetPath()); String datasetPath = normalizePath(config.getDatasetPath());
String outputPath = normalizePath(config.getOutputPath()); String outputPath = normalizePath(config.getOutputPath());
String modelPath = config.getModelPath() != null ? normalizePath(config.getModelPath()) : null; String modelPath = config.getModelPath() != null ? normalizePath(config.getModelPath()) : null;
// 调试日志:检查模型路径 // 调试日志:检查模型路径
log.info("模型路径配置 - modelPath: {}, modelName: {}", config.getModelPath(), config.getModelName()); log.info("模型路径配置 - modelPath: {}, modelName: {}", config.getModelPath(), config.getModelName());
@ -485,6 +495,36 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// //
// command.append("project=r'").append(escapePythonString(outputPath)).append("')\""); // command.append("project=r'").append(escapePythonString(outputPath)).append("')\"");
// //
command.append("-c \"");
command.append("import sys; sys.path.insert(0, '.'); ");
command.append("from ultralytics import YOLO; ");
if (modelPath != null && !modelPath.trim().isEmpty()) {
command.append("model = YOLO(r'").append(escapePythonString(modelPath)).append("'); ");
} else {
command.append("model = YOLO('").append(config.getModelName()).append("'); ");
}
// 使用raw string来处理中文路径
command.append("model.train(data=r'").append(escapePythonString(datasetPath)).append("', ");
command.append("epochs=").append(config.getEpochs()).append(", ");
command.append("batch=").append(config.getBatchSize()).append(", ");
command.append("imgsz=").append(config.getImageSize()).append(", ");
// 取消注释设备和学习率配置
// command.append("device=").append(config.getDevice()).append(", ");
// command.append("lr0=").append(config.getLearningRate()).append(", ");
// command.append("workers=").append(config.getWorkers()).append(", ");
//
if (config.getSavePeriod() > 0) {
command.append("save_period=").append(config.getSavePeriod()).append(", ");
}
command.append("project=r'").append(escapePythonString(outputPath)).append("')\"");
return command.toString(); return command.toString();
} }
@ -515,15 +555,17 @@ public class YoloOperationServiceImpl implements YoloOperationService {
return command.toString(); return command.toString();
} }
private String buildPredictCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) { private String buildPredictCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) {
StringBuilder command = new StringBuilder(); StringBuilder command = new StringBuilder();
if ("conda".equals(envInfo.getVirtualEnvType())) { if ("conda".equals(envInfo.getVirtualEnvType())) {
command.append("conda run -n ").append(new File(envInfo.getVirtualEnvPath()).getName()).append(" "); command.append("conda run -n ").append(new File(envInfo.getVirtualEnvPath()).getName()).append(" ");
} else { } else {
command.append(envInfo.getVirtualEnvPythonPath()).append(" "); command.append(envInfo.getVirtualEnvPythonPath()).append(" ");
} }
String inputPath = normalizePath(config.getInputPath()); String inputPath = normalizePath(config.getInputPath());
String outputPath = normalizePath(config.getOutputPath()); String outputPath = normalizePath(config.getOutputPath());
String modelPath = normalizePath(config.getModelPath()); String modelPath = normalizePath(config.getModelPath());
@ -533,6 +575,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
command.append("from ultralytics import YOLO; "); command.append("from ultralytics import YOLO; ");
command.append("model = YOLO(r'").append(escapePythonString(modelPath)).append("'); "); command.append("model = YOLO(r'").append(escapePythonString(modelPath)).append("'); ");
command.append("results = model.predict(source=r'").append(escapePythonString(inputPath)).append("', "); command.append("results = model.predict(source=r'").append(escapePythonString(inputPath)).append("', ");
// 在这里我们将 save=true 改为 save=True // 在这里我们将 save=true 改为 save=True
command.append("save=").append(config.isSave() ? "True" : "False").append(", "); command.append("save=").append(config.isSave() ? "True" : "False").append(", ");
command.append("project='").append(escapePythonString(outputPath)).append("')\""); command.append("project='").append(escapePythonString(outputPath)).append("')\"");
@ -541,6 +584,8 @@ public class YoloOperationServiceImpl implements YoloOperationService {
} }
private String buildExportCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) { private String buildExportCommand(PythonVirtualEnvInfo envInfo, YoloConfig config) {
StringBuilder command = new StringBuilder(); StringBuilder command = new StringBuilder();
@ -582,6 +627,106 @@ public class YoloOperationServiceImpl implements YoloOperationService {
return command.toString(); return command.toString();
} }
/**
*
*/
private void parseTrainingProgress(String line, YoloConfig config, Integer trainId, int[] lastSavedEpoch) {
try {
String trimmedLine = line.trim();
// 解析Epoch信息: "Epoch 1/100" 或 "1/100"
if (trimmedLine.contains("Epoch") || (trimmedLine.matches("\\d+/\\d+"))) {
String[] parts = trimmedLine.split("\\s+");
for (String part : parts) {
if (part.contains("/")) {
String[] epochParts = part.split("/");
if (epochParts.length == 2) {
int currentEpoch = Integer.parseInt(epochParts[0]);
int totalEpochs = Integer.parseInt(epochParts[1]);
config.setCurrentEpoch(currentEpoch);
// 计算总体进度epoch进度 + batch进度
int epochProgress = (int) ((double) currentEpoch / totalEpochs * 90); // 90%用于epoch10%用于最终处理
int totalProgress = Math.min(epochProgress + (config.getBatchProgress() / 10), 95);
config.setProgress(totalProgress);
// 当epoch完成时保存训练信息
if (currentEpoch > lastSavedEpoch[0] && currentEpoch <= config.getEpochs()) {
saveTrainEpochInfo(trainId, currentEpoch, config.getEpochs(), config);
lastSavedEpoch[0] = currentEpoch;
}
}
}
}
}
// 解析批次信息: "1/100" 或 "batch 1/100"
if (trimmedLine.contains("batch") || trimmedLine.matches("\\d+/\\d+.*loss")) {
if (trimmedLine.contains("batch")) {
String[] parts = trimmedLine.split("\\s+");
for (int i = 0; i < parts.length; i++) {
if (parts[i].equals("batch") && i + 1 < parts.length && parts[i + 1].contains("/")) {
String[] batchParts = parts[i + 1].split("/");
if (batchParts.length == 2) {
int currentBatch = Integer.parseInt(batchParts[0]);
int totalBatches = Integer.parseInt(batchParts[1]);
config.setCurrentBatch(currentBatch);
config.setTotalBatches(totalBatches);
config.setBatchProgress((int) ((double) currentBatch / totalBatches * 100));
}
}
}
}
}
// 解析损失值: "loss: 0.1234", "box_loss: 0.1", "cls_loss: 0.05", "obj_loss: 0.02"
if (trimmedLine.contains("loss")) {
String[] lossParts = trimmedLine.split(",");
for (String lossPart : lossParts) {
lossPart = lossPart.trim();
if (lossPart.startsWith("loss:")) {
String value = lossPart.substring(5).trim();
config.setLoss(Double.parseDouble(value));
} else if (lossPart.startsWith("box_loss:")) {
// 可以添加更详细的损失记录
} else if (lossPart.startsWith("cls_loss:")) {
// 可以添加更详细的损失记录
} else if (lossPart.startsWith("obj_loss:")) {
// 可以添加更详细的损失记录
}
}
}
// 解析训练指标: "metrics/precision(B): 0.85", "metrics/recall(B): 0.78", "metrics/mAP50(B): 0.82", "metrics/mAP50-95(B): 0.65"
if (trimmedLine.contains("metrics/")) {
String[] metricParts = trimmedLine.split(",");
for (String metricPart : metricParts) {
metricPart = metricPart.trim();
if (metricPart.startsWith("metrics/precision")) {
String value = metricPart.split(":")[1].trim();
config.setPrecision(Double.parseDouble(value));
} else if (metricPart.startsWith("metrics/recall")) {
String value = metricPart.split(":")[1].trim();
config.setRecall(Double.parseDouble(value));
} else if (metricPart.startsWith("metrics/mAP50-95") || metricPart.startsWith("metrics/mAP")) {
String value = metricPart.split(":")[1].trim();
config.setMap(Double.parseDouble(value));
}
}
}
// 解析学习率: "lr: 0.001"
if (trimmedLine.contains("lr:")) {
// 可以记录学习率变化
}
} catch (Exception e) {
log.debug("解析训练进度时出错: {}", line, e);
// 忽略解析错误,继续处理下一行
}
}
/** /**
* *
@ -771,9 +916,12 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// 不影响训练完成状态 // 不影响训练完成状态
} }
} }
@Resource @Resource
@Lazy @Lazy
TrainService trainService; TrainService trainService;
/** /**
* TrainResultDO * TrainResultDO
@ -781,8 +929,11 @@ public class YoloOperationServiceImpl implements YoloOperationService {
private void saveTrainResultDO(TrainResultDO trainResult) { private void saveTrainResultDO(TrainResultDO trainResult) {
try { try {
// 创建SaveReqVO // 创建SaveReqVO
TrainResultSaveReqVO saveReqVO = TrainResultSaveReqVO saveReqVO =
new TrainResultSaveReqVO(); new TrainResultSaveReqVO();
// 设置值 // 设置值
saveReqVO.setTrainId(trainResult.getTrainId()); saveReqVO.setTrainId(trainResult.getTrainId());
@ -792,8 +943,11 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// 保存 // 保存
trainResultService.createTrainResult(saveReqVO); trainResultService.createTrainResult(saveReqVO);
trainService.update(new UpdateWrapper<TrainDO>().eq("id", trainResult.getTrainId()) trainService.update(new UpdateWrapper<TrainDO>().eq("id", trainResult.getTrainId())
.set("train_type", 4)); .set("train_type", 4));
} catch (Exception e) { } catch (Exception e) {
log.error("保存训练结果失败: {}", trainResult, e); log.error("保存训练结果失败: {}", trainResult, e);
@ -915,6 +1069,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
return escaped.toString(); return escaped.toString();
} }
@Resource @Resource
private RedisUtil redisUtil; private RedisUtil redisUtil;
/** /**
@ -1595,4 +1750,6 @@ public class YoloOperationServiceImpl implements YoloOperationService {
} }
} }

@ -48,6 +48,7 @@ spring:
master: master:
name: sy name: sy
url: jdbc:mysql://192.168.1.21:3306/${spring.datasource.dynamic.datasource.master.name}?useSSL=false&serverTimezone=Asia/Shanghai&allowPublicKeyRetrieval=true&nullCatalogMeansCurrent=true # MySQL Connector/J 8.X 连接的示例 url: jdbc:mysql://192.168.1.21:3306/${spring.datasource.dynamic.datasource.master.name}?useSSL=false&serverTimezone=Asia/Shanghai&allowPublicKeyRetrieval=true&nullCatalogMeansCurrent=true # MySQL Connector/J 8.X 连接的示例
username: root username: root
password: upright password: upright
# Redis 配置。Redisson 默认的配置足够使用,一般不需要进行调优 # Redis 配置。Redisson 默认的配置足够使用,一般不需要进行调优

Loading…
Cancel
Save