From b4a9ed9663bccfe5a6af35392bbae8be0cb56130 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 17 Aug 2019 18:18:26 -0700 Subject: [PATCH] [python] added basic tensorflow support --- examples/cpp/cuda.h | 160 +++++++++++++++++++ examples/cpp/dot.cpp | 4 +- include/triton/codegen/selection/selection.h | 2 +- include/triton/ir/builder.h | 2 +- include/triton/ir/instructions.h | 24 +-- include/triton/lang/expression.h | 4 +- include/triton/lang/parser.y | 4 +- include/triton/lang/scanner.l | 2 +- lib/codegen/analysis/alignment.cpp | 4 +- lib/codegen/selection/selection.cpp | 10 +- lib/codegen/transform/reassociate.cpp | 6 +- lib/dnn/batchnorm.cpp | 4 +- lib/dnn/blocksparse/dot.cpp | 6 +- lib/dnn/conv.cpp | 4 +- lib/dnn/dot.cpp | 4 +- lib/dnn/shift.cpp | 6 +- lib/ir/builder.cpp | 4 +- lib/ir/instructions.cpp | 28 ++-- lib/lang/expression.cpp | 6 +- python/examples/dot.py | 120 ++------------ python/setup.py | 2 + python/src/tensorflow.cpp | 4 +- python/triton/__init__.py | 1 + python/triton/ops.py | 103 ++++++++++++ 24 files changed, 341 insertions(+), 173 deletions(-) create mode 100644 examples/cpp/cuda.h create mode 100644 python/triton/__init__.py create mode 100644 python/triton/ops.py diff --git a/examples/cpp/cuda.h b/examples/cpp/cuda.h new file mode 100644 index 000000000..5f03870f5 --- /dev/null +++ b/examples/cpp/cuda.h @@ -0,0 +1,160 @@ +/* Copyright 2015-2017 Philippe Tillet +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files +* (the "Software"), to deal in the Software without restriction, +* including without limitation the rights to use, copy, modify, merge, +* publish, distribute, sublicense, and/or sell copies of the Software, +* and to permit persons to whom the Software is furnished to do so, +* subject to the following conditions: +* +* The above copyright notice and this permission notice shall be +* included in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +*/ + +#include +#include +#include +#include "cublas_v2.h" +#include "triton/driver/buffer.h" +#include "triton/driver/stream.h" +#include "triton/driver/context.h" +#include "triton/tools/bench.hpp" + +enum cublasStrategy_t{ + CUBLAS_PREFER_FASTEST, + CUBLAS_HEURISTICS +}; + +enum DType{ + HALF_TYPE, + FLOAT_TYPE, + DOUBLE_TYPE, +}; + +inline size_t size_of(DType dtype){ + switch (dtype) { + case HALF_TYPE: return 2; + case FLOAT_TYPE: return 4; + case DOUBLE_TYPE: return 8; + default: throw; + } +} + +std::vector gather_all_algos() { + std::vector result; + // non-tensor ops + for(int i = -1; i < 24; i++) + result.push_back((cublasGemmAlgo_t)i); + // tensor ops + for(int i = 99; i < 116; i++) + result.push_back((cublasGemmAlgo_t)i); + return result; +} + +static const std::vector algorithms = gather_all_algos(); + +static const std::map cu_dtype = { + {HALF_TYPE, CUDA_R_16F}, + {FLOAT_TYPE, CUDA_R_32F}, + {DOUBLE_TYPE, CUDA_R_64F} +}; + +static const std::map cu_op = { + {false, CUBLAS_OP_N}, + {true, CUBLAS_OP_T} +}; + +inline cublasGemmAlgo_t cublasGemmFastest( + triton::driver::stream* stream, + cublasHandle_t handle, cudaDataType cudt, + cublasOperation_t AT, cublasOperation_t BT, + int32_t M, int32_t N, int32_t K, + void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb, + void* beta, CUdeviceptr C, int32_t ldc) { + + // cache to avoid re-benchmarking + typedef std::tuple key_t; + static std::map cache; + key_t key(cudt, AT, BT, M, N, K); + // benchmark algorithms if necessary + if(cache.find(key) == cache.end()){ + std::vector times; + for(cublasGemmAlgo_t a: algorithms) { + cublasStatus_t status; + double nanosec = triton::tools::bench([&](){ status = cublasGemmEx(handle, AT, BT, + M, N, K, + alpha, (const void*)A, cudt, lda, + (const void*)B, cudt, ldb, + beta, (void*)C, cudt, ldc, cudt, + a); }, stream); + if(status != CUBLAS_STATUS_SUCCESS) + nanosec = INFINITY; + } + size_t argmin = std::min_element(times.begin(), times.end()) - times.begin(); + assert(times[argmin] != INFINITY); + cache.insert({key, algorithms[argmin]}); + } + + // return best algorithm + return cache.at(key); +} + +/* Wrapper for cublasGemmEx */ +inline cublasStatus_t cublasGemmEx(cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K, + void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb, + void* beta, CUdeviceptr C, int32_t ldc, cublasGemmAlgo_t algo) +{ + cublasStatus_t status = cublasGemmEx(handle, AT, BT, M, N, K, alpha, (const void*)A, cudt, lda, (const void*)B, cudt, ldb, beta, (void*)C, cudt, ldc, cudt, algo); + if(status != CUBLAS_STATUS_SUCCESS){ + std::cout << status; + exit(EXIT_FAILURE); + } +} + + +/* Get cuBLAS handle */ +cublasHandle_t cublasGetHandle(triton::driver::stream* stream) { + static std::map cache; + CUstream key = *stream->cu(); + + // create handle if necessary + if(cache.find(key) == cache.end()) { + cublasHandle_t handle; + if(cublasCreate_v2(&handle) != CUBLAS_STATUS_SUCCESS) + throw std::runtime_error("Error: could not create cuBLAS handle"); + cublasSetStream_v2(handle, key); + cache.insert({key, handle}); + } + + // return handle for the stream + return cache.at(key); +} + +/* Simplified API for default GEMM */ +inline void cublasGemm(DType dtype, triton::driver::stream* stream, bool AT, bool BT, + int32_t M, int32_t N, int32_t K, + void* alpha, triton::driver::buffer* A, int32_t lda, + triton::driver::buffer* B, int32_t ldb, + void* beta, triton::driver::buffer* C, int32_t ldc, + cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT) { + triton::driver::cu_context::context_switcher scope(*stream->context()); + static cublasHandle_t handle = cublasGetHandle(stream); + if(dtype == HALF_TYPE) + cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH); + cublasStatus_t status; + if(fastest) + *fastest = cublasGemmFastest(stream, handle, cu_dtype.at(dtype), cu_op.at(AT), cu_op.at(BT), M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc); + else + status = cublasGemmEx(handle, cu_dtype.at(dtype), cu_op.at(AT), cu_op.at(BT), M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc, algo); +} diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index 90287f719..e592da570 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -88,8 +88,8 @@ void matmul(restrict read_only align(16) )" + a_ty + R"( *A, restrict read_only align(16) )" + c_ty + R"( *C, int M, int N, int K, )" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc) { - int ridx = get_range_id(0); - int ridy = get_range_id(1); + int ridx = get_program_id(0); + int ridy = get_program_id(1); int rxa[TM] = ridx * TM + (0 ... TM); int ryb[TN] = ridy * TN + (0 ... TN); int rka[TK] = 0 ... TK; diff --git a/include/triton/codegen/selection/selection.h b/include/triton/codegen/selection/selection.h index 3b871dce0..433633cff 100644 --- a/include/triton/codegen/selection/selection.h +++ b/include/triton/codegen/selection/selection.h @@ -169,7 +169,7 @@ private: void lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, Builder &builder); void lower_downcast(ir::downcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder); void lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, Builder &builder); - void lower_dynamic_range_idx(ir::nv_dynamic_range_idx_inst *x, LLVMContext &ctx, Function *fn, Builder &builder); + void lower_dynamic_program_idx(ir::nv_dynamic_program_idx_inst *x, LLVMContext &ctx, Function *fn, Builder &builder); void lower_reshape(ir::reshape_inst* x, LLVMContext &ctx, Function *fn, Builder &builder); void lower_splat(ir::splat_inst *x, LLVMContext &ctx, Function *fn, Builder &builder); void lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder); diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index a104cc6b4..4f5f4f45b 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -126,7 +126,7 @@ public: value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = ""); value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = ""); // Built-in instruction - value *create_get_range_id(unsigned axis, const std::string &name = ""); + value *create_get_program_id(unsigned axis, const std::string &name = ""); value *create_get_num_program(unsigned axis, const std::string &name = ""); value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = ""); value *create_atomic_exch(value *ptr, value *val, const std::string &name = ""); diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index e9791e2a1..446dd871b 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -496,10 +496,10 @@ protected: using instruction::instruction; }; -class get_range_id_inst: public builtin_inst { +class get_program_id_inst: public builtin_inst { private: - get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next); - std::string repr_impl() const { return "get_range_id(" + std::to_string(axis_) + ")"; } + get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next); + std::string repr_impl() const { return "get_program_id(" + std::to_string(axis_) + ")"; } public: static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr); @@ -668,23 +668,23 @@ public: }; // On NVIDIA, implementation is such that -// constant_range = nv_dynamic_range_idx + nv_static_range_idx -// so as to enable re-association on nv_static_range_idx which is constant -class nv_dynamic_range_idx_inst: public instruction { +// constant_range = nv_dynamic_program_idx + nv_static_program_idx +// so as to enable re-association on nv_static_program_idx which is constant +class nv_dynamic_program_idx_inst: public instruction { private: - nv_dynamic_range_idx_inst(type *ty, const std::string &name, instruction *next); - std::string repr_impl() const { return "nv_dynamic_range_idx"; } + nv_dynamic_program_idx_inst(type *ty, const std::string &name, instruction *next); + std::string repr_impl() const { return "nv_dynamic_program_idx"; } public: - static nv_dynamic_range_idx_inst* create(type *ty, const std::string &name = "", instruction *next = nullptr); + static nv_dynamic_program_idx_inst* create(type *ty, const std::string &name = "", instruction *next = nullptr); }; -class nv_static_range_idx: public constant { +class nv_static_program_idx: public constant { private: - nv_static_range_idx(constant_range *range); + nv_static_program_idx(constant_range *range); public: - static nv_static_range_idx *get(constant_range* range); + static nv_static_program_idx *get(constant_range* range); constant_range* get_range() const; private: diff --git a/include/triton/lang/expression.h b/include/triton/lang/expression.h index 6823e8988..9d65de5c0 100644 --- a/include/triton/lang/expression.h +++ b/include/triton/lang/expression.h @@ -71,9 +71,9 @@ private: const constant* size_; }; -class get_range_id_expression: public builtin_expression{ +class get_program_id_expression: public builtin_expression{ public: - get_range_id_expression(node *axis): axis_((constant*)axis) { } + get_program_id_expression(node *axis): axis_((constant*)axis) { } ir::value* codegen(ir::module *) const; private: diff --git a/include/triton/lang/parser.y b/include/triton/lang/parser.y index c44a619e8..d67a89562 100644 --- a/include/triton/lang/parser.y +++ b/include/triton/lang/parser.y @@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;} %token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64 %token IF ELSE FOR CONTINUE WHILE %token NEWAXIS ELLIPSIS AT -%token GET_NUM_PROGRAM GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST RESHAPE +%token GET_NUM_PROGRAM GET_PROGRAM_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST RESHAPE %start translation_unit %% @@ -120,7 +120,7 @@ identifier /* Built-in */ builtin_expression - : GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); } + : GET_PROGRAM_ID '(' constant ')' { $$ = new get_program_id_expression($3); } | GET_NUM_PROGRAM '(' constant ')' { $$ = new get_num_program_expression($3); } | DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); } | SQRT '(' expression ')' { $$ = new sqrt_expression($3); } diff --git a/include/triton/lang/scanner.l b/include/triton/lang/scanner.l index 1aaf40a57..6062a51ad 100644 --- a/include/triton/lang/scanner.l +++ b/include/triton/lang/scanner.l @@ -43,7 +43,7 @@ using triton::lang::return_void; "float" { return return_impl(FP32, yytext); } "double" { return return_impl(FP64, yytext); } "..." { return return_impl(ELLIPSIS, yytext); } -"get_range_id" { return return_impl(GET_RANGE_ID, yytext); } +"get_program_id" { return return_impl(GET_PROGRAM_ID, yytext); } "get_num_program" { return return_impl(GET_NUM_PROGRAM, yytext); } "__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); } "__atomic_exch" { return return_impl(ATOMIC_EXCH, yytext); } diff --git a/lib/codegen/analysis/alignment.cpp b/lib/codegen/analysis/alignment.cpp index 3ed74f7a3..a602c87ca 100644 --- a/lib/codegen/analysis/alignment.cpp +++ b/lib/codegen/analysis/alignment.cpp @@ -227,10 +227,10 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){ if(auto *x = dynamic_cast(v)){ return cache(x->get_first()->get_value()); } - if(auto *x = dynamic_cast(v)){ + if(auto *x = dynamic_cast(v)){ return cache(128); } - if(auto *x = dynamic_cast(v)){ + if(auto *x = dynamic_cast(v)){ return cache(x->get_range()->get_first()->get_value()); } if(auto *x = dynamic_cast(v)){ diff --git a/lib/codegen/selection/selection.cpp b/lib/codegen/selection/selection.cpp index 0ca17f9e0..166b423bb 100644 --- a/lib/codegen/selection/selection.cpp +++ b/lib/codegen/selection/selection.cpp @@ -411,7 +411,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::functionget_operand(2)); return builder.Insert(SelectInst::Create(pred, if_value, else_value)); } - if(ir::get_range_id_inst* ii = dynamic_cast(inst)){ + if(ir::get_program_id_inst* ii = dynamic_cast(inst)){ Value *result = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis()); return (Instruction*)result; } @@ -837,7 +837,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, T->set_value(idx, idx[0]); }); } - if(is_inserted && dynamic_cast(v)){ + if(is_inserted && dynamic_cast(v)){ T->for_each([&](indices_t idx){ assert(idx.size() == 1); BinaryOperator *bin_add = dyn_cast(idx[0]); @@ -996,7 +996,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, } } -void selection::lower_dynamic_range_idx(ir::nv_dynamic_range_idx_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { +void selection::lower_dynamic_program_idx(ir::nv_dynamic_program_idx_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) { distributed_tile* result = (distributed_tile*)tmap_.at(x); result->for_each([&](indices_t idx){ assert(idx.size() == 1); @@ -1418,8 +1418,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & lower_downcast(x, ctx, fn, builder); else if(auto *x = dynamic_cast(ins)) lower_reduce(x, ctx, fn, builder); - else if(auto *x = dynamic_cast(ins)) - lower_dynamic_range_idx(x, ctx, fn, builder); + else if(auto *x = dynamic_cast(ins)) + lower_dynamic_program_idx(x, ctx, fn, builder); else if(auto *x = dynamic_cast(ins)) lower_reshape(x, ctx, fn, builder); else if(auto *x = dynamic_cast(ins)) diff --git a/lib/codegen/transform/reassociate.cpp b/lib/codegen/transform/reassociate.cpp index 6893a7a10..c411ccf12 100644 --- a/lib/codegen/transform/reassociate.cpp +++ b/lib/codegen/transform/reassociate.cpp @@ -164,7 +164,7 @@ reassociate::reassociate(analysis::tune* params) void reassociate::run(ir::module &mod) { ir::builder &builder = mod.get_builder(); - // constant_range -> nv_dynamic_range_idx + nv_static_range_idx + // constant_range -> nv_dynamic_program_idx + nv_static_program_idx for(ir::function *fn: mod.get_function_list()){ std::vector ranges; std::vector rpo = ir::cfg::reverse_post_order(fn); @@ -178,8 +178,8 @@ void reassociate::run(ir::module &mod) { builder.set_insert_point(rpo.front()->get_first_non_phi()); for(ir::constant_range* old_range: ranges){ - ir::value* dyn_range = builder.insert(ir::nv_dynamic_range_idx_inst::create(old_range->get_type())); - ir::value* static_range = ir::nv_static_range_idx::get(old_range); + ir::value* dyn_range = builder.insert(ir::nv_dynamic_program_idx_inst::create(old_range->get_type())); + ir::value* static_range = ir::nv_static_program_idx::get(old_range); ir::value* new_range = builder.create_add(dyn_range, static_range); old_range->replace_all_uses_with(new_range); params_->copy(dyn_range, old_range); diff --git a/lib/dnn/batchnorm.cpp b/lib/dnn/batchnorm.cpp index e5143755e..fe785afdd 100644 --- a/lib/dnn/batchnorm.cpp +++ b/lib/dnn/batchnorm.cpp @@ -82,7 +82,7 @@ void batchnorm_forward(float *Y, float *M, float *V, int rx[TM] = 0 ... TM; float *px[TM]; float x[TM] = 0; - int c = get_range_id(1); + int c = get_program_id(1); float g = *(G + c); float b = *(B + c); @@ -177,7 +177,7 @@ void batchnorm_backward(float *DX, float *DG, float *DB, restrict read_only float *V, int DHWN, float rcpDHWN, float epsilon) { int rx[TM] = 0 ... TM; - int c = get_range_id(1); + int c = get_program_id(1); int offset = c*DHWN; float g = *(G + c); float mean = *(M + c); diff --git a/lib/dnn/blocksparse/dot.cpp b/lib/dnn/blocksparse/dot.cpp index 97823e309..b155f9c89 100644 --- a/lib/dnn/blocksparse/dot.cpp +++ b/lib/dnn/blocksparse/dot.cpp @@ -114,11 +114,11 @@ std::string dot::triton_c_src_ydx() const { int lda, int ldb, int ldc, int N, int* lut, int* locks, int nlocks) { - int ridx = get_range_id(0); + int ridx = get_program_id(0); float acc[TM, TN] = 0; int rka[TK] = 0 ... TK; int rkb[TK] = 0 ... TK; - int *header = lut + get_range_id(1) * 4; + int *header = lut + get_program_id(1) * 4; int offset = *(header + 0); int K = *(header + 1); int column = *(header + 2); @@ -191,7 +191,7 @@ std::string dot::triton_c_src_dw() const { int lda, int ldb, int ldc, int N, int* lut, int* locks, int nlocks) { - int ridx = get_range_id(0); + int ridx = get_program_id(0); float acc[TM, TN] = 0; int rka[TK] = 0 ... TK; int rkb[TK] = 0 ... TK; diff --git a/lib/dnn/conv.cpp b/lib/dnn/conv.cpp index 63503c70f..381691ff0 100644 --- a/lib/dnn/conv.cpp +++ b/lib/dnn/conv.cpp @@ -686,8 +686,8 @@ if(b_lut_){ float* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis]; bool checkc0[TM] = rxc < M; bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; - int ridx = get_range_id(0); - int ridy = get_range_id(1); + int ridx = get_program_id(0); + int ridy = get_program_id(1); int *plock = locks + ridx + ridy*grid0; while(__atomic_cas(plock, 0, 1) == 1); int *pcount = plock + grid0*grid1; diff --git a/lib/dnn/dot.cpp b/lib/dnn/dot.cpp index 4ea355170..f3d35a2f0 100644 --- a/lib/dnn/dot.cpp +++ b/lib/dnn/dot.cpp @@ -116,8 +116,8 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, int M, int N, int K, )" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc, int bound, int *locks, int grid0, int grid1) { - int ridx = get_range_id(0); - int ridy = get_range_id(1); + int ridx = get_program_id(0); + int ridy = get_program_id(1); int rxa[TM] = ridx * TM + (0 ... TM); int ryb[TN] = ridy * TN + (0 ... TN); int rka[TK] = 0 ... TK; diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index 5b50a73b4..93ae57cd4 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -354,9 +354,9 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A, int BH, int BW, int CH, int CW, int* locks, int grid0, int grid1, int grid2) { - int ridx = get_range_id(0); - int ridy = get_range_id(1); - int rz = get_range_id(2); + int ridx = get_program_id(0); + int ridy = get_program_id(1); + int rz = get_program_id(2); int rxa[TM] = ridx*TM + (0 ... TM); int ryb[TN] = ridy*TN + (0 ... TN); int rka[TK] = 0 ... TK; diff --git a/lib/ir/builder.cpp b/lib/ir/builder.cpp index ef2d81abf..9fe444dd1 100644 --- a/lib/ir/builder.cpp +++ b/lib/ir/builder.cpp @@ -300,8 +300,8 @@ value *builder::create_downcast(value *arg, const std::string &name) { // built-in instructions //===----------------------------------------------------------------------===// -value *builder::create_get_range_id(unsigned axis, const std::string &name) { - return insert(get_range_id_inst::create(ctx_, axis, name)); +value *builder::create_get_program_id(unsigned axis, const std::string &name) { + return insert(get_program_id_inst::create(ctx_, axis, name)); } value *builder::create_get_num_program(unsigned axis, const std::string &name) { diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index 074b55bb8..85b6eee5c 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -664,14 +664,14 @@ instruction* select_inst::create(value *pred, value *if_value, value *else_value //===----------------------------------------------------------------------===// -// get_range_id -get_range_id_inst::get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next) +// get_program_id +get_program_id_inst::get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next) : builtin_inst(ty, 0, 1, name, next), axis_(axis){ } -instruction* get_range_id_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) { - return new get_range_id_inst(type::get_int32_ty(ctx), axis, name, next); +instruction* get_program_id_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) { + return new get_program_id_inst(type::get_int32_ty(ctx), axis, name, next); } // get_num_program @@ -745,25 +745,25 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru return new barrier_inst(ctx, name, next); } -// nv_dynamic_range_idx -nv_dynamic_range_idx_inst::nv_dynamic_range_idx_inst(type *ty, const std::string &name, instruction *next) +// nv_dynamic_program_idx +nv_dynamic_program_idx_inst::nv_dynamic_program_idx_inst(type *ty, const std::string &name, instruction *next) : instruction(ty, 0, 1, name, next) { } -nv_dynamic_range_idx_inst* nv_dynamic_range_idx_inst::create(type *ty, const std::string &name, instruction *next) { - return new nv_dynamic_range_idx_inst(ty, name, next); +nv_dynamic_program_idx_inst* nv_dynamic_program_idx_inst::create(type *ty, const std::string &name, instruction *next) { + return new nv_dynamic_program_idx_inst(ty, name, next); } -// nv_static_range_idx -nv_static_range_idx::nv_static_range_idx(constant_range *range) +// nv_static_program_idx +nv_static_program_idx::nv_static_program_idx(constant_range *range) : constant(range->get_type(), 0), range_(range) { } -constant_range* nv_static_range_idx::get_range() const +constant_range* nv_static_program_idx::get_range() const { return range_; } -nv_static_range_idx* nv_static_range_idx::get(constant_range* range) { - static std::map cache; +nv_static_program_idx* nv_static_program_idx::get(constant_range* range) { + static std::map cache; if(cache.find(range) == cache.end()) - cache.insert({range, new nv_static_range_idx(range)}); + cache.insert({range, new nv_static_program_idx(range)}); return cache.at(range); } diff --git a/lib/lang/expression.cpp b/lib/lang/expression.cpp index acbfaf6f6..8d5288e8b 100644 --- a/lib/lang/expression.cpp +++ b/lib/lang/expression.cpp @@ -115,9 +115,9 @@ ir::value* alloc_const_expression::codegen(ir::module *mod) const { return res; } -// get_range_id -ir::value* get_range_id_expression::codegen(ir::module *mod) const { - return mod->get_builder().create_get_range_id(axis_->value()); +// get_program_id +ir::value* get_program_id_expression::codegen(ir::module *mod) const { + return mod->get_builder().create_get_program_id(axis_->value()); } // get_num_program diff --git a/python/examples/dot.py b/python/examples/dot.py index 6c79e846c..75fe931bc 100644 --- a/python/examples/dot.py +++ b/python/examples/dot.py @@ -1,14 +1,6 @@ -import libtriton +import triton import tensorflow as tf -import distutils -import distutils.log -import setuptools.command.build_ext -import setuptools import numpy as np -import os -import tempfile -import shutil -import hashlib src = """ const tunable int TM = {128}; @@ -20,8 +12,8 @@ void matmul(restrict read_only align(16) half *A, restrict read_only align(16) half *C, int M, int N, int K, multiple_of(8) int lda, multiple_of(8) int ldb, int ldc) { - int ridx = get_range_id(0); - int ridy = get_range_id(1); + int ridx = get_program_id(0); + int ridy = get_program_id(1); int rxa[TM] = ridx * TM + (0 ... TM); int ryb[TN] = ridy * TN + (0 ... TN); int rka[TK] = 0 ... TK; @@ -40,7 +32,7 @@ void matmul(restrict read_only align(16) half *A, } int rxc[TM] = ridx * TM + (0 ... TM); int ryc[TN] = ridy * TN + (0 ... TN); - half* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; + half* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis]*ldc; half c[TM, TN] = xc; bool checkc0[TM] = rxc < M; bool checkc1[TN] = ryc < N; @@ -49,100 +41,10 @@ void matmul(restrict read_only align(16) half *A, } """ - -extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so') - - -def make_bindings(src, outputs, grids): - return libtriton.make_tensorflow_src(src, outputs, grids) - -def make_cache_path(src): - md5 = hashlib.sha1(src.encode()) - hexhash = md5.hexdigest() - home = os.path.expanduser('~') - cacheroot = os.path.join(home, '.triton', 'cache') - cachepath = os.path.join(cacheroot, str(hexhash)) - if not os.path.exists(cachepath): - os.makedirs(cachepath) - print(cachepath) - return cachepath - -def write_bindings(src, root): - cpp = os.path.join(root, 'tensorflow.cpp') - so = os.path.join(root, 'tensorflow.so') - recompile = False - # recompile if .so does not exist - if not os.path.exists(cpp) or not os.path.exists(so): - recompile = True - # recompile if cpp was modified after .so - elif max(cpp, so, key=os.path.getctime) == cpp: - recompile = True - # write cpp file - if recompile: - with open(cpp, 'w+') as handle: - handle.writelines(src) - # return path of cpp file - return cpp - -def build(src, path): - # include directories - triton_include_dirs = ['/home/philippe/development/triton/include'] - tensorflow_include_dirs = [tf.sysconfig.get_include()] - cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/'] - include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs - # library directories - triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))] - tensorflow_library_dirs = [tf.sysconfig.get_lib()] - library_dirs = triton_library_dirs + tensorflow_library_dirs - # libraries - libraries = ['tensorflow_framework', 'triton'] - # extra arguments - extra_compile_args = [] - extra_link_args = [] - # create extension module - ext = setuptools.Extension( - name = 'test', - language = 'c++', - sources = [src], - include_dirs = include_dirs, - extra_compile_args = extra_compile_args, - extra_link_args = extra_link_args, - library_dirs = library_dirs, - libraries = libraries - ) - # build extension module - args = ['build_ext'] - tmp = tempfile.mkdtemp() - args.append('--build-temp=' + tmp) - args.append('--build-lib=' + path) - args.append('-q') - args = dict( - name = 'test', - ext_modules = [ext], - script_args = args, - ) - setuptools.setup(**args) - shutil.rmtree(tmp) - -def make_tensorflow_op(src, outputs, grids): - bindings = make_bindings(src, outputs, grids) - cache_path = make_cache_path(bindings) - cpp = write_bindings(bindings, cache_path) - build(cpp, cache_path) - result = tf.load_op_library(os.path.join(cache_path, 'test.cpython-36m-x86_64-linux-gnu.so')) - return result - - -library_dir = os.path.dirname(os.path.realpath(__file__)) -module = make_tensorflow_op(src, ['C'], ['(M + #TM - 1)/#TM', '(N + #TN - 1)/#TN']) -print(module.matmul) - - class dot: def __init__(self): - trans_a = True - trans_b = False + self.matmul = triton.make_tensorflow_op(src, ['C'], ['(M + #TM - 1)/#TM', '(N + #TN - 1)/#TN']) def __call__(self, a, b): shape_a = tf.shape(a) @@ -152,17 +54,17 @@ class dot: N = shape_b[0] lda = M ldb = K - ldc = M - c = extra_ops.alloc_empty(tf.stack([M, N])) - return module.matmul(a, b, c, M, N, K, lda, ldb, ldc) + ldc = N + c = triton.empty([M, N]) + return self.matmul.matmul(a, b, c, M, N, K, lda, ldb, ldc) -dot_nt = dot() +dot_tn = dot() def run_dot(): M, N, K = 128, 128, 128 a = tf.placeholder(tf.float16, shape=[M, K]) b = tf.placeholder(tf.float16, shape=[N, K]) # c = tf.matmul(a, b, transpose_a=True) - c = dot_nt(a, b) + c = dot_tn(a, b) # Reference ha = np.random.rand(M, K).astype(np.float16) hb = np.random.rand(N, K).astype(np.float16) @@ -172,7 +74,7 @@ def run_dot(): result = sess.run([c], feed_dict = {a: ha, b: hb})[0] # Test - hresult = np.dot(ha.T, hb).T + hresult = np.dot(ha.T, hb) dif = np.abs(result - hresult) np.savetxt('dif.dat', dif, '%2.4f') print(hresult) diff --git a/python/setup.py b/python/setup.py index 3d98218ac..aeba8b5a6 100644 --- a/python/setup.py +++ b/python/setup.py @@ -18,6 +18,7 @@ class CMakeExtension(Extension): class CMakeBuild(build_ext): + def run(self): try: out = subprocess.check_output(['cmake', '--version']) @@ -80,6 +81,7 @@ setup( author_email='ptillet@g.harvard.edu', description='A language and compiler for custom Deep Learning operations', long_description='', + packages=['triton'], ext_modules=[CMakeExtension('triton')], cmdclass=dict(build_ext=CMakeBuild), zip_safe=False, diff --git a/python/src/tensorflow.cpp b/python/src/tensorflow.cpp index 40810fc75..0e98f6636 100644 --- a/python/src/tensorflow.cpp +++ b/python/src/tensorflow.cpp @@ -177,8 +177,8 @@ result += R"( std::regex regex("#([a-zA-Z]([a-zA-Z]|[0-9])*)"); -std::vector grids; -for(size_t i = macros.size(); i < 3; i++) +std::vector grids = macros; +for(size_t i = grids.size(); i < 3; i++) grids.push_back("1"); std::string grid = "rt::grid_t{"; for(size_t i = 0; i < grids.size(); i++){ diff --git a/python/triton/__init__.py b/python/triton/__init__.py new file mode 100644 index 000000000..18dff0a49 --- /dev/null +++ b/python/triton/__init__.py @@ -0,0 +1 @@ +from .ops import * \ No newline at end of file diff --git a/python/triton/ops.py b/python/triton/ops.py new file mode 100644 index 000000000..ea782ad08 --- /dev/null +++ b/python/triton/ops.py @@ -0,0 +1,103 @@ +# import for cache +import os +import tempfile +import shutil +import hashlib +import sysconfig +import sys +# import for just-in-time compilation +import distutils +import setuptools.command.build_ext +import setuptools +# triton +import libtriton +# frameworks +import tensorflow as tf + +extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so') + + +def make_bindings(src, outputs, grids): + return libtriton.make_tensorflow_src(src, outputs, grids) + +def make_cache_path(src): + md5 = hashlib.sha1(src.encode()) + hexhash = md5.hexdigest() + home = os.path.expanduser('~') + cacheroot = os.path.join(home, '.triton', 'cache') + cachepath = os.path.join(cacheroot, str(hexhash)) + if not os.path.exists(cachepath): + os.makedirs(cachepath) + return cachepath + +def write_bindings(src, root): + cpp = os.path.join(root, 'tensorflow.cpp') + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(root, 'tensorflow{suffix}'.format(suffix=suffix)) + recompile = False + # recompile if .so does not exist + if not os.path.exists(cpp) or not os.path.exists(so): + recompile = True + # recompile if cpp was modified after .so + elif max(cpp, so, key=os.path.getctime) == cpp: + recompile = True + # write cpp file + if recompile: + with open(cpp, 'w+') as handle: + handle.writelines(src) + # return path of cpp file + return (cpp, so) + +def build(src, path): + # include directories + triton_include_dirs = ['/home/philippe/development/triton/include'] + tensorflow_include_dirs = [tf.sysconfig.get_include()] + cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/'] + include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs + # library directories + triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))] + tensorflow_library_dirs = [tf.sysconfig.get_lib()] + library_dirs = triton_library_dirs + tensorflow_library_dirs + # libraries + libraries = ['tensorflow_framework', 'triton'] + # extra arguments + extra_compile_args = [] + extra_link_args = [] + # dependences + depends = [os.path.realpath(libtriton.__file__)] + # create extension module + ext = setuptools.Extension( + name = 'tensorflow', + language = 'c++', + sources = [src], + include_dirs = include_dirs, + extra_compile_args = extra_compile_args, + extra_link_args = extra_link_args, + library_dirs = library_dirs, + libraries = libraries, + depends = depends + ) + # build extension module + args = ['build_ext'] + tmp = tempfile.mkdtemp() + args.append('--build-temp=' + tmp) + args.append('--build-lib=' + path) + args.append('-q') + args = dict( + name = 'tensorflow', + ext_modules = [ext], + script_args = args, + ) + setuptools.setup(**args) + shutil.rmtree(tmp) + +def make_tensorflow_op(src, outputs, grids): + bindings = make_bindings(src, outputs, grids) + cache_path = make_cache_path(bindings) + cpp, so = write_bindings(bindings, cache_path) + build(cpp, cache_path) + result = tf.load_op_library(so) + return result + +def empty(shapes): + return extra_ops.alloc_empty(tf.stack(shapes))