|
|
|
@ -0,0 +1,705 @@
|
|
|
|
|
|
|
|
package com.example.lxcameraapi.service.IndustrialCamera.algorithm;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import com.example.lxcameraapi.conf.AppConfig;
|
|
|
|
|
|
|
|
import com.example.lxcameraapi.service.IndustrialCamera.yolo.BoundingBox;
|
|
|
|
|
|
|
|
import com.example.lxcameraapi.service.IndustrialCamera.yolo.ClassifyEntity;
|
|
|
|
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import ai.onnxruntime.*;
|
|
|
|
|
|
|
|
import org.opencv.core.Mat;
|
|
|
|
|
|
|
|
import org.opencv.imgcodecs.Imgcodecs;
|
|
|
|
|
|
|
|
import org.springframework.http.ResponseEntity;
|
|
|
|
|
|
|
|
import org.springframework.stereotype.Service;
|
|
|
|
|
|
|
|
import org.springframework.web.multipart.MultipartFile;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import javax.annotation.PostConstruct;
|
|
|
|
|
|
|
|
import javax.annotation.Resource;
|
|
|
|
|
|
|
|
import javax.imageio.ImageIO;
|
|
|
|
|
|
|
|
import java.awt.*;
|
|
|
|
|
|
|
|
import java.awt.image.BufferedImage;
|
|
|
|
|
|
|
|
import java.io.File;
|
|
|
|
|
|
|
|
import java.io.IOException;
|
|
|
|
|
|
|
|
import java.net.URL;
|
|
|
|
|
|
|
|
import java.nio.FloatBuffer;
|
|
|
|
|
|
|
|
import java.nio.file.Files;
|
|
|
|
|
|
|
|
import java.nio.file.Path;
|
|
|
|
|
|
|
|
import java.util.*;
|
|
|
|
|
|
|
|
import java.util.List;
|
|
|
|
|
|
|
|
import java.util.concurrent.ConcurrentHashMap;
|
|
|
|
|
|
|
|
import java.util.regex.Matcher;
|
|
|
|
|
|
|
|
import java.util.regex.Pattern;
|
|
|
|
|
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Slf4j
|
|
|
|
|
|
|
|
@Service
|
|
|
|
|
|
|
|
public class ONNXServiceNew {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Resource
|
|
|
|
|
|
|
|
AppConfig configProperties;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private static final Map<String, OrtSession> ortSessions = new ConcurrentHashMap<>();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private static final Map<String, AppConfig.YoloModelConfig> ortMap = new ConcurrentHashMap<>();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private static final OrtEnvironment environment = OrtEnvironment.getEnvironment();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static void main(String[] args) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
System.load(new File(System.getProperty("user.dir")+"\\libs\\opencv\\opencv_java480.dll").getAbsolutePath());
|
|
|
|
|
|
|
|
ortMap.put("a", new AppConfig.YoloModelConfig());
|
|
|
|
|
|
|
|
ortMap.get("a").setImageSize(2048);
|
|
|
|
|
|
|
|
ortMap.get("a").setConfThreshold(0.5f);
|
|
|
|
|
|
|
|
ortMap.get("a").setNames(new String[]{"0"});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int imageSize = 2048;
|
|
|
|
|
|
|
|
String imagePath = "D:\\data\\1776157002220.png";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
String modelPath = "D:\\data\\best.onnx";
|
|
|
|
|
|
|
|
// List<String> name = Arrays.asList("0143", "0153", "0173", "0177", "0191", "0253", "0256", "0266", "0268", "0286", "0302", "0304", "0305", "0307", "0320", "0326", "0336", "0339", "0343", "0352", "0458", "0461", "0462", "0473", "0477", "0486", "0490", "0492", "0930", "1101", "1102", "1104", "1262", "1269", "1302", "1308", "1359", "1366", "1622", "1625", "1919", "1976", "20", "2165", "2188", "2210", "2224", "2445", "2476", "2611", "2730", "2731", "2910", "2914", "2943", "3027", "3028", "3029", "3212", "3226", "3344", "3501", "3509", "3538", "3725", "3741", "3751", "3754", "3763", "3766", "3808");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OrtSession.SessionOptions sessionOptions = null;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 使用gpu,需要本机按钻过cuda,并修改pom.xml,不安装也能运行本程序
|
|
|
|
|
|
|
|
// sessionOptions.addCUDA(0);
|
|
|
|
|
|
|
|
// 实际项目中,视频识别必须开启GPU,并且要防止队列堆积
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OrtSession session = null;
|
|
|
|
|
|
|
|
ONNXServiceNew onnxServiceNew = new ONNXServiceNew();
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
|
|
|
sessionOptions = new OrtSession.SessionOptions();
|
|
|
|
|
|
|
|
session = environment.createSession(modelPath, sessionOptions);
|
|
|
|
|
|
|
|
ortSessions.put("a", session);
|
|
|
|
|
|
|
|
List<BoundingBox> boxes = onnxServiceNew.detect(imagePath, "a");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// classifyEntity.setIndex(predictedClassId);
|
|
|
|
|
|
|
|
// classifyEntity.setName(config.getNames()[predictedClassId]);
|
|
|
|
|
|
|
|
// classifyEntity.setConfidence(outputData[predictedClassId]);
|
|
|
|
|
|
|
|
// classifyEntity.setClassProb(String.format("%.4f", outputData[i]));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// float[] output = (float[]) result.get(0).getValue();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
System.out.println(boxes);
|
|
|
|
|
|
|
|
// System.out.println(name.get(predictedClassId));
|
|
|
|
|
|
|
|
} catch (OrtException e) {
|
|
|
|
|
|
|
|
throw new RuntimeException(e);
|
|
|
|
|
|
|
|
} finally {
|
|
|
|
|
|
|
|
if (sessionOptions != null) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sessionOptions.close();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// // 处理图像
|
|
|
|
|
|
|
|
// float[] imageData = new float[0];
|
|
|
|
|
|
|
|
// try {
|
|
|
|
|
|
|
|
// imageData = processImageFromURL(imagePath, imageSize);
|
|
|
|
|
|
|
|
// } catch (IOException ex) {
|
|
|
|
|
|
|
|
// throw new RuntimeException(ex);
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// // 构建输入张量
|
|
|
|
|
|
|
|
// long[] shape = new long[]{1, 3, imageSize, imageSize}; // batch_size, channels, height, width
|
|
|
|
|
|
|
|
// OnnxTensor inputTensor = null;
|
|
|
|
|
|
|
|
// try {
|
|
|
|
|
|
|
|
// inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(imageData), shape);
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// HashMap<String, OnnxTensor> stringOnnxTensorHashMap = new HashMap<>();
|
|
|
|
|
|
|
|
// stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), inputTensor);
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// // 执行推理
|
|
|
|
|
|
|
|
//// OrtSession.Result result = session.run(stringOnnxTensorHashMap);
|
|
|
|
|
|
|
|
////
|
|
|
|
|
|
|
|
//// float[][] outputData = ((float[][][])result.get(0).getValue())[0];
|
|
|
|
|
|
|
|
////
|
|
|
|
|
|
|
|
//// outputData = transposeMatrix(outputData);
|
|
|
|
|
|
|
|
////
|
|
|
|
|
|
|
|
////// float[] output = (float[]) result.get(0).getValue();
|
|
|
|
|
|
|
|
////
|
|
|
|
|
|
|
|
//// // 解析推理结果
|
|
|
|
|
|
|
|
//// List<BoundingBox> detections = parseOutput(outputData,0.4);
|
|
|
|
|
|
|
|
////
|
|
|
|
|
|
|
|
//// List<BoundingBox> filteredDetections = nonMaxSuppression(detections, 0.5f); // IoU 阈值设为 0.5
|
|
|
|
|
|
|
|
//// DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections);
|
|
|
|
|
|
|
|
//// System.out.println(filteredDetections);
|
|
|
|
|
|
|
|
//// int mostFrequentClassId = getMaxClassIndex(filteredDetections);
|
|
|
|
|
|
|
|
//// System.out.println(calculateLayers(filteredDetections, 200,mostFrequentClassId));
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// // 执行推理
|
|
|
|
|
|
|
|
// OrtSession.Result result = session.run(stringOnnxTensorHashMap);
|
|
|
|
|
|
|
|
//// 获取第一个输出(大多数情况下只有一个输出)
|
|
|
|
|
|
|
|
// float[] outputData = ((float[][]) result.get(0).getValue())[0];
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
//// 获取最大概率对应的类别 ID
|
|
|
|
|
|
|
|
// int predictedClassId = argmax(outputData, 0.8);
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
//// float[] output = (float[]) result.get(0).getValue();
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
//// DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections);
|
|
|
|
|
|
|
|
// ClassifyEntity classifyEntity = new ClassifyEntity();
|
|
|
|
|
|
|
|
// classifyEntity.setIndex(predictedClassId);
|
|
|
|
|
|
|
|
// } catch (OrtException e) {
|
|
|
|
|
|
|
|
// throw new RuntimeException(e);
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static int getMaxClassIndex(List<BoundingBox> boxes) {
|
|
|
|
|
|
|
|
return boxes.stream()
|
|
|
|
|
|
|
|
.collect(Collectors.groupingBy(BoundingBox::getIndex, Collectors.counting()))
|
|
|
|
|
|
|
|
.entrySet().stream()
|
|
|
|
|
|
|
|
.max(Map.Entry.comparingByValue())
|
|
|
|
|
|
|
|
.map(Map.Entry::getKey)
|
|
|
|
|
|
|
|
.orElse(-1);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static Set<Integer> getMinTypeClassIndex(List<BoundingBox> boxes) {
|
|
|
|
|
|
|
|
return boxes.stream()
|
|
|
|
|
|
|
|
.map(BoundingBox::getIndex)
|
|
|
|
|
|
|
|
.collect(Collectors.toSet());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
|
|
|
* 计算层数,当两个 y 值之差小于 threshold 时视为同一层
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
public static int calculateLayers(List<BoundingBox> boxes, double threshold, int mostFrequentClassId) {
|
|
|
|
|
|
|
|
if (boxes == null || boxes.isEmpty()) {
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
//出现次数做多的classId
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
boxes = boxes.stream()
|
|
|
|
|
|
|
|
.filter(box -> box.getIndex() == 0)
|
|
|
|
|
|
|
|
.collect(Collectors.toList());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 检查过滤后的boxes是否为空
|
|
|
|
|
|
|
|
if (boxes.isEmpty()) {
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
List<Double> yValues = new ArrayList<>();
|
|
|
|
|
|
|
|
for (BoundingBox box : boxes) {
|
|
|
|
|
|
|
|
yValues.add(box.getY());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
Collections.sort(yValues); // 先排序
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int layers = 1; // 至少有一层
|
|
|
|
|
|
|
|
double currentBase = yValues.get(0);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 1; i < yValues.size(); i++) {
|
|
|
|
|
|
|
|
if (Math.abs(yValues.get(i) - currentBase) > threshold) {
|
|
|
|
|
|
|
|
layers++;
|
|
|
|
|
|
|
|
currentBase = yValues.get(i);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return layers;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static float[][] transposeMatrix(float[][] m) {
|
|
|
|
|
|
|
|
float[][] temp = new float[m[0].length][m.length];
|
|
|
|
|
|
|
|
for (int i = 0; i < m.length; i++)
|
|
|
|
|
|
|
|
for (int j = 0; j < m[0].length; j++)
|
|
|
|
|
|
|
|
temp[j][i] = m[i][j];
|
|
|
|
|
|
|
|
return temp;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
// 图像文件处理
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// // 处理图片URL
|
|
|
|
|
|
|
|
// private static float[] processImageFromURL(String imageUrl, int imageSize) throws IOException {
|
|
|
|
|
|
|
|
// File file = new File(imageUrl);
|
|
|
|
|
|
|
|
// BufferedImage bufferedImage = ImageIO.read(file);
|
|
|
|
|
|
|
|
// // 假设我们需要将图片缩放为 imageSize*imageSize,并转换为一个符合 YOLO 输入要求的 float 数组
|
|
|
|
|
|
|
|
// BufferedImage resizedImage = new BufferedImage(imageSize, imageSize, BufferedImage.TYPE_INT_RGB);
|
|
|
|
|
|
|
|
// resizedImage.getGraphics().drawImage(bufferedImage, 0, 0, imageSize, imageSize, null);
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// // 将图片转换为 float 数组(假设 RGB 格式)
|
|
|
|
|
|
|
|
// float[] imageData = new float[imageSize * imageSize * 3];
|
|
|
|
|
|
|
|
// int index = 0;
|
|
|
|
|
|
|
|
// int[] rData = new int[imageSize * imageSize];
|
|
|
|
|
|
|
|
// int[] gData = new int[imageSize * imageSize];
|
|
|
|
|
|
|
|
// int[] bData = new int[imageSize * imageSize];
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// int idx = 0;
|
|
|
|
|
|
|
|
// for (int y = 0; y < imageSize; y++) {
|
|
|
|
|
|
|
|
// for (int x = 0; x < imageSize; x++) {
|
|
|
|
|
|
|
|
// int rgb = resizedImage.getRGB(x, y);
|
|
|
|
|
|
|
|
// rData[idx] = (rgb >> 16) & 0xFF;
|
|
|
|
|
|
|
|
// gData[idx] = (rgb >> 8) & 0xFF;
|
|
|
|
|
|
|
|
// bData[idx] = rgb & 0xFF;
|
|
|
|
|
|
|
|
// idx++;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
//// 填充为 CHW
|
|
|
|
|
|
|
|
// idx = 0;
|
|
|
|
|
|
|
|
// for (int i = 0; i < imageSize * imageSize; i++) {
|
|
|
|
|
|
|
|
// imageData[idx++] = rData[i] / 255.0f;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
// for (int i = 0; i < imageSize * imageSize; i++) {
|
|
|
|
|
|
|
|
// imageData[idx++] = gData[i] / 255.0f;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
// for (int i = 0; i < imageSize * imageSize; i++) {
|
|
|
|
|
|
|
|
// imageData[idx++] = bData[i] / 255.0f;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
// return imageData;
|
|
|
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
// 处理图片URL
|
|
|
|
|
|
|
|
private static float[] processImageFromURL(String imageUrl, int imageSize) throws IOException {
|
|
|
|
|
|
|
|
// 从URL读取图片
|
|
|
|
|
|
|
|
BufferedImage originalImage;
|
|
|
|
|
|
|
|
if (imageUrl.startsWith("http")) {
|
|
|
|
|
|
|
|
URL url = new URL(imageUrl);
|
|
|
|
|
|
|
|
originalImage = ImageIO.read(url);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
File file = new File(imageUrl);
|
|
|
|
|
|
|
|
originalImage = ImageIO.read(file);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 创建目标尺寸的图片
|
|
|
|
|
|
|
|
BufferedImage resizedImage = new BufferedImage(imageSize, imageSize, BufferedImage.TYPE_INT_RGB);
|
|
|
|
|
|
|
|
Graphics2D g2d = resizedImage.createGraphics();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 设置背景为黑色
|
|
|
|
|
|
|
|
g2d.setColor(Color.BLACK);
|
|
|
|
|
|
|
|
g2d.fillRect(0, 0, imageSize, imageSize);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 计算缩放比例,保持宽高比
|
|
|
|
|
|
|
|
int originalWidth = originalImage.getWidth();
|
|
|
|
|
|
|
|
int originalHeight = originalImage.getHeight();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 按照长边计算缩放比例
|
|
|
|
|
|
|
|
double scale = Math.min((double) imageSize / originalWidth, (double) imageSize / originalHeight);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 计算缩放后的尺寸
|
|
|
|
|
|
|
|
int scaledWidth = (int) (originalWidth * scale);
|
|
|
|
|
|
|
|
int scaledHeight = (int) (originalHeight * scale);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 计算居中位置
|
|
|
|
|
|
|
|
int x = (imageSize - scaledWidth) / 2;
|
|
|
|
|
|
|
|
int y = (imageSize - scaledHeight) / 2;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 绘制缩放后的图片
|
|
|
|
|
|
|
|
g2d.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
|
|
|
|
|
|
|
|
g2d.drawImage(originalImage, x, y, scaledWidth, scaledHeight, null);
|
|
|
|
|
|
|
|
g2d.dispose();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 将图片转换为 float 数组(RGB 格式,归一化到 0-1)
|
|
|
|
|
|
|
|
float[] imageData = new float[imageSize * imageSize * 3];
|
|
|
|
|
|
|
|
int index = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 按照CHW格式排列数据(通道-高度-宽度)
|
|
|
|
|
|
|
|
// 先处理所有像素的红色通道
|
|
|
|
|
|
|
|
for (int i = 0; i < imageSize; i++) {
|
|
|
|
|
|
|
|
for (int j = 0; j < imageSize; j++) {
|
|
|
|
|
|
|
|
int rgb = resizedImage.getRGB(j, i);
|
|
|
|
|
|
|
|
imageData[index++] = ((rgb >> 16) & 0xFF) / 255.0f; // R
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 再处理绿色通道
|
|
|
|
|
|
|
|
for (int i = 0; i < imageSize; i++) {
|
|
|
|
|
|
|
|
for (int j = 0; j < imageSize; j++) {
|
|
|
|
|
|
|
|
int rgb = resizedImage.getRGB(j, i);
|
|
|
|
|
|
|
|
imageData[index++] = ((rgb >> 8) & 0xFF) / 255.0f; // G
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 最后处理蓝色通道
|
|
|
|
|
|
|
|
for (int i = 0; i < imageSize; i++) {
|
|
|
|
|
|
|
|
for (int j = 0; j < imageSize; j++) {
|
|
|
|
|
|
|
|
int rgb = resizedImage.getRGB(j, i);
|
|
|
|
|
|
|
|
imageData[index++] = (rgb & 0xFF) / 255.0f; // B
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return imageData;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static List<BoundingBox> nonMaxSuppression(List<BoundingBox> boxes, float iouThreshold) {
|
|
|
|
|
|
|
|
if (boxes.isEmpty()) return new ArrayList<>();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 按置信度从高到低排序
|
|
|
|
|
|
|
|
boxes.sort((a, b) -> Float.compare((float) b.getConfidence(), (float) a.getConfidence()));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
List<BoundingBox> result = new ArrayList<>();
|
|
|
|
|
|
|
|
while (!boxes.isEmpty()) {
|
|
|
|
|
|
|
|
BoundingBox best = boxes.remove(0);
|
|
|
|
|
|
|
|
result.add(best);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 过滤掉和 best IoU 太大的框
|
|
|
|
|
|
|
|
boxes.removeIf(box -> computeIoU(best, box) > iouThreshold);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return result;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static float computeIoU(BoundingBox box1, BoundingBox box2) {
|
|
|
|
|
|
|
|
double x1 = Math.max(box1.getX() - box1.getW() / 2, box2.getX() - box2.getW() / 2);
|
|
|
|
|
|
|
|
double y1 = Math.max(box1.getY() - box1.getH() / 2, box2.getY() - box2.getH() / 2);
|
|
|
|
|
|
|
|
double x2 = Math.min(box1.getX() + box1.getW() / 2, box2.getX() + box2.getW() / 2);
|
|
|
|
|
|
|
|
double y2 = Math.min(box1.getY() + box1.getH() / 2, box2.getY() + box2.getH() / 2);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
double w = Math.max(0, x2 - x1);
|
|
|
|
|
|
|
|
double h = Math.max(0, y2 - y1);
|
|
|
|
|
|
|
|
double intersection = w * h;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
double area1 = (box1.getW() * box1.getH());
|
|
|
|
|
|
|
|
double area2 = (box2.getW() * box2.getH());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return (float) (intersection / (area1 + area2 - intersection + 1e-6));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@PostConstruct
|
|
|
|
|
|
|
|
public void initSession() {
|
|
|
|
|
|
|
|
OrtSession.SessionOptions sessionOptions = null;
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
|
|
|
sessionOptions = new OrtSession.SessionOptions();
|
|
|
|
|
|
|
|
// 使用gpu,需要本机按钻过cuda,并修改pom.xml,不安装也能运行本程序
|
|
|
|
|
|
|
|
// sessionOptions.addCUDA(0);
|
|
|
|
|
|
|
|
// 实际项目中,视频识别必须开启GPU,并且要防止队列堆积
|
|
|
|
|
|
|
|
// 大模型
|
|
|
|
|
|
|
|
OrtSession session = null;
|
|
|
|
|
|
|
|
if (configProperties.getYoloModelConfig().getModelPath() != null) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
session = environment.createSession(configProperties.getYoloModelConfig().getModelPath(), sessionOptions);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OrtSession finalSession = session;
|
|
|
|
|
|
|
|
session.getInputInfo().keySet().forEach(x -> {
|
|
|
|
|
|
|
|
System.out.println("input name = " + x);
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
|
|
|
TensorInfo tensorInfo = (TensorInfo) finalSession.getInputInfo().get(x).getInfo();
|
|
|
|
|
|
|
|
long[] shape = tensorInfo.getShape();
|
|
|
|
|
|
|
|
// 假设形状是 [batch, channels, height, width]
|
|
|
|
|
|
|
|
if (shape.length >= 4) {
|
|
|
|
|
|
|
|
configProperties.getYoloModelConfig().setImageSize((int)shape[3]); // width
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
System.out.println(tensorInfo.toString());
|
|
|
|
|
|
|
|
} catch (OrtException e) {
|
|
|
|
|
|
|
|
throw new RuntimeException(e);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
ortSessions.put("model", session);
|
|
|
|
|
|
|
|
ortMap.put("model", configProperties.getYoloModelConfig());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (configProperties.getYoloModelConfig().getModelMap() != null) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OrtSession.SessionOptions finalSessionOptions = sessionOptions;
|
|
|
|
|
|
|
|
configProperties.getYoloModelConfig().getModelMap().forEach((k, v) -> {
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
|
|
|
// 判断v文件是否存在
|
|
|
|
|
|
|
|
if (!new File(v.getModelPath()).exists()) {
|
|
|
|
|
|
|
|
log.error("文件不存在:{}", v);
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OrtSession session1 = environment.createSession(v.getModelPath(), finalSessionOptions);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
session1.getInputInfo().keySet().forEach(x -> {
|
|
|
|
|
|
|
|
System.out.println("input name = " + x);
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
|
|
|
TensorInfo tensorInfo = (TensorInfo) session1.getInputInfo().get(x).getInfo();
|
|
|
|
|
|
|
|
long[] shape = tensorInfo.getShape();
|
|
|
|
|
|
|
|
// 假设形状是 [batch, channels, height, width]
|
|
|
|
|
|
|
|
if (shape.length >= 4) {
|
|
|
|
|
|
|
|
v.setImageSize((int)shape[3]); // width
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} catch (OrtException e) {
|
|
|
|
|
|
|
|
throw new RuntimeException(e);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
ortSessions.put(k, session1);
|
|
|
|
|
|
|
|
ortMap.put(k, v);
|
|
|
|
|
|
|
|
} catch (OrtException e) {
|
|
|
|
|
|
|
|
throw new RuntimeException(e);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} catch (OrtException e) {
|
|
|
|
|
|
|
|
throw new RuntimeException(e);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
// sessionOptions 不在这里关闭,它的生命周期应该与 Session 一致
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private OrtSession getSession(String type) {
|
|
|
|
|
|
|
|
if (type != null && !type.equals("")) {
|
|
|
|
|
|
|
|
return ortSessions.get(type);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return ortSessions.get("model");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private AppConfig.YoloModelConfig getConfig(String type) {
|
|
|
|
|
|
|
|
if (type != null && !type.equals("")) {
|
|
|
|
|
|
|
|
return ortMap.get(type);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return ortMap.get("model");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
|
|
|
* yolo11计算图片有多少个识别到的模版
|
|
|
|
|
|
|
|
*
|
|
|
|
|
|
|
|
* @param imagePath 输入的图片位置
|
|
|
|
|
|
|
|
* @return
|
|
|
|
|
|
|
|
* @throws OrtException
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
public List<ClassifyEntity> classify(String imagePath, String type) throws OrtException {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OrtSession session = getSession(type);
|
|
|
|
|
|
|
|
AppConfig.YoloModelConfig config = getConfig(type);
|
|
|
|
|
|
|
|
// 处理图像
|
|
|
|
|
|
|
|
float[] imageData = new float[0];
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
|
|
|
imageData = processImageFromURL(imagePath, config.getImageSize());
|
|
|
|
|
|
|
|
} catch (IOException ex) {
|
|
|
|
|
|
|
|
throw new RuntimeException(ex);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 构建输入张量
|
|
|
|
|
|
|
|
long[] shape = new long[]{1, 3, config.getImageSize(), config.getImageSize()}; // batch_size, channels, height, width
|
|
|
|
|
|
|
|
OnnxTensor inputTensor = null;
|
|
|
|
|
|
|
|
OrtSession.Result result = null;
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
|
|
|
inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(imageData), shape);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HashMap<String, OnnxTensor> stringOnnxTensorHashMap = new HashMap<>();
|
|
|
|
|
|
|
|
stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), inputTensor);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 执行推理
|
|
|
|
|
|
|
|
result = session.run(stringOnnxTensorHashMap);
|
|
|
|
|
|
|
|
// 获取第一个输出(大多数情况下只有一个输出)
|
|
|
|
|
|
|
|
float[] outputData = ((float[][]) result.get(0).getValue())[0];
|
|
|
|
|
|
|
|
List<ClassifyEntity> classifyEntities = new ArrayList<>();
|
|
|
|
|
|
|
|
for (int i = 0; i < outputData.length; i += 5) {
|
|
|
|
|
|
|
|
ClassifyEntity classifyEntity = new ClassifyEntity();
|
|
|
|
|
|
|
|
classifyEntity.setIndex(i);
|
|
|
|
|
|
|
|
classifyEntity.setName(config.getNames()[i]);
|
|
|
|
|
|
|
|
classifyEntity.setConfidence(outputData[i]);
|
|
|
|
|
|
|
|
// classifyEntity.setClassProb(String.format("%.4f", outputData[i]));
|
|
|
|
|
|
|
|
classifyEntities.add(classifyEntity);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 获取最大概率对应的类别 ID
|
|
|
|
|
|
|
|
// int predictedClassId = argmax(outputData, configProperties.getYoloModelConfig().getConfThreshold());
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// float[] output = (float[]) result.get(0).getValue();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return classifyEntities;
|
|
|
|
|
|
|
|
} finally {
|
|
|
|
|
|
|
|
if (inputTensor != null) {
|
|
|
|
|
|
|
|
inputTensor.close();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (result != null) {
|
|
|
|
|
|
|
|
result.close();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
|
|
|
* yolo11计算图片有多少个识别到的模版
|
|
|
|
|
|
|
|
*
|
|
|
|
|
|
|
|
* @param imagePath 输入的图片位置
|
|
|
|
|
|
|
|
* @return
|
|
|
|
|
|
|
|
* @throws OrtException
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
public ClassifyEntity classifyOne(String imagePath, String type) throws OrtException {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OrtSession session = getSession(type);
|
|
|
|
|
|
|
|
AppConfig.YoloModelConfig config = getConfig(type);
|
|
|
|
|
|
|
|
// 处理图像
|
|
|
|
|
|
|
|
float[] imageData = new float[0];
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
|
|
|
imageData = processImageFromURL(imagePath, config.getImageSize());
|
|
|
|
|
|
|
|
} catch (IOException ex) {
|
|
|
|
|
|
|
|
throw new RuntimeException(ex);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 构建输入张量
|
|
|
|
|
|
|
|
long[] shape = new long[]{1, 3, config.getImageSize(), config.getImageSize()}; // batch_size, channels, height, width
|
|
|
|
|
|
|
|
OnnxTensor inputTensor = null;
|
|
|
|
|
|
|
|
OrtSession.Result result = null;
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
|
|
|
inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(imageData), shape);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HashMap<String, OnnxTensor> stringOnnxTensorHashMap = new HashMap<>();
|
|
|
|
|
|
|
|
stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), inputTensor);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 执行推理
|
|
|
|
|
|
|
|
result = session.run(stringOnnxTensorHashMap);
|
|
|
|
|
|
|
|
// 获取第一个输出(大多数情况下只有一个输出)
|
|
|
|
|
|
|
|
float[] outputData = ((float[][]) result.get(0).getValue())[0];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 获取最大概率对应的类别 ID
|
|
|
|
|
|
|
|
int predictedClassId = argmax(outputData, configProperties.getYoloModelConfig().getConfThreshold());
|
|
|
|
|
|
|
|
ClassifyEntity classifyEntity = new ClassifyEntity();
|
|
|
|
|
|
|
|
classifyEntity.setIndex(predictedClassId);
|
|
|
|
|
|
|
|
if (predictedClassId!=-1){
|
|
|
|
|
|
|
|
log.info("识别模版:{}",config.getNames()[predictedClassId]);
|
|
|
|
|
|
|
|
classifyEntity.setName(config.getNames()[predictedClassId]);
|
|
|
|
|
|
|
|
classifyEntity.setConfidence(outputData[predictedClassId]);
|
|
|
|
|
|
|
|
}else {
|
|
|
|
|
|
|
|
log.info("未识别模版");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
// classifyEntity.setClassProb(String.format("%.4f", outputData[i]));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// float[] output = (float[]) result.get(0).getValue();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// DrawBoundingBox.drawBoundingBoxesOnImage(imagePath, "D:\\data\\2024-04-24_08-55-42-492data.BMP.jpg", filteredDetections);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return classifyEntity;
|
|
|
|
|
|
|
|
} finally {
|
|
|
|
|
|
|
|
if (inputTensor != null) {
|
|
|
|
|
|
|
|
inputTensor.close();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (result != null) {
|
|
|
|
|
|
|
|
result.close();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
|
|
|
* yolo11计算图片有多少个识别到的模版
|
|
|
|
|
|
|
|
*
|
|
|
|
|
|
|
|
* @param imagePath 输入的图片位置
|
|
|
|
|
|
|
|
* @return
|
|
|
|
|
|
|
|
* @throws OrtException
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
public List<BoundingBox> detect(String imagePath, String type) throws OrtException {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OrtSession session = getSession(type);
|
|
|
|
|
|
|
|
AppConfig.YoloModelConfig config = getConfig(type);
|
|
|
|
|
|
|
|
// 处理图像
|
|
|
|
|
|
|
|
float[] imageData = new float[0];
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
|
|
|
imageData = processImageFromURL(imagePath, config.getImageSize());
|
|
|
|
|
|
|
|
} catch (IOException ex) {
|
|
|
|
|
|
|
|
log.error("处理图像时出错: {}", ex.getMessage(), ex);
|
|
|
|
|
|
|
|
return new ArrayList<>();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 构建输入张量
|
|
|
|
|
|
|
|
long[] shape = new long[]{1, 3, config.getImageSize(), config.getImageSize()}; // batch_size, channels, height, width
|
|
|
|
|
|
|
|
OnnxTensor inputTensor = null;
|
|
|
|
|
|
|
|
OrtSession.Result result = null;
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
|
|
|
inputTensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(imageData), shape);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HashMap<String, OnnxTensor> stringOnnxTensorHashMap = new HashMap<>();
|
|
|
|
|
|
|
|
stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), inputTensor);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 执行推理
|
|
|
|
|
|
|
|
result = session.run(stringOnnxTensorHashMap);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
float[][] outputData = ((float[][][]) result.get(0).getValue())[0];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputData = transposeMatrix(outputData);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// float[] output = (float[]) result.get(0).getValue();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 解析推理结果
|
|
|
|
|
|
|
|
List<BoundingBox> detections = parseOutput(outputData, config.getConfThreshold(), config);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
List<BoundingBox> filteredDetections = nonMaxSuppression(detections, 0.5f); // IoU 阈值设为 0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
System.out.println(filteredDetections);
|
|
|
|
|
|
|
|
return filteredDetections;
|
|
|
|
|
|
|
|
} finally {
|
|
|
|
|
|
|
|
if (inputTensor != null) {
|
|
|
|
|
|
|
|
inputTensor.close();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (result != null) {
|
|
|
|
|
|
|
|
result.close();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
private static List<BoundingBox> parseOutput(float[][] outputData, double confThreshold, AppConfig.YoloModelConfig config) {
|
|
|
|
|
|
|
|
List<BoundingBox> detections = new ArrayList<>();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (float[] bbox : outputData) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
float[] conditionalProbabilities = Arrays.copyOfRange(bbox, 4, bbox.length);
|
|
|
|
|
|
|
|
int label = argmax(conditionalProbabilities, confThreshold);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 如果没有找到满足阈值的类别,跳过这个检测框
|
|
|
|
|
|
|
|
if (label == -1) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
float conf = conditionalProbabilities[label];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bbox[4] = conf;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// xywh to (x1, y1, x2, y2)
|
|
|
|
|
|
|
|
float x = bbox[0];
|
|
|
|
|
|
|
|
float y = bbox[1];
|
|
|
|
|
|
|
|
float w = bbox[2];
|
|
|
|
|
|
|
|
float h = bbox[3];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bbox[0] = x - w * 0.5f;
|
|
|
|
|
|
|
|
bbox[1] = y - h * 0.5f;
|
|
|
|
|
|
|
|
bbox[2] = x + w * 0.5f;
|
|
|
|
|
|
|
|
bbox[3] = y + h * 0.5f;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) continue;
|
|
|
|
|
|
|
|
BoundingBox detection = new BoundingBox();
|
|
|
|
|
|
|
|
detection.setX(x);
|
|
|
|
|
|
|
|
detection.setY(y);
|
|
|
|
|
|
|
|
detection.setW(w);
|
|
|
|
|
|
|
|
detection.setH(h);
|
|
|
|
|
|
|
|
detection.setConfidence(conf);
|
|
|
|
|
|
|
|
detection.setIndex(label);
|
|
|
|
|
|
|
|
detection.setName(config.getNames()[label]);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
detections.add(detection);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return detections;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
//返回最大值的索引
|
|
|
|
|
|
|
|
public static int argmax(float[] a, double threshold) {
|
|
|
|
|
|
|
|
float re = -Float.MAX_VALUE;
|
|
|
|
|
|
|
|
int arg = -1;
|
|
|
|
|
|
|
|
for (int i = 0; i < a.length; i++) {
|
|
|
|
|
|
|
|
if (a[i] < threshold) continue;
|
|
|
|
|
|
|
|
if (a[i] >= re) {
|
|
|
|
|
|
|
|
re = a[i];
|
|
|
|
|
|
|
|
arg = i;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return arg;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
|
|
|
* 销毁时清理资源
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
@jakarta.annotation.PreDestroy
|
|
|
|
|
|
|
|
public void destroy() {
|
|
|
|
|
|
|
|
log.info("开始清理 ONNX 资源...");
|
|
|
|
|
|
|
|
for (Map.Entry<String, OrtSession> entry : ortSessions.entrySet()) {
|
|
|
|
|
|
|
|
try {
|
|
|
|
|
|
|
|
entry.getValue().close();
|
|
|
|
|
|
|
|
log.info("已关闭 Session: {}", entry.getKey());
|
|
|
|
|
|
|
|
} catch (OrtException e) {
|
|
|
|
|
|
|
|
log.error("关闭 Session 时出错: {}", entry.getKey(), e);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
ortSessions.clear();
|
|
|
|
|
|
|
|
ortMap.clear();
|
|
|
|
|
|
|
|
log.info("ONNX 资源清理完成");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|