Äú¿ÉÒÔ¾èÖú£¬Ö§³ÖÎÒÃǵĹ«ÒæÊÂÒµ¡£

1Ôª 10Ôª 50Ôª





ÈÏÖ¤Â룺  ÑéÖ¤Âë,¿´²»Çå³þ?Çëµã»÷Ë¢ÐÂÑéÖ¤Âë ±ØÌî



  ÇóÖª ÎÄÕ ÎÄ¿â Lib ÊÓÆµ iPerson ¿Î³Ì ÈÏÖ¤ ×Éѯ ¹¤¾ß ½²×ù Model Center   Code  
»áÔ±   
   
 
     
   
 ¶©ÔÄ
  ¾èÖú
DL4JÖеÄÑ­»·ÍøÂç
 
  3085  次浏览      28
 2019-3-8
   
 
±à¼­ÍƼö:

±¾ÎÄÀ´×ÔÓÚdeeplearning4j.org£¬ÎÄÕ½éÉÜÁËÑ­»·ÍøÂçµÄ¾ßÌ嶨Ð͹¦ÄÜ£¬ÒÔ¼°ÈçºÎÔÚDeepLearning4JÖÐʵ¼ÊÔËÓÃÕâЩ¹¦ÄܵÈÏà¹ØÄÚÈÝ¡£

»ù´¡ÄÚÈÝ£ºÊý¾ÝºÍÍøÂçÅäÖÃ

DL4JĿǰ֧³ÖÒÔϸ÷ÀàÑ­»·Éñ¾­ÍøÂç * GravesLSTM£¨³¤¶ÌÆÚ¼ÇÒ䣩 * BidirectionalGravesLSTM£¨Ë«Ïò³¤¶ÌÆÚ¼ÇÒ䣩 * BaseRecurrent

ÿÖÖÍøÂç¾ùÓÐJavaÎĵµ¿É¹©²Î¿¼£ºGravesLSTM¡¢ BidirectionalGravesLSTM¡¢BaseRecurrent

RNNµÄÊý¾Ý

ÔÚ±ê×¼µÄǰÀ¡ÍøÂçÖУ¨¶à²ã¸ÐÖªÆ÷»òDL4JµÄ¡¯DenseLayer¡¯£©£¬ÊäÈëºÍÊä³öÊý¾Ý¾ßÓжþά½á¹¹£¬»òÕß˵Êý¾ÝµÄ¡°ÐÎ×´¡±¿ÉÒÔÃèÊöΪ[numExamples, inputSize]£¬¼´ÊäÈëǰÀ¡ÍøÂçµÄÊý¾ÝµÄÐУ¯ÑùÀýÊýΪ¡¯numExamples¡¯£¬¶øÃ¿Ò»ÐÐÖеÄÁÐÊýΪ¡¯inputSize¡¯¡£µ¥¸öÑùÀýµÄÐÎ״ӦΪ[1,inputSize]£¬µ«ÔÚʵ¼ÊÓ¦ÓÃÖУ¬ÎªÁ˱£Ö¤ÔËËãºÍÓÅ»¯µÄЧÂÊ£¬Í¨³£»áʹÓöà¸öÑùÀý¡£Óë´ËÀàËÆ£¬±ê׼ǰÀ¡ÍøÂçµÄÊä³öÊý¾ÝͬÑù¾ßÓжþά½á¹¹£¬ÐÎ״Ϊ[numExamples,outputSize]¡£

¶øRNNµÄÊý¾ÝÔòÊÇʱ¼äÐòÁС£ÕâЩÊý¾Ý¾ß±¸Èý¸öά¶È£¬Ôö¼ÓÁËÒ»¸öʱ¼äά¶È¡£Òò´Ë£¬ÊäÈëÊý¾ÝµÄÐÎ״Ϊ[numExamples,inputSize,timeSeriesLength]£¬¶øÊä³öÊý¾ÝµÄÐÎ״Ϊ[numExamples,outputSize,timeSeriesLength]¡£¾ÍINDArrayÖеÄÊý¾Ý²¼¾Ö¶øÑÔ£¬Î»ÓÚ(i,j,k)µÄÖµ¼´ÊÇÒ»ÅúÊý¾ÝÖеÚiÀýµÄµÚk¸öʱ¼ä²½µÄµÚj¸öÖµ¡£Êý¾Ý²¼¾ÖÈçÏÂͼËùʾ¡£

RnnOutputLayer

RnnOutputLayerÊÇÔÚÐí¶àÑ­»·ÍøÂçϵͳ£¨ÓÃÓڻعéºÍ·ÖÀàÈÎÎñ£©ÖÐʹÓõÄ×îÖղ㡣RnnOutputLayer¿É´¦Àí¼Æ·ÖÔËËã¡¢»ùÓÚ¸ø¶¨Ëðʧº¯ÊýµÄÎó²î¼ÆË㣨Ԥ²âÓëʵ¼ÊÖµ¶Ô±È£©µÈ¡£´Ó¹¦ÄÜÉÏ¿´£¬ËüÓë¡°±ê×¼¡±µÄOutputLayerÀࣨÓÃÓÚǰÀ¡ÍøÂ磩ʮ·ÖÏàËÆ£»µ«RnnOutputLayerµÄÊä³ö£¨ÒÔ¼°±êÇ©/Ä¿±ê£©¾ùΪÈýάʱ¼äÐòÁÐÊý¾Ý¼¯¡£

RnnOutputLayerÅäÖÃÓëÆäËû²ã²ÉÈ¡ÏàͬµÄÉè¼Æ£ºÀýÈ磬½«MultiLayerNetworkµÄµÚÈý²ãÉèÖÃΪRnnOutputLayer£¬ÓÃÓÚ·ÖÀࣺ

.layer(2, new RnnOutputLayer.Builder
(LossFunction.MCXENT).activation("softmax")
.weightInit(WeightInit.XAVIER)
.nIn(prevLayerSize).nOut(nOut).build())

ÓйØRnnOutputLayerµÄʵ¼ÊÓ¦Ó㬿ɲο¼±¾Ò³Ä©Î²´¦Á´½ÓÖ¸ÏòµÄÏà¹ØÊ¾Àý¡£

RNN¶¨Ð͹¦ÄÜ

½Ø¶ÏÊ½ÑØÊ±¼ä·´Ïò´«²¥

Éñ¾­ÍøÂ磨°üÀ¨RNN£©¶¨Ð͵ÄÔËËãÄÜÁ¦ÒªÇó¿ÉÄÜÏ൱¸ß¡£Ñ­»·ÍøÂçÔÚ´¦Àí½Ï³¤ÐòÁÐʱ£¨¼´¶¨ÐÍÊý¾ÝÓÐÐí¶àʱ¼ä²½Ê±£©ÓÈÆäÈç´Ë¡£

²ÉÓýضÏÊ½ÑØÊ±¼ä·´Ïò´«²¥Ëã·¨£¨BPTT£©¿ÉÒÔ½µµÍÑ­»·ÍøÂçÖÐÿÏî²ÎÊý¸üеĸ´ÔÓ¶È¡£¼ò¶øÑÔÖ®£¬´ËÖÖËã·¨¿ÉÒÔÈÃÎÒÃÇÒÔͬÑùµÄÔËËãÄÜÁ¦¸ü¿ìµØ¶¨ÐÍÉñ¾­ÍøÂ磨Ìá¸ß²ÎÊý¸üÐÂµÄÆµÂÊ£©¡£ÎÒÃǽ¨ÒéÔÚÊäÈë½Ï³¤ÐòÁÐʱ£¨Í¨³£Ö¸³¬¹ý¼¸°Ù¸öʱ¼ä²½£©Ê¹ÓýضÏʽBPTTËã·¨¡£

¼ÙÉèÓó¤¶ÈΪ12¸öʱ¼ä²½µÄʱ¼äÐòÁж¨ÐÍÒ»¸öÑ­»·ÍøÂç¡£ÎÒÃÇÐèÒª½øÐÐ12²½µÄÕýÏò´«µÝ£¬¼ÆËãÎó²î£¨»ùÓÚÔ¤²âÓëʵ¼ÊÖµ¶Ô±È£©£¬ÔÙ½øÐÐ12¸öʱ¼ä²½µÄ·´Ïò´«µÝ£º

