Backend: A lot of bugfixes in dot() for handling shapes better

This commit is contained in:
Philippe Tillet
2015-06-30 17:55:57 -04:00
parent e7cabf65ac
commit cf2dba43ef
12 changed files with 108 additions and 73 deletions

View File

@@ -18,6 +18,8 @@ void test_row_wise_reduction(T epsilon, simple_vector_base<T> & cy, simple_matri
simple_vector<T> bufy(M);
simple_vector<T> bufx(N);
T alpha = 4.2, beta = 1.8;
ad::driver::CommandQueue queue = ad::driver::queues[y.context()][0];
T yi = 0, xi = 0;
@@ -32,6 +34,7 @@ void test_row_wise_reduction(T epsilon, simple_vector_base<T> & cy, simple_matri
ASSIGNMENT;\
}\
GPU_REDUCTION;\
queue.synchronize();\
ad::copy(RES, BUF.data());\
if(diff(CRES, BUF, epsilon))\
{\
@@ -47,24 +50,24 @@ void test_row_wise_reduction(T epsilon, simple_vector_base<T> & cy, simple_matri
cl_command_queue clqueue = (*queue.handle().cl)();
TEST_OPERATION("GEMV(ROW, NoTrans)", M, N, yi+=cA(i,j)*cx[j], cy[i] = yi,
BLAS<T>::F(clblasSgemv, clblasDgemv)(clblasRowMajor, clblasTrans, N, M, 1, CHANDLE(A), OFF(A), LD(A),
CHANDLE(x), x.start()[0], x.stride()[0], 0, CHANDLE(y), y.start()[0], y.stride()[0],
TEST_OPERATION("GEMV(ROW, NoTrans)", M, N, yi+=cA(i,j)*cx[j], cy[i] = alpha*yi + beta*cy[i],
BLAS<T>::F(clblasSgemv, clblasDgemv)(clblasRowMajor, clblasTrans, N, M, alpha, CHANDLE(A), OFF(A), LD(A),
CHANDLE(x), x.start()[0], x.stride()[0], beta, CHANDLE(y), y.start()[0], y.stride()[0],
1, &clqueue, 0, NULL, NULL), y, bufy, cy);
TEST_OPERATION("GEMV(ROW, Trans)", N, M, xi+=cA(j,i)*cy[j], cx[i] = xi,
BLAS<T>::F(clblasSgemv, clblasDgemv)(clblasRowMajor, clblasNoTrans, M, N, 1, CHANDLE(A), OFF(A), LD(A),
CHANDLE(y), y.start()[0], y.stride()[0], 0, CHANDLE(x), x.start()[0], x.stride()[0],
TEST_OPERATION("GEMV(ROW, Trans)", N, M, xi+=cA(j,i)*cy[j], cx[i] = alpha*xi + beta*cx[i],
BLAS<T>::F(clblasSgemv, clblasDgemv)(clblasRowMajor, clblasNoTrans, N, M, alpha, CHANDLE(A), OFF(A), LD(A),
CHANDLE(y), y.start()[0], y.stride()[0], beta, CHANDLE(x), x.start()[0], x.stride()[0],
1, &clqueue, 0, NULL, NULL), x, bufx, cx);
TEST_OPERATION("GEMV(COL, NoTrans)", M, N, yi+=cA(i,j)*cx[j], cy[i] = yi,
BLAS<T>::F(clblasSgemv, clblasDgemv)(clblasColumnMajor, clblasNoTrans, M, N, 1, CHANDLE(A), OFF(A), LD(A),
TEST_OPERATION("GEMV(COL, NoTrans)", M, N, yi+=cA(i,j)*cx[j], cy[i] = alpha*yi + beta*cy[i],
BLAS<T>::F(clblasSgemv, clblasDgemv)(clblasColumnMajor, clblasNoTrans, M, N, alpha, CHANDLE(A), OFF(A), LD(A),
CHANDLE(x), x.start()[0], x.stride()[0], 0, CHANDLE(y), y.start()[0], y.stride()[0],
1, &clqueue, 0, NULL, NULL), y, bufy, cy);
TEST_OPERATION("GEMV(COL, Trans)", N, M, xi+=cA(j,i)*cy[j], cx[i] = xi,
BLAS<T>::F(clblasSgemv, clblasDgemv)(clblasColumnMajor, clblasTrans, N, M, 1, CHANDLE(A), OFF(A), LD(A),
CHANDLE(y), y.start()[0], y.stride()[0], 0, CHANDLE(x), x.start()[0], x.stride()[0],
TEST_OPERATION("GEMV(COL, Trans)", N, M, xi+=cA(j,i)*cy[j], cx[i] = alpha*xi + beta*cx[i],
BLAS<T>::F(clblasSgemv, clblasDgemv)(clblasColumnMajor, clblasTrans, M, N, alpha, CHANDLE(A), OFF(A), LD(A),
CHANDLE(y), y.start()[0], y.stride()[0], beta, CHANDLE(x), x.start()[0], x.stride()[0],
1, &clqueue, 0, NULL, NULL), x, bufx, cx);
}
else
@@ -102,6 +105,7 @@ void test_impl(T epsilon, ad::driver::Context const & ctx)
int main()
{
clblasSetup();
auto data = ad::driver::queues.contexts();
for(const auto & elem : data)
{
@@ -114,5 +118,6 @@ int main()
test_impl<double>(1e-9, elem.first);
std::cout << "---" << std::endl;
}
clblasTeardown();
return EXIT_SUCCESS;
}