[general] major overhaul of triton-c/triton-ir/triton-jit:
- Added alloc const - Added atomics - Pruning tuning space - Added example for dot/conv/shift - Bugfixes
This commit is contained in:
@@ -95,55 +95,75 @@ void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs,
|
||||
throw std::runtime_error("unreachable");
|
||||
}
|
||||
|
||||
void node::implicit_broadcast(ir::module *mod, ir::value *&arg, ir::type *ty) {
|
||||
ir::value *tmp = ir::undef_value::get(ty);
|
||||
implicit_broadcast(mod, arg, tmp);
|
||||
}
|
||||
|
||||
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs){
|
||||
ir::builder &builder = mod->get_builder();
|
||||
void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs) {
|
||||
ir::type *lhs_ty = lhs->get_type();
|
||||
ir::type *rhs_ty = rhs->get_type();
|
||||
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
||||
// Both are scalar
|
||||
ir::type *res_ty = nullptr;
|
||||
if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
|
||||
return;
|
||||
// One argument is scalar
|
||||
if(lhs_ty->is_tile_ty() ^ rhs_ty->is_tile_ty()){
|
||||
auto &shapes = lhs_ty->is_tile_ty()?lhs_ty->get_tile_shapes():rhs_ty->get_tile_shapes();
|
||||
auto &scalar = lhs_ty->is_tile_ty()?rhs:lhs;
|
||||
scalar = builder.create_splat(scalar, shapes);
|
||||
else if(lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty())
|
||||
res_ty = lhs_ty;
|
||||
else if(!lhs_ty->is_tile_ty() && rhs_ty->is_tile_ty())
|
||||
res_ty = rhs_ty;
|
||||
else{
|
||||
auto lhs_shapes = lhs_ty->get_tile_shapes();
|
||||
auto rhs_shapes = rhs_ty->get_tile_shapes();
|
||||
size_t lhs_size = lhs_shapes.size();
|
||||
size_t rhs_size = rhs_shapes.size();
|
||||
size_t res_size = std::max(lhs_size, rhs_size);
|
||||
ir::type::tile_shapes_t res_shapes(res_size);
|
||||
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
||||
for(int i = 0; i < res_size; i++){
|
||||
if(i >= res_size - lhs_size && i >= res_size - rhs_size)
|
||||
res_shapes[i] = lhs_shapes[i]==one?rhs_shapes[i]:lhs_shapes[i];
|
||||
else if(i >= res_size - lhs_size)
|
||||
res_shapes[i] = lhs_shapes[i];
|
||||
else if(i >= res_size - rhs_size)
|
||||
res_shapes[i] = rhs_shapes[i];
|
||||
}
|
||||
res_ty = ir::tile_type::get(lhs_ty->get_scalar_ty(), res_shapes);
|
||||
}
|
||||
implicit_broadcast(mod, res_ty, rhs);
|
||||
implicit_broadcast(mod, res_ty, lhs);
|
||||
}
|
||||
|
||||
void node::implicit_broadcast(ir::module *mod, ir::type *ty, ir::value *&src){
|
||||
ir::builder &builder = mod->get_builder();
|
||||
ir::type *src_ty = src->get_type();
|
||||
ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context());
|
||||
// Both are scalar
|
||||
if(!ty->is_tile_ty() && !src_ty->is_tile_ty())
|
||||
return;
|
||||
// Broadcast scalar
|
||||
if(ty->is_tile_ty() && !src_ty->is_tile_ty()){
|
||||
src = builder.create_splat(src, ty->get_tile_shapes());
|
||||
return;
|
||||
}
|
||||
// Downcast tile
|
||||
if(!ty->is_tile_ty() && src_ty->is_tile_ty()){
|
||||
for(ir::constant *shape: src_ty->get_tile_shapes())
|
||||
if(shape != one)
|
||||
throw std::runtime_error("cannot downcast");
|
||||
src = builder.create_downcast(src);
|
||||
return;
|
||||
}
|
||||
// Both are arrays
|
||||
auto lhs_shapes = lhs->get_type()->get_tile_shapes();
|
||||
auto rhs_shapes = rhs->get_type()->get_tile_shapes();
|
||||
if(lhs_shapes == rhs_shapes)
|
||||
return;
|
||||
int lhs_dim = lhs_shapes.size();
|
||||
int rhs_dim = rhs_shapes.size();
|
||||
auto &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes;
|
||||
auto &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes;
|
||||
size_t ndim = longest.size();
|
||||
int off = longest.size() - shortest.size();
|
||||
for(int i = longest.size() - 1; i>= 0; i--){
|
||||
if(shortest[off + i] != longest[i] && shortest[off + i] != one && longest[i] != one)
|
||||
throw std::runtime_error("cannot broadcast");
|
||||
}
|
||||
auto dst_shapes = ty->get_tile_shapes();
|
||||
auto src_shapes = src_ty->get_tile_shapes();
|
||||
int dst_dim = dst_shapes.size();
|
||||
int src_dim = src_shapes.size();
|
||||
// Pad
|
||||
int off = dst_dim - src_dim;
|
||||
for(size_t i = 0; i < off; i++)
|
||||
shortest.insert(shortest.begin(), one);
|
||||
ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs;
|
||||
src_shapes.insert(src_shapes.begin(), one);
|
||||
if(off > 0)
|
||||
target = builder.create_reshape(target, shortest);
|
||||
src = builder.create_reshape(src, src_shapes);
|
||||
// Broadcast
|
||||
ir::type::tile_shapes_t shapes(ndim);
|
||||
for(size_t i = 0; i < ndim; i++)
|
||||
shapes[i] = shortest[i]==one?longest[i]:shortest[i];
|
||||
if(shapes != lhs_shapes)
|
||||
lhs = builder.create_broadcast(lhs, shapes);
|
||||
if(shapes != rhs_shapes)
|
||||
rhs = builder.create_broadcast(rhs, shapes);
|
||||
for(int i = dst_dim - 1; i>= 0; i--)
|
||||
if(dst_shapes[i] != src_shapes[i] && dst_shapes[i] != one && src_shapes[i] != one)
|
||||
throw std::runtime_error("cannot broadcast");
|
||||
if(dst_shapes != src_shapes)
|
||||
src = builder.create_broadcast(src, dst_shapes);
|
||||
}
|
||||
|
||||
/* Helper */
|
||||
@@ -336,7 +356,9 @@ ir::value* iteration_statement::codegen(ir::module *mod) const{
|
||||
return builder.create_cond_br(cond, loop_bb, next_bb);
|
||||
});
|
||||
init_->codegen(mod);
|
||||
builder.create_br(loop_bb);
|
||||
ir::value *cond = stop_->codegen(mod);
|
||||
builder.create_cond_br(cond, loop_bb, next_bb);
|
||||
// builder.create_br(loop_bb);
|
||||
builder.set_insert_point(loop_bb);
|
||||
if(!is_terminator(statements_->codegen(mod)))
|
||||
mod->get_continue_fn()();
|
||||
@@ -378,6 +400,7 @@ ir::value* selection_statement::codegen(ir::module* mod) const{
|
||||
builder.create_br(endif_bb);
|
||||
}
|
||||
// Endif
|
||||
mod->seal_block(endif_bb);
|
||||
builder.set_insert_point(endif_bb);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -422,7 +445,7 @@ ir::value* initializer::codegen(ir::module * mod) const{
|
||||
else if(expr_){
|
||||
value = expr_->codegen(mod);
|
||||
value = explicit_cast(mod->get_builder(), value, ty);
|
||||
implicit_broadcast(mod, value, ty);
|
||||
implicit_broadcast(mod, ty, value);
|
||||
}
|
||||
value->set_name(name);
|
||||
mod->set_value(name, value);
|
||||
@@ -543,6 +566,19 @@ ir::value* get_global_range::codegen(ir::module *mod) const {
|
||||
return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod));
|
||||
}
|
||||
|
||||
// get_range_id
|
||||
ir::value* get_range_id::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_get_range_id(axis_->value());
|
||||
}
|
||||
|
||||
// atomic cas
|
||||
ir::value* atomic_cas::codegen(ir::module *mod) const {
|
||||
ir::value *ptr = ptr_->codegen(mod);
|
||||
ir::value *cmp = cmp_->codegen(mod);
|
||||
ir::value *val = val_->codegen(mod);
|
||||
return mod->get_builder().create_atomic_cas(ptr, cmp, val);
|
||||
}
|
||||
|
||||
// matmul
|
||||
ir::value* matmul_expression::codegen(ir::module *mod) const {
|
||||
ir::value *A = A_->codegen(mod);
|
||||
@@ -554,10 +590,37 @@ ir::value* matmul_expression::codegen(ir::module *mod) const {
|
||||
// ir::type *tile_ty = ir::tile_type::get(scalar_ty, {M, N});
|
||||
// ir::value *tmp = ir::undef_value::get(tile_ty);
|
||||
// implicit_broadcast(mod, tmp, C);
|
||||
return mod->get_builder().create_matmul(A, B, C);
|
||||
return mod->get_builder().create_dot(A, B, C);
|
||||
}
|
||||
|
||||
// min
|
||||
ir::value* min_expression::codegen(ir::module *mod) const {
|
||||
ir::value* cmp = binary_operator(LT, (node*)x_, (node*)y_).codegen(mod);
|
||||
ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0);
|
||||
ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1);
|
||||
return mod->get_builder().create_select(cmp, x, y);
|
||||
}
|
||||
|
||||
// max
|
||||
ir::value* max_expression::codegen(ir::module *mod) const {
|
||||
ir::value* cmp = binary_operator(GT, (node*)x_, (node*)y_).codegen(mod);
|
||||
ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0);
|
||||
ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1);
|
||||
return mod->get_builder().create_select(cmp, x, y);
|
||||
}
|
||||
|
||||
// select
|
||||
ir::value* select_expression::codegen(ir::module *mod) const {
|
||||
ir::value* pred = pred_->codegen(mod);
|
||||
ir::value* if_value = if_value_->codegen(mod);
|
||||
ir::value* else_value = else_value_->codegen(mod);
|
||||
return mod->get_builder().create_select(pred, if_value, else_value);
|
||||
}
|
||||
|
||||
// Trans
|
||||
ir::value* trans_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_trans(arg_->codegen(mod));
|
||||
}
|
||||
|
||||
/* Postfix expression */
|
||||
ir::value* indexing_expression::codegen(ir::module *mod) const{
|
||||
@@ -573,6 +636,7 @@ ir::value* indexing_expression::codegen(ir::module *mod) const{
|
||||
return mod->get_builder().create_reshape(in, out_shapes);
|
||||
}
|
||||
|
||||
|
||||
/* Unary operator */
|
||||
ir::value *unary_operator::llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const{
|
||||
ir::type *atype = arg->get_type();
|
||||
@@ -666,7 +730,7 @@ ir::value *assignment_expression::codegen(ir::module *mod) const{
|
||||
if(auto *x = dynamic_cast<const named_expression*>(lvalue_)){
|
||||
ir::type *ty = mod->get_scope().types.at(x->id()->name());
|
||||
rvalue = explicit_cast(mod->get_builder(), rvalue, ty);
|
||||
implicit_broadcast(mod, rvalue, ty);
|
||||
implicit_broadcast(mod, ty, rvalue);
|
||||
mod->set_value(x->id()->name(), rvalue);
|
||||
}
|
||||
else if(auto* x = dynamic_cast<const unary_operator*>(lvalue_)){
|
||||
|
Reference in New Issue
Block a user