ÈçÉÏͼËùʾ£¬12¸öʱ¼ä²½µÄÔËËã²»»áÓÐÎÊÌâ¡£µ«ÊÔÏëÊäÈëµÄʱ¼äÐòÁбäΪ10,000¸öʱ¼ä²½£¬ÉõÖÁ¸ü¶à¡£´Ëʱ£¬ÈôʹÓñê×¼µÄÑØÊ±¼ä·´Ïò´«²¥Ëã·¨£¬Ôòÿ¸ö²ÎÊýÿ´Î¸üж¼ÐèÒª½øÐÐ10,000´ÎÕýÏò¼°·´Ïò´«µÝ¡£ÕâÖÖ·½·¨¶ÔÔËËãÄÜÁ¦µÄÒªÇóÏÔÈ»ºÜ¸ß¡£

ÔÚʵ¼ÊÓ¦ÓÃÖУ¬½Ø¶ÏʽBPTT¿É½«ÕýÏòºÍ·´Ïò´«µÝ²ð·ÖΪһϵÁнÏСʱ¼ä¶ÎµÄÕýÏò£¯·´Ïò´«µÝ²Ù×÷¡£ÕýÏò£¯·´Ïò´«µÝʱ¼ä¶ÎµÄ¾ßÌ峤¶ÈÊÇÓû§¿ÉÒÔ×ÔÐÐÉ趨µÄ²ÎÊý¡£ÀýÈ磬Èô½«½Ø¶ÏʽBPTTµÄ³¤¶ÈÉ趨Ϊ4¸öʱ¼ä²½£¬Ôòѧϰ¹ý³ÌÈçÏÂͼËùʾ£º

×¢Òâ½Ø¶ÏʽBPTTºÍ±ê×¼BPTTµÄ×ÜÌ帴ÔÓ¶È´óÖÂÏàͬ£­Á½ÕßµÄÕýÏò£¯·´Ïò´«µÝʱ¼ä²½ÊýÁ¿ÏàµÈ¡£µ«ÊÇ£¬²ÉÓø÷½·¨ºó£¬ÓÃÔ­À´1´Î²ÎÊý¸üÐµĹ¤×÷Á¿¿ÉÒÔÍê³É3´Î¸üС£È»¶øÁ½ÖÖ·½·¨µÄÔËËãÁ¿²¢²»ÍêȫһÖ£¬ÒòΪÿ´Î²ÎÊý¸üлáÓÐÉÙÁ¿¶îÍâÔËËãÁ¿¡£

½Ø¶ÏʽBPTTµÄ²»ÀûÖ®´¦ÔÚÓÚ£¬Í¨¹ýÕâÖÖ·½·¨Ï°µÃµÄÒÀÀµ³¤¶È¿ÉÄܶÌÓÚÍêÕûBPTT¡£Ô­ÒòºÜÃ÷ÏÔ¡£ÒÔÉÏͼÖ㤶ÈΪ4µÄ½Ø¶ÏʽBPTTΪÀý¡£¼ÙÉèÔÚµÚ10¸öʱ¼ä²½Ê±£¬ÍøÂçÐèÒª´æ´¢Ò»Ð©À´×ÔµÚ0ʱ¼ä²½µÄÐÅÏ¢À´×ö³ö׼ȷµÄÔ¤²â¡£ÕâÔÚ±ê×¼BPTTÖпÉÒÔʵÏÖ£ºÌݶȿÉÑØ×ÅÕ¹¿ªµÄÍøÂç·´ÏòÁ÷¶¯£¬´ÓµÚ10²½Ò»Ö±µ½µÚ0²½¡£¶ø½Ø¶ÏʽBPTTÔò»á³öÏÖÎÊÌ⣺µÚ10ʱ¼ä²½µÄÌݶȷ´ÏòÁ÷¶¯µÄ¾àÀë²»¹»Ô¶£¬ÎÞ·¨Íê³É´æ´¢±ØÒªÐÅÏ¢ËùÐèµÄ²ÎÊý¸üС£Í¨³£Çé¿öÏ£¬½Ø¶ÏʽBPTT·½·¨Àû´óÓÚ±×£¬¶øÇÒ£¨Ö»Òª³¤¶ÈÉ趨ºÏÊÊ£©ÔÚʵ¼ÊÓ¦ÓÃÖÐЧ¹ûÁ¼ºÃ¡£

ÔÚDL4JÖÐʹÓýضÏʽBPTTÏ൱¼òµ¥£ºÖ»Ð轫ÏÂÁдúÂë¼ÓÈëÍøÂçÅäÖã¨Ìí¼ÓÔÚÍøÂçÅäÖÃ×îºóµÄ.build()֮ǰ£©

.backpropType(BackpropType.TruncatedBPTT)
.tBPTTForwardLength(100)
.tBPTTBackwardLength(100)

ÉÏÊö´úÂëÆ¬¶Î½«ÁîÈÎÒâÍøÂ綨ÐÍ£¨¼´µ÷ÓÃMultiLayerNetwork.fit()·½·¨£©Ê¹ÓýضÏʽBPTT£¬ÕýÏòÓë·´Ïò´«µÝ³¤¶È¾ùΪ100¡£

×¢ÒâÊÂÏ

ÔÚĬÈÏÇé¿öÏ£¨Î´ÊÖ¶¯ÉèÖ÷´Ïò´«²¥ÀàÐÍ£©£¬DL4J½«Ê¹ÓÃBackpropType.Standard£¨¼´ÍêÕûBPTT£©¡£

tBPTTForwardLengthºÍtBPTTBackwardLengthÑ¡ÏîÓÃÓÚÉèÖýضÏʽBPTT´«µÝµÄ³¤¶È¡£Ê±¼ä¶Î³¤¶Èͨ³£É趨Ϊ50¡«200£¬µ«ÐèÒªÊÓ¾ßÌåÓ¦Óöø¶¨¡£ÕýÏò´«µÝÓë·´Ïò´«µÝµÄ³¤¶Èͨ³£Ïàͬ£¨ÓÐʱtBPTTBackwardLength¿ÉÄܸü¶Ì£¬µ«²»»á¸ü³¤£©

½Ø¶ÏʽBPTTµÄ³¤¶È±ØÐë¶ÌÓÚ»òµÈÓÚʱ¼äÐòÁеÄ×ܳ¤

ÑÚÄ££ºÒ»¶Ô¶à¡¢¶à¶ÔÒ»ºÍÐòÁзÖÀà

DL4JÖ§³ÖһϵÁлùÓÚÌîÁãºÍÑÚÄ£²Ù×÷µÄRNN¶¨Ð͹¦ÄÜ¡£ÌîÁãºÍÑÚÄ£ÈÃÎÒÃÇÄÜÖ§³ÖÖîÈçÒ»¶Ô¶à¡¢¶à¶ÔÒ»Êý¾ÝÇé¾°ÏµĶ¨ÐÍ£¬Í¬Ê±Ò²ÄÜÖ§³Ö³¤¶È¿É±äµÄʱ¼äÐòÁУ¨Í¬Ò»Åú´ÎÄÚ£©¡£

¼ÙÉèÎÒÃÇÓÃÓÚ¶¨ÐÍÑ­»·ÍøÂçµÄÊäÈëºÍÊä³öÊý¾Ý²¢²»»áÔÚÿ¸öʱ¼ä²½¶¼³öÏÖ¡£¾ßÌåʾÀý£¨µ¥¸öÑùÀý£©¼ûÏÂͼ¡£DL4JÖ§³ÖÒÔÏÂËùÓÐÇé¾°µÄÍøÂ綨ÐÍ¡£

Èç¹ûûÓÐÑÚÄ£ºÍÌîÁã²Ù×÷£¬¾ÍÖ»ÄÜÖ§³Ö¶à¶Ô¶àµÄÇé¾°£¨ÉÏͼ×óÒ»£©£¬¼´(a)ËùÓÐÑùÀý³¤¶ÈÏàͬÇÒ(b)ÑùÀýÔÚÿһʱ¼ä²½¾ùÓÐÊäÈëºÍÊä³ö¡£

