整体修改

普洱算法
LAPTOP-S9HJSOEB\昊天 1 week ago
parent 521a2efcfd
commit 0558000e06

@ -77,7 +77,11 @@
<artifactId>hutool-all</artifactId>
<version>5.8.16</version>
</dependency>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.16.0</version>
</dependency>
<dependency>
<groupId>MvCameraControlWrapper</groupId>
<artifactId>MvCameraControlWrapper</artifactId>

@ -175,6 +175,65 @@ public class HikController {
return pojo;
}
// 识别
@GetMapping("/distinguishOnnx")
@ResponseBody
public Pojo distinguishOnnx(String streetNumber,int direction, String category) throws IOException {
String picPath = appConfig.getPicPath();
String path =streetNumber+"\\"+direction+"_"+ UUID.randomUUID() +".jpg";
Mat img2 = null;
for (AppConfig.Camera camera : appConfig.getHikCamera()){
if (camera.getStreetId().equals(streetNumber)&& camera.getDirection()==direction) {
boolean i = hikSaveImage.saveImage(camera.getIp(),picPath+ path, "ip");
img2 = Imgcodecs.imread(picPath+path);;
img2 = Calibration.performCalibration(img2, camera.getId(), appConfig.getActualRectangularRatio());
}
}
// 进行标定,保存图片
Imgcodecs.imwrite(picPath+path+".jpg", Objects.requireNonNull(img2));
// 指定文件夹路径
String folderPath = appConfig.templatePath+category;
// 创建File对象
File folder = new File(folderPath);
boolean flag = false;
// 判断是否是文件夹并且存在
if (folder.exists() && folder.isDirectory()) {
// 获取文件夹中的所有文件和子文件夹
File[] files = folder.listFiles();
int i= 0;
// 判断文件夹中是否有文件
if (files != null && files.length > 0) {
for (File file : files) {
i++;
// 判断是否是文件
if (file.isFile()) {
// 获取文件名并进行判断
Mat img1 = Imgcodecs.imread(file.getAbsolutePath());
if(FeatureMatchingExample.matchTemplate(img1, img2,picPath+path+".jpg.jpg",appConfig.getThreshold())){
flag = true;
}else {
img2 = Imgcodecs.imread(picPath+path+".jpg.jpg");
}
}
}
} else {
System.out.println("The folder is empty.");
}
} else {
System.out.println("The specified path is not a valid directory or does not exist.");
}
Pojo pojo = new Pojo();
pojo.setResult(flag);
pojo.setDeterminePath(path+".jpg.jpg");
return pojo;
}
// 标定
@GetMapping("/task")
@ResponseBody

