diff --git a/pom.xml b/pom.xml
index 00e3152..6f544ce 100644
--- a/pom.xml
+++ b/pom.xml
@@ -77,7 +77,11 @@
hutool-all
5.8.16
-
+
+ com.microsoft.onnxruntime
+ onnxruntime
+ 1.16.0
+
MvCameraControlWrapper
MvCameraControlWrapper
diff --git a/src/main/java/com/example/lxcameraapi/controller/HikController.java b/src/main/java/com/example/lxcameraapi/controller/HikController.java
index 2ae4462..bc5f724 100644
--- a/src/main/java/com/example/lxcameraapi/controller/HikController.java
+++ b/src/main/java/com/example/lxcameraapi/controller/HikController.java
@@ -175,6 +175,65 @@ public class HikController {
return pojo;
}
+ // 识别
+ @GetMapping("/distinguishOnnx")
+ @ResponseBody
+ public Pojo distinguishOnnx(String streetNumber,int direction, String category) throws IOException {
+ String picPath = appConfig.getPicPath();
+ String path =streetNumber+"\\"+direction+"_"+ UUID.randomUUID() +".jpg";
+ Mat img2 = null;
+ for (AppConfig.Camera camera : appConfig.getHikCamera()){
+ if (camera.getStreetId().equals(streetNumber)&& camera.getDirection()==direction) {
+ boolean i = hikSaveImage.saveImage(camera.getIp(),picPath+ path, "ip");
+
+ img2 = Imgcodecs.imread(picPath+path);;
+ img2 = Calibration.performCalibration(img2, camera.getId(), appConfig.getActualRectangularRatio());
+ }
+ }
+
+
+ // 进行标定,保存图片
+ Imgcodecs.imwrite(picPath+path+".jpg", Objects.requireNonNull(img2));
+
+ // 指定文件夹路径
+ String folderPath = appConfig.templatePath+category;
+
+ // 创建File对象
+ File folder = new File(folderPath);
+ boolean flag = false;
+ // 判断是否是文件夹并且存在
+ if (folder.exists() && folder.isDirectory()) {
+ // 获取文件夹中的所有文件和子文件夹
+ File[] files = folder.listFiles();
+ int i= 0;
+ // 判断文件夹中是否有文件
+ if (files != null && files.length > 0) {
+ for (File file : files) {
+ i++;
+ // 判断是否是文件
+ if (file.isFile()) {
+ // 获取文件名并进行判断
+ Mat img1 = Imgcodecs.imread(file.getAbsolutePath());
+ if(FeatureMatchingExample.matchTemplate(img1, img2,picPath+path+".jpg.jpg",appConfig.getThreshold())){
+ flag = true;
+
+ }else {
+ img2 = Imgcodecs.imread(picPath+path+".jpg.jpg");
+ }
+ }
+ }
+ } else {
+ System.out.println("The folder is empty.");
+ }
+ } else {
+ System.out.println("The specified path is not a valid directory or does not exist.");
+ }
+ Pojo pojo = new Pojo();
+ pojo.setResult(flag);
+ pojo.setDeterminePath(path+".jpg.jpg");
+ return pojo;
+ }
+
// 标定
@GetMapping("/task")
@ResponseBody
diff --git a/src/main/java/com/example/lxcameraapi/service/IndustrialCamera/algorithm/ONNXServiceNew.java b/src/main/java/com/example/lxcameraapi/service/IndustrialCamera/algorithm/ONNXServiceNew.java
new file mode 100644
index 0000000..20209c9
--- /dev/null
+++ b/src/main/java/com/example/lxcameraapi/service/IndustrialCamera/algorithm/ONNXServiceNew.java
@@ -0,0 +1,705 @@
+package com.example.lxcameraapi.service.IndustrialCamera.algorithm;
+
+import com.example.lxcameraapi.conf.AppConfig;
+import com.example.lxcameraapi.service.IndustrialCamera.yolo.BoundingBox;
+import com.example.lxcameraapi.service.IndustrialCamera.yolo.ClassifyEntity;
+import lombok.extern.slf4j.Slf4j;
+
+import ai.onnxruntime.*;
+import org.opencv.core.Mat;
+import org.opencv.imgcodecs.Imgcodecs;
+import org.springframework.http.ResponseEntity;
+import org.springframework.stereotype.Service;
+import org.springframework.web.multipart.MultipartFile;
+
+import javax.annotation.PostConstruct;
+import javax.annotation.Resource;
+import javax.imageio.ImageIO;
+import java.awt.*;
+import java.awt.image.BufferedImage;
+import java.io.File;
+import java.io.IOException;
+import java.net.URL;
+import java.nio.FloatBuffer;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.*;
+import java.util.List;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import java.util.stream.Collectors;
+
+@Slf4j
+@Service
+public class ONNXServiceNew {
+
+
+ @Resource
+ AppConfig configProperties;
+
+
+ private static final Map ortSessions = new ConcurrentHashMap<>();
+
+ private static final Map ortMap = new ConcurrentHashMap<>();
+
+ private static final OrtEnvironment environment = OrtEnvironment.getEnvironment();
+
+ public static void main(String[] args) {
+
+ System.load(new File(System.getProperty("user.dir")+"\\libs\\opencv\\opencv_java480.dll").getAbsolutePath());
+ ortMap.put("a", new AppConfig.YoloModelConfig());
+ ortMap.get("a").setImageSize(2048);
+ ortMap.get("a").setConfThreshold(0.5f);
+ ortMap.get("a").setNames(new String[]{"0"});
+
+ int imageSize = 2048;
+ String imagePath = "D:\\data\\1776157002220.png";
+
+ String modelPath = "D:\\data\\best.onnx";
+// List name = Arrays.asList("0143", "0153", "0173", "0177", "0191", "0253", "0256", "0266", "0268", "0286", "0302", "0304", "0305", "0307", "0320", "0326", "0336", "0339", "0343", "0352", "0458", "0461", "0462", "0473", "0477", "0486", "0490", "0492", "0930", "1101", "1102", "1104", "1262", "1269", "1302", "1308", "1359", "1366", "1622", "1625", "1919", "1976", "20", "2165", "2188", "2210", "2224", "2445", "2476", "2611", "2730", "2731", "2910", "2914", "2943", "3027", "3028", "3029", "3212", "3226", "3344", "3501", "3509", "3538", "3725", "3741", "3751", "3754", "3763", "3766", "3808");
+
+ OrtSession.SessionOptions sessionOptions = null;
+
+ // 使用gpu,需要本机按钻过cuda,并修改pom.xml,不安装也能运行本程序
+ // sessionOptions.addCUDA(0);
+ // 实际项目中,视频识别必须开启GPU,并且要防止队列堆积
+
+ OrtSession session = null;
+ ONNXServiceNew onnxServiceNew = new ONNXServiceNew();
+ try {
+ sessionOptions = new OrtSession.SessionOptions();
+ session = environment.createSession(modelPath, sessionOptions);
+ ortSessions.put("a", session);
+ List boxes = onnxServiceNew.detect(imagePath, "a");
+
+// classifyEntity.setIndex(predictedClassId);
+// classifyEntity.setName(config.getNames()[predictedClassId]);
+// classifyEntity.setConfidence(outputData[predictedClassId]);
+// classifyEntity.setClassProb(String.format("%.4f", outputData[i]));
+
+
+// float[] output = (float[]) result.get(0).getValue();
+
+// DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections);
+
+ System.out.println(boxes);
+// System.out.println(name.get(predictedClassId));
+ } catch (OrtException e) {
+ throw new RuntimeException(e);
+ } finally {
+ if (sessionOptions != null) {
+
+ sessionOptions.close();
+ }
+ }
+
+// // 处理图像
+// float[] imageData = new float[0];
+// try {
+// imageData = processImageFromURL(imagePath, imageSize);
+// } catch (IOException ex) {
+// throw new RuntimeException(ex);
+// }
+//
+// // 构建输入张量
+// long[] shape = new long[]{1, 3, imageSize, imageSize}; // batch_size, channels, height, width
+// OnnxTensor inputTensor = null;
+// try {
+// inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(imageData), shape);
+//
+// HashMap stringOnnxTensorHashMap = new HashMap<>();
+// stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), inputTensor);
+//
+// // 执行推理
+//// OrtSession.Result result = session.run(stringOnnxTensorHashMap);
+////
+//// float[][] outputData = ((float[][][])result.get(0).getValue())[0];
+////
+//// outputData = transposeMatrix(outputData);
+////
+////// float[] output = (float[]) result.get(0).getValue();
+////
+//// // 解析推理结果
+//// List detections = parseOutput(outputData,0.4);
+////
+//// List filteredDetections = nonMaxSuppression(detections, 0.5f); // IoU 阈值设为 0.5
+//// DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections);
+//// System.out.println(filteredDetections);
+//// int mostFrequentClassId = getMaxClassIndex(filteredDetections);
+//// System.out.println(calculateLayers(filteredDetections, 200,mostFrequentClassId));
+//
+// // 执行推理
+// OrtSession.Result result = session.run(stringOnnxTensorHashMap);
+//// 获取第一个输出(大多数情况下只有一个输出)
+// float[] outputData = ((float[][]) result.get(0).getValue())[0];
+//
+//// 获取最大概率对应的类别 ID
+// int predictedClassId = argmax(outputData, 0.8);
+//
+//// float[] output = (float[]) result.get(0).getValue();
+//
+//// DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections);
+// ClassifyEntity classifyEntity = new ClassifyEntity();
+// classifyEntity.setIndex(predictedClassId);
+// } catch (OrtException e) {
+// throw new RuntimeException(e);
+// }
+
+ }
+
+ public static int getMaxClassIndex(List boxes) {
+ return boxes.stream()
+ .collect(Collectors.groupingBy(BoundingBox::getIndex, Collectors.counting()))
+ .entrySet().stream()
+ .max(Map.Entry.comparingByValue())
+ .map(Map.Entry::getKey)
+ .orElse(-1);
+ }
+
+ public static Set getMinTypeClassIndex(List boxes) {
+ return boxes.stream()
+ .map(BoundingBox::getIndex)
+ .collect(Collectors.toSet());
+ }
+
+ /**
+ * 计算层数,当两个 y 值之差小于 threshold 时视为同一层
+ */
+ public static int calculateLayers(List boxes, double threshold, int mostFrequentClassId) {
+ if (boxes == null || boxes.isEmpty()) {
+ return 0;
+ }
+//出现次数做多的classId
+
+ boxes = boxes.stream()
+ .filter(box -> box.getIndex() == 0)
+ .collect(Collectors.toList());
+
+ // 检查过滤后的boxes是否为空
+ if (boxes.isEmpty()) {
+ return 0;
+ }
+ List yValues = new ArrayList<>();
+ for (BoundingBox box : boxes) {
+ yValues.add(box.getY());
+ }
+ Collections.sort(yValues); // 先排序
+
+ int layers = 1; // 至少有一层
+ double currentBase = yValues.get(0);
+
+ for (int i = 1; i < yValues.size(); i++) {
+ if (Math.abs(yValues.get(i) - currentBase) > threshold) {
+ layers++;
+ currentBase = yValues.get(i);
+ }
+ }
+
+ return layers;
+ }
+
+ public static float[][] transposeMatrix(float[][] m) {
+ float[][] temp = new float[m[0].length][m.length];
+ for (int i = 0; i < m.length; i++)
+ for (int j = 0; j < m[0].length; j++)
+ temp[j][i] = m[i][j];
+ return temp;
+ }
+ // 图像文件处理
+//
+// // 处理图片URL
+// private static float[] processImageFromURL(String imageUrl, int imageSize) throws IOException {
+// File file = new File(imageUrl);
+// BufferedImage bufferedImage = ImageIO.read(file);
+// // 假设我们需要将图片缩放为 imageSize*imageSize,并转换为一个符合 YOLO 输入要求的 float 数组
+// BufferedImage resizedImage = new BufferedImage(imageSize, imageSize, BufferedImage.TYPE_INT_RGB);
+// resizedImage.getGraphics().drawImage(bufferedImage, 0, 0, imageSize, imageSize, null);
+//
+// // 将图片转换为 float 数组(假设 RGB 格式)
+// float[] imageData = new float[imageSize * imageSize * 3];
+// int index = 0;
+// int[] rData = new int[imageSize * imageSize];
+// int[] gData = new int[imageSize * imageSize];
+// int[] bData = new int[imageSize * imageSize];
+//
+// int idx = 0;
+// for (int y = 0; y < imageSize; y++) {
+// for (int x = 0; x < imageSize; x++) {
+// int rgb = resizedImage.getRGB(x, y);
+// rData[idx] = (rgb >> 16) & 0xFF;
+// gData[idx] = (rgb >> 8) & 0xFF;
+// bData[idx] = rgb & 0xFF;
+// idx++;
+// }
+// }
+//// 填充为 CHW
+// idx = 0;
+// for (int i = 0; i < imageSize * imageSize; i++) {
+// imageData[idx++] = rData[i] / 255.0f;
+// }
+// for (int i = 0; i < imageSize * imageSize; i++) {
+// imageData[idx++] = gData[i] / 255.0f;
+// }
+// for (int i = 0; i < imageSize * imageSize; i++) {
+// imageData[idx++] = bData[i] / 255.0f;
+// }
+// return imageData;
+// }
+// 处理图片URL
+ private static float[] processImageFromURL(String imageUrl, int imageSize) throws IOException {
+ // 从URL读取图片
+ BufferedImage originalImage;
+ if (imageUrl.startsWith("http")) {
+ URL url = new URL(imageUrl);
+ originalImage = ImageIO.read(url);
+ } else {
+ File file = new File(imageUrl);
+ originalImage = ImageIO.read(file);
+ }
+
+ // 创建目标尺寸的图片
+ BufferedImage resizedImage = new BufferedImage(imageSize, imageSize, BufferedImage.TYPE_INT_RGB);
+ Graphics2D g2d = resizedImage.createGraphics();
+
+ // 设置背景为黑色
+ g2d.setColor(Color.BLACK);
+ g2d.fillRect(0, 0, imageSize, imageSize);
+
+ // 计算缩放比例,保持宽高比
+ int originalWidth = originalImage.getWidth();
+ int originalHeight = originalImage.getHeight();
+
+ // 按照长边计算缩放比例
+ double scale = Math.min((double) imageSize / originalWidth, (double) imageSize / originalHeight);
+
+ // 计算缩放后的尺寸
+ int scaledWidth = (int) (originalWidth * scale);
+ int scaledHeight = (int) (originalHeight * scale);
+
+ // 计算居中位置
+ int x = (imageSize - scaledWidth) / 2;
+ int y = (imageSize - scaledHeight) / 2;
+
+ // 绘制缩放后的图片
+ g2d.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
+ g2d.drawImage(originalImage, x, y, scaledWidth, scaledHeight, null);
+ g2d.dispose();
+
+ // 将图片转换为 float 数组(RGB 格式,归一化到 0-1)
+ float[] imageData = new float[imageSize * imageSize * 3];
+ int index = 0;
+
+ // 按照CHW格式排列数据(通道-高度-宽度)
+ // 先处理所有像素的红色通道
+ for (int i = 0; i < imageSize; i++) {
+ for (int j = 0; j < imageSize; j++) {
+ int rgb = resizedImage.getRGB(j, i);
+ imageData[index++] = ((rgb >> 16) & 0xFF) / 255.0f; // R
+ }
+ }
+
+ // 再处理绿色通道
+ for (int i = 0; i < imageSize; i++) {
+ for (int j = 0; j < imageSize; j++) {
+ int rgb = resizedImage.getRGB(j, i);
+ imageData[index++] = ((rgb >> 8) & 0xFF) / 255.0f; // G
+ }
+ }
+
+ // 最后处理蓝色通道
+ for (int i = 0; i < imageSize; i++) {
+ for (int j = 0; j < imageSize; j++) {
+ int rgb = resizedImage.getRGB(j, i);
+ imageData[index++] = (rgb & 0xFF) / 255.0f; // B
+ }
+ }
+
+ return imageData;
+ }
+
+ public static List nonMaxSuppression(List boxes, float iouThreshold) {
+ if (boxes.isEmpty()) return new ArrayList<>();
+
+ // 按置信度从高到低排序
+ boxes.sort((a, b) -> Float.compare((float) b.getConfidence(), (float) a.getConfidence()));
+
+ List result = new ArrayList<>();
+ while (!boxes.isEmpty()) {
+ BoundingBox best = boxes.remove(0);
+ result.add(best);
+
+ // 过滤掉和 best IoU 太大的框
+ boxes.removeIf(box -> computeIoU(best, box) > iouThreshold);
+ }
+
+ return result;
+ }
+
+ public static float computeIoU(BoundingBox box1, BoundingBox box2) {
+ double x1 = Math.max(box1.getX() - box1.getW() / 2, box2.getX() - box2.getW() / 2);
+ double y1 = Math.max(box1.getY() - box1.getH() / 2, box2.getY() - box2.getH() / 2);
+ double x2 = Math.min(box1.getX() + box1.getW() / 2, box2.getX() + box2.getW() / 2);
+ double y2 = Math.min(box1.getY() + box1.getH() / 2, box2.getY() + box2.getH() / 2);
+
+ double w = Math.max(0, x2 - x1);
+ double h = Math.max(0, y2 - y1);
+ double intersection = w * h;
+
+ double area1 = (box1.getW() * box1.getH());
+ double area2 = (box2.getW() * box2.getH());
+
+ return (float) (intersection / (area1 + area2 - intersection + 1e-6));
+ }
+
+
+
+ @PostConstruct
+ public void initSession() {
+ OrtSession.SessionOptions sessionOptions = null;
+ try {
+ sessionOptions = new OrtSession.SessionOptions();
+ // 使用gpu,需要本机按钻过cuda,并修改pom.xml,不安装也能运行本程序
+ // sessionOptions.addCUDA(0);
+ // 实际项目中,视频识别必须开启GPU,并且要防止队列堆积
+// 大模型
+ OrtSession session = null;
+ if (configProperties.getYoloModelConfig().getModelPath() != null) {
+
+ session = environment.createSession(configProperties.getYoloModelConfig().getModelPath(), sessionOptions);
+
+ OrtSession finalSession = session;
+ session.getInputInfo().keySet().forEach(x -> {
+ System.out.println("input name = " + x);
+ try {
+ TensorInfo tensorInfo = (TensorInfo) finalSession.getInputInfo().get(x).getInfo();
+ long[] shape = tensorInfo.getShape();
+ // 假设形状是 [batch, channels, height, width]
+ if (shape.length >= 4) {
+ configProperties.getYoloModelConfig().setImageSize((int)shape[3]); // width
+ }
+ System.out.println(tensorInfo.toString());
+ } catch (OrtException e) {
+ throw new RuntimeException(e);
+ }
+
+ });
+ ortSessions.put("model", session);
+ ortMap.put("model", configProperties.getYoloModelConfig());
+
+ }
+
+ if (configProperties.getYoloModelConfig().getModelMap() != null) {
+
+ OrtSession.SessionOptions finalSessionOptions = sessionOptions;
+ configProperties.getYoloModelConfig().getModelMap().forEach((k, v) -> {
+ try {
+// 判断v文件是否存在
+ if (!new File(v.getModelPath()).exists()) {
+ log.error("文件不存在:{}", v);
+ return;
+ }
+
+ OrtSession session1 = environment.createSession(v.getModelPath(), finalSessionOptions);
+
+ session1.getInputInfo().keySet().forEach(x -> {
+ System.out.println("input name = " + x);
+ try {
+ TensorInfo tensorInfo = (TensorInfo) session1.getInputInfo().get(x).getInfo();
+ long[] shape = tensorInfo.getShape();
+ // 假设形状是 [batch, channels, height, width]
+ if (shape.length >= 4) {
+ v.setImageSize((int)shape[3]); // width
+ }
+ } catch (OrtException e) {
+ throw new RuntimeException(e);
+ }
+
+ });
+ ortSessions.put(k, session1);
+ ortMap.put(k, v);
+ } catch (OrtException e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+ } catch (OrtException e) {
+ throw new RuntimeException(e);
+ }
+ // sessionOptions 不在这里关闭,它的生命周期应该与 Session 一致
+ }
+
+ private OrtSession getSession(String type) {
+ if (type != null && !type.equals("")) {
+ return ortSessions.get(type);
+ }
+ return ortSessions.get("model");
+ }
+
+ private AppConfig.YoloModelConfig getConfig(String type) {
+ if (type != null && !type.equals("")) {
+ return ortMap.get(type);
+ }
+ return ortMap.get("model");
+ }
+
+ /**
+ * yolo11计算图片有多少个识别到的模版
+ *
+ * @param imagePath 输入的图片位置
+ * @return
+ * @throws OrtException
+ */
+ public List classify(String imagePath, String type) throws OrtException {
+
+ OrtSession session = getSession(type);
+ AppConfig.YoloModelConfig config = getConfig(type);
+ // 处理图像
+ float[] imageData = new float[0];
+ try {
+ imageData = processImageFromURL(imagePath, config.getImageSize());
+ } catch (IOException ex) {
+ throw new RuntimeException(ex);
+ }
+
+ // 构建输入张量
+ long[] shape = new long[]{1, 3, config.getImageSize(), config.getImageSize()}; // batch_size, channels, height, width
+ OnnxTensor inputTensor = null;
+ OrtSession.Result result = null;
+ try {
+ inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(imageData), shape);
+
+ HashMap stringOnnxTensorHashMap = new HashMap<>();
+ stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), inputTensor);
+
+ // 执行推理
+ result = session.run(stringOnnxTensorHashMap);
+ // 获取第一个输出(大多数情况下只有一个输出)
+ float[] outputData = ((float[][]) result.get(0).getValue())[0];
+ List classifyEntities = new ArrayList<>();
+ for (int i = 0; i < outputData.length; i += 5) {
+ ClassifyEntity classifyEntity = new ClassifyEntity();
+ classifyEntity.setIndex(i);
+ classifyEntity.setName(config.getNames()[i]);
+ classifyEntity.setConfidence(outputData[i]);
+ // classifyEntity.setClassProb(String.format("%.4f", outputData[i]));
+ classifyEntities.add(classifyEntity);
+ }
+
+ // 获取最大概率对应的类别 ID
+ // int predictedClassId = argmax(outputData, configProperties.getYoloModelConfig().getConfThreshold());
+
+ // float[] output = (float[]) result.get(0).getValue();
+
+ // DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections);
+
+ return classifyEntities;
+ } finally {
+ if (inputTensor != null) {
+ inputTensor.close();
+ }
+ if (result != null) {
+ result.close();
+ }
+ }
+ }
+
+
+ /**
+ * yolo11计算图片有多少个识别到的模版
+ *
+ * @param imagePath 输入的图片位置
+ * @return
+ * @throws OrtException
+ */
+ public ClassifyEntity classifyOne(String imagePath, String type) throws OrtException {
+
+ OrtSession session = getSession(type);
+ AppConfig.YoloModelConfig config = getConfig(type);
+ // 处理图像
+ float[] imageData = new float[0];
+ try {
+ imageData = processImageFromURL(imagePath, config.getImageSize());
+ } catch (IOException ex) {
+ throw new RuntimeException(ex);
+ }
+
+ // 构建输入张量
+ long[] shape = new long[]{1, 3, config.getImageSize(), config.getImageSize()}; // batch_size, channels, height, width
+ OnnxTensor inputTensor = null;
+ OrtSession.Result result = null;
+ try {
+ inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(imageData), shape);
+
+ HashMap stringOnnxTensorHashMap = new HashMap<>();
+ stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), inputTensor);
+
+ // 执行推理
+ result = session.run(stringOnnxTensorHashMap);
+ // 获取第一个输出(大多数情况下只有一个输出)
+ float[] outputData = ((float[][]) result.get(0).getValue())[0];
+
+ // 获取最大概率对应的类别 ID
+ int predictedClassId = argmax(outputData, configProperties.getYoloModelConfig().getConfThreshold());
+ ClassifyEntity classifyEntity = new ClassifyEntity();
+ classifyEntity.setIndex(predictedClassId);
+ if (predictedClassId!=-1){
+ log.info("识别模版:{}",config.getNames()[predictedClassId]);
+ classifyEntity.setName(config.getNames()[predictedClassId]);
+ classifyEntity.setConfidence(outputData[predictedClassId]);
+ }else {
+ log.info("未识别模版");
+ }
+ // classifyEntity.setClassProb(String.format("%.4f", outputData[i]));
+
+
+ // float[] output = (float[]) result.get(0).getValue();
+
+ // DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections);
+
+ return classifyEntity;
+ } finally {
+ if (inputTensor != null) {
+ inputTensor.close();
+ }
+ if (result != null) {
+ result.close();
+ }
+ }
+ }
+
+ /**
+ * yolo11计算图片有多少个识别到的模版
+ *
+ * @param imagePath 输入的图片位置
+ * @return
+ * @throws OrtException
+ */
+ public List detect(String imagePath, String type) throws OrtException {
+
+ OrtSession session = getSession(type);
+ AppConfig.YoloModelConfig config = getConfig(type);
+ // 处理图像
+ float[] imageData = new float[0];
+ try {
+ imageData = processImageFromURL(imagePath, config.getImageSize());
+ } catch (IOException ex) {
+ log.error("处理图像时出错: {}", ex.getMessage(), ex);
+ return new ArrayList<>();
+ }
+
+ // 构建输入张量
+ long[] shape = new long[]{1, 3, config.getImageSize(), config.getImageSize()}; // batch_size, channels, height, width
+ OnnxTensor inputTensor = null;
+ OrtSession.Result result = null;
+ try {
+ inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(imageData), shape);
+
+ HashMap stringOnnxTensorHashMap = new HashMap<>();
+ stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), inputTensor);
+
+ // 执行推理
+ result = session.run(stringOnnxTensorHashMap);
+
+ float[][] outputData = ((float[][][]) result.get(0).getValue())[0];
+
+ outputData = transposeMatrix(outputData);
+
+ // float[] output = (float[]) result.get(0).getValue();
+
+ // 解析推理结果
+ List detections = parseOutput(outputData, config.getConfThreshold(), config);
+
+
+ List filteredDetections = nonMaxSuppression(detections, 0.5f); // IoU 阈值设为 0.5
+
+ System.out.println(filteredDetections);
+ return filteredDetections;
+ } finally {
+ if (inputTensor != null) {
+ inputTensor.close();
+ }
+ if (result != null) {
+ result.close();
+ }
+ }
+ }
+
+ private static List parseOutput(float[][] outputData, double confThreshold, AppConfig.YoloModelConfig config) {
+ List detections = new ArrayList<>();
+
+ for (float[] bbox : outputData) {
+
+
+ float[] conditionalProbabilities = Arrays.copyOfRange(bbox, 4, bbox.length);
+ int label = argmax(conditionalProbabilities, confThreshold);
+
+ // 如果没有找到满足阈值的类别,跳过这个检测框
+ if (label == -1) {
+ continue;
+ }
+
+ float conf = conditionalProbabilities[label];
+
+ bbox[4] = conf;
+
+ // xywh to (x1, y1, x2, y2)
+ float x = bbox[0];
+ float y = bbox[1];
+ float w = bbox[2];
+ float h = bbox[3];
+
+ bbox[0] = x - w * 0.5f;
+ bbox[1] = y - h * 0.5f;
+ bbox[2] = x + w * 0.5f;
+ bbox[3] = y + h * 0.5f;
+
+ if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) continue;
+ BoundingBox detection = new BoundingBox();
+ detection.setX(x);
+ detection.setY(y);
+ detection.setW(w);
+ detection.setH(h);
+ detection.setConfidence(conf);
+ detection.setIndex(label);
+ detection.setName(config.getNames()[label]);
+
+ detections.add(detection);
+ }
+
+ return detections;
+ }
+
+ //返回最大值的索引
+ public static int argmax(float[] a, double threshold) {
+ float re = -Float.MAX_VALUE;
+ int arg = -1;
+ for (int i = 0; i < a.length; i++) {
+ if (a[i] < threshold) continue;
+ if (a[i] >= re) {
+ re = a[i];
+ arg = i;
+ }
+ }
+ return arg;
+ }
+
+ /**
+ * 销毁时清理资源
+ */
+ @jakarta.annotation.PreDestroy
+ public void destroy() {
+ log.info("开始清理 ONNX 资源...");
+ for (Map.Entry entry : ortSessions.entrySet()) {
+ try {
+ entry.getValue().close();
+ log.info("已关闭 Session: {}", entry.getKey());
+ } catch (OrtException e) {
+ log.error("关闭 Session 时出错: {}", entry.getKey(), e);
+ }
+ }
+ ortSessions.clear();
+ ortMap.clear();
+ log.info("ONNX 资源清理完成");
+ }
+}