±à¼ÍƼö: |
±¾ÎÄÀ´×ÔÓÚcsdn£¬ÎÄÕ½éÉÜÁËÊý¾Ý¼¯¡¢Ä£ÐÍѵÁ·ÖжÁȡѵÁ·Êý¾ÝÒÔ¼°Ä£Ðͼì²â¿ÉÊÓ»¯µÈÏà¹ØÄÚÈÝ¡£ |
|
ÔÚÕâ¸öYolo v3·¢²¼µÄ´óºÃÈÕ×Ó¡£
Deeplearning4jÖÕÓÚÓÀ´ÁËеİ汾¸üÐÂ1.0.0-alpha£¬ÔÚzoo modelÖÐÒýÈëTinyYoloÄ£ÐÍ¿ÉÒÔѵÁ·×Ô¼ºµÄÊý¾ÝÓÃÓÚÄ¿±ê¼ì²â¡£
²»µÃ²»Ëµ£¬ÔÚYolo v3ÕâÖÖÐÔÄܺÍ׼ȷÂÊÉÏÃæ¶¼Óдó·ù¶ÈÌáÉýµÄÇé¿öÏ£¬dl4j²ÅÒýÈëTinyYolo×ÜÓÐÒ»ÖÖ49Äê¼ÓÈë¹ú¾üµÄ¸Ð¾õ
Ò»¡¢ÈÎÎñºÍÊý¾Ý
Êý¾ÝÀ´Ô´×Ô https://github.com/cosmicad/dataset £¬Ö÷ҪĿµÄÊÇʶ±ð²¢¶¨Î»Í¼ÏñÖеĺìϸ°û¡£
Êý¾Ý¼¯×ܹ²·ÖΪÁ½¸ö²¿·Ö£º
Êý¾Ý¼¯£ºJPEGImages
±êÇ©£ºAnnotations
1.1 Êý¾Ý¼¯
Êý¾Ý¼¯ÑùÕÅÈçͼËùʾ£º

Êý¾Ý¼¯ÖÐËùÓеÄͼÏñ¾ùΪ.jpg¸ñʽ¡£Ò»¹²ÓÐ410ÕÅͼƬÓÃÓÚÄ£Ð͵ÄѵÁ·¡£
1.2 ±êÇ©
±êÇ©ÈçͼËùʾ£¬Ã¿Ò»¸öͼƬ¶¼»áÓÐÒ»¸ö¶ÔÓ¦µÄxmlÎļþ×÷ΪѵÁ·±êÇ©¡£

