gemm_mixed_simple< do_trans_A, do_trans_B, use_alpha, use_beta > Class Template Reference
[Gemm_mixed]

Matrix multplication where the matrices have different element types. Simple version (no caching). Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes). More...

#include <gemm_mixed.hpp>

List of all members.

Static Public Member Functions

template<typename out_eT , typename in_eT1 , typename in_eT2 >
static arma_hot void apply (Mat< out_eT > &C, const Mat< in_eT1 > &A, const Mat< in_eT2 > &B, const out_eT alpha=out_eT(1), const out_eT beta=out_eT(0))

Detailed Description

template<const bool do_trans_A = false, const bool do_trans_B = false, const bool use_alpha = false, const bool use_beta = false>
class gemm_mixed_simple< do_trans_A, do_trans_B, use_alpha, use_beta >

Matrix multplication where the matrices have different element types. Simple version (no caching). Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes).

Definition at line 213 of file gemm_mixed.hpp.


Member Function Documentation

template<const bool do_trans_A = false, const bool do_trans_B = false, const bool use_alpha = false, const bool use_beta = false>
template<typename out_eT , typename in_eT1 , typename in_eT2 >
static arma_hot void gemm_mixed_simple< do_trans_A, do_trans_B, use_alpha, use_beta >::apply ( Mat< out_eT > &  C,
const Mat< in_eT1 > &  A,
const Mat< in_eT2 > &  B,
const out_eT  alpha = out_eT(1),
const out_eT  beta = out_eT(0) 
) [inline, static]

Definition at line 223 of file gemm_mixed.hpp.

References Mat< eT >::at(), Mat< eT >::colptr(), Mat< eT >::n_cols, and Mat< eT >::n_rows.

00230     {
00231     arma_extra_debug_sigprint();
00232     
00233     const u32 A_n_rows = A.n_rows;
00234     const u32 A_n_cols = A.n_cols;
00235     
00236     const u32 B_n_rows = B.n_rows;
00237     const u32 B_n_cols = B.n_cols;
00238     
00239     if( (do_trans_A == false) && (do_trans_B == false) )
00240       {
00241       for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00242         {
00243         for(u32 col_B = 0; col_B < B_n_cols; ++col_B)
00244           {
00245           const in_eT2* B_coldata = B.colptr(col_B);
00246           
00247           out_eT acc = out_eT(0);
00248           for(u32 i = 0; i < B_n_rows; ++i)
00249             {
00250             const out_eT val1 = upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i));
00251             const out_eT val2 = upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00252             acc += val1 * val2;
00253             //acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00254             }
00255           
00256           if( (use_alpha == false) && (use_beta == false) )
00257             {
00258             C.at(row_A,col_B) = acc;
00259             }
00260           else
00261           if( (use_alpha == true) && (use_beta == false) )
00262             {
00263             C.at(row_A,col_B) = alpha * acc;
00264             }
00265           else
00266           if( (use_alpha == false) && (use_beta == true) )
00267             {
00268             C.at(row_A,col_B) = acc + beta*C.at(row_A,col_B);
00269             }
00270           else
00271           if( (use_alpha == true) && (use_beta == true) )
00272             {
00273             C.at(row_A,col_B) = alpha*acc + beta*C.at(row_A,col_B);
00274             }
00275           }
00276         }
00277       }
00278     else
00279     if( (do_trans_A == true) && (do_trans_B == false) )
00280       {
00281       for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00282         {
00283         // col_A is interpreted as row_A when storing the results in matrix C
00284         
00285         const in_eT1* A_coldata = A.colptr(col_A);
00286         
00287         for(u32 col_B=0; col_B < B_n_cols; ++col_B)
00288           {
00289           const in_eT2* B_coldata = B.colptr(col_B);
00290           
00291           out_eT acc = out_eT(0);
00292           for(u32 i=0; i < B_n_rows; ++i)
00293             {
00294             acc += upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]) * upgrade_val<in_eT1,in_eT2>::apply(B_coldata[i]);
00295             }
00296         
00297           if( (use_alpha == false) && (use_beta == false) )
00298             {
00299             C.at(col_A,col_B) = acc;
00300             }
00301           else
00302           if( (use_alpha == true) && (use_beta == false) )
00303             {
00304             C.at(col_A,col_B) = alpha * acc;
00305             }
00306           else
00307           if( (use_alpha == false) && (use_beta == true) )
00308             {
00309             C.at(col_A,col_B) = acc + beta*C.at(col_A,col_B);
00310             }
00311           else
00312           if( (use_alpha == true) && (use_beta == true) )
00313             {
00314             C.at(col_A,col_B) = alpha*acc + beta*C.at(col_A,col_B);
00315             }
00316           
00317           }
00318         }
00319       }
00320     else
00321     if( (do_trans_A == false) && (do_trans_B == true) )
00322       {
00323       for(u32 row_A = 0; row_A < A_n_rows; ++row_A)
00324         {
00325         for(u32 row_B = 0; row_B < B_n_rows; ++row_B)
00326           {
00327           out_eT acc = out_eT(0);
00328           for(u32 i = 0; i < B_n_cols; ++i)
00329             {
00330             acc += upgrade_val<in_eT1,in_eT2>::apply(A.at(row_A,i)) * upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i));
00331             }
00332           
00333           if( (use_alpha == false) && (use_beta == false) )
00334             {
00335             C.at(row_A,row_B) = acc;
00336             }
00337           else
00338           if( (use_alpha == true) && (use_beta == false) )
00339             {
00340             C.at(row_A,row_B) = alpha * acc;
00341             }
00342           else
00343           if( (use_alpha == false) && (use_beta == true) )
00344             {
00345             C.at(row_A,row_B) = acc + beta*C.at(row_A,row_B);
00346             }
00347           else
00348           if( (use_alpha == true) && (use_beta == true) )
00349             {
00350             C.at(row_A,row_B) = alpha*acc + beta*C.at(row_A,row_B);
00351             }
00352           }
00353         }
00354       }
00355     else
00356     if( (do_trans_A == true) && (do_trans_B == true) )
00357       {
00358       for(u32 row_B=0; row_B < B_n_rows; ++row_B)
00359         {
00360         
00361         for(u32 col_A=0; col_A < A_n_cols; ++col_A)
00362           {
00363           const in_eT1* A_coldata = A.colptr(col_A);
00364           
00365           out_eT acc = out_eT(0);
00366           for(u32 i=0; i < A_n_rows; ++i)
00367             {
00368             acc += upgrade_val<in_eT1,in_eT2>::apply(B.at(row_B,i)) * upgrade_val<in_eT1,in_eT2>::apply(A_coldata[i]);
00369             }
00370         
00371           if( (use_alpha == false) && (use_beta == false) )
00372             {
00373             C.at(col_A,row_B) = acc;
00374             }
00375           else
00376           if( (use_alpha == true) && (use_beta == false) )
00377             {
00378             C.at(col_A,row_B) = alpha * acc;
00379             }
00380           else
00381           if( (use_alpha == false) && (use_beta == true) )
00382             {
00383             C.at(col_A,row_B) = acc + beta*C.at(col_A,row_B);
00384             }
00385           else
00386           if( (use_alpha == true) && (use_beta == true) )
00387             {
00388             C.at(col_A,row_B) = alpha*acc + beta*C.at(col_A,row_B);
00389             }
00390           
00391           }
00392         }
00393       
00394       }
00395     }