[BACKEND] Add bf16 & tf32 mma supports (on A100) (#426)

This commit is contained in:
daadaada
2022-01-12 02:20:31 +08:00
committed by GitHub
parent efdabe6073
commit 94a2e10fe5
17 changed files with 717 additions and 263 deletions

View File

@@ -109,6 +109,63 @@ protected:
};
class mma_layout: public distributed_layout {
public:
enum TensorCoreType : uint8_t {
// floating-point tensor core instr
FP32_FP16_FP16_FP32 = 0, // default
FP32_BF16_BF16_FP32,
FP32_TF32_TF32_FP32,
// integer tensor core instr
INT32_INT1_INT1_INT32, // Not implemented
INT32_INT4_INT4_INT32, // Not implemented
INT32_INT8_INT8_INT32, // Not implemented
//
NOT_APPLICABLE,
};
// Used on nvidia GPUs with sm >= 80
inline static const std::map<TensorCoreType, std::vector<int>> mma_instr_shape_ = {
{FP32_FP16_FP16_FP32, {16, 8, 16}},
{FP32_BF16_BF16_FP32, {16, 8, 16}},
{FP32_TF32_TF32_FP32, {16, 8, 8}},
{INT32_INT1_INT1_INT32, {16, 8, 256}},
{INT32_INT4_INT4_INT32, {16, 8, 64}},
{INT32_INT8_INT8_INT32, {16, 8, 32}},
};
// shape of matrices loaded by ldmatrix (m-n-k, for mxk & kxn matrices)
inline static const std::map<TensorCoreType, std::vector<int>> mma_mat_shape_ = {
{FP32_FP16_FP16_FP32, {8, 8, 8}},
{FP32_BF16_BF16_FP32, {8, 8, 8}},
{FP32_TF32_TF32_FP32, {8, 8, 4}},
{INT32_INT1_INT1_INT32, {8, 8, 64}},
{INT32_INT4_INT4_INT32, {8, 8, 32}},
{INT32_INT8_INT8_INT32, {8, 8, 16}},
};
inline static const std::map<TensorCoreType, std::string> mma_instr_ptx_ = {
{FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"},
{FP32_BF16_BF16_FP32, "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"},
{FP32_TF32_TF32_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32"},
{INT32_INT1_INT1_INT32, "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc"},
{INT32_INT4_INT4_INT32, "mma.sync.aligned.m16n8k64.row.col.satfinite.s32.s4.s4.s32"},
{INT32_INT8_INT8_INT32, "mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32"},
};
// vector length per ldmatrix (16*8/elelment_size_in_bits)
inline static const std::map<TensorCoreType, int> mma_instr_vec_ = {
{FP32_FP16_FP16_FP32, 8},
{FP32_BF16_BF16_FP32, 8},
{FP32_TF32_TF32_FP32, 4},
{INT32_INT1_INT1_INT32, 128},
{INT32_INT4_INT4_INT32, 32},
{INT32_INT8_INT8_INT32, 16},
};
public:
mma_layout(size_t num_warps,
const std::vector<int>& axes,
@@ -116,7 +173,8 @@ public:
const std::vector<ir::value *> &values,
analysis::align* align, target *tgt,
shared_layout* layout_a,
shared_layout* layout_b);
shared_layout* layout_b,
ir::value *dot);
void accept(layout_visitor* vst) { vst->visit_layout_mma(this); }
// accessor
int fpw(size_t k) { return fpw_.at(k); }
@@ -124,6 +182,16 @@ public:
int spw(size_t k) { return spw_.at(k); }
int rep(size_t k) { return rep_.at(k); }
// helpers for generator.cc
std::string get_ptx_instr() const { return mma_instr_ptx_.at(tensor_core_type_); }
std::vector<int> get_mma_instr_shape() const { return mma_instr_shape_.at(tensor_core_type_); }
std::vector<int> get_mma_mat_shape() const { return mma_mat_shape_.at(tensor_core_type_); }
int get_vec_a() const { return mma_instr_vec_.at(tensor_core_type_); }
int get_vec_b() const { return mma_instr_vec_.at(tensor_core_type_); }
// setter
void set_tensor_core_type(TensorCoreType type) { tensor_core_type_ = type; }
private:
// fragment per warp
std::vector<int> fpw_;
@@ -135,6 +203,8 @@ private:
std::vector<int> spt_;
// repetitions
std::vector<int> rep_;
TensorCoreType tensor_core_type_ = FP32_FP16_FP16_FP32;
};
struct scanline_layout: public distributed_layout {
@@ -182,7 +252,7 @@ public:
const std::vector<unsigned>& shapes,
const std::vector<ir::value *> &values_,
ir::type *ty,
analysis::align* align);
analysis::align* align, target *tgt);
void accept(layout_visitor* vst) { vst->visit_layout_shared(this); }
// accessors
size_t get_size() { return size_; }
@@ -197,6 +267,7 @@ public:
ir::value* hmma_dot_b() { return hmma_dot_b_; }
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
int get_mma_vec() { return mma_vec_;}
int get_mma_strided() { return mma_strided_; }
data_layout* get_arg_layout() { return arg_layout_; }
private:
@@ -209,6 +280,8 @@ private:
ir::value* hmma_dot_b_;
data_layout* arg_layout_;
int mma_vec_;
int mma_strided_;
target *tgt_;
};

View File

@@ -154,7 +154,7 @@ public:
value *create_cos(value* arg);
value *create_sin(value* arg);
value *create_log(value* arg);
value *create_dot(value *A, value *B, value *C);
value *create_dot(value *A, value *B, value *C, bool allow_tf32);
value *create_trans(value *A, const std::vector<int> &perm = {});
value *create_sqrt(value *A);
value *create_reduce(value *A, reduce_inst::op_t op, unsigned axis);

View File

@@ -80,7 +80,7 @@ struct dispatch{
static ir::value *atomic_xchg(ir::value* ptr, ir::value *val, ir::value *msk, ir::builder *builder);
// linear algebra
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::builder *builder);
static ir::value *dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder);
// indexing
static ir::value *where(ir::value* condition, ir::value *x, ir::value *y, ir::builder *builder);

View File

@@ -742,26 +742,29 @@ public:
};
private:
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next);
dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next);
std::string repr_impl() const { return "dot"; }
bool is_prefetched_ = false;
DataType C_type_ = DataType::FP32;
DataType A_type_ = DataType::FP16;
DataType B_type_ = DataType::FP16;
public:
bool is_prefetched() const { return is_prefetched_; }
void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
bool allow_tf32() const { return allow_tf32_; }
public:
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
static instruction* create_nt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
static instruction* create_tt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
_TRITON_DEFINE_CLONE(dot_inst)
_TRITON_DEFINE_ACCEPT(dot_inst)
private:
bool is_prefetched_ = false;
bool allow_tf32_ = false;
DataType C_type_ = DataType::FP32;
DataType A_type_ = DataType::FP16;
DataType B_type_ = DataType::FP16;
};
//class outer_inst: public builtin_inst {

View File