ÌîÁãµÄ¸ÅÄîºÜ¼òµ¥¡£ÊÔÏëͬһÅú´ÎÖÐÓÐÁ½¸ö³¤¶È·Ö±ðΪ50ºÍ100¸öʱ¼ä²½µÄʱ¼äÐòÁС£¶¨ÐÍÊý¾ÝÊÇÒ»¾ØÐÎÊý×飻Òò´ËÎÒÃǶԽ϶̵Äʱ¼äÐòÁУ¨ÊäÈëºÍÊä³ö£©½øÐÐÌîÁã²Ù×÷£¨¼´Ìí¼ÓÁ㣩£¬Ê¹ÊäÈëºÍÊä³ö³¤¶ÈÏàµÈ£¨ÔÚ±¾ÀýÖÐΪ100ʱ¼ä²½£©¡£

µ±È»£¬Ö»½øÐÐÕâÒ»²Ù×÷»áµ¼Ö¶¨ÐͳöÏÖÎÊÌâ¡£Òò´ËÔÚÌîÁãÖ®Í⣬ÎÒÃÇ»¹Ê¹ÓÃÑÚÄ£»úÖÆ¡£ÑÚÄ£µÄ¸ÅÄîÒ²ºÜ¼òµ¥£ºÎÒÃÇÔö¼ÓÁ½¸öÊý×飬ÓÃÀ´¼Ç¼һ¸öʱ¼ä²½ºÍÑùÀýµÄÊäÈ룯Êä³öÊÇʵ¼ÊµÄÊäÈ룯Êä³ö»¹ÊÇÌîÁã¡£

ÈçǰÎÄËùÊö£¬RNNµÄÅú´ÎÊý¾ÝÓÐ3¸öά¶È£¬ÊäÈëºÍÊä³öµÄÐÎ״Ϊ[miniBatchSize,inputSize,timeSeriesLength]ºÍ [miniBatchSize,outputSize,timeSeriesLength]¡£¶øÌîÁãÊý×éÔòÊǶþά½á¹¹£¬ÊäÈëºÍÊä³öµÄÐÎ×´¾ùΪ[miniBatchSize,timeSeriesLength]£¬Ã¿Ò»Ê±¼äÐòÁкÍÑùÀý¶ÔÓ¦µÄֵΪ0£¨¡°²»´æÔÚ¡±£©»ò1£¨¡°´æÔÚ¡±£©¡£ÊäÈëÓëÊä³öµÄÑÚÄ£Êý×é·Ö¿ª´æ´¢ÔÚ²»Í¬µÄÊý×éÖС£

¶Ôµ¥¸öÑùÀý¶øÑÔ£¬ÊäÈëÓëÊä³öµÄÑÚÄ£Êý×éÈçÏ£º

¶ÔÓÚ¡°²»ÐèÒªÑÚÄ£¡±µÄÇé¿ö£¬ÎÒÃÇ¿ÉÒÔʹÓÃÈ«²¿ÖµÎª1µÄÑÚÄ£Êý×飬ËùµÃ½á¹ûÓ벻ʹÓÃÑÚÄ£Êý×éÏàͬ¡£´ËÍ⣬RNN¶¨ÐÍÖÐʹÓõÄÑÚÄ£Êý×é¿ÉÒÔÊÇÁã¸ö¡¢Ò»¸ö»òÕßÁ½¸ö£¬±ÈÈç¶à¶ÔÒ»µÄÇé¾°¾ÍÓпÉÄܽöÉèÖÃÒ»¸öÓÃÓÚÊä³öµÄÑÚÄ£Êý×é¡£

ʵ¼ÊÓ¦ÓÃÖУ¬ÌîÁãÊý×éÒ»°ãÔÚÊý¾Ýµ¼Èë½×¶Î´´½¨£¨ÀýÈçÓÉSequenceRecordReaderDatasetIterator´´½¨£¬ºóÎĽ«¾ßÌå½éÉÜ£©£¬°üº¬ÔÚDataSet¶ÔÏóÖС£Èç¹ûÒ»¸öDataSet°üº¬ÑÚÄ£Êý×飬MultiLayerNetworkÔÚ¶¨ÐÍÖлá×Ô¶¯Ê¹Óá£Èç¹û²»´æÔÚÑÚÄ£Êý×飬Ôò²»»áÆôÓÃÑÚÄ£¹¦ÄÜ¡£

ʹÓÃÑÚÄ£µÄÆÀ¹ÀÓë¼Æ·Ö

ÑÚÄ£Êý×éÔÚ½øÐмƷÖÓëÆÀ¹Àʱ£¨ÈçÆÀ¹ÀRNN·ÖÀàÆ÷µÄ׼ȷÐÔ£©Ò²ºÜÖØÒª¡£ÒÔ¶à¶ÔÒ»Ç龰ΪÀý£ºÃ¿¸öÑùÀý½öÓе¥Ò»Êä³ö£¬ÈÎºÎÆÀ¹À¶¼Ó¦¿¼Âǵ½ÕâÒ»µã¡£

ÔÚÆÀ¹ÀÖпÉͨ¹ýÒÔÏ·½·¨Ê¹Óã¨Êä³ö£©ÑÚÄ£Êý×飺

Evaluation.evalTimeSeries(INDArray labels, INDArray predicted, INDArray outputMask)

ÆäÖÐlabelsÊÇʵ¼ÊÊä³ö£¨Èýάʱ¼äÐòÁУ©£¬predictedÊÇÍøÂçµÄÔ¤²â£¨Èýάʱ¼äÐòÁУ¬ÓëlabelsÐÎ×´Ïàͬ£©£¬¶øoutputMaskÔòÊÇÊä³öµÄ¶þάÑÚÄ£Êý×é¡£×¢ÒâÆÀ¹À²¢²»ÐèÒªÊäÈëÑÚÄ£Êý×é¡£

µÃ·Ö¼ÆËãͬÑù»áͨ¹ýMultiLayerNetwork.score(DataSet)·½·¨Óõ½ÑÚÄ£Êý×é¡£ÈçǰÎÄËùÊö£¬Èç¹ûDataSet°üÀ¨Ò»¸öÊä³öÑÚÄ£Êý×飬¼ÆËãÍøÂçµÃ·Ö£¨Ëðʧº¯Êý - ¾ù·½²î¡¢¸º¶ÔÊýËÆÈ»º¯ÊýµÈ£©Ê±¾Í»á×Ô¶¯Ê¹ÓÃÑÚÄ£¡£

ÑÚĤÓ붨ÐͺóµÄÐòÁзÖÀà

ÐòÁзÖÀàÊÇÑÚĤµÄ³£¼ûÓÃ;֮һ¡£Ö®ËùÒÔ²ÉÓÃÑÚĤ£¬ÊÇÒòΪѭ»·ÍøÂçµÄÊäÈëÊÇÐòÁУ¨Ê±¼äÐòÁУ©£¬¶øÎÒÃÇÖ»ÐèҪΪÕû¸öÐòÁмÓÒ»¸ö±êÇ©£¨¶ø²»ÊÇΪÐòÁÐÖеÄÿ¸öʱ¼ä²½¶¼¼ÓÒ»¸ö±êÇ©£©¡£

µ«ÊÇ£¬¸ù¾ÝRNNµÄÉè¼Æ£¬ÍøÂçÊä³öµÄÐòÁÐÓ¦ÓëÊäÈëÐòÁ㤶ÈÏàµÈ¡£ÓÐÁËÑÚĤ£¬ÎÒÃÇÔÚ¶¨ÐÍÓÃÓÚÐòÁзÖÀàµÄÍøÂçʱ¾Í¿ÉÒÔ½«Õû¸öÐòÁеıêÇ©ÖÃÓÚ×îºóÒ»¸öʱ¼ä²½¡ª¡ªÆä±¾ÖʾÍÊÇÈÃÍøÂçÖªµÀ±êÇ©Êý¾Ýʵ¼ÊÉÏÖ»³öÏÖÔÚ×îºóÒ»¸öʱ¼ä²½¡£

¼ÙÉèÍøÂçÒѶ¨ÐÍÍê±Ï£¬ÏÖÔÚÎÒÃÇÏ£Íû´Óʱ¼äÐòÁеÄÊä³öÊý×éÖлñÈ¡×îºóÒ»¸öʱ¼ä²½µÄÔ¤²â½á¹û¡£¸ÃÔõÑù²Ù×÷ÄØ£¿

