Home || Visual Search || Applications || Architecture || Important Messages || OGL || Src

template<class ArrayT>
ArrayT* Impala::Core::Matrix::MatNorm2DistSSE ( ArrayT *  aT,
ArrayT *  b 
) [inline]

Definition at line 121 of file MatNorm2DistSSE.h.

References ILOG_DEBUG, ILOG_ERROR, ILOG_VAR, Impala::IntAlignUp(), MatE(), MatMul(), MatNrCol(), MatNrRow(), MatPadZeros(), POINTER_ALIGNED, SIMD_MatSquareAndSumAxis0(), SIMD_MatSquareAndSumAxis1(), Impala::Timer::SplitTime(), and Impala::Timer::SplitTimeStr().

Referenced by MatNorm2DistInternal().

00122 {
00123     ILOG_VAR(Core.Matrix.MatNorm2DistSSE);
00124     if (MatNrCol(aT) != MatNrRow(b)) {
00125         ILOG_ERROR("MatNorm2DistSSE operands: dimensionality problem");
00126     }
00127 
00128     if(!POINTER_ALIGNED(MatE(aT, 0, 0)))
00129     {
00130         ILOG_ERROR("Pointer alignment error aT");
00131     }
00132     if(!POINTER_ALIGNED(MatE(b, 0, 0)))
00133     {
00134         ILOG_ERROR("Pointer alignment error b");
00135     }
00136 
00137     Timer timer;
00138     int alignCount = 4;
00139     if(sizeof(typename ArrayT::StorType) == 8)
00140         alignCount = 2;
00141     
00142 
00143     // aa = sum(multiply(a,a), axis=0)      # aa = sum(a.*a,1)
00144     ArrayT* aaT = SIMD_MatSquareAndSumAxis1(aT);
00145 
00146     ILOG_DEBUG("aaT: " << timer.SplitTimeStr());
00147 
00148     // bb = sum(multiply(b,b), axis=0)      # bb = sum(b.*b,1)
00149     // this operation ensures that later on the number of columns is always
00150     // a multiple of 2 inside bb
00151     ArrayT* bb = SIMD_MatSquareAndSumAxis0(b, alignCount);
00152 
00153     ILOG_DEBUG("bb: " << timer.SplitTimeStr());
00154 
00155     // ab = dot(transpose(a),b)             # ab = a'*b
00156     //ArrayT* aT = MatTranspose(a);
00157     //ILOG_DEBUG("transpose: " << timer.SplitTimeStr());
00158 
00159     ILOG_DEBUG("aT " << MatNrRow(aT) << " " << MatNrCol(aT));
00160     ILOG_DEBUG("b  " << MatNrRow(b) << " " << MatNrCol(b));
00161 
00162     ArrayT* paddedB = MatPadZeros(b, MatNrRow(b), IntAlignUp(MatNrCol(b), alignCount));
00163     ArrayT* ab = MatMul(aT, paddedB);
00164     delete paddedB;
00165 
00166     ILOG_DEBUG("matmul: " << timer.SplitTimeStr());
00167 
00168     // return sqrt(abs(transpose(repmat(aa,bb.shape[0],1)) + repmat(bb,aa.shape[0],1) - 2*ab))
00169     // #d = sqrt(abs(repmat(aa',[1 size(bb,2)]) + repmat(bb,[size(aa,2) 1]) - 2*ab));
00170 
00171     //ILOG_DEBUG("aaT " << MatNrRow(aaT) << " " << MatNrCol(aaT));
00172     //ILOG_DEBUG("bb  " << MatNrRow(bb) << " " << MatNrCol(bb));
00173 
00174     if(sizeof(typename ArrayT::StorType) == 8)
00175     {
00176         #pragma omp parallel for
00177         for (int i=0 ; i<MatNrRow(aaT) ; i++)
00178         {
00179             const __m128d tmp = _mm_set1_pd(*MatE(aaT, i, 0));
00180             const __m128d minusTwo = _mm_set1_pd(-2.0);
00181             const UInt64 mask = 0x7FFFFFFFFFFFFFFFL;
00182             const __m128d absMask = _mm_set1_pd(*(double*)(&mask));
00183             __m128d* baseBB = (__m128d*)(MatE(bb, 0, 0));
00184             __m128d* baseAB = (__m128d*)(MatE(ab, i, 0));
00185             const int SSELength = MatNrCol(bb) / 2;
00186             for(int j = 0; j < SSELength; j++)
00187             {
00188                 __m128d intermediate = _mm_add_pd(_mm_add_pd(baseBB[j], tmp), _mm_mul_pd(minusTwo, baseAB[j]));
00189                 baseAB[j] = _mm_sqrt_pd(_mm_and_pd(intermediate, absMask));
00190             }
00191         }
00192     }
00193     else
00194     {
00195         #pragma omp parallel for
00196         for (int i=0 ; i<MatNrRow(aaT) ; i++)
00197         {
00198             const __m128 tmp = _mm_set1_ps(*MatE(aaT, i, 0));
00199             const __m128 minusTwo = _mm_set1_ps(-2.0);
00200             const UInt32 mask = 0x7FFFFFFF;
00201             const __m128 absMask = _mm_set1_ps(*(Real32*)(&mask));
00202             __m128* baseBB = (__m128*)(MatE(bb, 0, 0));
00203             __m128* baseAB = (__m128*)(MatE(ab, i, 0));
00204             const int SSELength = MatNrCol(bb) / 4;
00205             for(int j = 0; j < SSELength; j++)
00206             {
00207                 __m128 intermediate = _mm_add_ps(_mm_add_ps(baseBB[j], tmp), _mm_mul_ps(minusTwo, baseAB[j]));
00208                 baseAB[j] = _mm_sqrt_ps(_mm_and_ps(intermediate, absMask));
00209             }
00210         }
00211     }
00212     ILOG_DEBUG("sqrt(abs(aa + bb - 2ab)): " << timer.SplitTimeStr());
00213     delete bb;
00214     delete aaT;
00215 
00216     ILOG_DEBUG("delete: " << timer.SplitTimeStr());
00217 //WriteRaw(a, "matrix_a.raw", &Util::Database::GetInstance(), 1);
00218 //WriteRaw(b, "matrix_b.raw", &Util::Database::GetInstance(), 1);
00219 //WriteRaw(repAAT, "matrix_c.raw", &Util::Database::GetInstance(), 1);
00220     //ILOG_INFO("cpu-matnorm2dist-total: " << timer.SplitTime());
00221     ILOG_DEBUG(timer.SplitTime() << " (cpu-matnorm2dist-total)");
00222     return ab;
00223 }

Here is the call graph for this function:


Generated on Thu Jan 13 09:20:15 2011 for ImpalaSrc by  doxygen 1.5.1