[intermediate representation] added ternary_inst
This commit is contained in:
@@ -60,11 +60,27 @@ void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32
|
||||
int1 checkc1[TN];\
|
||||
int1 checkc[TM, TN];\
|
||||
for(k = K; k > 0; k = k - TK){\
|
||||
int1 checka[TM, TK] = (k > bound);\
|
||||
int1 checkb[TN, TK] = (k > bound);\
|
||||
int1 checka0[TM];\
|
||||
int1 checka1[TK];\
|
||||
int1 checkb0[TN];\
|
||||
int1 checkb1[TK];\
|
||||
C = dot(a, b, C);\
|
||||
pa = pa + TK*M;\
|
||||
pb = pb + TK*K;\
|
||||
a = *pa;\
|
||||
b = *pb;\
|
||||
@checka a = *pa;\
|
||||
@checkb b = *pb;\
|
||||
if(k > bound)\
|
||||
continue;\
|
||||
checka0 = rxa < M;\
|
||||
checka1 = rka < k;\
|
||||
checkb0 = ryb < N;\
|
||||
checkb1 = rkb < k;\
|
||||
checka = checka0[:, newaxis] && checka1[newaxis, :];\
|
||||
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\
|
||||
@checka a = *pa;\
|
||||
@checkb b = *pb;\
|
||||
}\
|
||||
checkc0 = rxc < M;\
|
||||
checkc1 = ryc < N;\
|
||||
@@ -243,14 +259,13 @@ int main() {
|
||||
|
||||
|
||||
// run passes
|
||||
triton::ir::print(module, std::cout);
|
||||
exit(EXIT_FAILURE);
|
||||
buffer_info.run(module);
|
||||
shared.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run();
|
||||
barriers.run(module);
|
||||
vectorize.run(module);
|
||||
triton::ir::print(module, std::cout);
|
||||
selection.run(module, llvm_module);
|
||||
|
||||
// llvm source
|
||||
@@ -260,7 +275,7 @@ int main() {
|
||||
manager.run(llvm_module);
|
||||
|
||||
std::string src = generate_machine_code(llvm_module, "nvptx64-nvidia-cuda", compute_data_layout(true, true));
|
||||
std::cout << src << std::endl;
|
||||
// std::cout << src << std::endl;
|
||||
|
||||
// compile machine code
|
||||
CUdevice cu_device;
|
||||
@@ -308,6 +323,9 @@ int main() {
|
||||
checkCudaErrors(cuMemcpyDtoH(c.data(), d_c, sizeof(numeric_t) * c.size()));
|
||||
simple_gemm(rc, a, b, M, N, K);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(std::abs(c[i] - rc[i])/std::max(c[i], rc[i]) > 1e-4)
|
||||
if(std::abs(c[i] - rc[i])/std::max(c[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << c[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
std::cout << "Pass!" << std::endl;
|
||||
}
|
||||
|
Reference in New Issue
Block a user