Classes | |
struct | depth_lhs< glue_type, T1 > |
Template metaprogram depth_lhs calculates the number of Glue<Tx,Ty, glue_type> instances on the left hand side argument of Glue<Tx,Ty, glue_type> i.e. it recursively expands each Tx, until the type of Tx is not "Glue<..,.., glue_type>" (i.e the "glue_type" changes). More... | |
struct | depth_lhs< glue_type, Glue< T1, T2, glue_type > > |
struct | glue_times_redirect< N > |
struct | glue_times_redirect< 3 > |
struct | glue_times_redirect< 4 > |
class | glue_times |
Class which implements the immediate multiplication of two or more matrices. More... | |
class | glue_times_diag |
Functions | |
template<typename T1 , typename T2 > | |
static void | glue_times_redirect::apply (Mat< typename T1::elem_type > &out, const Glue< T1, T2, glue_times > &X) |
template<typename T1 , typename T2 , typename T3 > | |
static void | glue_times_redirect< 3 >::apply (Mat< typename T1::elem_type > &out, const Glue< Glue< T1, T2, glue_times >, T3, glue_times > &X) |
template<typename T1 , typename T2 , typename T3 , typename T4 > | |
static void | glue_times_redirect< 4 >::apply (Mat< typename T1::elem_type > &out, const Glue< Glue< Glue< T1, T2, glue_times >, T3, glue_times >, T4, glue_times > &X) |
template<typename T1 , typename T2 > | |
static void | glue_times::apply (Mat< typename T1::elem_type > &out, const Glue< T1, T2, glue_times > &X) |
template<typename T1 > | |
static void | glue_times::apply_inplace (Mat< typename T1::elem_type > &out, const T1 &X) |
template<typename T1 , typename T2 > | |
static arma_hot void | glue_times::apply_inplace_plus (Mat< typename T1::elem_type > &out, const Glue< T1, T2, glue_times > &X, const s32 sign) |
template<typename eT > | |
static arma_inline u32 | glue_times::mul_storage_cost (const Mat< eT > &A, const Mat< eT > &B, const bool do_trans_A, const bool do_trans_B) |
template<typename eT > | |
static arma_hot void | glue_times::apply (Mat< eT > &out, const Mat< eT > &A, const Mat< eT > &B, const eT val, const bool do_trans_A, const bool do_trans_B, const bool do_scalar_times) |
template<typename eT > | |
static void | glue_times::apply (Mat< eT > &out, const Mat< eT > &A, const Mat< eT > &B, const Mat< eT > &C, const eT val, const bool do_trans_A, const bool do_trans_B, const bool do_trans_C, const bool do_scalar_times) |
template<typename eT > | |
static void | glue_times::apply (Mat< eT > &out, const Mat< eT > &A, const Mat< eT > &B, const Mat< eT > &C, const Mat< eT > &D, const eT val, const bool do_trans_A, const bool do_trans_B, const bool do_trans_C, const bool do_trans_D, const bool do_scalar_times) |
template<typename T1 , typename T2 > | |
static arma_hot void | glue_times_diag::apply (Mat< typename T1::elem_type > &out, const Glue< T1, T2, glue_times_diag > &X) |
void glue_times_redirect< N >::apply | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< T1, T2, glue_times > & | X | |||
) | [inline, static, inherited] |
Definition at line 26 of file glue_times_meat.hpp.
References Glue< T1, T2, glue_type >::A, Glue< T1, T2, glue_type >::B, partial_unwrap_check< T1 >::do_times, partial_unwrap_check< T1 >::do_trans, partial_unwrap_check< T1 >::M, and partial_unwrap_check< T1 >::val.
Referenced by glue_times_redirect< 4 >::apply(), and glue_times_redirect< 3 >::apply().
00027 { 00028 arma_extra_debug_sigprint(); 00029 00030 typedef typename T1::elem_type eT; 00031 00032 const partial_unwrap_check<T1> tmp1(X.A, out); 00033 const partial_unwrap_check<T2> tmp2(X.B, out); 00034 00035 const Mat<eT>& A = tmp1.M; 00036 const Mat<eT>& B = tmp2.M; 00037 00038 const bool do_trans_A = tmp1.do_trans; 00039 const bool do_trans_B = tmp2.do_trans; 00040 00041 const bool use_alpha = tmp1.do_times | tmp2.do_times; 00042 const eT alpha = use_alpha ? (tmp1.val * tmp2.val) : eT(0); 00043 00044 glue_times::apply(out, A, B, alpha, do_trans_A, do_trans_B, use_alpha); 00045 }
void glue_times_redirect< 3 >::apply | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< Glue< T1, T2, glue_times >, T3, glue_times > & | X | |||
) | [inline, static, inherited] |
Definition at line 52 of file glue_times_meat.hpp.
References glue_times_redirect< N >::apply(), partial_unwrap_check< T1 >::do_times, partial_unwrap_check< T1 >::do_trans, partial_unwrap_check< T1 >::M, and partial_unwrap_check< T1 >::val.
00053 { 00054 arma_extra_debug_sigprint(); 00055 00056 typedef typename T1::elem_type eT; 00057 00058 // there is exactly 3 objects 00059 // hence we can safely expand X as X.A.A, X.A.B and X.B 00060 00061 const partial_unwrap_check<T1> tmp1(X.A.A, out); 00062 const partial_unwrap_check<T2> tmp2(X.A.B, out); 00063 const partial_unwrap_check<T3> tmp3(X.B, out); 00064 00065 const Mat<eT>& A = tmp1.M; 00066 const Mat<eT>& B = tmp2.M; 00067 const Mat<eT>& C = tmp3.M; 00068 00069 const bool do_trans_A = tmp1.do_trans; 00070 const bool do_trans_B = tmp2.do_trans; 00071 const bool do_trans_C = tmp3.do_trans; 00072 00073 const bool use_alpha = tmp1.do_times | tmp2.do_times | tmp3.do_times; 00074 const eT alpha = use_alpha ? (tmp1.val * tmp2.val * tmp3.val) : eT(0); 00075 00076 glue_times::apply(out, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha); 00077 }
void glue_times_redirect< 4 >::apply | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< Glue< Glue< T1, T2, glue_times >, T3, glue_times >, T4, glue_times > & | X | |||
) | [inline, static, inherited] |
Definition at line 84 of file glue_times_meat.hpp.
References glue_times_redirect< N >::apply(), partial_unwrap_check< T1 >::do_times, partial_unwrap_check< T1 >::do_trans, partial_unwrap_check< T1 >::M, and partial_unwrap_check< T1 >::val.
00085 { 00086 arma_extra_debug_sigprint(); 00087 00088 typedef typename T1::elem_type eT; 00089 00090 // there is exactly 4 objects 00091 // hence we can safely expand X as X.A.A.A, X.A.A.B, X.A.B and X.B 00092 00093 const partial_unwrap_check<T1> tmp1(X.A.A.A, out); 00094 const partial_unwrap_check<T2> tmp2(X.A.A.B, out); 00095 const partial_unwrap_check<T3> tmp3(X.A.B, out); 00096 const partial_unwrap_check<T4> tmp4(X.B, out); 00097 00098 const Mat<eT>& A = tmp1.M; 00099 const Mat<eT>& B = tmp2.M; 00100 const Mat<eT>& C = tmp3.M; 00101 const Mat<eT>& D = tmp4.M; 00102 00103 const bool do_trans_A = tmp1.do_trans; 00104 const bool do_trans_B = tmp2.do_trans; 00105 const bool do_trans_C = tmp3.do_trans; 00106 const bool do_trans_D = tmp4.do_trans; 00107 00108 const bool use_alpha = tmp1.do_times | tmp2.do_times | tmp3.do_times | tmp4.do_times; 00109 const eT alpha = use_alpha ? (tmp1.val * tmp2.val * tmp3.val * tmp4.val) : eT(0); 00110 00111 glue_times::apply(out, A, B, C, D, alpha, do_trans_A, do_trans_B, do_trans_C, do_trans_D, use_alpha); 00112 }
void glue_times::apply | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< T1, T2, glue_times > & | X | |||
) | [inline, static, inherited] |
Definition at line 119 of file glue_times_meat.hpp.
Referenced by apply(), apply_inplace(), and apply_inplace_plus().
00120 { 00121 arma_extra_debug_sigprint(); 00122 00123 typedef typename T1::elem_type eT; 00124 00125 const s32 N_mat = 1 + depth_lhs< glue_times, Glue<T1,T2,glue_times> >::num; 00126 00127 arma_extra_debug_print(arma_boost::format("N_mat = %d") % N_mat); 00128 00129 glue_times_redirect<N_mat>::apply(out, X); 00130 }
void glue_times::apply_inplace | ( | Mat< typename T1::elem_type > & | out, | |
const T1 & | X | |||
) | [inline, static, inherited] |
Definition at line 137 of file glue_times_meat.hpp.
References apply(), Mat< eT >::at(), Mat< eT >::colptr(), unwrap_check< T1 >::M, podarray< eT >::memptr(), Mat< eT >::n_cols, and Mat< eT >::n_rows.
Referenced by Mat< eT >::operator*=().
00138 { 00139 arma_extra_debug_sigprint(); 00140 00141 typedef typename T1::elem_type eT; 00142 00143 const unwrap_check<T1> tmp(X, out); 00144 const Mat<eT>& B = tmp.M; 00145 00146 arma_debug_assert_mul_size(out, B, "matrix multiply"); 00147 00148 if(out.n_cols == B.n_cols) 00149 { 00150 podarray<eT> tmp(out.n_cols); 00151 eT* tmp_rowdata = tmp.memptr(); 00152 00153 for(u32 out_row=0; out_row < out.n_rows; ++out_row) 00154 { 00155 for(u32 out_col=0; out_col < out.n_cols; ++out_col) 00156 { 00157 tmp_rowdata[out_col] = out.at(out_row,out_col); 00158 } 00159 00160 for(u32 B_col=0; B_col < B.n_cols; ++B_col) 00161 { 00162 const eT* B_coldata = B.colptr(B_col); 00163 00164 eT val = eT(0); 00165 for(u32 i=0; i < B.n_rows; ++i) 00166 { 00167 val += tmp_rowdata[i] * B_coldata[i]; 00168 } 00169 00170 out.at(out_row,B_col) = val; 00171 } 00172 } 00173 00174 } 00175 else 00176 { 00177 const Mat<eT> tmp(out); 00178 glue_times::apply(out, tmp, B, eT(1), false, false, false); 00179 } 00180 00181 }
arma_hot void glue_times::apply_inplace_plus | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< T1, T2, glue_times > & | X, | |||
const s32 | sign | |||
) | [inline, static, inherited] |
Definition at line 189 of file glue_times_meat.hpp.
References Glue< T1, T2, glue_type >::A, apply(), arma_assert_same_size(), Glue< T1, T2, glue_type >::B, partial_unwrap_check< T1 >::do_times, partial_unwrap_check< T1 >::do_trans, partial_unwrap_check< T1 >::M, Mat< eT >::memptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and partial_unwrap_check< T1 >::val.
Referenced by Mat< eT >::operator+=(), and Mat< eT >::operator-=().
00190 { 00191 arma_extra_debug_sigprint(); 00192 00193 typedef typename T1::elem_type eT; 00194 00195 const partial_unwrap_check<T1> tmp1(X.A, out); 00196 const partial_unwrap_check<T2> tmp2(X.B, out); 00197 00198 const Mat<eT>& A = tmp1.M; 00199 const Mat<eT>& B = tmp2.M; 00200 const eT alpha = tmp1.val * tmp2.val * ( (sign > s32(0)) ? eT(1) : eT(-1) ); 00201 00202 const bool do_trans_A = tmp1.do_trans; 00203 const bool do_trans_B = tmp2.do_trans; 00204 const bool use_alpha = tmp1.do_times | tmp2.do_times | (sign < s32(0)); 00205 00206 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiply"); 00207 00208 const u32 result_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; 00209 const u32 result_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; 00210 00211 arma_assert_same_size(out.n_rows, out.n_cols, result_n_rows, result_n_cols, "matrix addition"); 00212 00213 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) 00214 { 00215 if(A.n_rows == 1) 00216 { 00217 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00218 } 00219 else 00220 if(B.n_cols == 1) 00221 { 00222 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00223 } 00224 else 00225 { 00226 gemm<false, false, false, true>::apply(out, A, B, alpha, eT(1)); 00227 } 00228 } 00229 else 00230 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) 00231 { 00232 if(A.n_rows == 1) 00233 { 00234 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00235 } 00236 else 00237 if(B.n_cols == 1) 00238 { 00239 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00240 } 00241 else 00242 { 00243 gemm<false, false, true, true>::apply(out, A, B, alpha, eT(1)); 00244 } 00245 } 00246 else 00247 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) ) 00248 { 00249 if(A.n_cols == 1) 00250 { 00251 gemv<true, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00252 } 00253 else 00254 if(B.n_cols == 1) 00255 { 00256 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00257 } 00258 else 00259 { 00260 gemm<true, false, false, true>::apply(out, A, B, alpha, eT(1)); 00261 } 00262 } 00263 else 00264 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) ) 00265 { 00266 if(A.n_cols == 1) 00267 { 00268 gemv<true, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00269 } 00270 else 00271 if(B.n_cols == 1) 00272 { 00273 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00274 } 00275 else 00276 { 00277 gemm<true, false, true, true>::apply(out, A, B, alpha, eT(1)); 00278 } 00279 } 00280 else 00281 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) ) 00282 { 00283 if(A.n_rows == 1) 00284 { 00285 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00286 } 00287 else 00288 if(B.n_rows == 1) 00289 { 00290 gemv<false, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00291 } 00292 else 00293 { 00294 gemm<false, true, false, true>::apply(out, A, B, alpha, eT(1)); 00295 } 00296 } 00297 else 00298 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) ) 00299 { 00300 if(A.n_rows == 1) 00301 { 00302 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00303 } 00304 else 00305 if(B.n_rows == 1) 00306 { 00307 gemv<false, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00308 } 00309 else 00310 { 00311 gemm<false, true, true, true>::apply(out, A, B, alpha, eT(1)); 00312 } 00313 } 00314 else 00315 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) ) 00316 { 00317 if(A.n_cols == 1) 00318 { 00319 gemv<false, false, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00320 } 00321 else 00322 if(B.n_rows == 1) 00323 { 00324 gemv<true, false, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00325 } 00326 else 00327 { 00328 gemm<true, true, false, true>::apply(out, A, B, alpha, eT(1)); 00329 } 00330 } 00331 else 00332 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) ) 00333 { 00334 if(A.n_cols == 1) 00335 { 00336 gemv<false, true, true>::apply(out.memptr(), B, A.memptr(), alpha, eT(1)); 00337 } 00338 else 00339 if(B.n_rows == 1) 00340 { 00341 gemv<true, true, true>::apply(out.memptr(), A, B.memptr(), alpha, eT(1)); 00342 } 00343 else 00344 { 00345 gemm<true, true, true, true>::apply(out, A, B, alpha, eT(1)); 00346 } 00347 } 00348 00349 00350 }
arma_inline u32 glue_times::mul_storage_cost | ( | const Mat< eT > & | A, | |
const Mat< eT > & | B, | |||
const bool | do_trans_A, | |||
const bool | do_trans_B | |||
) | [inline, static, inherited] |
Definition at line 357 of file glue_times_meat.hpp.
References Mat< eT >::n_cols, and Mat< eT >::n_rows.
Referenced by apply().
arma_hot void glue_times::apply | ( | Mat< eT > & | out, | |
const Mat< eT > & | A, | |||
const Mat< eT > & | B, | |||
const eT | val, | |||
const bool | do_trans_A, | |||
const bool | do_trans_B, | |||
const bool | do_scalar_times | |||
) | [inline, static, inherited] |
Definition at line 372 of file glue_times_meat.hpp.
References gemm< do_trans_A, do_trans_B, use_alpha, use_beta >::apply(), gemv< do_trans_A, use_alpha, use_beta >::apply(), Mat< eT >::memptr(), Mat< eT >::n_cols, Mat< eT >::n_rows, and Mat< eT >::set_size().
00381 { 00382 arma_extra_debug_sigprint(); 00383 00384 arma_debug_assert_mul_size(A, B, do_trans_A, do_trans_B, "matrix multiply"); 00385 00386 const u32 final_n_rows = (do_trans_A == false) ? A.n_rows : A.n_cols; 00387 const u32 final_n_cols = (do_trans_B == false) ? B.n_cols : B.n_rows; 00388 00389 out.set_size(final_n_rows, final_n_cols); 00390 00391 // TODO: thoroughly test all combinations 00392 00393 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) 00394 { 00395 if(A.n_rows == 1) 00396 { 00397 gemv<true, false, false>::apply(out.memptr(), B, A.memptr()); 00398 } 00399 else 00400 if(B.n_cols == 1) 00401 { 00402 gemv<false, false, false>::apply(out.memptr(), A, B.memptr()); 00403 } 00404 else 00405 { 00406 gemm<false, false, false, false>::apply(out, A, B); 00407 } 00408 } 00409 else 00410 if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == true) ) 00411 { 00412 if(A.n_rows == 1) 00413 { 00414 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha); 00415 } 00416 else 00417 if(B.n_cols == 1) 00418 { 00419 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha); 00420 } 00421 else 00422 { 00423 gemm<false, false, true, false>::apply(out, A, B, alpha); 00424 } 00425 } 00426 else 00427 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == false) ) 00428 { 00429 if(A.n_cols == 1) 00430 { 00431 gemv<true, false, false>::apply(out.memptr(), B, A.memptr()); 00432 } 00433 else 00434 if(B.n_cols == 1) 00435 { 00436 gemv<true, false, false>::apply(out.memptr(), A, B.memptr()); 00437 } 00438 else 00439 { 00440 gemm<true, false, false, false>::apply(out, A, B); 00441 } 00442 } 00443 else 00444 if( (do_trans_A == true) && (do_trans_B == false) && (use_alpha == true) ) 00445 { 00446 if(A.n_cols == 1) 00447 { 00448 gemv<true, true, false>::apply(out.memptr(), B, A.memptr(), alpha); 00449 } 00450 else 00451 if(B.n_cols == 1) 00452 { 00453 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha); 00454 } 00455 else 00456 { 00457 gemm<true, false, true, false>::apply(out, A, B, alpha); 00458 } 00459 } 00460 else 00461 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == false) ) 00462 { 00463 if(A.n_rows == 1) 00464 { 00465 gemv<false, false, false>::apply(out.memptr(), B, A.memptr()); 00466 } 00467 else 00468 if(B.n_rows == 1) 00469 { 00470 gemv<false, false, false>::apply(out.memptr(), A, B.memptr()); 00471 } 00472 else 00473 { 00474 gemm<false, true, false, false>::apply(out, A, B); 00475 } 00476 } 00477 else 00478 if( (do_trans_A == false) && (do_trans_B == true) && (use_alpha == true) ) 00479 { 00480 if(A.n_rows == 1) 00481 { 00482 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha); 00483 } 00484 else 00485 if(B.n_rows == 1) 00486 { 00487 gemv<false, true, false>::apply(out.memptr(), A, B.memptr(), alpha); 00488 } 00489 else 00490 { 00491 gemm<false, true, true, false>::apply(out, A, B, alpha); 00492 } 00493 } 00494 else 00495 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == false) ) 00496 { 00497 if(A.n_cols == 1) 00498 { 00499 gemv<false, false, false>::apply(out.memptr(), B, A.memptr()); 00500 } 00501 else 00502 if(B.n_rows == 1) 00503 { 00504 gemv<true, false, false>::apply(out.memptr(), A, B.memptr()); 00505 } 00506 else 00507 { 00508 gemm<true, true, false, false>::apply(out, A, B); 00509 } 00510 } 00511 else 00512 if( (do_trans_A == true) && (do_trans_B == true) && (use_alpha == true) ) 00513 { 00514 if(A.n_cols == 1) 00515 { 00516 gemv<false, true, false>::apply(out.memptr(), B, A.memptr(), alpha); 00517 } 00518 else 00519 if(B.n_rows == 1) 00520 { 00521 gemv<true, true, false>::apply(out.memptr(), A, B.memptr(), alpha); 00522 } 00523 else 00524 { 00525 gemm<true, true, true, false>::apply(out, A, B, alpha); 00526 } 00527 } 00528 }
void glue_times::apply | ( | Mat< eT > & | out, | |
const Mat< eT > & | A, | |||
const Mat< eT > & | B, | |||
const Mat< eT > & | C, | |||
const eT | val, | |||
const bool | do_trans_A, | |||
const bool | do_trans_B, | |||
const bool | do_trans_C, | |||
const bool | do_scalar_times | |||
) | [inline, static, inherited] |
Definition at line 536 of file glue_times_meat.hpp.
References apply(), and mul_storage_cost().
00547 { 00548 arma_extra_debug_sigprint(); 00549 00550 Mat<eT> tmp; 00551 00552 if( glue_times::mul_storage_cost(A, B, do_trans_A, do_trans_B) <= glue_times::mul_storage_cost(B, C, do_trans_B, do_trans_C) ) 00553 { 00554 // out = (A*B)*C 00555 glue_times::apply(tmp, A, B, alpha, do_trans_A, do_trans_B, use_alpha); 00556 glue_times::apply(out, tmp, C, eT(0), false, do_trans_C, false ); 00557 } 00558 else 00559 { 00560 // out = A*(B*C) 00561 glue_times::apply(tmp, B, C, alpha, do_trans_B, do_trans_C, use_alpha); 00562 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false ); 00563 } 00564 }
void glue_times::apply | ( | Mat< eT > & | out, | |
const Mat< eT > & | A, | |||
const Mat< eT > & | B, | |||
const Mat< eT > & | C, | |||
const Mat< eT > & | D, | |||
const eT | val, | |||
const bool | do_trans_A, | |||
const bool | do_trans_B, | |||
const bool | do_trans_C, | |||
const bool | do_trans_D, | |||
const bool | do_scalar_times | |||
) | [inline, static, inherited] |
Definition at line 572 of file glue_times_meat.hpp.
References apply(), and mul_storage_cost().
00585 { 00586 arma_extra_debug_sigprint(); 00587 00588 Mat<eT> tmp; 00589 00590 if( glue_times::mul_storage_cost(A, C, do_trans_A, do_trans_C) <= glue_times::mul_storage_cost(B, D, do_trans_B, do_trans_D) ) 00591 { 00592 // out = (A*B*C)*D 00593 glue_times::apply(tmp, A, B, C, alpha, do_trans_A, do_trans_B, do_trans_C, use_alpha); 00594 00595 glue_times::apply(out, tmp, D, eT(0), false, do_trans_D, false); 00596 } 00597 else 00598 { 00599 // out = A*(B*C*D) 00600 glue_times::apply(tmp, B, C, D, alpha, do_trans_B, do_trans_C, do_trans_D, use_alpha); 00601 00602 glue_times::apply(out, A, tmp, eT(0), do_trans_A, false, false); 00603 } 00604 }
arma_hot void glue_times_diag::apply | ( | Mat< typename T1::elem_type > & | out, | |
const Glue< T1, T2, glue_times_diag > & | X | |||
) | [inline, static, inherited] |
Definition at line 616 of file glue_times_meat.hpp.
References Glue< T1, T2, glue_type >::A, Mat< eT >::at(), Glue< T1, T2, glue_type >::B, Mat< eT >::colptr(), strip_diagmat< T1 >::do_diagmat, unwrap_check< T1 >::M, strip_diagmat< T1 >::M, Mat< eT >::n_cols, diagmat_proxy_check< T1 >::n_elem, Mat< eT >::n_rows, Mat< eT >::set_size(), and Mat< eT >::zeros().
00617 { 00618 arma_extra_debug_sigprint(); 00619 00620 typedef typename T1::elem_type eT; 00621 00622 const strip_diagmat<T1> S1(X.A); 00623 const strip_diagmat<T2> S2(X.B); 00624 00625 typedef typename strip_diagmat<T1>::stored_type T1_stripped; 00626 typedef typename strip_diagmat<T2>::stored_type T2_stripped; 00627 00628 if( (S1.do_diagmat == true) && (S2.do_diagmat == false) ) 00629 { 00630 const diagmat_proxy_check<T1_stripped> A(S1.M, out); 00631 00632 const unwrap_check<T2> tmp(X.B, out); 00633 const Mat<eT>& B = tmp.M; 00634 00635 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_rows, B.n_cols, "matrix multiply"); 00636 00637 out.set_size(A.n_elem, B.n_cols); 00638 00639 for(u32 col=0; col<B.n_cols; ++col) 00640 { 00641 eT* out_coldata = out.colptr(col); 00642 const eT* B_coldata = B.colptr(col); 00643 00644 for(u32 row=0; row<B.n_rows; ++row) 00645 { 00646 out_coldata[row] = A[row] * B_coldata[row]; 00647 } 00648 } 00649 } 00650 else 00651 if( (S1.do_diagmat == false) && (S2.do_diagmat == true) ) 00652 { 00653 const unwrap_check<T1> tmp(X.A, out); 00654 const Mat<eT>& A = tmp.M; 00655 00656 const diagmat_proxy_check<T2_stripped> B(S2.M, out); 00657 00658 arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_elem, B.n_elem, "matrix multiply"); 00659 00660 out.set_size(A.n_rows, B.n_elem); 00661 00662 for(u32 col=0; col<A.n_cols; ++col) 00663 { 00664 const eT val = B[col]; 00665 00666 eT* out_coldata = out.colptr(col); 00667 const eT* A_coldata = A.colptr(col); 00668 00669 for(u32 row=0; row<A.n_rows; ++row) 00670 { 00671 out_coldata[row] = A_coldata[row] * val; 00672 } 00673 } 00674 } 00675 else 00676 if( (S1.do_diagmat == true) && (S2.do_diagmat == true) ) 00677 { 00678 const diagmat_proxy_check<T1_stripped> A(S1.M, out); 00679 const diagmat_proxy_check<T2_stripped> B(S2.M, out); 00680 00681 arma_debug_assert_mul_size(A.n_elem, A.n_elem, B.n_elem, B.n_elem, "matrix multiply"); 00682 00683 out.zeros(A.n_elem, A.n_elem); 00684 00685 for(u32 i=0; i<A.n_elem; ++i) 00686 { 00687 out.at(i,i) = A[i] * B[i]; 00688 } 00689 } 00690 }