[BACKEND] Added Int8 mma (#440)
This commit is contained in:
@@ -268,6 +268,7 @@ public:
|
||||
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
|
||||
int get_mma_vec() { return mma_vec_;}
|
||||
int get_mma_strided() { return mma_strided_; }
|
||||
bool allow_swizzle() const { return allow_swizzle_; }
|
||||
data_layout* get_arg_layout() { return arg_layout_; }
|
||||
|
||||
private:
|
||||
@@ -281,6 +282,7 @@ private:
|
||||
data_layout* arg_layout_;
|
||||
int mma_vec_;
|
||||
int mma_strided_;
|
||||
bool allow_swizzle_ = true;
|
||||
target *tgt_;
|
||||
};
|
||||
|
||||
|
@@ -33,7 +33,9 @@ inline bool is_hmma_c(ir::value *v, int sm){
|
||||
result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
|
||||
(a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
|
||||
(a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
|
||||
x->allow_tf32() && sm >= 80);
|
||||
x->allow_tf32() && sm >= 80) ||
|
||||
(a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) &&
|
||||
sm >= 80);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@@ -63,7 +65,7 @@ static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
|
||||
return mma_type;
|
||||
}
|
||||
} else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
|
||||
throw std::runtime_error("integer tensor cores are not yet supported");
|
||||
// throw std::runtime_error("integer tensor cores are not yet supported");
|
||||
// // integer tensor cores
|
||||
// if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
|
||||
// mma_type = mma_layout::INT32_INT1_INT1_INT32;
|
||||
@@ -73,10 +75,10 @@ static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
|
||||
// mma_type = mma_layout::INT32_INT4_INT4_INT32;
|
||||
// return mma_type;
|
||||
// }
|
||||
// if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
|
||||
// mma_type = mma_layout::INT32_INT8_INT8_INT32;
|
||||
// return mma_type;
|
||||
// }
|
||||
if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
|
||||
mma_type = mma_layout::INT32_INT8_INT8_INT32;
|
||||
return mma_type;
|
||||
}
|
||||
}
|
||||
}
|
||||
return mma_layout::NOT_APPLICABLE;
|
||||
@@ -444,11 +446,21 @@ shared_layout::shared_layout(data_layout *arg,
|
||||
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
|
||||
mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
|
||||
mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
|
||||
|
||||
// for now, disable swizzle when using lds.8
|
||||
if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32)
|
||||
if (order_[0] == 0) // need transpose
|
||||
allow_swizzle_ = false;
|
||||
} else if (hmma_dot_b_) {
|
||||
assert(order_.size() == 2);
|
||||
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
|
||||
mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
|
||||
mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
|
||||
|
||||
// for now, disable swizzle when using lds.8
|
||||
if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32)
|
||||
if (order_[0] == 1) // need transpose
|
||||
allow_swizzle_ = false;
|
||||
}
|
||||
|
||||
// size
|
||||
|
@@ -41,9 +41,15 @@ void swizzle::run(ir::module &) {
|
||||
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
||||
}
|
||||
else {
|
||||
if (!layout->allow_swizzle()) {
|
||||
per_phase_[layout] = 1;
|
||||
max_phase_[layout] = 1;
|
||||
vec_[layout] = 1;
|
||||
} else {
|
||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
|
||||
vec_[layout] = layout->get_mma_vec();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,78 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
#include "triton/codegen/selection/generator.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
#include "triton/codegen/analysis/allocation.h"
|
||||
#include "triton/codegen/analysis/align.h"
|
||||
#include "triton/codegen/analysis/swizzle.h"
|
||||
#include "triton/codegen/transform/coalesce.h"
|
||||
#include "triton/ir/context.h"
|
||||
#include "triton/ir/module.h"
|
||||
#include "triton/ir/function.h"
|
||||
#include "triton/ir/type.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/IntrinsicsNVPTX.h"
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Attributes.h"
|
||||
#include "llvm/IR/InlineAsm.h"
|
||||
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
||||
|
||||
namespace triton::codegen {
|
||||
// types
|
||||
#define void_ty builder_->getVoidTy()
|
||||
#define f16_ty builder_->getHalfTy()
|
||||
#define bf16_ty builder_->getBFloatTy()
|
||||
#define f32_ty builder_->getFloatTy()
|
||||
#define i8_ty builder_->getInt8Ty()
|
||||
#define i32_ty builder_->getInt32Ty()
|
||||
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
||||
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
||||
// constants
|
||||
#define i32(...) builder_->getInt32(__VA_ARGS__)
|
||||
// ops
|
||||
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
|
||||
#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__)
|
||||
#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__)
|
||||
#define bin_op(...) builder_->CreateBinOp(__VA_ARGS__)
|
||||
#define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__)
|
||||
#define br(...) builder_->CreateBr(__VA_ARGS__)
|
||||
#define call(...) builder_->CreateCall(__VA_ARGS__)
|
||||
#define cast(...) builder_->CreateCast(__VA_ARGS__)
|
||||
#define cond_br(...) builder_->CreateCondBr(__VA_ARGS__)
|
||||
#define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__)
|
||||
#define extract_elt(...) builder_->CreateExtractElement(__VA_ARGS__)
|
||||
#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__)
|
||||
#define fadd(...) builder_->CreateFAdd(__VA_ARGS__)
|
||||
#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__)
|
||||
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
|
||||
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
|
||||
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
|
||||
#define icmp(...) builder_->CreateICmp(__VA_ARGS__)
|
||||
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
|
||||
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
|
||||
#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__)
|
||||
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
|
||||
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
|
||||
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
|
||||
#define load(...) builder_->CreateLoad(__VA_ARGS__)
|
||||
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
|
||||
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
|
||||
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
|
||||
#define neg(...) builder_->CreateNeg(__VA_ARGS__)
|
||||
#define phi(...) builder_->CreatePHI(__VA_ARGS__)
|
||||
#define ret(...) builder_->CreateRet(__VA_ARGS__)
|
||||
#define select(...) builder_->CreateSelect(__VA_ARGS__)
|
||||
#define store(...) builder_->CreateStore(__VA_ARGS__)
|
||||
#define sub(...) builder_->CreateSub(__VA_ARGS__)
|
||||
#define shl(...) builder_->CreateShl(__VA_ARGS__)
|
||||
#define udiv(...) builder_->CreateUDiv(__VA_ARGS__)
|
||||
#define urem(...) builder_->CreateURem(__VA_ARGS__)
|
||||
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
|
||||
#define xor_(...) builder_->CreateXor(__VA_ARGS__)
|
||||
|
||||
} // namespace triton::codegen
|
@@ -1,6 +1,7 @@
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
#include <stdexcept>
|
||||
#include "triton/codegen/selection/generator.h"
|
||||
#include "triton/codegen/target.h"
|
||||
#include "triton/codegen/analysis/axes.h"
|
||||
@@ -1355,9 +1356,6 @@ public:
|
||||
need_trans_ = k_order_ != order_[0];
|
||||
can_use_ldmatrix_ = dtsize == 2 || (!need_trans_);
|
||||
|
||||
// std::cout << can_use_ldmatrix_ << std::endl;
|
||||
// std::cout << need_trans_ << std::endl;
|
||||
|
||||
// we need more pointers at the fast-changing axis,
|
||||
if (can_use_ldmatrix_)
|
||||
num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]];
|
||||
@@ -1365,6 +1363,9 @@ public:
|
||||
num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]];
|
||||
num_ptr_ = std::max<int>(num_ptr_, 2);
|
||||
|
||||
// special rule for i8/u8, 4 ptrs for each matrix
|
||||
if (!can_use_ldmatrix_ && dtsize_ == 1)
|
||||
num_ptr_ *= 4;
|
||||
|
||||
// load_v4 stride (in num of mats)
|
||||
int load_stride_in_mat[2];
|
||||
@@ -1445,6 +1446,46 @@ public:
|
||||
}
|
||||
return offs;
|
||||
// throw std::runtime_error("not implemented");
|
||||
} else if (dtsize_ == 1 && need_trans_) {
|
||||
// load i8/u8 matrices with lds8
|
||||
Value *c_off_in_mat = udiv(lane, i32(4)); //
|
||||
Value *s_off_in_mat = mul(urem(lane, i32(4)), i32(4)); // each thread load 4 cols
|
||||
|
||||
// Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
|
||||
std::vector<Value*> offs(num_ptr_);
|
||||
for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time
|
||||
int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2;
|
||||
int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2;
|
||||
if (k_mat_arr_int > 0) // we don't need pointers for k
|
||||
continue;
|
||||
Value *k_mat_arr = i32(k_mat_arr_int);
|
||||
Value *nk_mat_arr = i32(nk_mat_arr_int);
|
||||
// physical offset (before swizzling)
|
||||
Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)),
|
||||
mul(nk_mat_arr, i32(mat_arr_stride_)));
|
||||
Value *s_mat_off = k_mat_arr; // always 0?
|
||||
|
||||
for (int loadx4_off = 0; loadx4_off < num_ptr_/8; ++loadx4_off) {
|
||||
for (int elem_off = 0; elem_off < 4; ++elem_off) {
|
||||
int ptr_off = loadx4_off*8 + nk_mat_arr_int*4 + elem_off;
|
||||
|
||||
Value *c_mat_off_i = add(c_mat_off, i32(loadx4_off*p_load_stride_in_mat_*(k_order_ == 1?1:2)));
|
||||
Value *s_off_in_mat_elem = add(s_off_in_mat, i32(elem_off));
|
||||
|
||||
// disable swizzling ...
|
||||
// Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
|
||||
// c_mat_off_i = xor_(c_mat_off_i, phase);
|
||||
|
||||
Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_)));
|
||||
Value *s_off = add(s_off_in_mat_elem, mul(s_mat_off, i32(s_mat_shape_)));
|
||||
// To prevent out-of-bound access when the tile is too small
|
||||
c_off = urem(c_off, i32(tile_shape_[order_[0]]));
|
||||
s_off = urem(s_off, i32(tile_shape_[order_[1]]));
|
||||
offs[ptr_off] = add(c_off, mul(s_off, i32(s_stride_)));
|
||||
}
|
||||
}
|
||||
}
|
||||
return offs;
|
||||
} else
|
||||
throw std::runtime_error("invalid smem load config");
|
||||
}
|
||||
@@ -1461,8 +1502,10 @@ public:
|
||||
int ptr_idx = -1;
|
||||
if (can_use_ldmatrix_)
|
||||
ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]);
|
||||
else // tf32 & trans
|
||||
else if (dtsize_ == 4 && need_trans_) // tf32 & trans
|
||||
ptr_idx = mat_idx[order_[0]];
|
||||
else // i8 & trans
|
||||
ptr_idx = mat_idx[order_[0]] * 4;
|
||||
|
||||
auto get_ptr = [&](int idx) -> Value* {
|
||||
Value *ptr = nullptr;
|
||||
@@ -1495,9 +1538,7 @@ public:
|
||||
extract_val(res_v4, std::vector<unsigned>{1}),
|
||||
extract_val(res_v4, std::vector<unsigned>{2}),
|
||||
extract_val(res_v4, std::vector<unsigned>{3})};
|
||||
} else {
|
||||
// assert(false && "should not be here");
|
||||
assert(dtsize_ == 4 && need_trans_);
|
||||
} else if (dtsize_ == 4 && need_trans_) { // use lds.32 to load tf32 matrices
|
||||
Value *ptr2 = get_ptr(ptr_idx+1);
|
||||
assert(s_mat_stride_ == 1);
|
||||
int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_;
|
||||
@@ -1521,7 +1562,96 @@ public:
|
||||
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3);
|
||||
}
|
||||
return {elem0, elem1, elem2, elem3};
|
||||
}
|
||||
} else if (dtsize_ == 1 && need_trans_) { // use lds.8 to load i8/u8 matrices
|
||||
Value *ptr00 = get_ptr(ptr_idx);
|
||||
Value *ptr01 = get_ptr(ptr_idx+1);
|
||||
Value *ptr02 = get_ptr(ptr_idx+2);
|
||||
Value *ptr03 = get_ptr(ptr_idx+3);
|
||||
|
||||
Value *ptr10 = get_ptr(ptr_idx+4);
|
||||
Value *ptr11 = get_ptr(ptr_idx+5);
|
||||
Value *ptr12 = get_ptr(ptr_idx+6);
|
||||
Value *ptr13 = get_ptr(ptr_idx+7);
|
||||
|
||||
assert(s_mat_stride_ == 1);
|
||||
int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_;
|
||||
int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_;
|
||||
|
||||
Value *i8v4_elems[4];
|
||||
Value *i32_elems[4];
|
||||
for (int i=0; i<4; ++i)
|
||||
i8v4_elems[i] = UndefValue::get(vec_ty(i8_ty, 4));
|
||||
|
||||
Value *elem00, *elem01, *elem02, *elem03;
|
||||
Value *elem10, *elem11, *elem12, *elem13;
|
||||
Value *elem20, *elem21, *elem22, *elem23;
|
||||
Value *elem30, *elem31, *elem32, *elem33;
|
||||
Value *i8_elems[4*4];
|
||||
if (k_order_ == 1) { //
|
||||
i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem)));
|
||||
i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem)));
|
||||
i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem)));
|
||||
i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem)));
|
||||
|
||||
assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8));
|
||||
|
||||
i8_elems[1*4 + 0] = load(gep(ptr10, i32(s_offset_elem)));
|
||||
i8_elems[1*4 + 1] = load(gep(ptr11, i32(s_offset_elem)));
|
||||
i8_elems[1*4 + 2] = load(gep(ptr12, i32(s_offset_elem)));
|
||||
i8_elems[1*4 + 3] = load(gep(ptr13, i32(s_offset_elem)));
|
||||
|
||||
i8_elems[2*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
i8_elems[2*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
i8_elems[2*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
i8_elems[2*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
|
||||
i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
|
||||
for (int m=0; m<4; ++m) {
|
||||
for (int e=0; e<4; ++e)
|
||||
i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e);
|
||||
i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty);
|
||||
}
|
||||
} else { // for b (k first)
|
||||
i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem)));
|
||||
i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem)));
|
||||
i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem)));
|
||||
i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem)));
|
||||
|
||||
assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8));
|
||||
|
||||
i8_elems[2*4 + 0] = load(gep(ptr10, i32(s_offset_elem)));
|
||||
i8_elems[2*4 + 1] = load(gep(ptr11, i32(s_offset_elem)));
|
||||
i8_elems[2*4 + 2] = load(gep(ptr12, i32(s_offset_elem)));
|
||||
i8_elems[2*4 + 3] = load(gep(ptr13, i32(s_offset_elem)));
|
||||
|
||||
i8_elems[1*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
i8_elems[1*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
i8_elems[1*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
i8_elems[1*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
|
||||
i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem)));
|
||||
|
||||
for (int m=0; m<4; ++m) {
|
||||
for (int e=0; e<4; ++e)
|
||||
i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e);
|
||||
i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty);
|
||||
}
|
||||
}
|
||||
if (k == 0 && inc == 1 && is_prefetch) {
|
||||
for (int m = 0; m < 4; ++m)
|
||||
for (int e = 0; e < 4; ++e)
|
||||
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(i8_elems[m*4 + e]);
|
||||
}
|
||||
return {i32_elems[0], i32_elems[1], i32_elems[2], i32_elems[3]};
|
||||
} else
|
||||
throw std::runtime_error("invalid smem load");
|
||||
}
|
||||
|
||||
int get_num_ptr() const { return num_ptr_; }
|
||||
@@ -1596,12 +1726,18 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
const int num_rep_n = shapes[1] / layout->shape_per_cta(1);
|
||||
const int num_rep_k = std::max<int>(NK/mma_instr_k, 1);
|
||||
|
||||
// floating point types
|
||||
Type *fp32_ty = f32_ty;
|
||||
Type *fp16x2_ty = vec_ty(f16_ty, 2);
|
||||
Type *bf16x2_ty = vec_ty(bf16_ty, 2);
|
||||
Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
|
||||
Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty});
|
||||
Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
||||
// integer types
|
||||
Type *i8x4_ty = vec_ty(i8_ty, 4);
|
||||
Type *i8x4_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty});
|
||||
Type *i32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{i32_ty, i32_ty, i32_ty, i32_ty});
|
||||
|
||||
|
||||
FunctionType *ldmatrix_ty = nullptr;
|
||||
FunctionType *mma_ty = nullptr;
|
||||
@@ -1630,6 +1766,16 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
smem_ptr_ty = ptr_ty(fp32_ty, 3);
|
||||
ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
|
||||
phi_ty = fp32_ty;
|
||||
} else if (A_ir_ty->is_integer_ty(8) && B_ir_ty->is_integer_ty(8)) {
|
||||
// FIXME: We should use i8 here (but nvptx will generate extra casts when using i8)
|
||||
mma_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false);
|
||||
smem_ptr_ty = ptr_ty(i8_ty, 3);
|
||||
ldmatrix_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
|
||||
phi_ty = i32_ty;
|
||||
// mma_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false);
|
||||
// smem_ptr_ty = ptr_ty(i8_ty, 3);
|
||||
// ldmatrix_ty = FunctionType::get(i8x4_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
|
||||
// phi_ty = i8x4_ty;
|
||||
} else
|
||||
throw std::runtime_error("mma16816 data type not supported");
|
||||
|
||||
@@ -1690,7 +1836,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
" {$4, $5, $6, $7},"
|
||||
" {$8, $9},"
|
||||
" {$10, $11, $12, $13};",
|
||||
"=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true);
|
||||
"=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3", true);
|
||||
|
||||
// create mma & unpack result, m, n, k are offsets in mat
|
||||
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
|
||||
@@ -1715,12 +1861,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
||||
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
||||
|
||||
auto register_lds2 =
|
||||
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int n, int k, int inc, Value* val, bool is_prefetch) {
|
||||
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) {
|
||||
if (k < 2 && is_prefetch) {
|
||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{n, k}], val, inc_block));
|
||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block));
|
||||
} else
|
||||
vals[{n, k}] = val;
|
||||
vals[{mn, k}] = val;
|
||||
};
|
||||
|
||||
auto load_a = [&](int m, int k, int inc, bool is_prefetch) {
|
||||
@@ -1922,7 +2068,10 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
||||
return visit_mma884(dot, A, B, D, NK);
|
||||
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
|
||||
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
|
||||
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
||||
if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&
|
||||
A->get_type()->get_scalar_ty()->is_fp32_ty())
|
||||
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
||||
throw std::runtime_error("dot has invalid operand type");
|
||||
}
|
||||
|
||||
void generator::visit_trans_inst(ir::trans_inst* trans) {
|
||||
|
@@ -61,7 +61,8 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
// dot(a, b, c) + d -> dot(a, b, c + d)
|
||||
// d + dot(a, b, c) -> dot(a, b, c + d)
|
||||
auto add = dynamic_cast<ir::binary_operator*>(value);
|
||||
if(add && add->get_op() == ir::binary_op_t::FAdd) {
|
||||
if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) {
|
||||
bool is_int_dot = add->get_op() == ir::binary_op_t::Add;
|
||||
ir::value *lhs = add->get_operand(0);
|
||||
ir::value *rhs = add->get_operand(1);
|
||||
ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs);
|
||||
@@ -72,11 +73,17 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
||||
ir::value *other = (dot == lhs) ? rhs : lhs;
|
||||
ir::value *acc = dot->get_operand(2);
|
||||
ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc);
|
||||
ir::constant_fp *_0 = nullptr;
|
||||
ir::constant *_0 = nullptr;
|
||||
if(splat)
|
||||
_0 = dynamic_cast<ir::constant_fp*>(splat->get_operand(0));
|
||||
if(!(_0 && _0->get_value() == 0.0))
|
||||
_0 = dynamic_cast<ir::constant*>(splat->get_operand(0));
|
||||
if(!_0)
|
||||
return false;
|
||||
if (auto *fp_0 = dynamic_cast<ir::constant_fp*>(_0))
|
||||
if (fp_0->get_value() != 0.0)
|
||||
return false;
|
||||
if (auto *int_0 = dynamic_cast<ir::constant_int*>(_0))
|
||||
if (int_0->get_value() != 0)
|
||||
return false;
|
||||
ir::value *a = dot->get_operand(0);
|
||||
ir::value *b = dot->get_operand(1);
|
||||
builder.set_insert_point(add);
|
||||
|
@@ -33,6 +33,9 @@ void prefetch::run(ir::module &mod) {
|
||||
if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() ||
|
||||
dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() ||
|
||||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32()
|
||||
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) ||
|
||||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_integer_ty(8)
|
||||
&& dot->get_operand(1)->get_type()->get_scalar_ty()->is_integer_ty(8)
|
||||
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
||||
)
|
||||
)
|
||||
|
@@ -46,6 +46,7 @@
|
||||
#include "llvm/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
#include "llvm/Transforms/Scalar.h"
|
||||
|
||||
// begin AMD stuff
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
@@ -121,9 +122,12 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
||||
init_llvm();
|
||||
// verify and store llvm
|
||||
llvm::legacy::PassManager pm;
|
||||
// module->print(llvm::outs(), nullptr);
|
||||
pm.add(llvm::createVerifierPass());
|
||||
// pm.add(llvm::createDeadCodeEliminationPass());
|
||||
// pm.add(llvm::createEarlyCSEPass());
|
||||
pm.run(*module);
|
||||
// module->print(llvm::outs(), nullptr);
|
||||
|
||||
// create machine
|
||||
module->setTargetTriple(triple);
|
||||
std::string error;
|
||||
|
@@ -726,7 +726,11 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) {
|
||||
ir::value *_0 = builder->get_float32(0);
|
||||
ir::value *_0 = nullptr;
|
||||
if (lhs->get_type()->is_int_or_tileint_ty())
|
||||
_0 = builder->get_int32(0);
|
||||
else
|
||||
_0 = builder->get_float32(0);
|
||||
unsigned M = lhs->get_type()->get_block_shapes()[0];
|
||||
unsigned N = rhs->get_type()->get_block_shapes()[1];
|
||||
_0 = builder->create_splat(_0, {M, N});
|
||||
|
@@ -77,7 +77,7 @@ class CMakeBuild(build_ext):
|
||||
|
||||
def build_extension(self, ext):
|
||||
llvm_include_dir, llvm_library_dir = get_llvm()
|
||||
# self.debug = True
|
||||
self.debug = True
|
||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||
# create build directories
|
||||
build_suffix = 'debug' if self.debug else 'release'
|
||||
|
@@ -661,11 +661,20 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
# ---------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("epilogue, allow_tf32",
|
||||
[(epilogue, allow_tf32)
|
||||
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
|
||||
[(epilogue, allow_tf32, dtype)
|
||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
|
||||
for allow_tf32 in [True, False]])
|
||||
def test_dot(epilogue, allow_tf32, device='cuda'):
|
||||
for allow_tf32 in [True, False]
|
||||
for dtype in ['float32', 'int8']
|
||||
if not (allow_tf32 and (dtype == 'int8'))])
|
||||
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80:
|
||||
if dtype == 'int8':
|
||||
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||
elif dtype == 'float32' and allow_tf32:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
|
||||
# triton kernel
|
||||
@triton.jit
|
||||
def kernel(X, stride_xm, stride_xk,
|
||||
@@ -693,18 +702,15 @@ def test_dot(epilogue, allow_tf32, device='cuda'):
|
||||
# input
|
||||
M, N, K = 64, 64, 32
|
||||
rs = RandomState(17)
|
||||
x = numpy_random((M, K), dtype_str='float32', rs=rs)
|
||||
y = numpy_random((K, N), dtype_str='float32', rs=rs)
|
||||
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
|
||||
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
|
||||
if allow_tf32:
|
||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||
if cc < 80:
|
||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||
x_tri = to_triton(x, device=device)
|
||||
y_tri = to_triton(y, device=device)
|
||||
# triton result
|
||||
z = numpy_random((M, N), dtype_str='float32', rs=rs)
|
||||
z = numpy_random((M, N), dtype_str=dtype, rs=rs)
|
||||
z_tri = to_triton(z, device=device)
|
||||
if epilogue == 'trans':
|
||||
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||
@@ -732,8 +738,10 @@ def test_dot(epilogue, allow_tf32, device='cuda'):
|
||||
assert 'st.global.v4' in ptx
|
||||
if allow_tf32:
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||
else:
|
||||
elif dtype == 'float32':
|
||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||
elif dtype == 'int8':
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
def test_dot_without_load():
|
||||
|
Reference in New Issue
Block a user