[intermediate representation] added ternary_inst
This commit is contained in:
@@ -60,11 +60,27 @@ void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32
|
|||||||
int1 checkc1[TN];\
|
int1 checkc1[TN];\
|
||||||
int1 checkc[TM, TN];\
|
int1 checkc[TM, TN];\
|
||||||
for(k = K; k > 0; k = k - TK){\
|
for(k = K; k > 0; k = k - TK){\
|
||||||
|
int1 checka[TM, TK] = (k > bound);\
|
||||||
|
int1 checkb[TN, TK] = (k > bound);\
|
||||||
|
int1 checka0[TM];\
|
||||||
|
int1 checka1[TK];\
|
||||||
|
int1 checkb0[TN];\
|
||||||
|
int1 checkb1[TK];\
|
||||||
C = dot(a, b, C);\
|
C = dot(a, b, C);\
|
||||||
pa = pa + TK*M;\
|
pa = pa + TK*M;\
|
||||||
pb = pb + TK*K;\
|
pb = pb + TK*K;\
|
||||||
a = *pa;\
|
@checka a = *pa;\
|
||||||
b = *pb;\
|
@checkb b = *pb;\
|
||||||
|
if(k > bound)\
|
||||||
|
continue;\
|
||||||
|
checka0 = rxa < M;\
|
||||||
|
checka1 = rka < k;\
|
||||||
|
checkb0 = ryb < N;\
|
||||||
|
checkb1 = rkb < k;\
|
||||||
|
checka = checka0[:, newaxis] && checka1[newaxis, :];\
|
||||||
|
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\
|
||||||
|
@checka a = *pa;\
|
||||||
|
@checkb b = *pb;\
|
||||||
}\
|
}\
|
||||||
checkc0 = rxc < M;\
|
checkc0 = rxc < M;\
|
||||||
checkc1 = ryc < N;\
|
checkc1 = ryc < N;\
|
||||||
@@ -243,14 +259,13 @@ int main() {
|
|||||||
|
|
||||||
|
|
||||||
// run passes
|
// run passes
|
||||||
triton::ir::print(module, std::cout);
|
|
||||||
exit(EXIT_FAILURE);
|
|
||||||
buffer_info.run(module);
|
buffer_info.run(module);
|
||||||
shared.run(module);
|
shared.run(module);
|
||||||
liveness.run(module);
|
liveness.run(module);
|
||||||
allocation.run();
|
allocation.run();
|
||||||
barriers.run(module);
|
barriers.run(module);
|
||||||
vectorize.run(module);
|
vectorize.run(module);
|
||||||
|
triton::ir::print(module, std::cout);
|
||||||
selection.run(module, llvm_module);
|
selection.run(module, llvm_module);
|
||||||
|
|
||||||
// llvm source
|
// llvm source
|
||||||
@@ -260,7 +275,7 @@ int main() {
|
|||||||
manager.run(llvm_module);
|
manager.run(llvm_module);
|
||||||
|
|
||||||
std::string src = generate_machine_code(llvm_module, "nvptx64-nvidia-cuda", compute_data_layout(true, true));
|
std::string src = generate_machine_code(llvm_module, "nvptx64-nvidia-cuda", compute_data_layout(true, true));
|
||||||
std::cout << src << std::endl;
|
// std::cout << src << std::endl;
|
||||||
|
|
||||||
// compile machine code
|
// compile machine code
|
||||||
CUdevice cu_device;
|
CUdevice cu_device;
|
||||||
@@ -308,6 +323,9 @@ int main() {
|
|||||||
checkCudaErrors(cuMemcpyDtoH(c.data(), d_c, sizeof(numeric_t) * c.size()));
|
checkCudaErrors(cuMemcpyDtoH(c.data(), d_c, sizeof(numeric_t) * c.size()));
|
||||||
simple_gemm(rc, a, b, M, N, K);
|
simple_gemm(rc, a, b, M, N, K);
|
||||||
for(size_t i = 0; i < M*N; i++)
|
for(size_t i = 0; i < M*N; i++)
|
||||||
if(std::abs(c[i] - rc[i])/std::max(c[i], rc[i]) > 1e-4)
|
if(std::abs(c[i] - rc[i])/std::max(c[i], rc[i]) > 1e-4){
|
||||||
std::cout << i << " " << c[i] << " " << rc[i] << std::endl;
|
std::cout << i << " " << c[i] << " " << rc[i] << std::endl;
|
||||||
|
exit(EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
std::cout << "Pass!" << std::endl;
|
||||||
}
|
}
|
||||||
|
@@ -226,7 +226,7 @@ logical_or_expression
|
|||||||
/* Conditional */
|
/* Conditional */
|
||||||
conditional_expression
|
conditional_expression
|
||||||
: logical_or_expression { $$ = $1; }
|
: logical_or_expression { $$ = $1; }
|
||||||
| logical_or_expression '?' conditional_expression ':' conditional_expression { $$ = new conditional_expression($1, $2, $3); }
|
| logical_or_expression '?' conditional_expression ':' conditional_expression { $$ = new conditional_expression($1, $3, $5); }
|
||||||
;
|
;
|
||||||
|
|
||||||
/* Assignment */
|
/* Assignment */
|
||||||
|
@@ -50,6 +50,8 @@ public:
|
|||||||
value* create_br(basic_block *dest);
|
value* create_br(basic_block *dest);
|
||||||
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
|
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
|
||||||
value* create_ret_void();
|
value* create_ret_void();
|
||||||
|
// Tile-level control flow
|
||||||
|
value *create_ternary(value *cond, value *true_value, value *false_value, const std::string &name = "");
|
||||||
// Cast instructions
|
// Cast instructions
|
||||||
value *create_cast(cast_inst::op_t op, value *v, type *dst_ty, const std::string &name = "");
|
value *create_cast(cast_inst::op_t op, value *v, type *dst_ty, const std::string &name = "");
|
||||||
value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = "");
|
value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = "");
|
||||||
|
@@ -289,6 +289,23 @@ private:
|
|||||||
public:
|
public:
|
||||||
basic_block *get_dest() { return (basic_block*)get_operand(0); }
|
basic_block *get_dest() { return (basic_block*)get_operand(0); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// ternary
|
||||||
|
class ternary_inst: public instruction {
|
||||||
|
private:
|
||||||
|
std::string repr_impl() const { return "ternary"; }
|
||||||
|
ternary_inst(value *cond, value *true_value, value *false_value,
|
||||||
|
const std::string &name, instruction *next);
|
||||||
|
|
||||||
|
public:
|
||||||
|
value *get_cond() { return get_operand(0); }
|
||||||
|
value *get_true_value() { return get_operand(1); }
|
||||||
|
value *get_false_value() { return get_operand(2); }
|
||||||
|
static ternary_inst* create(value *cond, value *true_value, value *false_value,
|
||||||
|
const std::string &name = "", instruction *next = nullptr);
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// getelementptr_inst classes
|
// getelementptr_inst classes
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -681,7 +681,6 @@ void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
Instruction *i = (Instruction*)llvm_value(src, builder);
|
Instruction *i = (Instruction*)llvm_value(src, builder);
|
||||||
std::cout << "instruction: " << src->get_name() << " " << src->has_tile_result_or_op() << std::endl;
|
|
||||||
vmap_[src] = i;
|
vmap_[src] = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -797,7 +796,6 @@ void selection::run(ir::module &src, Module &dst){
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
std::cout << phi->get_name() << " " << inc_val->get_name() << std::endl;
|
|
||||||
PHINode *llvm_phi = (PHINode*)vmap_.at(phi);
|
PHINode *llvm_phi = (PHINode*)vmap_.at(phi);
|
||||||
Value *llvm_inc_val = vmap_.at(inc_val);
|
Value *llvm_inc_val = vmap_.at(inc_val);
|
||||||
llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
|
llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
|
||||||
|
@@ -67,6 +67,16 @@ value *builder::create_ret_void() {
|
|||||||
return insert(return_inst::create(ctx_));
|
return insert(return_inst::create(ctx_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// tile-level control-flow instructions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
value *builder::create_ternary(value *cond, value *true_value, value *false_value, const std::string &name){
|
||||||
|
return insert(ternary_inst::create(cond, true_value, false_value, name));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// cast instructions
|
// cast instructions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -311,6 +311,21 @@ cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, v
|
|||||||
set_operand(2, cond);
|
set_operand(2, cond);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ternary_inst
|
||||||
|
ternary_inst::ternary_inst(value *cond, value *true_value, value *false_value, const std::string &name, instruction *next)
|
||||||
|
: instruction(true_value->get_type(), 3) {
|
||||||
|
assert(true_value->get_type() == false_value->get_type());
|
||||||
|
set_operand(0, cond);
|
||||||
|
set_operand(1, true_value);
|
||||||
|
set_operand(2, false_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
ternary_inst *ternary_inst::create(value *cond, value *true_value, value *false_value,
|
||||||
|
const std::string &name, instruction *next) {
|
||||||
|
return new ternary_inst(cond, true_value, false_value, name, next);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// getelementptr_inst classes
|
// getelementptr_inst classes
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@@ -155,7 +155,8 @@ ir::type *module::get_type(const std::string &name) {
|
|||||||
void module::seal_block(ir::basic_block *block){
|
void module::seal_block(ir::basic_block *block){
|
||||||
for(auto &x: incomplete_phis_[block]){
|
for(auto &x: incomplete_phis_[block]){
|
||||||
add_phi_operands(x.first, x.second);
|
add_phi_operands(x.first, x.second);
|
||||||
set_value(x.first, try_remove_trivial_phis(x.second));
|
if(get_value(x.first) == x.second)
|
||||||
|
set_value(x.first, try_remove_trivial_phis(x.second));
|
||||||
}
|
}
|
||||||
sealed_blocks_.insert(block);
|
sealed_blocks_.insert(block);
|
||||||
incomplete_phis_[block].clear();
|
incomplete_phis_[block].clear();
|
||||||
|
Reference in New Issue
Block a user