Pular para o conteúdo principal

Brazilian Coins Classification Using Deep Learning and Java

It is exciting to learn deep learning. The courses and the results are exciting and you can see it in your machine using popular libraries like Keras for Python and DeepLearning4j for Java.
These libraries even have utility methods that will download famous dataset for you, you just have to copy and paste the code and run it, after training (probably a few hours on a CPU) you will have a model that is trained and ready to classify new data.
I did that. I did that a lot of times. I also created small neural nets from scratch, but still, I was missing solving a problem with this technology and collect my own dataset.
In this post I will share how I used DeepLearning4J to classify brazilian coins. I will try also to share how I failed a lot before getting my current 77% accuracy model.

Brazilian coin for one real

Collecting the data set


We need a large dataset to train a neural network. They say that at least 1k images per classes. I started much less than that, about 50 images per class and I got it from google images. Of course the training did very bad so I collected more images.
Take pictures of coin is not the most pleasant task in the world, so I asked for my wife and co workers help. I even created a small application so people could send me images more easily. Today I have 200 images per class in my training set and I separated it in directories, where the parent directory is the label for these images. With this approach I case use the DeepLearning4j classes I already mentioned on this blog.

ParentPathLabelGenerator labelGenerator = new ParentPathLabelGenerator();
File coinsTrainRootDir = Paths.get("/path/to/my/mycooldataset/train").toFile();
InputSplit trainData = new FileSplit(coinsTrainRootDir, BaseImageLoader.ALLOWED_FORMATS, new Random());
ImageRecordReader trainReader = new ImageRecordReader(IMG_WIDTH, IMG_HEIGHT, 3, labelGenerator);
trainReader.initialize(trainData);
view raw TrainData.java hosted with ❤ by GitHub
I also have about 40 images per class in my test set. I resized all images to be of size 448x448 - I found no neural network trained with images larger than that, and I don't want to explode my file system with large photos files took from my mobile camera (~3mb per photo)



When I was collecting the data set I was also making some testing and trying to build a good model and I also used the DL4J library to generate images for me. Since my dataset was so small I can use DL4J to apply some sort of transformation (crop, resize, rotate, etc) to my dataset generating images based on the originals so I can have more data to train my neural network.

Training a neural network


First I started trying to create my own neural networks and train it from scratch (with random weights). I choose Convolutional Neural Networks (CNN) to identify patterns in a coin to classsify it - a few classifical CNN layers and a full connected layer just as I learned on the internet.

That was a bad idea: training was taking too long I never got a good result with accuracy greater than 50. So I took some good neural networks already available in the internet and wrote using DeepLearning4j (DL4J).
Well, DL4J examples and its documentation is full of good examples and it even includes examples for famous networks, such as VGG16.

Convolutional Neural Netowkrs. Source: https://www.topbots.com/14-design-patterns-improve-convolutional-neural-network-cnn-architecture/


After grabbing the neural network code in the internet I noticed that the training was taking too long and no good result was coming out from my training sections. In another words, hours waiting for a result that was far away from what I expect. This was the good time to delete my project and forget about it. It was clear that it would take time to get a good dataset and also experience to know how to correctly choose the hyper-parameters for my neural network (it would require more and more time with the dataset and more inspection of the neural network). 

Using pre-trained models with DeepLearning4j


DL4J 0.9 comes with great APIs to use pre-trained  neural networks. Using the Model Zoo API you can get a fresh known neural network architecture (VGG16, Resnet50, GoogleLetNet, etc) and train it against your data or get a pre-trained model. There are CNN models that were trained against ImageNet, in another words, you can get what the neural network already learned from ImageNet and use it to classify your own images!

Once you get a model from the zoo you can use the Transfer Learning API to replace the last layer to a layer that has the same number of output as the number of labels you have in your dataset (see the code in the next section of this article).

There are a few architectures that you can choose. I choose ResNet50 initialized with ImageNet weights. I would choose GoogleLetNet or VGG16, but ResNet50 model size is 90mb only, VGG16, for example, is 500mb! In my first test, with a few images, I got 70% of accuracy, that was amazing!


The code


Following everything we discussed this is what happens in the code

1) Loaded the data set and used the parent directory as the label for the images contained in this directory

ParentPathLabelGenerator labelGenerator = new ParentPathLabelGenerator();
ImageTransform[] transforms = getTransforms();
File coinsTrainRootDir = Paths.get("/home/wsiqueir/moedas/train").toFile();
File coinsTestRootDir = Paths.get("/home/wsiqueir/moedas/test").toFile();
InputSplit trainData = new FileSplit(coinsTrainRootDir, BaseImageLoader.ALLOWED_FORMATS, new Random());
InputSplit testData = new FileSplit(coinsTestRootDir, BaseImageLoader.ALLOWED_FORMATS, new Random());
ImageRecordReader trainReader = new ImageRecordReader(IMG_WIDTH, IMG_HEIGHT, CHANNELS, labelGenerator);
ImageRecordReader testReader = new ImageRecordReader(IMG_WIDTH, IMG_HEIGHT, CHANNELS, labelGenerator);
System.out.println("initializing");
trainReader.initialize(trainData);
testReader.initialize(testData);
System.out.println(trainReader.getLabels());
System.out.println(testReader.getLabels());
DataSetIterator coinTrainDataSet = new RecordReaderDataSetIterator(trainReader, MINI_BATCH_SIZE, 1, N_LABELS);
DataSetIterator coinTestDataSet = new RecordReaderDataSetIterator(testReader, MINI_BATCH_SIZE, 1, N_LABELS);
2) Got the Resnet50 model from the zoo which was has weights already adjust over thousand of images from ImageNet dataset -it means that knows a lot of patterns from images already and I used the learning transfer API to replace the last layer only with a new Dense Layer that contains 5 outputs (the number of classes I have in my dataset)

