[python] added basic tensorflow support
This commit is contained in:
160
examples/cpp/cuda.h
Normal file
160
examples/cpp/cuda.h
Normal file
@@ -0,0 +1,160 @@
|
||||
/* Copyright 2015-2017 Philippe Tillet
|
||||
*
|
||||
* Permission is hereby granted, free of charge, to any person obtaining
|
||||
* a copy of this software and associated documentation files
|
||||
* (the "Software"), to deal in the Software without restriction,
|
||||
* including without limitation the rights to use, copy, modify, merge,
|
||||
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||
* and to permit persons to whom the Software is furnished to do so,
|
||||
* subject to the following conditions:
|
||||
*
|
||||
* The above copyright notice and this permission notice shall be
|
||||
* included in all copies or substantial portions of the Software.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
#include "cublas_v2.h"
|
||||
#include "triton/driver/buffer.h"
|
||||
#include "triton/driver/stream.h"
|
||||
#include "triton/driver/context.h"
|
||||
#include "triton/tools/bench.hpp"
|
||||
|
||||
enum cublasStrategy_t{
|
||||
CUBLAS_PREFER_FASTEST,
|
||||
CUBLAS_HEURISTICS
|
||||
};
|
||||
|
||||
enum DType{
|
||||
HALF_TYPE,
|
||||
FLOAT_TYPE,
|
||||
DOUBLE_TYPE,
|
||||
};
|
||||
|
||||
inline size_t size_of(DType dtype){
|
||||
switch (dtype) {
|
||||
case HALF_TYPE: return 2;
|
||||
case FLOAT_TYPE: return 4;
|
||||
case DOUBLE_TYPE: return 8;
|
||||
default: throw;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<cublasGemmAlgo_t> gather_all_algos() {
|
||||
std::vector<cublasGemmAlgo_t> result;
|
||||
// non-tensor ops
|
||||
for(int i = -1; i < 24; i++)
|
||||
result.push_back((cublasGemmAlgo_t)i);
|
||||
// tensor ops
|
||||
for(int i = 99; i < 116; i++)
|
||||
result.push_back((cublasGemmAlgo_t)i);
|
||||
return result;
|
||||
}
|
||||
|
||||
static const std::vector<cublasGemmAlgo_t> algorithms = gather_all_algos();
|
||||
|
||||
static const std::map<DType, cudaDataType> cu_dtype = {
|
||||
{HALF_TYPE, CUDA_R_16F},
|
||||
{FLOAT_TYPE, CUDA_R_32F},
|
||||
{DOUBLE_TYPE, CUDA_R_64F}
|
||||
};
|
||||
|
||||
static const std::map<char, cublasOperation_t> cu_op = {
|
||||
{false, CUBLAS_OP_N},
|
||||
{true, CUBLAS_OP_T}
|
||||
};
|
||||
|
||||
inline cublasGemmAlgo_t cublasGemmFastest(
|
||||
triton::driver::stream* stream,
|
||||
cublasHandle_t handle, cudaDataType cudt,
|
||||
cublasOperation_t AT, cublasOperation_t BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
|
||||
void* beta, CUdeviceptr C, int32_t ldc) {
|
||||
|
||||
// cache to avoid re-benchmarking
|
||||
typedef std::tuple<cudaDataType_t,
|
||||
cublasOperation_t, cublasOperation_t,
|
||||
int32_t, int32_t, int32_t> key_t;
|
||||
static std::map<key_t, cublasGemmAlgo_t> cache;
|
||||
key_t key(cudt, AT, BT, M, N, K);
|
||||
// benchmark algorithms if necessary
|
||||
if(cache.find(key) == cache.end()){
|
||||
std::vector<double> times;
|
||||
for(cublasGemmAlgo_t a: algorithms) {
|
||||
cublasStatus_t status;
|
||||
double nanosec = triton::tools::bench([&](){ status = cublasGemmEx(handle, AT, BT,
|
||||
M, N, K,
|
||||
alpha, (const void*)A, cudt, lda,
|
||||
(const void*)B, cudt, ldb,
|
||||
beta, (void*)C, cudt, ldc, cudt,
|
||||
a); }, stream);
|
||||
if(status != CUBLAS_STATUS_SUCCESS)
|
||||
nanosec = INFINITY;
|
||||
}
|
||||
size_t argmin = std::min_element(times.begin(), times.end()) - times.begin();
|
||||
assert(times[argmin] != INFINITY);
|
||||
cache.insert({key, algorithms[argmin]});
|
||||
}
|
||||
|
||||
// return best algorithm
|
||||
return cache.at(key);
|
||||
}
|
||||
|
||||
/* Wrapper for cublasGemmEx */
|
||||
inline cublasStatus_t cublasGemmEx(cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K,
|
||||
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
|
||||
void* beta, CUdeviceptr C, int32_t ldc, cublasGemmAlgo_t algo)
|
||||
{
|
||||
cublasStatus_t status = cublasGemmEx(handle, AT, BT, M, N, K, alpha, (const void*)A, cudt, lda, (const void*)B, cudt, ldb, beta, (void*)C, cudt, ldc, cudt, algo);
|
||||
if(status != CUBLAS_STATUS_SUCCESS){
|
||||
std::cout << status;
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Get cuBLAS handle */
|
||||
cublasHandle_t cublasGetHandle(triton::driver::stream* stream) {
|
||||
static std::map<CUstream, cublasHandle_t> cache;
|
||||
CUstream key = *stream->cu();
|
||||
|
||||
// create handle if necessary
|
||||
if(cache.find(key) == cache.end()) {
|
||||
cublasHandle_t handle;
|
||||
if(cublasCreate_v2(&handle) != CUBLAS_STATUS_SUCCESS)
|
||||
throw std::runtime_error("Error: could not create cuBLAS handle");
|
||||
cublasSetStream_v2(handle, key);
|
||||
cache.insert({key, handle});
|
||||
}
|
||||
|
||||
// return handle for the stream
|
||||
return cache.at(key);
|
||||
}
|
||||
|
||||
/* Simplified API for default GEMM */
|
||||
inline void cublasGemm(DType dtype, triton::driver::stream* stream, bool AT, bool BT,
|
||||
int32_t M, int32_t N, int32_t K,
|
||||
void* alpha, triton::driver::buffer* A, int32_t lda,
|
||||
triton::driver::buffer* B, int32_t ldb,
|
||||
void* beta, triton::driver::buffer* C, int32_t ldc,
|
||||
cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT) {
|
||||
triton::driver::cu_context::context_switcher scope(*stream->context());
|
||||
static cublasHandle_t handle = cublasGetHandle(stream);
|
||||
if(dtype == HALF_TYPE)
|
||||
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
|
||||
cublasStatus_t status;
|
||||
if(fastest)
|
||||
*fastest = cublasGemmFastest(stream, handle, cu_dtype.at(dtype), cu_op.at(AT), cu_op.at(BT), M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc);
|
||||
else
|
||||
status = cublasGemmEx(handle, cu_dtype.at(dtype), cu_op.at(AT), cu_op.at(BT), M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc, algo);
|
||||
}
|
@@ -88,8 +88,8 @@ void matmul(restrict read_only align(16) )" + a_ty + R"( *A,
|
||||
restrict read_only align(16) )" + c_ty + R"( *C,
|
||||
int M, int N, int K,
|
||||
)" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc) {
|
||||
int ridx = get_range_id(0);
|
||||
int ridy = get_range_id(1);
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int ryb[TN] = ridy * TN + (0 ... TN);
|
||||
int rka[TK] = 0 ... TK;
|
||||
|
@@ -169,7 +169,7 @@ private:
|
||||
void lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_downcast(ir::downcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_dynamic_range_idx(ir::nv_dynamic_range_idx_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_dynamic_program_idx(ir::nv_dynamic_program_idx_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_reshape(ir::reshape_inst* x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_splat(ir::splat_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
void lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||
|
@@ -126,7 +126,7 @@ public:
|
||||
value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||
value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||
// Built-in instruction
|
||||
value *create_get_range_id(unsigned axis, const std::string &name = "");
|
||||
value *create_get_program_id(unsigned axis, const std::string &name = "");
|
||||
value *create_get_num_program(unsigned axis, const std::string &name = "");
|
||||
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
|
||||
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
|
||||
|
@@ -496,10 +496,10 @@ protected:
|
||||
using instruction::instruction;
|
||||
};
|
||||
|
||||
class get_range_id_inst: public builtin_inst {
|
||||
class get_program_id_inst: public builtin_inst {
|
||||
private:
|
||||
get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "get_range_id(" + std::to_string(axis_) + ")"; }
|
||||
get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "get_program_id(" + std::to_string(axis_) + ")"; }
|
||||
|
||||
public:
|
||||
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||
@@ -668,23 +668,23 @@ public:
|
||||
};
|
||||
|
||||
// On NVIDIA, implementation is such that
|
||||
// constant_range = nv_dynamic_range_idx + nv_static_range_idx
|
||||
// so as to enable re-association on nv_static_range_idx which is constant
|
||||
class nv_dynamic_range_idx_inst: public instruction {
|
||||
// constant_range = nv_dynamic_program_idx + nv_static_program_idx
|
||||
// so as to enable re-association on nv_static_program_idx which is constant
|
||||
class nv_dynamic_program_idx_inst: public instruction {
|
||||
private:
|
||||
nv_dynamic_range_idx_inst(type *ty, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "nv_dynamic_range_idx"; }
|
||||
nv_dynamic_program_idx_inst(type *ty, const std::string &name, instruction *next);
|
||||
std::string repr_impl() const { return "nv_dynamic_program_idx"; }
|
||||
|
||||
public:
|
||||
static nv_dynamic_range_idx_inst* create(type *ty, const std::string &name = "", instruction *next = nullptr);
|
||||
static nv_dynamic_program_idx_inst* create(type *ty, const std::string &name = "", instruction *next = nullptr);
|
||||
};
|
||||
|
||||
class nv_static_range_idx: public constant {
|
||||
class nv_static_program_idx: public constant {
|
||||
private:
|
||||
nv_static_range_idx(constant_range *range);
|
||||
nv_static_program_idx(constant_range *range);
|
||||
|
||||
public:
|
||||
static nv_static_range_idx *get(constant_range* range);
|
||||
static nv_static_program_idx *get(constant_range* range);
|
||||
constant_range* get_range() const;
|
||||
|
||||
private:
|
||||
|
@@ -71,9 +71,9 @@ private:
|
||||
const constant* size_;
|
||||
};
|
||||
|
||||
class get_range_id_expression: public builtin_expression{
|
||||
class get_program_id_expression: public builtin_expression{
|
||||
public:
|
||||
get_range_id_expression(node *axis): axis_((constant*)axis) { }
|
||||
get_program_id_expression(node *axis): axis_((constant*)axis) { }
|
||||
ir::value* codegen(ir::module *) const;
|
||||
|
||||
private:
|
||||
|
@@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
|
||||
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
|
||||
%token IF ELSE FOR CONTINUE WHILE
|
||||
%token NEWAXIS ELLIPSIS AT
|
||||
%token GET_NUM_PROGRAM GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST RESHAPE
|
||||
%token GET_NUM_PROGRAM GET_PROGRAM_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST RESHAPE
|
||||
|
||||
%start translation_unit
|
||||
%%
|
||||
@@ -120,7 +120,7 @@ identifier
|
||||
|
||||
/* Built-in */
|
||||
builtin_expression
|
||||
: GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); }
|
||||
: GET_PROGRAM_ID '(' constant ')' { $$ = new get_program_id_expression($3); }
|
||||
| GET_NUM_PROGRAM '(' constant ')' { $$ = new get_num_program_expression($3); }
|
||||
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
||||
| SQRT '(' expression ')' { $$ = new sqrt_expression($3); }
|
||||
|
@@ -43,7 +43,7 @@ using triton::lang::return_void;
|
||||
"float" { return return_impl(FP32, yytext); }
|
||||
"double" { return return_impl(FP64, yytext); }
|
||||
"..." { return return_impl(ELLIPSIS, yytext); }
|
||||
"get_range_id" { return return_impl(GET_RANGE_ID, yytext); }
|
||||
"get_program_id" { return return_impl(GET_PROGRAM_ID, yytext); }
|
||||
"get_num_program" { return return_impl(GET_NUM_PROGRAM, yytext); }
|
||||
"__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); }
|
||||
"__atomic_exch" { return return_impl(ATOMIC_EXCH, yytext); }
|
||||
|
@@ -227,10 +227,10 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
|
||||
if(auto *x = dynamic_cast<ir::constant_range*>(v)){
|
||||
return cache(x->get_first()->get_value());
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::nv_dynamic_range_idx_inst*>(v)){
|
||||
if(auto *x = dynamic_cast<ir::nv_dynamic_program_idx_inst*>(v)){
|
||||
return cache(128);
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::nv_static_range_idx*>(v)){
|
||||
if(auto *x = dynamic_cast<ir::nv_static_program_idx*>(v)){
|
||||
return cache(x->get_range()->get_first()->get_value());
|
||||
}
|
||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
|
||||
|
@@ -411,7 +411,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
||||
Value *else_value = value(ii->get_operand(2));
|
||||
return builder.Insert(SelectInst::Create(pred, if_value, else_value));
|
||||
}
|
||||
if(ir::get_range_id_inst* ii = dynamic_cast<ir::get_range_id_inst*>(inst)){
|
||||
if(ir::get_program_id_inst* ii = dynamic_cast<ir::get_program_id_inst*>(inst)){
|
||||
Value *result = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
|
||||
return (Instruction*)result;
|
||||
}
|
||||
@@ -837,7 +837,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
||||
T->set_value(idx, idx[0]);
|
||||
});
|
||||
}
|
||||
if(is_inserted && dynamic_cast<ir::nv_static_range_idx*>(v)){
|
||||
if(is_inserted && dynamic_cast<ir::nv_static_program_idx*>(v)){
|
||||
T->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||
@@ -996,7 +996,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
||||
}
|
||||
}
|
||||
|
||||
void selection::lower_dynamic_range_idx(ir::nv_dynamic_range_idx_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||
void selection::lower_dynamic_program_idx(ir::nv_dynamic_program_idx_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||
result->for_each([&](indices_t idx){
|
||||
assert(idx.size() == 1);
|
||||
@@ -1418,8 +1418,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
||||
lower_downcast(x, ctx, fn, builder);
|
||||
else if(auto *x = dynamic_cast<ir::reduce_inst*>(ins))
|
||||
lower_reduce(x, ctx, fn, builder);
|
||||
else if(auto *x = dynamic_cast<ir::nv_dynamic_range_idx_inst*>(ins))
|
||||
lower_dynamic_range_idx(x, ctx, fn, builder);
|
||||
else if(auto *x = dynamic_cast<ir::nv_dynamic_program_idx_inst*>(ins))
|
||||
lower_dynamic_program_idx(x, ctx, fn, builder);
|
||||
else if(auto *x = dynamic_cast<ir::reshape_inst*>(ins))
|
||||
lower_reshape(x, ctx, fn, builder);
|
||||
else if(auto *x = dynamic_cast<ir::splat_inst*>(ins))
|
||||
|
@@ -164,7 +164,7 @@ reassociate::reassociate(analysis::tune* params)
|
||||
void reassociate::run(ir::module &mod) {
|
||||
ir::builder &builder = mod.get_builder();
|
||||
|
||||
// constant_range -> nv_dynamic_range_idx + nv_static_range_idx
|
||||
// constant_range -> nv_dynamic_program_idx + nv_static_program_idx
|
||||
for(ir::function *fn: mod.get_function_list()){
|
||||
std::vector<ir::constant_range*> ranges;
|
||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||
@@ -178,8 +178,8 @@ void reassociate::run(ir::module &mod) {
|
||||
|
||||
builder.set_insert_point(rpo.front()->get_first_non_phi());
|
||||
for(ir::constant_range* old_range: ranges){
|
||||
ir::value* dyn_range = builder.insert(ir::nv_dynamic_range_idx_inst::create(old_range->get_type()));
|
||||
ir::value* static_range = ir::nv_static_range_idx::get(old_range);
|
||||
ir::value* dyn_range = builder.insert(ir::nv_dynamic_program_idx_inst::create(old_range->get_type()));
|
||||
ir::value* static_range = ir::nv_static_program_idx::get(old_range);
|
||||
ir::value* new_range = builder.create_add(dyn_range, static_range);
|
||||
old_range->replace_all_uses_with(new_range);
|
||||
params_->copy(dyn_range, old_range);
|
||||
|
@@ -82,7 +82,7 @@ void batchnorm_forward(float *Y, float *M, float *V,
|
||||
int rx[TM] = 0 ... TM;
|
||||
float *px[TM];
|
||||
float x[TM] = 0;
|
||||
int c = get_range_id(1);
|
||||
int c = get_program_id(1);
|
||||
float g = *(G + c);
|
||||
float b = *(B + c);
|
||||
|
||||
@@ -177,7 +177,7 @@ void batchnorm_backward(float *DX, float *DG, float *DB,
|
||||
restrict read_only float *V,
|
||||
int DHWN, float rcpDHWN, float epsilon) {
|
||||
int rx[TM] = 0 ... TM;
|
||||
int c = get_range_id(1);
|
||||
int c = get_program_id(1);
|
||||
int offset = c*DHWN;
|
||||
float g = *(G + c);
|
||||
float mean = *(M + c);
|
||||
|
@@ -114,11 +114,11 @@ std::string dot::triton_c_src_ydx() const {
|
||||
int lda, int ldb, int ldc,
|
||||
int N, int* lut,
|
||||
int* locks, int nlocks) {
|
||||
int ridx = get_range_id(0);
|
||||
int ridx = get_program_id(0);
|
||||
float acc[TM, TN] = 0;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
int *header = lut + get_range_id(1) * 4;
|
||||
int *header = lut + get_program_id(1) * 4;
|
||||
int offset = *(header + 0);
|
||||
int K = *(header + 1);
|
||||
int column = *(header + 2);
|
||||
@@ -191,7 +191,7 @@ std::string dot::triton_c_src_dw() const {
|
||||
int lda, int ldb, int ldc,
|
||||
int N, int* lut,
|
||||
int* locks, int nlocks) {
|
||||
int ridx = get_range_id(0);
|
||||
int ridx = get_program_id(0);
|
||||
float acc[TM, TN] = 0;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
|
@@ -686,8 +686,8 @@ if(b_lut_){
|
||||
float* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
int ridx = get_range_id(0);
|
||||
int ridy = get_range_id(1);
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int *plock = locks + ridx + ridy*grid0;
|
||||
while(__atomic_cas(plock, 0, 1) == 1);
|
||||
int *pcount = plock + grid0*grid1;
|
||||
|
@@ -116,8 +116,8 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
int M, int N, int K,
|
||||
)" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc,
|
||||
int bound, int *locks, int grid0, int grid1) {
|
||||
int ridx = get_range_id(0);
|
||||
int ridy = get_range_id(1);
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int ryb[TN] = ridy * TN + (0 ... TN);
|
||||
int rka[TK] = 0 ... TK;
|
||||
|
@@ -354,9 +354,9 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
||||
int BH, int BW,
|
||||
int CH, int CW,
|
||||
int* locks, int grid0, int grid1, int grid2) {
|
||||
int ridx = get_range_id(0);
|
||||
int ridy = get_range_id(1);
|
||||
int rz = get_range_id(2);
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rz = get_program_id(2);
|
||||
int rxa[TM] = ridx*TM + (0 ... TM);
|
||||
int ryb[TN] = ridy*TN + (0 ... TN);
|
||||
int rka[TK] = 0 ... TK;
|
||||
|
@@ -300,8 +300,8 @@ value *builder::create_downcast(value *arg, const std::string &name) {
|
||||
// built-in instructions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
value *builder::create_get_range_id(unsigned axis, const std::string &name) {
|
||||
return insert(get_range_id_inst::create(ctx_, axis, name));
|
||||
value *builder::create_get_program_id(unsigned axis, const std::string &name) {
|
||||
return insert(get_program_id_inst::create(ctx_, axis, name));
|
||||
}
|
||||
|
||||
value *builder::create_get_num_program(unsigned axis, const std::string &name) {
|
||||
|
@@ -664,14 +664,14 @@ instruction* select_inst::create(value *pred, value *if_value, value *else_value
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
// get_range_id
|
||||
get_range_id_inst::get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
|
||||
// get_program_id
|
||||
get_program_id_inst::get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
|
||||
: builtin_inst(ty, 0, 1, name, next), axis_(axis){
|
||||
|
||||
}
|
||||
|
||||
instruction* get_range_id_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
|
||||
return new get_range_id_inst(type::get_int32_ty(ctx), axis, name, next);
|
||||
instruction* get_program_id_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
|
||||
return new get_program_id_inst(type::get_int32_ty(ctx), axis, name, next);
|
||||
}
|
||||
|
||||
// get_num_program
|
||||
@@ -745,25 +745,25 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru
|
||||
return new barrier_inst(ctx, name, next);
|
||||
}
|
||||
|
||||
// nv_dynamic_range_idx
|
||||
nv_dynamic_range_idx_inst::nv_dynamic_range_idx_inst(type *ty, const std::string &name, instruction *next)
|
||||
// nv_dynamic_program_idx
|
||||
nv_dynamic_program_idx_inst::nv_dynamic_program_idx_inst(type *ty, const std::string &name, instruction *next)
|
||||
: instruction(ty, 0, 1, name, next) { }
|
||||
|
||||
nv_dynamic_range_idx_inst* nv_dynamic_range_idx_inst::create(type *ty, const std::string &name, instruction *next) {
|
||||
return new nv_dynamic_range_idx_inst(ty, name, next);
|
||||
nv_dynamic_program_idx_inst* nv_dynamic_program_idx_inst::create(type *ty, const std::string &name, instruction *next) {
|
||||
return new nv_dynamic_program_idx_inst(ty, name, next);
|
||||
}
|
||||
|
||||
// nv_static_range_idx
|
||||
nv_static_range_idx::nv_static_range_idx(constant_range *range)
|
||||
// nv_static_program_idx
|
||||
nv_static_program_idx::nv_static_program_idx(constant_range *range)
|
||||
: constant(range->get_type(), 0), range_(range) { }
|
||||
|
||||
constant_range* nv_static_range_idx::get_range() const
|
||||
constant_range* nv_static_program_idx::get_range() const
|
||||
{ return range_; }
|
||||
|
||||
nv_static_range_idx* nv_static_range_idx::get(constant_range* range) {
|
||||
static std::map<constant_range*, nv_static_range_idx*> cache;
|
||||
nv_static_program_idx* nv_static_program_idx::get(constant_range* range) {
|
||||
static std::map<constant_range*, nv_static_program_idx*> cache;
|
||||
if(cache.find(range) == cache.end())
|
||||
cache.insert({range, new nv_static_range_idx(range)});
|
||||
cache.insert({range, new nv_static_program_idx(range)});
|
||||
return cache.at(range);
|
||||
}
|
||||
|
||||
|
@@ -115,9 +115,9 @@ ir::value* alloc_const_expression::codegen(ir::module *mod) const {
|
||||
return res;
|
||||
}
|
||||
|
||||
// get_range_id
|
||||
ir::value* get_range_id_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_get_range_id(axis_->value());
|
||||
// get_program_id
|
||||
ir::value* get_program_id_expression::codegen(ir::module *mod) const {
|
||||
return mod->get_builder().create_get_program_id(axis_->value());
|
||||
}
|
||||
|
||||
// get_num_program
|
||||
|
@@ -1,14 +1,6 @@
|
||||
import libtriton
|
||||
import triton
|
||||
import tensorflow as tf
|
||||
import distutils
|
||||
import distutils.log
|
||||
import setuptools.command.build_ext
|
||||
import setuptools
|
||||
import numpy as np
|
||||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
import hashlib
|
||||
|
||||
src = """
|
||||
const tunable int TM = {128};
|
||||
@@ -20,8 +12,8 @@ void matmul(restrict read_only align(16) half *A,
|
||||
restrict read_only align(16) half *C,
|
||||
int M, int N, int K,
|
||||
multiple_of(8) int lda, multiple_of(8) int ldb, int ldc) {
|
||||
int ridx = get_range_id(0);
|
||||
int ridy = get_range_id(1);
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||
int ryb[TN] = ridy * TN + (0 ... TN);
|
||||
int rka[TK] = 0 ... TK;
|
||||
@@ -40,7 +32,7 @@ void matmul(restrict read_only align(16) half *A,
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
half* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
half* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis]*ldc;
|
||||
half c[TM, TN] = xc;
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
@@ -49,100 +41,10 @@ void matmul(restrict read_only align(16) half *A,
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so')
|
||||
|
||||
|
||||
def make_bindings(src, outputs, grids):
|
||||
return libtriton.make_tensorflow_src(src, outputs, grids)
|
||||
|
||||
def make_cache_path(src):
|
||||
md5 = hashlib.sha1(src.encode())
|
||||
hexhash = md5.hexdigest()
|
||||
home = os.path.expanduser('~')
|
||||
cacheroot = os.path.join(home, '.triton', 'cache')
|
||||
cachepath = os.path.join(cacheroot, str(hexhash))
|
||||
if not os.path.exists(cachepath):
|
||||
os.makedirs(cachepath)
|
||||
print(cachepath)
|
||||
return cachepath
|
||||
|
||||
def write_bindings(src, root):
|
||||
cpp = os.path.join(root, 'tensorflow.cpp')
|
||||
so = os.path.join(root, 'tensorflow.so')
|
||||
recompile = False
|
||||
# recompile if .so does not exist
|
||||
if not os.path.exists(cpp) or not os.path.exists(so):
|
||||
recompile = True
|
||||
# recompile if cpp was modified after .so
|
||||
elif max(cpp, so, key=os.path.getctime) == cpp:
|
||||
recompile = True
|
||||
# write cpp file
|
||||
if recompile:
|
||||
with open(cpp, 'w+') as handle:
|
||||
handle.writelines(src)
|
||||
# return path of cpp file
|
||||
return cpp
|
||||
|
||||
def build(src, path):
|
||||
# include directories
|
||||
triton_include_dirs = ['/home/philippe/development/triton/include']
|
||||
tensorflow_include_dirs = [tf.sysconfig.get_include()]
|
||||
cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/']
|
||||
include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs
|
||||
# library directories
|
||||
triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))]
|
||||
tensorflow_library_dirs = [tf.sysconfig.get_lib()]
|
||||
library_dirs = triton_library_dirs + tensorflow_library_dirs
|
||||
# libraries
|
||||
libraries = ['tensorflow_framework', 'triton']
|
||||
# extra arguments
|
||||
extra_compile_args = []
|
||||
extra_link_args = []
|
||||
# create extension module
|
||||
ext = setuptools.Extension(
|
||||
name = 'test',
|
||||
language = 'c++',
|
||||
sources = [src],
|
||||
include_dirs = include_dirs,
|
||||
extra_compile_args = extra_compile_args,
|
||||
extra_link_args = extra_link_args,
|
||||
library_dirs = library_dirs,
|
||||
libraries = libraries
|
||||
)
|
||||
# build extension module
|
||||
args = ['build_ext']
|
||||
tmp = tempfile.mkdtemp()
|
||||
args.append('--build-temp=' + tmp)
|
||||
args.append('--build-lib=' + path)
|
||||
args.append('-q')
|
||||
args = dict(
|
||||
name = 'test',
|
||||
ext_modules = [ext],
|
||||
script_args = args,
|
||||
)
|
||||
setuptools.setup(**args)
|
||||
shutil.rmtree(tmp)
|
||||
|
||||
def make_tensorflow_op(src, outputs, grids):
|
||||
bindings = make_bindings(src, outputs, grids)
|
||||
cache_path = make_cache_path(bindings)
|
||||
cpp = write_bindings(bindings, cache_path)
|
||||
build(cpp, cache_path)
|
||||
result = tf.load_op_library(os.path.join(cache_path, 'test.cpython-36m-x86_64-linux-gnu.so'))
|
||||
return result
|
||||
|
||||
|
||||
library_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
module = make_tensorflow_op(src, ['C'], ['(M + #TM - 1)/#TM', '(N + #TN - 1)/#TN'])
|
||||
print(module.matmul)
|
||||
|
||||
|
||||
class dot:
|
||||
|
||||
def __init__(self):
|
||||
trans_a = True
|
||||
trans_b = False
|
||||
self.matmul = triton.make_tensorflow_op(src, ['C'], ['(M + #TM - 1)/#TM', '(N + #TN - 1)/#TN'])
|
||||
|
||||
def __call__(self, a, b):
|
||||
shape_a = tf.shape(a)
|
||||
@@ -152,17 +54,17 @@ class dot:
|
||||
N = shape_b[0]
|
||||
lda = M
|
||||
ldb = K
|
||||
ldc = M
|
||||
c = extra_ops.alloc_empty(tf.stack([M, N]))
|
||||
return module.matmul(a, b, c, M, N, K, lda, ldb, ldc)
|
||||
ldc = N
|
||||
c = triton.empty([M, N])
|
||||
return self.matmul.matmul(a, b, c, M, N, K, lda, ldb, ldc)
|
||||
|
||||
dot_nt = dot()
|
||||
dot_tn = dot()
|
||||
def run_dot():
|
||||
M, N, K = 128, 128, 128
|
||||
a = tf.placeholder(tf.float16, shape=[M, K])
|
||||
b = tf.placeholder(tf.float16, shape=[N, K])
|
||||
# c = tf.matmul(a, b, transpose_a=True)
|
||||
c = dot_nt(a, b)
|
||||
c = dot_tn(a, b)
|
||||
# Reference
|
||||
ha = np.random.rand(M, K).astype(np.float16)
|
||||
hb = np.random.rand(N, K).astype(np.float16)
|
||||
@@ -172,7 +74,7 @@ def run_dot():
|
||||
result = sess.run([c], feed_dict = {a: ha,
|
||||
b: hb})[0]
|
||||
# Test
|
||||
hresult = np.dot(ha.T, hb).T
|
||||
hresult = np.dot(ha.T, hb)
|
||||
dif = np.abs(result - hresult)
|
||||
np.savetxt('dif.dat', dif, '%2.4f')
|
||||
print(hresult)
|
||||
|
@@ -18,6 +18,7 @@ class CMakeExtension(Extension):
|
||||
|
||||
|
||||
class CMakeBuild(build_ext):
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
out = subprocess.check_output(['cmake', '--version'])
|
||||
@@ -80,6 +81,7 @@ setup(
|
||||
author_email='ptillet@g.harvard.edu',
|
||||
description='A language and compiler for custom Deep Learning operations',
|
||||
long_description='',
|
||||
packages=['triton'],
|
||||
ext_modules=[CMakeExtension('triton')],
|
||||
cmdclass=dict(build_ext=CMakeBuild),
|
||||
zip_safe=False,
|
||||
|
@@ -177,8 +177,8 @@ result += R"(
|
||||
|
||||
|
||||
std::regex regex("#([a-zA-Z]([a-zA-Z]|[0-9])*)");
|
||||
std::vector<std::string> grids;
|
||||
for(size_t i = macros.size(); i < 3; i++)
|
||||
std::vector<std::string> grids = macros;
|
||||
for(size_t i = grids.size(); i < 3; i++)
|
||||
grids.push_back("1");
|
||||
std::string grid = "rt::grid_t{";
|
||||
for(size_t i = 0; i < grids.size(); i++){
|
||||
|
1
python/triton/__init__.py
Normal file
1
python/triton/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .ops import *
|
103
python/triton/ops.py
Normal file
103
python/triton/ops.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# import for cache
|
||||
import os
|
||||
import tempfile
|
||||
import shutil
|
||||
import hashlib
|
||||
import sysconfig
|
||||
import sys
|
||||
# import for just-in-time compilation
|
||||
import distutils
|
||||
import setuptools.command.build_ext
|
||||
import setuptools
|
||||
# triton
|
||||
import libtriton
|
||||
# frameworks
|
||||
import tensorflow as tf
|
||||
|
||||
extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so')
|
||||
|
||||
|
||||
def make_bindings(src, outputs, grids):
|
||||
return libtriton.make_tensorflow_src(src, outputs, grids)
|
||||
|
||||
def make_cache_path(src):
|
||||
md5 = hashlib.sha1(src.encode())
|
||||
hexhash = md5.hexdigest()
|
||||
home = os.path.expanduser('~')
|
||||
cacheroot = os.path.join(home, '.triton', 'cache')
|
||||
cachepath = os.path.join(cacheroot, str(hexhash))
|
||||
if not os.path.exists(cachepath):
|
||||
os.makedirs(cachepath)
|
||||
return cachepath
|
||||
|
||||
def write_bindings(src, root):
|
||||
cpp = os.path.join(root, 'tensorflow.cpp')
|
||||
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||
so = os.path.join(root, 'tensorflow{suffix}'.format(suffix=suffix))
|
||||
recompile = False
|
||||
# recompile if .so does not exist
|
||||
if not os.path.exists(cpp) or not os.path.exists(so):
|
||||
recompile = True
|
||||
# recompile if cpp was modified after .so
|
||||
elif max(cpp, so, key=os.path.getctime) == cpp:
|
||||
recompile = True
|
||||
# write cpp file
|
||||
if recompile:
|
||||
with open(cpp, 'w+') as handle:
|
||||
handle.writelines(src)
|
||||
# return path of cpp file
|
||||
return (cpp, so)
|
||||
|
||||
def build(src, path):
|
||||
# include directories
|
||||
triton_include_dirs = ['/home/philippe/development/triton/include']
|
||||
tensorflow_include_dirs = [tf.sysconfig.get_include()]
|
||||
cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/']
|
||||
include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs
|
||||
# library directories
|
||||
triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))]
|
||||
tensorflow_library_dirs = [tf.sysconfig.get_lib()]
|
||||
library_dirs = triton_library_dirs + tensorflow_library_dirs
|
||||
# libraries
|
||||
libraries = ['tensorflow_framework', 'triton']
|
||||
# extra arguments
|
||||
extra_compile_args = []
|
||||
extra_link_args = []
|
||||
# dependences
|
||||
depends = [os.path.realpath(libtriton.__file__)]
|
||||
# create extension module
|
||||
ext = setuptools.Extension(
|
||||
name = 'tensorflow',
|
||||
language = 'c++',
|
||||
sources = [src],
|
||||
include_dirs = include_dirs,
|
||||
extra_compile_args = extra_compile_args,
|
||||
extra_link_args = extra_link_args,
|
||||
library_dirs = library_dirs,
|
||||
libraries = libraries,
|
||||
depends = depends
|
||||
)
|
||||
# build extension module
|
||||
args = ['build_ext']
|
||||
tmp = tempfile.mkdtemp()
|
||||
args.append('--build-temp=' + tmp)
|
||||
args.append('--build-lib=' + path)
|
||||
args.append('-q')
|
||||
args = dict(
|
||||
name = 'tensorflow',
|
||||
ext_modules = [ext],
|
||||
script_args = args,
|
||||
)
|
||||
setuptools.setup(**args)
|
||||
shutil.rmtree(tmp)
|
||||
|
||||
def make_tensorflow_op(src, outputs, grids):
|
||||
bindings = make_bindings(src, outputs, grids)
|
||||
cache_path = make_cache_path(bindings)
|
||||
cpp, so = write_bindings(bindings, cache_path)
|
||||
build(cpp, cache_path)
|
||||
result = tf.load_op_library(so)
|
||||
return result
|
||||
|
||||
def empty(shapes):
|
||||
return extra_ops.alloc_empty(tf.stack(shapes))
|
Reference in New Issue
Block a user