template<class ArrayT>
Definition at line 81 of file MatNorm2Dist.h. References Impala::Core::Array::Abs(), Impala::Core::Array::Add(), ILOG_DEBUG, ILOG_ERROR, ILOG_VAR, MatMul(), MatNrCol(), MatNrRow(), MatReplicateMatrix(), MatSumAxis0(), MatSumAxis1(), Impala::Core::Array::Mul(), Impala::Core::Array::MulVal(), Impala::Timer::SplitTime(), Impala::Timer::SplitTimeStr(), and Impala::Core::Array::Sqrt(). Referenced by MatNorm2Dist(), and MatNorm2DistTransposed(). 00082 { 00083 ILOG_VAR(Core.Matrix.MatNorm2Dist); 00084 if (MatNrCol(aT) != MatNrRow(b)) { 00085 ILOG_ERROR("MatNorm2DistInternal operands: dimensionality problem"); 00086 } 00087 00088 Timer timer; 00089 00090 // aa = sum(multiply(a,a), axis=0) # aa = sum(a.*a,1) 00091 ArrayT* tempAT = 0; 00092 Mul(tempAT, aT, aT); // elementwise multiplication 00093 ArrayT* aaT = MatSumAxis1(tempAT); 00094 delete tempAT; 00095 00096 ILOG_DEBUG("aaT: " << timer.SplitTimeStr()); 00097 00098 // bb = sum(multiply(b,b), axis=0) # bb = sum(b.*b,1) 00099 ArrayT* tempB = 0; 00100 Mul(tempB, b, b); // elementwise multiplication 00101 ArrayT* bb = MatSumAxis0(tempB); 00102 delete tempB; 00103 00104 ILOG_DEBUG("bb: " << timer.SplitTimeStr()); 00105 00106 // ab = dot(transpose(a),b) # ab = a'*b 00107 //ArrayT* aT = MatTranspose(a); 00108 //ILOG_DEBUG("transpose: " << timer.SplitTimeStr()); 00109 00110 ILOG_DEBUG("aT " << MatNrRow(aT) << " " << MatNrCol(aT)); 00111 ILOG_DEBUG("b " << MatNrRow(b) << " " << MatNrCol(b)); 00112 00113 ArrayT* ab = MatMul(aT, b); 00114 00115 ILOG_DEBUG("matmul: " << timer.SplitTimeStr()); 00116 00117 MulVal(ab, ab, -2.0); // -2*ab 00118 00119 ILOG_DEBUG("ab: " << timer.SplitTimeStr()); 00120 00121 // return sqrt(abs(transpose(repmat(aa,bb.shape[0],1)) + repmat(bb,aa.shape[0],1) - 2*ab)) 00122 // #d = sqrt(abs(repmat(aa',[1 size(bb,2)]) + repmat(bb,[size(aa,2) 1]) - 2*ab)); 00123 //ArrayT* aaT = MatTranspose(aa); 00124 //ILOG_DEBUG("transpose: " << timer.SplitTimeStr()); 00125 00126 ArrayT* repAAT = MatReplicateMatrix(aaT, 1, MatNrCol(bb)); 00127 ILOG_DEBUG("repmatAAT: " << timer.SplitTimeStr()); 00128 ArrayT* repBB = MatReplicateMatrix(bb, MatNrRow(aaT), 1); 00129 ILOG_DEBUG("repmatBB: " << timer.SplitTimeStr()); 00130 delete bb; 00131 delete aaT; 00132 00133 ILOG_DEBUG("delete: " << timer.SplitTimeStr()); 00134 00135 Add(repAAT, repAAT, repBB); 00136 delete repBB; 00137 Add(repAAT, repAAT, ab); 00138 ILOG_DEBUG("add: " << timer.SplitTimeStr()); 00139 00140 delete ab; 00141 Abs(repAAT, repAAT); 00142 ILOG_DEBUG("abs: " << timer.SplitTimeStr()); 00143 00144 Sqrt(repAAT, repAAT); 00145 00146 ILOG_DEBUG("sqrt: " << timer.SplitTimeStr()); 00147 //WriteRaw(a, "matrix_a.raw", &Util::Database::GetInstance(), true); 00148 //WriteRaw(b, "matrix_b.raw", &Util::Database::GetInstance(), true); 00149 //WriteRaw(repAAT, "matrix_c.raw", &Util::Database::GetInstance(), true); 00150 //ILOG_INFO("cpu-matnorm2dist-total: " << timer.SplitTime()); 00151 ILOG_DEBUG(timer.SplitTime() << " (cpu-matnorm2dist-total)"); 00152 return repAAT; 00153 }
Here is the call graph for this function:
|