@ -1,11 +1,13 @@
package cn.iocoder.yudao.module.annotation.service.yolo ;
import cn.iocoder.yudao.module.annotation.controller.admin.trainresult.vo.TrainResultSaveReqVO ;
import cn.iocoder.yudao.module.annotation.dal.dataobject.train.TrainDO ;
import cn.iocoder.yudao.module.annotation.dal.dataobject.trainresult.TrainResultDO ;
import cn.iocoder.yudao.module.annotation.dal.yolo.YoloConfig ;
import cn.iocoder.yudao.module.annotation.service.python.PythonVirtualEnvService ;
import cn.iocoder.yudao.module.annotation.service.python.vo.PythonVirtualEnvInfo ;
import cn.iocoder.yudao.module.annotation.service.train.TrainService ;
import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService ;
import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService ;
@ -15,6 +17,10 @@ import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j ;
import org.springframework.beans.factory.annotation.Value ;
import org.springframework.context.annotation.Lazy ;
import cn.iocoder.yudao.module.annotation.service.traininfo.TrainInfoService ;
import cn.iocoder.yudao.module.annotation.service.trainresult.TrainResultService ;
import jakarta.annotation.Resource ;
import lombok.extern.slf4j.Slf4j ;
import org.springframework.stereotype.Service ;
import java.io.BufferedReader ;
@ -42,6 +48,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
@Resource
private TrainResultService trainResultService ;
@Resource
private YoloServiceClient yoloServiceClient ;
@ -51,17 +58,16 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// 异步执行器
private final Executor asyncExecutor = Executors . newCachedThreadPool ( ) ;
@Override
public CompletableFuture < YoloConfig > trainAsync ( String pythonPath , String envName , YoloConfig config ) {
return CompletableFuture . supplyAsync ( ( ) - > {
config . setStatus ( "running" ) ;
config . setProgress ( 0 ) ;
config . setTaskId ( UUID . randomUUID ( ) . toString ( ) ) ;
// 生成训练ID
Integer trainId = config . getTrainId ( ) ;
// 记录上一轮的epoch, 用于检测epoch完成
int [ ] lastSavedEpoch = { 0 } ;
@ -76,16 +82,13 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// 构建训练命令
String command = buildTrainCommand ( envInfo , config ) ;
log . info ( "开始YOLO训练, 命令: {}" , command ) ;
// 执行训练
ProcessBuilder processBuilder = new ProcessBuilder ( ) ;
// 在Windows上使用cmd, 在Linux/Mac上使用bash
String os = System . getProperty ( "os.name" ) . toLowerCase ( ) ;
if ( os . contains ( "win" ) ) {
processBuilder . command ( "cmd" , "/c" , command ) ;
// 设置环境变量确保正确处理中文路径
processBuilder . environment ( ) . put ( "PYTHONIOENCODING" , "utf-8" ) ;
processBuilder . environment ( ) . put ( "LANG" , "zh_CN.UTF-8" ) ;
} else {
@ -93,36 +96,35 @@ public class YoloOperationServiceImpl implements YoloOperationService {
processBuilder . environment ( ) . put ( "PYTHONIOENCODING" , "utf-8" ) ;
}
Process process = processBuilder . start ( ) ;
BufferedReader stdoutReader = new BufferedReader ( new InputStreamReader ( process . getInputStream ( ) ) ) ;
BufferedReader stderrReader = new BufferedReader ( new InputStreamReader ( process . getErrorStream ( ) ) ) ;
// BufferedReader stderrReader2 = new BufferedReader(new InputStreamReader(process.getOutputStream()));
StringBuilder fullOutput = new StringBuilder ( ) ;
// 使用单独的线程读取stdout和stderr, 避免死锁
String stdoutLine , stderrLine = null ;
while ( ( stdoutLine = stdoutReader . readLine ( ) ) ! = null | | ( stderrLine = stderrReader . readLine ( ) ) ! = null ) {
String stdoutLine ;
String stderrLine = null ;
// 同时读取标准输出和错误输出
while ( ( stdoutLine = stdoutReader . readLine ( ) ) ! = null | | ( stderrLine = stderrReader . readLine ( ) ) ! = null ) {
if ( stdoutLine ! = null ) {
log . info ( "YOLO训练日志: {}" , stdoutLine ) ;
fullOutput . append ( stdoutLine ) . append ( "\n" ) ;
// 详细解析YOLO训练进度和结果
manageTrainingLogInRedis ( trainId , stdoutLine ) ;
config . setLogMessage ( stdoutLine ) ;
// 检测 "Results saved to" 并解析路径和识别率
detectAndParseResultsSaved ( stdoutLine , config , trainId ) ;
parseTrainingProgress ( stdoutLine , config , trainId , lastSavedEpoch ) ;
}
if ( stderrLine ! = null ) {
log . info ( "YOLO训练日志: {}" , stderrLine ) ;
fullOutput . append ( stderrLine ) . append ( "\n" ) ;
manageTrainingLogInRedis ( trainId , stderrLine ) ;
config . setLogMessage ( stderrLine ) ;
// 检测 "Results saved to" 并解析路径和识别率
detectAndParseResultsSaved ( stderrLine , config , trainId ) ;
parseTrainingProgress ( stderrLine , config , trainId , lastSavedEpoch ) ;
}
}
@ -148,12 +150,14 @@ public class YoloOperationServiceImpl implements YoloOperationService {
config . setStatus ( "failed" ) ;
config . setErrorMessage ( "训练异常: " + e . getMessage ( ) ) ;
}
return config ;
} , asyncExecutor ) ;
}
@Override
public YoloConfig validate ( String pythonPath , String envName , YoloConfig config ) {
// 优先尝试使用socket服务( 如果启用且可用)
if ( useSocketService & & yoloServiceClient ! = null & & yoloServiceClient . isServiceAvailable ( ) ) {
log . info ( "使用socket服务执行YOLO验证" ) ;
@ -233,6 +237,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
@Override
public YoloConfig predict ( String pythonPath , String envName , YoloConfig config ) {
// 优先尝试使用socket服务( 如果启用且可用)
if ( useSocketService & & yoloServiceClient ! = null & & yoloServiceClient . isServiceAvailable ( ) ) {
log . info ( "使用socket服务执行YOLO预测" ) ;
@ -281,6 +286,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
if ( stdoutLine ! = null ) {
output . append ( stdoutLine ) . append ( "\n" ) ;
log . info ( "YOLO预测日志: {}" , stdoutLine ) ;
if ( stdoutLine . contains ( "Results saved to" ) ) {
stdoutLine = stdoutLine . replace ( "\u001B" , "" ) . replace ( "[1m" , "" ) . replace ( "[0m" , "" ) ;
@ -291,6 +297,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
output . append ( stderrLine ) . append ( "\n" ) ;
log . info ( "YOLO预测日志: {}" , stderrLine ) ;
if ( stderrLine . contains ( "Results saved to" ) ) {
stderrLine = stderrLine . replace ( "\u001B" , "" ) . replace ( "[1m" , "" ) . replace ( "[0m" , "" ) ;
config . setOutputPath ( stderrLine . split ( " " ) [ 3 ] ) ;
@ -445,13 +452,16 @@ public class YoloOperationServiceImpl implements YoloOperationService {
if ( "conda" . equals ( envInfo . getVirtualEnvType ( ) ) ) {
command . append ( "conda run -n " ) . append ( new File ( envInfo . getVirtualEnvPath ( ) ) . getName ( ) ) . append ( " " ) ;
} else {
command . append ( normalizePath ( envInfo . getVirtualEnvPythonPath ( ) ) ) . append ( " " ) ;
command . append ( envInfo . getVirtualEnvPythonPath ( ) ) . append ( " " ) ;
}
// 规范化路径,处理双斜杠和中文路径问题
String datasetPath = normalizePath ( config . getDatasetPath ( ) ) ;
String outputPath = normalizePath ( config . getOutputPath ( ) ) ;
String modelPath = config . getModelPath ( ) ! = null ? normalizePath ( config . getModelPath ( ) ) : null ;
// 调试日志:检查模型路径
log . info ( "模型路径配置 - modelPath: {}, modelName: {}" , config . getModelPath ( ) , config . getModelName ( ) ) ;
@ -485,6 +495,36 @@ public class YoloOperationServiceImpl implements YoloOperationService {
//
// command.append("project=r'").append(escapePythonString(outputPath)).append("')\"");
//
command . append ( "-c \"" ) ;
command . append ( "import sys; sys.path.insert(0, '.'); " ) ;
command . append ( "from ultralytics import YOLO; " ) ;
if ( modelPath ! = null & & ! modelPath . trim ( ) . isEmpty ( ) ) {
command . append ( "model = YOLO(r'" ) . append ( escapePythonString ( modelPath ) ) . append ( "'); " ) ;
} else {
command . append ( "model = YOLO('" ) . append ( config . getModelName ( ) ) . append ( "'); " ) ;
}
// 使用raw string来处理中文路径
command . append ( "model.train(data=r'" ) . append ( escapePythonString ( datasetPath ) ) . append ( "', " ) ;
command . append ( "epochs=" ) . append ( config . getEpochs ( ) ) . append ( ", " ) ;
command . append ( "batch=" ) . append ( config . getBatchSize ( ) ) . append ( ", " ) ;
command . append ( "imgsz=" ) . append ( config . getImageSize ( ) ) . append ( ", " ) ;
// 取消注释设备和学习率配置
// command.append("device=").append(config.getDevice()).append(", ");
// command.append("lr0=").append(config.getLearningRate()).append(", ");
// command.append("workers=").append(config.getWorkers()).append(", ");
//
if ( config . getSavePeriod ( ) > 0 ) {
command . append ( "save_period=" ) . append ( config . getSavePeriod ( ) ) . append ( ", " ) ;
}
command . append ( "project=r'" ) . append ( escapePythonString ( outputPath ) ) . append ( "')\"" ) ;
return command . toString ( ) ;
}
@ -515,15 +555,17 @@ public class YoloOperationServiceImpl implements YoloOperationService {
return command . toString ( ) ;
}
private String buildPredictCommand ( PythonVirtualEnvInfo envInfo , YoloConfig config ) {
StringBuilder command = new StringBuilder ( ) ;
if ( "conda" . equals ( envInfo . getVirtualEnvType ( ) ) ) {
command . append ( "conda run -n " ) . append ( new File ( envInfo . getVirtualEnvPath ( ) ) . getName ( ) ) . append ( " " ) ;
} else {
command . append ( envInfo . getVirtualEnvPythonPath ( ) ) . append ( " " ) ;
}
String inputPath = normalizePath ( config . getInputPath ( ) ) ;
String outputPath = normalizePath ( config . getOutputPath ( ) ) ;
String modelPath = normalizePath ( config . getModelPath ( ) ) ;
@ -533,6 +575,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
command . append ( "from ultralytics import YOLO; " ) ;
command . append ( "model = YOLO(r'" ) . append ( escapePythonString ( modelPath ) ) . append ( "'); " ) ;
command . append ( "results = model.predict(source=r'" ) . append ( escapePythonString ( inputPath ) ) . append ( "', " ) ;
// 在这里我们将 save=true 改为 save=True
command . append ( "save=" ) . append ( config . isSave ( ) ? "True" : "False" ) . append ( ", " ) ;
command . append ( "project='" ) . append ( escapePythonString ( outputPath ) ) . append ( "')\"" ) ;
@ -541,6 +584,8 @@ public class YoloOperationServiceImpl implements YoloOperationService {
}
private String buildExportCommand ( PythonVirtualEnvInfo envInfo , YoloConfig config ) {
StringBuilder command = new StringBuilder ( ) ;
@ -582,6 +627,106 @@ public class YoloOperationServiceImpl implements YoloOperationService {
return command . toString ( ) ;
}
/ * *
* 解 析 训 练 进 度 详 细 信 息
* /
private void parseTrainingProgress ( String line , YoloConfig config , Integer trainId , int [ ] lastSavedEpoch ) {
try {
String trimmedLine = line . trim ( ) ;
// 解析Epoch信息: "Epoch 1/100" 或 "1/100"
if ( trimmedLine . contains ( "Epoch" ) | | ( trimmedLine . matches ( "\\d+/\\d+" ) ) ) {
String [ ] parts = trimmedLine . split ( "\\s+" ) ;
for ( String part : parts ) {
if ( part . contains ( "/" ) ) {
String [ ] epochParts = part . split ( "/" ) ;
if ( epochParts . length = = 2 ) {
int currentEpoch = Integer . parseInt ( epochParts [ 0 ] ) ;
int totalEpochs = Integer . parseInt ( epochParts [ 1 ] ) ;
config . setCurrentEpoch ( currentEpoch ) ;
// 计算总体进度: epoch进度 + batch进度
int epochProgress = ( int ) ( ( double ) currentEpoch / totalEpochs * 90 ) ; // 90%用于epoch, 10%用于最终处理
int totalProgress = Math . min ( epochProgress + ( config . getBatchProgress ( ) / 10 ) , 95 ) ;
config . setProgress ( totalProgress ) ;
// 当epoch完成时保存训练信息
if ( currentEpoch > lastSavedEpoch [ 0 ] & & currentEpoch < = config . getEpochs ( ) ) {
saveTrainEpochInfo ( trainId , currentEpoch , config . getEpochs ( ) , config ) ;
lastSavedEpoch [ 0 ] = currentEpoch ;
}
}
}
}
}
// 解析批次信息: "1/100" 或 "batch 1/100"
if ( trimmedLine . contains ( "batch" ) | | trimmedLine . matches ( "\\d+/\\d+.*loss" ) ) {
if ( trimmedLine . contains ( "batch" ) ) {
String [ ] parts = trimmedLine . split ( "\\s+" ) ;
for ( int i = 0 ; i < parts . length ; i + + ) {
if ( parts [ i ] . equals ( "batch" ) & & i + 1 < parts . length & & parts [ i + 1 ] . contains ( "/" ) ) {
String [ ] batchParts = parts [ i + 1 ] . split ( "/" ) ;
if ( batchParts . length = = 2 ) {
int currentBatch = Integer . parseInt ( batchParts [ 0 ] ) ;
int totalBatches = Integer . parseInt ( batchParts [ 1 ] ) ;
config . setCurrentBatch ( currentBatch ) ;
config . setTotalBatches ( totalBatches ) ;
config . setBatchProgress ( ( int ) ( ( double ) currentBatch / totalBatches * 100 ) ) ;
}
}
}
}
}
// 解析损失值: "loss: 0.1234", "box_loss: 0.1", "cls_loss: 0.05", "obj_loss: 0.02"
if ( trimmedLine . contains ( "loss" ) ) {
String [ ] lossParts = trimmedLine . split ( "," ) ;
for ( String lossPart : lossParts ) {
lossPart = lossPart . trim ( ) ;
if ( lossPart . startsWith ( "loss:" ) ) {
String value = lossPart . substring ( 5 ) . trim ( ) ;
config . setLoss ( Double . parseDouble ( value ) ) ;
} else if ( lossPart . startsWith ( "box_loss:" ) ) {
// 可以添加更详细的损失记录
} else if ( lossPart . startsWith ( "cls_loss:" ) ) {
// 可以添加更详细的损失记录
} else if ( lossPart . startsWith ( "obj_loss:" ) ) {
// 可以添加更详细的损失记录
}
}
}
// 解析训练指标: "metrics/precision(B): 0.85", "metrics/recall(B): 0.78", "metrics/mAP50(B): 0.82", "metrics/mAP50-95(B): 0.65"
if ( trimmedLine . contains ( "metrics/" ) ) {
String [ ] metricParts = trimmedLine . split ( "," ) ;
for ( String metricPart : metricParts ) {
metricPart = metricPart . trim ( ) ;
if ( metricPart . startsWith ( "metrics/precision" ) ) {
String value = metricPart . split ( ":" ) [ 1 ] . trim ( ) ;
config . setPrecision ( Double . parseDouble ( value ) ) ;
} else if ( metricPart . startsWith ( "metrics/recall" ) ) {
String value = metricPart . split ( ":" ) [ 1 ] . trim ( ) ;
config . setRecall ( Double . parseDouble ( value ) ) ;
} else if ( metricPart . startsWith ( "metrics/mAP50-95" ) | | metricPart . startsWith ( "metrics/mAP" ) ) {
String value = metricPart . split ( ":" ) [ 1 ] . trim ( ) ;
config . setMap ( Double . parseDouble ( value ) ) ;
}
}
}
// 解析学习率: "lr: 0.001"
if ( trimmedLine . contains ( "lr:" ) ) {
// 可以记录学习率变化
}
} catch ( Exception e ) {
log . debug ( "解析训练进度时出错: {}" , line , e ) ;
// 忽略解析错误,继续处理下一行
}
}
/ * *
* 解 析 最 终 训 练 结 果
@ -771,9 +916,12 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// 不影响训练完成状态
}
}
@Resource
@Lazy
TrainService trainService ;
/ * *
* 保 存 TrainResultDO 到 数 据 库
@ -781,8 +929,11 @@ public class YoloOperationServiceImpl implements YoloOperationService {
private void saveTrainResultDO ( TrainResultDO trainResult ) {
try {
// 创建SaveReqVO
TrainResultSaveReqVO saveReqVO =
new TrainResultSaveReqVO ( ) ;
// 设置值
saveReqVO . setTrainId ( trainResult . getTrainId ( ) ) ;
@ -792,8 +943,11 @@ public class YoloOperationServiceImpl implements YoloOperationService {
// 保存
trainResultService . createTrainResult ( saveReqVO ) ;
trainService . update ( new UpdateWrapper < TrainDO > ( ) . eq ( "id" , trainResult . getTrainId ( ) )
. set ( "train_type" , 4 ) ) ;
} catch ( Exception e ) {
log . error ( "保存训练结果失败: {}" , trainResult , e ) ;
@ -915,6 +1069,7 @@ public class YoloOperationServiceImpl implements YoloOperationService {
return escaped . toString ( ) ;
}
@Resource
private RedisUtil redisUtil ;
/ * *
@ -1595,4 +1750,6 @@ public class YoloOperationServiceImpl implements YoloOperationService {
}
}