Backend: A lot of bugfixes in dot() for handling shapes better
This commit is contained in:
@@ -18,6 +18,8 @@ void test_row_wise_reduction(T epsilon, simple_vector_base<T> & cy, simple_matri
|
||||
simple_vector<T> bufy(M);
|
||||
simple_vector<T> bufx(N);
|
||||
|
||||
T alpha = 4.2, beta = 1.8;
|
||||
|
||||
ad::driver::CommandQueue queue = ad::driver::queues[y.context()][0];
|
||||
|
||||
T yi = 0, xi = 0;
|
||||
@@ -32,6 +34,7 @@ void test_row_wise_reduction(T epsilon, simple_vector_base<T> & cy, simple_matri
|
||||
ASSIGNMENT;\
|
||||
}\
|
||||
GPU_REDUCTION;\
|
||||
queue.synchronize();\
|
||||
ad::copy(RES, BUF.data());\
|
||||
if(diff(CRES, BUF, epsilon))\
|
||||
{\
|
||||
@@ -47,24 +50,24 @@ void test_row_wise_reduction(T epsilon, simple_vector_base<T> & cy, simple_matri
|
||||
cl_command_queue clqueue = (*queue.handle().cl)();
|
||||
|
||||
|
||||
TEST_OPERATION("GEMV(ROW, NoTrans)", M, N, yi+=cA(i,j)*cx[j], cy[i] = yi,
|
||||
BLAS<T>::F(clblasSgemv, clblasDgemv)(clblasRowMajor, clblasTrans, N, M, 1, CHANDLE(A), OFF(A), LD(A),
|
||||
CHANDLE(x), x.start()[0], x.stride()[0], 0, CHANDLE(y), y.start()[0], y.stride()[0],
|
||||
TEST_OPERATION("GEMV(ROW, NoTrans)", M, N, yi+=cA(i,j)*cx[j], cy[i] = alpha*yi + beta*cy[i],
|
||||
BLAS<T>::F(clblasSgemv, clblasDgemv)(clblasRowMajor, clblasTrans, N, M, alpha, CHANDLE(A), OFF(A), LD(A),
|
||||
CHANDLE(x), x.start()[0], x.stride()[0], beta, CHANDLE(y), y.start()[0], y.stride()[0],
|
||||
1, &clqueue, 0, NULL, NULL), y, bufy, cy);
|
||||
|
||||
TEST_OPERATION("GEMV(ROW, Trans)", N, M, xi+=cA(j,i)*cy[j], cx[i] = xi,
|
||||
BLAS<T>::F(clblasSgemv, clblasDgemv)(clblasRowMajor, clblasNoTrans, M, N, 1, CHANDLE(A), OFF(A), LD(A),
|
||||
CHANDLE(y), y.start()[0], y.stride()[0], 0, CHANDLE(x), x.start()[0], x.stride()[0],
|
||||
TEST_OPERATION("GEMV(ROW, Trans)", N, M, xi+=cA(j,i)*cy[j], cx[i] = alpha*xi + beta*cx[i],
|
||||
BLAS<T>::F(clblasSgemv, clblasDgemv)(clblasRowMajor, clblasNoTrans, N, M, alpha, CHANDLE(A), OFF(A), LD(A),
|
||||
CHANDLE(y), y.start()[0], y.stride()[0], beta, CHANDLE(x), x.start()[0], x.stride()[0],
|
||||
1, &clqueue, 0, NULL, NULL), x, bufx, cx);
|
||||
|
||||
TEST_OPERATION("GEMV(COL, NoTrans)", M, N, yi+=cA(i,j)*cx[j], cy[i] = yi,
|
||||
BLAS<T>::F(clblasSgemv, clblasDgemv)(clblasColumnMajor, clblasNoTrans, M, N, 1, CHANDLE(A), OFF(A), LD(A),
|
||||
TEST_OPERATION("GEMV(COL, NoTrans)", M, N, yi+=cA(i,j)*cx[j], cy[i] = alpha*yi + beta*cy[i],
|
||||
BLAS<T>::F(clblasSgemv, clblasDgemv)(clblasColumnMajor, clblasNoTrans, M, N, alpha, CHANDLE(A), OFF(A), LD(A),
|
||||
CHANDLE(x), x.start()[0], x.stride()[0], 0, CHANDLE(y), y.start()[0], y.stride()[0],
|
||||
1, &clqueue, 0, NULL, NULL), y, bufy, cy);
|
||||
|
||||
TEST_OPERATION("GEMV(COL, Trans)", N, M, xi+=cA(j,i)*cy[j], cx[i] = xi,
|
||||
BLAS<T>::F(clblasSgemv, clblasDgemv)(clblasColumnMajor, clblasTrans, N, M, 1, CHANDLE(A), OFF(A), LD(A),
|
||||
CHANDLE(y), y.start()[0], y.stride()[0], 0, CHANDLE(x), x.start()[0], x.stride()[0],
|
||||
TEST_OPERATION("GEMV(COL, Trans)", N, M, xi+=cA(j,i)*cy[j], cx[i] = alpha*xi + beta*cx[i],
|
||||
BLAS<T>::F(clblasSgemv, clblasDgemv)(clblasColumnMajor, clblasTrans, M, N, alpha, CHANDLE(A), OFF(A), LD(A),
|
||||
CHANDLE(y), y.start()[0], y.stride()[0], beta, CHANDLE(x), x.start()[0], x.stride()[0],
|
||||
1, &clqueue, 0, NULL, NULL), x, bufx, cx);
|
||||
}
|
||||
else
|
||||
@@ -102,6 +105,7 @@ void test_impl(T epsilon, ad::driver::Context const & ctx)
|
||||
|
||||
int main()
|
||||
{
|
||||
clblasSetup();
|
||||
auto data = ad::driver::queues.contexts();
|
||||
for(const auto & elem : data)
|
||||
{
|
||||
@@ -114,5 +118,6 @@ int main()
|
||||
test_impl<double>(1e-9, elem.first);
|
||||
std::cout << "---" << std::endl;
|
||||
}
|
||||
clblasTeardown();
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
Reference in New Issue
Block a user