[intermediate representation] added ternary_inst
This commit is contained in:
@@ -60,11 +60,27 @@ void matmul(restrict readonly fp32 *a, restrict readonly fp32 *b, fp32 *c, int32
|
||||
int1 checkc1[TN];\
|
||||
int1 checkc[TM, TN];\
|
||||
for(k = K; k > 0; k = k - TK){\
|
||||
int1 checka[TM, TK] = (k > bound);\
|
||||
int1 checkb[TN, TK] = (k > bound);\
|
||||
int1 checka0[TM];\
|
||||
int1 checka1[TK];\
|
||||
int1 checkb0[TN];\
|
||||
int1 checkb1[TK];\
|
||||
C = dot(a, b, C);\
|
||||
pa = pa + TK*M;\
|
||||
pb = pb + TK*K;\
|
||||
a = *pa;\
|
||||
b = *pb;\
|
||||
@checka a = *pa;\
|
||||
@checkb b = *pb;\
|
||||
if(k > bound)\
|
||||
continue;\
|
||||
checka0 = rxa < M;\
|
||||
checka1 = rka < k;\
|
||||
checkb0 = ryb < N;\
|
||||
checkb1 = rkb < k;\
|
||||
checka = checka0[:, newaxis] && checka1[newaxis, :];\
|
||||
checkb = checkb0[:, newaxis] && checkb1[newaxis, :];\
|
||||
@checka a = *pa;\
|
||||
@checkb b = *pb;\
|
||||
}\
|
||||
checkc0 = rxc < M;\
|
||||
checkc1 = ryc < N;\
|
||||
@@ -243,14 +259,13 @@ int main() {
|
||||
|
||||
|
||||
// run passes
|
||||
triton::ir::print(module, std::cout);
|
||||
exit(EXIT_FAILURE);
|
||||
buffer_info.run(module);
|
||||
shared.run(module);
|
||||
liveness.run(module);
|
||||
allocation.run();
|
||||
barriers.run(module);
|
||||
vectorize.run(module);
|
||||
triton::ir::print(module, std::cout);
|
||||
selection.run(module, llvm_module);
|
||||
|
||||
// llvm source
|
||||
@@ -260,7 +275,7 @@ int main() {
|
||||
manager.run(llvm_module);
|
||||
|
||||
std::string src = generate_machine_code(llvm_module, "nvptx64-nvidia-cuda", compute_data_layout(true, true));
|
||||
std::cout << src << std::endl;
|
||||
// std::cout << src << std::endl;
|
||||
|
||||
// compile machine code
|
||||
CUdevice cu_device;
|
||||
@@ -308,6 +323,9 @@ int main() {
|
||||
checkCudaErrors(cuMemcpyDtoH(c.data(), d_c, sizeof(numeric_t) * c.size()));
|
||||
simple_gemm(rc, a, b, M, N, K);
|
||||
for(size_t i = 0; i < M*N; i++)
|
||||
if(std::abs(c[i] - rc[i])/std::max(c[i], rc[i]) > 1e-4)
|
||||
if(std::abs(c[i] - rc[i])/std::max(c[i], rc[i]) > 1e-4){
|
||||
std::cout << i << " " << c[i] << " " << rc[i] << std::endl;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
std::cout << "Pass!" << std::endl;
|
||||
}
|
||||
|
@@ -226,7 +226,7 @@ logical_or_expression
|
||||
/* Conditional */
|
||||
conditional_expression
|
||||
: logical_or_expression { $$ = $1; }
|
||||
| logical_or_expression '?' conditional_expression ':' conditional_expression { $$ = new conditional_expression($1, $2, $3); }
|
||||
| logical_or_expression '?' conditional_expression ':' conditional_expression { $$ = new conditional_expression($1, $3, $5); }
|
||||
;
|
||||
|
||||
/* Assignment */
|
||||
|
@@ -50,6 +50,8 @@ public:
|
||||
value* create_br(basic_block *dest);
|
||||
value* create_cond_br(value *cond, basic_block* if_dest, basic_block* else_dest);
|
||||
value* create_ret_void();
|
||||
// Tile-level control flow
|
||||
value *create_ternary(value *cond, value *true_value, value *false_value, const std::string &name = "");
|
||||
// Cast instructions
|
||||
value *create_cast(cast_inst::op_t op, value *v, type *dst_ty, const std::string &name = "");
|
||||
value* create_si_to_fp(value *src, type *dst_ty, const std::string &name = "");
|
||||
|
@@ -289,6 +289,23 @@ private:
|
||||
public:
|
||||
basic_block *get_dest() { return (basic_block*)get_operand(0); }
|
||||
};
|
||||
|
||||
// ternary
|
||||
class ternary_inst: public instruction {
|
||||
private:
|
||||
std::string repr_impl() const { return "ternary"; }
|
||||
ternary_inst(value *cond, value *true_value, value *false_value,
|
||||
const std::string &name, instruction *next);
|
||||
|
||||
public:
|
||||
value *get_cond() { return get_operand(0); }
|
||||
value *get_true_value() { return get_operand(1); }
|
||||
value *get_false_value() { return get_operand(2); }
|
||||
static ternary_inst* create(value *cond, value *true_value, value *false_value,
|
||||
const std::string &name = "", instruction *next = nullptr);
|
||||
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// getelementptr_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -681,7 +681,6 @@ void selection::lower_instruction(ir::instruction *src, IRBuilder<> &builder) {
|
||||
}
|
||||
else {
|
||||
Instruction *i = (Instruction*)llvm_value(src, builder);
|
||||
std::cout << "instruction: " << src->get_name() << " " << src->has_tile_result_or_op() << std::endl;
|
||||
vmap_[src] = i;
|
||||
}
|
||||
}
|
||||
@@ -797,7 +796,6 @@ void selection::run(ir::module &src, Module &dst){
|
||||
});
|
||||
}
|
||||
else {
|
||||
std::cout << phi->get_name() << " " << inc_val->get_name() << std::endl;
|
||||
PHINode *llvm_phi = (PHINode*)vmap_.at(phi);
|
||||
Value *llvm_inc_val = vmap_.at(inc_val);
|
||||
llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block);
|
||||
|
@@ -67,6 +67,16 @@ value *builder::create_ret_void() {
|
||||
return insert(return_inst::create(ctx_));
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tile-level control-flow instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_ternary(value *cond, value *true_value, value *false_value, const std::string &name){
|
||||
return insert(ternary_inst::create(cond, true_value, false_value, name));
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// cast instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -311,6 +311,21 @@ cond_branch_inst::cond_branch_inst(basic_block *if_dst, basic_block *else_dst, v
|
||||
set_operand(2, cond);
|
||||
}
|
||||
|
||||
// ternary_inst
|
||||
ternary_inst::ternary_inst(value *cond, value *true_value, value *false_value, const std::string &name, instruction *next)
|
||||
: instruction(true_value->get_type(), 3) {
|
||||
assert(true_value->get_type() == false_value->get_type());
|
||||
set_operand(0, cond);
|
||||
set_operand(1, true_value);
|
||||
set_operand(2, false_value);
|
||||
}
|
||||
|
||||
ternary_inst *ternary_inst::create(value *cond, value *true_value, value *false_value,
|
||||
const std::string &name, instruction *next) {
|
||||
return new ternary_inst(cond, true_value, false_value, name, next);
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// getelementptr_inst classes
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@@ -155,7 +155,8 @@ ir::type *module::get_type(const std::string &name) {
|
||||
void module::seal_block(ir::basic_block *block){
|
||||
for(auto &x: incomplete_phis_[block]){
|
||||
add_phi_operands(x.first, x.second);
|
||||
set_value(x.first, try_remove_trivial_phis(x.second));
|
||||
if(get_value(x.first) == x.second)
|
||||
set_value(x.first, try_remove_trivial_phis(x.second));
|
||||
}
|
||||
sealed_blocks_.insert(block);
|
||||
incomplete_phis_[block].clear();
|
||||
|
Reference in New Issue
Block a user