[general] major overhaul of triton-c/triton-ir/triton-jit:

- Added alloc const
- Added atomics
- Pruning tuning space
- Added example for dot/conv/shift
- Bugfixes
This commit is contained in:
Philippe Tillet
2019-04-25 16:17:36 -04:00
parent 0c607c9392
commit 3413aad582
50 changed files with 2051 additions and 570 deletions

View File

@@ -95,55 +95,75 @@ void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
throw std::runtime_error("unreachable");
}
void node::implicit_broadcast(ir::module *mod, ir::value *&arg, ir::type *ty) {
ir::value *tmp = ir::undef_value::get(ty);
implicit_broadcast(mod, arg, tmp);
}
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs){
ir::builder &builder = mod->get_builder();
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs) {
ir::type *lhs_ty = lhs->get_type();
ir::type *rhs_ty = rhs->get_type();
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
// Both are scalar
ir::type *res_ty = nullptr;
if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
return;
// One argument is scalar
if(lhs_ty->is_tile_ty() ^ rhs_ty->is_tile_ty()){
auto &shapes = lhs_ty->is_tile_ty()?lhs_ty->get_tile_shapes():rhs_ty->get_tile_shapes();
auto &scalar = lhs_ty->is_tile_ty()?rhs:lhs;
scalar = builder.create_splat(scalar, shapes);
else if(lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
res_ty = lhs_ty;
else if(!lhs_ty->is_tile_ty() && rhs_ty->is_tile_ty())
res_ty = rhs_ty;
else{
auto lhs_shapes = lhs_ty->get_tile_shapes();
auto rhs_shapes = rhs_ty->get_tile_shapes();
size_t lhs_size = lhs_shapes.size();
size_t rhs_size = rhs_shapes.size();
size_t res_size = std::max(lhs_size, rhs_size);
ir::type::tile_shapes_t res_shapes(res_size);
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
for(int i = 0; i < res_size; i++){
if(i >= res_size - lhs_size && i >= res_size - rhs_size)
res_shapes[i] = lhs_shapes[i]==one?rhs_shapes[i]:lhs_shapes[i];
else if(i >= res_size - lhs_size)
res_shapes[i] = lhs_shapes[i];
else if(i >= res_size - rhs_size)
res_shapes[i] = rhs_shapes[i];
}
res_ty = ir::tile_type::get(lhs_ty->get_scalar_ty(), res_shapes);
}
implicit_broadcast(mod, res_ty, rhs);
implicit_broadcast(mod, res_ty, lhs);
}
void node::implicit_broadcast(ir::module *mod, ir::type *ty, ir::value *&src){
ir::builder &builder = mod->get_builder();
ir::type *src_ty = src->get_type();
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
// Both are scalar
if(!ty->is_tile_ty() && !src_ty->is_tile_ty())
return;
// Broadcast scalar
if(ty->is_tile_ty() && !src_ty->is_tile_ty()){
src = builder.create_splat(src, ty->get_tile_shapes());
return;
}
// Downcast tile
if(!ty->is_tile_ty() && src_ty->is_tile_ty()){
for(ir::constant *shape: src_ty->get_tile_shapes())
if(shape != one)
throw std::runtime_error("cannot downcast");
src = builder.create_downcast(src);
return;
}
// Both are arrays
auto lhs_shapes = lhs->get_type()->get_tile_shapes();
auto rhs_shapes = rhs->get_type()->get_tile_shapes();
if(lhs_shapes == rhs_shapes)
return;
int lhs_dim = lhs_shapes.size();
int rhs_dim = rhs_shapes.size();
auto &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes;
auto &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes;
size_t ndim = longest.size();
int off = longest.size() - shortest.size();
for(int i = longest.size() - 1; i>= 0; i--){
if(shortest[off + i] != longest[i] && shortest[off + i] != one && longest[i] != one)
throw std::runtime_error("cannot broadcast");
}
auto dst_shapes = ty->get_tile_shapes();
auto src_shapes = src_ty->get_tile_shapes();
int dst_dim = dst_shapes.size();
int src_dim = src_shapes.size();
// Pad
int off = dst_dim - src_dim;
for(size_t i = 0; i < off; i++)
shortest.insert(shortest.begin(), one);
ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs;
src_shapes.insert(src_shapes.begin(), one);
if(off > 0)
target = builder.create_reshape(target, shortest);
src = builder.create_reshape(src, src_shapes);
// Broadcast
ir::type::tile_shapes_t shapes(ndim);
for(size_t i = 0; i < ndim; i++)
shapes[i] = shortest[i]==one?longest[i]:shortest[i];
if(shapes != lhs_shapes)
lhs = builder.create_broadcast(lhs, shapes);
if(shapes != rhs_shapes)
rhs = builder.create_broadcast(rhs, shapes);
for(int i = dst_dim - 1; i>= 0; i--)
if(dst_shapes[i] != src_shapes[i] && dst_shapes[i] != one && src_shapes[i] != one)
throw std::runtime_error("cannot broadcast");
if(dst_shapes != src_shapes)
src = builder.create_broadcast(src, dst_shapes);
}
/* Helper */
@@ -336,7 +356,9 @@ ir::value* iteration_statement::codegen(ir::module *mod) const{
return builder.create_cond_br(cond, loop_bb, next_bb);
});
init_->codegen(mod);
builder.create_br(loop_bb);
ir::value *cond = stop_->codegen(mod);
builder.create_cond_br(cond, loop_bb, next_bb);
// builder.create_br(loop_bb);
builder.set_insert_point(loop_bb);
if(!is_terminator(statements_->codegen(mod)))
mod->get_continue_fn()();
@@ -378,6 +400,7 @@ ir::value* selection_statement::codegen(ir::module* mod) const{
builder.create_br(endif_bb);
}
// Endif
mod->seal_block(endif_bb);
builder.set_insert_point(endif_bb);
return nullptr;
}
@@ -422,7 +445,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
else if(expr_){
value = expr_->codegen(mod);
value = explicit_cast(mod->get_builder(), value, ty);
implicit_broadcast(mod, value, ty);
implicit_broadcast(mod, ty, value);
}
value->set_name(name);
mod->set_value(name, value);
@@ -543,6 +566,19 @@ ir::value* get_global_range::codegen(ir::module *mod) const {
return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod));
}
// get_range_id
ir::value* get_range_id::codegen(ir::module *mod) const {
return mod->get_builder().create_get_range_id(axis_->value());
}
// atomic cas
ir::value* atomic_cas::codegen(ir::module *mod) const {
ir::value *ptr = ptr_->codegen(mod);
ir::value *cmp = cmp_->codegen(mod);
ir::value *val = val_->codegen(mod);
return mod->get_builder().create_atomic_cas(ptr, cmp, val);
}
// matmul
ir::value* matmul_expression::codegen(ir::module *mod) const {
ir::value *A = A_->codegen(mod);
@@ -554,10 +590,37 @@ ir::value* matmul_expression::codegen(ir::module *mod) const {
// ir::type *tile_ty = ir::tile_type::get(scalar_ty, {M, N});
// ir::value *tmp = ir::undef_value::get(tile_ty);
// implicit_broadcast(mod, tmp, C);
return mod->get_builder().create_matmul(A, B, C);
return mod->get_builder().create_dot(A, B, C);
}
// min
ir::value* min_expression::codegen(ir::module *mod) const {
ir::value* cmp = binary_operator(LT, (node*)x_, (node*)y_).codegen(mod);
ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0);
ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1);
return mod->get_builder().create_select(cmp, x, y);
}
// max
ir::value* max_expression::codegen(ir::module *mod) const {
ir::value* cmp = binary_operator(GT, (node*)x_, (node*)y_).codegen(mod);
ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0);
ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1);
return mod->get_builder().create_select(cmp, x, y);
}
// select
ir::value* select_expression::codegen(ir::module *mod) const {
ir::value* pred = pred_->codegen(mod);
ir::value* if_value = if_value_->codegen(mod);
ir::value* else_value = else_value_->codegen(mod);
return mod->get_builder().create_select(pred, if_value, else_value);
}
// Trans
ir::value* trans_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_trans(arg_->codegen(mod));
}
/* Postfix expression */
ir::value* indexing_expression::codegen(ir::module *mod) const{
@@ -573,6 +636,7 @@ ir::value* indexing_expression::codegen(ir::module *mod) const{
return mod->get_builder().create_reshape(in, out_shapes);
}
/* Unary operator */
ir::value *unary_operator::llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const{
ir::type *atype = arg->get_type();
@@ -666,7 +730,7 @@ ir::value *assignment_expression::codegen(ir::module *mod) const{
if(auto *x = dynamic_cast<const named_expression*>(lvalue_)){
ir::type *ty = mod->get_scope().types.at(x->id()->name());
rvalue = explicit_cast(mod->get_builder(), rvalue, ty);
implicit_broadcast(mod, rvalue, ty);
implicit_broadcast(mod, ty, rvalue);
mod->set_value(x->id()->name(), rvalue);
}
else if(auto* x = dynamic_cast<const unary_operator*>(lvalue_)){