[intermediate representation] added ternary_inst

This commit is contained in:
Philippe Tillet
2019-02-26 14:20:58 -05:00
parent 68dea75aa0
commit 017702590b
8 changed files with 71 additions and 10 deletions

View File

@@ -60,11 +60,27 @@ void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32
int1 checkc1[TN];\
int1 checkc[TM, TN];\
for(k = K; k > 0; k = k - TK){\
int1 checka[TM, TK] = (k > bound);\
int1 checkb[TN, TK] = (k > bound);\
int1 checka0[TM];\
int1 checka1[TK];\
int1 checkb0[TN];\
int1 checkb1[TK];\
C = dot(a, b, C);\
pa = pa + TK*M;\
pb = pb + TK*K;\
a = *pa;\
b = *pb;\
@checka a = *pa;\
@checkb b = *pb;\
if(k > bound)\
continue;\
checka0 = rxa < M;\
checka1 = rka < k;\
checkb0 = ryb < N;\
checkb1 = rkb < k;\
checka = checka0[:, newaxis] && checka1[newaxis, :];\
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\
@checka a = *pa;\
@checkb b = *pb;\
}\
checkc0 = rxc < M;\
checkc1 = ryc < N;\
@@ -243,14 +259,13 @@ int main() {
// run passes
triton::ir::print(module, std::cout);
exit(EXIT_FAILURE);
buffer_info.run(module);
shared.run(module);
liveness.run(module);
allocation.run();
barriers.run(module);
vectorize.run(module);
triton::ir::print(module, std::cout);
selection.run(module, llvm_module);
// llvm source
@@ -260,7 +275,7 @@ int main() {
manager.run(llvm_module);
std::string src = generate_machine_code(llvm_module, "nvptx64-nvidia-cuda", compute_data_layout(true, true));
std::cout << src << std::endl;
// std::cout << src << std::endl;
// compile machine code
CUdevice cu_device;
@@ -308,6 +323,9 @@ int main() {
checkCudaErrors(cuMemcpyDtoH(c.data(), d_c, sizeof(numeric_t) * c.size()));
simple_gemm(rc, a, b, M, N, K);
for(size_t i = 0; i < M*N; i++)
if(std::abs(c[i] - rc[i])/std::max(c[i], rc[i]) > 1e-4)
if(std::abs(c[i] - rc[i])/std::max(c[i], rc[i]) > 1e-4){
std::cout << i << " " << c[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE);
}
std::cout << "Pass!" << std::endl;
}

View File

@@ -226,7 +226,7 @@ logical_or_expression
/* Conditional */
conditional_expression
: logical_or_expression { $$ = $1; }
| logical_or_expression '?' conditional_expression ':' conditional_expression { $$ = new conditional_expression($1, $2, $3); }
| logical_or_expression '?' conditional_expression ':' conditional_expression { $$ = new conditional_expression($1, $3, $5); }
;
/* Assignment */

View File

@@ -50,6 +50,8 @@ public:
value* create_br(basic_block *dest);
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
value* create_ret_void();
// Tile-level control flow
value *create_ternary(value *cond, value *true_value, value *false_value, const std::string &name = "");
// Cast instructions
value *create_cast(cast_inst::op_t op, value *v, type *dst_ty, const std::string &name = "");
value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = "");

View File

@@ -289,6 +289,23 @@ private:
public:
basic_block *get_dest() { return (basic_block*)get_operand(0); }
};
// ternary
class ternary_inst: public instruction {
private:
std::string repr_impl() const { return "ternary"; }
ternary_inst(value *cond, value *true_value, value *false_value,
const std::string &name, instruction *next);
public:
value *get_cond() { return get_operand(0); }
value *get_true_value() { return get_operand(1); }
value *get_false_value() { return get_operand(2); }
static ternary_inst* create(value *cond, value *true_value, value *false_value,
const std::string &name = "", instruction *next = nullptr);
};
//===----------------------------------------------------------------------===//
// getelementptr_inst classes
//===----------------------------------------------------------------------===//

View File

@@ -681,7 +681,6 @@ void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
}
else {
Instruction *i = (Instruction*)llvm_value(src, builder);
std::cout << "instruction: " << src->get_name() << " " << src->has_tile_result_or_op() << std::endl;
vmap_[src] = i;
}
}
@@ -797,7 +796,6 @@ void selection::run(ir::module &src, Module &dst){
});
}
else {
std::cout << phi->get_name() << " " << inc_val->get_name() << std::endl;
PHINode *llvm_phi = (PHINode*)vmap_.at(phi);
Value *llvm_inc_val = vmap_.at(inc_val);
llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);

View File

@@ -67,6 +67,16 @@ value *builder::create_ret_void() {
return insert(return_inst::create(ctx_));
}
//===----------------------------------------------------------------------===//
// tile-level control-flow instructions
//===----------------------------------------------------------------------===//
value *builder::create_ternary(value *cond, value *true_value, value *false_value, const std::string &name){
return insert(ternary_inst::create(cond, true_value, false_value, name));
}
//===----------------------------------------------------------------------===//
// cast instructions
//===----------------------------------------------------------------------===//

View File

@@ -311,6 +311,21 @@ cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, v
set_operand(2, cond);
}
// ternary_inst
ternary_inst::ternary_inst(value *cond, value *true_value, value *false_value, const std::string &name, instruction *next)
: instruction(true_value->get_type(), 3) {
assert(true_value->get_type() == false_value->get_type());
set_operand(0, cond);
set_operand(1, true_value);
set_operand(2, false_value);
}
ternary_inst *ternary_inst::create(value *cond, value *true_value, value *false_value,
const std::string &name, instruction *next) {
return new ternary_inst(cond, true_value, false_value, name, next);
}
//===----------------------------------------------------------------------===//
// getelementptr_inst classes
//===----------------------------------------------------------------------===//

View File

@@ -155,7 +155,8 @@ ir::type *module::get_type(const std::string &name) {
void module::seal_block(ir::basic_block *block){
for(auto &x: incomplete_phis_[block]){
add_phi_operands(x.first, x.second);
set_value(x.first, try_remove_trivial_phis(x.second));
if(get_value(x.first) == x.second)
set_value(x.first, try_remove_trivial_phis(x.second));
}
sealed_blocks_.insert(block);
incomplete_phis_[block].clear();