00001 #ifndef Impala_Core_Matrix_MatMul_h
00002 #define Impala_Core_Matrix_MatMul_h
00003
00004 #include "Core/Matrix/Mat.h"
00005 #include "Core/Matrix/MatFunc.h"
00006 #include "Core/Matrix/MatTranspose.h"
00007
00008
00009 #ifdef GSL_CBLAS
00010 #include <gsl/gsl_cblas.h>
00011 #endif
00012
00013 #ifdef MKL_CBLAS
00014 #include <mkl_cblas.h>
00015 #endif
00016
00017 #ifdef ATLAS_CBLAS
00018 extern "C"{
00019 #include <cblas.h>
00020 }
00021 #endif
00022
00023 #ifdef ACML_CBLAS
00025 //ACML defines the signatures of CBlas functions differently
00026
00027
00028
00029 #include <acml.h>
00030
00031 #define CblasTrans 'T'
00032 #define CblasNoTrans 'N'
00033
00034 #define SGEMM(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13) \
00035 sgemm(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13)
00036 #define DGEMM(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13) \
00037 dgemm(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13)
00038
00039 #else
00041 //For other CBlas interfaces that suit the cblas.h definitions,
00042
00043
00044 #define SGEMM(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13) \
00045 cblas_sgemm(CblasRowMajor,X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13)
00046
00047 #define DGEMM(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13) \
00048 cblas_dgemm(CblasRowMajor,X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13)
00049 #endif
00050
00052
00053 #ifndef GSL_CBLAS
00054 #ifndef MKL_CBLAS
00055 #ifndef ATLAS_CBLAS
00056 #ifndef ACML_CBLAS
00057 #define BASE_CBLAS
00058 #endif
00059 #endif
00060 #endif
00061 #endif
00062
00063
00064
00065
00066
00067
00068
00069 #ifdef ENABLE_PERFORMANCE_TEST
00070 #include "sys/time.h"
00071 #endif
00072
00073
00074
00075
00076
00077
00078 #ifdef ENABLE_PRECISION_TEST
00079 #include "Core/Array/Sub.h"
00080 #include "Core/Array/PixSum.h"
00081 #include "Core/Array/PrintData.h"
00082 #endif
00083
00084 namespace Impala
00085 {
00086 namespace Core
00087 {
00088 namespace Matrix
00089 {
00090
00091 template<class ArrayT>
00092 inline ArrayT*
00093 MatMul(ArrayT* m, ArrayT* a)
00094 {
00095 ArrayT* res;
00096 ILOG_VAR(Core.Matrix.MatMul);
00097
00098 #ifdef ENABLE_PERFORMANCE_TIMING
00099 timeval start,end,dur;
00100 gettimeofday(&start,NULL);
00101 #endif
00102
00103 if (MatNrCol(m) != MatNrRow(a))
00104 {
00105 ILOG_ERROR("nonconformant MatMul operands.");
00106 return 0;
00107 }
00108 res = MatCreate<ArrayT>(MatNrRow(m), MatNrCol(a));
00109
00110 #ifdef BASE_CBLAS
00111 Mat* aa = MatTranspose(a);
00112 for (int i=0 ; i<MatNrRow(m) ; i++)
00113 {
00114 double* rowM = MatE(m, i, 0);
00115 for (int j=0 ; j<MatNrCol(a) ; j++)
00116 {
00117 double* rowAA = MatE(aa, j, 0);
00118 MatStorType sum = 0;
00119 for (int k=0 ; k<MatNrCol(m) ; k++)
00120 {
00121 sum += rowM[k] * rowAA[k];
00122 }
00123 *MatE(res, i, j) = sum;
00124 }
00125 }
00126 delete aa;
00127 #else
00128 DGEMM(CblasNoTrans, CblasNoTrans,
00129 MatNrRow(m), MatNrCol(a), MatNrRow(a),
00130 1.0, m->PB(), MatNrCol(m),
00131 a->PB(), MatNrCol(a),
00132 0.0, res->PB(), MatNrCol(res));
00133 #ifdef ENABLE_PRECISION_TEST
00134 ArrayT* res2;
00135 res2 = MatCreate<ArrayT>(MatNrRow(m), MatNrCol(a));
00136 Mat* aa = MatTranspose(a);
00137 for (int i=0 ; i<MatNrRow(m) ; i++)
00138 {
00139 double* rowM = MatE(m, i, 0);
00140 for (int j=0 ; j<MatNrCol(a) ; j++)
00141 {
00142 double* rowAA = MatE(aa, j, 0);
00143 MatStorType sum = 0;
00144 for (int k=0 ; k<MatNrCol(m) ; k++)
00145 {
00146 sum += rowM[k] * rowAA[k];
00147 }
00148 *MatE(res2, i, j) = sum;
00149 }
00150 }
00151 delete aa;
00152 ArrayT* r=0;
00153 Sub(r,res,res2);
00154 Impala::Real64 diff=PixSum(r);
00155 if(diff!=0){
00156 ILOG_WARN("The Arrays are different!");
00157 PrintData(r);
00158
00159 }
00160 else{
00161 ILOG_INFO("Arrays are same!");
00162 }
00163 #endif
00164
00165 #endif
00166
00167 #ifdef ENABLE_PERFORMANCE_TIMING
00168 gettimeofday(&end,NULL);
00169 long long int sec = end.tv_sec - start.tv_sec;
00170 long long int usec= end.tv_usec - start.tv_usec;
00171
00172 ILOG_DEBUG("MatMul : "<<sec*1000000+usec<<" usec");
00173 #endif
00174
00175 return res;
00176
00177 }
00178
00179 }
00180 }
00181 }
00182
00183 #endif