@ -0,0 +1,705 @@
package com.example.lxcameraapi.service.IndustrialCamera.algorithm;
import com.example.lxcameraapi.conf.AppConfig;
import com.example.lxcameraapi.service.IndustrialCamera.yolo.BoundingBox;
import com.example.lxcameraapi.service.IndustrialCamera.yolo.ClassifyEntity;
import lombok.extern.slf4j.Slf4j;
import ai.onnxruntime.*;
import org.opencv.core.Mat;
import org.opencv.imgcodecs.Imgcodecs;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
@Slf4j
@Service
public class ONNXServiceNew {
@Resource
AppConfig configProperties;
private static final Map<String, OrtSession> ortSessions = new ConcurrentHashMap<>();
private static final Map<String, AppConfig.YoloModelConfig> ortMap = new ConcurrentHashMap<>();
private static final OrtEnvironment environment = OrtEnvironment.getEnvironment();
public static void main(String[] args) {
System.load(new File(System.getProperty("user.dir")+"\\libs\\opencv\\opencv_java480.dll").getAbsolutePath());
ortMap.put("a", new AppConfig.YoloModelConfig());
ortMap.get("a").setImageSize(2048);
ortMap.get("a").setConfThreshold(0.5f);
ortMap.get("a").setNames(new String[]{"0"});
int imageSize = 2048;
String imagePath = "D:\\data\\1776157002220.png";
String modelPath = "D:\\data\\best.onnx";
// List<String> name = Arrays.asList("0143", "0153", "0173", "0177", "0191", "0253", "0256", "0266", "0268", "0286", "0302", "0304", "0305", "0307", "0320", "0326", "0336", "0339", "0343", "0352", "0458", "0461", "0462", "0473", "0477", "0486", "0490", "0492", "0930", "1101", "1102", "1104", "1262", "1269", "1302", "1308", "1359", "1366", "1622", "1625", "1919", "1976", "20", "2165", "2188", "2210", "2224", "2445", "2476", "2611", "2730", "2731", "2910", "2914", "2943", "3027", "3028", "3029", "3212", "3226", "3344", "3501", "3509", "3538", "3725", "3741", "3751", "3754", "3763", "3766", "3808");
OrtSession.SessionOptions sessionOptions = null;
// 使用gpu,需要本机按钻过cuda并修改pom.xml不安装也能运行本程序
// sessionOptions.addCUDA(0);
// 实际项目中视频识别必须开启GPU并且要防止队列堆积
OrtSession session = null;
ONNXServiceNew onnxServiceNew = new ONNXServiceNew();
try {
sessionOptions = new OrtSession.SessionOptions();
session = environment.createSession(modelPath, sessionOptions);
ortSessions.put("a", session);
List<BoundingBox> boxes = onnxServiceNew.detect(imagePath, "a");
// classifyEntity.setIndex(predictedClassId);
// classifyEntity.setName(config.getNames()[predictedClassId]);
// classifyEntity.setConfidence(outputData[predictedClassId]);
// classifyEntity.setClassProb(String.format("%.4f", outputData[i]));
// float[] output = (float[]) result.get(0).getValue();
// DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections);
System.out.println(boxes);
// System.out.println(name.get(predictedClassId));
} catch (OrtException e) {
throw new RuntimeException(e);
} finally {
if (sessionOptions != null) {
sessionOptions.close();
}
}
// // 处理图像
// float[] imageData = new float[0];
// try {
// imageData = processImageFromURL(imagePath, imageSize);
// } catch (IOException ex) {
// throw new RuntimeException(ex);
// }
//
// // 构建输入张量
// long[] shape = new long[]{1, 3, imageSize, imageSize}; // batch_size, channels, height, width
// OnnxTensor inputTensor = null;
// try {
// inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(imageData), shape);
//
// HashMap<String, OnnxTensor> stringOnnxTensorHashMap = new HashMap<>();
// stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), inputTensor);
//
// // 执行推理
//// OrtSession.Result result = session.run(stringOnnxTensorHashMap);
////
//// float[][] outputData = ((float[][][])result.get(0).getValue())[0];
////
//// outputData = transposeMatrix(outputData);
////
////// float[] output = (float[]) result.get(0).getValue();
////
//// // 解析推理结果
//// List<BoundingBox> detections = parseOutput(outputData,0.4);
////
//// List<BoundingBox> filteredDetections = nonMaxSuppression(detections, 0.5f); // IoU 阈值设为 0.5
//// DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections);
//// System.out.println(filteredDetections);
//// int mostFrequentClassId = getMaxClassIndex(filteredDetections);
//// System.out.println(calculateLayers(filteredDetections, 200,mostFrequentClassId));
//
// // 执行推理
// OrtSession.Result result = session.run(stringOnnxTensorHashMap);
//// 获取第一个输出(大多数情况下只有一个输出)
// float[] outputData = ((float[][]) result.get(0).getValue())[0];
//
//// 获取最大概率对应的类别 ID
// int predictedClassId = argmax(outputData, 0.8);
//
//// float[] output = (float[]) result.get(0).getValue();
//
//// DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections);
// ClassifyEntity classifyEntity = new ClassifyEntity();
// classifyEntity.setIndex(predictedClassId);
// } catch (OrtException e) {
// throw new RuntimeException(e);
// }
}
public static int getMaxClassIndex(List<BoundingBox> boxes) {
return boxes.stream()
.collect(Collectors.groupingBy(BoundingBox::getIndex, Collectors.counting()))
.entrySet().stream()
.max(Map.Entry.comparingByValue())
.map(Map.Entry::getKey)
.orElse(-1);
}
public static Set<Integer> getMinTypeClassIndex(List<BoundingBox> boxes) {
return boxes.stream()
.map(BoundingBox::getIndex)
.collect(Collectors.toSet());
}
/**
* y threshold
*/
public static int calculateLayers(List<BoundingBox> boxes, double threshold, int mostFrequentClassId) {
if (boxes == null || boxes.isEmpty()) {
return 0;
}
//出现次数做多的classId
boxes = boxes.stream()
.filter(box -> box.getIndex() == 0)
.collect(Collectors.toList());
// 检查过滤后的boxes是否为空
if (boxes.isEmpty()) {
return 0;
}
List<Double> yValues = new ArrayList<>();
for (BoundingBox box : boxes) {
yValues.add(box.getY());
}
Collections.sort(yValues); // 先排序
int layers = 1; // 至少有一层
double currentBase = yValues.get(0);
for (int i = 1; i < yValues.size(); i++) {
if (Math.abs(yValues.get(i) - currentBase) > threshold) {
layers++;
currentBase = yValues.get(i);
}
}
return layers;
}
public static float[][] transposeMatrix(float[][] m) {
float[][] temp = new float[m[0].length][m.length];
for (int i = 0; i < m.length; i++)
for (int j = 0; j < m[0].length; j++)
temp[j][i] = m[i][j];
return temp;
}
// 图像文件处理
//
// // 处理图片URL
// private static float[] processImageFromURL(String imageUrl, int imageSize) throws IOException {
// File file = new File(imageUrl);
// BufferedImage bufferedImage = ImageIO.read(file);
// // 假设我们需要将图片缩放为 imageSize*imageSize并转换为一个符合 YOLO 输入要求的 float 数组
// BufferedImage resizedImage = new BufferedImage(imageSize, imageSize, BufferedImage.TYPE_INT_RGB);
// resizedImage.getGraphics().drawImage(bufferedImage, 0, 0, imageSize, imageSize, null);
//
// // 将图片转换为 float 数组(假设 RGB 格式)
// float[] imageData = new float[imageSize * imageSize * 3];
// int index = 0;
// int[] rData = new int[imageSize * imageSize];
// int[] gData = new int[imageSize * imageSize];
// int[] bData = new int[imageSize * imageSize];
//
// int idx = 0;
// for (int y = 0; y < imageSize; y++) {
// for (int x = 0; x < imageSize; x++) {
// int rgb = resizedImage.getRGB(x, y);
// rData[idx] = (rgb >> 16) & 0xFF;
// gData[idx] = (rgb >> 8) & 0xFF;
// bData[idx] = rgb & 0xFF;
// idx++;
// }
// }
//// 填充为 CHW
// idx = 0;
// for (int i = 0; i < imageSize * imageSize; i++) {
// imageData[idx++] = rData[i] / 255.0f;
// }
// for (int i = 0; i < imageSize * imageSize; i++) {
// imageData[idx++] = gData[i] / 255.0f;
// }
// for (int i = 0; i < imageSize * imageSize; i++) {
// imageData[idx++] = bData[i] / 255.0f;
// }
// return imageData;
// }
// 处理图片URL
private static float[] processImageFromURL(String imageUrl, int imageSize) throws IOException {
// 从URL读取图片
BufferedImage originalImage;
if (imageUrl.startsWith("http")) {
URL url = new URL(imageUrl);
originalImage = ImageIO.read(url);
} else {
File file = new File(imageUrl);
originalImage = ImageIO.read(file);
}
// 创建目标尺寸的图片
BufferedImage resizedImage = new BufferedImage(imageSize, imageSize, BufferedImage.TYPE_INT_RGB);
Graphics2D g2d = resizedImage.createGraphics();
// 设置背景为黑色
g2d.setColor(Color.BLACK);
g2d.fillRect(0, 0, imageSize, imageSize);
// 计算缩放比例,保持宽高比
int originalWidth = originalImage.getWidth();
int originalHeight = originalImage.getHeight();
// 按照长边计算缩放比例
double scale = Math.min((double) imageSize / originalWidth, (double) imageSize / originalHeight);
// 计算缩放后的尺寸
int scaledWidth = (int) (originalWidth * scale);
int scaledHeight = (int) (originalHeight * scale);
// 计算居中位置
int x = (imageSize - scaledWidth) / 2;
int y = (imageSize - scaledHeight) / 2;
// 绘制缩放后的图片
g2d.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
g2d.drawImage(originalImage, x, y, scaledWidth, scaledHeight, null);
g2d.dispose();
// 将图片转换为 float 数组RGB 格式,归一化到 0-1
float[] imageData = new float[imageSize * imageSize * 3];
int index = 0;
// 按照CHW格式排列数据通道-高度-宽度)
// 先处理所有像素的红色通道
for (int i = 0; i < imageSize; i++) {
for (int j = 0; j < imageSize; j++) {
int rgb = resizedImage.getRGB(j, i);
imageData[index++] = ((rgb >> 16) & 0xFF) / 255.0f; // R
}
}
// 再处理绿色通道
for (int i = 0; i < imageSize; i++) {
for (int j = 0; j < imageSize; j++) {
int rgb = resizedImage.getRGB(j, i);
imageData[index++] = ((rgb >> 8) & 0xFF) / 255.0f; // G
}
}
// 最后处理蓝色通道
for (int i = 0; i < imageSize; i++) {
for (int j = 0; j < imageSize; j++) {
int rgb = resizedImage.getRGB(j, i);
imageData[index++] = (rgb & 0xFF) / 255.0f; // B
}
}
return imageData;
}
public static List<BoundingBox> nonMaxSuppression(List<BoundingBox> boxes, float iouThreshold) {
if (boxes.isEmpty()) return new ArrayList<>();
// 按置信度从高到低排序
boxes.sort((a, b) -> Float.compare((float) b.getConfidence(), (float) a.getConfidence()));
List<BoundingBox> result = new ArrayList<>();
while (!boxes.isEmpty()) {
BoundingBox best = boxes.remove(0);
result.add(best);
// 过滤掉和 best IoU 太大的框
boxes.removeIf(box -> computeIoU(best, box) > iouThreshold);
}
return result;
}
public static float computeIoU(BoundingBox box1, BoundingBox box2) {
double x1 = Math.max(box1.getX() - box1.getW() / 2, box2.getX() - box2.getW() / 2);
double y1 = Math.max(box1.getY() - box1.getH() / 2, box2.getY() - box2.getH() / 2);
double x2 = Math.min(box1.getX() + box1.getW() / 2, box2.getX() + box2.getW() / 2);
double y2 = Math.min(box1.getY() + box1.getH() / 2, box2.getY() + box2.getH() / 2);
double w = Math.max(0, x2 - x1);
double h = Math.max(0, y2 - y1);
double intersection = w * h;
double area1 = (box1.getW() * box1.getH());
double area2 = (box2.getW() * box2.getH());
return (float) (intersection / (area1 + area2 - intersection + 1e-6));
}
@PostConstruct
public void initSession() {
OrtSession.SessionOptions sessionOptions = null;
try {
sessionOptions = new OrtSession.SessionOptions();
// 使用gpu,需要本机按钻过cuda并修改pom.xml不安装也能运行本程序
// sessionOptions.addCUDA(0);
// 实际项目中视频识别必须开启GPU并且要防止队列堆积
// 大模型
OrtSession session = null;
if (configProperties.getYoloModelConfig().getModelPath() != null) {
session = environment.createSession(configProperties.getYoloModelConfig().getModelPath(), sessionOptions);
OrtSession finalSession = session;
session.getInputInfo().keySet().forEach(x -> {
System.out.println("input name = " + x);
try {
TensorInfo tensorInfo = (TensorInfo) finalSession.getInputInfo().get(x).getInfo();
long[] shape = tensorInfo.getShape();
// 假设形状是 [batch, channels, height, width]
if (shape.length >= 4) {
configProperties.getYoloModelConfig().setImageSize((int)shape[3]); // width
}
System.out.println(tensorInfo.toString());
} catch (OrtException e) {
throw new RuntimeException(e);
}
});
ortSessions.put("model", session);
ortMap.put("model", configProperties.getYoloModelConfig());
}
if (configProperties.getYoloModelConfig().getModelMap() != null) {
OrtSession.SessionOptions finalSessionOptions = sessionOptions;
configProperties.getYoloModelConfig().getModelMap().forEach((k, v) -> {
try {
// 判断v文件是否存在
if (!new File(v.getModelPath()).exists()) {
log.error("文件不存在:{}", v);
return;
}
OrtSession session1 = environment.createSession(v.getModelPath(), finalSessionOptions);
session1.getInputInfo().keySet().forEach(x -> {
System.out.println("input name = " + x);
try {
TensorInfo tensorInfo = (TensorInfo) session1.getInputInfo().get(x).getInfo();
long[] shape = tensorInfo.getShape();
// 假设形状是 [batch, channels, height, width]
if (shape.length >= 4) {
v.setImageSize((int)shape[3]); // width
}
} catch (OrtException e) {
throw new RuntimeException(e);
}
});
ortSessions.put(k, session1);
ortMap.put(k, v);
} catch (OrtException e) {
throw new RuntimeException(e);
}
});
}
} catch (OrtException e) {
throw new RuntimeException(e);
}
// sessionOptions 不在这里关闭,它的生命周期应该与 Session 一致
}
private OrtSession getSession(String type) {
if (type != null && !type.equals("")) {
return ortSessions.get(type);
}
return ortSessions.get("model");
}
private AppConfig.YoloModelConfig getConfig(String type) {
if (type != null && !type.equals("")) {
return ortMap.get(type);
}
return ortMap.get("model");
}
/**
* yolo11
*
* @param imagePath
* @return
* @throws OrtException
*/
public List<ClassifyEntity> classify(String imagePath, String type) throws OrtException {
OrtSession session = getSession(type);
AppConfig.YoloModelConfig config = getConfig(type);
// 处理图像
float[] imageData = new float[0];
try {
imageData = processImageFromURL(imagePath, config.getImageSize());
} catch (IOException ex) {
throw new RuntimeException(ex);
}
// 构建输入张量
long[] shape = new long[]{1, 3, config.getImageSize(), config.getImageSize()}; // batch_size, channels, height, width
OnnxTensor inputTensor = null;
OrtSession.Result result = null;
try {
inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(imageData), shape);
HashMap<String, OnnxTensor> stringOnnxTensorHashMap = new HashMap<>();
stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), inputTensor);
// 执行推理
result = session.run(stringOnnxTensorHashMap);
// 获取第一个输出(大多数情况下只有一个输出)
float[] outputData = ((float[][]) result.get(0).getValue())[0];
List<ClassifyEntity> classifyEntities = new ArrayList<>();
for (int i = 0; i < outputData.length; i += 5) {
ClassifyEntity classifyEntity = new ClassifyEntity();
classifyEntity.setIndex(i);
classifyEntity.setName(config.getNames()[i]);
classifyEntity.setConfidence(outputData[i]);
// classifyEntity.setClassProb(String.format("%.4f", outputData[i]));
classifyEntities.add(classifyEntity);
}
// 获取最大概率对应的类别 ID
// int predictedClassId = argmax(outputData, configProperties.getYoloModelConfig().getConfThreshold());
// float[] output = (float[]) result.get(0).getValue();
// DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections);
return classifyEntities;
} finally {
if (inputTensor != null) {
inputTensor.close();
}
if (result != null) {
result.close();
}
}
}
/**
* yolo11
*
* @param imagePath
* @return
* @throws OrtException
*/
public ClassifyEntity classifyOne(String imagePath, String type) throws OrtException {
OrtSession session = getSession(type);
AppConfig.YoloModelConfig config = getConfig(type);
// 处理图像
float[] imageData = new float[0];
try {
imageData = processImageFromURL(imagePath, config.getImageSize());
} catch (IOException ex) {
throw new RuntimeException(ex);
}
// 构建输入张量
long[] shape = new long[]{1, 3, config.getImageSize(), config.getImageSize()}; // batch_size, channels, height, width
OnnxTensor inputTensor = null;
OrtSession.Result result = null;
try {
inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(imageData), shape);
HashMap<String, OnnxTensor> stringOnnxTensorHashMap = new HashMap<>();
stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), inputTensor);
// 执行推理
result = session.run(stringOnnxTensorHashMap);
// 获取第一个输出(大多数情况下只有一个输出)
float[] outputData = ((float[][]) result.get(0).getValue())[0];
// 获取最大概率对应的类别 ID
int predictedClassId = argmax(outputData, configProperties.getYoloModelConfig().getConfThreshold());
ClassifyEntity classifyEntity = new ClassifyEntity();
classifyEntity.setIndex(predictedClassId);
if (predictedClassId!=-1){
log.info("识别模版:{}",config.getNames()[predictedClassId]);
classifyEntity.setName(config.getNames()[predictedClassId]);
classifyEntity.setConfidence(outputData[predictedClassId]);
}else {
log.info("未识别模版");
}
// classifyEntity.setClassProb(String.format("%.4f", outputData[i]));
// float[] output = (float[]) result.get(0).getValue();
// DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections);
return classifyEntity;
} finally {
if (inputTensor != null) {
inputTensor.close();
}
if (result != null) {
result.close();
}
}
}
/**
* yolo11
*
* @param imagePath
* @return
* @throws OrtException
*/
public List<BoundingBox> detect(String imagePath, String type) throws OrtException {
OrtSession session = getSession(type);
AppConfig.YoloModelConfig config = getConfig(type);
// 处理图像
float[] imageData = new float[0];
try {
imageData = processImageFromURL(imagePath, config.getImageSize());
} catch (IOException ex) {
log.error("处理图像时出错: {}", ex.getMessage(), ex);
return new ArrayList<>();
}
// 构建输入张量
long[] shape = new long[]{1, 3, config.getImageSize(), config.getImageSize()}; // batch_size, channels, height, width
OnnxTensor inputTensor = null;
OrtSession.Result result = null;
try {
inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(imageData), shape);
HashMap<String, OnnxTensor> stringOnnxTensorHashMap = new HashMap<>();
stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), inputTensor);
// 执行推理
result = session.run(stringOnnxTensorHashMap);
float[][] outputData = ((float[][][]) result.get(0).getValue())[0];
outputData = transposeMatrix(outputData);
// float[] output = (float[]) result.get(0).getValue();
// 解析推理结果
List<BoundingBox> detections = parseOutput(outputData, config.getConfThreshold(), config);
List<BoundingBox> filteredDetections = nonMaxSuppression(detections, 0.5f); // IoU 阈值设为 0.5
System.out.println(filteredDetections);
return filteredDetections;
} finally {
if (inputTensor != null) {
inputTensor.close();
}
if (result != null) {
result.close();
}
}
}
private static List<BoundingBox> parseOutput(float[][] outputData, double confThreshold, AppConfig.YoloModelConfig config) {
List<BoundingBox> detections = new ArrayList<>();
for (float[] bbox : outputData) {
float[] conditionalProbabilities = Arrays.copyOfRange(bbox, 4, bbox.length);
int label = argmax(conditionalProbabilities, confThreshold);
// 如果没有找到满足阈值的类别,跳过这个检测框
if (label == -1) {
continue;
}
float conf = conditionalProbabilities[label];
bbox[4] = conf;
// xywh to (x1, y1, x2, y2)
float x = bbox[0];
float y = bbox[1];
float w = bbox[2];
float h = bbox[3];
bbox[0] = x - w * 0.5f;
bbox[1] = y - h * 0.5f;
bbox[2] = x + w * 0.5f;
bbox[3] = y + h * 0.5f;
if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) continue;
BoundingBox detection = new BoundingBox();
detection.setX(x);
detection.setY(y);
detection.setW(w);
detection.setH(h);
detection.setConfidence(conf);
detection.setIndex(label);
detection.setName(config.getNames()[label]);
detections.add(detection);
}
return detections;
}
//返回最大值的索引
public static int argmax(float[] a, double threshold) {
float re = -Float.MAX_VALUE;
int arg = -1;
for (int i = 0; i < a.length; i++) {
if (a[i] < threshold) continue;
if (a[i] >= re) {
re = a[i];
arg = i;
}
}
return arg;
}
/**
*
*/
@jakarta.annotation.PreDestroy
public void destroy() {
log.info("开始清理 ONNX 资源...");
for (Map.Entry<String, OrtSession> entry : ortSessions.entrySet()) {
try {
entry.getValue().close();
log.info("已关闭 Session: {}", entry.getKey());
} catch (OrtException e) {
log.error("关闭 Session 时出错: {}", entry.getKey(), e);
}
}
ortSessions.clear();
ortMap.clear();
log.info("ONNX 资源清理完成");
}
}
Loading…
Cancel
Save