[python] added basic tensorflow support

This commit is contained in:
Philippe Tillet
2019-08-17 18:18:26 -07:00
parent 078f0052fe
commit b4a9ed9663
24 changed files with 341 additions and 173 deletions

160
examples/cpp/cuda.h Normal file
View File

@@ -0,0 +1,160 @@
/* Copyright 2015-2017 Philippe Tillet
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include <algorithm>
#include <vector>
#include <cassert>
#include "cublas_v2.h"
#include "triton/driver/buffer.h"
#include "triton/driver/stream.h"
#include "triton/driver/context.h"
#include "triton/tools/bench.hpp"
enum cublasStrategy_t{
CUBLAS_PREFER_FASTEST,
CUBLAS_HEURISTICS
};
enum DType{
HALF_TYPE,
FLOAT_TYPE,
DOUBLE_TYPE,
};
inline size_t size_of(DType dtype){
switch (dtype) {
case HALF_TYPE: return 2;
case FLOAT_TYPE: return 4;
case DOUBLE_TYPE: return 8;
default: throw;
}
}
std::vector<cublasGemmAlgo_t> gather_all_algos() {
std::vector<cublasGemmAlgo_t> result;
// non-tensor ops
for(int i = -1; i < 24; i++)
result.push_back((cublasGemmAlgo_t)i);
// tensor ops
for(int i = 99; i < 116; i++)
result.push_back((cublasGemmAlgo_t)i);
return result;
}
static const std::vector<cublasGemmAlgo_t> algorithms = gather_all_algos();
static const std::map<DType, cudaDataType> cu_dtype = {
{HALF_TYPE, CUDA_R_16F},
{FLOAT_TYPE, CUDA_R_32F},
{DOUBLE_TYPE, CUDA_R_64F}
};
static const std::map<char, cublasOperation_t> cu_op = {
{false, CUBLAS_OP_N},
{true, CUBLAS_OP_T}
};
inline cublasGemmAlgo_t cublasGemmFastest(
triton::driver::stream* stream,
cublasHandle_t handle, cudaDataType cudt,
cublasOperation_t AT, cublasOperation_t BT,
int32_t M, int32_t N, int32_t K,
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
void* beta, CUdeviceptr C, int32_t ldc) {
// cache to avoid re-benchmarking
typedef std::tuple<cudaDataType_t,
cublasOperation_t, cublasOperation_t,
int32_t, int32_t, int32_t> key_t;
static std::map<key_t, cublasGemmAlgo_t> cache;
key_t key(cudt, AT, BT, M, N, K);
// benchmark algorithms if necessary
if(cache.find(key) == cache.end()){
std::vector<double> times;
for(cublasGemmAlgo_t a: algorithms) {
cublasStatus_t status;
double nanosec = triton::tools::bench([&](){ status = cublasGemmEx(handle, AT, BT,
M, N, K,
alpha, (const void*)A, cudt, lda,
(const void*)B, cudt, ldb,
beta, (void*)C, cudt, ldc, cudt,
a); }, stream);
if(status != CUBLAS_STATUS_SUCCESS)
nanosec = INFINITY;
}
size_t argmin = std::min_element(times.begin(), times.end()) - times.begin();
assert(times[argmin] != INFINITY);
cache.insert({key, algorithms[argmin]});
}
// return best algorithm
return cache.at(key);
}
/* Wrapper for cublasGemmEx */
inline cublasStatus_t cublasGemmEx(cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K,
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
void* beta, CUdeviceptr C, int32_t ldc, cublasGemmAlgo_t algo)
{
cublasStatus_t status = cublasGemmEx(handle, AT, BT, M, N, K, alpha, (const void*)A, cudt, lda, (const void*)B, cudt, ldb, beta, (void*)C, cudt, ldc, cudt, algo);
if(status != CUBLAS_STATUS_SUCCESS){
std::cout << status;
exit(EXIT_FAILURE);
}
}
/* Get cuBLAS handle */
cublasHandle_t cublasGetHandle(triton::driver::stream* stream) {
static std::map<CUstream, cublasHandle_t> cache;
CUstream key = *stream->cu();
// create handle if necessary
if(cache.find(key) == cache.end()) {
cublasHandle_t handle;
if(cublasCreate_v2(&handle) != CUBLAS_STATUS_SUCCESS)
throw std::runtime_error("Error: could not create cuBLAS handle");
cublasSetStream_v2(handle, key);
cache.insert({key, handle});
}
// return handle for the stream
return cache.at(key);
}
/* Simplified API for default GEMM */
inline void cublasGemm(DType dtype, triton::driver::stream* stream, bool AT, bool BT,
int32_t M, int32_t N, int32_t K,
void* alpha, triton::driver::buffer* A, int32_t lda,
triton::driver::buffer* B, int32_t ldb,
void* beta, triton::driver::buffer* C, int32_t ldc,
cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT) {
triton::driver::cu_context::context_switcher scope(*stream->context());
static cublasHandle_t handle = cublasGetHandle(stream);
if(dtype == HALF_TYPE)
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
cublasStatus_t status;
if(fastest)
*fastest = cublasGemmFastest(stream, handle, cu_dtype.at(dtype), cu_op.at(AT), cu_op.at(BT), M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc);
else
status = cublasGemmEx(handle, cu_dtype.at(dtype), cu_op.at(AT), cu_op.at(BT), M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc, algo);
}

