00001 #ifndef Impala_Core_Matrix_MatNorm2DistSSE_h
00002 #define Impala_Core_Matrix_MatNorm2DistSSE_h
00003
00004 #include "Core/Matrix/MatFunc.h"
00005 #include "Core/Matrix/MatSum.h"
00006 #include "Core/Matrix/MatMul.h"
00007 #include "Core/Array/Abs.h"
00008 #include "Core/Array/Add.h"
00009 #include "Core/Array/Mul.h"
00010 #include "Core/Array/MulVal.h"
00011 #include "Core/Array/Sqrt.h"
00012 #include "Basis/Timer.h"
00013
00014 #include <xmmintrin.h>
00015 #include <stdint.h>
00016 #define POINTER_ALIGNED(x) (!((( intptr_t)x) & 0xF))
00017
00018 namespace Impala
00019 {
00020 namespace Core
00021 {
00022 namespace Matrix
00023 {
00024
00025 inline Mat*
00026 SIMD_MatSquareAndSumAxis1(Mat* m)
00027 {
00028 int n = MatNrRow(m);
00029 Mat* res = MatCreate<Mat>(n, 1);
00030 for (int i=0 ; i<n ; i++)
00031 {
00032 __m128d tmp = _mm_setzero_pd();
00033 __m128d* base = (__m128d*) MatE(m, i, 0);
00034 for (int j=0 ; j<IntAlignDown(MatNrCol(m), 2) / 2 ; j++)
00035 tmp = _mm_add_pd(tmp, _mm_mul_pd(base[j], base[j]));
00036
00037 __m128d shuffle = _mm_shuffle_pd(tmp, tmp, _MM_SHUFFLE2(0, 1));
00038 Real64 result = _mm_cvtsd_f64(_mm_add_pd(tmp, shuffle));
00039 for (int j=IntAlignDown(MatNrCol(m), 2); j<MatNrCol(m) ; j++)
00040 result += result + *MatE(m, i, j);
00041 *MatE(res, i, 0) = result;
00042 }
00043 return res;
00044 }
00045
00046 inline Mat32*
00047 SIMD_MatSquareAndSumAxis1(Mat32* m)
00048 {
00049 int n = MatNrRow(m);
00050 Mat32* res = MatCreate<Mat32>(n, 1);
00051 for (int i=0 ; i<n ; i++)
00052 {
00053 __m128 tmp = _mm_setzero_ps();
00054 __m128* base = (__m128*) MatE(m, i, 0);
00055 for (int j=0 ; j<IntAlignDown(MatNrCol(m), 4) / 4 ; j++)
00056 tmp = _mm_add_ps(tmp, _mm_mul_ps(base[j], base[j]));
00057
00058
00059
00060
00061
00062 __m128 shuffle, sum;
00063 shuffle = _mm_shuffle_ps(tmp, tmp, _MM_SHUFFLE(1, 0, 3, 2));
00064 sum = _mm_add_ps(tmp, shuffle) ;
00065 shuffle = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(2, 3, 0, 1));
00066 Real32 result = _mm_cvtss_f32(_mm_add_ps(sum, shuffle));
00067 for (int j=IntAlignDown(MatNrCol(m), 4); j<MatNrCol(m) ; j++)
00068 result += result + *MatE(m, i, j);
00069 *MatE(res, i, 0) = result;
00070 }
00071 return res;
00072 }
00073
00074
00075
00076
00077 template<class ArrayT>
00078 inline ArrayT*
00079 SIMD_MatSquareAndSumAxis0(ArrayT* m, int alignCount)
00080 {
00081 int cn = MatNrCol(m);
00082 ArrayT* res = MatCreate<ArrayT>(1, IntAlignUp(cn, alignCount));
00083 SetVal(res, res, 0.0);
00084 for (int i=0 ; i<MatNrRow(m) ; i++)
00085 for (int j=0 ; j<cn ; j++)
00086 *MatE(res, 0, j) = *MatE(res, 0, j) + (*MatE(m, i, j)) * (*MatE(m, i, j));
00087 for(int j = 0; j < IntAlignUp(cn, alignCount) - cn; j++)
00088 {
00089 *MatE(res, 0, cn + j) = *MatE(res, 0, j);
00090 }
00091 return res;
00092 }
00093
00094 template<class ArrayT>
00095 ArrayT*
00096 MatPadZeros(ArrayT* data, int rowCount, int columnCount)
00097 {
00098 ArrayT* res = MatCreate<ArrayT>(rowCount, columnCount);
00099 for(int i = 0; i < MatNrRow(data); i++)
00100 {
00101 int j = 0;
00102 for(; j < MatNrCol(data); j++)
00103 {
00104 *MatE(res, i, j) = *MatE(data, i, j);
00105 }
00106 for(; j < columnCount; j++);
00107 {
00108 *MatE(res, i, j) = 0.0;
00109 }
00110 }
00111 for(int i = MatNrRow(data); i < rowCount; i++)
00112 for(int j = 0; j < columnCount; j++)
00113 *MatE(res, i, j) = 0.0;
00114 return res;
00115 }
00116
00117
00118
00119 template<class ArrayT>
00120 inline ArrayT*
00121 MatNorm2DistSSE(ArrayT* aT, ArrayT* b)
00122 {
00123 ILOG_VAR(Core.Matrix.MatNorm2DistSSE);
00124 if (MatNrCol(aT) != MatNrRow(b)) {
00125 ILOG_ERROR("MatNorm2DistSSE operands: dimensionality problem");
00126 }
00127
00128 if(!POINTER_ALIGNED(MatE(aT, 0, 0)))
00129 {
00130 ILOG_ERROR("Pointer alignment error aT");
00131 }
00132 if(!POINTER_ALIGNED(MatE(b, 0, 0)))
00133 {
00134 ILOG_ERROR("Pointer alignment error b");
00135 }
00136
00137 Timer timer;
00138 int alignCount = 4;
00139 if(sizeof(typename ArrayT::StorType) == 8)
00140 alignCount = 2;
00141
00142
00143
00144 ArrayT* aaT = SIMD_MatSquareAndSumAxis1(aT);
00145
00146 ILOG_DEBUG("aaT: " << timer.SplitTimeStr());
00147
00148
00149
00150
00151 ArrayT* bb = SIMD_MatSquareAndSumAxis0(b, alignCount);
00152
00153 ILOG_DEBUG("bb: " << timer.SplitTimeStr());
00154
00155
00156
00157
00158
00159 ILOG_DEBUG("aT " << MatNrRow(aT) << " " << MatNrCol(aT));
00160 ILOG_DEBUG("b " << MatNrRow(b) << " " << MatNrCol(b));
00161
00162 ArrayT* paddedB = MatPadZeros(b, MatNrRow(b), IntAlignUp(MatNrCol(b), alignCount));
00163 ArrayT* ab = MatMul(aT, paddedB);
00164 delete paddedB;
00165
00166 ILOG_DEBUG("matmul: " << timer.SplitTimeStr());
00167
00168
00169
00170
00171
00172
00173
00174 if(sizeof(typename ArrayT::StorType) == 8)
00175 {
00176 #pragma omp parallel for
00177 for (int i=0 ; i<MatNrRow(aaT) ; i++)
00178 {
00179 const __m128d tmp = _mm_set1_pd(*MatE(aaT, i, 0));
00180 const __m128d minusTwo = _mm_set1_pd(-2.0);
00181 const UInt64 mask = 0x7FFFFFFFFFFFFFFFL;
00182 const __m128d absMask = _mm_set1_pd(*(double*)(&mask));
00183 __m128d* baseBB = (__m128d*)(MatE(bb, 0, 0));
00184 __m128d* baseAB = (__m128d*)(MatE(ab, i, 0));
00185 const int SSELength = MatNrCol(bb) / 2;
00186 for(int j = 0; j < SSELength; j++)
00187 {
00188 __m128d intermediate = _mm_add_pd(_mm_add_pd(baseBB[j], tmp), _mm_mul_pd(minusTwo, baseAB[j]));
00189 baseAB[j] = _mm_sqrt_pd(_mm_and_pd(intermediate, absMask));
00190 }
00191 }
00192 }
00193 else
00194 {
00195 #pragma omp parallel for
00196 for (int i=0 ; i<MatNrRow(aaT) ; i++)
00197 {
00198 const __m128 tmp = _mm_set1_ps(*MatE(aaT, i, 0));
00199 const __m128 minusTwo = _mm_set1_ps(-2.0);
00200 const UInt32 mask = 0x7FFFFFFF;
00201 const __m128 absMask = _mm_set1_ps(*(Real32*)(&mask));
00202 __m128* baseBB = (__m128*)(MatE(bb, 0, 0));
00203 __m128* baseAB = (__m128*)(MatE(ab, i, 0));
00204 const int SSELength = MatNrCol(bb) / 4;
00205 for(int j = 0; j < SSELength; j++)
00206 {
00207 __m128 intermediate = _mm_add_ps(_mm_add_ps(baseBB[j], tmp), _mm_mul_ps(minusTwo, baseAB[j]));
00208 baseAB[j] = _mm_sqrt_ps(_mm_and_ps(intermediate, absMask));
00209 }
00210 }
00211 }
00212 ILOG_DEBUG("sqrt(abs(aa + bb - 2ab)): " << timer.SplitTimeStr());
00213 delete bb;
00214 delete aaT;
00215
00216 ILOG_DEBUG("delete: " << timer.SplitTimeStr());
00217
00218
00219
00220
00221 ILOG_DEBUG(timer.SplitTime() << " (cpu-matnorm2dist-total)");
00222 return ab;
00223 }
00224
00225 }
00226 }
00227 }
00228
00229 #endif
00230