[BACKEND] Added Int8 mma (#440)

This commit is contained in:
daadaada
2022-01-28 01:12:44 +08:00
committed by GitHub
parent 3a23c1dd33
commit 59d371c6eb
11 changed files with 232 additions and 115 deletions

View File

@@ -268,6 +268,7 @@ public:
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
int get_mma_vec() { return mma_vec_;}
int get_mma_strided() { return mma_strided_; }
bool allow_swizzle() const { return allow_swizzle_; }
data_layout* get_arg_layout() { return arg_layout_; }
private:
@@ -281,6 +282,7 @@ private:
data_layout* arg_layout_;
int mma_vec_;
int mma_strided_;
bool allow_swizzle_ = true;
target *tgt_;
};

View File

@@ -33,7 +33,9 @@ inline bool is_hmma_c(ir::value *v, int sm){
result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
(a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
(a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
x->allow_tf32() && sm >= 80);
x->allow_tf32() && sm >= 80) ||
(a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) &&
sm >= 80);
}
return result;
}
@@ -63,7 +65,7 @@ static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
return mma_type;
}
} else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
throw std::runtime_error("integer tensor cores are not yet supported");
// throw std::runtime_error("integer tensor cores are not yet supported");
// // integer tensor cores
// if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
// mma_type = mma_layout::INT32_INT1_INT1_INT32;
@@ -73,10 +75,10 @@ static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
// mma_type = mma_layout::INT32_INT4_INT4_INT32;
// return mma_type;
// }
// if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
// mma_type = mma_layout::INT32_INT8_INT8_INT32;
// return mma_type;
// }
if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
mma_type = mma_layout::INT32_INT8_INT8_INT32;
return mma_type;
}
}
}
return mma_layout::NOT_APPLICABLE;
@@ -444,11 +446,21 @@ shared_layout::shared_layout(data_layout *arg,
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
// for now, disable swizzle when using lds.8
if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32)
if (order_[0] == 0) // need transpose
allow_swizzle_ = false;
} else if (hmma_dot_b_) {
assert(order_.size() == 2);
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
// for now, disable swizzle when using lds.8
if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32)
if (order_[0] == 1) // need transpose
allow_swizzle_ = false;
}
// size

View File

@@ -41,9 +41,15 @@ void swizzle::run(ir::module &) {
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
}
else {
if (!layout->allow_swizzle()) {
per_phase_[layout] = 1;
max_phase_[layout] = 1;
vec_[layout] = 1;
} else {
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
vec_[layout] = layout->get_mma_vec();
}
}
}
}

View File

@@ -1,78 +0,0 @@
#pragma once
#include <numeric>
#include <sstream>
#include <iomanip>
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/target.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/ir/context.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/type.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
namespace triton::codegen {
// types
#define void_ty builder_->getVoidTy()
#define f16_ty builder_->getHalfTy()
#define bf16_ty builder_->getBFloatTy()
#define f32_ty builder_->getFloatTy()
#define i8_ty builder_->getInt8Ty()
#define i32_ty builder_->getInt32Ty()
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
// constants
#define i32(...) builder_->getInt32(__VA_ARGS__)
// ops
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__)
#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__)
#define bin_op(...) builder_->CreateBinOp(__VA_ARGS__)
#define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__)
#define br(...) builder_->CreateBr(__VA_ARGS__)
#define call(...) builder_->CreateCall(__VA_ARGS__)
#define cast(...) builder_->CreateCast(__VA_ARGS__)
#define cond_br(...) builder_->CreateCondBr(__VA_ARGS__)
#define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__)
#define extract_elt(...) builder_->CreateExtractElement(__VA_ARGS__)
#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__)
#define fadd(...) builder_->CreateFAdd(__VA_ARGS__)
#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__)
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
#define icmp(...) builder_->CreateICmp(__VA_ARGS__)
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__)
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
#define load(...) builder_->CreateLoad(__VA_ARGS__)
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
#define neg(...) builder_->CreateNeg(__VA_ARGS__)
#define phi(...) builder_->CreatePHI(__VA_ARGS__)
#define ret(...) builder_->CreateRet(__VA_ARGS__)
#define select(...) builder_->CreateSelect(__VA_ARGS__)
#define store(...) builder_->CreateStore(__VA_ARGS__)
#define sub(...) builder_->CreateSub(__VA_ARGS__)
#define shl(...) builder_->CreateShl(__VA_ARGS__)
#define udiv(...) builder_->CreateUDiv(__VA_ARGS__)
#define urem(...) builder_->CreateURem(__VA_ARGS__)
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
#define xor_(...) builder_->CreateXor(__VA_ARGS__)
} // namespace triton::codegen