Òª»ñÈ¡×îºóÒ»¸öʱ¼ä²½µÄ½á¹û£¬ÐèÒª¿¼ÂÇÁ½ÖÖ²»Í¬µÄÇéÐΡ£Ê×ÏÈ£¬Èç¹ûÖ»Óе¥¸öÑùÀý£¬ÄÇÎÒÃǾͲ»ÐèҪʹÓÃÑÚĤÊý×飬¿ÉÒÔÖ±½Ó»ñÈ¡Êä³öÊý×éÖÐ×îºóÒ»¸öʱ¼ä²½µÄ½á¹û£º

INDArray timeSeriesFeatures = ...;
INDArray timeSeriesOutput = myNetwork.output(timeSeriesFeatures);
int timeSeriesLength = timeSeriesOutput.size(2); //ʱ¼äά¶È´óС
INDArray lastTimeStepProbabilities = timeSeriesOutput.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(timeSeriesLength-1));

¼ÙÉèÏÖÔÚ´¦ÀíµÄÊÇ·ÖÀàÎÊÌ⣨µ«»Ø¹éÎÊÌâµÄÁ÷³ÌÒ²Ò»Ñù£©£¬ÉÏÊö×îºóÒ»ÐдúÂë»á¸ø³ö×îºóÒ»¸öʱ¼ä²½µÄ¸ÅÂÊ£¬Ò༴ÐòÁзÖÀàµÄÀà±ð¸ÅÂÊ¡£

ÁíÒ»ÖÖ¸ü¸´ÔÓµÄÇéÐÎÊÇÒ»¸ö΢Åú´Î£¨ÌØÕ÷Êý×飩Öаüº¬Á˶à¸öÑùÀý£¬¶øÑùÀýµÄ³¤¶È¸÷²»Ïàͬ£¨Èç¹ûÑùÀý³¤¶ÈÏàͬ£¬Ôò¿ÉÒÔʹÓÃǰһÖÖÁ÷³Ì£©¡£

ÔÚÑùÀý³¤¶ÈÓвîÒìµÄÇéÐÎÖУ¬ÎÒÃÇÐèÒª·Ö±ð»ñÈ¡¸÷¸öÑùÀýÔÚ×îºóÒ»¸öʱ¼ä²½µÄ½á¹û¡£Èç¹ûÊý¾Ý¼Ó¹¤¹ÜµÀΪÎÒÃÇÌṩÁËÿ¸öÑùÀýµÄʱ¼äÐòÁг¤¶È£¬ÄǾͱȽϺð죺ֻÐè¶ÔÑùÀý½øÐеü´ú £¬½«Ç°ÎÄ´úÂëÖеÄ```timeSeriesLength```»»³ÉÑùÀý³¤¶È¼´¿É¡£

¼ÙÈçÎÞ·¨Ö±½Ó»ñȡʱ¼äÐòÁеij¤¶È£¬ÎÒÃǾÍÐèÒª½«Æä´ÓÑÚĤÊý×éÖÐÌáÈ¡³öÀ´¡£

Èç¹ûÓбêÇ©ÑÚĤÊý×飨ÿ¸öʱ¼äÐòÁÐÏà¶ÔÓ¦µÄone-hotÏòÁ¿£¬ÐÎÈç[0,0,0,1,0]£©£º

INDArray labelsMaskArray = ...;
INDArray lastTimeStepIndices = Nd4j.argMax(labelMaskArray,1);

¼ÙÈçÖ»ÓÐÌØÕ÷ÑÚĤ£¬Ò»ÖֱȽÏÖ±½ØÁ˵±µÄ´¦Àí·½·¨ÊÇ£º

INDArray featuresMaskArray = ...;
int longestTimeSeries = featuresMaskArray.size(1);
INDArray linspace = Nd4j.linspace(1,longestTimeSeries,longestTimeSeries);
INDArray temp = featuresMaskArray.mulColumnVector(linspace);
INDArray lastTimeStepIndices = Nd4j.argMax(temp,1);

¿ÉÒÔÕâÑùÀí½âÉÏÊö·½·¨µÄÔ­Àí£ºÎÒÃÇÓÐÐÎÈç[1,1,1,1,0]µÄÌØÕ÷ÑÚĤ£¬ÏÖÔÚÒª´ÓÖÐÌáÈ¡³ö×îºóÒ»¸ö·ÇÁãÔªËØ¡£ËùÒÔÎÒÃǽ«[1,1,1,1,0]Ó³ÉäΪ[1,2,3,4,0]£¬È»ºóÌáÈ¡ÆäÖÐ×î´óµÄÔªËØ£¨¼´×îºóÒ»¸öʱ¼ä²½£©¡£

ÎÞÂÛÊÇÄÄÖÖÇéÐΣ¬½ÓÏÂÀ´µÄ²½Öè¶¼¿ÉÒÔÊÇ£º

int numExamples = timeSeriesFeatures.size(0);
for( int i=0; i<numExamples; i++ ){
int thisTimeSeriesLastIndex = lastTimeStepIndices.getInt(i);
INDArray thisExampleProbabilities = timeSeriesOutput.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(thisTimeSeriesLastIndex));
}

RNN²ãÓëÆäËûÉñ¾­ÍøÂç²ãµÄ½áºÏÓ¦ÓÃ

DL4JÖеÄRNN²ã¿ÉÒÔÓëÆäËûÀàÐ͵IJã½áºÏʹÓá£ÀýÈ磬¿ÉÒÔÔÚͬһ¸öÍøÂç½áºÏʹÓÃDenseLayerºÍGravesLSTM²ã£»»òÕß½«¾í»ý£¨CNN£©²ãÓëGravesLSTM²ã½áºÏÓÃÓÚ´¦ÀíÊÓÆµ¡£

µ±È»£¬DenseLayerºÍ¾í»ý²ã²¢²»´¦Àíʱ¼äÐòÁÐÊý¾Ý£­ÕâЩ²ãÒªÇóµÄÊäÈëÀàÐͲ»Í¬¡£ÎªÁ˽â¾öÕâÒ»ÎÊÌ⣬ÎÒÃÇÐèҪʹÓòãÔ¤´¦ÀíÆ÷¹¦ÄÜ£º±ÈÈçCnnToRnnPreProcessorºÍFeedForwardToRnnPreprocessorÀà¡£´ó²¿·ÖÇé¿öÏ£¬DL4JÅäÖÃϵͳ»á×Ô¶¯Ìí¼ÓËùÐèµÄÔ¤´¦ÀíÆ÷¡£µ«Ô¤´¦ÀíÆ÷Ò²¿ÉÒÔÊÖ¶¯Ìí¼Ó£¨Ìæ´úΪÿһ²ã×Ô¶¯Ìí¼ÓµÄÔ¤´¦ÀíÆ÷£©¡£

ÀýÈ磬ÈçÐèÔÚµÚ1ºÍµÚ2²ãÖ®¼äÌí¼ÓÔ¤´¦ÀíÆ÷£¬¿ÉÔÚÍøÂçÅäÖÃÖÐÌí¼ÓÏÂÁдúÂ룺.inputPreProcessor(2, new RnnToFeedForwardPreProcessor()).

²âÊÔʱ¼ä£ºÖð²½Ô¤²â

ͬÆäËûÀàÐ͵ÄÉñ¾­ÍøÂçÒ»Ñù£¬RNN¿ÉÒÔʹÓÃMultiLayerNetwork.output() ºÍMultiLayerNetwork.feedForward() ·½·¨Éú³ÉÔ¤²â¡£ÕâЩ·½·¨ÊÊÓÃÓÚÖî¶àÇé¿ö£»µ«ËüÃǵÄÏÞÖÆÊÇ£¬ÔÚÉú³Éʱ¼äÐòÁеÄÔ¤²âʱ£¬Ã¿´Î¶¼Ö»ÄÜ´ÓÍ·¿ªÊ¼ÔËËã¡£

