@@ -25,7 +25,7 @@ bool ConvolutionKernelBase::Validate(const Params& p) const {
25
25
return true ;
26
26
}
27
27
28
- JitConstants ConvolutionKernelBase::GetJitConstants (const convolution_params& params, const DispatchData& dispatchData) const {
28
+ JitConstants ConvolutionKernelBase::GetJitConstants (const convolution_params& params, const DispatchData& dispatchData, bool LoopUnrollParams ) const {
29
29
JitConstants mem_consts = WeightBiasKernelBase::GetJitConstants (params);
30
30
mem_consts.Merge (GetFusedPrimitivesJitConstants (params, dispatchData));
31
31
const auto & padding = params.padding_begin ;
@@ -83,19 +83,21 @@ JitConstants ConvolutionKernelBase::GetJitConstants(const convolution_params& pa
83
83
}
84
84
}
85
85
86
- std::vector<uint32_t > unrollLoopParams{params.filterSize .x ,
87
- params.filterSize .y ,
88
- (uint32_t )dispatchData.gemmStyle .globalWorkSizeDX ,
89
- (uint32_t )dispatchData.gemmStyle .globalWorkSizeDY ,
90
- (uint32_t )dispatchData.gemmStyle .globalWorkSizeDZ ,
91
- (uint32_t )dispatchData.gemmStyle .subBlockDimM ,
92
- (uint32_t )dispatchData.gemmStyle .subBlockDimK ,
93
- (uint32_t )dispatchData.gemmStyle .subBlockDimN };
86
+ if (LoopUnrollParams) {
87
+ std::vector<uint32_t > unrollLoopParams{params.filterSize .x ,
88
+ params.filterSize .y ,
89
+ (uint32_t )dispatchData.gemmStyle .globalWorkSizeDX ,
90
+ (uint32_t )dispatchData.gemmStyle .globalWorkSizeDY ,
91
+ (uint32_t )dispatchData.gemmStyle .globalWorkSizeDZ ,
92
+ (uint32_t )dispatchData.gemmStyle .subBlockDimM ,
93
+ (uint32_t )dispatchData.gemmStyle .subBlockDimK ,
94
+ (uint32_t )dispatchData.gemmStyle .subBlockDimN };
94
95
95
- auto loopCount = *std::max_element (unrollLoopParams.begin (), unrollLoopParams.end ());
96
+ auto loopCount = *std::max_element (unrollLoopParams.begin (), unrollLoopParams.end ());
96
97
97
- JitConstants mem_consts_loop = MakeLoopUnrollParamsJitConstants (loopCount);
98
- mem_consts.Merge (mem_consts_loop);
98
+ JitConstants mem_consts_loop = MakeLoopUnrollParamsJitConstants (loopCount);
99
+ mem_consts.Merge (mem_consts_loop);
100
+ }
99
101
100
102
return mem_consts;
101
103
}
0 commit comments