# This example reads external file with images and uses JavaCNN to
# identify images with faces.
print "Download and unzip files with images"
from jhplot import *
print Web.get("http://datamelt.org/examples/data/mitcbcl_pgm_set2.zip")
print IO.unzip("mitcbcl_pgm_set2.zip")
NMax=50 # Total runs. Reduce this number to get results faster
from org.ea.javacnn.data import DataBlock,OutputDefinition,TrainResult
from org.ea.javacnn.layers import DropoutLayer,FullyConnectedLayer,InputLayer,LocalResponseNormalizationLayer
from org.ea.javacnn.layers import ConvolutionLayer,RectifiedLinearUnitsLayer,PoolingLayer
from org.ea.javacnn.losslayers import SoftMaxLayer
from org.ea.javacnn.readers import ImageReader,MnistReader,PGMReader,Reader
from org.ea.javacnn.trainers import AdaGradTrainer,Trainer
from org.ea.javacnn import JavaCNN
from java.util import ArrayList,Arrays
from java.lang import System
layers = ArrayList(); de = OutputDefinition()
print "Total number of runs=", NMax
print "Reading train sample.."
mr = PGMReader("mitcbcl_pgm_set2/train/")
print "Total number of trainning images=",mr.size()," Nr of types=",mr.numOfClasses()
print "Read test sample .."
mrTest = PGMReader("mitcbcl_pgm_set2/test/")
print "Total number of test images=",mrTest.size()," Nr of types=",mrTest.numOfClasses()
modelName = "model.ser" # save NN to this file
layers.add(InputLayer(de, mr.getSizeX(), mr.getSizeY(), 1))
layers.add(ConvolutionLayer(de, 5, 32, 1, 2)) # uses different filters
layers.add(RectifiedLinearUnitsLayer()) # applies the non-saturating activation function
layers.add(PoolingLayer(de, 2,2, 0)) # creats a smaller zoomed out version
layers.add(ConvolutionLayer(de, 5, 64, 1, 2))
layers.add(RectifiedLinearUnitsLayer())
layers.add(PoolingLayer(de, 2,2, 0))
layers.add(FullyConnectedLayer(de, 1024))
layers.add(LocalResponseNormalizationLayer())
layers.add(DropoutLayer(de))
layers.add(FullyConnectedLayer(de, mr.numOfClasses()))
layers.add(SoftMaxLayer(de))
print "Training.."
net = JavaCNN(layers)
trainer = AdaGradTrainer(net, 20, 0.001)
from jarray import zeros
numberDistribution,correctPredictions = zeros(10, "i"),zeros(10, "i")
start = System.currentTimeMillis()
db = DataBlock(mr.getSizeX(), mr.getSizeY(), 1, 0)
for j in range(NMax):
loss = 0
for i in range(mr.size()):
db.addImageData(mr.readNextImage(), mr.getMaxvalue())
tr = trainer.train(db, mr.readNextLabel())
loss = loss + tr.getLoss()
if (i != 0 and i % 500 == 0):
print "Nr of images: ",i," Loss: ",(loss/float(i))
print "Loss: ", (loss / float(mr.size())), " for run=",j
mr.reset()
print 'Wait.. Calculating predictions for labels=', mr.getLabels()
Arrays.fill(correctPredictions, 0)
Arrays.fill(numberDistribution, 0)
for i in range(mrTest.size()):
db.addImageData(mrTest.readNextImage(), mr.getMaxvalue())
net.forward(db, False)
correct = mrTest.readNextLabel()
prediction = net.getPrediction()
if(correct == prediction): correctPredictions[correct] +=1
numberDistribution[correct] +=1
mrTest.reset()
print " -> Testing time: ",int(0.001*(System.currentTimeMillis() - start))," s"
print " -> Current run:",j
print net.getPredictions(correctPredictions, numberDistribution, mrTest.size(), mrTest.numOfClasses())
print " -> Save current state to ",modelName
net.saveModel(modelName)
print "Read trained network from ",modelName," and make the final test"
cnn =net.loadModel(modelName)
Arrays.fill(correctPredictions, 0)
Arrays.fill(numberDistribution, 0)
for i in range(mrTest.size()):
db.addImageData(mrTest.readNextImage(), mr.getMaxvalue())
net.forward(db, False)
correct = mrTest.readNextLabel()
prediction = net.getPrediction()
if(correct == prediction): correctPredictions[correct] +=1
numberDistribution[correct] +=1
print "Final test:"
print net.getPredictions(correctPredictions, numberDistribution, mrTest.size(), mrTest.numOfClasses())