Pular para o conteúdo principal

Recognizing Handwritten digits from a JavaFX application using Deeplearning4j





We already talked about tensorflow and JavaFX on this blog, but tensorflow Java API is still incomplete. A mature and well documented API is DeepLearning4J.


In this example we load the trained model in our application, create a canvas for writing and enter is pressed, the canvas image is resized and sent to the deeplearning4j trained model for recognition:






The way it "guess" the digit is like the "hello world" for  deep neural networks. A neuron roughly mimics the human brain neuron and it has a weight, which controls when the neuron is activated. A neural network consists of many neurons linked to each other and organized in layers. What we do is provide to our neural network labeled data and adjust the weights of our neurons until it is able to correctly predict values for the given data, this is called training.




Once it is trained, we test the neural network against known labeled data to measure the neural network precision (in our case the precision is 97.5%!). In our case we use the famous MNIST database.


Because it has hidden layers between the input layer (where we input our data) and the output layer (where we get our predictions), we call it deep neural network. We have many other concepts and types of neural networks, I encourage you to watch some videos about the subject on youtube.



And if it is the first time you reading about this stuff, be aware that it won't be the last time!

If you try the code you may find that it is not so precise as this web application, for example. The reason is that I didn't handle the image precisely before sending it for prediction, we just resize it to 28x28 pixels as required by our trained model.

The code of the JavaFX application is below and the full project is on my github, including the training Java code, which was created using deeplearning4j examples.

package org.fxapps.deeplearning;
import java.awt.Graphics;
import java.awt.Image;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import javafx.application.Application;
import javafx.embed.swing.SwingFXUtils;
import javafx.geometry.Pos;
import javafx.scene.Scene;
import javafx.scene.canvas.Canvas;
import javafx.scene.canvas.GraphicsContext;
import javafx.scene.control.Label;
import javafx.scene.image.ImageView;
import javafx.scene.image.WritableImage;
import javafx.scene.input.KeyCode;
import javafx.scene.input.MouseButton;
import javafx.scene.layout.HBox;
import javafx.scene.layout.VBox;
import javafx.scene.paint.Color;
import javafx.scene.shape.StrokeLineCap;
import javafx.stage.Stage;
public class MnistTestFXApp extends Application {
private final int CANVAS_WIDTH = 150;
private final int CANVAS_HEIGHT = 150;
private NativeImageLoader loader;
private MultiLayerNetwork model;
private Label lblResult;
public static void main(String[] args) throws IOException {
launch();
}
@Override
public void start(Stage stage) throws Exception {
Canvas canvas = new Canvas(CANVAS_WIDTH, CANVAS_HEIGHT);
ImageView imgView = new ImageView();
GraphicsContext ctx = canvas.getGraphicsContext2D();
model = ModelSerializer.restoreMultiLayerNetwork(new File("minist-model.zip"));
loader = new NativeImageLoader(28,28,1,true);
imgView.setFitHeight(100);
imgView.setFitWidth(100);
ctx.setLineWidth(10);
ctx.setLineCap(StrokeLineCap.SQUARE);
lblResult = new Label();
HBox hbBottom = new HBox(10, imgView, lblResult);
VBox root = new VBox(5, canvas, hbBottom);
hbBottom.setAlignment(Pos.CENTER);
root.setAlignment(Pos.CENTER);
Scene scene = new Scene(root, 520, 300);
stage.setScene(scene);
stage.show();
stage.setTitle("Handwritten digits recognition");
canvas.setOnMousePressed(e -> {
ctx.setStroke(Color.WHITE);
ctx.beginPath();
ctx.moveTo(e.getX(), e.getY());
ctx.stroke();
});
canvas.setOnMouseDragged(e -> {
ctx.setStroke(Color.WHITE);
ctx.lineTo(e.getX(), e.getY());
ctx.stroke();
});
canvas.setOnMouseClicked(e -> {
if (e.getButton() == MouseButton.SECONDARY) {
clear(ctx);
}
});
canvas.setOnKeyReleased(e -> {
if(e.getCode() == KeyCode.ENTER) {
BufferedImage scaledImg = getScaledImage(canvas);
imgView.setImage(SwingFXUtils.toFXImage(scaledImg, null));
try {
predictImage(scaledImg);
} catch (Exception e1) {
e1.printStackTrace();
}
}
});
clear(ctx);
canvas.requestFocus();
}
private BufferedImage getScaledImage(Canvas canvas) {
// for a better recognition we should improve this part of how we retrieve the image from the canvas
WritableImage writableImage = new WritableImage(CANVAS_WIDTH, CANVAS_HEIGHT);
canvas.snapshot(null, writableImage);
Image tmp = SwingFXUtils.fromFXImage(writableImage, null).getScaledInstance(28, 28, Image.SCALE_SMOOTH);
BufferedImage scaledImg = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY);
Graphics graphics = scaledImg.getGraphics();
graphics.drawImage(tmp, 0, 0, null);
graphics.dispose();
return scaledImg;
}
private void clear(GraphicsContext ctx) {
ctx.setFill(Color.BLACK);
ctx.fillRect(0, 0, 300, 300);
}
private void predictImage(BufferedImage img ) throws IOException {
ImagePreProcessingScaler imagePreProcessingScaler = new ImagePreProcessingScaler(0, 1);
INDArray image = loader.asRowVector(img);
imagePreProcessingScaler.transform(image);
INDArray output = model.output(image);
String putStr = output.toString();
lblResult.setText("Prediction: " + model.predict(image)[0] + "\n " + putStr);
}
}

