template<class ArrayT>
Definition at line 121 of file MatNorm2DistSSE.h. References ILOG_DEBUG, ILOG_ERROR, ILOG_VAR, Impala::IntAlignUp(), MatE(), MatMul(), MatNrCol(), MatNrRow(), MatPadZeros(), POINTER_ALIGNED, SIMD_MatSquareAndSumAxis0(), SIMD_MatSquareAndSumAxis1(), Impala::Timer::SplitTime(), and Impala::Timer::SplitTimeStr(). Referenced by MatNorm2DistInternal(). 00122 { 00123 ILOG_VAR(Core.Matrix.MatNorm2DistSSE); 00124 if (MatNrCol(aT) != MatNrRow(b)) { 00125 ILOG_ERROR("MatNorm2DistSSE operands: dimensionality problem"); 00126 } 00127 00128 if(!POINTER_ALIGNED(MatE(aT, 0, 0))) 00129 { 00130 ILOG_ERROR("Pointer alignment error aT"); 00131 } 00132 if(!POINTER_ALIGNED(MatE(b, 0, 0))) 00133 { 00134 ILOG_ERROR("Pointer alignment error b"); 00135 } 00136 00137 Timer timer; 00138 int alignCount = 4; 00139 if(sizeof(typename ArrayT::StorType) == 8) 00140 alignCount = 2; 00141 00142 00143 // aa = sum(multiply(a,a), axis=0) # aa = sum(a.*a,1) 00144 ArrayT* aaT = SIMD_MatSquareAndSumAxis1(aT); 00145 00146 ILOG_DEBUG("aaT: " << timer.SplitTimeStr()); 00147 00148 // bb = sum(multiply(b,b), axis=0) # bb = sum(b.*b,1) 00149 // this operation ensures that later on the number of columns is always 00150 // a multiple of 2 inside bb 00151 ArrayT* bb = SIMD_MatSquareAndSumAxis0(b, alignCount); 00152 00153 ILOG_DEBUG("bb: " << timer.SplitTimeStr()); 00154 00155 // ab = dot(transpose(a),b) # ab = a'*b 00156 //ArrayT* aT = MatTranspose(a); 00157 //ILOG_DEBUG("transpose: " << timer.SplitTimeStr()); 00158 00159 ILOG_DEBUG("aT " << MatNrRow(aT) << " " << MatNrCol(aT)); 00160 ILOG_DEBUG("b " << MatNrRow(b) << " " << MatNrCol(b)); 00161 00162 ArrayT* paddedB = MatPadZeros(b, MatNrRow(b), IntAlignUp(MatNrCol(b), alignCount)); 00163 ArrayT* ab = MatMul(aT, paddedB); 00164 delete paddedB; 00165 00166 ILOG_DEBUG("matmul: " << timer.SplitTimeStr()); 00167 00168 // return sqrt(abs(transpose(repmat(aa,bb.shape[0],1)) + repmat(bb,aa.shape[0],1) - 2*ab)) 00169 // #d = sqrt(abs(repmat(aa',[1 size(bb,2)]) + repmat(bb,[size(aa,2) 1]) - 2*ab)); 00170 00171 //ILOG_DEBUG("aaT " << MatNrRow(aaT) << " " << MatNrCol(aaT)); 00172 //ILOG_DEBUG("bb " << MatNrRow(bb) << " " << MatNrCol(bb)); 00173 00174 if(sizeof(typename ArrayT::StorType) == 8) 00175 { 00176 #pragma omp parallel for 00177 for (int i=0 ; i<MatNrRow(aaT) ; i++) 00178 { 00179 const __m128d tmp = _mm_set1_pd(*MatE(aaT, i, 0)); 00180 const __m128d minusTwo = _mm_set1_pd(-2.0); 00181 const UInt64 mask = 0x7FFFFFFFFFFFFFFFL; 00182 const __m128d absMask = _mm_set1_pd(*(double*)(&mask)); 00183 __m128d* baseBB = (__m128d*)(MatE(bb, 0, 0)); 00184 __m128d* baseAB = (__m128d*)(MatE(ab, i, 0)); 00185 const int SSELength = MatNrCol(bb) / 2; 00186 for(int j = 0; j < SSELength; j++) 00187 { 00188 __m128d intermediate = _mm_add_pd(_mm_add_pd(baseBB[j], tmp), _mm_mul_pd(minusTwo, baseAB[j])); 00189 baseAB[j] = _mm_sqrt_pd(_mm_and_pd(intermediate, absMask)); 00190 } 00191 } 00192 } 00193 else 00194 { 00195 #pragma omp parallel for 00196 for (int i=0 ; i<MatNrRow(aaT) ; i++) 00197 { 00198 const __m128 tmp = _mm_set1_ps(*MatE(aaT, i, 0)); 00199 const __m128 minusTwo = _mm_set1_ps(-2.0); 00200 const UInt32 mask = 0x7FFFFFFF; 00201 const __m128 absMask = _mm_set1_ps(*(Real32*)(&mask)); 00202 __m128* baseBB = (__m128*)(MatE(bb, 0, 0)); 00203 __m128* baseAB = (__m128*)(MatE(ab, i, 0)); 00204 const int SSELength = MatNrCol(bb) / 4; 00205 for(int j = 0; j < SSELength; j++) 00206 { 00207 __m128 intermediate = _mm_add_ps(_mm_add_ps(baseBB[j], tmp), _mm_mul_ps(minusTwo, baseAB[j])); 00208 baseAB[j] = _mm_sqrt_ps(_mm_and_ps(intermediate, absMask)); 00209 } 00210 } 00211 } 00212 ILOG_DEBUG("sqrt(abs(aa + bb - 2ab)): " << timer.SplitTimeStr()); 00213 delete bb; 00214 delete aaT; 00215 00216 ILOG_DEBUG("delete: " << timer.SplitTimeStr()); 00217 //WriteRaw(a, "matrix_a.raw", &Util::Database::GetInstance(), 1); 00218 //WriteRaw(b, "matrix_b.raw", &Util::Database::GetInstance(), 1); 00219 //WriteRaw(repAAT, "matrix_c.raw", &Util::Database::GetInstance(), 1); 00220 //ILOG_INFO("cpu-matnorm2dist-total: " << timer.SplitTime()); 00221 ILOG_DEBUG(timer.SplitTime() << " (cpu-matnorm2dist-total)"); 00222 return ab; 00223 }
Here is the call graph for this function:
|