diff --git a/pom.xml b/pom.xml index 00e3152..6f544ce 100644 --- a/pom.xml +++ b/pom.xml @@ -77,7 +77,11 @@ hutool-all 5.8.16 - + + com.microsoft.onnxruntime + onnxruntime + 1.16.0 + MvCameraControlWrapper MvCameraControlWrapper diff --git a/src/main/java/com/example/lxcameraapi/controller/HikController.java b/src/main/java/com/example/lxcameraapi/controller/HikController.java index 2ae4462..bc5f724 100644 --- a/src/main/java/com/example/lxcameraapi/controller/HikController.java +++ b/src/main/java/com/example/lxcameraapi/controller/HikController.java @@ -175,6 +175,65 @@ public class HikController { return pojo; } + // 识别 + @GetMapping("/distinguishOnnx") + @ResponseBody + public Pojo distinguishOnnx(String streetNumber,int direction, String category) throws IOException { + String picPath = appConfig.getPicPath(); + String path =streetNumber+"\\"+direction+"_"+ UUID.randomUUID() +".jpg"; + Mat img2 = null; + for (AppConfig.Camera camera : appConfig.getHikCamera()){ + if (camera.getStreetId().equals(streetNumber)&& camera.getDirection()==direction) { + boolean i = hikSaveImage.saveImage(camera.getIp(),picPath+ path, "ip"); + + img2 = Imgcodecs.imread(picPath+path);; + img2 = Calibration.performCalibration(img2, camera.getId(), appConfig.getActualRectangularRatio()); + } + } + + + // 进行标定,保存图片 + Imgcodecs.imwrite(picPath+path+".jpg", Objects.requireNonNull(img2)); + + // 指定文件夹路径 + String folderPath = appConfig.templatePath+category; + + // 创建File对象 + File folder = new File(folderPath); + boolean flag = false; + // 判断是否是文件夹并且存在 + if (folder.exists() && folder.isDirectory()) { + // 获取文件夹中的所有文件和子文件夹 + File[] files = folder.listFiles(); + int i= 0; + // 判断文件夹中是否有文件 + if (files != null && files.length > 0) { + for (File file : files) { + i++; + // 判断是否是文件 + if (file.isFile()) { + // 获取文件名并进行判断 + Mat img1 = Imgcodecs.imread(file.getAbsolutePath()); + if(FeatureMatchingExample.matchTemplate(img1, img2,picPath+path+".jpg.jpg",appConfig.getThreshold())){ + flag = true; + + }else { + img2 = Imgcodecs.imread(picPath+path+".jpg.jpg"); + } + } + } + } else { + System.out.println("The folder is empty."); + } + } else { + System.out.println("The specified path is not a valid directory or does not exist."); + } + Pojo pojo = new Pojo(); + pojo.setResult(flag); + pojo.setDeterminePath(path+".jpg.jpg"); + return pojo; + } + // 标定 @GetMapping("/task") @ResponseBody diff --git a/src/main/java/com/example/lxcameraapi/service/IndustrialCamera/algorithm/ONNXServiceNew.java b/src/main/java/com/example/lxcameraapi/service/IndustrialCamera/algorithm/ONNXServiceNew.java new file mode 100644 index 0000000..20209c9 --- /dev/null +++ b/src/main/java/com/example/lxcameraapi/service/IndustrialCamera/algorithm/ONNXServiceNew.java @@ -0,0 +1,705 @@ +package com.example.lxcameraapi.service.IndustrialCamera.algorithm; + +import com.example.lxcameraapi.conf.AppConfig; +import com.example.lxcameraapi.service.IndustrialCamera.yolo.BoundingBox; +import com.example.lxcameraapi.service.IndustrialCamera.yolo.ClassifyEntity; +import lombok.extern.slf4j.Slf4j; + +import ai.onnxruntime.*; +import org.opencv.core.Mat; +import org.opencv.imgcodecs.Imgcodecs; +import org.springframework.http.ResponseEntity; +import org.springframework.stereotype.Service; +import org.springframework.web.multipart.MultipartFile; + +import javax.annotation.PostConstruct; +import javax.annotation.Resource; +import javax.imageio.ImageIO; +import java.awt.*; +import java.awt.image.BufferedImage; +import java.io.File; +import java.io.IOException; +import java.net.URL; +import java.nio.FloatBuffer; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.*; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +@Slf4j +@Service +public class ONNXServiceNew { + + + @Resource + AppConfig configProperties; + + + private static final Map ortSessions = new ConcurrentHashMap<>(); + + private static final Map ortMap = new ConcurrentHashMap<>(); + + private static final OrtEnvironment environment = OrtEnvironment.getEnvironment(); + + public static void main(String[] args) { + + System.load(new File(System.getProperty("user.dir")+"\\libs\\opencv\\opencv_java480.dll").getAbsolutePath()); + ortMap.put("a", new AppConfig.YoloModelConfig()); + ortMap.get("a").setImageSize(2048); + ortMap.get("a").setConfThreshold(0.5f); + ortMap.get("a").setNames(new String[]{"0"}); + + int imageSize = 2048; + String imagePath = "D:\\data\\1776157002220.png"; + + String modelPath = "D:\\data\\best.onnx"; +// List name = Arrays.asList("0143", "0153", "0173", "0177", "0191", "0253", "0256", "0266", "0268", "0286", "0302", "0304", "0305", "0307", "0320", "0326", "0336", "0339", "0343", "0352", "0458", "0461", "0462", "0473", "0477", "0486", "0490", "0492", "0930", "1101", "1102", "1104", "1262", "1269", "1302", "1308", "1359", "1366", "1622", "1625", "1919", "1976", "20", "2165", "2188", "2210", "2224", "2445", "2476", "2611", "2730", "2731", "2910", "2914", "2943", "3027", "3028", "3029", "3212", "3226", "3344", "3501", "3509", "3538", "3725", "3741", "3751", "3754", "3763", "3766", "3808"); + + OrtSession.SessionOptions sessionOptions = null; + + // 使用gpu,需要本机按钻过cuda,并修改pom.xml,不安装也能运行本程序 + // sessionOptions.addCUDA(0); + // 实际项目中,视频识别必须开启GPU,并且要防止队列堆积 + + OrtSession session = null; + ONNXServiceNew onnxServiceNew = new ONNXServiceNew(); + try { + sessionOptions = new OrtSession.SessionOptions(); + session = environment.createSession(modelPath, sessionOptions); + ortSessions.put("a", session); + List boxes = onnxServiceNew.detect(imagePath, "a"); + +// classifyEntity.setIndex(predictedClassId); +// classifyEntity.setName(config.getNames()[predictedClassId]); +// classifyEntity.setConfidence(outputData[predictedClassId]); +// classifyEntity.setClassProb(String.format("%.4f", outputData[i])); + + +// float[] output = (float[]) result.get(0).getValue(); + +// DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections); + + System.out.println(boxes); +// System.out.println(name.get(predictedClassId)); + } catch (OrtException e) { + throw new RuntimeException(e); + } finally { + if (sessionOptions != null) { + + sessionOptions.close(); + } + } + +// // 处理图像 +// float[] imageData = new float[0]; +// try { +// imageData = processImageFromURL(imagePath, imageSize); +// } catch (IOException ex) { +// throw new RuntimeException(ex); +// } +// +// // 构建输入张量 +// long[] shape = new long[]{1, 3, imageSize, imageSize}; // batch_size, channels, height, width +// OnnxTensor inputTensor = null; +// try { +// inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(imageData), shape); +// +// HashMap stringOnnxTensorHashMap = new HashMap<>(); +// stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), inputTensor); +// +// // 执行推理 +//// OrtSession.Result result = session.run(stringOnnxTensorHashMap); +//// +//// float[][] outputData = ((float[][][])result.get(0).getValue())[0]; +//// +//// outputData = transposeMatrix(outputData); +//// +////// float[] output = (float[]) result.get(0).getValue(); +//// +//// // 解析推理结果 +//// List detections = parseOutput(outputData,0.4); +//// +//// List filteredDetections = nonMaxSuppression(detections, 0.5f); // IoU 阈值设为 0.5 +//// DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections); +//// System.out.println(filteredDetections); +//// int mostFrequentClassId = getMaxClassIndex(filteredDetections); +//// System.out.println(calculateLayers(filteredDetections, 200,mostFrequentClassId)); +// +// // 执行推理 +// OrtSession.Result result = session.run(stringOnnxTensorHashMap); +//// 获取第一个输出(大多数情况下只有一个输出) +// float[] outputData = ((float[][]) result.get(0).getValue())[0]; +// +//// 获取最大概率对应的类别 ID +// int predictedClassId = argmax(outputData, 0.8); +// +//// float[] output = (float[]) result.get(0).getValue(); +// +//// DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections); +// ClassifyEntity classifyEntity = new ClassifyEntity(); +// classifyEntity.setIndex(predictedClassId); +// } catch (OrtException e) { +// throw new RuntimeException(e); +// } + + } + + public static int getMaxClassIndex(List boxes) { + return boxes.stream() + .collect(Collectors.groupingBy(BoundingBox::getIndex, Collectors.counting())) + .entrySet().stream() + .max(Map.Entry.comparingByValue()) + .map(Map.Entry::getKey) + .orElse(-1); + } + + public static Set getMinTypeClassIndex(List boxes) { + return boxes.stream() + .map(BoundingBox::getIndex) + .collect(Collectors.toSet()); + } + + /** + * 计算层数,当两个 y 值之差小于 threshold 时视为同一层 + */ + public static int calculateLayers(List boxes, double threshold, int mostFrequentClassId) { + if (boxes == null || boxes.isEmpty()) { + return 0; + } +//出现次数做多的classId + + boxes = boxes.stream() + .filter(box -> box.getIndex() == 0) + .collect(Collectors.toList()); + + // 检查过滤后的boxes是否为空 + if (boxes.isEmpty()) { + return 0; + } + List yValues = new ArrayList<>(); + for (BoundingBox box : boxes) { + yValues.add(box.getY()); + } + Collections.sort(yValues); // 先排序 + + int layers = 1; // 至少有一层 + double currentBase = yValues.get(0); + + for (int i = 1; i < yValues.size(); i++) { + if (Math.abs(yValues.get(i) - currentBase) > threshold) { + layers++; + currentBase = yValues.get(i); + } + } + + return layers; + } + + public static float[][] transposeMatrix(float[][] m) { + float[][] temp = new float[m[0].length][m.length]; + for (int i = 0; i < m.length; i++) + for (int j = 0; j < m[0].length; j++) + temp[j][i] = m[i][j]; + return temp; + } + // 图像文件处理 +// +// // 处理图片URL +// private static float[] processImageFromURL(String imageUrl, int imageSize) throws IOException { +// File file = new File(imageUrl); +// BufferedImage bufferedImage = ImageIO.read(file); +// // 假设我们需要将图片缩放为 imageSize*imageSize,并转换为一个符合 YOLO 输入要求的 float 数组 +// BufferedImage resizedImage = new BufferedImage(imageSize, imageSize, BufferedImage.TYPE_INT_RGB); +// resizedImage.getGraphics().drawImage(bufferedImage, 0, 0, imageSize, imageSize, null); +// +// // 将图片转换为 float 数组(假设 RGB 格式) +// float[] imageData = new float[imageSize * imageSize * 3]; +// int index = 0; +// int[] rData = new int[imageSize * imageSize]; +// int[] gData = new int[imageSize * imageSize]; +// int[] bData = new int[imageSize * imageSize]; +// +// int idx = 0; +// for (int y = 0; y < imageSize; y++) { +// for (int x = 0; x < imageSize; x++) { +// int rgb = resizedImage.getRGB(x, y); +// rData[idx] = (rgb >> 16) & 0xFF; +// gData[idx] = (rgb >> 8) & 0xFF; +// bData[idx] = rgb & 0xFF; +// idx++; +// } +// } +//// 填充为 CHW +// idx = 0; +// for (int i = 0; i < imageSize * imageSize; i++) { +// imageData[idx++] = rData[i] / 255.0f; +// } +// for (int i = 0; i < imageSize * imageSize; i++) { +// imageData[idx++] = gData[i] / 255.0f; +// } +// for (int i = 0; i < imageSize * imageSize; i++) { +// imageData[idx++] = bData[i] / 255.0f; +// } +// return imageData; +// } +// 处理图片URL + private static float[] processImageFromURL(String imageUrl, int imageSize) throws IOException { + // 从URL读取图片 + BufferedImage originalImage; + if (imageUrl.startsWith("http")) { + URL url = new URL(imageUrl); + originalImage = ImageIO.read(url); + } else { + File file = new File(imageUrl); + originalImage = ImageIO.read(file); + } + + // 创建目标尺寸的图片 + BufferedImage resizedImage = new BufferedImage(imageSize, imageSize, BufferedImage.TYPE_INT_RGB); + Graphics2D g2d = resizedImage.createGraphics(); + + // 设置背景为黑色 + g2d.setColor(Color.BLACK); + g2d.fillRect(0, 0, imageSize, imageSize); + + // 计算缩放比例,保持宽高比 + int originalWidth = originalImage.getWidth(); + int originalHeight = originalImage.getHeight(); + + // 按照长边计算缩放比例 + double scale = Math.min((double) imageSize / originalWidth, (double) imageSize / originalHeight); + + // 计算缩放后的尺寸 + int scaledWidth = (int) (originalWidth * scale); + int scaledHeight = (int) (originalHeight * scale); + + // 计算居中位置 + int x = (imageSize - scaledWidth) / 2; + int y = (imageSize - scaledHeight) / 2; + + // 绘制缩放后的图片 + g2d.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR); + g2d.drawImage(originalImage, x, y, scaledWidth, scaledHeight, null); + g2d.dispose(); + + // 将图片转换为 float 数组(RGB 格式,归一化到 0-1) + float[] imageData = new float[imageSize * imageSize * 3]; + int index = 0; + + // 按照CHW格式排列数据(通道-高度-宽度) + // 先处理所有像素的红色通道 + for (int i = 0; i < imageSize; i++) { + for (int j = 0; j < imageSize; j++) { + int rgb = resizedImage.getRGB(j, i); + imageData[index++] = ((rgb >> 16) & 0xFF) / 255.0f; // R + } + } + + // 再处理绿色通道 + for (int i = 0; i < imageSize; i++) { + for (int j = 0; j < imageSize; j++) { + int rgb = resizedImage.getRGB(j, i); + imageData[index++] = ((rgb >> 8) & 0xFF) / 255.0f; // G + } + } + + // 最后处理蓝色通道 + for (int i = 0; i < imageSize; i++) { + for (int j = 0; j < imageSize; j++) { + int rgb = resizedImage.getRGB(j, i); + imageData[index++] = (rgb & 0xFF) / 255.0f; // B + } + } + + return imageData; + } + + public static List nonMaxSuppression(List boxes, float iouThreshold) { + if (boxes.isEmpty()) return new ArrayList<>(); + + // 按置信度从高到低排序 + boxes.sort((a, b) -> Float.compare((float) b.getConfidence(), (float) a.getConfidence())); + + List result = new ArrayList<>(); + while (!boxes.isEmpty()) { + BoundingBox best = boxes.remove(0); + result.add(best); + + // 过滤掉和 best IoU 太大的框 + boxes.removeIf(box -> computeIoU(best, box) > iouThreshold); + } + + return result; + } + + public static float computeIoU(BoundingBox box1, BoundingBox box2) { + double x1 = Math.max(box1.getX() - box1.getW() / 2, box2.getX() - box2.getW() / 2); + double y1 = Math.max(box1.getY() - box1.getH() / 2, box2.getY() - box2.getH() / 2); + double x2 = Math.min(box1.getX() + box1.getW() / 2, box2.getX() + box2.getW() / 2); + double y2 = Math.min(box1.getY() + box1.getH() / 2, box2.getY() + box2.getH() / 2); + + double w = Math.max(0, x2 - x1); + double h = Math.max(0, y2 - y1); + double intersection = w * h; + + double area1 = (box1.getW() * box1.getH()); + double area2 = (box2.getW() * box2.getH()); + + return (float) (intersection / (area1 + area2 - intersection + 1e-6)); + } + + + + @PostConstruct + public void initSession() { + OrtSession.SessionOptions sessionOptions = null; + try { + sessionOptions = new OrtSession.SessionOptions(); + // 使用gpu,需要本机按钻过cuda,并修改pom.xml,不安装也能运行本程序 + // sessionOptions.addCUDA(0); + // 实际项目中,视频识别必须开启GPU,并且要防止队列堆积 +// 大模型 + OrtSession session = null; + if (configProperties.getYoloModelConfig().getModelPath() != null) { + + session = environment.createSession(configProperties.getYoloModelConfig().getModelPath(), sessionOptions); + + OrtSession finalSession = session; + session.getInputInfo().keySet().forEach(x -> { + System.out.println("input name = " + x); + try { + TensorInfo tensorInfo = (TensorInfo) finalSession.getInputInfo().get(x).getInfo(); + long[] shape = tensorInfo.getShape(); + // 假设形状是 [batch, channels, height, width] + if (shape.length >= 4) { + configProperties.getYoloModelConfig().setImageSize((int)shape[3]); // width + } + System.out.println(tensorInfo.toString()); + } catch (OrtException e) { + throw new RuntimeException(e); + } + + }); + ortSessions.put("model", session); + ortMap.put("model", configProperties.getYoloModelConfig()); + + } + + if (configProperties.getYoloModelConfig().getModelMap() != null) { + + OrtSession.SessionOptions finalSessionOptions = sessionOptions; + configProperties.getYoloModelConfig().getModelMap().forEach((k, v) -> { + try { +// 判断v文件是否存在 + if (!new File(v.getModelPath()).exists()) { + log.error("文件不存在:{}", v); + return; + } + + OrtSession session1 = environment.createSession(v.getModelPath(), finalSessionOptions); + + session1.getInputInfo().keySet().forEach(x -> { + System.out.println("input name = " + x); + try { + TensorInfo tensorInfo = (TensorInfo) session1.getInputInfo().get(x).getInfo(); + long[] shape = tensorInfo.getShape(); + // 假设形状是 [batch, channels, height, width] + if (shape.length >= 4) { + v.setImageSize((int)shape[3]); // width + } + } catch (OrtException e) { + throw new RuntimeException(e); + } + + }); + ortSessions.put(k, session1); + ortMap.put(k, v); + } catch (OrtException e) { + throw new RuntimeException(e); + } + }); + } + } catch (OrtException e) { + throw new RuntimeException(e); + } + // sessionOptions 不在这里关闭,它的生命周期应该与 Session 一致 + } + + private OrtSession getSession(String type) { + if (type != null && !type.equals("")) { + return ortSessions.get(type); + } + return ortSessions.get("model"); + } + + private AppConfig.YoloModelConfig getConfig(String type) { + if (type != null && !type.equals("")) { + return ortMap.get(type); + } + return ortMap.get("model"); + } + + /** + * yolo11计算图片有多少个识别到的模版 + * + * @param imagePath 输入的图片位置 + * @return + * @throws OrtException + */ + public List classify(String imagePath, String type) throws OrtException { + + OrtSession session = getSession(type); + AppConfig.YoloModelConfig config = getConfig(type); + // 处理图像 + float[] imageData = new float[0]; + try { + imageData = processImageFromURL(imagePath, config.getImageSize()); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + + // 构建输入张量 + long[] shape = new long[]{1, 3, config.getImageSize(), config.getImageSize()}; // batch_size, channels, height, width + OnnxTensor inputTensor = null; + OrtSession.Result result = null; + try { + inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(imageData), shape); + + HashMap stringOnnxTensorHashMap = new HashMap<>(); + stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), inputTensor); + + // 执行推理 + result = session.run(stringOnnxTensorHashMap); + // 获取第一个输出(大多数情况下只有一个输出) + float[] outputData = ((float[][]) result.get(0).getValue())[0]; + List classifyEntities = new ArrayList<>(); + for (int i = 0; i < outputData.length; i += 5) { + ClassifyEntity classifyEntity = new ClassifyEntity(); + classifyEntity.setIndex(i); + classifyEntity.setName(config.getNames()[i]); + classifyEntity.setConfidence(outputData[i]); + // classifyEntity.setClassProb(String.format("%.4f", outputData[i])); + classifyEntities.add(classifyEntity); + } + + // 获取最大概率对应的类别 ID + // int predictedClassId = argmax(outputData, configProperties.getYoloModelConfig().getConfThreshold()); + + // float[] output = (float[]) result.get(0).getValue(); + + // DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections); + + return classifyEntities; + } finally { + if (inputTensor != null) { + inputTensor.close(); + } + if (result != null) { + result.close(); + } + } + } + + + /** + * yolo11计算图片有多少个识别到的模版 + * + * @param imagePath 输入的图片位置 + * @return + * @throws OrtException + */ + public ClassifyEntity classifyOne(String imagePath, String type) throws OrtException { + + OrtSession session = getSession(type); + AppConfig.YoloModelConfig config = getConfig(type); + // 处理图像 + float[] imageData = new float[0]; + try { + imageData = processImageFromURL(imagePath, config.getImageSize()); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + + // 构建输入张量 + long[] shape = new long[]{1, 3, config.getImageSize(), config.getImageSize()}; // batch_size, channels, height, width + OnnxTensor inputTensor = null; + OrtSession.Result result = null; + try { + inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(imageData), shape); + + HashMap stringOnnxTensorHashMap = new HashMap<>(); + stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), inputTensor); + + // 执行推理 + result = session.run(stringOnnxTensorHashMap); + // 获取第一个输出(大多数情况下只有一个输出) + float[] outputData = ((float[][]) result.get(0).getValue())[0]; + + // 获取最大概率对应的类别 ID + int predictedClassId = argmax(outputData, configProperties.getYoloModelConfig().getConfThreshold()); + ClassifyEntity classifyEntity = new ClassifyEntity(); + classifyEntity.setIndex(predictedClassId); + if (predictedClassId!=-1){ + log.info("识别模版:{}",config.getNames()[predictedClassId]); + classifyEntity.setName(config.getNames()[predictedClassId]); + classifyEntity.setConfidence(outputData[predictedClassId]); + }else { + log.info("未识别模版"); + } + // classifyEntity.setClassProb(String.format("%.4f", outputData[i])); + + + // float[] output = (float[]) result.get(0).getValue(); + + // DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections); + + return classifyEntity; + } finally { + if (inputTensor != null) { + inputTensor.close(); + } + if (result != null) { + result.close(); + } + } + } + + /** + * yolo11计算图片有多少个识别到的模版 + * + * @param imagePath 输入的图片位置 + * @return + * @throws OrtException + */ + public List detect(String imagePath, String type) throws OrtException { + + OrtSession session = getSession(type); + AppConfig.YoloModelConfig config = getConfig(type); + // 处理图像 + float[] imageData = new float[0]; + try { + imageData = processImageFromURL(imagePath, config.getImageSize()); + } catch (IOException ex) { + log.error("处理图像时出错: {}", ex.getMessage(), ex); + return new ArrayList<>(); + } + + // 构建输入张量 + long[] shape = new long[]{1, 3, config.getImageSize(), config.getImageSize()}; // batch_size, channels, height, width + OnnxTensor inputTensor = null; + OrtSession.Result result = null; + try { + inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(imageData), shape); + + HashMap stringOnnxTensorHashMap = new HashMap<>(); + stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), inputTensor); + + // 执行推理 + result = session.run(stringOnnxTensorHashMap); + + float[][] outputData = ((float[][][]) result.get(0).getValue())[0]; + + outputData = transposeMatrix(outputData); + + // float[] output = (float[]) result.get(0).getValue(); + + // 解析推理结果 + List detections = parseOutput(outputData, config.getConfThreshold(), config); + + + List filteredDetections = nonMaxSuppression(detections, 0.5f); // IoU 阈值设为 0.5 + + System.out.println(filteredDetections); + return filteredDetections; + } finally { + if (inputTensor != null) { + inputTensor.close(); + } + if (result != null) { + result.close(); + } + } + } + + private static List parseOutput(float[][] outputData, double confThreshold, AppConfig.YoloModelConfig config) { + List detections = new ArrayList<>(); + + for (float[] bbox : outputData) { + + + float[] conditionalProbabilities = Arrays.copyOfRange(bbox, 4, bbox.length); + int label = argmax(conditionalProbabilities, confThreshold); + + // 如果没有找到满足阈值的类别,跳过这个检测框 + if (label == -1) { + continue; + } + + float conf = conditionalProbabilities[label]; + + bbox[4] = conf; + + // xywh to (x1, y1, x2, y2) + float x = bbox[0]; + float y = bbox[1]; + float w = bbox[2]; + float h = bbox[3]; + + bbox[0] = x - w * 0.5f; + bbox[1] = y - h * 0.5f; + bbox[2] = x + w * 0.5f; + bbox[3] = y + h * 0.5f; + + if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) continue; + BoundingBox detection = new BoundingBox(); + detection.setX(x); + detection.setY(y); + detection.setW(w); + detection.setH(h); + detection.setConfidence(conf); + detection.setIndex(label); + detection.setName(config.getNames()[label]); + + detections.add(detection); + } + + return detections; + } + + //返回最大值的索引 + public static int argmax(float[] a, double threshold) { + float re = -Float.MAX_VALUE; + int arg = -1; + for (int i = 0; i < a.length; i++) { + if (a[i] < threshold) continue; + if (a[i] >= re) { + re = a[i]; + arg = i; + } + } + return arg; + } + + /** + * 销毁时清理资源 + */ + @jakarta.annotation.PreDestroy + public void destroy() { + log.info("开始清理 ONNX 资源..."); + for (Map.Entry entry : ortSessions.entrySet()) { + try { + entry.getValue().close(); + log.info("已关闭 Session: {}", entry.getKey()); + } catch (OrtException e) { + log.error("关闭 Session 时出错: {}", entry.getKey(), e); + } + } + ortSessions.clear(); + ortMap.clear(); + log.info("ONNX 资源清理完成"); + } +}