@@ -69,7 +69,7 @@ class MKLDNNBatchNormLayer : public MKLDNNLayer<Dtype>, public Layer<Dtype> {
6969 , bwd_top_diff(), bwd_bottom_diff()
7070 , BatchNormFwd_pd(), BatchNormBwd_pd()
7171 , scaleshift_memory(), bwd_scaleshift_diff_memory()
72- , output_memory(), bwd_bottom_diff_memory(), inplace_buffer_memory()
72+ , output_memory(), bwd_bottom_diff_memory()
7373 , input_primitive(), bwd_top_diff_primitive()
7474 {
7575 PERFORMANCE_EVENT_ID_RESET (perf_id_fw_);
@@ -95,12 +95,10 @@ class MKLDNNBatchNormLayer : public MKLDNNLayer<Dtype>, public Layer<Dtype> {
9595 void InitBatchNormBwd (const vector<Blob<Dtype>*>& top,
9696 const vector<bool >& propagate_down,
9797 const vector<Blob<Dtype>*>& bottom);
98- void InitBatchNormFwdPrimitive (int stats_batch_idx, bool inplace );
99- void InitBatchNormBwdPrimitive (int stats_batch_idx, bool inplace );
98+ void InitBatchNormFwdPrimitive (int stats_batch_idx);
99+ void InitBatchNormBwdPrimitive (int stats_batch_idx);
100100 template <bool diff> shared_ptr<memory> GetStatsBatchMemory (
101101 shared_ptr<MKLDNNMemoryDescriptor<Dtype, diff> > mkldnn_data, int idx);
102- template <bool diff> shared_ptr<memory> GetStatsBatchMemoryInplace (
103- shared_ptr<MKLDNNMemoryDescriptor<Dtype, diff> > mkldnn_data, int idx, shared_ptr<memory > buffer_memory);
104102 void InitStatsBatchVars (int batch_size);
105103 shared_ptr<MKLDNNData<Dtype> > fwd_top_data, fwd_bottom_data;
106104 shared_ptr<MKLDNNDiff<Dtype> > bwd_top_diff, bwd_bottom_diff;
@@ -112,8 +110,8 @@ class MKLDNNBatchNormLayer : public MKLDNNLayer<Dtype>, public Layer<Dtype> {
112110
113111 shared_ptr<memory> scaleshift_memory, bwd_scaleshift_diff_memory;
114112 shared_ptr<memory> output_memory, bwd_bottom_diff_memory;
115- shared_ptr<memory> inplace_buffer_memory;
116- vector<shared_ptr<memory> > input_stats, output_stats, top_diff_stats, bottom_diff_stats, input_inplace_buffer ;
113+
114+ vector<shared_ptr<memory> > input_stats, output_stats, top_diff_stats, bottom_diff_stats;
117115
118116 shared_ptr<primitive> input_primitive, bwd_top_diff_primitive;
119117
@@ -124,6 +122,7 @@ class MKLDNNBatchNormLayer : public MKLDNNLayer<Dtype>, public Layer<Dtype> {
124122 int stats_batch_size_;
125123 shared_ptr<Blob<Dtype> > scaleshift_blob_;
126124 shared_ptr<Blob<Dtype> > scaleshift_acc_;
125+ Blob<Dtype> inplace_buffer;
127126
128127 PERFORMANCE_EVENT_ID_DECL (perf_id_fw_);
129128 PERFORMANCE_EVENT_ID_DECL (perf_id_bw_);
@@ -224,7 +223,7 @@ class MKLDNNInnerProductLayer : public MKLDNNLayer<Dtype> , public InnerProductL
224223 , bwdd_top_diff_primitive, bwdd_weights_data_primitive
225224 , bwdw_top_diff_primitive, bwdw_bottom_data_primitive;
226225 int32_t w_, h_;
227-
226+
228227 /* In case of (iter_size > 1) we need additional buffers */
229228 shared_ptr<MKLDNNDiff<Dtype> > bwdw_weights_diff_iter, bwdw_bias_diff_iter;
230229 shared_ptr<memory> bwdw_weights_diff_memory_iter, bwdw_bias_diff_memory_iter;
@@ -322,13 +321,14 @@ class MKLDNNPoolingLayer : public MKLDNNLayer<Dtype>, public Layer<Dtype> {
322321 ,const vector<Blob<Dtype>*>& bottom);
323322 virtual void Backward_gpu (const vector<Blob<Dtype>*>& top, const vector<bool >& propagate_down
324323 ,const vector<Blob<Dtype>*>& bottom);
324+ virtual void compute_output_shape (const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
325325
326326private:
327327 void InitPoolingFwd (const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
328328 void InitPoolingBwd (const vector<Blob<Dtype>*>& bottom
329329 , const vector<bool >& propagate_down
330330 , const vector<Blob<Dtype>*>& top);
331-
331+
332332 shared_ptr<MKLDNNData<Dtype>> fwd_bottom_data, fwd_top_data;
333333 shared_ptr<MKLDNNDiff<Dtype>> bwd_top_diff, bwd_bottom_diff;
334334 shared_ptr<pooling_forward::primitive_desc> poolingFwd_pd;
@@ -408,7 +408,7 @@ class MKLDNNConcatLayer : public MKLDNNLayer<Dtype> , public Layer<Dtype> {
408408 : MKLDNNLayer<Dtype>(), Layer<Dtype>(param),
409409 concatFwd_pd(), fwd_output_memory(),
410410 bwd_reorder_input_memory(), bwd_reorder_output_memory(),
411- fwd_top_data(), fwd_bottom_data(), split_channels () {
411+ fwd_top_data(), fwd_bottom_data(), split_dims () {
412412 PERFORMANCE_EVENT_ID_RESET (perf_id_fw_);
413413 PERFORMANCE_EVENT_ID_RESET (perf_id_bw_);
414414 }
@@ -440,7 +440,7 @@ class MKLDNNConcatLayer : public MKLDNNLayer<Dtype> , public Layer<Dtype> {
440440 shared_ptr<MKLDNNDiff<Dtype> > bwd_top_diff;
441441 vector<shared_ptr<MKLDNNDiff<Dtype> > > bwd_bottom_diff;
442442 vector<MKLDNNPrimitive<Dtype> > reorders;
443- vector<int > split_channels ;
443+ vector<int > split_dims ;
444444
445445 int32_t num_, width_, height_, channels_, num_concats_;
446446 int concat_dimension;
0 commit comments