1.准备好模型文件和对象分类放到同一文件夹下
2.准备 pom文件
<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
<djl.version>0.15.0-SNAPSHOT</djl.version>
<exec.mainClass>ai.djl.examples.inference.ObjectDetection</exec.mainClass>
</properties>
<repositories>
<repository>
<id>djl.ai</id>
<url>https://oss.sonatype.org/content/repositories/snapshots/</url>
</repository>
</repositories>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>${djl.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
<version>1.4</version>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-slf4j-impl</artifactId>
<version>2.12.1</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>${djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>${djl.version}</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
<scope>runtime</scope>
<version>1.9.1</version>
</dependency>
<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<version>6.8.1</version>
<scope>test</scope>
</dependency>
</dependencies>
3.重写官方文件YoloV5Translator.java
修改这一部分即可,也可以在DetectedObjects在输出后进行修改,否则输出的图片无法画圈
@PostMapping("/ObjectDetectionHbj")
public void ObjectDetection(){
DetectedObjects detection = null;
try {
detection = predict();
} catch (IOException e) {
e.printStackTrace();
} catch (ModelException e) {
e.printStackTrace();
} catch (TranslateException e) {
e.printStackTrace();
}
log.info("{}", detection);
}
public static DetectedObjects predict() throws IOException, ModelException, TranslateException {
Path imageFile = Paths.get("src/main/resources/kana1.jpg");
Image img = ImageFactory.getInstance().fromFile(imageFile);
Map<String, Object> arguments = new ConcurrentHashMap<>();
arguments.put("width", 640);//图片以640宽度进行操作
arguments.put("height", 640);//图片以640高度进行操作
arguments.put("resize", true);//调整图片大小
arguments.put("rescale", true);//图片值编程0-1之间
//arguments.put("normalize", true);
/* arguments.put("toTensor", false);//转换成张量
arguments.put("range", "0,1");//范围
arguments.put("normalize", "false");//正态化*/
//arguments.put("threshold", 0.2);//阈值小于0.2不显示
//arguments.put("nmsThreshold", 0.5);
//获取模型分类
Translator<Image, DetectedObjects> translator = YoloV5Translator.builder(arguments).optSynsetArtifactName("coco.names").build();
Criteria<Image, DetectedObjects> criteria =
Criteria.builder()
.optApplication(Application.CV.INSTANCE_SEGMENTATION)
.setTypes(Image.class, DetectedObjects.class)
.optDevice(Device.cpu())
.optModelPath(Paths.get("D:\\work\\git\\model\\yolov5s\\"))
.optModelName("yolov5s.torchscript.pt") //获取模型
.optTranslator(translator)
.optProgress(new ProgressBar())
.optEngine("PyTorch")
.build();
try (ZooModel<Image, DetectedObjects> model = criteria.loadModel()) {
try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
DetectedObjects detection = predictor.predict(img);
saveBoundingBoxImage(img, detection);
return detection;
}
}
}
/**
* @Author bjiang
* @Description //TODO 根据detection绘制图片,输出到 build/output
* @Date 10:08 2021/12/31
* @Version 1.0
* @Param [img, detection]
* @return void
*/
private static void saveBoundingBoxImage(Image img, DetectedObjects detection)
throws IOException {
Path outputDir = Paths.get("build/output");
Files.createDirectories(outputDir);
DetectedObjects detectionNew= DetectedObjectUtil.me().getDetectedObjects(detection);
img.drawBoundingBoxes(detectionNew);
Path imagePath = outputDir.resolve("instances.png");
img.save(Files.newOutputStream(imagePath), "png");
System.out.println("Segmentation result image has been saved in"+detectionNew);
}
private static DetectedObjectUtil instance;
public static DetectedObjectUtil me() {
if (instance == null) {
instance = new DetectedObjectUtil();
}
return instance;
}
/**
* @Author bjiang
* @Description //TODO 重构detection,对象后增加可能性
* @Date 10:06 2021/12/31
* @Version 1.0
* @Param [detection]
* @return ai.djl.modality.cv.output.DetectedObjects
*/
public DetectedObjects getDetectedObjects(DetectedObjects detection){
List<String> className=new ArrayList<>();
List<Double> probability=new ArrayList<>();
List<BoundingBox> boundingBoxes=new ArrayList<>();
for (DetectedObjects.DetectedObject obj : detection.<DetectedObjects.DetectedObject>items()) {
BoundingBox bbox = obj.getBoundingBox();
Rectangle rectangle = bbox.getBounds();
className.add(obj.getClassName()+" " + obj.getProbability());
probability.add(obj.getProbability());
Rectangle rectangleNew=new Rectangle(rectangle.getX(),rectangle.getY(),
rectangle.getWidth(),rectangle.getHeight());
boundingBoxes.add(rectangleNew);
}
DetectedObjects detectionNew=new DetectedObjects(className,probability,boundingBoxes);
return detectionNew;
}
5.执行接口ObjectDetectionHbj,控制台输出
[
class: "person", probability: 0.71513, bounds: [x=0.175, y=0.158, width=0.775, height=0.826]
]
查看build/output目录