Pular para o conteúdo principal

Computer Vision with JavaFX and DJL

Computer Vision is not a new topic in Computer Science, but in the recent years it got a boost due the use of Neural Networks for classification, object detection, instance segmentation and pose prediction.

We can make use of these algorithms in Java using a Deep Learning library, such as DL4J as we already did in this blog with handwritten digits recognition, and detecting objects in a JavaFX application.

However, deep learning is popular mostly among Python developers (one reason is because Python easily wraps on top of native libraries, it will be easier in Java after Project Panama), so we have much more pre-trained models for Tensorflow and Pytorch libraries. It is possible to import them, but it requires a bit more of work then simply reusing a pre-trained model.

Fortunately there is a new library called Deep Java Library which offers a good set of pre-trained models. It makes use of Jupyter, which makes easier to try the library APIs. Another DJL feature is that it is made to wrap an existing library, so it works on top of Keras, Tensorflow, MXNet and other libraries.

In this post we will test some of the DJL Computer Vision models from a JavaFX application. Let's start first capturing webcam from JavaFX then use this input to a pre-trained model


Capturing Web Cam


The input data for the neural network we capture a webcam image. The project that worked without any issue with JavaFX on my Fedora 34 is capture-webcam. I used the JavaFX sample code and it just worked. See my workspace captured from the webcam:


Using Pre-trained Neural Networks

I started with a maven project that only used the webcam-capture. Then I added DJL maven dependencies and Eclipse allowed me to import the classes from DJL. Later I had also the ML library engine to run my models, this is how my final pom.xml looks like:


<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.fxapps</groupId>
<artifactId>webcam-example-javafx</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>WebCam capture from a JavaFX application</name>
<properties>
<maven.compiler.source>16</maven.compiler.source>
<maven.compiler.target>16</maven.compiler.target>
<version.webcam.capture>0.3.12</version.webcam.capture>
<version.javafx>17-ea+16</version.javafx>
<djl.version>0.12.0</djl.version>
</properties>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>bom</artifactId>
<version>${djl.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>
<dependencies>
<!-- WebCam Capture -->
<dependency>
<groupId>com.github.sarxos</groupId>
<artifactId>webcam-capture</artifactId>
<version>${version.webcam.capture}</version>
</dependency>
<!-- OpenJFX -->
<dependency>
<groupId>org.openjfx</groupId>
<artifactId>javafx-controls</artifactId>
<version>${version.javafx}</version>
</dependency>
<dependency>
<groupId>org.openjfx</groupId>
<artifactId>javafx-graphics</artifactId>
<version>${version.javafx}</version>
</dependency>
<dependency>
<groupId>org.openjfx</groupId>
<artifactId>javafx-media</artifactId>
<version>${version.javafx}</version>
</dependency>
<dependency>
<groupId>org.openjfx</groupId>
<artifactId>javafx-swing</artifactId>
<version>${version.javafx}</version>
</dependency>
<!-- ================ From DJL Examples ================ -->
<dependency>
<groupId>commons-cli</groupId>
<artifactId>commons-cli</artifactId>
<version>1.4</version>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-slf4j-impl</artifactId>
<version>2.12.1</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.5</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
</dependency>
<!-- MXNet -->
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-model-zoo</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-engine</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-native-auto</artifactId>
<scope>runtime</scope>
</dependency>
<!-- Pytorch -->
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-auto</artifactId>
</dependency>
<!-- Tensorflow -->
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-model-zoo</artifactId>
</dependency>
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-engine</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-native-cpu</artifactId>
<classifier>linux-x86_64</classifier>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-native-auto</artifactId>
<scope>runtime</scope>
</dependency>
</dependencies>
</project>
view raw pom.xml hosted with ❤ by GitHub

Back to the code, what I wanted was to grab the image from the webcam, input into a pre-trained model, get the result and print in JavaFX instead printing the webcam image. To allow users to test the pre-trained models we added a combo box to the user interface. 


To abstract the model we created a abstract class called MLModel. This class wraps the model call, so we can focus on printing the image and the UI will not know about any specific model, making it easy to add and remove models. The class MLModel grabs the predictions from the ML algorithm and draw on the image accordingly to the result type:

package org.fxapps.predict;
import java.awt.image.BufferedImage;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Joints;
import org.fxapps.predict.model.PoseEstimationPrediction;
public abstract class MLModel<T> {
public abstract String getName();
public T predict(BufferedImage image) {
var img = ImageFactory.getInstance().fromImage(image);
return predict(img);
}
protected abstract T predict(Image image);
public BufferedImage predictAndDraw(BufferedImage image) {
var img = ImageFactory.getInstance().fromImage(image);
T result = predict(img);
if (result instanceof Joints joints) {
img.drawJoints(joints);
}
if (result instanceof PoseEstimationPrediction poseEstimation) {
poseEstimation.personRect().ifPresent(rect -> {
img.getSubimage((int) rect.getX(),
(int) rect.getY(),
(int) rect.getWidth(),
(int) rect.getHeight())
.drawJoints(poseEstimation.joints());
});
}
if (result instanceof DetectedObjects objects) {
img.drawBoundingBoxes(objects);
}
// do not work on Android due hard code of BufferedImage
// change the use of BufferedImage so it should work on Android as well
return (BufferedImage) img.getWrappedImage();
}
@Override
public String toString() {
return getName();
}
}
view raw MLModel.java hosted with ❤ by GitHub

In the UI we have an array of implementations, which are selected when the combo box changes. We could use other ways to select the model, like Java Service Provider, but for our code we decided to keep it simple.


// objects to allow users to select a ML Model
ObjectProperty<MLModel<?>> selectedModel = new SimpleObjectProperty<>();
List<MLModel<?>> models = List.of(new ObjectDetectionMLModel(),
new InstanceSegmentationMLModel(),
new PoseEstimationMLModel());
AtomicBoolean runningPrediction = new AtomicBoolean();
// Creating the combo box and binding the selected model to what user selects
var modelOptions = new ComboBox<MLModel<?>>();
selectedModel.bind(modelOptions.getSelectionModel().selectedItemProperty());
modelOptions.getItems().addAll(models);
modelOptions.getSelectionModel().select(0);
// later when there's a new image from webcam, we lock the prediction while other is running
// the code below runs on a thread separated from JavaFX main thread
while (!stopCamera) {
try {
if ((img = webCam.getImage()) != null) {
if (!runningPrediction.get()) {
runningPrediction.set(true);
img = selectedModel.get().predictAndDraw(img);
runningPrediction.set(false);
}
img.flush();
ref.set(SwingFXUtils.toFXImage(img, ref.get()));
Platform.runLater(() -> imageProperty.set(ref.get()));
}
} catch (Exception e) {
e.printStackTrace();
}
}


Finally it was time to implement the algorithms itself. First we created MLModel implementation for Object detection. It returns a class of type DetectedObject and we re able to build the pre-trained model as we want. We could select the engine, select the data used to train the model and other parameters using the class ai.djl.repository.zoo.Criteria. In our case we selected from the Engine Tensorflow a model that was trained using the mobilenet_v2 dataset. A video is on my twitter.


package org.fxapps.predict;
import java.io.IOException;
import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
public class ObjectDetectionMLModel extends MLModel<DetectedObjects> {
private Predictor<Image, DetectedObjects> predictor;
public ObjectDetectionMLModel() {
try {
predictor = buildCriteria().loadModel().newPredictor();
} catch (ModelNotFoundException | MalformedModelException | IOException e) {
throw new RuntimeException("Not able to load detect object models", e);
}
}
protected Criteria<Image, DetectedObjects> buildCriteria() {
return Criteria.builder()
.optEngine("TensorFlow")
.optApplication(Application.CV.OBJECT_DETECTION)
.setTypes(Image.class, DetectedObjects.class)
.optFilter("backbone", "mobilenet_v2")
.optArgument("threshold", "0.2")
.optProgress(new ProgressBar())
.build();
}
@Override
public String getName() {
return "Object Detection";
}
@Override
public DetectedObjects predict(Image image) {
try {
return predictor.predict(image);
} catch (TranslateException e) {
throw new RuntimeException("Not able to detect objects", e);
}
}
}




For Pose Prediction we had to internally run two models: the first to extract Person from the input image and the other to do the PoseEstimation itself. We also had to calculate the points of the pose relative to the input image. See a video for this model

package org.fxapps.predict;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.Optional;
import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import org.fxapps.predict.model.PoseEstimationPrediction;
public final class PoseEstimationMLModel extends MLModel<PoseEstimationPrediction> {
private Predictor<Image, DetectedObjects> objectsPredictor;
private Predictor<Image, Joints> posePredictor;
private static Joints EMPTY_JOINTS = new Joints(Collections.emptyList());
public PoseEstimationMLModel() {
super();
try {
// more precision
// objectsPredictor =
// Criteria.builder()
// .optApplication(Application.CV.OBJECT_DETECTION)
// .setTypes(Image.class, DetectedObjects.class)
// .optFilter("size", "512")
// .optFilter("backbone", "resnet50")
// .optFilter("flavor", "v1")
// .optFilter("dataset", "voc")
// .optProgress(new ProgressBar())
// .build()
// .loadModel()
// .newPredictor();
// faster
objectsPredictor = Criteria.builder()
.optEngine("TensorFlow")
.optApplication(Application.CV.OBJECT_DETECTION)
.setTypes(Image.class, DetectedObjects.class)
.optFilter("backbone", "mobilenet_v2")
.optArgument("threshold", "0.1")
.optProgress(new ProgressBar())
.build()
.loadModel()
.newPredictor();
} catch (ModelNotFoundException | MalformedModelException | IOException e) {
throw new RuntimeException("Not able to load object detector model.");
}
try {
posePredictor = Criteria.builder()
.optApplication(Application.CV.POSE_ESTIMATION)
.setTypes(Image.class, Joints.class)
.optFilter("backbone", "resnet18")
.optFilter("flavor", "v1b")
.optFilter("dataset", "imagenet")
.build().loadModel().newPredictor();
} catch (ModelNotFoundException | MalformedModelException | IOException e) {
throw new RuntimeException("Not able to load pose estimatino model.");
}
}
@Override
public String getName() {
return "Pose Estimation";
}
@Override
protected PoseEstimationPrediction predict(Image image) {
try {
var personPosOp = retrievePerson(image);
var personJoints = personPosOp.map(rect -> predictPose(image, rect))
.orElse(EMPTY_JOINTS);
return new PoseEstimationPrediction(personPosOp, personJoints);
} catch (TranslateException e) {
e.printStackTrace();
return new PoseEstimationPrediction(Optional.empty(), EMPTY_JOINTS);
}
}
private Optional<Rectangle> retrievePerson(Image img) throws TranslateException {
var detectedBoxes = objectsPredictor.predict(img);
// use to draw the predicted objects
// img.drawBoundingBoxes(detectedBoxes);
return detectedBoxes.items()
.stream()
.map(i -> (DetectedObjects.DetectedObject) i)
.filter(item -> "person".equalsIgnoreCase(item.getClassName()))
.findFirst()
.map(box -> extractPersonBounds(img, box));
}
private Rectangle extractPersonBounds(Image img, DetectedObjects.DetectedObject box) {
var rect = box.getBoundingBox().getBounds();
int width = img.getWidth();
int height = img.getHeight();
int personX = (int) (rect.getX() * width);
int personY = (int) (rect.getY() * height);
int personWidth = (int) (rect.getWidth() * width);
int personHeight = (int) (rect.getHeight() * height);
if (personX > personWidth) {
personX = personWidth;
}
if (personY > personHeight) {
personY = personHeight;
}
if (personX < 0) {
personX = 0;
}
if (personY < 0) {
personY = 0;
}
int rectXBound = personX + personWidth;
int rectYBound = personY + personHeight;
if (rectXBound > img.getWidth()) {
personWidth = personWidth - (rectXBound - img.getWidth());
}
if (rectYBound > img.getHeight()) {
personHeight = personHeight - (rectYBound - img.getHeight());
}
return new Rectangle(personX, personY, personWidth, personHeight);
}
private Joints predictPose(Image image, Rectangle rect) {
try {
var personImage = image.getSubimage((int) rect.getX(),
(int) rect.getY(),
(int) rect.getWidth(),
(int) rect.getHeight());
return posePredictor.predict(personImage);
} catch (TranslateException e) {
e.printStackTrace();
return EMPTY_JOINTS;
}
}
// use this for debug purpose
protected static void saveJointsImage(Image img, Joints joints) {
Path outputDir = Paths.get("build/output");
try {
Files.createDirectories(outputDir);
img.drawJoints(joints);
Path imagePath = outputDir.resolve("joints.png");
// Must use png format because you can't save as jpg with an alpha channel
img.save(Files.newOutputStream(imagePath), "png");
} catch (IOException e) {
e.printStackTrace();
}
}
}




Finally we also have Instance Segmentation, but it was so, so slow that I will not discuss it here. You are free to run the application and test it.



Conclusion

Java is a good alternative for building Machine Learning applications! DJL and its great model zoo makes it easier to reuse ms trained with other libraries. Next steps would be use other family of trained models, such as NLP and create more useful applications, like augmented reality!

Full code on Github.


Comentários

Postagens mais visitadas deste blog

Dancing lights with Arduino - The idea

I have been having fun with Arduino these days! In this article I am going to show how did I use an electret mic with Arduino to create a Dancing Lights circuit. Dancing Lights   I used to be an eletronician before starting the IT college. I had my own electronics maintenance office to fix television, radios, etc. In my free time I used to create electronic projects to sell and I made a few "reais" selling a version of Dancing lights, but it was too limited: it simply animated lamps using a relay in the output of a 4017 CMOS IC. The circuit was a decimal counter  controlled by a 555. 4017 decimal counter. Source in the image When I met Arduino a few years ago, I was skeptical because I said: I can do this with IC, why should I use a microcontroller. I thought that Arduino was for kids. But now my pride is gone and I am having a lot of fun with Arduino :-) The implementation of Dancing Lights with Arduino uses an electret mic to capture the sound and light leds...