¼ÙÉèÎÒÃÇÐèÒªÔÚÒ»¸öʵʱϵͳÖÐÉú³É»ùÓÚ´óÁ¿ÀúÊ·Êý¾ÝµÄÔ¤²â¡£ÔÚÕâÖÖÇé¿öÏ£¬Ê¹ÓÃoutput/feedForward·½·¨ÊDz»Êµ¼ÊµÄ£¬ÒòΪÕâЩ·½·¨Ã¿´Î±»µ÷ÓÃʱ¶¼ÐèÒª½øÐÐËùÓÐÀúÊ·Êý¾ÝµÄÕýÏò´«µÝ¡£Èç¹ûÎÒÃÇÒªÔÚÿ¸öʱ¼ä²½½øÐе¥¸öʱ¼ä²½µÄÔ¤²â£¬ÄÇô´ËÀà·½·¨»áµ¼ÖÂ(a)ÔËËãÁ¿ºÜ´ó£¬Í¬Ê±(b)ÓÉÓÚÖØ¸´Í¬ÑùµÄÔËËã¶øÔì³ÉÀË·Ñ¡£

¶ÔÓÚ´ËÀàÇé¿ö£¬MultiLayerNetworkÌṩËÄÖÖÖ÷ÒªµÄ·½·¨£º

rnnTimeStep(INDArray)

rnnClearPreviousState()

rnnGetPreviousState(int layer)

rnnSetPreviousState(int layer, Map<String,INDArray> state)

rnnTimeStep()·½·¨µÄ×÷ÓÃÊÇÌá¸ßÕýÏò´«µÝ£¨Ô¤²â£©µÄЧÂÊ£¬Ò»´Î½øÐÐÒ»²½»òÊý²½Ô¤²â¡£Óëoutput/feedForward·½·¨²»Í¬£¬rnnTimeStep·½·¨ÔÚ±»µ÷ÓÃʱ»á¼Ç¼RNN¸÷²ãµÄÄÚ²¿×´Ì¬¡£ÐèҪעÒâµÄÊÇ£¬rnnTimeStepÓëoutput/feedForward·½·¨µÄÊä³öÓ¦µ±ÍêȫһÖ£¨¶Ôÿ¸öʱ¼ä²½¶øÑÔ£©£¬²»ÂÛÊÇͬʱ½øÐÐËùÓÐÔ¤²â£¨output/feedForward£©»¹ÊÇÒ»´ÎÖ»Éú³ÉÒ»²½»òÊý²½Ô¤²â£¨rnnTimeStep£©£¬Î¨Ò»µÄÇø±ð¾ÍÊÇÔËËãÁ¿²»Í¬¡£

¼òÑÔÖ®£¬MultiLayerNetwork.rnnTimeStep()·½·¨ÓÐÒÔÏÂÁ½Ïî×÷Óãº

ÓÃÊÂÏÈ´æ´¢µÄ״̬£¨ÈçÓУ©Éú³ÉÊä³ö£¯Ô¤²â£¨ÕýÏò´«µÝ£©

¸üÐÂÒÑ´æ´¢µÄ״̬£¬¼Ç¼ÉÏÒ»¸öʱ¼ä²½µÄ¼¤»îÇé¿ö£¨×¼±¸ÔÚÏÂÒ»´Îµ÷ÓÃrnnTimeStepʱʹÓã©

ÀýÈ磬¼ÙÉèÎÒÃÇÐèÒªÓÃÒ»¸öRNNÀ´Ô¤²âһСʱºóµÄÌìÆø×´¿ö£¨¼Ù¶¨ÊäÈëÊÇǰ100¸öСʱµÄÌìÆøÊý¾Ý£©¡£ Èç¹û²ÉÓÃoutput·½·¨£¬ÄÇôÎÒÃÇÐèÒªËÍÈëÈ«²¿100¸öСʱµÄÊý¾Ý£¬²ÅÄÜÔ¤²â³öµÚ101¸öСʱµÄÌìÆø¡£¶øÔ¤²âµÚ102¸öСʱµÄÌìÆøÊ±£¬ÎÒÃÇÓÖÐèÒªËÍÈë100£¨»ò101£©¸öСʱµÄÊý¾Ý£»µÚ103¸öСʱ¼°Ö®ºóµÄÔ¤²âͬÀí¡£

»òÕߣ¬ÎÒÃÇ¿ÉÒÔʹÓÃrnnTimeStep·½·¨¡£µ±È»£¬ÔÚ½øÐеÚÒ»´ÎÔ¤²âʱ£¬ÎÒÃÇÈÔÐèҪʹÓÃÈ«²¿100¸öСʱµÄÀúÊ·Êý¾Ý£¬½øÐÐÍêÕûµÄÕýÏò´«µÝ£º

Ê״ε÷ÓÃrnnTimeStepʱ£¬Î¨Ò»Êµ¼ÊÇø±ð¾ÍÊÇÉÏÒ»¸öʱ¼ä²½µÄ¼¤»îÇé¿ö£¯×´Ì¬»á±»¼Ç¼ÏÂÀ´£­Í¼ÖÐÒÔ³ÈÉ«±íʾ¡£µ«ÊÇ£¬µÚ¶þ´ÎʹÓÃrnnTimeStep·½·¨Ê±£¬ÒÑ´æ´¢µÄ״̬»á±»ÓÃÓÚÉú³ÉµÚ¶þ´ÎÔ¤²â£º

ÕâÀïÓм¸¸öÖØÒªµÄÇø±ð£º

ÔÚµÚ¶þÕÅͼƬÖУ¨µÚ¶þ´Îµ÷ÓÃrnnTimeStep£©£¬ÊäÈëÊý¾Ý½öΪµ¥¸öʱ¼ä²½£¬¶ø·ÇËùÓеÄÀúÊ·Êý¾Ý¡£

Òò´Ë£¬ÕýÏò´«µÝÖ»°üÀ¨Ò»¸öʱ¼ä²½£¨¶ø²»ÊǼ¸°Ù¸ö»ò¸ü¶à£©

rnnTimeStep·½·¨·µ»Øºó£¬ÄÚ²¿×´Ì¬»á×Ô¶¯¸üС£ËùÒÔµÚ103¸öʱ¼ä²½µÄÔ¤²â·½Ê½ÓëµÚ102²½Ïàͬ¡£ÒÔ´ËÀàÍÆ¡£

µ«Èç¹ûÒª¿ªÊ¼¶ÔÒ»¸öеģ¨ÍêÈ«·ÖÀëµÄ£©Ê±¼äÐòÁнøÐÐÔ¤²â£¬¾Í±ØÐ루ÕâºÜÖØÒª£©ÓÃMultiLayerNetwork.rnnClearPreviousState()·½·¨ÊÖ¶¯Çå³ýÒÑ´æ´¢µÄ״̬¡£¸Ã·½·¨½«»áÖØÖÃÍøÂçÖÐËùÓÐÑ­»·²ãµÄÄÚ²¿×´Ì¬¡£

Èç¹ûÐèÒª´æ´¢»òÉèÖÃRNNµÄÄÚ²¿×´Ì¬ÒÔÓÃÓÚÔ¤²â£¬¿ÉÒÔ¶Ôÿһ²ã·Ö±ðʹÓÃrnnGetPreviousStateºÍrnnSetPreviousState·½·¨¡£ÕâÊÊÓÃÓÚÀýÈçÐòÁл¯£¨ÍøÂç±£´æ£¯¼ÓÔØ£©µÈÇé¿ö£¬ÒòΪÓÉrnnTimeStep·½·¨²úÉúµÄÄÚ²¿ÍøÂç״̬ĬÈϲ»»á±£´æ£¬±ØÐëÁíÍâ±£´æºÍ¶ÁÈ¡¡£×¢ÒâÕâЩ»ñÈ¡/ÉèÖÃ״̬µÄ·½·¨·µ»ØºÍ½ÓÊܵÄÊÇÒ»ÕÅÓ³Éäͼ£¬¹Ø¼ü×ÖΪ¼¤»îÀàÐÍ¡£ÀýÈ磬ÔÚLSTMÄ£ÐÍÖУ¬±ØÐëͬʱ´æ´¢Êä³ö¼¤»îÇé¿öºÍ¼ÇÒ䵥Ԫ״̬¡£

ÆäËû×¢ÒâÊÂÏ

