This commit is contained in:
Philippe Tillet
2015-07-09 11:40:26 -04:00
parent 4e25e20206
commit a676b15448
3 changed files with 34 additions and 24 deletions

View File

@@ -201,20 +201,31 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
unsigned int npA = p_.mL/(A_trans_=='N'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1); unsigned int npA = p_.mL/(A_trans_=='N'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
unsigned int npB = p_.nL/(B_trans_=='T'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1); unsigned int npB = p_.nL/(B_trans_=='T'?p_.local_fetch_0*p_.simd_width:p_.local_fetch_1);
if (A_trans_=='N')
stream << "A += (gidx*" << p_.mL/p_.simd_width << ")" << ASTRIDE1 << " + idyT*Ald + offz*Ald;" << std::endl;
else
stream << "A += idxT" << ASTRIDE1 << " + gidx*" << p_.mL/p_.simd_width << "*Ald + offz;" << std::endl;
if(B_trans_=='T')
stream << "B += (gidy*" << p_.nL/p_.simd_width << ")" << BSTRIDE1 << " + idyT*Bld + offz*Bld;" << std::endl;
else
stream << "B += idxT" << BSTRIDE1 << " + gidy*" << p_.nL/p_.simd_width << "*Bld + offz;" << std::endl;
stream << "__global " << vdtype << "* Ai[" << npA << "];" << std::endl; stream << "__global " << vdtype << "* Ai[" << npA << "];" << std::endl;
for(unsigned int i = 0 ; i < npA ; ++i) for(unsigned int i = 0 ; i < npA ; ++i)
if (A_trans_=='N') if (A_trans_=='N')
stream << "Ai[" << i << "] = A + (gidx*" << p_.mL/p_.simd_width << ")" << ASTRIDE1 << " + idyT*Ald + offz*Ald;" << std::endl; stream << "Ai[" << i << "] = A;" << std::endl;
else else
stream << "Ai[" << i << "] = A + idxT" << ASTRIDE1 << " + gidx*" << p_.mL/p_.simd_width << "*Ald + offz;" << std::endl; stream << "Ai[" << i << "] = A;" << std::endl;
stream << "__global " << vdtype << "* Bi[" << npB << "];" << std::endl; stream << "__global " << vdtype << "* Bi[" << npB << "];" << std::endl;
for(unsigned int i = 0 ; i < npB ; ++i) for(unsigned int i = 0 ; i < npB ; ++i)
if(B_trans_=='T') if(B_trans_=='T')
stream << "Bi[" << i << "] = B + (gidy*" << p_.nL/p_.simd_width << ")" << BSTRIDE1 << " + idyT*Bld + offz*Bld;" << std::endl; stream << "Bi[" << i << "] = B;" << std::endl;
else else
stream << "Bi[" << i << "] = B + idxT" << BSTRIDE1 << " + gidy*" << p_.nL/p_.simd_width << "*Bld + offz;" << std::endl; stream << "Bi[" << i << "] = B;" << std::endl;
switch (p_.A_fetching_policy) switch (p_.A_fetching_policy)
{ {
@@ -521,6 +532,7 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
if(M==0 || N==0 || K==0) if(M==0 || N==0 || K==0)
return; return;
char gemm_name[32] = {"gemm"}; char gemm_name[32] = {"gemm"};
char reduce_name[32] = {"reduce"}; char reduce_name[32] = {"reduce"};
strcat(gemm_name, suffix); strcat(gemm_name, suffix);
@@ -555,15 +567,14 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
helper.set_arguments(alpha.dtype(), alpha.values()); helper.set_arguments(alpha.dtype(), alpha.values());
gemm.setArg(current_arg++, A.data()); gemm.setArg(current_arg++, A.data());
gemm.setSizeArg(current_arg++, A.ld()*A.stride()[1]/p_.simd_width); gemm.setSizeArg(current_arg++, A.ld()*A.stride()[1]/p_.simd_width);
gemm.setSizeArg(current_arg++, A.start()[0] + A.start()[1]*A.ld()/p_.simd_width); gemm.setSizeArg(current_arg++, (A.start()[0] + A.start()[1]*A.ld())/p_.simd_width);
gemm.setSizeArg(current_arg++, A.stride()[0]); gemm.setSizeArg(current_arg++, A.stride()[0]);
gemm.setArg(current_arg++, B.data()); gemm.setArg(current_arg++, B.data());
gemm.setSizeArg(current_arg++, B.ld()*B.stride()[1]/p_.simd_width); gemm.setSizeArg(current_arg++, B.ld()*B.stride()[1]/p_.simd_width);
gemm.setSizeArg(current_arg++, B.start()[0] + B.start()[1]*B.ld()/p_.simd_width); gemm.setSizeArg(current_arg++, (B.start()[0] + B.start()[1]*B.ld())/p_.simd_width);
gemm.setSizeArg(current_arg++, B.stride()[0]); gemm.setSizeArg(current_arg++, B.stride()[0]);
helper.set_arguments(beta.dtype(), beta.values()); helper.set_arguments(beta.dtype(), beta.values());
options.enqueue(program.context(), gemm, global, local); options.enqueue(program.context(), gemm, global, local);
@@ -674,20 +685,19 @@ mproduct_parameters::mproduct_parameters(unsigned int simd_width
execution_options_type const & options = ctr.execution_options(); execution_options_type const & options = ctr.execution_options();
// if (ldstrideA> 1 || ldstrideB > 1 || ldstrideC > 1 if (ldstrideA> 1 || ldstrideB > 1 || ldstrideC > 1
// || (p_.simd_width>1 && (ldstartA % p_.simd_width > 0 || ldstartB % p_.simd_width > 0 || pA->ld()%p_.simd_width > 0 || pB->ld()%p_.simd_width > 0))) || (p_.simd_width>1 && (ldstartA % p_.simd_width > 0 || ldstartB % p_.simd_width > 0 || pA->ld()%p_.simd_width > 0 || pB->ld()%p_.simd_width > 0)))
// { {
// fallback.enqueue_block(queue, M, N, K, create_slice(*pA, 0, M, 0, K, swap_A), create_slice(*pB, 0, K, 0, N, swap_B), fallback.enqueue_block(queue, M, N, K, *pA, *pB, *pC, alpha, beta, program, "fallback", options);
// create_slice(*pC, 0, M, 0, N, false), alpha, beta, program, "fallback", options); }
// return; else
// } {
int_t lK = K / (p_.kL*p_.depth) * p_.kL*p_.depth; int_t lK = K / (p_.kL*p_.depth) * p_.kL*p_.depth;
value_scalar _1(1, dtype); value_scalar _1(1, dtype);
enqueue_block(queue, M, N, lK, create_slice(*pA, 0, M, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, beta, program, suffix, options); enqueue_block(queue, M, N, lK, create_slice(*pA, 0, M, 0, lK, swap_A), create_slice(*pB, 0, lK, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, beta, program, suffix, options);
fallback.enqueue_block(queue, M, N, K - lK, create_slice(*pA, 0, M, lK, K, swap_A), create_slice(*pB, lK, K, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, _1, program, "fallback", options); fallback.enqueue_block(queue, M, N, K - lK, create_slice(*pA, 0, M, lK, K, swap_A), create_slice(*pB, lK, K, 0, N, swap_B), create_slice(*pC, 0, M, 0, N, false), alpha, _1, program, "fallback", options);
} }
}
// //
mproduct_nn::mproduct_nn(unsigned int simd mproduct_nn::mproduct_nn(unsigned int simd

View File

@@ -258,7 +258,7 @@ std::map<std::pair<expression_type, numeric_type>, tools::shared_ptr<base> > ini
res[std::make_pair(COL_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new mreduction_cols(1, 8, 8, 64, 8, FETCH_FROM_GLOBAL_STRIDED)); res[std::make_pair(COL_WISE_REDUCTION_TYPE, DTYPE)] = ptr_t(new mreduction_cols(1, 8, 8, 64, 8, FETCH_FROM_GLOBAL_STRIDED));
res[std::make_pair(MATRIX_PRODUCT_NN_TYPE, DTYPE)] = ptr_t(new mproduct_nn(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true)); res[std::make_pair(MATRIX_PRODUCT_NN_TYPE, DTYPE)] = ptr_t(new mproduct_nn(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TN_TYPE, DTYPE)] = ptr_t(new mproduct_tn(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true)); res[std::make_pair(MATRIX_PRODUCT_TN_TYPE, DTYPE)] = ptr_t(new mproduct_tn(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_NT_TYPE, DTYPE)] = ptr_t(new mproduct_nt(4, 8, 16, 8, 1, 8, 2, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true)); res[std::make_pair(MATRIX_PRODUCT_NT_TYPE, DTYPE)] = ptr_t(new mproduct_nt(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
res[std::make_pair(MATRIX_PRODUCT_TT_TYPE, DTYPE)] = ptr_t(new mproduct_tt(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true)); res[std::make_pair(MATRIX_PRODUCT_TT_TYPE, DTYPE)] = ptr_t(new mproduct_tt(1, 8, 16, 8, 1, 8, 1, 8, FETCH_FROM_LOCAL, FETCH_FROM_LOCAL, 8, 8, true));
} }
return res; return res;

View File

@@ -106,7 +106,7 @@ void test_impl(T epsilon, ad::driver::Context const & ctx)
INIT_MATRIX(M, SUBM, 5, 1, N, SUBN, 7, 1, cC, C, ctx); INIT_MATRIX(M, SUBM, 5, 1, N, SUBN, 7, 1, cC, C, ctx);
INIT_MATRIX(M, SUBM, 8, 1, K, SUBK, 4, 1, cA, A, ctx); INIT_MATRIX(M, SUBM, 8, 1, K, SUBK, 4, 1, cA, A, ctx);
INIT_MATRIX(K, SUBK, 9, 1, N, SUBN, 6, 1, cB, B, ctx); INIT_MATRIX(K, SUBK, 9, 1, N, SUBN, 6, 1, cB, B, ctx);
// test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, clBLAS, "BLAS, FULL"); test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, clBLAS, "BLAS, FULL");
test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, clBLAS, "BLAS, SUB"); test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, clBLAS, "BLAS, SUB");
} }
@@ -114,8 +114,8 @@ void test_impl(T epsilon, ad::driver::Context const & ctx)
INIT_MATRIX(M, SUBM, 5, 2, N, SUBN, 7, 3, cC, C, ctx); INIT_MATRIX(M, SUBM, 5, 2, N, SUBN, 7, 3, cC, C, ctx);
INIT_MATRIX(M, SUBM, 8, 2, K, SUBK, 4, 3, cA, A, ctx); INIT_MATRIX(M, SUBM, 8, 2, K, SUBK, 4, 3, cA, A, ctx);
INIT_MATRIX(K, SUBK, 9, 4, N, SUBN, 6, 2, cB, B, ctx); INIT_MATRIX(K, SUBK, 9, 4, N, SUBN, 6, 2, cB, B, ctx);
// test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, CPP, "C++, FULL"); test_impl(epsilon, cC_full, cA_full, cB_full, C_full, A_full, AT_full, B_full, BT_full, CPP, "C++, FULL");
// test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, CPP, "C++, SUB"); test_impl(epsilon, cC_slice, cA_slice, cB_slice, C_slice, A_slice, AT_slice, B_slice, BT_slice, CPP, "C++, SUB");
} }
} }