Simplest JavaFX ComboBox autocomplete

Based on this Brazilian community post , I've created a sample Combobox auto complete. What it basically does is: When user type with the combobox selected, it will work on a temporary string to store the typed text; Each key typed leads to the combobox to be showed and updated If backspace is type, we update the filter Each key typed shows the combo box items, when the combobox is hidden, the filter is cleaned and the tooltip is hidden:   The class code and a sample application is below. I also added the source to my personal github , sent me PR to improve it and there are a lot of things to improve, like space and accents support.

Creating Fat JARs for JavaFX Applications

A FAT JAR is a type of Java application distribution where a single JAR contains all other dependencies, so no additional file is required to run the JAR besides the Java Virtual Machine. For any Java maven based application, creating a FAR JAR could be solved by using Maven Shade Plugin . However, creating FAT Jars using JavaFX may be a challenge because JavaFX uses modules. Fortunately this subject was intensely discussed in the WEB, and a good explanation and a solution was provided by Jose Pereda in this StackOverflow response. In this post I want to briefly share the steps to make a FAT JAR and post an example on my github so I can point others to check the example. How to create a FAT JAR for a JavaFX application? 1- Create a main class that will run your application. This class must have the main method and call your actual application static launch method; 2- Add the Shade Plugin to your project. For those using Gradle notice that Jose Pereda also provided an answer about it i...