Äú¿ÉÒÔ¾èÖú£¬Ö§³ÖÎÒÃǵĹ«ÒæÊÂÒµ¡£

1Ôª 10Ôª 50Ôª





ÈÏÖ¤Â룺  ÑéÖ¤Âë,¿´²»Çå³þ?Çëµã»÷Ë¢ÐÂÑéÖ¤Âë ±ØÌî



  ÇóÖª ÎÄÕ ÎÄ¿â Lib ÊÓÆµ iPerson ¿Î³Ì ÈÏÖ¤ ×Éѯ ¹¤¾ß ½²×ù Model Center   Code  
»áÔ±   
   
 
     
   
 ¶©ÔÄ
  ¾èÖú
DeepLearning4j-ʹÓÃJavaѵÁ·YOLOÄ£ÐÍ
 
  4285  次浏览      29
 2019-3-6
   
 
±à¼­ÍƼö:

±¾ÎÄÀ´×ÔÓÚcsdn£¬ÎÄÕ½éÉÜÁËÊý¾Ý¼¯¡¢Ä£ÐÍѵÁ·ÖжÁȡѵÁ·Êý¾ÝÒÔ¼°Ä£Ðͼì²â¿ÉÊÓ»¯µÈÏà¹ØÄÚÈÝ¡£

ÔÚÕâ¸öYolo v3·¢²¼µÄ´óºÃÈÕ×Ó¡£

Deeplearning4jÖÕÓÚÓ­À´ÁËеİ汾¸üÐÂ1.0.0-alpha£¬ÔÚzoo modelÖÐÒýÈëTinyYoloÄ£ÐÍ¿ÉÒÔѵÁ·×Ô¼ºµÄÊý¾ÝÓÃÓÚÄ¿±ê¼ì²â¡£

²»µÃ²»Ëµ£¬ÔÚYolo v3ÕâÖÖÐÔÄܺÍ׼ȷÂÊÉÏÃæ¶¼Óдó·ù¶ÈÌáÉýµÄÇé¿öÏ£¬dl4j²ÅÒýÈëTinyYolo×ÜÓÐÒ»ÖÖ49Äê¼ÓÈë¹ú¾üµÄ¸Ð¾õ

Ò»¡¢ÈÎÎñºÍÊý¾Ý

Êý¾ÝÀ´Ô´×Ô https://github.com/cosmicad/dataset £¬Ö÷ҪĿµÄÊÇʶ±ð²¢¶¨Î»Í¼ÏñÖеĺìϸ°û¡£

Êý¾Ý¼¯×ܹ²·ÖΪÁ½¸ö²¿·Ö£º

Êý¾Ý¼¯£ºJPEGImages

±êÇ©£ºAnnotations

1.1 Êý¾Ý¼¯

Êý¾Ý¼¯ÑùÕÅÈçͼËùʾ£º

Êý¾Ý¼¯ÖÐËùÓеÄͼÏñ¾ùΪ.jpg¸ñʽ¡£Ò»¹²ÓÐ410ÕÅͼƬÓÃÓÚÄ£Ð͵ÄѵÁ·¡£

1.2 ±êÇ©

±êÇ©ÈçͼËùʾ£¬Ã¿Ò»¸öͼƬ¶¼»áÓÐÒ»¸ö¶ÔÓ¦µÄxmlÎļþ×÷ΪѵÁ·±êÇ©¡£

ûһ¸ö±êÇ©µÄÊý¾Ý¶¼ÊÇ×ñÊØPASCAL VOCµÄÊý¾Ý¸ñʽ£¬ÎļþÄÚÈÝÈçÏ£º

<annotation verified="no">
<folder>RBC</folder>
<filename>BloodImage_00000</filename> //¶ÔÓ¦µÄͼƬ
<path>/Users/cosmic/WBC_CLASSIFICATION_ANNO
/RBC/BloodImage_00000.jpg</path> //·¾¶£¨²»ÖØÒª£©
<source> //Êý¾ÝÀ´Ô´£¨²»ÖØÒª£©
<database>Unknown</database>
</source>
<size> //ͼÏñµÄ¿í¸ßºÍͨµÀÊý
<width>640</width>
<height>480</height>
<depth>3</depth>
</size>
<segmented>0</segmented> //ÊÇ·ñÓÃÓڷָÔÚͼÏñÎïÌåʶ±ðÖÐ01ÎÞËùν£©
<object> //ÐèÒª¼ì²âµÄÎïÌå
<name>RBC</name> //ÎïÌåÀà±ðµÄ±êÇ©£¬¿ÉÒÔʹÓÃÖÐÎÄ
<pose>Unspecified</pose> //ÅÄÉã½Ç¶È
<truncated>0</truncated> //ÊÇ·ñ±»½Ø¶Ï£¨0±íʾÍêÕû£©
<difficult>0</difficult> //Ä¿±êÊÇ·ñÄÑÒÔʶ±ð£¨0±íʾÈÝÒ×ʶ±ð£©
<bndbox> //bounding-box£¨°üº¬×óÉϽǺÍÓÒϽÇxy×ø±ê£©
<xmin>216</xmin>
<ymin>359</ymin>
<xmax>316</xmax>
<ymax>464</ymax>
</bndbox>
</object>
... //Èç¹ûÐèÒª¼ì²â¶à¸öÎïÌ壬Ôò¶¨Òå¶à¸ö<object></object>¶ÔÏó¼´¿É
</annotation>

