[code generation] added masked loads
This commit is contained in:
@@ -36,7 +36,7 @@ extern translation_unit *ast_root;
|
||||
|
||||
const char src[] =
|
||||
"\
|
||||
void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
|
||||
void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\
|
||||
int32 rxa[16] = get_global_range[16](0);\
|
||||
int32 ryb[16] = get_global_range[16](1);\
|
||||
int32 rka[8] = 0 ... 8;\
|
||||
@@ -50,15 +50,17 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
|
||||
fp32* pc[16, 16] = c + rxc[:, newaxis] + ryc[newaxis, :]*M;\
|
||||
fp32 a[16, 8] = *pa;\
|
||||
fp32 b[16, 8] = *pb;\
|
||||
int1 checkc0[16] = (rxc < M);\
|
||||
int1 checkc1[16] = (ryc < N);\
|
||||
int1 checkc0[16] = rxc < M;\
|
||||
int1 checkc1[16] = ryc < N;\
|
||||
int1 checkc[16, 16] = checkc0[:, newaxis] && checkc1[newaxis, :];\
|
||||
for(k = K; k > 0; k = k - 8){\
|
||||
int1 sanitya[16, 8] = (k >= bound);\
|
||||
int1 sanityb[16, 8] = (k >= bound);\
|
||||
C = dot(a, b, C);\
|
||||
pa = pa + 8*M;\
|
||||
pb = pb + 8*K;\
|
||||
a = *pa;\
|
||||
b = *pb;\
|
||||
@sanitya a = *pa;\
|
||||
@sanityb b = *pb;\
|
||||
}\
|
||||
@checkc *pc = C;\
|
||||
}\
|
||||
@@ -201,6 +203,8 @@ int main() {
|
||||
for(auto &e: x.second)
|
||||
std::cout << e << std::endl;
|
||||
}
|
||||
if(errors.size())
|
||||
exit(EXIT_FAILURE);
|
||||
|
||||
// run passes
|
||||
shared.run(module);
|
||||
@@ -213,7 +217,7 @@ int main() {
|
||||
|
||||
// llvm source
|
||||
llvm::legacy::PassManager manager;
|
||||
// manager.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
manager.add(llvm::createPrintModulePass(llvm::outs()));
|
||||
manager.add(llvm::createVerifierPass(true));
|
||||
manager.run(llvm_module);
|
||||
|
||||
@@ -233,6 +237,7 @@ int main() {
|
||||
// Allocate buffers
|
||||
typedef float numeric_t;
|
||||
size_t M = 128, N = 128, K = 128;
|
||||
size_t bound = 8;
|
||||
std::vector<numeric_t> c(M*N);
|
||||
std::vector<numeric_t> rc(M*N);
|
||||
std::vector<numeric_t> a(M*K);
|
||||
@@ -252,13 +257,13 @@ int main() {
|
||||
checkCudaErrors(cuMemcpyHtoD(d_b, b.data(), sizeof(numeric_t) * b.size()));
|
||||
checkCudaErrors(cuMemcpyHtoD(d_c, c.data(), sizeof(numeric_t) * c.size()));
|
||||
// Launch kernel
|
||||
void *args[] = { &d_a, &d_b, &d_c, &M, &N, &K};
|
||||
void *args[] = { &d_a, &d_b, &d_c, &M, &N, &K, &bound};
|
||||
int num_regs;
|
||||
cuFuncGetAttribute(&num_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, cu_kernel);
|
||||
unsigned TM = 16;
|
||||
unsigned TN = 16;
|
||||
unsigned nthreads = 32;
|
||||
checkCudaErrors(cuLaunchKernel(cu_kernel, M/TM, N/TN, 1, nthreads, 1, 1, 0, cu_stream, args, NULL));
|
||||
checkCudaErrors(cuLaunchKernel(cu_kernel, (M + TM - 1)/TM, (N + TN - 1)/TN, 1, nthreads, 1, 1, 0, cu_stream, args, NULL));
|
||||
checkCudaErrors(cuStreamSynchronize(cu_stream));
|
||||
// Write back
|
||||
checkCudaErrors(cuMemcpyDtoH(c.data(), d_c, sizeof(numeric_t) * c.size()));
|
||||
|
Reference in New Issue
Block a user