¿ÉÒÔÓÃrnnTimeStep·½·¨Í¬Ê±´¦Àí¶à¸ö¶ÀÁ¢µÄÑùÀý/Ô¤²â¡£±ÈÈçÔÚÉÏÎÄÌáµ½µÄÌìÆøÔ¤²â°¸ÀýÖУ¬¾Í¿ÉÒÔʹÓÃͬ¸öÉñ¾­ÍøÂç¶Ô¶à¸öµØµã¿ªÕ¹Ô¤²â¡£ÔË×÷·½Ê½Ó붨ÐÍÒÔ¼°ÕýÏò´«µÝ£¯Êä³ö·½·¨Ïàͬ£º¶à¸öÐУ¨ÊäÈëÊý¾ÝÖеĵÚÁãά¶È£©ÓÃÓÚ¶à¸öÑùÀý¡£

Èç¹ûδÉèÖÃÀúÊ·£¯ÒѴ洢״̬£¨×î³õ»òÔÚµ÷ÓÃrnnClearPreviousStateºó£©£¬Ôò»áʹÓÃĬÈϳõʼֵ£¨Á㣩¡£ÕâÓ붨ÐÍʱµÄ·½Ê½Ïàͬ¡£ rnnTimeStep¿ÉÒÔͬʱÓÃÓÚÈÎÒâÊýÁ¿µÄʱ¼ä²½£¬¶ø²»½ö½öÊÇÒ»¸öʱ¼ä²½¡£µ«±ØÐë×¢ÒâµÄÊÇ£º

ÔÚµ¥¸öʱ¼ä²½Ô¤²âÖУ¬Êý¾Ý¾ßÓÐ[numExamples,nIn]µÄ¶þάÐÎ×´£»´ËʱµÄÊä³öÒ²ÊǶþά½á¹¹£¬ÐÎ״Ϊ[numExamples,nOut]

ÔÚ¶à¸öʱ¼ä²½Ô¤²âÖУ¬Êý¾Ý¾ßÓÐ[numExamples,nIn,numTimeSteps]µÄÈýάÐÎ×´£»´ËʱÊä³öÐÎ״Ϊ[numExamples,nOut,numTimeSteps]¡£ÈçǰÎÄËùÊö£¬×îºóÒ»¸öʱ¼ä²½µÄ¼¤»îÇé¿ö»áÏñ֮ǰһÑù±»´æ´¢¡£

ÑùÀýµÄÊýÁ¿ÔÚÁ½´Îµ÷ÓÃrnnTimeStepÖ®¼äÎÞ·¨¸Ä±ä£¨»»ÑÔÖ®£¬Èç¹ûµÚÒ»´ÎʹÓÃrnnTimeStepʱµÄÑùÀýÊýÁ¿Îª3£¬ÄÇô´Ëºóÿ´Îµ÷ÓÃʱµÄÑùÀý¶¼±ØÐëÊÇ3¸ö£©¡£ÖØÖÃÄÚ²¿×´Ì¬ºó£¨Ê¹ÓÃrnnClearPreviousState()£©£¬ÏÂÒ»´Îµ÷ÓÃrnnTimeStep¿ÉÑ¡ÓÃÈÎÒâÊýÁ¿µÄÑùÀý¡£

rnnTimeStep·½·¨²»¸Ä±ä²ÎÊý£»¸Ã·½·¨½öÔÚÍøÂ綨ÐÍÒѾ­Íê³ÉºóʹÓá£

rnnTimeStep·½·¨ÊÊÓÃÓÚ°üº¬µ¥¸öºÍ¶Ñµþ£¯¶à¸öRNN²ãµÄÍøÂ磬ҲÊÊÓÃÓÚRNNÓëÆäËûÀàÐ͵IJ㣨ÀýÈç¾í»ý»ò³íÃܲ㣩Ïà½áºÏµÄÍøÂç¡£

RnnOutputLayerÀàÐ͵IJãûÓÐÈκÎÑ­»·Á¬½Ó£¬Òò´Ë²»´æÔÚÄÚ²¿×´Ì¬¡£

µ¼Èëʱ¼äÐòÁÐÊý¾Ý

RNNµÄÊý¾Ýµ¼Èë±È½Ï¸´ÔÓ£¬ÒòΪ¿ÉÄÜʹÓõÄÊý¾ÝÀàÐͽ϶ࣺһ¶Ô¶à¡¢¶à¶ÔÒ»¡¢³¤¶È¿É±äµÄʱ¼äÐòÁеȡ£±¾½Ú½«½éÉÜDL4JĿǰÒÑʵÏÖµÄÊý¾Ýµ¼Èë»úÖÆ¡£

´Ë´¦½éÉܵķ½·¨²ÉÓÃSequenceRecordReaderDataSetIterator classÀ࣬ÒÔ¼°DataVecµÄCSVSequenceRecordReaderÀà¡£¸Ã·½·¨Ä¿Ç°¿É¼ÓÔØÀ´×ÔÎļþµÄÒÑ·Ö¸ô£¨ÓÃÖÆ±í·û»ò¶ººÅ£©Êý¾Ý£¬Ã¿¸öʱ¼äÐòÁÐΪһ¸öµ¥¶ÀÎļþ¡£ ¸Ã·½·¨»¹Ö§³Ö£º

³¤¶È¿É±äµÄʱ¼äÐòÁÐÊäÈë

Ò»¶Ô¶àºÍ¶à¶ÔÒ»Êý¾Ý¼ÓÔØ£¨ÊäÈëºÍ±êÇ©ÔÚ²»Í¬ÎļþÄÚ£©

·ÖÀàÎÊÌâÖУ¬ÓÉË÷Òýµ½one-hot±íʾ·½·¨µÄ±êǩת»»£¨Èç´Ó¡°2¡±µ½[0,0,1,0]£©

ÔÚÊý¾ÝÎļþ¿ªÊ¼´¦Ìø¹ý¹Ì¶¨£¯Ö¸¶¨ÊýÁ¿µÄÐУ¨Èç×¢ÊÍ»ò±êÌâÐУ©

×¢ÒâÔÚËùÓÐÇé¿öÏ£¬Êý¾ÝÎļþÖеÄÿһÐж¼±íʾһ¸öʱ¼ä²½¡£

 

ʾÀý1£ºµÈ³¤Ê±¼äÐòÁУ¬ÊäÈëºÍ±êÇ©ÔÚ²»Í¬ÎļþÄÚ

¼ÙÉ趨ÐÍÊý¾ÝÖÐÓÐ10¸öʱ¼äÐòÁУ¬ÒÔ20¸öÎļþ±íʾ£º10¸öÎļþΪÿ¸öʱ¼äÐòÁеÄÊäÈ룬10¸öÎļþΪÊä³ö£¯±êÇ©¡£ÏÖÔÚÔÝʱ¼ÙÉèÕâ20¸öÎļþ¶¼°üº¬Í¬ÑùÊýÁ¿µÄʱ¼ä²½£¨¼´ÐÐÊýÏàͬ£©¡£

ΪÁËʹÓÃSequenceRecordReaderDataSetIteratorºÍCSVSequenceRecordReader·½·¨£¬Ê×ÏÈÒª´´½¨Á½¸öCSVSequenceRecordReader¶ÔÏó£¬Ò»¸öÓÃÓÚÊäÈ룬һ¸öÓÃÓÚ±êÇ©£º

SequenceRecordReader featureReader = new CSVSequenceRecordReader(1, ",");
SequenceRecordReader labelReader = new CSVSequenceRecordReader(1, ",");

ÕâÒ»¹¹Ôì·½·¨Ö¸¶¨ÐèÒªÌø¹ýµÄÐÐÊý£¨´Ë´¦Ìø¹ý1ÐУ©ºÍ·Ö¸ô·û£¨´Ë´¦Ê¹ÓöººÅ£©¡£

Æä´Î£¬ÎÒÃÇÐèÒª½«ÕâÁ½¸ö¶ÁÈ¡Æ÷³õʼ»¯£¬Ö¸Ê¾ËüÃǴӺ䦻ñÈ¡Êý¾Ý¡£ÕâÒ»²½¿ÉÒÔÓÃInputSplit¶ÔÏóÍê³É¡£ ¼ÙÉèÎÒÃǵÄʱ¼äÐòÁдøÓбàºÅ£¬ÎļþÃûÈç¡°myInput_0.csv¡±¡¢¡°myInput_1.csv¡±¡¢¡­¡­¡°myLabels_0.csv¡±µÈ¡£·½·¨Ö®Ò»ÊÇʹÓÃNumberedFileInputSplit:

featureReader.initialize(new NumberedFileInputSplit("/path/to/data/myInput_%d.csv", 0, 9));
labelReader.initialize(new NumberedFileInputSplit(/path/to/data/myLabels_%d.csv", 0, 9));

ÔÚÕâÒ»·½·¨ÖУ¬¡°%d¡±±»ÏàÓ¦µÄÊý×ÖÌæ´ú£¬ ´Ë´¦Ê¹ÓÃÊý×Ö0¡«9£¨°üÀ¨0ºÍ9£©¡£

×îºó£¬ÎÒÃÇ¿ÉÒÔ´´½¨×Ô¼ºµÄSequenceRecordReaderdataSetIterator£º

DataSetIterator iter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, miniBatchSize, numPossibleLabels, regression);

ËæºóDataSetIterator¿ÉÒÔ´«µÝ¸øMultiLayerNetwork.fit()£¬ÓÃÓÚÍøÂ綨ÐÍ¡£

²ÎÊýminiBatchSizeÖ¸¶¨Ã¿¸öÅú´ÎÖеÄÑùÀý£¨Ê±¼äÐòÁУ©ÊýÁ¿¡£ÀýÈ磬ÈôÎļþÊýΪ10£¬miniBatchSizeΪ5£¬ ÎÒÃǽ«µÃµ½Á½¸öÊý¾Ý¼¯£¬¹²ÓÐ2¸öÅú´Î£¨DataSet¶ÔÏ󣩣¬Ã¿Åú´ÎÓÐ5¸öʱ¼äÐòÁС£

Çë×¢Ò⣺

·ÖÀàÎÊÌâÖУ¬numPossibleLabelsÊÇÊý¾Ý¼¯ÄÚÀàµÄÊýÁ¿¡£Ó¦Ö¸¶¨regression = false¡£

±êÇ©Êý¾Ý£ºÃ¿ÐÐÒ»¸öÖµ£¬×÷ΪÀàË÷Òý

±êÇ©Êý¾Ý»á±»×Ô¶¯×ª»»Îªone-hot±íʾ·½·¨

»Ø¹éÎÊÌâÖУ¬²»Ê¹ÓÃnumPossibleLabels£¨¿ÉÈÎÒâÖ¸¶¨Öµ£©£¬Ó¦Ö¸¶¨regression = true¡£

¿ÉÒÔ´¦ÀíÈÎÒâÊýÁ¿µÄÊäÈëÓë±êǩֵ£¨Óë·ÖÀ಻ͬ£¬¿ÉÒÔ´¦ÀíÈÎÒâÊýÁ¿µÄÊä³ö£©

Ö¸¶¨regression = trueʱ²»»á¶Ô±êÇ©½øÐд¦Àí

ʾÀý2£ºµÈ³¤Ê±¼äÐòÁУ¬ÊäÈëºÍ±êÇ©ÔÚͬ¸öÎļþÄÚ

½ÓǰһʾÀý£¬ÏÖ¼ÙÉèÊäÈëÊý¾ÝºÍ±êÇ©²¢·ÇλÓÚ²»Í¬µÄÎļþÄÚ£¬¶øÊÇ´æ·ÅÓÚͬ¸öÎļþÖС£µ«Ã¿¸öʱ¼äÐòÁÐÈÔȻλÓÚÒ»¸öµ¥¶ÀµÄÎļþÄÚ¡£

½ØÖ¹µ½DL4J 0.4-rc3.8°æ±¾£¬ÕâÒ»·½·¨½öÏÞÓÚ´¦Àíµ¥ÁÐÊä³ö£¨Ò»¸öÀàË÷Òý»òÕßµ¥Ò»ÊµÊýÖµµÄ»Ø¹éÊä³ö£©

´ËʱÐè´´½¨µ¥¸ö¶ÁÈ¡Æ÷²¢½«Ö®³õʼ»¯¡£ºÍǰһÀýÏàͬ£¬ÎÒÃÇÌø¹ýÒ»¸ö±êÌâÐУ¬Ö¸¶¨¸ñʽΪ°´¶ººÅ·Ö¸ô£¬Í¬Ê±¼ÙÉèÊý¾ÝÎļþÃüÃûΪ¡°myData_0.csv¡±£¬¡­¡­£¬¡°myData_9.csv¡±£º

SequenceRecordReader reader = new CSVSequenceRecordReader(1, ",");
reader.initialize(new NumberedFileInputSplit("/path/to/data/myData_%d.csv", 0, 9));
DataSetIterator iterClassification = new SequenceRecordReaderDataSetIterator(reader, miniBatchSize, numPossibleLabels, labelIndex, false);

miniBatchSizeºÍnumPossibleLabelsÓëǰһÀýÏàͬ¡£´Ë´¦µÄlabelIndexÖ¸¶¨±êÇ©ËùÔÚµÄÁС£ ±ÈÈ磬Èô±êÇ©ÔÚµÚÎåÁУ¬ÔòÖ¸¶¨labelIndex = 4£¨¼´ÁеÄË÷ÒýֵΪ0µ½numColumns-1£©¡£

ÔÚµ¥Ò»Êä³öÖµµÄ»Ø¹éÖУ¬ÎÒÃÇʹÓãº

DataSetIterator iterRegression = new SequenceRecordReaderDataSetIterator(reader, miniBatchSize, -1, labelIndex, true);

ÈçǰÎÄËùÊö£¬»Ø¹éÖв»Ê¹ÓÃnumPossibleLabels²ÎÊý¡£

ʾÀý3£º²»µÈ³¤Ê±¼äÐòÁУ¨¶à¶Ô¶à£©

½ÓǰÁ½Àý£¬¼ÙÉèÿ¸öµ¥¶ÀÑùÀýµÄÊäÈëºÍ±êÇ©³¤¶ÈÏàµÈ£¬µ«²»Í¬µÄʱ¼äÐòÁÐÖ®¼äÔò´æÔÚ³¤¶È²îÒì¡£

ÎÒÃÇ¿ÉÒÔʹÓÃͬÑùµÄ·½·¨£¨CSVSequenceRecordReader and SequenceRecordReaderDataSetIterator£©£¬µ«ÐèÒª¸Ä±ä¹¹ÔìÆ÷£º

DataSetIterator variableLengthIter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, miniBatchSize, numPossibleLabels,
regression, SequenceRecordReaderDataSetIterator
.AlignmentMode.ALIGN_END);

´Ë´¦µÄ²ÎÊýÓëǰһʾÀýÏàͬ£¬Çø±ðÔÚÓÚÌí¼ÓÁËAlignmentMode.ALIGN_END¡£ÕâÒ»¶ÔÆëģʽÊäÈëÈÃSequenceRecordReaderDataSetIterator×öºÃÒÔÏÂÁ½Ïî×¼±¸£º

»ñ֪ʱ¼äÐòÁеij¤¶È¿ÉÄܲ»ÏàµÈ

½«Ã¿¸öµ¥¶ÀÑùÀýÖеÄÊäÈëÓë±êÇ©½øÐÐ¶ÔÆë£¬Ê¹Æä×îÖÕÖµ³öÏÖÔÚͬһ¸öʱ¼ä²½¡£

×¢Ò⣬Èç¹ûÌØÕ÷Óë±êÇ©µÄ³¤¶ÈʼÖÕÏàͬ£¨ÈçʾÀý3µÄ¼ÙÉ裩£¬ÔòÁ½¸ö¶ÔÆëģʽ£¨AlignmentMode.ALIGN_ENDºÍAlignmentMode.ALIGN_START£©»á¸ø³öÍêÈ«ÏàͬµÄÊä³ö¡£¶ÔÆëģʽѡÏî»áÔÚÏÂÒ»½ÚÖнéÉÜ¡£

ÁíÍâÇë×¢Ò⣬³¤¶È¿É±äµÄʱ¼äÐòÁÐʼÖÕ´ÓÊý¾Ý×éÖеÚ0ʱ¼ä²½¿ªÊ¼£¬ÈçÐèÒªÌîÁ㣬Ôò»áÔÚʱ¼äÐòÁнáÊøºóÌí¼ÓÁã¡£