1.3 ÈçºÎÖÆ×÷×Ô¼ºµÄÊý¾Ý¼¯

BBox-Label-Tool: https://github.com/puzzledqs/BBox-Label-Tool

¾«Áé±ê×¢: http://jl.shenjian.io/

¶þ¡¢Ä£ÐÍѵÁ·

2.1 Ô¤¶¨Òå²ÎÊýÓÃÓÚÄ£Ð͵ÄѵÁ·

// parameters matching the pretrained TinyYOLO model
int width = 416;
int height = 416;
int nChannels = 3;
int gridWidth = 13;
int gridHeight = 13;

ÒÔÉÏ´úÂ붨ÒåµÄÊÇ£º

¿í¸ßºÍͼÏñµÄͨµÀÊý

YOLOÄ£ÐͶÔͼÏñ·Ö¸îµÄ³ß´ç£¬ÔÚÕâÀï±»·Ö¸î³ÉΪ13 x 13

// number classes for the red blood cells (RBC)
int nClasses = 1;

¶¨ÒåÎÒÃÇÐèÒª·ÖÀàµÄÊýÁ¿£¬ÔÚÕâÀïÎÒÃÇֻʶ±ðºìϸ°ûÕâÒ»¸öÎïÌ壬ÒòΪֵΪ1¡£

// parameters for the Yolo2OutputLayer
int nBoxes = 5;
double lambdaNoObj = 0.5;
double lambdaCoord = 5.0;
double[][] priorBoxes = { { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 }, { 2, 2 } };
double detectionThreshold = 0.3;

¶¨ÒåÎÒÃÇÄ£ÐÍÊä³ö²ãµÄһЩ²ÎÊý¡£

// parameters for the training phase
int batchSize = 2;
int nEpochs = 50;
double learningRate = 1e-3;
double lrMomentum = 0.9;

¶¨ÒåһЩÎÒÃÇѵÁ·Ê±Ä£Ð͵IJÎÊý£º

batchSizeΪ2£¬ÕâÀïÖ÷ÒªÊÇÒòΪÎÒʹÓÃCPUÔËÐУ¬¶øÇÒµçÄÔÖ»ÓÐ8GÔ˴棬Òò´Ëµ±ÄãµçÄÔÅäÖøü¸ßµÄʱºò¿ÉÒÔÑ¡Ôñ¸ü´óµÄֵʹµÃÄ£ÐÍ»ñµÃ¸üºÃµÄѵÁ·½á¹û¡£

nEpochΪ50£¬×ܹ²ÑµÁ·Êý¾Ý50¸öÂִΡ£

learningRate£¬Ñ§Ï°ÂÊΪ1e-3¡£

ѧϰÂÊË¥¼õ¶¯Á¿£¬Ó¦ÓÃÓÚNesterovs¸üÐÂÆ÷¡£

2.2 Êý¾Ý¶ÁÈ¡

String dataDir = new ClassPathResource("/datasets").getFile().getPath();
File imageDir = new File(dataDir, "JPEGImages");

ÔÚ±¾ÏîÄ¿ÖÐÊý¾Ý±»´æ·ÅÔÚresourcesÎļþ¼ÐÏ£¬Òò´ËÐèÒª»ñÈ¡Àà·¾¶£¬ÕâÀïÖ÷ÒªÊÇ»ñȡͼÏñĿ¼¡£

log.info("Load data...");
RandomPathFilter pathFilter = new RandomPathFilter(rng) {
@Override
protected boolean accept(String name) {
name = name.replace("/JPEGImages/", "/Annotations/").replace(".jpg", ".xml");
try {
return new File(new URI(name)).exists();
} catch (URISyntaxException ex) {
throw new RuntimeException(ex);
}
}
};
InputSplit[] data = new FileSplit(imageDir, NativeImageLoader.ALLOWED_FORMATS, rng).sample(pathFilter, 0.8, 0.2);
InputSplit trainData = data[0];
InputSplit testData = data[1];