@@ -23,19 +23,65 @@ inline unsigned clamp(unsigned x, unsigned a, unsigned b) {
return std::min(std::max(x, lo), hi);
}
inline bool is_hmma_c(ir::value *v){
inline bool is_hmma_c(ir::value *v, int sm){
bool result = false;
if(auto *x = dynamic_cast<ir::dot_inst*>(v)){
ir::value *a = x->get_operand(0);
ir::type *a_ty = a->get_type();
ir::value *b = x->get_operand(1);
ir::type *b_ty = b->get_type();
result = a_ty->get_scalar_ty()->is_fp16_ty() &&
b_ty->get_scalar_ty()->is_fp16_ty();
result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
(a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
(a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
x->allow_tf32() && sm >= 80);
}
return result;
}
static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
mma_layout::TensorCoreType mma_type;
if (auto* dot = dynamic_cast<ir::dot_inst*>(v)) {
ir::value* a = dot->get_operand(0);
ir::value* b = dot->get_operand(1);
ir::type* a_ty = a->get_type();
ir::type* b_ty = b->get_type();
ir::type* c_ty = v->get_type();
if (c_ty->get_scalar_ty()->is_fp32_ty()) {
// floating point tensor cores
if (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) {
mma_type = mma_layout::FP32_FP16_FP16_FP32;
return mma_type;
}
if (a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) {
mma_type = mma_layout::FP32_BF16_BF16_FP32;
return mma_type;
}
if (a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty()
&& dot->allow_tf32()) {
mma_type = mma_layout::FP32_TF32_TF32_FP32;
return mma_type;
}
} else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
throw std::runtime_error("integer tensor cores are not yet supported");
// // integer tensor cores
// if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
// mma_type = mma_layout::INT32_INT1_INT1_INT32;
// return mma_type;
// }
// if (a_ty->get_scalar_ty()->is_integer_ty(4) && b_ty->get_scalar_ty()->is_integer_ty(4)) {
// mma_type = mma_layout::INT32_INT4_INT4_INT32;
// return mma_type;
// }
// if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
// mma_type = mma_layout::INT32_INT8_INT8_INT32;
// return mma_type;
// }
}
}
return mma_layout::NOT_APPLICABLE;
}
inline void extract_io_use(ir::value *v, std::set<ir::value*>& result) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::io_inst*>(u);
@@ -52,13 +98,14 @@ inline void extract_dot_use(ir::value *v, ir::value*& result, size_t n) {
}
}
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n) {
inline void extract_hmma_dot_use(ir::value *v, ir::value*& result, size_t n, int sm) {
for(ir::user* u: v->get_users()){
auto i = dynamic_cast<ir::dot_inst*>(u);
if(i && is_hmma_c(i) && i->get_operand(n) == v)
if(i && is_hmma_c(i, sm) && i->get_operand(n) == v) {
result = i;
}
}
}
inline bool is_trans(ir::value *v) {
@@ -142,7 +189,9 @@ mma_layout::mma_layout(size_t num_warps,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
analysis::align* align, target* tgt,
shared_layout *layout_a, shared_layout *layout_b): distributed_layout(MMA, axes, shape, values, align) {
shared_layout *layout_a, shared_layout *layout_b,
ir::value *dot): distributed_layout(MMA, axes, shape, values, align) {
tensor_core_type_ = get_mma_type(dot);
/* fragments per warp */
// try to make things as square as possible to maximize data re-use
if(tgt->as_nvidia()->sm() < 80){
@@ -159,9 +208,9 @@ mma_layout::mma_layout(size_t num_warps,
spw_ = {fpw_[0]*4*rep_[0], fpw_[1]*4*rep_[1], 1};
}
else{
fpw_ = {1, 1, 1};
spw_ = {16, 8, 1};
rep_ = {2, 2, 1};
// fpw_ = {1, 1, 1};
spw_ = mma_instr_shape_.at(tensor_core_type_); // e.g., {16, 8, 16} for f32.f16.f16.f32
// rep_ = {2, 2, 1};
}
order_ = {0, 1};
@@ -356,7 +405,8 @@ shared_layout::shared_layout(data_layout *arg,
const std::vector<unsigned>& shape,
const std::vector<ir::value *> &values,
ir::type *ty,
analysis::align* align): data_layout(SHARED, axes, shape, values, align), ty_(ty) {
analysis::align* align, target *tgt)
: data_layout(SHARED, axes, shape, values, align), ty_(ty), tgt_(tgt) {
size_ = 0;
arg_layout_ = arg;
@@ -382,12 +432,25 @@ shared_layout::shared_layout(data_layout *arg,
for(ir::value* v: values){
extract_dot_use(v, dot_a, 0);
extract_dot_use(v, dot_b, 1);
extract_hmma_dot_use(v, hmma_dot_a, 0);
extract_hmma_dot_use(v, hmma_dot_b, 1);
extract_hmma_dot_use(v, hmma_dot_a, /*op*/0, tgt_->as_nvidia()->sm());
extract_hmma_dot_use(v, hmma_dot_b, /*op*/1, tgt_->as_nvidia()->sm());
}
hmma_dot_a_ = hmma_dot_a;
hmma_dot_b_ = hmma_dot_b;
// Update mma_vec
if (hmma_dot_a_) {
assert(order_.size() == 2);
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
} else if (hmma_dot_b_) {
assert(order_.size() == 2);
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
}
// size
size_ = ty_->get_primitive_size_in_bits() / 8;
for(auto s: shape_)
@@ -451,7 +514,8 @@ void layouts::make_graph(ir::instruction *i) {
void layouts::create(size_t id, const std::vector<ir::value*>& values) {
// if(layouts_.find(id) != layouts_.end())
// return;
auto it_hmma_c = std::find_if(values.begin(), values.end(), &is_hmma_c);
auto it_hmma_c = std::find_if(values.begin(), values.end(),
[&](ir::value* v){ return is_hmma_c(v, tgt_->as_nvidia()->sm()); });
auto cmp = [](ir::value* x, ir::value *y) {
std::pair<int, int> xx = {x->get_type()->get_tile_rank(), x->get_type()->get_tile_num_elements()};
std::pair<int, int> yy = {y->get_type()->get_tile_rank(), y->get_type()->get_tile_num_elements()};
@@ -473,13 +537,16 @@ void layouts::create(size_t id, const std::vector<ir::value*>& values) {
ir::value *b = dot->get_operand(1);
create(groups_.at(a), values_.at(groups_.at(a)));
create(groups_.at(b), values_.at(groups_.at(b)));
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_, (shared_layout*)layouts_.at(groups_.at(a)), (shared_layout*)layouts_.at(groups_.at(b)));
layouts_[id] = new mma_layout(num_warps_, axes, shapes, values, align_, tgt_,
(shared_layout*)layouts_.at(groups_.at(a)),
(shared_layout*)layouts_.at(groups_.at(b)),
dot);
}
else if(it_cts != values.end()){
ir::instruction *cts = (ir::instruction*)*it_cts;
ir::value *arg = cts->get_operand(0);
create(groups_.at(arg), values_.at(groups_.at(arg)));
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_);
layouts_[id] = new shared_layout(get(arg), axes, shapes, values, largest->get_type()->get_scalar_ty(), align_, tgt_);
}
else{
layouts_[id] = new scanline_layout(num_warps_, axes, shapes, values, align_, tgt_);
@@ -516,7 +583,7 @@ void layouts::run(ir::module &mod) {
scanline_layout *layout = get(arg)->to_scanline();
shapes[axis] = layout->mts(axis);
// create layout
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_);
layouts_[id] = new shared_layout(layout, axes_->get(arg), shapes, {red}, red->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[red] = id;
}
if(auto *val = dynamic_cast<ir::cvt_layout_inst*>(i)){
@@ -529,12 +596,12 @@ void layouts::run(ir::module &mod) {
shape[k] = std::max(in_layout->shape_per_cta(k),
out_layout->shape_per_cta(k));
}
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_);
layouts_[id] = new shared_layout(out_layout, axes_->get(val), shape, {val}, val->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[val] = id;
}
if(auto *atom = dynamic_cast<ir::atomic_inst*>(i)){
id++;
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_);
layouts_[id] = new shared_layout(nullptr, {}, {1}, {atom}, atom->get_type()->get_scalar_ty(), align_, tgt_);
tmp_[atom] = id;
}
});

View File

@@ -19,6 +19,7 @@ void swizzle::run(ir::module &) {
continue;
ir::value* mma_dot_a = layout->hmma_dot_a();
ir::value* mma_dot_b = layout->hmma_dot_b();
if(!mma_dot_a && !mma_dot_b){
per_phase_[layout] = 1;
max_phase_[layout] = 1;
@@ -41,8 +42,8 @@ void swizzle::run(ir::module &) {
}
else {
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
max_phase_[layout] = 8 / per_phase_[layout];
vec_[layout] = 8;
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
vec_[layout] = layout->get_mma_vec();
}
}
}

View File

@@ -85,7 +85,6 @@ std::unique_ptr<llvm::Module> add_passes_to_emit_bin(ir::module &ir, llvm::LLVMC
allocation.run(ir);
prefetch_s.run(ir);
barriers.run(ir);
// ir.print(std::cout);
isel.visit(ir, *llvm);
shared_static = allocation.allocated_size();
return llvm;

View File

@@ -0,0 +1,78 @@
#pragma once
#include <numeric>
#include <sstream>
#include <iomanip>
#include "triton/codegen/selection/generator.h"
#include "triton/codegen/target.h"
#include "triton/codegen/analysis/axes.h"
#include "triton/codegen/analysis/allocation.h"
#include "triton/codegen/analysis/align.h"
#include "triton/codegen/analysis/swizzle.h"
#include "triton/codegen/transform/coalesce.h"
#include "triton/ir/context.h"
#include "triton/ir/module.h"
#include "triton/ir/function.h"
#include "triton/ir/type.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
namespace triton::codegen {
// types
#define void_ty builder_->getVoidTy()
#define f16_ty builder_->getHalfTy()
#define bf16_ty builder_->getBFloatTy()
#define f32_ty builder_->getFloatTy()
#define i8_ty builder_->getInt8Ty()
#define i32_ty builder_->getInt32Ty()
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
// constants
#define i32(...) builder_->getInt32(__VA_ARGS__)
// ops
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__)
#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__)
#define bin_op(...) builder_->CreateBinOp(__VA_ARGS__)
#define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__)
#define br(...) builder_->CreateBr(__VA_ARGS__)
#define call(...) builder_->CreateCall(__VA_ARGS__)
#define cast(...) builder_->CreateCast(__VA_ARGS__)
#define cond_br(...) builder_->CreateCondBr(__VA_ARGS__)
#define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__)
#define extract_elt(...) builder_->CreateExtractElement(__VA_ARGS__)
#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__)
#define fadd(...) builder_->CreateFAdd(__VA_ARGS__)
#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__)
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
#define icmp(...) builder_->CreateICmp(__VA_ARGS__)
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__)
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
#define load(...) builder_->CreateLoad(__VA_ARGS__)
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
#define neg(...) builder_->CreateNeg(__VA_ARGS__)
#define phi(...) builder_->CreatePHI(__VA_ARGS__)
#define ret(...) builder_->CreateRet(__VA_ARGS__)
#define select(...) builder_->CreateSelect(__VA_ARGS__)
#define store(...) builder_->CreateStore(__VA_ARGS__)
#define sub(...) builder_->CreateSub(__VA_ARGS__)
#define shl(...) builder_->CreateShl(__VA_ARGS__)
#define udiv(...) builder_->CreateUDiv(__VA_ARGS__)
#define urem(...) builder_->CreateURem(__VA_ARGS__)
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
#define xor_(...) builder_->CreateXor(__VA_ARGS__)
} // namespace triton::codegen

View File

@@ -81,12 +81,13 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
// return (*builder_)->CreateGEP(ty, ptr, vals, name);
//}
// types
#define void_ty builder_->getVoidTy()
#define f16_ty builder_->getHalfTy()
#define bf16_ty builder_->getBFloatTy()
#define f32_ty builder_->getFloatTy()
#define i8_ty builder_->getInt8Ty()
#define i16_ty builder_->getInt16Ty()
#define i32_ty builder_->getInt32Ty()
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
@@ -133,7 +134,6 @@ Value* geper::operator()(Value *ptr, Value* off, const std::string& name){
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
#define xor_(...) builder_->CreateXor(__VA_ARGS__)
/**
* \brief Convert Triton-IR Type to LLVM-IR Type
*/
@@ -162,7 +162,7 @@ Type *generator::cvt(ir::type *ty) {
case ir::type::VoidTyID: return Type::getVoidTy(*ctx_);
case ir::type::FP8TyID: return Type::getInt8Ty(*ctx_);
case ir::type::FP16TyID: return Type::getHalfTy(*ctx_);
case ir::type::BF16TyID: return Type::getInt16Ty(*ctx_);
case ir::type::BF16TyID: return Type::getBFloatTy(*ctx_);
case ir::type::FP32TyID: return Type::getFloatTy(*ctx_);
case ir::type::FP64TyID: return Type::getDoubleTy(*ctx_);
case ir::type::LabelTyID: return Type::getLabelTy(*ctx_);
@@ -457,19 +457,25 @@ std::tuple<Value*, Value*, Value*, Value*> generator::fp8x4_to_fp16x4(Value *in0
}
Value* generator::bf16_to_fp32(Value *in0){
Value *ret = UndefValue::get(vec_ty(builder_->getInt16Ty(), 2));
ret = insert_elt(ret, in0, (uint64_t)1);
ret = insert_elt(ret, builder_->getInt16(0), (uint64_t)0);
return bit_cast(ret, builder_->getFloatTy());
if (tgt_->as_nvidia()->sm() >= 80) {
InlineAsm *ptx = InlineAsm::get(FunctionType::get(f32_ty, {bf16_ty}, false),
"cvt.rn.f32.bf16 $0, $1;", "=r,h", false);
return call(ptx, {in0});
} else {
Value *ret = UndefValue::get(vec_ty(i16_ty, 2));
ret = insert_elt(ret, bit_cast(in0, i16_ty), (uint64_t)1);
ret = insert_elt(ret, bit_cast(builder_->getInt16(0), i16_ty), (uint64_t)0);
return bit_cast(ret, f32_ty);
}
}
Value* generator::fp32_to_bf16(Value *in0){
if(tgt_->as_nvidia()->sm() >= 80){
InlineAsm *ptx = InlineAsm::get(FunctionType::get(builder_->getInt16Ty(), {builder_->getFloatTy()}, false),
InlineAsm *ptx = InlineAsm::get(FunctionType::get(bf16_ty, {f32_ty}, false),
"cvt.rn.bf16.f32 $0, $1;", "=h,r", false);
return call(ptx, {in0});
}
return extract_elt(bit_cast(in0, vec_ty(builder_->getInt16Ty(), 2)), (uint64_t)1);
return extract_elt(bit_cast(in0, vec_ty(i16_ty, 2)), (uint64_t)1);
}
/**
@@ -514,8 +520,12 @@ void generator::visit_cast_inst(ir::cast_inst* x) {
if(ret_sca_ty->is_bf16_ty() || op_sca_ty->is_bf16_ty()){
// FP32 -> BF16
if(op_sca_ty->is_fp32_ty())
for(size_t i = 0; i < x_idxs.size(); i++)
vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]);
// for(size_t i = 0; i < x_idxs.size(); i++)
// vals_[x][x_idxs[i + 0]] = fp32_to_bf16(vals_[op][op_idxs[i + 0]]);
for (indices_t idx: idxs_.at(x)) {
Value *arg = vals_[x->get_operand(0)][idx];
vals_[x][idx] = fp32_to_bf16(arg); // cast(cvt(x->get_op()), arg, ty);
}
// BF16 -> FP32
if(ret_sca_ty->is_fp32_ty())
for(size_t i = 0; i < x_idxs.size(); i++)
@@ -678,6 +688,7 @@ void generator::visit_load_inst(ir::load_inst* x){
// ---
std::vector<Type*> ret_tys(n_words, IntegerType::get(*ctx_, width));
Type* ret_ty = ret_tys.size() > 1 ? StructType::get(*ctx_, ret_tys) : ret_tys[0];
// ret_ty->print(llvm::outs());
std::vector<Type*> arg_tys = {pred->getType(), ptr->getType()};
for(Value *v: others)
arg_tys.push_back(v->getType());
@@ -747,15 +758,19 @@ void generator::visit_store_inst(ir::store_inst * x){
}
auto idxs = idxs_.at(val_op);
Type *ty = cvt(val_op->get_type()->get_scalar_ty());
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
ty = f16_ty;
for(size_t i = 0; i < idxs.size(); i += vec){
auto idx = idxs[i];
// pointer
Value *ptr = vals_[ptr_op][idx];
ptr = bit_cast(ptr, vec_ty(ty, vec)->getPointerTo(1));
// vectorize
Type *v_ty = vec_ty(ty, vec);
ptr = bit_cast(ptr, v_ty->getPointerTo(1));
// value
Value* val = UndefValue::get(vec_ty(ty, vec));
Value* val = UndefValue::get(v_ty);
for(size_t ii = 0; ii < vec; ii++)
val = insert_elt(val, vals_.at(val_op)[idxs[i + ii]], ii);
val = insert_elt(val, bit_cast(vals_.at(val_op)[idxs[i + ii]], ty), ii);
if(mx){
Value *msk = vals_[mx->get_mask_operand()][idx];
Instruction *no_op = intrinsic(Intrinsic::donothing, {}, {});
@@ -1317,6 +1332,229 @@ void generator::visit_mma884(ir::dot_inst* C, ir::value *A, ir::value *B, ir::va
vals_[C][idxs_[C][i]] = acc[i];
}
namespace {
class mma16816_smem_loader {
public:
mma16816_smem_loader(int wpt, std::vector<int> order, int k_order,
std::vector<unsigned> tile_shape,
std::vector<int> instr_shape, std::vector<int> mat_shape,
int per_phase, int max_phase, int dtsize, Builder *builder,
adder add, multiplier mul, geper gep)
: wpt_(wpt), order_(order), k_order_(k_order), tile_shape_(tile_shape),
instr_shape_(instr_shape), mat_shape_(mat_shape),
per_phase_(per_phase), max_phase_(max_phase), dtsize_(dtsize), builder_(builder),
add(add), mul(mul), gep(gep) {
// compute compile-time constant variables & types
c_mat_shape_ = mat_shape[order[0]];
s_mat_shape_ = mat_shape[order[1]];
c_stride_ = tile_shape[order[1]];
s_stride_ = tile_shape[order[0]];
// rule: k must be the fast-changing axis
need_trans_ = k_order_ != order_[0];
can_use_ldmatrix_ = dtsize == 2 || (!need_trans_);
// std::cout << can_use_ldmatrix_ << std::endl;
// std::cout << need_trans_ << std::endl;
// we need more pointers at the fast-changing axis,
if (can_use_ldmatrix_)
num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]];
else // warning: this only works for tf32 & need transpose
num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]];
num_ptr_ = std::max<int>(num_ptr_, 2);
// load_v4 stride (in num of mats)
int load_stride_in_mat[2];
load_stride_in_mat[k_order] = 2; // instr_shape[k_order] / mat_shape[k_order], always 2
load_stride_in_mat[k_order^1] = wpt * (instr_shape[k_order^1] / mat_shape[k_order^1]);
p_load_stride_in_mat_ = load_stride_in_mat[order[0]];
// stride in mat, used by load_v4
s_mat_stride_ = load_stride_in_mat[order[1]] / (instr_shape[order[1]]/mat_shape[order[1]]);
}
std::vector<Value*> compute_offs(Value *warp_off, Value *lane) {
// TODO: this needs to be moved to constructor (and extracted to arr_order)
mat_arr_stride_ = (k_order_ == 1) ? 1 : wpt_;
warp_off_stride_ = instr_shape_[k_order_^1] / mat_shape_[k_order_^1];
// start matrix logic offset (rename it as base_mat_off?)
Value *mat_off[2] = {nullptr, nullptr};
if (can_use_ldmatrix_) {
// c: lane idx inside a group (a group is a collection of 8 contiguous threads)
// s: group idx (0,1,2,3) inside a warp
Value *c = urem(lane, i32(8));
Value *s = udiv(lane, i32(8));
// We can decompose s => s_0, s_1...
Value *s0 = urem(s, i32(2));
Value *s1 = udiv(s, i32(2));
// We use different orders for a & b for better performance.
Value *k_mat_arr = (k_order_ == 1) ? s1 : s0;
Value *nk_mat_arr = (k_order_ == 1) ? s0 : s1;
mat_off[k_order_^1] = add(mul(warp_off, i32(warp_off_stride_)),
mul(nk_mat_arr, i32(mat_arr_stride_)));
mat_off[k_order_] = k_mat_arr;
// physical offset (before swizzling)
Value *c_mat_off = mat_off[order_[0]];
Value *s_mat_off = mat_off[order_[1]];
// offset inside a matrix
Value *s_off_in_mat = c;
std::vector<Value*> offs(num_ptr_);
Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
// pre-compute strided offset
Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_)));
for (int i=0; i < num_ptr_; ++i) {
Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_));
c_mat_off_i = xor_(c_mat_off_i, phase); // smem swizzle
offs[i] = add(mul(c_mat_off_i, i32(c_mat_shape_)), mul(s_off, i32(s_stride_)));
}
return offs;
} else if (dtsize_ == 4 && need_trans_) {
// load tf32 matrices with lds32
Value *c_off_in_mat = udiv(lane, i32(4)); // 4 = mat_shape[order[1]]
Value *s_off_in_mat = urem(lane, i32(4)); //
Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
std::vector<Value*> offs(num_ptr_);
for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time
int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2;
int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2;
if (k_mat_arr_int > 0) // we don't need pointers for k
continue;
Value *k_mat_arr = i32(k_mat_arr_int);
Value *nk_mat_arr = i32(nk_mat_arr_int);
// physical offset (before swizzling)
Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)),
mul(nk_mat_arr, i32(mat_arr_stride_)));
Value *s_mat_off = k_mat_arr; // always 0?
Value *s_off = add(s_off_in_mat, mul(s_mat_off, i32(s_mat_shape_)));
// FIXME: (k_order_ == 1?) is really dirty hack
for (int i = 0; i < num_ptr_/2; ++i) {
Value *c_mat_off_i = add(c_mat_off, i32(i*p_load_stride_in_mat_*(k_order_ == 1?1:2)));
c_mat_off_i = xor_(c_mat_off_i, phase);
Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_)));
// TODO: move this out of the loop
c_off = urem(c_off, i32(tile_shape_[order_[0]]));
s_off = urem(s_off, i32(tile_shape_[order_[1]]));
offs[2*i + nk_mat_arr_int] = add(c_off, mul(s_off, i32(s_stride_)));
}
}
return offs;
// throw std::runtime_error("not implemented");
} else
throw std::runtime_error("invalid smem load config");
}
std::tuple<Value*, Value*, Value*, Value*>
load_x4(int mat0, int mat1, int inc, bool is_prefetch, ir::phi_node *pn,
Value *pre_ptr, Value *next_ptr, std::vector<Value*> &off, std::vector<Value*> &ptrs,
FunctionType *ldmatrix_ty, Type *smem_ptr_ty,
std::map<ir::value*, std::vector<Value*>> &prefetch_latch_to_bb_) {
assert(mat0 % 2 == 0 && mat1 % 2 == 0 && "smem matrix load must be aligned");
int mat_idx[2] = {mat0, mat1};
int k = mat_idx[k_order_];
int ptr_idx = -1;
if (can_use_ldmatrix_)
ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]);
else // tf32 & trans
ptr_idx = mat_idx[order_[0]];
auto get_ptr = [&](int idx) -> Value* {
Value *ptr = nullptr;
if (k == 0 && is_prefetch) {
if (inc == 0)
ptr = bit_cast(gep(pre_ptr, off.at(idx)), smem_ptr_ty);
else
ptr = bit_cast(gep(next_ptr, off.at(idx)), smem_ptr_ty);
} else
ptr = ptrs.at(idx);
return ptr;
};
Value *ptr = get_ptr(ptr_idx);
Value *res_v4 = nullptr;
if (can_use_ldmatrix_) {
std::string trans = need_trans_ ? ".trans" : "";
// the offset (in byte) on the strided axis is a constant
int s_offset = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_ * dtsize_;
InlineAsm *ld_fn = InlineAsm::get(ldmatrix_ty,
"ldmatrix.sync.aligned.m8n8.x4" + trans + ".shared.b16 "
"{$0, $1, $2, $3}, "
"[$4 + " + std::to_string(s_offset) + "];",
"=r,=r,=r,=r,r", true);
assert(ptr);
res_v4 = call(ldmatrix_ty, ld_fn, {ptr});
if (k == 0 && inc == 1 && is_prefetch)
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(res_v4);
return {extract_val(res_v4, std::vector<unsigned>{0}),
extract_val(res_v4, std::vector<unsigned>{1}),
extract_val(res_v4, std::vector<unsigned>{2}),
extract_val(res_v4, std::vector<unsigned>{3})};
} else {
// assert(false && "should not be here");
assert(dtsize_ == 4 && need_trans_);
Value *ptr2 = get_ptr(ptr_idx+1);
assert(s_mat_stride_ == 1);
int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_;
int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_;
Value *elem0, *elem1, *elem2, *elem3;
if (k_order_ == 1) {
elem0 = load(gep(ptr, i32(s_offset_elem)));
elem1 = load(gep(ptr2, i32(s_offset_elem)));
elem2 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem)));
elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem)));
} else { // for b (k first)
elem0 = load(gep(ptr, i32(s_offset_elem)));
elem2 = load(gep(ptr2, i32(s_offset_elem)));
elem1 = load(gep(ptr, i32(s_offset_elem + s_offset_arr_elem)));
elem3 = load(gep(ptr2, i32(s_offset_elem + s_offset_arr_elem)));
}
if (k == 0 && inc == 1 && is_prefetch) {
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem0);
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem1);
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem2);
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3);
}
return {elem0, elem1, elem2, elem3};
}
}
int get_num_ptr() const { return num_ptr_; }
private:
int wpt_;
std::vector<int> order_;
int k_order_;
std::vector<unsigned> tile_shape_;
std::vector<int> instr_shape_;
std::vector<int> mat_shape_;
int per_phase_, max_phase_;
int dtsize_;
// generated
int c_mat_shape_, s_mat_shape_;
int c_stride_, s_stride_;
// p_: on the pointer axis
int p_load_stride_in_mat_;
int s_mat_stride_;
// stride when moving to next not-k mat
int warp_off_stride_;
int mat_arr_stride_; // matrix arrangement (inside a load) stride
bool need_trans_, can_use_ldmatrix_;
int num_ptr_;
Builder *builder_;
adder add;
multiplier mul;
geper gep;
};
}
/**
* \brief Code Generation for `mma.16816` (A100)
*/
@@ -1338,35 +1576,65 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
analysis::shared_layout* layout_b = (analysis::shared_layout*)layouts_->get(C->get_operand(1));
bool is_a_row = ord_a[0] == 1;
bool is_b_row = ord_b[0] == 1;
std::string a_trans = is_a_row ? "" : ".trans";
std::string b_trans = is_b_row ? ".trans" : "";
int stride_a_m = is_a_row ? shape_a[1] : 1;
int stride_a_k = is_a_row ? 1 : shape_a[0];
int stride_b_n = is_b_row ? 1 : shape_b[0];
int stride_b_k = is_b_row ? shape_b[1] : 1;
int stride_a0 = is_a_row ? stride_a_k : stride_a_m;
int stride_a1 = is_a_row ? stride_a_m : stride_a_k;
int stride_b0 = is_b_row ? stride_b_n : stride_b_k;
int stride_b1 = is_b_row ? stride_b_k : stride_b_n;
int lda = is_a_row ? stride_a_m : stride_a_k;
int ldb = is_b_row ? stride_b_k : stride_b_n;
int per_phase_a = swizzle_->get_per_phase(layout_a);
int max_phase_a = swizzle_->get_max_phase(layout_a);
int per_phase_b = swizzle_->get_per_phase(layout_b);
int max_phase_b = swizzle_->get_max_phase(layout_b);
int num_ptr_a = 8;
int num_ptr_b = 8;
int vec_a = 8;
int vec_b = 8;
std::vector<int> mma_instr_shape = layout->get_mma_instr_shape();
const int mma_instr_m = mma_instr_shape[0];
const int mma_instr_n = mma_instr_shape[1];
const int mma_instr_k = mma_instr_shape[2];
std::vector<int> mat_shape = layout->get_mma_mat_shape();
const int mat_shape_m = mat_shape[0];
const int mat_shape_n = mat_shape[1];
const int mat_shape_k = mat_shape[2];
const int per_phase_a = swizzle_->get_per_phase(layout_a);
const int max_phase_a = swizzle_->get_max_phase(layout_a);
const int per_phase_b = swizzle_->get_per_phase(layout_b);
const int max_phase_b = swizzle_->get_max_phase(layout_b);
const int num_rep_m = shapes[0] / layout->shape_per_cta(0);
const int num_rep_n = shapes[1] / layout->shape_per_cta(1);
const int num_rep_k = std::max<int>(NK/mma_instr_k, 1);
Type *fp32_ty = f32_ty;
Type *fp16x2_ty = vec_ty(f16_ty, 2);
Type *bf16x2_ty = vec_ty(bf16_ty, 2);
Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty});
Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty});
FunctionType *ld_x4_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{ptr_ty(f16_ty, 3)}, false);
FunctionType *ldmatrix_ty = nullptr;
FunctionType *mma_ty = nullptr;
Type *phi_ty = nullptr;
Type *smem_ptr_ty = nullptr;
ir::type *A_ir_ty = A->get_type()->get_scalar_ty();
ir::type *B_ir_ty = B->get_type()->get_scalar_ty();
if (A_ir_ty->is_fp16_ty() && B_ir_ty->is_fp16_ty()) {
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
smem_ptr_ty = ptr_ty(f16_ty, 3);
ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
phi_ty = fp16x2_ty;
} else if (A_ir_ty->is_bf16_ty() && B_ir_ty->is_bf16_ty()) {
// FIXME: We should use bf16 here.
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
smem_ptr_ty = ptr_ty(f16_ty, 3);
ldmatrix_ty = FunctionType::get(fp16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
phi_ty = fp16x2_ty;
// mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
// smem_ptr_ty = ptr_ty(bf16_ty, 3);
// ldmatrix_ty = FunctionType::get(bf16x2_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
// phi_ty = bf16x2_ty;
} else if (A_ir_ty->is_fp32_ty() && B_ir_ty->is_fp32_ty()) {
mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
smem_ptr_ty = ptr_ty(fp32_ty, 3);
ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
phi_ty = fp32_ty;
} else
throw std::runtime_error("mma16816 data type not supported");
// left-hand-side values
std::map<std::pair<unsigned, unsigned>, std::pair<Value*, Value*>> ha;
std::map<std::pair<unsigned, unsigned>, Value*> ha;
std::map<std::pair<unsigned, unsigned>, Value*> hb;
BasicBlock* CurrBB = builder_->GetInsertBlock();
@@ -1377,78 +1645,65 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
Value* thread = tgt_->get_local_id(mod_, *builder_, 0);
Value *lane = urem(thread, i32(32));
Value *warp = udiv(thread, i32(32));
Value *warp12 = udiv(warp, i32(layout->wpt(0)));
Value *warp0 = urem(warp, i32(layout->wpt(0)));
Value *warp1 = urem(warp12, i32(layout->wpt(1)));
Value *warp_mn = udiv(warp, i32(layout->wpt(0)));
Value *warp_m = urem(warp, i32(layout->wpt(0)));
Value *warp_n = urem(warp_mn, i32(layout->wpt(1)));
std::vector<Value *>& fc = fcs.begin()->second;
Value *tidr8 = urem(lane, i32(8));
Value *phase_a = urem(udiv(tidr8, i32(per_phase_a)), i32(max_phase_a));
Value* off_a0 = mul(tidr8, i32(lda));
Value *off_am = mul(add(urem(udiv(lane, i32(8)), i32(2)), mul(warp0, i32(2))), i32(8));
Value *off_ak = mul(udiv(lane, i32(16)), i32(8));
off_am = urem(off_am, i32(shape_a[0]));
off_ak = urem(off_ak, i32(shape_a[1]));
off_a0 = add(off_a0, is_a_row ? off_ak : off_am);
Value* off_a1 = is_a_row ? off_am : off_ak;
std::vector<Value*> off_a(num_ptr_a);
for(int i = 0; i < num_ptr_a; i++){
Value* off_a0i = add(off_a0, i32(i*16*(is_a_row?1:layout->wpt(0))));
off_a0i = exact_udiv(off_a0i, i32(vec_a));
off_a0i = xor_(off_a0i, phase_a);
off_a0i = mul(off_a0i, i32(vec_a));
off_a[i] = add(mul(off_a0i, i32(stride_a0)), mul(off_a1, i32(stride_a1)));
}
size_t dtsize_a = A->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
size_t dtsize_b = B->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
Value *phase_b = urem(udiv(tidr8, i32(per_phase_b)), i32(max_phase_b));
Value* off_b0 = mul(tidr8, i32(ldb));
Value *off_bn = mul(add(mul(udiv(lane, i32(16)), i32(layout->wpt(1))), mul(warp1, i32(1))), i32(8));
Value *off_bk = mul(urem(udiv(lane, i32(8)), i32(2)), i32(8));
off_bn = urem(off_bn, i32(shape_b[1]));
off_bk = urem(off_bk, i32(shape_b[0]));
off_b0 = add(off_b0, is_b_row ? off_bn : off_bk);
Value* off_b1 = is_b_row ? off_bk : off_bn;
std::vector<Value*> off_b(num_ptr_b);
for(int i = 0; i < num_ptr_b; i++){
Value* off_b0i = add(off_b0, i32(i*(is_b_row?8*layout->wpt(1):16)));
off_b0i = exact_udiv(off_b0i, i32(vec_b));
off_b0i = xor_(off_b0i, phase_b);
off_b0i = mul(off_b0i, i32(vec_b));
off_b[i] = add(mul(off_b0i, i32(stride_b0)), mul(off_b1, i32(stride_b1)));
}
// | -> k (row-major), since we have ldmatrix.trans, we only need to change stride
// v (s0_0(0), s1_0(2), | *num_rep_k
// m s0_1(1), s1_1(3)) | (stride in num of matrices(mat_stride_ak): 2)
// -----------
// *num_rep_m (stride in num of matrices(mat_stride_am): 2*layout->wpt(0))
mma16816_smem_loader a_loader(layout->wpt(0), ord_a, /*k_order*/1, shape_a,
{mma_instr_m, mma_instr_k}, {mat_shape_m, mat_shape_k},
per_phase_a, max_phase_a, dtsize_a, builder_, add, mul, gep);
std::vector<Value*> off_a = a_loader.compute_offs(warp_m, lane);
int num_ptr_a = a_loader.get_num_ptr();
// | -> n (col-major)
// v (s0_0(0), | (stride: wpt(1)) | s1_0(2) | *num_rep_n
// k s0_1(1), | | s1_1(3)) | (stride in num of matrices(mat_stride_bn): wpt(1))
// -----------
// *num_rep_k (stride in num of matrices(mat_stride_bk): 2)
mma16816_smem_loader b_loader(layout->wpt(1), ord_b, /*k_order*/0, shape_b,
{mma_instr_k, mma_instr_n}, {mat_shape_k, mat_shape_n},
per_phase_b, max_phase_b, dtsize_b, builder_, add, mul, gep);
std::vector<Value*> off_b = b_loader.compute_offs(warp_n, lane);
int num_ptr_b = b_loader.get_num_ptr();
builder_->SetInsertPoint(CurrBB);
// A pointer
std::vector<Value*> ptrs_a(num_ptr_a);
for(int i = 0; i < num_ptr_a; i++)
ptrs_a[i] = gep(shmems_[A], {off_a[i]});
ptrs_a[i] = bit_cast(gep(shmems_[A], {off_a[i]}), smem_ptr_ty);
// B pointer
std::vector<Value*> ptrs_b(num_ptr_b);
for(int i = 0; i < num_ptr_b; i++)
ptrs_b[i] = gep(shmems_[B], {off_b[i]});
ptrs_b[i] = bit_cast(gep(shmems_[B], {off_b[i]}), smem_ptr_ty);
FunctionType *mma_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
InlineAsm *mma_fn = InlineAsm::get(mma_ty, "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
InlineAsm *mma_fn = InlineAsm::get(mma_ty, layout->get_ptx_instr() +
" {$0, $1, $2, $3},"
" {$4, $5, $6, $7},"
" {$8, $9},"
" {$10, $11, $12, $13};",
"=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true);
unsigned num_rep_0 = shapes[0] / layout->shape_per_cta(0);
unsigned num_rep_1 = shapes[1] / layout->shape_per_cta(1);
// create mma & unpack result
auto call_mma = [&](unsigned m, unsigned n, unsigned K) {
unsigned cols_per_thread = num_rep_0 * 2;
// create mma & unpack result, m, n, k are offsets in mat
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
unsigned cols_per_thread = num_rep_m * 2;
std::vector<size_t> idx = {
(m*2 + 0) + (n*2 + 0)*cols_per_thread,
(m*2 + 0) + (n*2 + 1)*cols_per_thread,
(m*2 + 1) + (n*2 + 0)*cols_per_thread,
(m*2 + 1) + (n*2 + 1)*cols_per_thread
(m + 0) + (n*2 + 0)*cols_per_thread,
(m + 0) + (n*2 + 1)*cols_per_thread,
(m + 1) + (n*2 + 0)*cols_per_thread,
(m + 1) + (n*2 + 1)*cols_per_thread
};
Value *nc = call(mma_ty, mma_fn, {ha[{m, K}].first, ha[{m, K}].second,ha[{m, K+8}].first, ha[{m, K+8}].second,
hb[{n, K}], hb[{n, K+8}],
Value *nc = call(mma_ty, mma_fn,
{ha[{m, k}], ha[{m+1, k}], ha[{m, k+1}], ha[{m+1, k+1}],
hb[{n, k}], hb[{n, k+1}],
fc[idx[0]], fc[idx[1]], fc[idx[2]], fc[idx[3]]});
fc[idx[0]] = extract_val(nc, std::vector<unsigned>{0});
fc[idx[1]] = extract_val(nc, std::vector<unsigned>{1});
@@ -1459,130 +1714,82 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
ir::phi_node* phiA = dynamic_cast<ir::phi_node*>(A);
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
auto register_lds =
[&](decltype(ha)& vals, int m, int K, int inc, Value* val0, Value *val1, bool is_prefetch) {
if (K <= 8 && is_prefetch) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].first, val0, inc_block));
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}].second, val1, inc_block));
} else
vals[{m, K}] = {val0, val1};
};
auto register_lds2 =
[&](decltype(hb)& vals, int m, int K, int inc, Value* val, bool is_prefetch) {
if (K <= 8 && is_prefetch) {
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int n, int k, int inc, Value* val, bool is_prefetch) {
if (k < 2 && is_prefetch) {
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{m, K}], val, inc_block));
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{n, k}], val, inc_block));
} else
vals[{m, K}] = val;
vals[{n, k}] = val;
};
auto load_a = [&](int m, int K, int inc, bool is_prefetch) {
int offidx = (is_a_row ? K/16 : m) % num_ptr_a;
Value* ptra;
if(K == 0 && is_prefetch){
if(inc == 0)
ptra = gep(shared_pre_ptr_[layout_a], off_a[offidx]);
else
ptra = gep(shared_next_ptr_[layout_a], off_a[offidx]);
}
else
ptra = ptrs_a[offidx];
int step_am = is_a_row ? m : m / (num_ptr_a)*(num_ptr_a);
int step_ak = is_a_row ? K / (num_ptr_a*16)*(num_ptr_a*16) : K;
InlineAsm *ld_a0_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + a_trans + ".shared.b16 "
"{$0, $1, $2, $3}, [$4 + " +
std::to_string(2*step_am*16*layout->wpt(0)*stride_a_m + 2*step_ak*stride_a_k) + "];",
"=r,=r,=r,=r,r", true);
Value *haa = call(ld_x4_ty, ld_a0_fn, {ptra});
if(K == 0 && inc == 1 && is_prefetch)
prefetch_latch_to_bb_[phiA->get_incoming_value(1)].push_back(haa);
Value *ha0 = extract_val(haa, std::vector<unsigned>{0});
Value *ha1 = extract_val(haa, std::vector<unsigned>{1});
Value *ha2 = extract_val(haa, std::vector<unsigned>{2});
Value *ha3 = extract_val(haa, std::vector<unsigned>{3});
register_lds(ha, m, K, inc, ha0, ha1, is_prefetch);
register_lds(ha, m, K + 8, inc, ha2, ha3, is_prefetch);
auto load_a = [&](int m, int k, int inc, bool is_prefetch) {
auto [ha0, ha1, ha2, ha3] = a_loader.load_x4(m, k, inc, is_prefetch, phiA, shared_pre_ptr_[layout_a],
shared_next_ptr_[layout_a], off_a, ptrs_a,
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
register_lds2(ha, m, k, inc, ha0, is_prefetch);
register_lds2(ha, m+1, k, inc, ha1, is_prefetch);
register_lds2(ha, m, k+1, inc, ha2, is_prefetch);
register_lds2(ha, m+1, k+1, inc, ha3, is_prefetch);
};
auto load_b = [&](int n, int K, int inc, bool is_prefetch) {
int offidx = (is_b_row ? n : K/16) % num_ptr_b;
Value* ptrb;
if(K == 0 && is_prefetch){
if(inc == 0)
ptrb = gep(shared_pre_ptr_[layout_b], off_b[offidx]);
else
ptrb = gep(shared_next_ptr_[layout_b], off_b[offidx]);
}
else
ptrb = ptrs_b[offidx];
int step_bn = is_b_row ? n / (num_ptr_b)*(num_ptr_b) : n;
int step_bk = is_b_row ? K : K / (num_ptr_b*8)*(num_ptr_b*8);
InlineAsm *ld_b_fn = InlineAsm::get(ld_x4_ty, "ldmatrix.sync.aligned.m8n8.x4" + b_trans + ".shared.b16 "
"{$0, $1, $2, $3}, [$4 + " +
std::to_string(2*step_bn*8*layout->wpt(1)*stride_b_n + 2*step_bk*stride_b_k) + "];",
"=r,=r,=r,=r,r", true);
Value *hbb = call(ld_x4_ty, ld_b_fn, {ptrb});
if(K == 0 && inc == 1 && is_prefetch)
prefetch_latch_to_bb_[phiB->get_incoming_value(1)].push_back(hbb);
Value *hb0 = extract_val(hbb, std::vector<unsigned>{0});
Value *hb1 = extract_val(hbb, std::vector<unsigned>{1});
Value *hb2 = extract_val(hbb, std::vector<unsigned>{2});
Value *hb3 = extract_val(hbb, std::vector<unsigned>{3});
register_lds2(hb, n, K, inc, hb0, is_prefetch);
register_lds2(hb, n+1, K, inc, hb2, is_prefetch);
register_lds2(hb, n, K+8, inc, hb1, is_prefetch);
register_lds2(hb, n+1, K+8, inc, hb3, is_prefetch);
auto load_b = [&](int n, int k, int inc, bool is_prefetch) {
auto [hb0, hb1, hb2, hb3] = b_loader.load_x4(k, n, inc, is_prefetch, phiB, shared_pre_ptr_[layout_b],
shared_next_ptr_[layout_b], off_b, ptrs_b,
ldmatrix_ty, smem_ptr_ty, prefetch_latch_to_bb_);
register_lds2(hb, n, k, inc, hb0, is_prefetch);
register_lds2(hb, n+1, k, inc, hb2, is_prefetch);
register_lds2(hb, n, k+1, inc, hb1, is_prefetch);
register_lds2(hb, n+1, k+1, inc, hb3, is_prefetch);
};
if (C->is_prefetched()) {
// create phis
builder_->SetInsertPoint(CurrBB->getFirstNonPHI());
for(unsigned m = 0; m < num_rep_0; m++){
ha[{m, 0}].first = phi(fp16x2_ty, 2);
ha[{m, 0}].second = phi(fp16x2_ty, 2);
ha[{m, 8}].first = phi(fp16x2_ty, 2);
ha[{m, 8}].second = phi(fp16x2_ty, 2);
for(unsigned m = 0; m < num_rep_m; m++){
ha[{2*m, 0}] = phi(phi_ty, 2);
ha[{2*m+1, 0}] = phi(phi_ty, 2);
ha[{2*m, 1}] = phi(phi_ty, 2);
ha[{2*m+1, 1}] = phi(phi_ty, 2);
}
for(unsigned n = 0; n < num_rep_1; n+=2){
hb[{n, 0}] = phi(fp16x2_ty, 2);
hb[{n+1, 0}] = phi(fp16x2_ty, 2);
hb[{n, 8}] = phi(fp16x2_ty, 2);
hb[{n+1, 8}] = phi(fp16x2_ty, 2);
for(unsigned n = 0; n < num_rep_n; n+=2){
hb[{n, 0}] = phi(phi_ty, 2);
hb[{n+1, 0}] = phi(phi_ty, 2);
hb[{n, 1}] = phi(phi_ty, 2);
hb[{n+1, 1}] = phi(phi_ty, 2);
}
// insert prefetched lds at the end of loop header
builder_->SetInsertPoint(bbs_[phiA->get_incoming_block(0)]->getTerminator());
for(unsigned m = 0; m < num_rep_0; m++)
load_a(m, 0, 0, true);
for(unsigned n = 0; n < num_rep_1; n+=2)
for(unsigned m = 0; m < num_rep_m; m++)
load_a(2*m, 0, 0, true);
for(unsigned n = 0; n < num_rep_n; n+=2)
load_b(n, 0, 0, true);
// update accumulators
builder_->SetInsertPoint(CurrBB);
for(unsigned K = 0; K < NK; K += 16){
int NEXTK = (K + 16) % NK;
for(unsigned k = 0; k < num_rep_k; ++k){ // stride of instr in mat is 2
int next_k = (k + 1) % num_rep_k;
// prefetch A
for(unsigned m = 0; m < num_rep_0; m++)
load_a(m, NEXTK, 1, true);
for(unsigned m = 0; m < num_rep_m; m++)
load_a(2*m, 2*next_k, 1, true);
// prefetch B
for(unsigned n = 0; n < num_rep_1; n+=2)
load_b(n, NEXTK, 1, true);
for(unsigned n = 0; n < num_rep_n; n+=2)
load_b(n, 2*next_k, 1, true);
// tensor core ops
for(unsigned m = 0; m < num_rep_0; m++)
for(unsigned n = 0; n < num_rep_1; n++){
call_mma(m, n, K);
for(unsigned m = 0; m < num_rep_m; m++)
for(unsigned n = 0; n < num_rep_n; n++){
call_mma(2*m, n, 2*k);
}
}
}
else{
for(unsigned K = 0; K < NK; K += 16)
for(unsigned m = 0; m < num_rep_0; m++)
for(unsigned n = 0; n < num_rep_1; n++){
if(ha.find({m, K}) == ha.end())
load_a(m, K, 0, false);
if(hb.find({n, K})==hb.end())
load_b(n, K, 0, false);
call_mma(m, n, K);
for (unsigned k = 0; k < num_rep_k; k++) {
for (unsigned m = 0; m < num_rep_m; m++)
load_a(2*m, 2*k, 0, /*is_prefetch*/false);
for (unsigned n = 0; n < num_rep_n; n+=2)
load_b(n, 2*k, 0, /*is_prefetch*/false);
for (unsigned m = 0; m < num_rep_m; m++)
for (unsigned n = 0; n < num_rep_n; n++)
call_mma(2*m, n, 2*k);
}
}
// write back
@@ -1714,7 +1921,7 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() < 80)
return visit_mma884(dot, A, B, D, NK);
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
return visit_mma16816(dot, A, B, D, NK);
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
}
@@ -1752,13 +1959,13 @@ inline Value* generator::shfl_sync(Value* acc, int32_t i){
InlineAsm *shfl = InlineAsm::get(FunctionType::get(ty, {ty, i32_ty}, false), asm_str, "=f,f,r", false);
if(ty->getPrimitiveSizeInBits() <= 32)
return call(shfl, {acc, i32(i)});
acc = builder_->CreateBitCast(acc, vec_ty(f32_ty, 2));
acc = bit_cast(acc, vec_ty(f32_ty, 2));
Value* acc0 = builder_->CreateExtractElement(acc, i32(0));
Value* acc1 = builder_->CreateExtractElement(acc, i32(1));
Value* ret = UndefValue::get(vec_ty(f32_ty, 2));
ret = insert_elt(ret, shfl_sync(acc0, i), i32(0));
ret = insert_elt(ret, shfl_sync(acc1, i), i32(1));
return builder_->CreateBitCast(ret, ty);
return bit_cast(ret, ty);
}
/**
@@ -1936,6 +2143,10 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
ir::block_type::block_shapes_t shape = out->get_type()->get_block_shapes();
// pointer to temporary shared memory
Type *ty = cvt(out->get_type()->get_scalar_ty());
if (ty->isBFloatTy()) // llvm11-nvptx cannot select bf16 store
ty = f16_ty;
// Orders
analysis::distributed_layout* in_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(in));
analysis::distributed_layout* out_layout = dynamic_cast<analysis::distributed_layout*>(layouts_->get(out));
@@ -1976,7 +2187,7 @@ void generator::visit_layout_convert(ir::value *out, ir::value *in){
// stash value to shared mem
indices_t idxs = {in_ax[0][i*max_ii + ii],
in_ax[1][j*max_jj + jj]};
store(vals_[in][idxs], ptr);
store(bit_cast(vals_[in][idxs], ty), ptr);
}
add_barrier();
max_ii = out_ax[0].size()/n_reps[0];

View File

@@ -80,7 +80,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
ir::value *a = dot->get_operand(0);
ir::value *b = dot->get_operand(1);
builder.set_insert_point(add);
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->get_name()));
ir::value * new_dot = builder.insert(ir::dot_inst::create_nn(a, b, other, dot->allow_tf32(), dot->get_name()));
add->replace_all_uses_with(new_dot);
return true;
}

View File

@@ -29,8 +29,13 @@ void prefetch::run(ir::module &mod) {
std::vector<ir::dot_inst*> to_prefetch;
ir::for_each_instruction(mod, [&](ir::instruction *i) {
if (auto *dot = dynamic_cast<ir::dot_inst*>(i)) {
// Now only do prefetching when dot is fp16
if (dot->get_operand(0)->get_type()->get_scalar_ty()->get_type_id() != ir::type::FP16TyID)
// Now only do prefetching when dot is using tensor cores
if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() ||
dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() ||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32()
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
)
)
return;
auto *a = dynamic_cast<ir::phi_node*>(dot->get_operand(0));
auto *b = dynamic_cast<ir::phi_node*>(dot->get_operand(1));

View File

@@ -121,6 +121,7 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
init_llvm();
// verify and store llvm
llvm::legacy::PassManager pm;
// module->print(llvm::outs(), nullptr);
pm.add(llvm::createVerifierPass());
pm.run(*module);
// create machine

View File

@@ -367,8 +367,8 @@ value *builder::create_log(value *arg){
return insert(log_inst::create(arg));
}
value *builder::create_dot(value *A, value *B, value *C) {
return insert(dot_inst::create_nn(A, B, C));
value *builder::create_dot(value *A, value *B, value *C, bool allow_tf32) {
return insert(dot_inst::create_nn(A, B, C, allow_tf32));
}
value *builder::create_trans(value *A, const std::vector<int>& perm) {

View File

@@ -718,12 +718,13 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask
// Linear Algebra
//===----------------------------------------------------------------------===//
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::builder *builder) {
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) {
ir::value *_0 = builder->get_float32(0);
unsigned M = lhs->get_type()->get_block_shapes()[0];
unsigned N = rhs->get_type()->get_block_shapes()[1];
_0 = builder->create_splat(_0, {M, N});
return builder->create_dot(lhs, rhs, _0);
bool _allow_tf32 = allow_tf32->get_value() != 0;
return builder->create_dot(lhs, rhs, _0, _allow_tf32);
}

View File

@@ -577,40 +577,41 @@ instruction* downcast_inst::create(value *arg, const std::string &name, instruct
// matmul_inst classes
//===----------------------------------------------------------------------===//
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT,
dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32,
const std::string &name, instruction *next)
: builtin_inst(C->get_type(), INST_DOT, 3, name, next) {
set_operand(0, A);
set_operand(1, B);
set_operand(2, C);
allow_tf32_ = allow_tf32;
}
instruction *dot_inst::create(value *A, value *B, value *C,
bool AT, bool BT,
bool AT, bool BT, bool allow_tf32,
const std::string &name, instruction *next) {
TransT OPA = AT ? Trans : NoTrans;
TransT OPB = BT ? Trans : NoTrans;
return new dot_inst(A, B, C, OPA, OPB, name, next);
return new dot_inst(A, B, C, OPA, OPB, allow_tf32, name, next);
}
instruction *dot_inst::create_nn(value *A, value *B, value *C,
instruction *dot_inst::create_nn(value *A, value *B, value *C, bool allow_tf32,
const std::string &name, instruction *next) {
return new dot_inst(A, B, C, NoTrans, NoTrans, name, next);
return new dot_inst(A, B, C, NoTrans, NoTrans, allow_tf32, name, next);
}
instruction *dot_inst::create_nt(value *A, value *B, value *C,
instruction *dot_inst::create_nt(value *A, value *B, value *C, bool allow_tf32,
const std::string &name, instruction *next) {
return new dot_inst(A, B, C, NoTrans, Trans, name, next);
return new dot_inst(A, B, C, NoTrans, Trans, allow_tf32, name, next);
}
instruction *dot_inst::create_tn(value *A, value *B, value *C,
instruction *dot_inst::create_tn(value *A, value *B, value *C, bool allow_tf32,
const std::string &name, instruction *next) {
return new dot_inst(A, B, C, Trans, NoTrans, name, next);
return new dot_inst(A, B, C, Trans, NoTrans, allow_tf32, name, next);
}
instruction *dot_inst::create_tt(value *A, value *B, value *C,
instruction *dot_inst::create_tt(value *A, value *B, value *C, bool allow_tf32,
const std::string &name, instruction *next) {
return new dot_inst(A, B, C, Trans, Trans, name, next);
return new dot_inst(A, B, C, Trans, Trans, allow_tf32, name, next);
}
//===----------------------------------------------------------------------===//

View File

@@ -10,6 +10,7 @@ import torch
from numpy.random import RandomState
import triton
import triton._C.libtriton.triton as _triton
import triton.language as tl
from triton.code_gen import TensorWrapper, reinterpret
@@ -660,22 +661,26 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
# ---------------
@pytest.mark.parametrize("epilogue", ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols'])
def test_dot(epilogue, device='cuda'):
@pytest.mark.parametrize("epilogue, allow_tf32",
[(epilogue, allow_tf32)
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
for allow_tf32 in [True, False]])
def test_dot(epilogue, allow_tf32, device='cuda'):
# triton kernel
@triton.jit
def kernel(X, stride_xm, stride_xk,
Y, stride_yk, stride_yn,
Z, stride_zm, stride_zn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr):
ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr,
ALLOW_TF32: tl.constexpr):
off_m = tl.arange(0, BLOCK_M)
off_n = tl.arange(0, BLOCK_N)
off_k = tl.arange(0, BLOCK_K)
Xs = X + off_m[:, None] * stride_xm + off_k[None, :] * stride_xk
Ys = Y + off_k[:, None] * stride_yk + off_n[None, :] * stride_yn
Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn
z = tl.dot(tl.load(Xs), tl.load(Ys))
z = tl.dot(tl.load(Xs), tl.load(Ys), allow_tf32=ALLOW_TF32)
if ADD_MATRIX:
z += tl.load(Zs)
if ADD_ROWS:
@@ -690,6 +695,12 @@ def test_dot(epilogue, device='cuda'):
rs = RandomState(17)
x = numpy_random((M, K), dtype_str='float32', rs=rs)
y = numpy_random((K, N), dtype_str='float32', rs=rs)
if allow_tf32:
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
if cc < 80:
pytest.skip("Only test tf32 on devices with sm >= 80")
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
x_tri = to_triton(x, device=device)
y_tri = to_triton(y, device=device)
# triton result
@@ -703,7 +714,8 @@ def test_dot(epilogue, device='cuda'):
BLOCK_M=M, BLOCK_K=K, BLOCK_N=N,
ADD_MATRIX=epilogue == 'add-matrix',
ADD_ROWS=epilogue == 'add-rows',
ADD_COLS=epilogue == 'add-cols')
ADD_COLS=epilogue == 'add-cols',
ALLOW_TF32=allow_tf32)
# torch result
z_ref = np.matmul(x, y)
if epilogue == 'add-matrix':
@@ -718,6 +730,8 @@ def test_dot(epilogue, device='cuda'):
ptx = pgm.asm['ptx']
assert 'ld.global.v4' in ptx
assert 'st.global.v4' in ptx
if allow_tf32:
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
def test_dot_without_load():

View File

@@ -524,18 +524,18 @@ def reshape(input, shape, _builder=None):
@builtin
def dot(input, other, _builder=None):
def dot(input, other, allow_tf32=True, _builder=None):
"""
Returns the matrix product of two blocks.
The two blocks must be two dimensionals and have compatible inner dimensions.
:param input: The first block to be multiplied.
:type input: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
:type input: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
:param other: The second block to be multiplied.
:type other: 2D block of scalar-type in {:code:`float16`, :code:`float32`}
:type other: 2D block of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
"""
return frontend.dot(input, other, _builder)
return frontend.dot(input, other, allow_tf32, _builder)
# -----------------------