View File

@@ -88,8 +88,8 @@ void matmul(restrict read_only align(16) )" + a_ty + R"( *A,
restrict read_only align(16) )" + c_ty + R"( *C,
int M, int N, int K,
)" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc) {
int ridx = get_range_id(0);
int ridy = get_range_id(1);
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rxa[TM] = ridx * TM + (0 ... TM);
int ryb[TN] = ridy * TN + (0 ... TN);
int rka[TK] = 0 ... TK;

View File

@@ -169,7 +169,7 @@ private:
void lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_downcast(ir::downcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_dynamic_range_idx(ir::nv_dynamic_range_idx_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_dynamic_program_idx(ir::nv_dynamic_program_idx_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_reshape(ir::reshape_inst* x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_splat(ir::splat_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
void lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);

View File

@@ -126,7 +126,7 @@ public:
value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
// Built-in instruction
value *create_get_range_id(unsigned axis, const std::string &name = "");
value *create_get_program_id(unsigned axis, const std::string &name = "");
value *create_get_num_program(unsigned axis, const std::string &name = "");
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");

View File

@@ -496,10 +496,10 @@ protected:
using instruction::instruction;
};
class get_range_id_inst: public builtin_inst {
class get_program_id_inst: public builtin_inst {
private:
get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
std::string repr_impl() const { return "get_range_id(" + std::to_string(axis_) + ")"; }
get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
std::string repr_impl() const { return "get_program_id(" + std::to_string(axis_) + ")"; }
public:
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
@@ -668,23 +668,23 @@ public:
};
// On NVIDIA, implementation is such that
// constant_range = nv_dynamic_range_idx + nv_static_range_idx
// so as to enable re-association on nv_static_range_idx which is constant
class nv_dynamic_range_idx_inst: public instruction {
// constant_range = nv_dynamic_program_idx + nv_static_program_idx
// so as to enable re-association on nv_static_program_idx which is constant
class nv_dynamic_program_idx_inst: public instruction {
private:
nv_dynamic_range_idx_inst(type *ty, const std::string &name, instruction *next);
std::string repr_impl() const { return "nv_dynamic_range_idx"; }
nv_dynamic_program_idx_inst(type *ty, const std::string &name, instruction *next);
std::string repr_impl() const { return "nv_dynamic_program_idx"; }
public:
static nv_dynamic_range_idx_inst* create(type *ty, const std::string &name = "", instruction *next = nullptr);
static nv_dynamic_program_idx_inst* create(type *ty, const std::string &name = "", instruction *next = nullptr);
};
class nv_static_range_idx: public constant {
class nv_static_program_idx: public constant {
private:
nv_static_range_idx(constant_range *range);
nv_static_program_idx(constant_range *range);
public:
static nv_static_range_idx *get(constant_range* range);
static nv_static_program_idx *get(constant_range* range);
constant_range* get_range() const;
private:

View File

@@ -71,9 +71,9 @@ private:
const constant* size_;
};
class get_range_id_expression: public builtin_expression{
class get_program_id_expression: public builtin_expression{
public:
get_range_id_expression(node *axis): axis_((constant*)axis) { }
get_program_id_expression(node *axis): axis_((constant*)axis) { }
ir::value* codegen(ir::module *) const;
private:

View File

@@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
%token IF ELSE FOR CONTINUE WHILE
%token NEWAXIS ELLIPSIS AT
%token GET_NUM_PROGRAM GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST RESHAPE
%token GET_NUM_PROGRAM GET_PROGRAM_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST RESHAPE
%start translation_unit
%%
@@ -120,7 +120,7 @@ identifier
/* Built-in */
builtin_expression
: GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); }
: GET_PROGRAM_ID '(' constant ')' { $$ = new get_program_id_expression($3); }
| GET_NUM_PROGRAM '(' constant ')' { $$ = new get_num_program_expression($3); }
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
| SQRT '(' expression ')' { $$ = new sqrt_expression($3); }

View File

@@ -43,7 +43,7 @@ using triton::lang::return_void;
"float" { return return_impl(FP32, yytext); }
"double" { return return_impl(FP64, yytext); }
"..." { return return_impl(ELLIPSIS, yytext); }
"get_range_id" { return return_impl(GET_RANGE_ID, yytext); }
"get_program_id" { return return_impl(GET_PROGRAM_ID, yytext); }
"get_num_program" { return return_impl(GET_NUM_PROGRAM, yytext); }
"__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); }
"__atomic_exch" { return return_impl(ATOMIC_EXCH, yytext); }

View File

@@ -227,10 +227,10 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
if(auto *x = dynamic_cast<ir::constant_range*>(v)){
return cache(x->get_first()->get_value());
}
if(auto *x = dynamic_cast<ir::nv_dynamic_range_idx_inst*>(v)){
if(auto *x = dynamic_cast<ir::nv_dynamic_program_idx_inst*>(v)){
return cache(128);
}
if(auto *x = dynamic_cast<ir::nv_static_range_idx*>(v)){
if(auto *x = dynamic_cast<ir::nv_static_program_idx*>(v)){
return cache(x->get_range()->get_first()->get_value());
}
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){

View File

@@ -411,7 +411,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
Value *else_value = value(ii->get_operand(2));
return builder.Insert(SelectInst::Create(pred, if_value, else_value));
}
if(ir::get_range_id_inst* ii = dynamic_cast<ir::get_range_id_inst*>(inst)){
if(ir::get_program_id_inst* ii = dynamic_cast<ir::get_program_id_inst*>(inst)){
Value *result = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
return (Instruction*)result;
}
@@ -837,7 +837,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
T->set_value(idx, idx[0]);
});
}
if(is_inserted && dynamic_cast<ir::nv_static_range_idx*>(v)){
if(is_inserted && dynamic_cast<ir::nv_static_program_idx*>(v)){
T->for_each([&](indices_t idx){
assert(idx.size() == 1);
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
@@ -996,7 +996,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
}
}
void selection::lower_dynamic_range_idx(ir::nv_dynamic_range_idx_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
void selection::lower_dynamic_program_idx(ir::nv_dynamic_program_idx_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
distributed_tile* result = (distributed_tile*)tmap_.at(x);
result->for_each([&](indices_t idx){
assert(idx.size() == 1);
@@ -1418,8 +1418,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
lower_downcast(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::reduce_inst*>(ins))
lower_reduce(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::nv_dynamic_range_idx_inst*>(ins))
lower_dynamic_range_idx(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::nv_dynamic_program_idx_inst*>(ins))
lower_dynamic_program_idx(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::reshape_inst*>(ins))
lower_reshape(x, ctx, fn, builder);
else if(auto *x = dynamic_cast<ir::splat_inst*>(ins))

View File

@@ -164,7 +164,7 @@ reassociate::reassociate(analysis::tune* params)
void reassociate::run(ir::module &mod) {
ir::builder &builder = mod.get_builder();
// constant_range -> nv_dynamic_range_idx + nv_static_range_idx
// constant_range -> nv_dynamic_program_idx + nv_static_program_idx
for(ir::function *fn: mod.get_function_list()){
std::vector<ir::constant_range*> ranges;
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
@@ -178,8 +178,8 @@ void reassociate::run(ir::module &mod) {
builder.set_insert_point(rpo.front()->get_first_non_phi());
for(ir::constant_range* old_range: ranges){
ir::value* dyn_range = builder.insert(ir::nv_dynamic_range_idx_inst::create(old_range->get_type()));
ir::value* static_range = ir::nv_static_range_idx::get(old_range);
ir::value* dyn_range = builder.insert(ir::nv_dynamic_program_idx_inst::create(old_range->get_type()));
ir::value* static_range = ir::nv_static_program_idx::get(old_range);
ir::value* new_range = builder.create_add(dyn_range, static_range);
old_range->replace_all_uses_with(new_range);
params_->copy(dyn_range, old_range);

View File

@@ -82,7 +82,7 @@ void batchnorm_forward(float *Y, float *M, float *V,
int rx[TM] = 0 ... TM;
float *px[TM];
float x[TM] = 0;
int c = get_range_id(1);
int c = get_program_id(1);
float g = *(G + c);
float b = *(B + c);
@@ -177,7 +177,7 @@ void batchnorm_backward(float *DX, float *DG, float *DB,
restrict read_only float *V,
int DHWN, float rcpDHWN, float epsilon) {
int rx[TM] = 0 ... TM;
int c = get_range_id(1);
int c = get_program_id(1);
int offset = c*DHWN;
float g = *(G + c);
float mean = *(M + c);

View File

@@ -114,11 +114,11 @@ std::string dot::triton_c_src_ydx() const {
int lda, int ldb, int ldc,
int N, int* lut,
int* locks, int nlocks) {
int ridx = get_range_id(0);
int ridx = get_program_id(0);
float acc[TM, TN] = 0;
int rka[TK] = 0 ... TK;
int rkb[TK] = 0 ... TK;
int *header = lut + get_range_id(1) * 4;
int *header = lut + get_program_id(1) * 4;
int offset = *(header + 0);
int K = *(header + 1);
int column = *(header + 2);
@@ -191,7 +191,7 @@ std::string dot::triton_c_src_dw() const {
int lda, int ldb, int ldc,
int N, int* lut,
int* locks, int nlocks) {
int ridx = get_range_id(0);
int ridx = get_program_id(0);
float acc[TM, TN] = 0;
int rka[TK] = 0 ... TK;
int rkb[TK] = 0 ... TK;

View File

@@ -686,8 +686,8 @@ if(b_lut_){
float* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
bool checkc0[TM] = rxc < M;
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
int ridx = get_range_id(0);
int ridy = get_range_id(1);
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int *plock = locks + ridx + ridy*grid0;
while(__atomic_cas(plock, 0, 1) == 1);
int *pcount = plock + grid0*grid1;

View File

@@ -116,8 +116,8 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
int M, int N, int K,
)" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc,
int bound, int *locks, int grid0, int grid1) {
int ridx = get_range_id(0);
int ridy = get_range_id(1);
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rxa[TM] = ridx * TM + (0 ... TM);
int ryb[TN] = ridy * TN + (0 ... TN);
int rka[TK] = 0 ... TK;

View File

@@ -354,9 +354,9 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
int BH, int BW,
int CH, int CW,
int* locks, int grid0, int grid1, int grid2) {
int ridx = get_range_id(0);
int ridy = get_range_id(1);
int rz = get_range_id(2);
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rz = get_program_id(2);
int rxa[TM] = ridx*TM + (0 ... TM);
int ryb[TN] = ridy*TN + (0 ... TN);
int rka[TK] = 0 ... TK;

View File

@@ -300,8 +300,8 @@ value *builder::create_downcast(value *arg, const std::string &name) {
// built-in instructions
//===----------------------------------------------------------------------===//
value *builder::create_get_range_id(unsigned axis, const std::string &name) {
return insert(get_range_id_inst::create(ctx_, axis, name));
value *builder::create_get_program_id(unsigned axis, const std::string &name) {
return insert(get_program_id_inst::create(ctx_, axis, name));
}
value *builder::create_get_num_program(unsigned axis, const std::string &name) {

View File

@@ -664,14 +664,14 @@ instruction* select_inst::create(value *pred, value *if_value, value *else_value
//===----------------------------------------------------------------------===//
// get_range_id
get_range_id_inst::get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
// get_program_id
get_program_id_inst::get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
: builtin_inst(ty, 0, 1, name, next), axis_(axis){
}
instruction* get_range_id_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
return new get_range_id_inst(type::get_int32_ty(ctx), axis, name, next);
instruction* get_program_id_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
return new get_program_id_inst(type::get_int32_ty(ctx), axis, name, next);
}
// get_num_program
@@ -745,25 +745,25 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru
return new barrier_inst(ctx, name, next);
}
// nv_dynamic_range_idx
nv_dynamic_range_idx_inst::nv_dynamic_range_idx_inst(type *ty, const std::string &name, instruction *next)
// nv_dynamic_program_idx
nv_dynamic_program_idx_inst::nv_dynamic_program_idx_inst(type *ty, const std::string &name, instruction *next)
: instruction(ty, 0, 1, name, next) { }
nv_dynamic_range_idx_inst* nv_dynamic_range_idx_inst::create(type *ty, const std::string &name, instruction *next) {
return new nv_dynamic_range_idx_inst(ty, name, next);
nv_dynamic_program_idx_inst* nv_dynamic_program_idx_inst::create(type *ty, const std::string &name, instruction *next) {
return new nv_dynamic_program_idx_inst(ty, name, next);
}
// nv_static_range_idx
nv_static_range_idx::nv_static_range_idx(constant_range *range)
// nv_static_program_idx
nv_static_program_idx::nv_static_program_idx(constant_range *range)
: constant(range->get_type(), 0), range_(range) { }
constant_range* nv_static_range_idx::get_range() const
constant_range* nv_static_program_idx::get_range() const
{ return range_; }
nv_static_range_idx* nv_static_range_idx::get(constant_range* range) {
static std::map<constant_range*, nv_static_range_idx*> cache;
nv_static_program_idx* nv_static_program_idx::get(constant_range* range) {
static std::map<constant_range*, nv_static_program_idx*> cache;
if(cache.find(range) == cache.end())
cache.insert({range, new nv_static_range_idx(range)});
cache.insert({range, new nv_static_program_idx(range)});
return cache.at(range);
}

View File

@@ -115,9 +115,9 @@ ir::value* alloc_const_expression::codegen(ir::module *mod) const {
return res;
}
// get_range_id
ir::value* get_range_id_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_get_range_id(axis_->value());
// get_program_id
ir::value* get_program_id_expression::codegen(ir::module *mod) const {
return mod->get_builder().create_get_program_id(axis_->value());
}
// get_num_program

View File

@@ -1,14 +1,6 @@
import libtriton
import triton
import tensorflow as tf
import distutils
import distutils.log
import setuptools.command.build_ext
import setuptools
import numpy as np
import os
import tempfile
import shutil
import hashlib
src = """
const tunable int TM = {128};
@@ -20,8 +12,8 @@ void matmul(restrict read_only align(16) half *A,
restrict read_only align(16) half *C,
int M, int N, int K,
multiple_of(8) int lda, multiple_of(8) int ldb, int ldc) {
int ridx = get_range_id(0);
int ridy = get_range_id(1);
int ridx = get_program_id(0);
int ridy = get_program_id(1);
int rxa[TM] = ridx * TM + (0 ... TM);
int ryb[TN] = ridy * TN + (0 ... TN);
int rka[TK] = 0 ... TK;
@@ -40,7 +32,7 @@ void matmul(restrict read_only align(16) half *A,
}
int rxc[TM] = ridx * TM + (0 ... TM);
int ryc[TN] = ridy * TN + (0 ... TN);
half* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
half* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis]*ldc;
half c[TM, TN] = xc;
bool checkc0[TM] = rxc < M;
bool checkc1[TN] = ryc < N;
@@ -49,100 +41,10 @@ void matmul(restrict read_only align(16) half *A,
}
"""
extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so')
def make_bindings(src, outputs, grids):
return libtriton.make_tensorflow_src(src, outputs, grids)
def make_cache_path(src):
md5 = hashlib.sha1(src.encode())
hexhash = md5.hexdigest()
home = os.path.expanduser('~')
cacheroot = os.path.join(home, '.triton', 'cache')
cachepath = os.path.join(cacheroot, str(hexhash))
if not os.path.exists(cachepath):
os.makedirs(cachepath)
print(cachepath)
return cachepath
def write_bindings(src, root):
cpp = os.path.join(root, 'tensorflow.cpp')
so = os.path.join(root, 'tensorflow.so')
recompile = False
# recompile if .so does not exist
if not os.path.exists(cpp) or not os.path.exists(so):
recompile = True
# recompile if cpp was modified after .so
elif max(cpp, so, key=os.path.getctime) == cpp:
recompile = True
# write cpp file
if recompile:
with open(cpp, 'w+') as handle:
handle.writelines(src)
# return path of cpp file
return cpp
def build(src, path):
# include directories
triton_include_dirs = ['/home/philippe/development/triton/include']
tensorflow_include_dirs = [tf.sysconfig.get_include()]
cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/']
include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs
# library directories
triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))]
tensorflow_library_dirs = [tf.sysconfig.get_lib()]
library_dirs = triton_library_dirs + tensorflow_library_dirs
# libraries
libraries = ['tensorflow_framework', 'triton']
# extra arguments
extra_compile_args = []
extra_link_args = []
# create extension module
ext = setuptools.Extension(
name = 'test',
language = 'c++',
sources = [src],
include_dirs = include_dirs,
extra_compile_args = extra_compile_args,
extra_link_args = extra_link_args,
library_dirs = library_dirs,
libraries = libraries
)
# build extension module
args = ['build_ext']
tmp = tempfile.mkdtemp()
args.append('--build-temp=' + tmp)
args.append('--build-lib=' + path)
args.append('-q')
args = dict(
name = 'test',
ext_modules = [ext],
script_args = args,
)
setuptools.setup(**args)
shutil.rmtree(tmp)
def make_tensorflow_op(src, outputs, grids):
bindings = make_bindings(src, outputs, grids)
cache_path = make_cache_path(bindings)
cpp = write_bindings(bindings, cache_path)
build(cpp, cache_path)
result = tf.load_op_library(os.path.join(cache_path, 'test.cpython-36m-x86_64-linux-gnu.so'))
return result
library_dir = os.path.dirname(os.path.realpath(__file__))
module = make_tensorflow_op(src, ['C'], ['(M + #TM - 1)/#TM', '(N + #TN - 1)/#TN'])
print(module.matmul)
class dot:
def __init__(self):
trans_a = True
trans_b = False
self.matmul = triton.make_tensorflow_op(src, ['C'], ['(M + #TM - 1)/#TM', '(N + #TN - 1)/#TN'])
def __call__(self, a, b):
shape_a = tf.shape(a)
@@ -152,17 +54,17 @@ class dot:
N = shape_b[0]
lda = M
ldb = K
ldc = M
c = extra_ops.alloc_empty(tf.stack([M, N]))
return module.matmul(a, b, c, M, N, K, lda, ldb, ldc)
ldc = N
c = triton.empty([M, N])
return self.matmul.matmul(a, b, c, M, N, K, lda, ldb, ldc)
dot_nt = dot()
dot_tn = dot()
def run_dot():
M, N, K = 128, 128, 128
a = tf.placeholder(tf.float16, shape=[M, K])
b = tf.placeholder(tf.float16, shape=[N, K])
# c = tf.matmul(a, b, transpose_a=True)
c = dot_nt(a, b)
c = dot_tn(a, b)
# Reference
ha = np.random.rand(M, K).astype(np.float16)
hb = np.random.rand(N, K).astype(np.float16)
@@ -172,7 +74,7 @@ def run_dot():
result = sess.run([c], feed_dict = {a: ha,
b: hb})[0]
# Test
hresult = np.dot(ha.T, hb).T
hresult = np.dot(ha.T, hb)
dif = np.abs(result - hresult)
np.savetxt('dif.dat', dif, '%2.4f')
print(hresult)

View File

@@ -18,6 +18,7 @@ class CMakeExtension(Extension):
class CMakeBuild(build_ext):
def run(self):
try:
out = subprocess.check_output(['cmake', '--version'])
@@ -80,6 +81,7 @@ setup(
author_email='ptillet@g.harvard.edu',
description='A language and compiler for custom Deep Learning operations',
long_description='',
packages=['triton'],
ext_modules=[CMakeExtension('triton')],
cmdclass=dict(build_ext=CMakeBuild),
zip_safe=False,

View File

@@ -177,8 +177,8 @@ result += R"(
std::regex regex("#([a-zA-Z]([a-zA-Z]|[0-9])*)");
std::vector<std::string> grids;
for(size_t i = macros.size(); i < 3; i++)
std::vector<std::string> grids = macros;
for(size_t i = grids.size(); i < 3; i++)
grids.push_back("1");
std::string grid = "rt::grid_t{";
for(size_t i = 0; i < grids.size(); i++){

View File

@@ -0,0 +1 @@
from .ops import *

103
python/triton/ops.py Normal file
View File

@@ -0,0 +1,103 @@
# import for cache
import os
import tempfile
import shutil
import hashlib
import sysconfig
import sys
# import for just-in-time compilation
import distutils
import setuptools.command.build_ext
import setuptools
# triton
import libtriton
# frameworks
import tensorflow as tf
extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so')
def make_bindings(src, outputs, grids):
return libtriton.make_tensorflow_src(src, outputs, grids)
def make_cache_path(src):
md5 = hashlib.sha1(src.encode())
hexhash = md5.hexdigest()
home = os.path.expanduser('~')
cacheroot = os.path.join(home, '.triton', 'cache')
cachepath = os.path.join(cacheroot, str(hexhash))
if not os.path.exists(cachepath):
os.makedirs(cachepath)
return cachepath
def write_bindings(src, root):
cpp = os.path.join(root, 'tensorflow.cpp')
suffix = sysconfig.get_config_var('EXT_SUFFIX')
so = os.path.join(root, 'tensorflow{suffix}'.format(suffix=suffix))
recompile = False
# recompile if .so does not exist
if not os.path.exists(cpp) or not os.path.exists(so):
recompile = True
# recompile if cpp was modified after .so
elif max(cpp, so, key=os.path.getctime) == cpp:
recompile = True
# write cpp file
if recompile:
with open(cpp, 'w+') as handle:
handle.writelines(src)
# return path of cpp file
return (cpp, so)
def build(src, path):
# include directories
triton_include_dirs = ['/home/philippe/development/triton/include']
tensorflow_include_dirs = [tf.sysconfig.get_include()]
cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/']
include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs
# library directories
triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))]
tensorflow_library_dirs = [tf.sysconfig.get_lib()]
library_dirs = triton_library_dirs + tensorflow_library_dirs
# libraries
libraries = ['tensorflow_framework', 'triton']
# extra arguments
extra_compile_args = []
extra_link_args = []
# dependences
depends = [os.path.realpath(libtriton.__file__)]
# create extension module
ext = setuptools.Extension(
name = 'tensorflow',
language = 'c++',
sources = [src],
include_dirs = include_dirs,
extra_compile_args = extra_compile_args,
extra_link_args = extra_link_args,
library_dirs = library_dirs,
libraries = libraries,
depends = depends
)
# build extension module
args = ['build_ext']
tmp = tempfile.mkdtemp()
args.append('--build-temp=' + tmp)
args.append('--build-lib=' + path)
args.append('-q')
args = dict(
name = 'tensorflow',
ext_modules = [ext],
script_args = args,
)
setuptools.setup(**args)
shutil.rmtree(tmp)
def make_tensorflow_op(src, outputs, grids):
bindings = make_bindings(src, outputs, grids)
cache_path = make_cache_path(bindings)
cpp, so = write_bindings(bindings, cache_path)
build(cpp, cache_path)
result = tf.load_op_library(so)
return result
def empty(shapes):
return extra_ops.alloc_empty(tf.stack(shapes))