¶ÁȡѵÁ·Êý¾Ý£¬²¢ÇÒ½«Êý¾Ý»®·ÖΪѵÁ·¼¯ºÍ²âÊÔ¼¯¡£

ObjectDetectionRecordReader recordReaderTrain = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth, new VocLabelProvider(dataDir));
recordReaderTrain.initialize(trainData);
ObjectDetectionRecordReader recordReaderTest = new ObjectDetectionRecordReader(height, width, nChannels, gridHeight, gridWidth,
new VocLabelProvider(dataDir));
recordReaderTest.initialize(testData);
// ObjectDetectionRecordReader performs regression, so we need to specify it here
RecordReaderDataSetIterator train = new RecordReaderDataSetIterator(recordReaderTrain, batchSize, 1, 1, true);
train.setPreProcessor(new ImagePreProcessingScaler(0, 1));
RecordReaderDataSetIterator test = new RecordReaderDataSetIterator(recordReaderTest, 1, 1, 1, true);
test.setPreProcessor(new ImagePreProcessingScaler(0, 1));

¹¹½¨ÑµÁ·¼¯ºÍ²âÊÔ¼¯µÄµü´úÆ÷£¬²¢ÇÒ´´½¨Êý¾ÝÔ¤´¦ÀíÆ÷£¬Ê¹µÃͼÏñÊý¾ÝÔÚѵÁ·Ê±±»Ëõ·ÅÖÁ0~1·¶Î§ÄÚ¡£

2.3 Ä£Ð͹¹½¨

ComputationGraph model;
String modelFilename = "model_rbc.zip";
ComputationGraph pretrained = (ComputationGraph) new TinyYOLO().initPretrained();
INDArray priors = Nd4j.create(priorBoxes);

Ê×ÏÈ»á´ÓÍøÂçÉÏÃæÏÂÔØÔ¤ÑµÁ·Ä£ÐÍ£¬ÏÂÔØµØÖ·ÎªÓû§Ä¿Â¼ÏµÄ.deeplearning4jĿ¼Ï£¬ÄÚÈÝÈçͼËùʾ£º

½ÓÏÂÀ´Ê¹ÓÃfine tune¶ÔÄ£Ðͽṹ½øÐиü¸Ä£º

FineTuneConfiguration fineTuneConf = new
FineTuneConfiguration.Builder().seed(seed)
.optimizationAlgo(OptimizationAlgorithm
.STOCHASTIC_GRADIENT_DESCENT)
.gradientNormalization
(GradientNormalization.RenormalizeL2PerLayer)
.gradientNormalizationThreshold(1.0).updater(new
Adam.Builder().learningRate(learningRate).build())
.updater(new Nesterovs.Builder().learningRate(learningRate)
.momentum(lrMomentum).build())
.activation(Activation.IDENTITY)
.trainingWorkspaceMode(WorkspaceMode.SEPARATE)
.inferenceWorkspaceMode(WorkspaceMode.SEPARATE).build();

ÒÔÉÏ´úÂëÖ÷Òª×öÁËÕ⼸¼þÊÂÇ飺

ʹÓÃËæ»úÌݶÈϽµÓÅ»¯Ëã·¨

ʹÓà RenormalizeL2PerLayer Ìݶȱê×¼»¯Ëã·¨£¬ÓÃÓÚ·ÀÖ¹ÌݶÈÏûʧºÍÌݶȱ¬Õ¨¡£

ʹÓÃNesterovs¸üÐÂÆ÷£¬ÅäÖÃѧϰÂʺͶ¯Á¿

É趨ѵÁ·Ä£Ê½¡£

Ö®ºóʹÓÃÇ¨ÒÆÑ§Ï°¶ÔÓÚÄ£Ðͼܹ¹¼ÇÐÔÐ޸ģº

model = new TransferLearning.GraphBuilder(pretrained)
.fineTuneConfiguration(fineTuneConf)
.removeVertexKeepConnections("conv2d_9")
.addLayer("convolution2d_9",
new ConvolutionLayer.Builder(1, 1).nIn(1024).nOut(nBoxes
* (5 + nClasses)).stride(1, 1).convolutionMode(ConvolutionMode.Same)
.weightInit(WeightInit.UNIFORM).hasBias(false)
.activation(Activation.IDENTITY).build(),
"leaky_re_lu_8")
.addLayer("outputs", new Yolo2OutputLayer.Builder()
.lambbaNoObj(lambdaNoObj).lambdaCoord(lambdaCoord)
.boundingBoxPriors(priors).build(),
"convolution2d_9")
.setOutputs("outputs")
.build();

