[BACKEND] Added Int8 mma (#440)
This commit is contained in:
@@ -268,6 +268,7 @@ public:
|
|||||||
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
|
void set_mma_vec(int mma_vec) { mma_vec_ = mma_vec; }
|
||||||
int get_mma_vec() { return mma_vec_;}
|
int get_mma_vec() { return mma_vec_;}
|
||||||
int get_mma_strided() { return mma_strided_; }
|
int get_mma_strided() { return mma_strided_; }
|
||||||
|
bool allow_swizzle() const { return allow_swizzle_; }
|
||||||
data_layout* get_arg_layout() { return arg_layout_; }
|
data_layout* get_arg_layout() { return arg_layout_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@@ -281,6 +282,7 @@ private:
|
|||||||
data_layout* arg_layout_;
|
data_layout* arg_layout_;
|
||||||
int mma_vec_;
|
int mma_vec_;
|
||||||
int mma_strided_;
|
int mma_strided_;
|
||||||
|
bool allow_swizzle_ = true;
|
||||||
target *tgt_;
|
target *tgt_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -33,7 +33,9 @@ inline bool is_hmma_c(ir::value *v, int sm){
|
|||||||
result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
|
result = (a_ty->get_scalar_ty()->is_fp16_ty() && b_ty->get_scalar_ty()->is_fp16_ty()) ||
|
||||||
(a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
|
(a_ty->get_scalar_ty()->is_bf16_ty() && b_ty->get_scalar_ty()->is_bf16_ty()) ||
|
||||||
(a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
|
(a_ty->get_scalar_ty()->is_fp32_ty() && b_ty->get_scalar_ty()->is_fp32_ty() &&
|
||||||
x->allow_tf32() && sm >= 80);
|
x->allow_tf32() && sm >= 80) ||
|
||||||
|
(a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8) &&
|
||||||
|
sm >= 80);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@@ -63,7 +65,7 @@ static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
|
|||||||
return mma_type;
|
return mma_type;
|
||||||
}
|
}
|
||||||
} else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
|
} else if (c_ty->get_scalar_ty()->is_integer_ty(32)) {
|
||||||
throw std::runtime_error("integer tensor cores are not yet supported");
|
// throw std::runtime_error("integer tensor cores are not yet supported");
|
||||||
// // integer tensor cores
|
// // integer tensor cores
|
||||||
// if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
|
// if (a_ty->get_scalar_ty()->is_integer_ty(1) && b_ty->get_scalar_ty()->is_integer_ty(1)) {
|
||||||
// mma_type = mma_layout::INT32_INT1_INT1_INT32;
|
// mma_type = mma_layout::INT32_INT1_INT1_INT32;
|
||||||
@@ -73,10 +75,10 @@ static mma_layout::TensorCoreType get_mma_type(ir::value *v) {
|
|||||||
// mma_type = mma_layout::INT32_INT4_INT4_INT32;
|
// mma_type = mma_layout::INT32_INT4_INT4_INT32;
|
||||||
// return mma_type;
|
// return mma_type;
|
||||||
// }
|
// }
|
||||||
// if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
|
if (a_ty->get_scalar_ty()->is_integer_ty(8) && b_ty->get_scalar_ty()->is_integer_ty(8)) {
|
||||||
// mma_type = mma_layout::INT32_INT8_INT8_INT32;
|
mma_type = mma_layout::INT32_INT8_INT8_INT32;
|
||||||
// return mma_type;
|
return mma_type;
|
||||||
// }
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return mma_layout::NOT_APPLICABLE;
|
return mma_layout::NOT_APPLICABLE;
|
||||||
@@ -444,11 +446,21 @@ shared_layout::shared_layout(data_layout *arg,
|
|||||||
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
|
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_a_));
|
||||||
mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
|
mma_vec_ = order_[0] == 1 ? mat_shape[2] : mat_shape[0]; // k : m
|
||||||
mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
|
mma_strided_ = order_[0] == 1 ? mat_shape[0] : mat_shape[2];
|
||||||
|
|
||||||
|
// for now, disable swizzle when using lds.8
|
||||||
|
if (get_mma_type(hmma_dot_a_) == mma_layout::INT32_INT8_INT8_INT32)
|
||||||
|
if (order_[0] == 0) // need transpose
|
||||||
|
allow_swizzle_ = false;
|
||||||
} else if (hmma_dot_b_) {
|
} else if (hmma_dot_b_) {
|
||||||
assert(order_.size() == 2);
|
assert(order_.size() == 2);
|
||||||
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
|
std::vector<int> mat_shape = mma_layout::mma_mat_shape_.at(get_mma_type(hmma_dot_b_));
|
||||||
mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
|
mma_vec_ = order_[0] == 1 ? mat_shape[1] : mat_shape[2]; // n : k
|
||||||
mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
|
mma_strided_ = order_[0] == 1 ? mat_shape[2] : mat_shape[1];
|
||||||
|
|
||||||
|
// for now, disable swizzle when using lds.8
|
||||||
|
if (get_mma_type(hmma_dot_b_) == mma_layout::INT32_INT8_INT8_INT32)
|
||||||
|
if (order_[0] == 1) // need transpose
|
||||||
|
allow_swizzle_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// size
|
// size
|
||||||
|
@@ -41,9 +41,15 @@ void swizzle::run(ir::module &) {
|
|||||||
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
vec_[layout] = 2*layouts_->get(mma_dot_b)->to_mma()->rep(1);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
if (!layout->allow_swizzle()) {
|
||||||
|
per_phase_[layout] = 1;
|
||||||
|
max_phase_[layout] = 1;
|
||||||
|
vec_[layout] = 1;
|
||||||
|
} else {
|
||||||
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
per_phase_[layout] = std::max<int>(128 / (in_layout->mts(ord[0])*in_layout->nts(ord[0])*dtsize), 1);
|
||||||
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
|
max_phase_[layout] = layout->get_mma_strided() / per_phase_[layout];
|
||||||
vec_[layout] = layout->get_mma_vec();
|
vec_[layout] = layout->get_mma_vec();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,78 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <numeric>
|
|
||||||
#include <sstream>
|
|
||||||
#include <iomanip>
|
|
||||||
#include "triton/codegen/selection/generator.h"
|
|
||||||
#include "triton/codegen/target.h"
|
|
||||||
#include "triton/codegen/analysis/axes.h"
|
|
||||||
#include "triton/codegen/analysis/allocation.h"
|
|
||||||
#include "triton/codegen/analysis/align.h"
|
|
||||||
#include "triton/codegen/analysis/swizzle.h"
|
|
||||||
#include "triton/codegen/transform/coalesce.h"
|
|
||||||
#include "triton/ir/context.h"
|
|
||||||
#include "triton/ir/module.h"
|
|
||||||
#include "triton/ir/function.h"
|
|
||||||
#include "triton/ir/type.h"
|
|
||||||
#include "llvm/IR/Module.h"
|
|
||||||
#include "llvm/IR/IRBuilder.h"
|
|
||||||
#include "llvm/IR/IntrinsicsNVPTX.h"
|
|
||||||
#include "llvm/IR/BasicBlock.h"
|
|
||||||
#include "llvm/IR/Attributes.h"
|
|
||||||
#include "llvm/IR/InlineAsm.h"
|
|
||||||
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
|
||||||
|
|
||||||
namespace triton::codegen {
|
|
||||||
// types
|
|
||||||
#define void_ty builder_->getVoidTy()
|
|
||||||
#define f16_ty builder_->getHalfTy()
|
|
||||||
#define bf16_ty builder_->getBFloatTy()
|
|
||||||
#define f32_ty builder_->getFloatTy()
|
|
||||||
#define i8_ty builder_->getInt8Ty()
|
|
||||||
#define i32_ty builder_->getInt32Ty()
|
|
||||||
#define vec_ty(type, num_el) VectorType::get(type, num_el, false)
|
|
||||||
#define ptr_ty(...) PointerType::get(__VA_ARGS__)
|
|
||||||
// constants
|
|
||||||
#define i32(...) builder_->getInt32(__VA_ARGS__)
|
|
||||||
// ops
|
|
||||||
#define and_(...) builder_->CreateAnd(__VA_ARGS__)
|
|
||||||
#define atomic_cmp_xchg(...) builder_->CreateAtomicCmpXchg(__VA_ARGS__)
|
|
||||||
#define atomic_rmw(...) builder_->CreateAtomicRMW(__VA_ARGS__)
|
|
||||||
#define bin_op(...) builder_->CreateBinOp(__VA_ARGS__)
|
|
||||||
#define bit_cast(...) builder_->CreateBitCast(__VA_ARGS__)
|
|
||||||
#define br(...) builder_->CreateBr(__VA_ARGS__)
|
|
||||||
#define call(...) builder_->CreateCall(__VA_ARGS__)
|
|
||||||
#define cast(...) builder_->CreateCast(__VA_ARGS__)
|
|
||||||
#define cond_br(...) builder_->CreateCondBr(__VA_ARGS__)
|
|
||||||
#define exact_udiv(...) builder_->CreateExactUDiv(__VA_ARGS__)
|
|
||||||
#define extract_elt(...) builder_->CreateExtractElement(__VA_ARGS__)
|
|
||||||
#define extract_val(...) builder_->CreateExtractValue(__VA_ARGS__)
|
|
||||||
#define fadd(...) builder_->CreateFAdd(__VA_ARGS__)
|
|
||||||
#define fcmp(...) builder_->CreateFCmp(__VA_ARGS__)
|
|
||||||
#define fmul(...) builder_->CreateFMul(__VA_ARGS__)
|
|
||||||
#define fpcast(...) builder_->CreateFPCast(__VA_ARGS__)
|
|
||||||
#define fsub(...) builder_->CreateFSub(__VA_ARGS__)
|
|
||||||
#define icmp(...) builder_->CreateICmp(__VA_ARGS__)
|
|
||||||
#define icmp_eq(...) builder_->CreateICmpEQ(__VA_ARGS__)
|
|
||||||
#define icmp_sge(...) builder_->CreateICmpSGE(__VA_ARGS__)
|
|
||||||
#define icmp_sle(...) builder_->CreateICmpSLE(__VA_ARGS__)
|
|
||||||
#define icmp_ult(...) builder_->CreateICmpULT(__VA_ARGS__)
|
|
||||||
#define insert_elt(...) builder_->CreateInsertElement(__VA_ARGS__)
|
|
||||||
#define intrinsic(...) builder_->CreateIntrinsic(__VA_ARGS__)
|
|
||||||
#define load(...) builder_->CreateLoad(__VA_ARGS__)
|
|
||||||
#define lshr(...) builder_->CreateLShr(__VA_ARGS__)
|
|
||||||
#define max_num(...) builder_->CreateMaxNum(__VA_ARGS__)
|
|
||||||
#define min_num(...) builder_->CreateMinNum(__VA_ARGS__)
|
|
||||||
#define neg(...) builder_->CreateNeg(__VA_ARGS__)
|
|
||||||
#define phi(...) builder_->CreatePHI(__VA_ARGS__)
|
|
||||||
#define ret(...) builder_->CreateRet(__VA_ARGS__)
|
|
||||||
#define select(...) builder_->CreateSelect(__VA_ARGS__)
|
|
||||||
#define store(...) builder_->CreateStore(__VA_ARGS__)
|
|
||||||
#define sub(...) builder_->CreateSub(__VA_ARGS__)
|
|
||||||
#define shl(...) builder_->CreateShl(__VA_ARGS__)
|
|
||||||
#define udiv(...) builder_->CreateUDiv(__VA_ARGS__)
|
|
||||||
#define urem(...) builder_->CreateURem(__VA_ARGS__)
|
|
||||||
#define splat(...) builder_->CreateVectorSplat(__VA_ARGS__)
|
|
||||||
#define xor_(...) builder_->CreateXor(__VA_ARGS__)
|
|
||||||
|
|
||||||
} // namespace triton::codegen
|
|
@@ -1,6 +1,7 @@
|
|||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
#include <stdexcept>
|
||||||
#include "triton/codegen/selection/generator.h"
|
#include "triton/codegen/selection/generator.h"
|
||||||
#include "triton/codegen/target.h"
|
#include "triton/codegen/target.h"
|
||||||
#include "triton/codegen/analysis/axes.h"
|
#include "triton/codegen/analysis/axes.h"
|
||||||
@@ -1355,9 +1356,6 @@ public:
|
|||||||
need_trans_ = k_order_ != order_[0];
|
need_trans_ = k_order_ != order_[0];
|
||||||
can_use_ldmatrix_ = dtsize == 2 || (!need_trans_);
|
can_use_ldmatrix_ = dtsize == 2 || (!need_trans_);
|
||||||
|
|
||||||
// std::cout << can_use_ldmatrix_ << std::endl;
|
|
||||||
// std::cout << need_trans_ << std::endl;
|
|
||||||
|
|
||||||
// we need more pointers at the fast-changing axis,
|
// we need more pointers at the fast-changing axis,
|
||||||
if (can_use_ldmatrix_)
|
if (can_use_ldmatrix_)
|
||||||
num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]];
|
num_ptr_ = tile_shape[order[0]] / (order[0] == k_order? 1 : wpt) / instr_shape[order[0]];
|
||||||
@@ -1365,6 +1363,9 @@ public:
|
|||||||
num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]];
|
num_ptr_ = tile_shape[order[0]] / wpt / mat_shape[order[0]];
|
||||||
num_ptr_ = std::max<int>(num_ptr_, 2);
|
num_ptr_ = std::max<int>(num_ptr_, 2);
|
||||||
|
|
||||||
|
// special rule for i8/u8, 4 ptrs for each matrix
|
||||||
|
if (!can_use_ldmatrix_ && dtsize_ == 1)
|
||||||
|
num_ptr_ *= 4;
|
||||||
|
|
||||||
// load_v4 stride (in num of mats)
|
// load_v4 stride (in num of mats)
|
||||||
int load_stride_in_mat[2];
|
int load_stride_in_mat[2];
|
||||||
@@ -1445,6 +1446,46 @@ public:
|
|||||||
}
|
}
|
||||||
return offs;
|
return offs;
|
||||||
// throw std::runtime_error("not implemented");
|
// throw std::runtime_error("not implemented");
|
||||||
|
} else if (dtsize_ == 1 && need_trans_) {
|
||||||
|
// load i8/u8 matrices with lds8
|
||||||
|
Value *c_off_in_mat = udiv(lane, i32(4)); //
|
||||||
|
Value *s_off_in_mat = mul(urem(lane, i32(4)), i32(4)); // each thread load 4 cols
|
||||||
|
|
||||||
|
// Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
|
||||||
|
std::vector<Value*> offs(num_ptr_);
|
||||||
|
for (int mat = 0; mat < 4; ++mat) { // loads 4 mats each time
|
||||||
|
int k_mat_arr_int = (k_order_ == 1) ? mat/2 : mat%2;
|
||||||
|
int nk_mat_arr_int = (k_order_ == 1) ? mat%2 : mat/2;
|
||||||
|
if (k_mat_arr_int > 0) // we don't need pointers for k
|
||||||
|
continue;
|
||||||
|
Value *k_mat_arr = i32(k_mat_arr_int);
|
||||||
|
Value *nk_mat_arr = i32(nk_mat_arr_int);
|
||||||
|
// physical offset (before swizzling)
|
||||||
|
Value *c_mat_off = add(mul(warp_off, i32(warp_off_stride_)),
|
||||||
|
mul(nk_mat_arr, i32(mat_arr_stride_)));
|
||||||
|
Value *s_mat_off = k_mat_arr; // always 0?
|
||||||
|
|
||||||
|
for (int loadx4_off = 0; loadx4_off < num_ptr_/8; ++loadx4_off) {
|
||||||
|
for (int elem_off = 0; elem_off < 4; ++elem_off) {
|
||||||
|
int ptr_off = loadx4_off*8 + nk_mat_arr_int*4 + elem_off;
|
||||||
|
|
||||||
|
Value *c_mat_off_i = add(c_mat_off, i32(loadx4_off*p_load_stride_in_mat_*(k_order_ == 1?1:2)));
|
||||||
|
Value *s_off_in_mat_elem = add(s_off_in_mat, i32(elem_off));
|
||||||
|
|
||||||
|
// disable swizzling ...
|
||||||
|
// Value *phase = urem(udiv(s_off_in_mat, i32(per_phase_)), i32(max_phase_));
|
||||||
|
// c_mat_off_i = xor_(c_mat_off_i, phase);
|
||||||
|
|
||||||
|
Value *c_off = add(c_off_in_mat, mul(c_mat_off_i, i32(c_mat_shape_)));
|
||||||
|
Value *s_off = add(s_off_in_mat_elem, mul(s_mat_off, i32(s_mat_shape_)));
|
||||||
|
// To prevent out-of-bound access when the tile is too small
|
||||||
|
c_off = urem(c_off, i32(tile_shape_[order_[0]]));
|
||||||
|
s_off = urem(s_off, i32(tile_shape_[order_[1]]));
|
||||||
|
offs[ptr_off] = add(c_off, mul(s_off, i32(s_stride_)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return offs;
|
||||||
} else
|
} else
|
||||||
throw std::runtime_error("invalid smem load config");
|
throw std::runtime_error("invalid smem load config");
|
||||||
}
|
}
|
||||||
@@ -1461,8 +1502,10 @@ public:
|
|||||||
int ptr_idx = -1;
|
int ptr_idx = -1;
|
||||||
if (can_use_ldmatrix_)
|
if (can_use_ldmatrix_)
|
||||||
ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]);
|
ptr_idx = mat_idx[order_[0]] / (instr_shape_[order_[0]] / mat_shape_[order_[0]]);
|
||||||
else // tf32 & trans
|
else if (dtsize_ == 4 && need_trans_) // tf32 & trans
|
||||||
ptr_idx = mat_idx[order_[0]];
|
ptr_idx = mat_idx[order_[0]];
|
||||||
|
else // i8 & trans
|
||||||
|
ptr_idx = mat_idx[order_[0]] * 4;
|
||||||
|
|
||||||
auto get_ptr = [&](int idx) -> Value* {
|
auto get_ptr = [&](int idx) -> Value* {
|
||||||
Value *ptr = nullptr;
|
Value *ptr = nullptr;
|
||||||
@@ -1495,9 +1538,7 @@ public:
|
|||||||
extract_val(res_v4, std::vector<unsigned>{1}),
|
extract_val(res_v4, std::vector<unsigned>{1}),
|
||||||
extract_val(res_v4, std::vector<unsigned>{2}),
|
extract_val(res_v4, std::vector<unsigned>{2}),
|
||||||
extract_val(res_v4, std::vector<unsigned>{3})};
|
extract_val(res_v4, std::vector<unsigned>{3})};
|
||||||
} else {
|
} else if (dtsize_ == 4 && need_trans_) { // use lds.32 to load tf32 matrices
|
||||||
// assert(false && "should not be here");
|
|
||||||
assert(dtsize_ == 4 && need_trans_);
|
|
||||||
Value *ptr2 = get_ptr(ptr_idx+1);
|
Value *ptr2 = get_ptr(ptr_idx+1);
|
||||||
assert(s_mat_stride_ == 1);
|
assert(s_mat_stride_ == 1);
|
||||||
int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_;
|
int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_;
|
||||||
@@ -1521,7 +1562,96 @@ public:
|
|||||||
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3);
|
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(elem3);
|
||||||
}
|
}
|
||||||
return {elem0, elem1, elem2, elem3};
|
return {elem0, elem1, elem2, elem3};
|
||||||
}
|
} else if (dtsize_ == 1 && need_trans_) { // use lds.8 to load i8/u8 matrices
|
||||||
|
Value *ptr00 = get_ptr(ptr_idx);
|
||||||
|
Value *ptr01 = get_ptr(ptr_idx+1);
|
||||||
|
Value *ptr02 = get_ptr(ptr_idx+2);
|
||||||
|
Value *ptr03 = get_ptr(ptr_idx+3);
|
||||||
|
|
||||||
|
Value *ptr10 = get_ptr(ptr_idx+4);
|
||||||
|
Value *ptr11 = get_ptr(ptr_idx+5);
|
||||||
|
Value *ptr12 = get_ptr(ptr_idx+6);
|
||||||
|
Value *ptr13 = get_ptr(ptr_idx+7);
|
||||||
|
|
||||||
|
assert(s_mat_stride_ == 1);
|
||||||
|
int s_offset_elem = mat_idx[order_[1]] * (s_mat_stride_*s_mat_shape_) * s_stride_;
|
||||||
|
int s_offset_arr_elem = 1 * (s_mat_stride_*s_mat_shape_) * s_stride_;
|
||||||
|
|
||||||
|
Value *i8v4_elems[4];
|
||||||
|
Value *i32_elems[4];
|
||||||
|
for (int i=0; i<4; ++i)
|
||||||
|
i8v4_elems[i] = UndefValue::get(vec_ty(i8_ty, 4));
|
||||||
|
|
||||||
|
Value *elem00, *elem01, *elem02, *elem03;
|
||||||
|
Value *elem10, *elem11, *elem12, *elem13;
|
||||||
|
Value *elem20, *elem21, *elem22, *elem23;
|
||||||
|
Value *elem30, *elem31, *elem32, *elem33;
|
||||||
|
Value *i8_elems[4*4];
|
||||||
|
if (k_order_ == 1) { //
|
||||||
|
i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem)));
|
||||||
|
i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem)));
|
||||||
|
i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem)));
|
||||||
|
i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem)));
|
||||||
|
|
||||||
|
assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8));
|
||||||
|
|
||||||
|
i8_elems[1*4 + 0] = load(gep(ptr10, i32(s_offset_elem)));
|
||||||
|
i8_elems[1*4 + 1] = load(gep(ptr11, i32(s_offset_elem)));
|
||||||
|
i8_elems[1*4 + 2] = load(gep(ptr12, i32(s_offset_elem)));
|
||||||
|
i8_elems[1*4 + 3] = load(gep(ptr13, i32(s_offset_elem)));
|
||||||
|
|
||||||
|
i8_elems[2*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
i8_elems[2*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
i8_elems[2*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
i8_elems[2*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
|
||||||
|
i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
|
||||||
|
for (int m=0; m<4; ++m) {
|
||||||
|
for (int e=0; e<4; ++e)
|
||||||
|
i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e);
|
||||||
|
i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty);
|
||||||
|
}
|
||||||
|
} else { // for b (k first)
|
||||||
|
i8_elems[0*4 + 0] = load(gep(ptr00, i32(s_offset_elem)));
|
||||||
|
i8_elems[0*4 + 1] = load(gep(ptr01, i32(s_offset_elem)));
|
||||||
|
i8_elems[0*4 + 2] = load(gep(ptr02, i32(s_offset_elem)));
|
||||||
|
i8_elems[0*4 + 3] = load(gep(ptr03, i32(s_offset_elem)));
|
||||||
|
|
||||||
|
assert(i8_elems[0*4 + 0]->getType()->isIntegerTy(8));
|
||||||
|
|
||||||
|
i8_elems[2*4 + 0] = load(gep(ptr10, i32(s_offset_elem)));
|
||||||
|
i8_elems[2*4 + 1] = load(gep(ptr11, i32(s_offset_elem)));
|
||||||
|
i8_elems[2*4 + 2] = load(gep(ptr12, i32(s_offset_elem)));
|
||||||
|
i8_elems[2*4 + 3] = load(gep(ptr13, i32(s_offset_elem)));
|
||||||
|
|
||||||
|
i8_elems[1*4 + 0] = load(gep(ptr00, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
i8_elems[1*4 + 1] = load(gep(ptr01, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
i8_elems[1*4 + 2] = load(gep(ptr02, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
i8_elems[1*4 + 3] = load(gep(ptr03, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
|
||||||
|
i8_elems[3*4 + 0] = load(gep(ptr10, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
i8_elems[3*4 + 1] = load(gep(ptr11, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
i8_elems[3*4 + 2] = load(gep(ptr12, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
i8_elems[3*4 + 3] = load(gep(ptr13, i32(s_offset_elem + s_offset_arr_elem)));
|
||||||
|
|
||||||
|
for (int m=0; m<4; ++m) {
|
||||||
|
for (int e=0; e<4; ++e)
|
||||||
|
i8v4_elems[m] = insert_elt(i8v4_elems[m], i8_elems[m*4 + e], e);
|
||||||
|
i32_elems[m] = bit_cast(i8v4_elems[m], i32_ty);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (k == 0 && inc == 1 && is_prefetch) {
|
||||||
|
for (int m = 0; m < 4; ++m)
|
||||||
|
for (int e = 0; e < 4; ++e)
|
||||||
|
prefetch_latch_to_bb_[pn->get_incoming_value(1)].push_back(i8_elems[m*4 + e]);
|
||||||
|
}
|
||||||
|
return {i32_elems[0], i32_elems[1], i32_elems[2], i32_elems[3]};
|
||||||
|
} else
|
||||||
|
throw std::runtime_error("invalid smem load");
|
||||||
}
|
}
|
||||||
|
|
||||||
int get_num_ptr() const { return num_ptr_; }
|
int get_num_ptr() const { return num_ptr_; }
|
||||||
@@ -1596,12 +1726,18 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
|||||||
const int num_rep_n = shapes[1] / layout->shape_per_cta(1);
|
const int num_rep_n = shapes[1] / layout->shape_per_cta(1);
|
||||||
const int num_rep_k = std::max<int>(NK/mma_instr_k, 1);
|
const int num_rep_k = std::max<int>(NK/mma_instr_k, 1);
|
||||||
|
|
||||||
|
// floating point types
|
||||||
Type *fp32_ty = f32_ty;
|
Type *fp32_ty = f32_ty;
|
||||||
Type *fp16x2_ty = vec_ty(f16_ty, 2);
|
Type *fp16x2_ty = vec_ty(f16_ty, 2);
|
||||||
Type *bf16x2_ty = vec_ty(bf16_ty, 2);
|
Type *bf16x2_ty = vec_ty(bf16_ty, 2);
|
||||||
Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
|
Type *fp16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty});
|
||||||
Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty});
|
Type *bf16x2_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{bf16x2_ty, bf16x2_ty, bf16x2_ty, bf16x2_ty});
|
||||||
Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
Type *fp32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
||||||
|
// integer types
|
||||||
|
Type *i8x4_ty = vec_ty(i8_ty, 4);
|
||||||
|
Type *i8x4_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty});
|
||||||
|
Type *i32_pack4_ty = StructType::get(*ctx_, std::vector<llvm::Type*>{i32_ty, i32_ty, i32_ty, i32_ty});
|
||||||
|
|
||||||
|
|
||||||
FunctionType *ldmatrix_ty = nullptr;
|
FunctionType *ldmatrix_ty = nullptr;
|
||||||
FunctionType *mma_ty = nullptr;
|
FunctionType *mma_ty = nullptr;
|
||||||
@@ -1630,6 +1766,16 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
|||||||
smem_ptr_ty = ptr_ty(fp32_ty, 3);
|
smem_ptr_ty = ptr_ty(fp32_ty, 3);
|
||||||
ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
|
ldmatrix_ty = FunctionType::get(fp32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
|
||||||
phi_ty = fp32_ty;
|
phi_ty = fp32_ty;
|
||||||
|
} else if (A_ir_ty->is_integer_ty(8) && B_ir_ty->is_integer_ty(8)) {
|
||||||
|
// FIXME: We should use i8 here (but nvptx will generate extra casts when using i8)
|
||||||
|
mma_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false);
|
||||||
|
smem_ptr_ty = ptr_ty(i8_ty, 3);
|
||||||
|
ldmatrix_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
|
||||||
|
phi_ty = i32_ty;
|
||||||
|
// mma_ty = FunctionType::get(i32_pack4_ty, std::vector<llvm::Type*>{i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i8x4_ty, i32_ty, i32_ty, i32_ty, i32_ty}, false);
|
||||||
|
// smem_ptr_ty = ptr_ty(i8_ty, 3);
|
||||||
|
// ldmatrix_ty = FunctionType::get(i8x4_pack4_ty, std::vector<llvm::Type*>{smem_ptr_ty}, false);
|
||||||
|
// phi_ty = i8x4_ty;
|
||||||
} else
|
} else
|
||||||
throw std::runtime_error("mma16816 data type not supported");
|
throw std::runtime_error("mma16816 data type not supported");
|
||||||
|
|
||||||
@@ -1690,7 +1836,7 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
|||||||
" {$4, $5, $6, $7},"
|
" {$4, $5, $6, $7},"
|
||||||
" {$8, $9},"
|
" {$8, $9},"
|
||||||
" {$10, $11, $12, $13};",
|
" {$10, $11, $12, $13};",
|
||||||
"=f,=f,=f,=f,r,r,r,r,r,r,0,1,2,3", true);
|
"=r,=r,=r,=r,r,r,r,r,r,r,0,1,2,3", true);
|
||||||
|
|
||||||
// create mma & unpack result, m, n, k are offsets in mat
|
// create mma & unpack result, m, n, k are offsets in mat
|
||||||
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
|
auto call_mma = [&](unsigned m, unsigned n, unsigned k) {
|
||||||
@@ -1715,12 +1861,12 @@ void generator::visit_mma16816(ir::dot_inst* C, ir::value *A, ir::value *B, ir::
|
|||||||
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
ir::phi_node* phiB = dynamic_cast<ir::phi_node*>(B);
|
||||||
|
|
||||||
auto register_lds2 =
|
auto register_lds2 =
|
||||||
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int n, int k, int inc, Value* val, bool is_prefetch) {
|
[&](std::map<std::pair<unsigned, unsigned>, Value*>& vals, int mn, int k, int inc, Value* val, bool is_prefetch) {
|
||||||
if (k < 2 && is_prefetch) {
|
if (k < 2 && is_prefetch) {
|
||||||
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
ir::basic_block* inc_block = phiA->get_incoming_block(inc);
|
||||||
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{n, k}], val, inc_block));
|
lazy_phi_incs_.push_back(std::make_tuple((PHINode*)vals[{mn, k}], val, inc_block));
|
||||||
} else
|
} else
|
||||||
vals[{n, k}] = val;
|
vals[{mn, k}] = val;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto load_a = [&](int m, int k, int inc, bool is_prefetch) {
|
auto load_a = [&](int m, int k, int inc, bool is_prefetch) {
|
||||||
@@ -1922,7 +2068,10 @@ void generator::visit_dot_inst(ir::dot_inst* dot) {
|
|||||||
return visit_mma884(dot, A, B, D, NK);
|
return visit_mma884(dot, A, B, D, NK);
|
||||||
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
|
if(!is_outer && is_mma && tgt_->as_nvidia()->sm() >= 80)
|
||||||
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
|
return visit_mma16816(dot, A, B, D, NK); // rename it as visit_mma_v2()?
|
||||||
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
if (dot->get_type()->get_scalar_ty()->is_fp32_ty() &&
|
||||||
|
A->get_type()->get_scalar_ty()->is_fp32_ty())
|
||||||
|
return visit_fmadot(dot, A, B, D, NK, c_ty, f_mul_add);
|
||||||
|
throw std::runtime_error("dot has invalid operand type");
|
||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_trans_inst(ir::trans_inst* trans) {
|
void generator::visit_trans_inst(ir::trans_inst* trans) {
|
||||||
|
@@ -61,7 +61,8 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
|||||||
// dot(a, b, c) + d -> dot(a, b, c + d)
|
// dot(a, b, c) + d -> dot(a, b, c + d)
|
||||||
// d + dot(a, b, c) -> dot(a, b, c + d)
|
// d + dot(a, b, c) -> dot(a, b, c + d)
|
||||||
auto add = dynamic_cast<ir::binary_operator*>(value);
|
auto add = dynamic_cast<ir::binary_operator*>(value);
|
||||||
if(add && add->get_op() == ir::binary_op_t::FAdd) {
|
if(add && (add->get_op() == ir::binary_op_t::FAdd || add->get_op() == ir::binary_op_t::Add)) {
|
||||||
|
bool is_int_dot = add->get_op() == ir::binary_op_t::Add;
|
||||||
ir::value *lhs = add->get_operand(0);
|
ir::value *lhs = add->get_operand(0);
|
||||||
ir::value *rhs = add->get_operand(1);
|
ir::value *rhs = add->get_operand(1);
|
||||||
ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs);
|
ir::dot_inst *lhs_dot = dynamic_cast<ir::dot_inst*>(lhs);
|
||||||
@@ -72,11 +73,17 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
|||||||
ir::value *other = (dot == lhs) ? rhs : lhs;
|
ir::value *other = (dot == lhs) ? rhs : lhs;
|
||||||
ir::value *acc = dot->get_operand(2);
|
ir::value *acc = dot->get_operand(2);
|
||||||
ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc);
|
ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(acc);
|
||||||
ir::constant_fp *_0 = nullptr;
|
ir::constant *_0 = nullptr;
|
||||||
if(splat)
|
if(splat)
|
||||||
_0 = dynamic_cast<ir::constant_fp*>(splat->get_operand(0));
|
_0 = dynamic_cast<ir::constant*>(splat->get_operand(0));
|
||||||
if(!(_0 && _0->get_value() == 0.0))
|
if(!_0)
|
||||||
return false;
|
return false;
|
||||||
|
if (auto *fp_0 = dynamic_cast<ir::constant_fp*>(_0))
|
||||||
|
if (fp_0->get_value() != 0.0)
|
||||||
|
return false;
|
||||||
|
if (auto *int_0 = dynamic_cast<ir::constant_int*>(_0))
|
||||||
|
if (int_0->get_value() != 0)
|
||||||
|
return false;
|
||||||
ir::value *a = dot->get_operand(0);
|
ir::value *a = dot->get_operand(0);
|
||||||
ir::value *b = dot->get_operand(1);
|
ir::value *b = dot->get_operand(1);
|
||||||
builder.set_insert_point(add);
|
builder.set_insert_point(add);
|
||||||
|
@@ -33,6 +33,9 @@ void prefetch::run(ir::module &mod) {
|
|||||||
if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() ||
|
if (!(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp16_ty() ||
|
||||||
dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() ||
|
dot->get_operand(0)->get_type()->get_scalar_ty()->is_bf16_ty() ||
|
||||||
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32()
|
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_fp32_ty() && dot->allow_tf32()
|
||||||
|
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80) ||
|
||||||
|
(dot->get_operand(0)->get_type()->get_scalar_ty()->is_integer_ty(8)
|
||||||
|
&& dot->get_operand(1)->get_type()->get_scalar_ty()->is_integer_ty(8)
|
||||||
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
&& tgt_->as_nvidia() && tgt_->as_nvidia()->sm() >= 80)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@@ -46,6 +46,7 @@
|
|||||||
#include "llvm/ExecutionEngine/ExecutionEngine.h"
|
#include "llvm/ExecutionEngine/ExecutionEngine.h"
|
||||||
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
|
||||||
#include "llvm/Transforms/Utils/Cloning.h"
|
#include "llvm/Transforms/Utils/Cloning.h"
|
||||||
|
#include "llvm/Transforms/Scalar.h"
|
||||||
|
|
||||||
// begin AMD stuff
|
// begin AMD stuff
|
||||||
#include "llvm/Support/FileSystem.h"
|
#include "llvm/Support/FileSystem.h"
|
||||||
@@ -121,9 +122,12 @@ std::string llir_to_ptx(llvm::Module* module, int cc, int version){
|
|||||||
init_llvm();
|
init_llvm();
|
||||||
// verify and store llvm
|
// verify and store llvm
|
||||||
llvm::legacy::PassManager pm;
|
llvm::legacy::PassManager pm;
|
||||||
// module->print(llvm::outs(), nullptr);
|
|
||||||
pm.add(llvm::createVerifierPass());
|
pm.add(llvm::createVerifierPass());
|
||||||
|
// pm.add(llvm::createDeadCodeEliminationPass());
|
||||||
|
// pm.add(llvm::createEarlyCSEPass());
|
||||||
pm.run(*module);
|
pm.run(*module);
|
||||||
|
// module->print(llvm::outs(), nullptr);
|
||||||
|
|
||||||
// create machine
|
// create machine
|
||||||
module->setTargetTriple(triple);
|
module->setTargetTriple(triple);
|
||||||
std::string error;
|
std::string error;
|
||||||
|
@@ -726,7 +726,11 @@ ir::value *dispatch::atomic_xchg(ir::value* ptr, ir::value *val, ir::value *mask
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) {
|
ir::value *dispatch::dot(ir::value *lhs, ir::value *rhs, ir::constant_int *allow_tf32, ir::builder *builder) {
|
||||||
ir::value *_0 = builder->get_float32(0);
|
ir::value *_0 = nullptr;
|
||||||
|
if (lhs->get_type()->is_int_or_tileint_ty())
|
||||||
|
_0 = builder->get_int32(0);
|
||||||
|
else
|
||||||
|
_0 = builder->get_float32(0);
|
||||||
unsigned M = lhs->get_type()->get_block_shapes()[0];
|
unsigned M = lhs->get_type()->get_block_shapes()[0];
|
||||||
unsigned N = rhs->get_type()->get_block_shapes()[1];
|
unsigned N = rhs->get_type()->get_block_shapes()[1];
|
||||||
_0 = builder->create_splat(_0, {M, N});
|
_0 = builder->create_splat(_0, {M, N});
|
||||||
|
@@ -77,7 +77,7 @@ class CMakeBuild(build_ext):
|
|||||||
|
|
||||||
def build_extension(self, ext):
|
def build_extension(self, ext):
|
||||||
llvm_include_dir, llvm_library_dir = get_llvm()
|
llvm_include_dir, llvm_library_dir = get_llvm()
|
||||||
# self.debug = True
|
self.debug = True
|
||||||
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path)))
|
||||||
# create build directories
|
# create build directories
|
||||||
build_suffix = 'debug' if self.debug else 'release'
|
build_suffix = 'debug' if self.debug else 'release'
|
||||||
|
@@ -661,11 +661,20 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
|
|||||||
# ---------------
|
# ---------------
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("epilogue, allow_tf32",
|
@pytest.mark.parametrize("epilogue, allow_tf32, dtype",
|
||||||
[(epilogue, allow_tf32)
|
[(epilogue, allow_tf32, dtype)
|
||||||
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
|
for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols']
|
||||||
for allow_tf32 in [True, False]])
|
for allow_tf32 in [True, False]
|
||||||
def test_dot(epilogue, allow_tf32, device='cuda'):
|
for dtype in ['float32', 'int8']
|
||||||
|
if not (allow_tf32 and (dtype == 'int8'))])
|
||||||
|
def test_dot(epilogue, allow_tf32, dtype, device='cuda'):
|
||||||
|
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
||||||
|
if cc < 80:
|
||||||
|
if dtype == 'int8':
|
||||||
|
pytest.skip("Only test int8 on devices with sm >= 80")
|
||||||
|
elif dtype == 'float32' and allow_tf32:
|
||||||
|
pytest.skip("Only test tf32 on devices with sm >= 80")
|
||||||
|
|
||||||
# triton kernel
|
# triton kernel
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(X, stride_xm, stride_xk,
|
def kernel(X, stride_xm, stride_xk,
|
||||||
@@ -693,18 +702,15 @@ def test_dot(epilogue, allow_tf32, device='cuda'):
|
|||||||
# input
|
# input
|
||||||
M, N, K = 64, 64, 32
|
M, N, K = 64, 64, 32
|
||||||
rs = RandomState(17)
|
rs = RandomState(17)
|
||||||
x = numpy_random((M, K), dtype_str='float32', rs=rs)
|
x = numpy_random((M, K), dtype_str=dtype, rs=rs)
|
||||||
y = numpy_random((K, N), dtype_str='float32', rs=rs)
|
y = numpy_random((K, N), dtype_str=dtype, rs=rs)
|
||||||
if allow_tf32:
|
if allow_tf32:
|
||||||
cc = _triton.runtime.cc(_triton.runtime.backend.CUDA, torch.cuda.current_device())
|
|
||||||
if cc < 80:
|
|
||||||
pytest.skip("Only test tf32 on devices with sm >= 80")
|
|
||||||
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
x = (x.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||||
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
y = (y.view('uint32') & np.uint32(0xffffe000)).view('float32')
|
||||||
x_tri = to_triton(x, device=device)
|
x_tri = to_triton(x, device=device)
|
||||||
y_tri = to_triton(y, device=device)
|
y_tri = to_triton(y, device=device)
|
||||||
# triton result
|
# triton result
|
||||||
z = numpy_random((M, N), dtype_str='float32', rs=rs)
|
z = numpy_random((M, N), dtype_str=dtype, rs=rs)
|
||||||
z_tri = to_triton(z, device=device)
|
z_tri = to_triton(z, device=device)
|
||||||
if epilogue == 'trans':
|
if epilogue == 'trans':
|
||||||
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
z_tri = torch.as_strided(z_tri, (M, N), z_tri.stride()[::-1])
|
||||||
@@ -732,8 +738,10 @@ def test_dot(epilogue, allow_tf32, device='cuda'):
|
|||||||
assert 'st.global.v4' in ptx
|
assert 'st.global.v4' in ptx
|
||||||
if allow_tf32:
|
if allow_tf32:
|
||||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' in ptx
|
||||||
else:
|
elif dtype == 'float32':
|
||||||
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
assert 'mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32' not in ptx
|
||||||
|
elif dtype == 'int8':
|
||||||
|
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||||
|
|
||||||
|
|
||||||
def test_dot_without_load():
|
def test_dot_without_load():
|
||||||
|
Reference in New Issue
Block a user