00001 #ifndef Impala_Core_Matrix_MatMul_h
00002 #define Impala_Core_Matrix_MatMul_h
00003
00004 #include "Core/Matrix/Mat.h"
00005 #include "Core/Matrix/MatFunc.h"
00006 #include "Core/Matrix/MatTranspose.h"
00007
00008
00009 #ifdef GSL_CBLAS
00010 #include <gsl/gsl_cblas.h>
00011 #endif
00012
00013 #ifdef MKL_CBLAS
00014 #include <mkl_cblas.h>
00015 #endif
00016
00017 #ifdef ATLAS_CBLAS
00018 extern "C"{
00019 #include <cblas.h>
00020 }
00021 #endif
00022
00023 #ifdef ACML_CBLAS
00025 //ACML defines the signatures of CBlas functions differently
00026
00027
00028
00029 #include <acml.h>
00030
00031 #define CblasTrans 'T'
00032 #define CblasNoTrans 'N'
00033
00034 #define SGEMM(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13) \
00035 sgemm(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13)
00036 #define DGEMM(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13) \
00037 dgemm(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13)
00038
00039 #else
00041 //For other CBlas interfaces that suit the cblas.h definitions,
00042
00043
00044 #define SGEMM(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13) \
00045 cblas_sgemm(CblasRowMajor,X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13)
00046
00047 #define DGEMM(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13) \
00048 cblas_dgemm(CblasRowMajor,X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13)
00049 #endif
00050
00052
00053 #ifndef GSL_CBLAS
00054 #ifndef MKL_CBLAS
00055 #ifndef ATLAS_CBLAS
00056 #ifndef ACML_CBLAS
00057 #define BASE_CBLAS
00058 #endif
00059 #endif
00060 #endif
00061 #endif
00062
00063
00064 namespace Impala
00065 {
00066 namespace Core
00067 {
00068 namespace Matrix
00069 {
00070
00071 template<class ArrayT>
00072 inline ArrayT*
00073 MatMul(ArrayT* m, ArrayT* a)
00074 {
00075 ArrayT* res;
00076 ILOG_VAR(Core.Matrix.MatMul);
00077
00078 if (MatNrCol(m) != MatNrRow(a))
00079 {
00080 ILOG_ERROR("nonconformant MatMul operands.");
00081 return 0;
00082 }
00083 res = MatCreate<ArrayT>(MatNrRow(m), MatNrCol(a));
00084
00085 #ifdef BASE_CBLAS
00086 Mat* aa = MatTranspose(a);
00087 for (int i=0 ; i<MatNrRow(m) ; i++)
00088 {
00089 double* rowM = MatE(m, i, 0);
00090 for (int j=0 ; j<MatNrCol(a) ; j++)
00091 {
00092 double* rowAA = MatE(aa, j, 0);
00093 MatStorType sum = 0;
00094 for (int k=0 ; k<MatNrCol(m) ; k++)
00095 {
00096 sum += rowM[k] * rowAA[k];
00097 }
00098 *MatE(res, i, j) = sum;
00099 }
00100 }
00101 delete aa;
00102 #else
00103 DGEMM(CblasNoTrans, CblasNoTrans,
00104 MatNrRow(m), MatNrCol(a), MatNrRow(a),
00105 1.0, m->PB(), MatNrCol(m),
00106 a->PB(), MatNrCol(a),
00107 0.0, res->PB(), MatNrCol(res));
00108 #endif
00109
00110 return res;
00111 }
00112
00113 inline Mat32*
00114 MatMul(Mat32* m, Mat32* a)
00115 {
00116 Mat32* res;
00117 ILOG_VAR(Core.Matrix.MatMul);
00118
00119 if (MatNrCol(m) != MatNrRow(a))
00120 {
00121 ILOG_ERROR("nonconformant MatMul operands.");
00122 return 0;
00123 }
00124 res = MatCreate<Mat32>(MatNrRow(m), MatNrCol(a));
00125
00126 #ifdef BASE_CBLAS
00127 Mat32* aa = MatTranspose(a);
00128 for (int i=0 ; i<MatNrRow(m) ; i++)
00129 {
00130 Real32* rowM = MatE(m, i, 0);
00131 for (int j=0 ; j<MatNrCol(a) ; j++)
00132 {
00133 Real32* rowAA = MatE(aa, j, 0);
00134 Real32 sum = 0;
00135 for (int k=0 ; k<MatNrCol(m) ; k++)
00136 {
00137 sum += rowM[k] * rowAA[k];
00138 }
00139 *MatE(res, i, j) = sum;
00140 }
00141 }
00142 delete aa;
00143 #else
00144 SGEMM(CblasNoTrans, CblasNoTrans,
00145 MatNrRow(m), MatNrCol(a), MatNrRow(a),
00146 1.0, m->PB(), MatNrCol(m),
00147 a->PB(), MatNrCol(a),
00148 0.0, res->PB(), MatNrCol(res));
00149 #endif
00150
00151 return res;
00152 }
00153
00154 }
00155 }
00156 }
00157
00158 #endif