Ö÷ÒªÊÇÅäÖÃʶ±ðµÄÖÖÀàÊýÄ¿¡£

2.4 Ä£ÐÍѵÁ·

model.setListeners(new ScoreIterationListener(1));
for (int i = 0; i < nEpochs; i++) {
train.reset();
while (train.hasNext()) {
model.fit(train.next());
}
log.info("*** Completed epoch {} ***", i);
}
ModelSerializer.writeModel(model, modelFilename, true);

Ä£ÐÍѵÁ·Íê³ÉÖ®ºó£¬ÐòÁл¯±£´æÔÚ±¾µØ¡£

2.5 Ä£Ðͼì²â¿ÉÊÓ»¯

// visualize results on the test set
NativeImageLoader imageLoader = new NativeImageLoader();
CanvasFrame frame = new CanvasFrame("RedBloodCellDetection");
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0);
List<String> labels = train.getLabels();
test.setCollectMetaData(true);
while (test.hasNext() && frame.isVisible()) {
org.nd4j.linalg.dataset.DataSet ds = test.next();
RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) ds.getExampleMetaData().get(0);
INDArray features = ds.getFeatures();
INDArray results = model.outputSingle(features);
List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold);
File file = new File(metadata.getURI());
log.info(file.getName() + ": " + objs);
Mat mat = imageLoader.asMat(features);
Mat convertedMat = new Mat();
mat.convertTo(convertedMat, CV_8U, 255, 0);
int w = metadata.getOrigW() * 2;
int h = metadata.getOrigH() * 2;
Mat image = new Mat();
resize(convertedMat, image, new Size(w, h));
for (DetectedObject obj : objs) {
double[] xy1 = obj.getTopLeftXY();
double[] xy2 = obj.getBottomRightXY();
String label = labels.get(obj.getPredictedClass());
int x1 = (int) Math.round(w * xy1[0] / gridWidth);
int y1 = (int) Math.round(h * xy1[1] / gridHeight);
int x2 = (int) Math.round(w * xy2[0] / gridWidth);
int y2 = (int) Math.round(h * xy2[1] / gridHeight);
rectangle(image, new Point(x1, y1), new Point(x2, y2), Scalar.RED);
putText(image, label, new Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, Scalar.GREEN);
}
frame.setTitle(new File(metadata.getURI()).getName() + " - RedBloodCellDetection");
frame.setCanvasSize(w, h);
frame.showImage(converter.convert(image));
frame.waitKey();
}
frame.dispose();

Èý¡¢ÊµÑé½á¹û

ÒòΪÊý¾ÝÁ¿ÉÙ£¬ÑµÁ·ÂÖ´ÎС ÓÐÐËȤµÄ¿ÉÒÔ×Ô¼º³¢ÊÔ¼ÌÐøÑµÁ·¡£

ËÄ¡¢´úÂëµØÖ·

´úÂëµØÖ·ÒѾ­·ÅÔÚgithubÉÏÃæ£¬×ÔÐÐÏÂÔØ¼´¿É£ºhttps://github.com/sjsdfg/dl4j-tutorials

ÔÚ°üobjectdetectionÏ£¬¿ÉÒÔËæÒâÔËÐС£

 
   
4285 ´Îä¯ÀÀ       29
Ïà¹ØÎÄÕÂ

»ùÓÚͼ¾í»ýÍøÂçµÄͼÉî¶Èѧϰ
×Ô¶¯¼ÝÊ»ÖеÄ3DÄ¿±ê¼ì²â
¹¤Òµ»úÆ÷ÈË¿ØÖÆÏµÍ³¼Ü¹¹½éÉÜ
ÏîĿʵս£ºÈçºÎ¹¹½¨ÖªÊ¶Í¼Æ×
 
Ïà¹ØÎĵµ

5GÈ˹¤ÖÇÄÜÎïÁªÍøµÄµäÐÍÓ¦ÓÃ
Éî¶ÈѧϰÔÚ×Ô¶¯¼ÝÊ»ÖеÄÓ¦ÓÃ
ͼÉñ¾­ÍøÂçÔÚ½»²æÑ§¿ÆÁìÓòµÄÓ¦ÓÃÑо¿
ÎÞÈË»úϵͳԭÀí
Ïà¹Ø¿Î³Ì

È˹¤ÖÇÄÜ¡¢»úÆ÷ѧϰ&TensorFlow
»úÆ÷ÈËÈí¼þ¿ª·¢¼¼Êõ
È˹¤ÖÇÄÜ£¬»úÆ÷ѧϰºÍÉî¶Èѧϰ
ͼÏñ´¦ÀíËã·¨·½·¨Óëʵ¼ù