[BACKEND] Various bug fixes; making reductions faster (#533)
This commit is contained in:
@@ -224,7 +224,7 @@ struct scanline_layout: public distributed_layout {
|
|||||||
int nts(size_t k) { return nts_.at(k); }
|
int nts(size_t k) { return nts_.at(k); }
|
||||||
int contig_per_thread(size_t k) { return nts_.at(k); }
|
int contig_per_thread(size_t k) { return nts_.at(k); }
|
||||||
|
|
||||||
int per_thread(size_t k) { return nts(k) * shape_[k] / shape_per_cta(k);}
|
int per_thread(size_t k) { return contig_per_thread(k) * shape_[k] / shape_per_cta(k);}
|
||||||
public:
|
public:
|
||||||
// micro tile size. The size of a tile held by a thread block.
|
// micro tile size. The size of a tile held by a thread block.
|
||||||
std::vector<int> mts_;
|
std::vector<int> mts_;
|
||||||
|
@@ -319,8 +319,8 @@ std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator*
|
|||||||
}
|
}
|
||||||
if(x->is_int_add_sub()){
|
if(x->is_int_add_sub()){
|
||||||
unsigned lvalue = 1, rvalue = 1;
|
unsigned lvalue = 1, rvalue = 1;
|
||||||
lvalue = gcd(rhs_max_contiguous[d], lhs_starting_multiple[d]);
|
lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
|
||||||
rvalue = gcd(lhs_max_contiguous[d], rhs_starting_multiple[d]);
|
rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
|
||||||
value = std::max(lvalue, rvalue);
|
value = std::max(lvalue, rvalue);
|
||||||
}
|
}
|
||||||
result.push_back(value);
|
result.push_back(value);
|
||||||
|
@@ -209,14 +209,15 @@ mma_layout::mma_layout(size_t num_warps,
|
|||||||
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
|
rep_ = {2*pack_size_0, 2*pack_size_1, 1};
|
||||||
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
|
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
|
||||||
contig_per_thread_ = {1, 1};
|
contig_per_thread_ = {1, 1};
|
||||||
|
order_ = {0, 1};
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
// fpw_ = {1, 1, 1};
|
// fpw_ = {1, 1, 1};
|
||||||
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
|
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
|
||||||
contig_per_thread_ = {1, 2};
|
contig_per_thread_ = {1, 2};
|
||||||
|
order_ = {1, 0};
|
||||||
// rep_ = {2, 2, 1};
|
// rep_ = {2, 2, 1};
|
||||||
}
|
}
|
||||||
order_ = {0, 1};
|
|
||||||
|
|
||||||
/* warps per tile */
|
/* warps per tile */
|
||||||
wpt_ = {1, 1, 1};
|
wpt_ = {1, 1, 1};
|
||||||
@@ -616,8 +617,9 @@ void layouts::run(ir::module &mod) {
|
|||||||
unsigned axis = red->get_axis();
|
unsigned axis = red->get_axis();
|
||||||
// shape
|
// shape
|
||||||
auto shapes = arg->get_type()->get_block_shapes();
|
auto shapes = arg->get_type()->get_block_shapes();
|
||||||
scanline_layout *layout = get(arg)->to_scanline();
|
distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(get(arg));
|
||||||
shapes[axis] = layout->mts(axis);
|
shapes[axis] = layout->shape_per_cta(axis) / layout->contig_per_thread(axis);
|
||||||
|
|
||||||
// create layout
|
// create layout
|
||||||
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_);
|
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_);
|
||||||
tmp_[red] = id;
|
tmp_[red] = id;
|
||||||
|
@@ -88,6 +88,7 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
|
|||||||
allocation.run(ir);
|
allocation.run(ir);
|
||||||
prefetch_s.run(ir);
|
prefetch_s.run(ir);
|
||||||
barriers.run(ir);
|
barriers.run(ir);
|
||||||
|
// ir.print(std::cout);
|
||||||
isel.visit(ir, *llvm);
|
isel.visit(ir, *llvm);
|
||||||
shared_static = allocation.allocated_size();
|
shared_static = allocation.allocated_size();
|
||||||
return llvm;
|
return llvm;
|
||||||
|
@@ -88,6 +88,7 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
|
|||||||
#define f16_ty builder_->getHalfTy()
|
#define f16_ty builder_->getHalfTy()
|
||||||
#define bf16_ty builder_->getBFloatTy()
|
#define bf16_ty builder_->getBFloatTy()
|
||||||
#define f32_ty builder_->getFloatTy()
|
#define f32_ty builder_->getFloatTy()
|
||||||
|
#define i1_ty builder_->getInt1Ty()
|
||||||
#define i8_ty builder_->getInt8Ty()
|
#define i8_ty builder_->getInt8Ty()
|
||||||
#define i16_ty builder_->getInt16Ty()
|
#define i16_ty builder_->getInt16Ty()
|
||||||
#define i32_ty builder_->getInt32Ty()
|
#define i32_ty builder_->getInt32Ty()
|
||||||
@@ -736,6 +737,9 @@ void generator::visit_uncond_branch_inst(ir::uncond_branch_inst* br) {
|
|||||||
* \brief Code Generation for a (synchronous) `load`
|
* \brief Code Generation for a (synchronous) `load`
|
||||||
*/
|
*/
|
||||||
void generator::visit_load_inst(ir::load_inst* x){
|
void generator::visit_load_inst(ir::load_inst* x){
|
||||||
|
BasicBlock *current = builder_->GetInsertBlock();
|
||||||
|
Module *module = current->getModule();
|
||||||
|
Value *tid = tgt_->get_local_id(module, *builder_, 0);
|
||||||
ir::value *op = x->get_pointer_operand();
|
ir::value *op = x->get_pointer_operand();
|
||||||
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
|
ir::masked_load_inst *mx = dynamic_cast<ir::masked_load_inst*>(x);
|
||||||
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
|
Type* ty = cvt(op->get_type()->get_scalar_ty()->get_pointer_element_ty());
|
||||||
@@ -775,6 +779,9 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|||||||
in_off = 0;
|
in_off = 0;
|
||||||
}
|
}
|
||||||
Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue();
|
Value *pred = mx ? vals_[mx->get_mask_operand()][idx] : builder_->getTrue();
|
||||||
|
// if(!op->get_type()->is_block_ty()){
|
||||||
|
// pred = builder_->CreateAnd(pred, icmp_eq(tid, i32(0)));
|
||||||
|
// }
|
||||||
Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr;
|
Value *other = mx ? vals_[mx->get_false_value_operand()][idx] : nullptr;
|
||||||
size_t nbits = dtsize*8;
|
size_t nbits = dtsize*8;
|
||||||
// pack sub-words (< 32/64bits) into words
|
// pack sub-words (< 32/64bits) into words
|
||||||
@@ -878,6 +885,18 @@ void generator::visit_load_inst(ir::load_inst* x){
|
|||||||
|
|
||||||
|
|
||||||
Value *_ret = call(inlineAsm, args);
|
Value *_ret = call(inlineAsm, args);
|
||||||
|
// if(!op->get_type()->is_block_ty()){
|
||||||
|
// Value* cond = icmp_eq(tid, i32(0));
|
||||||
|
// Value* shptr = bit_cast(shmem_, ptr_ty(_ret->getType(), 3));
|
||||||
|
// Instruction* bar = add_barrier();
|
||||||
|
// Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, bar, false);
|
||||||
|
// builder_->SetInsertPoint(term);
|
||||||
|
// store(_ret, shptr);
|
||||||
|
// builder_->SetInsertPoint(bar->getParent());
|
||||||
|
// _ret = load(shptr);
|
||||||
|
// add_barrier();
|
||||||
|
// }
|
||||||
|
|
||||||
// ---
|
// ---
|
||||||
// extract and store return values
|
// extract and store return values
|
||||||
// ---
|
// ---
|
||||||
@@ -2033,12 +2052,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
|||||||
|
|
||||||
// create mma & unpack result, m, n, k are offsets in mat
|
// create mma & unpack result, m, n, k are offsets in mat
|
||||||
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
|
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
|
||||||
unsigned cols_per_thread = num_rep_m * 2;
|
unsigned cols_per_thread = num_rep_n * 2;
|
||||||
std::vector<size_t> idx = {
|
std::vector<size_t> idx = {
|
||||||
(m + 0) + (n*2 + 0)*cols_per_thread,
|
(m + 0)*cols_per_thread + (n*2 + 0),
|
||||||
(m + 0) + (n*2 + 1)*cols_per_thread,
|
(m + 0)*cols_per_thread + (n*2 + 1),
|
||||||
(m + 1) + (n*2 + 0)*cols_per_thread,
|
(m + 1)*cols_per_thread + (n*2 + 0),
|
||||||
(m + 1) + (n*2 + 1)*cols_per_thread
|
(m + 1)*cols_per_thread + (n*2 + 1)
|
||||||
};
|
};
|
||||||
Value *nc = call(mma_ty, mma_fn,
|
Value *nc = call(mma_ty, mma_fn,
|
||||||
{ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}],
|
{ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}],
|
||||||
@@ -2316,62 +2335,93 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){
|
|||||||
void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral){
|
void generator::visit_reducend_inst_fast(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral){
|
||||||
//
|
//
|
||||||
ir::value *arg = x->get_operand(0);
|
ir::value *arg = x->get_operand(0);
|
||||||
analysis::scanline_layout* layout = layouts_->get(arg)->to_scanline();
|
analysis::distributed_layout* layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(arg));
|
||||||
std::vector<unsigned> shapes = layout->get_shape();
|
std::vector<unsigned> shapes = layout->get_shape();
|
||||||
std::vector<int> order = layout->get_order();
|
|
||||||
unsigned mts = layout->mts(order[0]);
|
Type* sca_ty = cvt(arg->get_type()->get_scalar_ty());
|
||||||
unsigned nts = layout->nts(order[0]);
|
size_t n_bits = sca_ty->getPrimitiveSizeInBits();
|
||||||
unsigned col_per_thread = shapes[order[0]] / mts;
|
|
||||||
auto idxs = idxs_.at(arg);
|
std::string n_bits_str = std::to_string(n_bits);
|
||||||
size_t n_elts = idxs.size();
|
std::string cst = (n_bits == 64) ? "l" : "r";
|
||||||
|
|
||||||
|
FunctionType *st_shared_ty = FunctionType::get(void_ty, {i1_ty, ptr_ty(sca_ty, 3), sca_ty}, false);
|
||||||
|
InlineAsm *st_shared = InlineAsm::get(st_shared_ty, "@$0 st.shared.b" + n_bits_str + " [$1], $2;", "b," + cst + "," + cst, true);
|
||||||
|
FunctionType *ld_shared_ty = FunctionType::get(sca_ty, {i1_ty, ptr_ty(sca_ty, 3)}, false);
|
||||||
|
InlineAsm *ld_shared = InlineAsm::get(ld_shared_ty, "@$1 ld.shared.b" + n_bits_str + " $0, [$2];", "=" + cst + ",b," + cst, true);
|
||||||
|
|
||||||
|
|
||||||
|
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
|
||||||
|
Value* warp = udiv(thread, i32(32));
|
||||||
|
Value* lane = urem(thread, i32(32));
|
||||||
|
|
||||||
|
unsigned shuffle_width = 0;
|
||||||
|
unsigned warps_per_inner = 0;
|
||||||
|
auto arg_vals = vals_.at(arg);
|
||||||
|
std::vector<indices_t> arg_idxs = idxs_.at(arg);
|
||||||
|
size_t n_elts = arg_idxs.size();
|
||||||
|
unsigned col_per_thread;
|
||||||
|
Value* warp_i;
|
||||||
|
Value* warp_j;
|
||||||
|
if(analysis::scanline_layout* scanline = layout->to_scanline()){
|
||||||
|
std::vector<int> order = layout->get_order();
|
||||||
|
unsigned mts = scanline->mts(order[0]);
|
||||||
|
shuffle_width = std::min<int>(mts, 32);
|
||||||
|
warps_per_inner = std::max<int>(mts/32, 1);
|
||||||
|
col_per_thread = shapes[order[0]] / mts;
|
||||||
|
warp_i = udiv(warp, i32(warps_per_inner));
|
||||||
|
warp_j = urem(warp, i32(warps_per_inner));
|
||||||
|
}
|
||||||
|
else if(layout->to_mma()){
|
||||||
|
shuffle_width = 4;
|
||||||
|
warps_per_inner = layout->to_mma()->wpt(1);
|
||||||
|
col_per_thread = 16;
|
||||||
|
warp_i = axes_.at(a_axes_->get(arg, 0)).thread_id;
|
||||||
|
warp_j = axes_.at(a_axes_->get(arg, 1)).thread_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
// unsigned col_per_thread = 2 * shapes[order[0]] / layout->shape_per_cta(order[0]);
|
||||||
//
|
//
|
||||||
Type *ret_ty = cvt(x->get_type()->get_scalar_ty());
|
Type *ret_ty = cvt(x->get_type()->get_scalar_ty());
|
||||||
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
|
unsigned addr_space = shmem_->getType()->getPointerAddressSpace();
|
||||||
Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space));
|
Value *base = bit_cast(shmem_, ptr_ty(ret_ty, addr_space));
|
||||||
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
|
// preds
|
||||||
Value* warp = udiv(thread, i32(32));
|
Value* is_lane0 = icmp_eq(lane, i32(0));
|
||||||
Value* lane = urem(thread, i32(32));
|
Value* is_warp0 = icmp_eq(warp, i32(0));
|
||||||
size_t warps_per_inner = std::max<int>(mts/32, 1);
|
Value* is_thread0 = icmp_eq(thread, i32(0));
|
||||||
Value* warp_i = udiv(warp, i32(warps_per_inner));
|
Value* lane_j = urem(lane, i32(shuffle_width));
|
||||||
unsigned row_per_thread = std::max<int>(32/mts, 1);
|
Value* first_lane_in_col = icmp_eq(lane_j, i32(0));
|
||||||
|
add_barrier();
|
||||||
|
// compute partial sum for each warp, and store to shared memory
|
||||||
for(size_t i = 0; i < n_elts/col_per_thread; i++){
|
for(size_t i = 0; i < n_elts/col_per_thread; i++){
|
||||||
Value* acc;
|
Value* acc;
|
||||||
// reduce within thread
|
// reduce within thread
|
||||||
for(size_t j = 0; j < col_per_thread; j++){
|
for(size_t j = 0; j < col_per_thread; j++){
|
||||||
Value* val = vals_[arg][idxs[i*col_per_thread + j]];
|
Value* val = arg_vals[arg_idxs[i*col_per_thread + j]];
|
||||||
|
// acc = (j == 0) ? val : do_acc(acc, val);
|
||||||
acc = (j == 0) ? val : do_acc(acc, val);
|
acc = (j == 0) ? val : do_acc(acc, val);
|
||||||
}
|
}
|
||||||
// reduce within warp
|
// reduce within warp
|
||||||
for(int k = std::min<int>(mts, 32)/2 ; k > 0; k >>= 1)
|
for(int k = shuffle_width/2 ; k > 0; k >>= 1)
|
||||||
acc = do_acc(acc, shfl_sync(acc, k));
|
acc = do_acc(acc, shfl_sync(acc, k));
|
||||||
// store warp result in shared memory
|
// store partial result to shared memory
|
||||||
Value* ret = acc;
|
auto x_idxs = idxs_[x][i];
|
||||||
if(mts >= 32){
|
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
|
||||||
add_barrier();
|
Value* st_off = add(mul(x_idx, i32(warps_per_inner)), warp_j);
|
||||||
store(neutral, gep(base, lane));
|
call(st_shared, {icmp_eq(lane_j, i32(0)), gep(base, st_off), acc});
|
||||||
add_barrier();
|
|
||||||
store(acc, gep(base, warp));
|
|
||||||
add_barrier();
|
|
||||||
// reduce across warps
|
|
||||||
Value *cond = icmp_eq(warp, i32(0));
|
|
||||||
Instruction *barrier = add_barrier();
|
|
||||||
builder_->SetInsertPoint(barrier->getParent());
|
|
||||||
Instruction* dummy = builder_->CreateRet(nullptr);
|
|
||||||
Instruction *term = llvm::SplitBlockAndInsertIfThen(cond, barrier, false);
|
|
||||||
dummy->removeFromParent();
|
|
||||||
builder_->SetInsertPoint(term);
|
|
||||||
ret = load(gep(base, thread));
|
|
||||||
for(int k = (mts/32)/2; k > 0; k >>= 1){
|
|
||||||
Value *current = shfl_sync(ret, k);
|
|
||||||
ret = do_acc(ret, current);
|
|
||||||
}
|
|
||||||
store(ret, gep(base, thread));
|
|
||||||
builder_->SetInsertPoint(barrier->getParent());
|
|
||||||
ret = load(gep(base, warp));
|
|
||||||
}
|
|
||||||
vals_[x][idxs_[x][i]] = ret;
|
|
||||||
}
|
}
|
||||||
|
add_barrier();
|
||||||
|
// at this point, partial accumulator synchronized in shared memory
|
||||||
|
// Just need to reduce `warp_per_inner` numbers in shared memory
|
||||||
|
for(size_t i = 0; i < n_elts/col_per_thread; i++){
|
||||||
|
auto x_idxs = idxs_[x][i];
|
||||||
|
Value* x_idx = x_idxs.empty() ? builder_->getInt32(0) : x_idxs[0];
|
||||||
|
Value* ld_off = add(mul(x_idx, i32(warps_per_inner)), urem(lane_j, i32(warps_per_inner)));
|
||||||
|
Value* acc = call(ld_shared, {builder_->getInt1(true), gep(base, ld_off)});
|
||||||
|
for(int k = warps_per_inner/2; k > 0; k >>= 1)
|
||||||
|
acc = do_acc(acc, shfl_sync(acc, k));
|
||||||
|
vals_[x][idxs_[x][i]] = acc;
|
||||||
|
}
|
||||||
|
// add_barrier();
|
||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral) {
|
void generator::visit_reducend_inst(ir::reduce_inst* x, std::function<Value*(Value*,Value*)> do_acc, Value *neutral) {
|
||||||
@@ -2471,8 +2521,12 @@ void generator::visit_reduce_inst(ir::reduce_inst* x) {
|
|||||||
default: throw std::runtime_error("unreachable");
|
default: throw std::runtime_error("unreachable");
|
||||||
}
|
}
|
||||||
ir::value *arg = x->get_operand(0);
|
ir::value *arg = x->get_operand(0);
|
||||||
|
int cc = tgt_->as_nvidia()->sm();
|
||||||
analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline();
|
analysis::scanline_layout* scanline = layouts_->get(x->get_operand(0))->to_scanline();
|
||||||
if(scanline && scanline->get_order()[0] == x->get_axis())
|
analysis::mma_layout* mma = layouts_->get(x->get_operand(0))->to_mma();
|
||||||
|
bool is_coalesced_scanline = scanline && (scanline->get_order()[0] == x->get_axis());
|
||||||
|
bool is_a100_mma = mma && (cc >= 80) && (x->get_axis() == 1);
|
||||||
|
if(is_coalesced_scanline || is_a100_mma)
|
||||||
visit_reducend_inst_fast(x, do_acc, neutral);
|
visit_reducend_inst_fast(x, do_acc, neutral);
|
||||||
else
|
else
|
||||||
visit_reducend_inst(x, do_acc, neutral);
|
visit_reducend_inst(x, do_acc, neutral);
|
||||||
@@ -2665,12 +2719,12 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
|||||||
unsigned in_vec = 1;
|
unsigned in_vec = 1;
|
||||||
ir::value *arg = cts->get_operand(0);
|
ir::value *arg = cts->get_operand(0);
|
||||||
analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared();
|
analysis::shared_layout* out_layout = layouts_->get(cts)->to_shared();
|
||||||
analysis::scanline_layout* in_layout = layouts_->get(arg)->to_scanline();
|
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(arg));
|
||||||
auto out_order = out_layout->get_order();
|
auto out_order = out_layout->get_order();
|
||||||
auto in_order = in_layout->get_order();
|
auto in_order = in_layout->get_order();
|
||||||
// tiles
|
// tiles
|
||||||
if(out_order == in_order)
|
if(out_order == in_order)
|
||||||
in_vec = in_layout->nts(in_order[0]);
|
in_vec = in_layout->contig_per_thread(in_order[0]);
|
||||||
int out_vec = swizzle_->get_vec(out_layout);
|
int out_vec = swizzle_->get_vec(out_layout);
|
||||||
int min_vec = std::min<int>(out_vec, in_vec);
|
int min_vec = std::min<int>(out_vec, in_vec);
|
||||||
int s = std::max<int>(out_vec / in_vec, 1);
|
int s = std::max<int>(out_vec / in_vec, 1);
|
||||||
@@ -2678,8 +2732,11 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
|||||||
int per_phase = swizzle_->get_per_phase(out_layout);
|
int per_phase = swizzle_->get_per_phase(out_layout);
|
||||||
int max_phase = swizzle_->get_max_phase(out_layout);
|
int max_phase = swizzle_->get_max_phase(out_layout);
|
||||||
//
|
//
|
||||||
int in_ld = in_layout->get_shape()[in_order[0]] / in_layout->mts(in_order[0]);
|
int mts_0 = in_layout->shape_per_cta(in_order[0]) / in_layout->contig_per_thread(in_order[0]);
|
||||||
int n_shared_1 = std::max<int>(per_phase*max_phase / in_layout->mts(in_order[1]), 1);
|
int mts_1 = in_layout->shape_per_cta(in_order[1]) / in_layout->contig_per_thread(in_order[1]);
|
||||||
|
|
||||||
|
int in_ld = in_layout->get_shape()[in_order[0]] / mts_0;
|
||||||
|
int n_shared_1 = std::max<int>(per_phase*max_phase / mts_1, 1);
|
||||||
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
|
int n_shared_0 = std::max<int>(in_vec / out_vec, 1);
|
||||||
|
|
||||||
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
BasicBlock* CurrBB = builder_->GetInsertBlock();
|
||||||
@@ -2700,8 +2757,8 @@ void generator::visit_copy_to_shared_inst(ir::copy_to_shared_inst* cts) {
|
|||||||
// input ptr info
|
// input ptr info
|
||||||
int id_0 = id % (in_ld/min_vec);
|
int id_0 = id % (in_ld/min_vec);
|
||||||
int id_1 = id / (in_ld/min_vec);
|
int id_1 = id / (in_ld/min_vec);
|
||||||
int off_0 = id_0 / n_shared_0 * n_shared_0 * in_layout->mts(in_order[0]);
|
int off_0 = id_0 / n_shared_0 * n_shared_0 * mts_0;
|
||||||
int off_1 = id_1 / n_shared_1 * n_shared_1 * in_layout->mts(in_order[1]);
|
int off_1 = id_1 / n_shared_1 * n_shared_1 * mts_1;
|
||||||
int off = (off_1*shapes[in_order[0]] + off_0);
|
int off = (off_1*shapes[in_order[0]] + off_0);
|
||||||
std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
|
std::pair<int, int> key = {id_1 % n_shared_1, id_0 % n_shared_0};
|
||||||
if(ptrs.find(key) == ptrs.end()){
|
if(ptrs.find(key) == ptrs.end()){
|
||||||
@@ -3026,8 +3083,7 @@ void generator::visit_layout_mma(analysis::mma_layout* layout) {
|
|||||||
else{
|
else{
|
||||||
/* warp offset */
|
/* warp offset */
|
||||||
Value *warp_0 = urem(warp, i32(layout->wpt(0)));
|
Value *warp_0 = urem(warp, i32(layout->wpt(0)));
|
||||||
Value *warp_12 = udiv(warp, i32(layout->wpt(0)));
|
Value *warp_1 = urem(udiv(warp, i32(layout->wpt(0))), i32(layout->wpt(1)));
|
||||||
Value *warp_1 = urem(warp_12, i32(layout->wpt(1)));
|
|
||||||
Value *off_warp_m = mul(warp_0, i32(layout->spw(0)));
|
Value *off_warp_m = mul(warp_0, i32(layout->spw(0)));
|
||||||
Value *off_warp_n = mul(warp_1, i32(layout->spw(1)));
|
Value *off_warp_n = mul(warp_1, i32(layout->spw(1)));
|
||||||
Value *off_lane_m = urem(lane, _16);
|
Value *off_lane_m = urem(lane, _16);
|
||||||
@@ -3152,7 +3208,9 @@ void generator::visit_basic_block(ir::basic_block * block) {
|
|||||||
BasicBlock *parent = bbs_[block];
|
BasicBlock *parent = bbs_[block];
|
||||||
builder_->SetInsertPoint(parent);
|
builder_->SetInsertPoint(parent);
|
||||||
for(ir::instruction *i: block->get_inst_list()){
|
for(ir::instruction *i: block->get_inst_list()){
|
||||||
|
// i->print(std::cout);
|
||||||
visit_value(i);
|
visit_value(i);
|
||||||
|
// std::cout << "done" << std::endl;
|
||||||
}
|
}
|
||||||
// Update ir bb -> llvm bb mapping
|
// Update ir bb -> llvm bb mapping
|
||||||
bbs_[block] = builder_->GetInsertBlock();
|
bbs_[block] = builder_->GetInsertBlock();
|
||||||
|
@@ -52,6 +52,7 @@ coalesce::coalesce(analysis::align* align, analysis::layouts *layouts)
|
|||||||
//}
|
//}
|
||||||
|
|
||||||
void coalesce::run(ir::module &mod) {
|
void coalesce::run(ir::module &mod) {
|
||||||
|
std::set<analysis::data_layout*> invalidated;
|
||||||
ir::builder& builder = mod.get_builder();
|
ir::builder& builder = mod.get_builder();
|
||||||
// add layout conversion instructions
|
// add layout conversion instructions
|
||||||
for(ir::function *fn: mod.get_function_list())
|
for(ir::function *fn: mod.get_function_list())
|
||||||
@@ -61,12 +62,29 @@ void coalesce::run(ir::module &mod) {
|
|||||||
if(dynamic_cast<ir::store_inst*>(i) || dynamic_cast<ir::atomic_rmw_inst*>(i))
|
if(dynamic_cast<ir::store_inst*>(i) || dynamic_cast<ir::atomic_rmw_inst*>(i))
|
||||||
if(ir::value* op = i->get_operand(1))
|
if(ir::value* op = i->get_operand(1))
|
||||||
if(op->get_type()->is_block_ty())
|
if(op->get_type()->is_block_ty())
|
||||||
|
if(op->get_type()->get_tile_rank() == 2)
|
||||||
|
if(invalidated.find(layout_->get(op)) == invalidated.end())
|
||||||
if(layout_->get(op)->to_mma()){
|
if(layout_->get(op)->to_mma()){
|
||||||
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
|
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
|
||||||
builder.set_insert_point(i);
|
builder.set_insert_point(i);
|
||||||
builder.insert(new_op);
|
builder.insert(new_op);
|
||||||
i->replace_uses_of_with(op, new_op);
|
i->replace_uses_of_with(op, new_op);
|
||||||
}
|
}
|
||||||
|
// coalesce before copy_to_shared
|
||||||
|
// It's dirty, but the backend is being rewritten from scratch. :)
|
||||||
|
if(dynamic_cast<ir::copy_to_shared_inst*>(i))
|
||||||
|
if(ir::value* op = i->get_operand(0))
|
||||||
|
if(op->get_type()->is_block_ty())
|
||||||
|
if(op->get_type()->get_tile_rank() == 2)
|
||||||
|
if(invalidated.find(layout_->get(op)) == invalidated.end())
|
||||||
|
if(layout_->get(op)->to_mma()){
|
||||||
|
ir::instruction* new_op = ir::cvt_layout_inst::create(op);
|
||||||
|
builder.set_insert_point(i);
|
||||||
|
builder.insert(new_op);
|
||||||
|
op->replace_all_uses_with(new_op);
|
||||||
|
new_op->replace_uses_of_with(new_op, op);
|
||||||
|
invalidated.insert(layout_->get(op));
|
||||||
|
}
|
||||||
// uncoalesce after load
|
// uncoalesce after load
|
||||||
if(auto x = dynamic_cast<ir::load_inst*>(i))
|
if(auto x = dynamic_cast<ir::load_inst*>(i))
|
||||||
if(x->get_type()->is_block_ty())
|
if(x->get_type()->is_block_ty())
|
||||||
@@ -120,6 +138,7 @@ void coalesce::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
if(in_contig.size() <= 1 || out_contig==in_contig)
|
if(in_contig.size() <= 1 || out_contig==in_contig)
|
||||||
continue;
|
continue;
|
||||||
|
std::cout << "3!!" << std::endl;
|
||||||
builder.set_insert_point_after(val_inst);
|
builder.set_insert_point_after(val_inst);
|
||||||
auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst));
|
auto new_val = builder.insert(ir::cvt_layout_inst::create(val_inst));
|
||||||
x->replace_uses_of_with(val_inst, new_val);
|
x->replace_uses_of_with(val_inst, new_val);
|
||||||
|
@@ -79,7 +79,7 @@ class CMakeBuild(build_ext):
|
|||||||
|
|
||||||
def build_extension(self, ext):
|
def build_extension(self, ext):
|
||||||
llvm_include_dir, llvm_library_dir = get_llvm()
|
llvm_include_dir, llvm_library_dir = get_llvm()
|
||||||
# self.debug = True
|
self.debug = True
|
||||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||||
# create build directories
|
# create build directories
|
||||||
build_suffix = 'debug' if self.debug else 'release'
|
build_suffix = 'debug' if self.debug else 'release'
|
||||||
|
@@ -698,6 +698,7 @@ def test_reduce1d(dtype_str, shape, device='cuda'):
|
|||||||
|
|
||||||
rs = RandomState(17)
|
rs = RandomState(17)
|
||||||
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
|
x = numpy_random((shape,), dtype_str=dtype_str, rs=rs)
|
||||||
|
x[:] = 1
|
||||||
# numpy result
|
# numpy result
|
||||||
z_ref = np.sum(x).astype(getattr(np, dtype_str))
|
z_ref = np.sum(x).astype(getattr(np, dtype_str))
|
||||||
# triton result
|
# triton result
|
||||||
@@ -1132,3 +1133,25 @@ def test_constexpr_shape():
|
|||||||
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
x_tri = to_triton(np.empty((256, ), dtype=np.int32))
|
||||||
kernel[(1,)](x_tri)
|
kernel[(1,)](x_tri)
|
||||||
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
|
np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256))
|
||||||
|
|
||||||
|
# -------------
|
||||||
|
# test if
|
||||||
|
# -------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_if():
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def kernel(Cond, XTrue, XFalse, Ret):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
cond = tl.load(Cond)
|
||||||
|
if pid % 2:
|
||||||
|
tl.store(Ret, tl.load(XTrue))
|
||||||
|
else:
|
||||||
|
tl.store(Ret, tl.load(XFalse))
|
||||||
|
|
||||||
|
cond = torch.ones(1, dtype=torch.int32, device='cuda')
|
||||||
|
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
|
||||||
|
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
|
||||||
|
ret = torch.empty(1, dtype=torch.float32, device='cuda')
|
||||||
|
kernel[(1,)](cond, x_true, x_false, ret)
|
||||||
|
@@ -63,7 +63,7 @@ def mangle_ty(ty):
|
|||||||
def mangle_fn(name, arg_tys, constants):
|
def mangle_fn(name, arg_tys, constants):
|
||||||
# doesn't mangle ret type, which must be a function of arg tys
|
# doesn't mangle ret type, which must be a function of arg tys
|
||||||
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
|
mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys])
|
||||||
key = lambda x: x.__name__ if isinstance(x, JITFunction) else repr(x)
|
key = lambda x: x.cache_key if isinstance(x, JITFunction) else repr(x)
|
||||||
mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)])
|
mangled_constants = '_'.join([f'{i}c{key(constants[i])}' for i in sorted(constants)])
|
||||||
mangled_constants = mangled_constants.replace('.', '_d_')
|
mangled_constants = mangled_constants.replace('.', '_d_')
|
||||||
mangled_constants = mangled_constants.replace("'", '_sq_')
|
mangled_constants = mangled_constants.replace("'", '_sq_')
|
||||||
|
@@ -32,6 +32,8 @@ def _to_tensor(x, builder):
|
|||||||
return _to_tensor(x.value, builder)
|
return _to_tensor(x.value, builder)
|
||||||
elif isinstance(x, tensor):
|
elif isinstance(x, tensor):
|
||||||
return x
|
return x
|
||||||
|
elif x is None:
|
||||||
|
return None
|
||||||
assert False, f'cannot convert {x} to tensor'
|
assert False, f'cannot convert {x} to tensor'
|
||||||
|
|
||||||
|
|
||||||
|
@@ -559,7 +559,7 @@ def cast(input: tl.tensor,
|
|||||||
dst_ty: tl.dtype,
|
dst_ty: tl.dtype,
|
||||||
builder: ir.builder) -> tl.tensor:
|
builder: ir.builder) -> tl.tensor:
|
||||||
src_ty = input.type
|
src_ty = input.type
|
||||||
if src_ty.is_block():
|
if src_ty.is_block() and not dst_ty.is_block():
|
||||||
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
|
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
|
||||||
if src_ty == dst_ty:
|
if src_ty == dst_ty:
|
||||||
return input
|
return input
|
||||||
|
@@ -252,6 +252,7 @@ def matmul_kernel(
|
|||||||
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def leaky_relu(x):
|
def leaky_relu(x):
|
||||||
|
x = x + 1
|
||||||
return tl.where(x >= 0, x, 0.01 * x)
|
return tl.where(x >= 0, x, 0.01 * x)
|
||||||
|
|
||||||
|
|
||||||
@@ -296,7 +297,7 @@ def matmul(a, b, activation=None):
|
|||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||||
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
|
||||||
triton_output = matmul(a, b, activation=None)
|
triton_output = matmul(a, b, activation=leaky_relu)
|
||||||
torch_output = torch.matmul(a, b)
|
torch_output = torch.matmul(a, b)
|
||||||
print(f"triton_output={triton_output}")
|
print(f"triton_output={triton_output}")
|
||||||
print(f"torch_output={torch_output}")
|
print(f"torch_output={torch_output}")
|
||||||
@@ -305,6 +306,8 @@ if triton.testing.allclose(triton_output, torch_output):
|
|||||||
else:
|
else:
|
||||||
print("❌ Triton and Torch differ")
|
print("❌ Triton and Torch differ")
|
||||||
|
|
||||||
|
print(matmul_kernel.cache_key)
|
||||||
|
exit()
|
||||||
# %%
|
# %%
|
||||||
# Benchmark
|
# Benchmark
|
||||||
# --------------
|
# --------------
|
||||||
|
Reference in New Issue
Block a user