View File

@@ -1,6 +1,7 @@
#include <numeric>
#include <sstream>
#include <iomanip>
#include <stdexcept>
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/target.h"
#include "triton/codegen/analysis/axes.h"
@@ -1355,9 +1356,6 @@ public:
need_trans_ = k_order_ != order_[0];
can_use_ldmatrix_ = dtsize == 2 || (!need_trans_);
// std::cout << can_use_ldmatrix_ << std::endl;
// std::cout << need_trans_ << std::endl;
// we need more pointers at the fast-changing axis,
if (can_use_ldmatrix_)
num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]];
@@ -1365,6 +1363,9 @@ public:
num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]];
num_ptr_ = std::max<int>(num_ptr_, 2);
// special rule for i8/u8, 4 ptrs for each matrix
if (!can_use_ldmatrix_ && dtsize_ == 1)
num_ptr_ *= 4;
// load_v4 stride (in num of mats)
int load_stride_in_mat[2];
@@ -1445,6 +1446,46 @@ public:
}
return offs;
// throw std::runtime_error("not implemented");
} else if (dtsize_ == 1 && need_trans_) {
// load i8/u8 matrices with lds8
Value *c_off_in_mat = udiv(lane, i32(4)); //
Value *s_off_in_mat = mul(urem(lane, i32(4)), i32(4)); // each thread load 4 cols
// Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
std::vector<Value*> offs(num_ptr_);
for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time
int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2;
int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2;
if (k_mat_arr_int > 0) // we don't need pointers for k
continue;
Value *k_mat_arr = i32(k_mat_arr_int);
Value *nk_mat_arr = i32(nk_mat_arr_int);
// physical offset (before swizzling)
Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)),
mul(nk_mat_arr, i32(mat_arr_stride_)));
Value *s_mat_off = k_mat_arr; // always 0?
for (int loadx4_off = 0; loadx4_off < num_ptr_/8; ++loadx4_off) {
for (int elem_off = 0; elem_off < 4; ++elem_off) {
int ptr_off = loadx4_off*8 + nk_mat_arr_int*4 + elem_off;
Value *c_mat_off_i = add(c_mat_off, i32(loadx4_off*p_load_stride_in_mat_*(k_order_ == 1?1:2)));
Value *s_off_in_mat_elem = add(s_off_in_mat, i32(elem_off));
// disable swizzling ...
// Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
// c_mat_off_i = xor_(c_mat_off_i, phase);
Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_)));
Value *s_off = add(s_off_in_mat_elem, mul(s_mat_off, i32(s_mat_shape_)));
// To prevent out-of-bound access when the tile is too small
c_off = urem(c_off, i32(tile_shape_[order_[0]]));
s_off = urem(s_off, i32(tile_shape_[order_[1]]));
offs[ptr_off] = add(c_off, mul(s_off, i32(s_stride_)));
}
}
}
return offs;
} else
throw std::runtime_error("invalid smem load config");
}
@@ -1461,8 +1502,10 @@ public:
int ptr_idx = -1;
if (can_use_ldmatrix_)
ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]);
else // tf32 & trans
else if (dtsize_ == 4 && need_trans_) // tf32 & trans
ptr_idx = mat_idx[order_[0]];
else // i8 & trans
ptr_idx = mat_idx[order_[0]] * 4;
auto get_ptr = [&](int idx) -> Value* {
Value *ptr = nullptr;
@@ -1495,9 +1538,7 @@ public:
extract_val(res_v4, std::vector<unsigned>{1}),
extract_val(res_v4, std::vector<unsigned>{2}),
extract_val(res_v4, std::vector<unsigned>{3})};
} else {
// assert(false && "should not be here");
assert(dtsize_ == 4 && need_trans_);
} else if (dtsize_ == 4 && need_trans_) { // use lds.32 to load tf32 matrices
Value *ptr2 = get_ptr(ptr_idx+1);
assert(s_mat_stride_ == 1);
int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_;
@@ -1521,7 +1562,96 @@ public:
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3);
}
return {elem0, elem1, elem2, elem3};
}
} else if (dtsize_ == 1 && need_trans_) { // use lds.8 to load i8/u8 matrices
Value *ptr00 = get_ptr(ptr_idx);
Value *ptr01 = get_ptr(ptr_idx+1);
Value *ptr02 = get_ptr(ptr_idx+2);
Value *ptr03 = get_ptr(ptr_idx+3);
Value *ptr10 = get_ptr(ptr_idx+4);
Value *ptr11 = get_ptr(ptr_idx+5);
Value *ptr12 = get_ptr(ptr_idx+6);
Value *ptr13 = get_ptr(ptr_idx+7);
assert(s_mat_stride_ == 1);
int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_;
int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_;
Value *i8v4_elems[4];
Value *i32_elems[4];
for (int i=0; i<4; ++i)
i8v4_elems[i] = UndefValue::get(vec_ty(i8_ty, 4));
Value *elem00, *elem01, *elem02, *elem03;
Value *elem10, *elem11, *elem12, *elem13;
Value *elem20, *elem21, *elem22, *elem23;
Value *elem30, *elem31, *elem32, *elem33;
Value *i8_elems[4*4];
if (k_order_ == 1) { //
i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem)));
i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem)));
i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem)));
i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem)));
assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8));
i8_elems[1*4 + 0] = load(gep(ptr10, i32(s_offset_elem)));
i8_elems[1*4 + 1] = load(gep(ptr11, i32(s_offset_elem)));
i8_elems[1*4 + 2] = load(gep(ptr12, i32(s_offset_elem)));
i8_elems[1*4 + 3] = load(gep(ptr13, i32(s_offset_elem)));
i8_elems[2*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[2*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[2*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[2*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem)));
for (int m=0; m<4; ++m) {
for (int e=0; e<4; ++e)
i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e);
i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty);
}
} else { // for b (k first)
i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem)));
i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem)));
i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem)));
i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem)));
assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8));
i8_elems[2*4 + 0] = load(gep(ptr10, i32(s_offset_elem)));
i8_elems[2*4 + 1] = load(gep(ptr11, i32(s_offset_elem)));
i8_elems[2*4 + 2] = load(gep(ptr12, i32(s_offset_elem)));
i8_elems[2*4 + 3] = load(gep(ptr13, i32(s_offset_elem)));
i8_elems[1*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[1*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[1*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[1*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem)));
i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem)));
for (int m=0; m<4; ++m) {
for (int e=0; e<4; ++e)
i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e);
i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty);
}
}
if (k == 0 && inc == 1 && is_prefetch) {
for (int m = 0; m < 4; ++m)
for (int e = 0; e < 4; ++e)
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(i8_elems[m*4 + e]);
}
return {i32_elems[0], i32_elems[1], i32_elems[2], i32_elems[3]};
} else
throw std::runtime_error("invalid smem load");
}
int get_num_ptr() const { return num_ptr_; }
@@ -1596,12 +1726,18 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
const int num_rep_n = shapes[1] / layout->shape_per_cta(1);
const int num_rep_k = std::max<int>(NK/mma_instr_k, 1);
// floating point types
Type *fp32_ty = f32_ty;
Type *fp16x2_ty = vec_ty(f16_ty, 2);
Type *bf16x2_ty = vec_ty(bf16_ty, 2);
Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty});
Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty});
// integer types
Type *i8x4_ty = vec_ty(i8_ty, 4);
Type *i8x4_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty});
Type *i32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{i32_ty, i32_ty, i32_ty, i32_ty});
FunctionType *ldmatrix_ty = nullptr;
FunctionType *mma_ty = nullptr;
@@ -1630,6 +1766,16 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
smem_ptr_ty = ptr_ty(fp32_ty, 3);
ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
phi_ty = fp32_ty;
} else if (A_ir_ty->is_integer_ty(8) && B_ir_ty->is_integer_ty(8)) {
// FIXME: We should use i8 here (but nvptx will generate extra casts when using i8)
mma_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false);
smem_ptr_ty = ptr_ty(i8_ty, 3);
ldmatrix_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
phi_ty = i32_ty;
// mma_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false);
// smem_ptr_ty = ptr_ty(i8_ty, 3);
// ldmatrix_ty = FunctionType::get(i8x4_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
// phi_ty = i8x4_ty;
} else
throw std::runtime_error("mma16816 data type not supported");
@@ -1690,7 +1836,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
" {$4, $5, $6, $7},"
" {$8, $9},"
" {$10, $11, $12, $13};",
"=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true);
"=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3", true);
// create mma & unpack result, m, n, k are offsets in mat
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
@@ -1715,12 +1861,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
auto register_lds2 =
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int n, int k, int inc, Value* val, bool is_prefetch) {
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) {
if (k < 2 && is_prefetch) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{n, k}], val, inc_block));
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block));
} else
vals[{n, k}] = val;
vals[{mn, k}] = val;
};
auto load_a = [&](int m, int k, int inc, bool is_prefetch) {
@@ -1922,7 +2068,10 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
return visit_mma884(dot, A, B, D, NK);
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&
A->get_type()->get_scalar_ty()->is_fp32_ty())
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
throw std::runtime_error("dot has invalid operand type");
}
void generator::visit_trans_inst(ir::trans_inst* trans) {

View File

@@ -61,7 +61,8 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
// dot(a, b, c) + d -> dot(a, b, c + d)
// d + dot(a, b, c) -> dot(a, b, c + d)
auto add = dynamic_cast<ir::binary_operator*>(value);
if(add && add->get_op() == ir::binary_op_t::FAdd) {
if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) {
bool is_int_dot = add->get_op() == ir::binary_op_t::Add;
ir::value *lhs = add->get_operand(0);
ir::value *rhs = add->get_operand(1);
ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs);
@@ -72,11 +73,17 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
ir::value *other = (dot == lhs) ? rhs : lhs;
ir::value *acc = dot->get_operand(2);
ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc);
ir::constant_fp *_0 = nullptr;
ir::constant *_0 = nullptr;
if(splat)
_0 = dynamic_cast<ir::constant_fp*>(splat->get_operand(0));
if(!(_0 && _0->get_value() == 0.0))
_0 = dynamic_cast<ir::constant*>(splat->get_operand(0));
if(!_0)
return false;
if (auto *fp_0 = dynamic_cast<ir::constant_fp*>(_0))
if (fp_0->get_value() != 0.0)
return false;
if (auto *int_0 = dynamic_cast<ir::constant_int*>(_0))
if (int_0->get_value() != 0)
return false;
ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1);
builder.set_insert_point(add);