ûһ¸ö±êÇ©µÄÊý¾Ý¶¼ÊÇ×ñÊØPASCAL VOCµÄÊý¾Ý¸ñʽ£¬ÎļþÄÚÈÝÈçÏ£º
<annotation
verified="no"> <folder>RBC</folder>
<filename>BloodImage_00000</filename>
//¶ÔÓ¦µÄͼƬ <path>/Users/cosmic/WBC_CLASSIFICATION_ANNO /RBC/BloodImage_00000.jpg</path>
//·¾¶£¨²»ÖØÒª£© <source> //Êý¾ÝÀ´Ô´£¨²»ÖØÒª£© <database>Unknown</database>
</source> <size> //ͼÏñµÄ¿í¸ßºÍͨµÀÊý
<width>640</width> <height>480</height>
<depth>3</depth> </size>
<segmented>0</segmented> //ÊÇ·ñÓÃÓڷָÔÚͼÏñÎïÌåʶ±ðÖÐ01ÎÞËùν£©
<object> //ÐèÒª¼ì²âµÄÎïÌå <name>RBC</name>
//ÎïÌåÀà±ðµÄ±êÇ©£¬¿ÉÒÔʹÓÃÖÐÎÄ <pose>Unspecified</pose>
//ÅÄÉã½Ç¶È <truncated>0</truncated>
//ÊÇ·ñ±»½Ø¶Ï£¨0±íʾÍêÕû£© <difficult>0</difficult>
//Ä¿±êÊÇ·ñÄÑÒÔʶ±ð£¨0±íʾÈÝÒ×ʶ±ð£© <bndbox> //bounding-box£¨°üº¬×óÉϽǺÍÓÒϽÇxy×ø±ê£©
<xmin>216</xmin> <ymin>359</ymin>
<xmax>316</xmax> <ymax>464</ymax>
</bndbox> </object>
... //Èç¹ûÐèÒª¼ì²â¶à¸öÎïÌ壬Ôò¶¨Òå¶à¸ö<object></object>¶ÔÏó¼´¿É
</annotation> |
1.3 ÈçºÎÖÆ×÷×Ô¼ºµÄÊý¾Ý¼¯
BBox-Label-Tool: https://github.com/puzzledqs/BBox-Label-Tool
¾«Áé±ê×¢: http://jl.shenjian.io/
¶þ¡¢Ä£ÐÍѵÁ·
2.1 Ô¤¶¨Òå²ÎÊýÓÃÓÚÄ£Ð͵ÄѵÁ·
// parameters
matching the pretrained TinyYOLO model
int width = 416;
int height = 416;
int nChannels = 3;
int gridWidth = 13;
int gridHeight = 13; |
ÒÔÉÏ´úÂ붨ÒåµÄÊÇ£º
¿í¸ßºÍͼÏñµÄͨµÀÊý
YOLOÄ£ÐͶÔͼÏñ·Ö¸îµÄ³ß´ç£¬ÔÚÕâÀï±»·Ö¸î³ÉΪ13 x 13
// number classes
for the red blood cells (RBC)
int nClasses = 1; |
¶¨ÒåÎÒÃÇÐèÒª·ÖÀàµÄÊýÁ¿£¬ÔÚÕâÀïÎÒÃÇֻʶ±ðºìϸ°ûÕâÒ»¸öÎïÌ壬ÒòΪֵΪ1¡£
// parameters
for the Yolo2OutputLayer
int nBoxes = 5;
double lambdaNoObj = 0.5;
double lambdaCoord = 5.0;
double[][] priorBoxes = { { 2, 2 }, { 2, 2 },
{ 2, 2 }, { 2, 2 }, { 2, 2 } };
double detectionThreshold = 0.3; |
¶¨ÒåÎÒÃÇÄ£ÐÍÊä³ö²ãµÄһЩ²ÎÊý¡£
// parameters
for the training phase
int batchSize = 2;
int nEpochs = 50;
double learningRate = 1e-3;
double lrMomentum = 0.9; |
¶¨ÒåһЩÎÒÃÇѵÁ·Ê±Ä£Ð͵IJÎÊý£º
batchSizeΪ2£¬ÕâÀïÖ÷ÒªÊÇÒòΪÎÒʹÓÃCPUÔËÐУ¬¶øÇÒµçÄÔÖ»ÓÐ8GÔ˴棬Òò´Ëµ±ÄãµçÄÔÅäÖøü¸ßµÄʱºò¿ÉÒÔÑ¡Ôñ¸ü´óµÄֵʹµÃÄ£ÐÍ»ñµÃ¸üºÃµÄѵÁ·½á¹û¡£
nEpochΪ50£¬×ܹ²ÑµÁ·Êý¾Ý50¸öÂִΡ£
learningRate£¬Ñ§Ï°ÂÊΪ1e-3¡£
ѧϰÂÊË¥¼õ¶¯Á¿£¬Ó¦ÓÃÓÚNesterovs¸üÐÂÆ÷¡£
2.2 Êý¾Ý¶ÁÈ¡
String dataDir
= new ClassPathResource("/datasets").getFile().getPath();
File imageDir = new File(dataDir, "JPEGImages"); |
ÔÚ±¾ÏîÄ¿ÖÐÊý¾Ý±»´æ·ÅÔÚresourcesÎļþ¼ÐÏ£¬Òò´ËÐèÒª»ñÈ¡Àà·¾¶£¬ÕâÀïÖ÷ÒªÊÇ»ñȡͼÏñĿ¼¡£
log.info("Load
data...");
RandomPathFilter pathFilter = new RandomPathFilter(rng)
{
@Override
protected boolean accept(String name) {
name = name.replace("/JPEGImages/",
"/Annotations/").replace(".jpg",
".xml");
try {
return new File(new URI(name)).exists();
} catch (URISyntaxException ex) {
throw new RuntimeException(ex);
}
}
};
InputSplit[] data = new FileSplit(imageDir, NativeImageLoader.ALLOWED_FORMATS,
rng).sample(pathFilter, 0.8, 0.2);
InputSplit trainData = data[0];
InputSplit testData = data[1]; |
¶ÁȡѵÁ·Êý¾Ý£¬²¢ÇÒ½«Êý¾Ý»®·ÖΪѵÁ·¼¯ºÍ²âÊÔ¼¯¡£
ObjectDetectionRecordReader
recordReaderTrain = new ObjectDetectionRecordReader(height,
width, nChannels, gridHeight, gridWidth, new VocLabelProvider(dataDir));
recordReaderTrain.initialize(trainData);
ObjectDetectionRecordReader recordReaderTest =
new ObjectDetectionRecordReader(height, width,
nChannels, gridHeight, gridWidth,
new VocLabelProvider(dataDir));
recordReaderTest.initialize(testData);
// ObjectDetectionRecordReader performs regression,
so we need to specify it here
RecordReaderDataSetIterator train = new RecordReaderDataSetIterator(recordReaderTrain,
batchSize, 1, 1, true);
train.setPreProcessor(new ImagePreProcessingScaler(0,
1));
RecordReaderDataSetIterator test = new RecordReaderDataSetIterator(recordReaderTest,
1, 1, 1, true);
test.setPreProcessor(new ImagePreProcessingScaler(0,
1)); |
¹¹½¨ÑµÁ·¼¯ºÍ²âÊÔ¼¯µÄµü´úÆ÷£¬²¢ÇÒ´´½¨Êý¾ÝÔ¤´¦ÀíÆ÷£¬Ê¹µÃͼÏñÊý¾ÝÔÚѵÁ·Ê±±»Ëõ·ÅÖÁ0~1·¶Î§ÄÚ¡£
2.3 Ä£Ð͹¹½¨
ComputationGraph
model;
String modelFilename = "model_rbc.zip";
ComputationGraph pretrained = (ComputationGraph)
new TinyYOLO().initPretrained();
INDArray priors = Nd4j.create(priorBoxes); |
Ê×ÏÈ»á´ÓÍøÂçÉÏÃæÏÂÔØÔ¤ÑµÁ·Ä£ÐÍ£¬ÏÂÔØµØÖ·ÎªÓû§Ä¿Â¼ÏµÄ.deeplearning4jĿ¼Ï£¬ÄÚÈÝÈçͼËùʾ£º

½ÓÏÂÀ´Ê¹ÓÃfine tune¶ÔÄ£Ðͽṹ½øÐиü¸Ä£º
FineTuneConfiguration
fineTuneConf = new FineTuneConfiguration.Builder().seed(seed)
.optimizationAlgo(OptimizationAlgorithm .STOCHASTIC_GRADIENT_DESCENT) .gradientNormalization (GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(1.0).updater(new
Adam.Builder().learningRate(learningRate).build())
.updater(new Nesterovs.Builder().learningRate(learningRate) .momentum(lrMomentum).build()) .activation(Activation.IDENTITY)
.trainingWorkspaceMode(WorkspaceMode.SEPARATE) .inferenceWorkspaceMode(WorkspaceMode.SEPARATE).build(); |
ÒÔÉÏ´úÂëÖ÷Òª×öÁËÕ⼸¼þÊÂÇ飺
ʹÓÃËæ»úÌݶÈϽµÓÅ»¯Ëã·¨
ʹÓà RenormalizeL2PerLayer Ìݶȱê×¼»¯Ëã·¨£¬ÓÃÓÚ·ÀÖ¹ÌݶÈÏûʧºÍÌݶȱ¬Õ¨¡£
ʹÓÃNesterovs¸üÐÂÆ÷£¬ÅäÖÃѧϰÂʺͶ¯Á¿
É趨ѵÁ·Ä£Ê½¡£
Ö®ºóʹÓÃÇ¨ÒÆÑ§Ï°¶ÔÓÚÄ£Ðͼܹ¹¼ÇÐÔÐ޸ģº
model = new TransferLearning.GraphBuilder(pretrained) .fineTuneConfiguration(fineTuneConf) .removeVertexKeepConnections("conv2d_9")
.addLayer("convolution2d_9",
new ConvolutionLayer.Builder(1, 1).nIn(1024).nOut(nBoxes
* (5 + nClasses)).stride(1, 1).convolutionMode(ConvolutionMode.Same)
.weightInit(WeightInit.UNIFORM).hasBias(false) .activation(Activation.IDENTITY).build(),
"leaky_re_lu_8")
.addLayer("outputs", new Yolo2OutputLayer.Builder() .lambbaNoObj(lambdaNoObj).lambdaCoord(lambdaCoord) .boundingBoxPriors(priors).build(),
"convolution2d_9")
.setOutputs("outputs")
.build(); |
Ö÷ÒªÊÇÅäÖÃʶ±ðµÄÖÖÀàÊýÄ¿¡£
2.4 Ä£ÐÍѵÁ·
model.setListeners(new
ScoreIterationListener(1));
for (int i = 0; i < nEpochs; i++) {
train.reset();
while (train.hasNext()) {
model.fit(train.next());
}
log.info("*** Completed epoch {} ***",
i);
}
ModelSerializer.writeModel(model, modelFilename,
true);
|
Ä£ÐÍѵÁ·Íê³ÉÖ®ºó£¬ÐòÁл¯±£´æÔÚ±¾µØ¡£
2.5 Ä£Ðͼì²â¿ÉÊÓ»¯
// visualize
results on the test set
NativeImageLoader imageLoader = new NativeImageLoader();
CanvasFrame frame = new CanvasFrame("RedBloodCellDetection");
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer
yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer)
model.getOutputLayer(0);
List<String> labels = train.getLabels();
test.setCollectMetaData(true);
while (test.hasNext() && frame.isVisible())
{
org.nd4j.linalg.dataset.DataSet ds = test.next();
RecordMetaDataImageURI metadata = (RecordMetaDataImageURI)
ds.getExampleMetaData().get(0);
INDArray features = ds.getFeatures();
INDArray results = model.outputSingle(features);
List<DetectedObject> objs = yout.getPredictedObjects(results,
detectionThreshold);
File file = new File(metadata.getURI());
log.info(file.getName() + ": " + objs);
Mat mat = imageLoader.asMat(features);
Mat convertedMat = new Mat();
mat.convertTo(convertedMat, CV_8U, 255, 0);
int w = metadata.getOrigW() * 2;
int h = metadata.getOrigH() * 2;
Mat image = new Mat();
resize(convertedMat, image, new Size(w, h));
for (DetectedObject obj : objs) {
double[] xy1 = obj.getTopLeftXY();
double[] xy2 = obj.getBottomRightXY();
String label = labels.get(obj.getPredictedClass());
int x1 = (int) Math.round(w * xy1[0] / gridWidth);
int y1 = (int) Math.round(h * xy1[1] / gridHeight);
int x2 = (int) Math.round(w * xy2[0] / gridWidth);
int y2 = (int) Math.round(h * xy2[1] / gridHeight);
rectangle(image, new Point(x1, y1), new Point(x2,
y2), Scalar.RED);
putText(image, label, new Point(x1 + 2, y2 - 2),
FONT_HERSHEY_DUPLEX, 1, Scalar.GREEN);
}
frame.setTitle(new File(metadata.getURI()).getName()
+ " - RedBloodCellDetection");
frame.setCanvasSize(w, h);
frame.showImage(converter.convert(image));
frame.waitKey();
}
frame.dispose(); |
Èý¡¢ÊµÑé½á¹û

ÒòΪÊý¾ÝÁ¿ÉÙ£¬ÑµÁ·ÂÖ´ÎС ÓÐÐËȤµÄ¿ÉÒÔ×Ô¼º³¢ÊÔ¼ÌÐøÑµÁ·¡£
ËÄ¡¢´úÂëµØÖ·
´úÂëµØÖ·ÒѾ·ÅÔÚgithubÉÏÃæ£¬×ÔÐÐÏÂÔØ¼´¿É£ºhttps://github.com/sjsdfg/dl4j-tutorials
ÔÚ°üobjectdetectionÏ£¬¿ÉÒÔËæÒâÔËÐС£ |