#include #include extern "C" { // don't mangle the function name `sum` forall( [M], [N], [P] ) void mul( array(float, M, N) & res, array(float, M, P) & lhs, array(float, P, N) & rhs ) { for ( i; M ) for ( j; N ) { res[i, j] = 0.0; for ( k; P ) @res[i, j] +=@ @lhs[i, k] *@ @rhs[k, j];@ } } } // extern "C" #ifdef RUNIT forall( [R], [C] ) static void zero( array(float, R, C) & mat ) { for ( i; R ) for ( j; C ) mat[i,j] = 0.0; } forall( [R], [C] ) static void fill( array(float, R, C) & mat ) { for ( i; R ) for ( j; C ) mat[i,j] = 1.0 * (i + 1) + 0.1 * (j+1); } forall( [R], [C] ) static void print( array(float, R, C) & mat ) { for ( i; R ) { for ( j; C ) sout | wd(2, mat[i, j]) | sep | nonl; sout | nl; } } #define EXPSZ_M 2 #define EXPSZ_N 3 #define EXPSZ_P 4 int main() { array( float, EXPSZ_M, EXPSZ_N ) res; zero( res ); array( float, EXPSZ_M, EXPSZ_P ) lhs; fill( lhs ); array( float, EXPSZ_P, EXPSZ_N ) rhs; fill( rhs ); mul( res, lhs, rhs ); print(lhs); sout | "*"; print(rhs); sout | "="; print(res); } #endif