[code generation] added masked loads

This commit is contained in:
Philippe Tillet
2019-02-15 11:14:50 -05:00
parent 896e856b07
commit 5f5959dc6e
11 changed files with 128 additions and 54 deletions

View File

@@ -36,7 +36,7 @@ extern translation_unit *ast_root;
const char src[] =
"\
void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K, int32 bound){\
int32 rxa[16] = get_global_range[16](0);\
int32 ryb[16] = get_global_range[16](1);\
int32 rka[8] = 0 ... 8;\
@@ -50,15 +50,17 @@ void test(fp32 *a, fp32 *b, fp32 *c, int32 M, int32 N, int32 K){\
fp32* pc[16, 16] = c + rxc[:, newaxis] + ryc[newaxis, :]*M;\
fp32 a[16, 8] = *pa;\
fp32 b[16, 8] = *pb;\
int1 checkc0[16] = (rxc < M);\
int1 checkc1[16] = (ryc < N);\
int1 checkc0[16] = rxc < M;\
int1 checkc1[16] = ryc < N;\
int1 checkc[16, 16] = checkc0[:, newaxis] && checkc1[newaxis, :];\
for(k = K; k > 0; k = k - 8){\
int1 sanitya[16, 8] = (k >= bound);\
int1 sanityb[16, 8] = (k >= bound);\
C = dot(a, b, C);\
pa = pa + 8*M;\
pb = pb + 8*K;\
a = *pa;\
b = *pb;\
@sanitya a = *pa;\
@sanityb b = *pb;\
}\
@checkc *pc = C;\
}\
@@ -201,6 +203,8 @@ int main() {
for(auto &e: x.second)
std::cout << e << std::endl;
}
if(errors.size())
exit(EXIT_FAILURE);
// run passes
shared.run(module);
@@ -213,7 +217,7 @@ int main() {
// llvm source
llvm::legacy::PassManager manager;
// manager.add(llvm::createPrintModulePass(llvm::outs()));
manager.add(llvm::createPrintModulePass(llvm::outs()));
manager.add(llvm::createVerifierPass(true));
manager.run(llvm_module);
@@ -233,6 +237,7 @@ int main() {
// Allocate buffers
typedef float numeric_t;
size_t M = 128, N = 128, K = 128;
size_t bound = 8;
std::vector<numeric_t> c(M*N);
std::vector<numeric_t> rc(M*N);
std::vector<numeric_t> a(M*K);
@@ -252,13 +257,13 @@ int main() {
checkCudaErrors(cuMemcpyHtoD(d_b, b.data(), sizeof(numeric_t) * b.size()));
checkCudaErrors(cuMemcpyHtoD(d_c, c.data(), sizeof(numeric_t) * c.size()));
// Launch kernel
void *args[] = { &d_a, &d_b, &d_c, &M, &N, &K};
void *args[] = { &d_a, &d_b, &d_c, &M, &N, &K, &bound};
int num_regs;
cuFuncGetAttribute(&num_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, cu_kernel);
unsigned TM = 16;
unsigned TN = 16;
unsigned nthreads = 32;
checkCudaErrors(cuLaunchKernel(cu_kernel, M/TM, N/TN, 1, nthreads, 1, 1, 0, cu_stream, args, NULL));
checkCudaErrors(cuLaunchKernel(cu_kernel, (M + TM - 1)/TM, (N + TN - 1)/TN, 1, nthreads, 1, 1, 0, cu_stream, args, NULL));
checkCudaErrors(cuStreamSynchronize(cu_stream));
// Write back
checkCudaErrors(cuMemcpyDtoH(c.data(), d_c, sizeof(numeric_t) * c.size()));