Comentários

Postagens mais visitadas deste blog

Dancing lights with Arduino - The idea

I have been having fun with Arduino these days! In this article I am going to show how did I use an electret mic with Arduino to create a Dancing Lights circuit. Dancing Lights   I used to be an eletronician before starting the IT college. I had my own electronics maintenance office to fix television, radios, etc. In my free time I used to create electronic projects to sell and I made a few "reais" selling a version of Dancing lights, but it was too limited: it simply animated lamps using a relay in the output of a 4017 CMOS IC. The circuit was a decimal counter  controlled by a 555. 4017 decimal counter. Source in the image When I met Arduino a few years ago, I was skeptical because I said: I can do this with IC, why should I use a microcontroller. I thought that Arduino was for kids. But now my pride is gone and I am having a lot of fun with Arduino :-) The implementation of Dancing Lights with Arduino uses an electret mic to capture the sound and light leds...

Simplest JavaFX ComboBox autocomplete

Based on this Brazilian community post , I've created a sample Combobox auto complete. What it basically does is: When user type with the combobox selected, it will work on a temporary string to store the typed text; Each key typed leads to the combobox to be showed and updated If backspace is type, we update the filter Each key typed shows the combo box items, when the combobox is hidden, the filter is cleaned and the tooltip is hidden:   The class code and a sample application is below. I also added the source to my personal github , sent me PR to improve it and there are a lot of things to improve, like space and accents support.

Genetic algorithms with Java

One of the most fascinating topics in computer science world is Artificial Intelligence . A subset of Artificial intelligence are the algorithms that were created inspired in the nature. In this group, we have Genetic Algorithms  (GA). Genetic Algorithms  To find out more about this topic I recommend the following MIT lecture and the Nature of Code book and videos created by Daniel Shiffman. Genetic Algorithms using Java After I remembered the basics about it, I wanted to practice, so I tried my own implementation, but I would have to write a lot of code to do what certainly others already did. So I started looking for Genetic Algorithm libraries and found Jenetics , which is a modern library that uses Java 8 concepts and APIs, and there's also JGAP . I decided to use Jenetics because the User Guide was so clear and it has no other dependency, but Java 8. The only thing I missed for Jenetics are more small examples like the ones I will show i...