View File

@@ -33,6 +33,9 @@ void prefetch::run(ir::module &mod) {
if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() ||
dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() ||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32()
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) ||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_integer_ty(8)
&& dot->get_operand(1)->get_type()->get_scalar_ty()->is_integer_ty(8)
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
)
)

View File

@@ -46,6 +46,7 @@
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Scalar.h"
// begin AMD stuff
#include "llvm/Support/FileSystem.h"
@@ -121,9 +122,12 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
init_llvm();
// verify and store llvm
llvm::legacy::PassManager pm;
// module->print(llvm::outs(), nullptr);
pm.add(llvm::createVerifierPass());
// pm.add(llvm::createDeadCodeEliminationPass());
// pm.add(llvm::createEarlyCSEPass());
pm.run(*module);
// module->print(llvm::outs(), nullptr);
// create machine
module->setTargetTriple(triple);
std::string error;

View File

@@ -726,7 +726,11 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask
//===----------------------------------------------------------------------===//
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) {
ir::value *_0 = builder->get_float32(0);
ir::value *_0 = nullptr;
if (lhs->get_type()->is_int_or_tileint_ty())
_0 = builder->get_int32(0);
else
_0 = builder->get_float32(0);
unsigned M = lhs->get_type()->get_block_shapes()[0];
unsigned N = rhs->get_type()->get_block_shapes()[1];
_0 = builder->create_splat(_0, {M, N});

View File

@@ -77,7 +77,7 @@ class CMakeBuild(build_ext):
def build_extension(self, ext):
llvm_include_dir, llvm_library_dir = get_llvm()
# self.debug = True
self.debug = True
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
# create build directories
build_suffix = 'debug' if self.debug else 'release'

View File

@@ -661,11 +661,20 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# ---------------
@pytest.mark.parametrize("epilogue, allow_tf32",
[(epilogue, allow_tf32)
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
[(epilogue, allow_tf32, dtype)
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
for allow_tf32 in [True, False]])
def test_dot(epilogue, allow_tf32, device='cuda'):
for allow_tf32 in [True, False]
for dtype in ['float32', 'int8']
if not (allow_tf32 and (dtype == 'int8'))])
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80:
if dtype == 'int8':
pytest.skip("Only test int8 on devices with sm >= 80")
elif dtype == 'float32' and allow_tf32:
pytest.skip("Only test tf32 on devices with sm >= 80")
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xk,
@@ -693,18 +702,15 @@ def test_dot(epilogue, allow_tf32, device='cuda'):
# input
M, N, K = 64, 64, 32
rs = RandomState(17)
x = numpy_random((M, K), dtype_str='float32', rs=rs)
y = numpy_random((K, N), dtype_str='float32', rs=rs)
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
if allow_tf32:
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80:
pytest.skip("Only test tf32 on devices with sm >= 80")
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
# triton result
z = numpy_random((M, N), dtype_str='float32', rs=rs)
z = numpy_random((M, N), dtype_str=dtype, rs=rs)
z_tri = to_triton(z, device=device)
if epilogue == 'trans':
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
@@ -732,8 +738,10 @@ def test_dot(epilogue, allow_tf32, device='cuda'):
assert 'st.global.v4' in ptx
if allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
else:
elif dtype == 'float32':
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
elif dtype == 'int8':
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
def test_dot_without_load():