ÓëʾÀý1ºÍ2²»Í¬£¬ÉÏÊövariableLengthIterÑùÀý²úÉúµÄDataSet¶ÔÏó»¹½«°üÀ¨ÊäÈëºÍÑÚÄ£Êý×飬ÈçǰÎÄËùÊö¡£

ʾÀý4£º¶à¶ÔÒ»ºÍÒ»¶Ô¶àÊý¾Ý

ʾÀý3ÖеÄAlignmentMode¹¦ÄÜ»¹¿ÉÒÔÓÃÓÚʵÏÖ¶à¶ÔÒ»µÄRNNÐòÁзÖÀàÆ÷¡£ÈÃÎÒÃǼÙÉ裺

ÊäÈëºÍ±êÇ©¸÷λÓÚ²»Í¬µÄÒÑ·Ö¸ôÎļþÄÚ

±êÇ©Îļþ°üº¬µ¥¸öÐУ¨Ê±¼ä²½£©£¨·ÖÀàÓõÄÀàË÷Òý£¬»òÕßÒ»¸ö»ò¶à¸ö»Ø¹éÊýÖµ£©

²»Í¬ÑùÀýµÄÊä³ö³¤¶ÈÓпÉÄܲ»Ïàͬ£¨¿ÉÑ¡£©

ʾÀý3ÖеÄͬһ·½·¨Æäʵ»¹¿ÉÒÔÈçϲÙ×÷£º

DataSetIterator variableLengthIter = new SequenceRecordReaderDataSetIterator(featureReader, labelReader, miniBatchSize, numPossibleLabels,
regression, SequenceRecordReaderDataSetIterator
.AlignmentMode.ALIGN_END);

¶ÔÆëģʽÏà¶ÔÈÝÒ×Àí½â¡£ËüÃÇÖ¸¶¨ÊÇÔڽ϶Ìʱ¼äÐòÁÐµÄÆðʼ»¹Êǽáβ´¦ÌîÁã¡£ÏÂͼÃèÊöÁËÕâÒ»¹ý³Ì£¬²¢±ê³öÑÚÄ£Êý×飨È籾ҳǰÎÄËùÊö£©£º

Ò»¶Ô¶àÇé¾°£¨ÓëǰһÀýÏà·Â£¬µ«ÊäÈë½öÓÐÒ»¸ö£©¿ÉÒÔÓÃAlignmentMode.ALIGN_STARTÀ´´¦Àí¡£

×¢Ò⣬ÔÚ¶¨ÐÍÊý¾Ý°üº¬·ÇµÈ³¤Ê±¼äÐòÁеÄÇé¿öÏ£¬¸÷¸öÑùÀýµÄ±êÇ©ºÍÊäÈë»á±»·Ö±ð¶ÔÆë£¬Ëæºó»á°´ÐèÒª¶Ô½Ï¶ÌµÄʱ¼äÐòÁнøÐÐÌîÁã¡£

Ìæ´ú·½·¨£ºÔËÓÃ×Ô¶¨ÒåDataSetIterator

ÓÐЩʱºò£¬ÎÒÃÇ¿ÉÄÜÐèÒª½øÐв»·ûºÏ³£¹æÇé¾°µÄÊý¾Ýµ¼Èë¡£·½·¨Ö®Ò»ÊÇÔËÓÃ×Ô¶¨ÒåµÄDataSetIterator¡£DataSetIteratorÖ»ÊÇÓÃÓÚµü´úDataSet¶ÔÏóµÄ½Ó¿Ú£¬ÕâЩ¶ÔÏó·â×°ÁËÊäÈëºÍÄ¿±êINDArrays£¬ÒÔ¼°ÊäÈëºÍ±êÇ©ÑÚÄ£Êý×飨¿ÉÑ¡£©¡£

ÐèҪעÒâµÄÊÇ£¬ÕâÒ»·½·¨µÄ¼¶±ð½ÏµÍ£ºÔËÓÃDataSetIteratorʱ£¬±ØÐëÊÖ¶¯´´½¨ËùÐèµÄÊäÈëºÍ±êÇ©INDArrays£¬ÒÔ¼°ÊäÈëºÍ±êÇ©ÑÚÄ£Êý×飨ÈçÐèÒª£©¡£µ«ÕâÒ»·½·¨¿ÉÒÔÈÃÊý¾Ý¼ÓÔØ·½Ê½±äµÃÊ®·ÖÁé»î¡£

±¾·½·¨µÄʵ¼ùÓ¦Óÿɲο¼ÎÄ×Ö/×Ö·ûʾÀýÒÔ¼°Word2VecµçÓ°ÆÀÂÛÇéÐ÷ʾÀý¶Ôµü´úÆ÷µÄÓ¦Óá£

×¢£ºÔÚ´´½¨×Ô¶¨ÒåµÄDataSetIteratorʱ£¬°üÀ¨ÊäÈëÌØÕ÷¡¢±êÇ©ÒÔ¼°ÈκÎÑÚÄ£Êý×éÔÚÄÚµÄÊý×é¶¼Ó¦µ±°´¡°f¡±£¨fortran£©Ë³Ðò´´½¨¡£ÓйØÊý×é˳ÐòµÄÏêÇéÇë²ÎÔÄND4JÓû§Ö¸ÄÏ¡£ÔÚʵ¼Ê²Ù×÷ÖУ¬ÕâÒâζ×ÅҪʹÓÃNd4j.create·½·¨À´Ö¸¶¨Êý×é˳Ðò£ºNd4j.create(new int[]{numExamples, inputSize, timeSeriesLength},'f')¡£ËäÈ»¡°c¡±Ë³ÐòµÄÊý×éÒ²¿ÉÒÔÔËÐУ¬µ«ÓÉÓÚÔÚ½øÐÐijЩÔËËãʱÐèÒªÏȽ«Êý×鏴֯µ½¡°f¡±Ë³Ðò£¬»áµ¼ÖÂÐÔÄÜÓÐËùϽµ¡£

ʾÀý

DL4JĿǰÌṩÏÂÁÐÑ­»·ÍøÂçʾÀý£º

³õ¼¶ÊÓÆµÖ¡·ÖÀàʾÀý£¬µ¼ÈëÊÓÆµÎļþ£¨.mp4¸ñʽ£©£¬¶Ôÿһ֡ÖеÄÐÎ×´½øÐзÖÀà

word2vecÐòÁзÖÀàʾÀý£¬Ê¹ÓÃÔ¤¶¨ÐÍ´ÊÏòÁ¿ºÍÒ»¸öÑ­»·Éñ¾­ÍøÂ罫µçÓ°ÆÀÂÛ·ÖΪÕýÃæºÍ¸ºÃæÁ½Àà¡£

 
   
3085 ´Îä¯ÀÀ       28
Ïà¹ØÎÄÕÂ

»ùÓÚͼ¾í»ýÍøÂçµÄͼÉî¶Èѧϰ
×Ô¶¯¼ÝÊ»ÖеÄ3DÄ¿±ê¼ì²â
¹¤Òµ»úÆ÷ÈË¿ØÖÆÏµÍ³¼Ü¹¹½éÉÜ
ÏîĿʵս£ºÈçºÎ¹¹½¨ÖªÊ¶Í¼Æ×
 
Ïà¹ØÎĵµ

5GÈ˹¤ÖÇÄÜÎïÁªÍøµÄµäÐÍÓ¦ÓÃ
Éî¶ÈѧϰÔÚ×Ô¶¯¼ÝÊ»ÖеÄÓ¦ÓÃ
ͼÉñ¾­ÍøÂçÔÚ½»²æÑ§¿ÆÁìÓòµÄÓ¦ÓÃÑо¿
ÎÞÈË»úϵͳԭÀí
Ïà¹Ø¿Î³Ì

È˹¤ÖÇÄÜ¡¢»úÆ÷ѧϰ&TensorFlow
»úÆ÷ÈËÈí¼þ¿ª·¢¼¼Êõ
È˹¤ÖÇÄÜ£¬»úÆ÷ѧϰºÍÉî¶Èѧϰ
ͼÏñ´¦ÀíËã·¨·½·¨Óëʵ¼ù