ZooModel<?> zooModel = new ResNet50(N_LABELS, SEED, 1);
ComputationGraph initializedZooModel = (ComputationGraph) zooModel.initPretrained(PretrainedType.IMAGENET);
FineTuneConfiguration fineTuneConf = new FineTuneConfiguration.Builder()
.learningRate(0.0001)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.NESTEROVS)
.seed(123)
.build();
System.out.println(initializedZooModel.summary());
ComputationGraph modelTransfer = new TransferLearning.GraphBuilder(initializedZooModel)
.fineTuneConfiguration(fineTuneConf)
.setFeatureExtractor("flatten_3")
.removeVertexKeepConnections("fc1000")
.addLayer("fc1000",
new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(2048).nOut(N_LABELS)
.weightInit(WeightInit.XAVIER)
.activation(Activation.SOFTMAX).build(), "flatten_3")
.build();
Once I run this code I could see the number of weights that needed to be updated. The other layers were frozen.

The Resnet50 layers were frozen except the last layer, which was added by us


3) Trained the model using the original dataset and transformed images, see:


log.info("Training with original data");
for (int i = 0; i < numEpochs; i++) {
log.info("Epoch " + i);
modelTransfer.fit(coinTrainDataSet);
}
log.info("Training with transformed data");
for (int i = 0; i < transformedDataEpochs; i++) {
log.info("Epoch " + i + " (transformed data)");
for (int j = 0; j < transforms.length; j++) {
ImageTransform imageTransform = transforms[j];
log.info("Epoch " + i + " (transform " + imageTransform + ")");
trainReader.initialize(trainData, imageTransform);
coinTrainDataSet = new RecordReaderDataSetIterator(trainReader, MINI_BATCH_SIZE, 1, N_LABELS);
modelTransfer.fit(coinTrainDataSet);
}
}
4) Finally evaluated the model and exported it to the disk. In my real application I can simply load and use it to predict the label for new images

log.info("Evaluate model....");
Evaluation eval = new Evaluation(N_LABELS);
while (coinTestDataSet.hasNext()) {
DataSet next = coinTestDataSet.next();
INDArray[] output = modelTransfer.output(next.getFeatureMatrix());
for (int i = 0; i < output.length; i++) {
eval.eval(next.getLabels(), output[i]);
}
}
String stats = eval.stats();
log.info(stats);
log.info("****************Example finished********************");
Files.write(stats.getBytes(), Paths.get("latest-output.txt").toFile());
File file = new File("brazilian-coin-model.zip");
ModelSerializer.writeModel(modelTransfer, file, true);
With my last training I had 79 % of accuracy - I consider it a good result so it is time to build stuff using it!

The results of this adventure
The full code can be found in my github. It is not using the best Java code practices, but the goal here is to export the model to be used in real applications.

Applications


The model is still far from what I wanted (85% of accuracy), but we can already create interesting applications:

* An offline mobile free application to help blind people identify its coins;
* A telegram bot that receives the image and return the classification;
* An app on your mobile to automatically count the coin values from a picture - return how much we have in that picture, so no human counting coins;

I will try to cover these 3 applications in this blog!

We should also consider that we are living in amazing times where AI is becoming an important part of our lives - but the lack of public datasets may not be a good thing - who owns the data owns the future and I hope to create more public datasets so people can create their own application.


Comentários

Postagens mais visitadas deste blog

Dancing lights with Arduino - The idea

I have been having fun with Arduino these days! In this article I am going to show how did I use an electret mic with Arduino to create a Dancing Lights circuit. Dancing Lights   I used to be an eletronician before starting the IT college. I had my own electronics maintenance office to fix television, radios, etc. In my free time I used to create electronic projects to sell and I made a few "reais" selling a version of Dancing lights, but it was too limited: it simply animated lamps using a relay in the output of a 4017 CMOS IC. The circuit was a decimal counter  controlled by a 555. 4017 decimal counter. Source in the image When I met Arduino a few years ago, I was skeptical because I said: I can do this with IC, why should I use a microcontroller. I thought that Arduino was for kids. But now my pride is gone and I am having a lot of fun with Arduino :-) The implementation of Dancing Lights with Arduino uses an electret mic to capture the sound and light leds...

Simplest JavaFX ComboBox autocomplete

Based on this Brazilian community post , I've created a sample Combobox auto complete. What it basically does is: When user type with the combobox selected, it will work on a temporary string to store the typed text; Each key typed leads to the combobox to be showed and updated If backspace is type, we update the filter Each key typed shows the combo box items, when the combobox is hidden, the filter is cleaned and the tooltip is hidden:   The class code and a sample application is below. I also added the source to my personal github , sent me PR to improve it and there are a lot of things to improve, like space and accents support.

Genetic algorithms with Java

One of the most fascinating topics in computer science world is Artificial Intelligence . A subset of Artificial intelligence are the algorithms that were created inspired in the nature. In this group, we have Genetic Algorithms  (GA). Genetic Algorithms  To find out more about this topic I recommend the following MIT lecture and the Nature of Code book and videos created by Daniel Shiffman. Genetic Algorithms using Java After I remembered the basics about it, I wanted to practice, so I tried my own implementation, but I would have to write a lot of code to do what certainly others already did. So I started looking for Genetic Algorithm libraries and found Jenetics , which is a modern library that uses Java 8 concepts and APIs, and there's also JGAP . I decided to use Jenetics because the User Guide was so clear and it has no other dependency, but Java 8. The only thing I missed for Jenetics are more small examples like the ones I will show i...