Home || Visual Search || Applications || Architecture || Important Messages || OGL || Src

MatMul.h

Go to the documentation of this file.
00001 #ifndef Impala_Core_Matrix_MatMul_h
00002 #define Impala_Core_Matrix_MatMul_h
00003 
00004 #include "Core/Matrix/Mat.h"
00005 #include "Core/Matrix/MatFunc.h"
00006 #include "Core/Matrix/MatTranspose.h"
00007 
00008 
00009 #ifdef GSL_CBLAS
00010 #include <gsl/gsl_cblas.h>
00011 #endif
00012 
00013 #ifdef MKL_CBLAS
00014 #include <mkl_cblas.h>
00015 #endif
00016 
00017 #ifdef ATLAS_CBLAS
00018 extern "C"{
00019     #include <cblas.h>
00020 }
00021 #endif
00022 
00023 #ifdef ACML_CBLAS
00025 //ACML defines the signatures of CBlas functions differently
00026 //there is no cblas_ prefix to the BLAS functions
00027 //Also there is no option for choosing ColMajor or RowMajor
00028 //It defaults to rowmajor
00029     #include <acml.h>
00030 
00031     #define CblasTrans 'T'
00032     #define CblasNoTrans 'N'
00033 
00034     #define SGEMM(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13) \
00035             sgemm(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13)
00036     #define DGEMM(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13) \
00037             dgemm(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13)
00038 
00039 #else
00041 //For other CBlas interfaces that suit the cblas.h definitions,
00042 //we only need to add a cblas_ prefix, and mention it's RowMajor
00043 //
00044     #define SGEMM(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13) \
00045             cblas_sgemm(CblasRowMajor,X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13)
00046 
00047     #define DGEMM(X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13) \
00048             cblas_dgemm(CblasRowMajor,X1,X2,X3,X4,X5,X6,X7,X8,X9,X10,X11,X12,X13)
00049     #endif
00050 
00052 //If no external library is used, use the reference implementation
00053 #ifndef GSL_CBLAS
00054 #ifndef MKL_CBLAS
00055 #ifndef ATLAS_CBLAS
00056 #ifndef ACML_CBLAS
00057     #define BASE_CBLAS
00058 #endif
00059 #endif
00060 #endif
00061 #endif
00062 
00063 
00064 namespace Impala
00065 {
00066 namespace Core
00067 {
00068 namespace Matrix
00069 {
00070 
00071 template<class ArrayT>
00072 inline ArrayT*
00073 MatMul(ArrayT* m, ArrayT* a)
00074 {
00075     ArrayT* res;
00076     ILOG_VAR(Core.Matrix.MatMul);
00077 
00078     if (MatNrCol(m) != MatNrRow(a))
00079     {
00080         ILOG_ERROR("nonconformant MatMul operands.");
00081         return 0;
00082     }
00083     res = MatCreate<ArrayT>(MatNrRow(m), MatNrCol(a));
00084 
00085 #ifdef BASE_CBLAS
00086     Mat* aa = MatTranspose(a);
00087     for (int i=0 ; i<MatNrRow(m) ; i++)
00088     {
00089         double* rowM = MatE(m, i, 0);
00090         for (int j=0 ; j<MatNrCol(a) ; j++)
00091         {
00092             double* rowAA = MatE(aa, j, 0);
00093             MatStorType sum = 0;
00094             for (int k=0 ; k<MatNrCol(m) ; k++)
00095             {
00096                 sum += rowM[k] * rowAA[k];
00097             }
00098             *MatE(res, i, j) = sum;
00099         }
00100     }
00101     delete aa;
00102 #else
00103     DGEMM(CblasNoTrans, CblasNoTrans,
00104           MatNrRow(m), MatNrCol(a), MatNrRow(a),
00105           1.0, m->PB(), MatNrCol(m),
00106           a->PB(), MatNrCol(a),
00107           0.0, res->PB(), MatNrCol(res));
00108 #endif
00109 
00110     return res;
00111 }
00112 
00113 inline Mat32*
00114 MatMul(Mat32* m, Mat32* a)
00115 {
00116     Mat32* res;
00117     ILOG_VAR(Core.Matrix.MatMul);
00118 
00119     if (MatNrCol(m) != MatNrRow(a))
00120     {
00121         ILOG_ERROR("nonconformant MatMul operands.");
00122         return 0;
00123     }
00124     res = MatCreate<Mat32>(MatNrRow(m), MatNrCol(a));
00125 
00126 #ifdef BASE_CBLAS
00127     Mat32* aa = MatTranspose(a);
00128     for (int i=0 ; i<MatNrRow(m) ; i++)
00129     {
00130         Real32* rowM = MatE(m, i, 0);
00131         for (int j=0 ; j<MatNrCol(a) ; j++)
00132         {
00133             Real32* rowAA = MatE(aa, j, 0);
00134             Real32 sum = 0;
00135             for (int k=0 ; k<MatNrCol(m) ; k++)
00136             {
00137                 sum += rowM[k] * rowAA[k];
00138             }
00139             *MatE(res, i, j) = sum;
00140         }
00141     }
00142     delete aa;
00143 #else
00144     SGEMM(CblasNoTrans, CblasNoTrans,
00145           MatNrRow(m), MatNrCol(a), MatNrRow(a),
00146           1.0, m->PB(), MatNrCol(m),
00147           a->PB(), MatNrCol(a),
00148           0.0, res->PB(), MatNrCol(res));
00149 #endif
00150 
00151     return res;
00152 }
00153 
00154 } // namespace Matrix
00155 } // namespace Core
00156 } // namespace Impala
00157 
00158 #endif

Generated on Thu Jan 13 09:04:33 2011 for ImpalaSrc by  doxygen 1.5.1