[feature] basic tensor core utilization works
This commit is contained in:
@@ -54,26 +54,11 @@ void matmul(restrict read_only fp16 *A, restrict read_only fp16 *B,
|
|||||||
}
|
}
|
||||||
int32 rxc[TM] = get_global_range[TM](0);
|
int32 rxc[TM] = get_global_range[TM](0);
|
||||||
int32 ryc[TN] = get_global_range[TN](1);
|
int32 ryc[TN] = get_global_range[TN](1);
|
||||||
int32 ridx = get_range_id(0);
|
|
||||||
int32 ridy = get_range_id(1);
|
|
||||||
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||||
int32 *plock = locks + ridx + ridy*grid0;
|
|
||||||
while(__atomic_cas(plock, 0, 1));
|
|
||||||
int32 *pcount = plock + grid0*grid1;
|
|
||||||
int32 count = *pcount;
|
|
||||||
int32 countp1 = select(count == GZ - 1, 0, count + 1);
|
|
||||||
int1 checkc0[TM] = rxc < M;
|
int1 checkc0[TM] = rxc < M;
|
||||||
int1 checkc1[TN] = ryc < N;
|
int1 checkc1[TN] = ryc < N;
|
||||||
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||||
if(count == 0) {
|
|
||||||
@checkc *pc = c;
|
@checkc *pc = c;
|
||||||
*pcount = countp1;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
@checkc *pc = c + *pc;
|
|
||||||
*pcount = countp1;
|
|
||||||
}
|
|
||||||
__atomic_cas(plock, 1, 0);
|
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
|
@@ -6,14 +6,14 @@ data_files_path = tf.resource_loader.get_data_files_path()
|
|||||||
library_dir = '/home/philippe/development/triton/build/examples/python/tensorflow'
|
library_dir = '/home/philippe/development/triton/build/examples/python/tensorflow'
|
||||||
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
|
module = tf.load_op_library(os.path.join(library_dir, 'libtf_blocksparse.so'))
|
||||||
|
|
||||||
M, N, K = 16, 16, 16
|
M, N, K = 256, 256, 256
|
||||||
a = tf.placeholder(tf.float16, shape=[M, K])
|
a = tf.placeholder(tf.float16, shape=[M, K])
|
||||||
b = tf.placeholder(tf.float16, shape=[N, K])
|
b = tf.placeholder(tf.float16, shape=[N, K])
|
||||||
locks = tf.placeholder(tf.int32, shape=[4096])
|
locks = tf.placeholder(tf.int32, shape=[4096])
|
||||||
c = module.dot(a, b, locks)
|
c = module.dot(a, b, locks)
|
||||||
# Reference
|
# Reference
|
||||||
ha = np.ones((M, K)).astype(np.float16)
|
ha = np.random.rand(M, K).astype(np.float16)
|
||||||
hb = np.ones((N, K)).astype(np.float16)
|
hb = np.random.rand(N, K).astype(np.float16)
|
||||||
hresult = np.dot(hb.T, ha)
|
hresult = np.dot(hb.T, ha)
|
||||||
|
|
||||||
# Run
|
# Run
|
||||||
@@ -22,4 +22,7 @@ sess.run(tf.global_variables_initializer())
|
|||||||
result = sess.run([c], feed_dict = {locks: np.zeros(4096),
|
result = sess.run([c], feed_dict = {locks: np.zeros(4096),
|
||||||
a: ha,
|
a: ha,
|
||||||
b: hb})
|
b: hb})
|
||||||
print(result - hresult)
|
print(result)
|
||||||
|
print(hresult)
|
||||||
|
#print(result - hresult)
|
||||||
|
print(np.max(np.abs(result - hresult)))
|
||||||
|
@@ -57,6 +57,7 @@ private:
|
|||||||
public:
|
public:
|
||||||
shared_tile(llvm::Type* ty, const shapes_t &shapes, llvm::Value* ptr, llvm::IRBuilder<> &builder, llvm::Value* offset = nullptr);
|
shared_tile(llvm::Type* ty, const shapes_t &shapes, llvm::Value* ptr, llvm::IRBuilder<> &builder, llvm::Value* offset = nullptr);
|
||||||
void set_vector_size(unsigned vector_size);
|
void set_vector_size(unsigned vector_size);
|
||||||
|
void set_return_mode(bool return_vector);
|
||||||
void set_value(indices_t, llvm::Value *);
|
void set_value(indices_t, llvm::Value *);
|
||||||
llvm::Value* get_value(indices_t idx);
|
llvm::Value* get_value(indices_t idx);
|
||||||
llvm::Value* get_pointer() { return ptr_; }
|
llvm::Value* get_pointer() { return ptr_; }
|
||||||
@@ -64,6 +65,7 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
llvm::Value *ptr_;
|
llvm::Value *ptr_;
|
||||||
|
bool return_vector_;
|
||||||
llvm::Value *offset_;
|
llvm::Value *offset_;
|
||||||
llvm::IRBuilder<> &builder_;
|
llvm::IRBuilder<> &builder_;
|
||||||
std::map<indices_t, llvm::Value*> ptr_cache_;
|
std::map<indices_t, llvm::Value*> ptr_cache_;
|
||||||
|
@@ -129,6 +129,7 @@ Value* shared_tile::shared_offset(indices_t idx) {
|
|||||||
|
|
||||||
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRBuilder<> &builder, Value *offset):
|
shared_tile::shared_tile(Type *ty, const shapes_t &shapes, Value *ptr, llvm::IRBuilder<> &builder, Value *offset):
|
||||||
tile(ty, shapes), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1){
|
tile(ty, shapes), ptr_(ptr), builder_(builder), offset_(offset), vector_size_(1){
|
||||||
|
return_vector_ = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void shared_tile::set_value(indices_t idx, Value *value) {
|
void shared_tile::set_value(indices_t idx, Value *value) {
|
||||||
@@ -142,12 +143,18 @@ void shared_tile::set_vector_size(unsigned vector_size) {
|
|||||||
vector_size_ = vector_size;
|
vector_size_ = vector_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void shared_tile::set_return_mode(bool return_vector){
|
||||||
|
return_vector_ = return_vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
Value* shared_tile::get_value(indices_t idx) {
|
Value* shared_tile::get_value(indices_t idx) {
|
||||||
indices_t non_cst_idx, cst_idx;
|
indices_t non_cst_idx, cst_idx;
|
||||||
extract_constant(idx, non_cst_idx, cst_idx);
|
extract_constant(idx, non_cst_idx, cst_idx);
|
||||||
Value *&base_ptr = ptr_cache_[non_cst_idx];
|
Value *&base_ptr = ptr_cache_[non_cst_idx];
|
||||||
if(base_ptr == nullptr){
|
if(base_ptr == nullptr){
|
||||||
base_ptr = builder_.CreateGEP(ptr_, shared_offset(non_cst_idx));
|
base_ptr = builder_.CreateGEP(ptr_, shared_offset(non_cst_idx));
|
||||||
|
// base_ptr = builder_.CreateBitCast(base_ptr, load_ptr_->getType());
|
||||||
if(vector_size_ > 1){
|
if(vector_size_ > 1){
|
||||||
Type *vec_ty = VectorType::get(base_ptr->getType()->getPointerElementType(), vector_size_);
|
Type *vec_ty = VectorType::get(base_ptr->getType()->getPointerElementType(), vector_size_);
|
||||||
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
|
Type *vec_ptr_ty = PointerType::get(vec_ty, base_ptr->getType()->getPointerAddressSpace());
|
||||||
@@ -160,7 +167,7 @@ Value* shared_tile::get_value(indices_t idx) {
|
|||||||
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
|
div = builder_.CreateUDiv(offset, builder_.getInt32(vector_size_));
|
||||||
Value *ptr = builder_.CreateGEP(base_ptr, div);
|
Value *ptr = builder_.CreateGEP(base_ptr, div);
|
||||||
Value *result = builder_.CreateLoad(ptr);
|
Value *result = builder_.CreateLoad(ptr);
|
||||||
if(vector_size_ > 1) {
|
if(return_vector_ == false && vector_size_ > 1) {
|
||||||
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
|
Value *rem = builder_.CreateURem(offset, builder_.getInt32(vector_size_));
|
||||||
result = builder_.CreateExtractElement(result, rem);
|
result = builder_.CreateExtractElement(result, rem);
|
||||||
}
|
}
|
||||||
@@ -479,19 +486,19 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
|
|||||||
builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), _4)));
|
builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), _4)));
|
||||||
// idx_i
|
// idx_i
|
||||||
std::vector<Value*> idx_j;
|
std::vector<Value*> idx_j;
|
||||||
for(unsigned i = 0; i < 2; i++){
|
for(unsigned j = 0; j < 2; j++){
|
||||||
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(i*4)));
|
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(j*4)));
|
||||||
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(i*4 + 1)));
|
idx_j.push_back(builder.CreateAdd(offset_j, builder.getInt32(j*4 + 1)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// idx_j
|
// idx_j
|
||||||
std::vector<Value*> idx_i;
|
std::vector<Value*> idx_i;
|
||||||
for(unsigned j = 0; j < 2; j++){
|
for(unsigned i = 0; i < 2; i++){
|
||||||
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(j*2)));
|
idx_i.push_back(builder.CreateAdd(offset_i, builder.getInt32(i*2)));
|
||||||
}
|
}
|
||||||
|
|
||||||
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_j};
|
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i};
|
||||||
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_i};
|
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -855,6 +862,11 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
TA->set_vector_size(2);
|
||||||
|
TB->set_vector_size(2);
|
||||||
|
TA->set_return_mode(true);
|
||||||
|
TB->set_return_mode(true);
|
||||||
|
Value *_0 = builder.getInt32(0);
|
||||||
Value *_1 = builder.getInt32(1);
|
Value *_1 = builder.getInt32(1);
|
||||||
Value *_2 = builder.getInt32(2);
|
Value *_2 = builder.getInt32(2);
|
||||||
Value *_3 = builder.getInt32(3);
|
Value *_3 = builder.getInt32(3);
|
||||||
@@ -864,47 +876,62 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|||||||
BasicBlock *current = builder.GetInsertBlock();
|
BasicBlock *current = builder.GetInsertBlock();
|
||||||
Module *module = current->getModule();
|
Module *module = current->getModule();
|
||||||
Value *tid = tgt_->get_local_id(module, builder, 0);
|
Value *tid = tgt_->get_local_id(module, builder, 0);
|
||||||
// offset_a_i = (tid & 3)
|
// offset_a_i = (tid & 4)*2 + (tid & 16)/4;
|
||||||
// offset_a_j = (tid & 4)*2 + (tid & 16)/4;
|
// offset_a_k = (tid & 3)
|
||||||
Value *offset_a_i = builder.CreateAnd(tid, _3);
|
Value *offset_a_i = builder.CreateAdd(builder.CreateMul(builder.CreateAnd(tid, _4), _2),
|
||||||
Value *offset_a_k = builder.CreateAdd(builder.CreateMul(builder.CreateAnd(tid, _4),
|
|
||||||
_2),
|
|
||||||
builder.CreateUDiv(builder.CreateAnd(tid, _16),
|
builder.CreateUDiv(builder.CreateAnd(tid, _16),
|
||||||
_4));
|
_4));
|
||||||
// offset_b_i = (tid & 3)
|
Value *offset_a_k = builder.CreateAnd(tid, _3);
|
||||||
// offset_b_j = (tid & 8)*1 + (tid & 16)/4
|
|
||||||
Value *offset_b_i = builder.CreateAnd(tid, _3);
|
// offset_b_i = (tid & 4)*1 + (tid & 16)/4
|
||||||
Value *offset_b_k = builder.CreateAdd(builder.CreateAnd(tid, _8),
|
// offset_b_k = (tid & 3)
|
||||||
|
Value *offset_b_i = builder.CreateAdd(builder.CreateAnd(tid, _8),
|
||||||
builder.CreateUDiv(builder.CreateAnd(tid, _16),
|
builder.CreateUDiv(builder.CreateAnd(tid, _16),
|
||||||
_4));
|
_4));
|
||||||
Value *ha0 = TA->get_value({offset_a_i, offset_a_k});
|
Value *offset_b_k = builder.CreateAnd(tid, _3);
|
||||||
Value *ha1 = TA->get_value({builder.CreateAdd(offset_a_i, _1), offset_a_k});
|
|
||||||
Value *hb0 = TB->get_value({offset_b_i, offset_b_k});
|
|
||||||
Value *hb1 = TB->get_value({builder.CreateAdd(offset_b_i, _1), offset_b_k});
|
|
||||||
std::vector<Value *> fc;
|
std::vector<Value *> fc;
|
||||||
result->for_each([&](indices_t idx){
|
result->for_each([&](indices_t idx){
|
||||||
fc.push_back(result->get_value(idx));
|
fc.push_back(result->get_value(idx));
|
||||||
});
|
});
|
||||||
|
|
||||||
Type *void_ty = builder.getVoidTy();
|
|
||||||
Type *int32_ty = builder.getInt32Ty();
|
|
||||||
Type *fp32_ty = builder.getFloatTy();
|
Type *fp32_ty = builder.getFloatTy();
|
||||||
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
|
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
|
||||||
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
|
||||||
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {int32_ty, int32_ty, int32_ty, int32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
FunctionType *mma_ty = FunctionType::get(fp32_pack8_ty, {fp16x2_ty, fp16x2_ty, fp16x2_ty, fp16x2_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty}, false);
|
||||||
|
|
||||||
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 "
|
InlineAsm *mma_fn = InlineAsm::get(mma_ty, " mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32 "
|
||||||
"{$0, $1, $2, $3, $4, $5, $6, $7}, "
|
"{$0, $1, $2, $3, $4, $5, $6, $7}, "
|
||||||
"{$8, $9}, "
|
"{$8, $9}, "
|
||||||
"{$10, $11}, "
|
"{$10, $11}, "
|
||||||
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
"{$0, $1, $2, $3, $4, $5, $6, $7};", "=f,=f,=f,=f,=f,=f,=f,=f,r,r,r,r,0,1,2,3,4,5,6,7", false);
|
||||||
Value *nc = builder.CreateCall(mma_fn, {builder.getInt32(0), builder.getInt32(0), builder.getInt32(0), builder.getInt32(0), fc[0], fc[1], fc[2], fc[3], fc[4], fc[5], fc[6], fc[7]});
|
|
||||||
std::cout << mma_fn->getFunctionType()->getFunctionNumParams() << std::endl;
|
for(unsigned K = 0; K < NK; K += 4){
|
||||||
|
Value *_K = builder.getInt32(K);
|
||||||
|
Value *ha0 = TA->get_value({offset_a_i, builder.CreateAdd(offset_a_k, _K)});
|
||||||
|
Value *ha1 = TA->get_value({builder.CreateAdd(offset_a_i, _2), builder.CreateAdd(offset_a_k, _K)});
|
||||||
|
Value *hb0 = TB->get_value({offset_b_i, builder.CreateAdd(offset_b_k, _K)});
|
||||||
|
Value *hb1 = TB->get_value({builder.CreateAdd(offset_b_i, _2), builder.CreateAdd(offset_b_k, _K)});
|
||||||
|
Value *nc = builder.CreateCall(mma_fn, {ha0, ha1, hb0, hb1, fc[0], fc[2], fc[1], fc[3], fc[4], fc[6], fc[5], fc[7]});
|
||||||
|
fc[0] = builder.CreateExtractValue(nc, {0});
|
||||||
|
fc[2] = builder.CreateExtractValue(nc, {1});
|
||||||
|
fc[1] = builder.CreateExtractValue(nc, {2});
|
||||||
|
fc[3] = builder.CreateExtractValue(nc, {3});
|
||||||
|
fc[4] = builder.CreateExtractValue(nc, {4});
|
||||||
|
fc[6] = builder.CreateExtractValue(nc, {5});
|
||||||
|
fc[5] = builder.CreateExtractValue(nc, {6});
|
||||||
|
fc[7] = builder.CreateExtractValue(nc, {7});
|
||||||
|
}
|
||||||
|
|
||||||
|
// write back
|
||||||
unsigned i = 0;
|
unsigned i = 0;
|
||||||
result->for_each([&](indices_t idx){
|
result->for_each([&](indices_t idx){
|
||||||
result->set_value(idx, builder.CreateExtractValue(nc, {i++}));
|
result->set_value(idx, fc[i++]);
|
||||||
});
|
});
|
||||||
std::cout << "haha" << std::endl;
|
|
||||||
|
TA->set_return_mode(false);
|
||||||
|
TB->set_return_mode(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
Reference in New Issue
Block a user