[python] added basic tensorflow support
This commit is contained in:
160
examples/cpp/cuda.h
Normal file
160
examples/cpp/cuda.h
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
/* Copyright 2015-2017 Philippe Tillet
|
||||||
|
*
|
||||||
|
* Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
* a copy of this software and associated documentation files
|
||||||
|
* (the "Software"), to deal in the Software without restriction,
|
||||||
|
* including without limitation the rights to use, copy, modify, merge,
|
||||||
|
* publish, distribute, sublicense, and/or sell copies of the Software,
|
||||||
|
* and to permit persons to whom the Software is furnished to do so,
|
||||||
|
* subject to the following conditions:
|
||||||
|
*
|
||||||
|
* The above copyright notice and this permission notice shall be
|
||||||
|
* included in all copies or substantial portions of the Software.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||||
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||||
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||||
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||||
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
#include <cassert>
|
||||||
|
#include "cublas_v2.h"
|
||||||
|
#include "triton/driver/buffer.h"
|
||||||
|
#include "triton/driver/stream.h"
|
||||||
|
#include "triton/driver/context.h"
|
||||||
|
#include "triton/tools/bench.hpp"
|
||||||
|
|
||||||
|
enum cublasStrategy_t{
|
||||||
|
CUBLAS_PREFER_FASTEST,
|
||||||
|
CUBLAS_HEURISTICS
|
||||||
|
};
|
||||||
|
|
||||||
|
enum DType{
|
||||||
|
HALF_TYPE,
|
||||||
|
FLOAT_TYPE,
|
||||||
|
DOUBLE_TYPE,
|
||||||
|
};
|
||||||
|
|
||||||
|
inline size_t size_of(DType dtype){
|
||||||
|
switch (dtype) {
|
||||||
|
case HALF_TYPE: return 2;
|
||||||
|
case FLOAT_TYPE: return 4;
|
||||||
|
case DOUBLE_TYPE: return 8;
|
||||||
|
default: throw;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<cublasGemmAlgo_t> gather_all_algos() {
|
||||||
|
std::vector<cublasGemmAlgo_t> result;
|
||||||
|
// non-tensor ops
|
||||||
|
for(int i = -1; i < 24; i++)
|
||||||
|
result.push_back((cublasGemmAlgo_t)i);
|
||||||
|
// tensor ops
|
||||||
|
for(int i = 99; i < 116; i++)
|
||||||
|
result.push_back((cublasGemmAlgo_t)i);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const std::vector<cublasGemmAlgo_t> algorithms = gather_all_algos();
|
||||||
|
|
||||||
|
static const std::map<DType, cudaDataType> cu_dtype = {
|
||||||
|
{HALF_TYPE, CUDA_R_16F},
|
||||||
|
{FLOAT_TYPE, CUDA_R_32F},
|
||||||
|
{DOUBLE_TYPE, CUDA_R_64F}
|
||||||
|
};
|
||||||
|
|
||||||
|
static const std::map<char, cublasOperation_t> cu_op = {
|
||||||
|
{false, CUBLAS_OP_N},
|
||||||
|
{true, CUBLAS_OP_T}
|
||||||
|
};
|
||||||
|
|
||||||
|
inline cublasGemmAlgo_t cublasGemmFastest(
|
||||||
|
triton::driver::stream* stream,
|
||||||
|
cublasHandle_t handle, cudaDataType cudt,
|
||||||
|
cublasOperation_t AT, cublasOperation_t BT,
|
||||||
|
int32_t M, int32_t N, int32_t K,
|
||||||
|
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
|
||||||
|
void* beta, CUdeviceptr C, int32_t ldc) {
|
||||||
|
|
||||||
|
// cache to avoid re-benchmarking
|
||||||
|
typedef std::tuple<cudaDataType_t,
|
||||||
|
cublasOperation_t, cublasOperation_t,
|
||||||
|
int32_t, int32_t, int32_t> key_t;
|
||||||
|
static std::map<key_t, cublasGemmAlgo_t> cache;
|
||||||
|
key_t key(cudt, AT, BT, M, N, K);
|
||||||
|
// benchmark algorithms if necessary
|
||||||
|
if(cache.find(key) == cache.end()){
|
||||||
|
std::vector<double> times;
|
||||||
|
for(cublasGemmAlgo_t a: algorithms) {
|
||||||
|
cublasStatus_t status;
|
||||||
|
double nanosec = triton::tools::bench([&](){ status = cublasGemmEx(handle, AT, BT,
|
||||||
|
M, N, K,
|
||||||
|
alpha, (const void*)A, cudt, lda,
|
||||||
|
(const void*)B, cudt, ldb,
|
||||||
|
beta, (void*)C, cudt, ldc, cudt,
|
||||||
|
a); }, stream);
|
||||||
|
if(status != CUBLAS_STATUS_SUCCESS)
|
||||||
|
nanosec = INFINITY;
|
||||||
|
}
|
||||||
|
size_t argmin = std::min_element(times.begin(), times.end()) - times.begin();
|
||||||
|
assert(times[argmin] != INFINITY);
|
||||||
|
cache.insert({key, algorithms[argmin]});
|
||||||
|
}
|
||||||
|
|
||||||
|
// return best algorithm
|
||||||
|
return cache.at(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Wrapper for cublasGemmEx */
|
||||||
|
inline cublasStatus_t cublasGemmEx(cublasHandle_t handle, cudaDataType cudt, cublasOperation_t AT, cublasOperation_t BT, int32_t M, int32_t N, int32_t K,
|
||||||
|
void* alpha, CUdeviceptr A, int32_t lda, CUdeviceptr B, int32_t ldb,
|
||||||
|
void* beta, CUdeviceptr C, int32_t ldc, cublasGemmAlgo_t algo)
|
||||||
|
{
|
||||||
|
cublasStatus_t status = cublasGemmEx(handle, AT, BT, M, N, K, alpha, (const void*)A, cudt, lda, (const void*)B, cudt, ldb, beta, (void*)C, cudt, ldc, cudt, algo);
|
||||||
|
if(status != CUBLAS_STATUS_SUCCESS){
|
||||||
|
std::cout << status;
|
||||||
|
exit(EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* Get cuBLAS handle */
|
||||||
|
cublasHandle_t cublasGetHandle(triton::driver::stream* stream) {
|
||||||
|
static std::map<CUstream, cublasHandle_t> cache;
|
||||||
|
CUstream key = *stream->cu();
|
||||||
|
|
||||||
|
// create handle if necessary
|
||||||
|
if(cache.find(key) == cache.end()) {
|
||||||
|
cublasHandle_t handle;
|
||||||
|
if(cublasCreate_v2(&handle) != CUBLAS_STATUS_SUCCESS)
|
||||||
|
throw std::runtime_error("Error: could not create cuBLAS handle");
|
||||||
|
cublasSetStream_v2(handle, key);
|
||||||
|
cache.insert({key, handle});
|
||||||
|
}
|
||||||
|
|
||||||
|
// return handle for the stream
|
||||||
|
return cache.at(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Simplified API for default GEMM */
|
||||||
|
inline void cublasGemm(DType dtype, triton::driver::stream* stream, bool AT, bool BT,
|
||||||
|
int32_t M, int32_t N, int32_t K,
|
||||||
|
void* alpha, triton::driver::buffer* A, int32_t lda,
|
||||||
|
triton::driver::buffer* B, int32_t ldb,
|
||||||
|
void* beta, triton::driver::buffer* C, int32_t ldc,
|
||||||
|
cublasGemmAlgo_t* fastest = NULL, cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT) {
|
||||||
|
triton::driver::cu_context::context_switcher scope(*stream->context());
|
||||||
|
static cublasHandle_t handle = cublasGetHandle(stream);
|
||||||
|
if(dtype == HALF_TYPE)
|
||||||
|
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
|
||||||
|
cublasStatus_t status;
|
||||||
|
if(fastest)
|
||||||
|
*fastest = cublasGemmFastest(stream, handle, cu_dtype.at(dtype), cu_op.at(AT), cu_op.at(BT), M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc);
|
||||||
|
else
|
||||||
|
status = cublasGemmEx(handle, cu_dtype.at(dtype), cu_op.at(AT), cu_op.at(BT), M, N, K, alpha, *A->cu(), lda, *B->cu(), ldb, beta, *C->cu(), ldc, algo);
|
||||||
|
}
|
@@ -88,8 +88,8 @@ void matmul(restrict read_only align(16) )" + a_ty + R"( *A,
|
|||||||
restrict read_only align(16) )" + c_ty + R"( *C,
|
restrict read_only align(16) )" + c_ty + R"( *C,
|
||||||
int M, int N, int K,
|
int M, int N, int K,
|
||||||
)" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc) {
|
)" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc) {
|
||||||
int ridx = get_range_id(0);
|
int ridx = get_program_id(0);
|
||||||
int ridy = get_range_id(1);
|
int ridy = get_program_id(1);
|
||||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||||
int ryb[TN] = ridy * TN + (0 ... TN);
|
int ryb[TN] = ridy * TN + (0 ... TN);
|
||||||
int rka[TK] = 0 ... TK;
|
int rka[TK] = 0 ... TK;
|
||||||
|
@@ -169,7 +169,7 @@ private:
|
|||||||
void lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
void lower_store(ir::store_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
void lower_downcast(ir::downcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
void lower_downcast(ir::downcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
void lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
void lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
void lower_dynamic_range_idx(ir::nv_dynamic_range_idx_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
void lower_dynamic_program_idx(ir::nv_dynamic_program_idx_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
void lower_reshape(ir::reshape_inst* x, LLVMContext &ctx, Function *fn, Builder &builder);
|
void lower_reshape(ir::reshape_inst* x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
void lower_splat(ir::splat_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
void lower_splat(ir::splat_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
void lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
void lower_broadcast(ir::broadcast_inst *x, LLVMContext &ctx, Function *fn, Builder &builder);
|
||||||
|
@@ -126,7 +126,7 @@ public:
|
|||||||
value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
value *create_reshape(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||||
value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = "");
|
||||||
// Built-in instruction
|
// Built-in instruction
|
||||||
value *create_get_range_id(unsigned axis, const std::string &name = "");
|
value *create_get_program_id(unsigned axis, const std::string &name = "");
|
||||||
value *create_get_num_program(unsigned axis, const std::string &name = "");
|
value *create_get_num_program(unsigned axis, const std::string &name = "");
|
||||||
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
|
value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = "");
|
||||||
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
|
value *create_atomic_exch(value *ptr, value *val, const std::string &name = "");
|
||||||
|
@@ -496,10 +496,10 @@ protected:
|
|||||||
using instruction::instruction;
|
using instruction::instruction;
|
||||||
};
|
};
|
||||||
|
|
||||||
class get_range_id_inst: public builtin_inst {
|
class get_program_id_inst: public builtin_inst {
|
||||||
private:
|
private:
|
||||||
get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
|
||||||
std::string repr_impl() const { return "get_range_id(" + std::to_string(axis_) + ")"; }
|
std::string repr_impl() const { return "get_program_id(" + std::to_string(axis_) + ")"; }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
|
||||||
@@ -668,23 +668,23 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
// On NVIDIA, implementation is such that
|
// On NVIDIA, implementation is such that
|
||||||
// constant_range = nv_dynamic_range_idx + nv_static_range_idx
|
// constant_range = nv_dynamic_program_idx + nv_static_program_idx
|
||||||
// so as to enable re-association on nv_static_range_idx which is constant
|
// so as to enable re-association on nv_static_program_idx which is constant
|
||||||
class nv_dynamic_range_idx_inst: public instruction {
|
class nv_dynamic_program_idx_inst: public instruction {
|
||||||
private:
|
private:
|
||||||
nv_dynamic_range_idx_inst(type *ty, const std::string &name, instruction *next);
|
nv_dynamic_program_idx_inst(type *ty, const std::string &name, instruction *next);
|
||||||
std::string repr_impl() const { return "nv_dynamic_range_idx"; }
|
std::string repr_impl() const { return "nv_dynamic_program_idx"; }
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static nv_dynamic_range_idx_inst* create(type *ty, const std::string &name = "", instruction *next = nullptr);
|
static nv_dynamic_program_idx_inst* create(type *ty, const std::string &name = "", instruction *next = nullptr);
|
||||||
};
|
};
|
||||||
|
|
||||||
class nv_static_range_idx: public constant {
|
class nv_static_program_idx: public constant {
|
||||||
private:
|
private:
|
||||||
nv_static_range_idx(constant_range *range);
|
nv_static_program_idx(constant_range *range);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static nv_static_range_idx *get(constant_range* range);
|
static nv_static_program_idx *get(constant_range* range);
|
||||||
constant_range* get_range() const;
|
constant_range* get_range() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@@ -71,9 +71,9 @@ private:
|
|||||||
const constant* size_;
|
const constant* size_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class get_range_id_expression: public builtin_expression{
|
class get_program_id_expression: public builtin_expression{
|
||||||
public:
|
public:
|
||||||
get_range_id_expression(node *axis): axis_((constant*)axis) { }
|
get_program_id_expression(node *axis): axis_((constant*)axis) { }
|
||||||
ir::value* codegen(ir::module *) const;
|
ir::value* codegen(ir::module *) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;}
|
|||||||
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
|
%token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP16 FP32 FP64
|
||||||
%token IF ELSE FOR CONTINUE WHILE
|
%token IF ELSE FOR CONTINUE WHILE
|
||||||
%token NEWAXIS ELLIPSIS AT
|
%token NEWAXIS ELLIPSIS AT
|
||||||
%token GET_NUM_PROGRAM GET_RANGE_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST RESHAPE
|
%token GET_NUM_PROGRAM GET_PROGRAM_ID DOT SQRT REDUCE_SUM TRANS MAX MIN SELECT ATOMIC_CAS ATOMIC_EXCH ATOMIC_ADD ALLOC_CONST RESHAPE
|
||||||
|
|
||||||
%start translation_unit
|
%start translation_unit
|
||||||
%%
|
%%
|
||||||
@@ -120,7 +120,7 @@ identifier
|
|||||||
|
|
||||||
/* Built-in */
|
/* Built-in */
|
||||||
builtin_expression
|
builtin_expression
|
||||||
: GET_RANGE_ID '(' constant ')' { $$ = new get_range_id_expression($3); }
|
: GET_PROGRAM_ID '(' constant ')' { $$ = new get_program_id_expression($3); }
|
||||||
| GET_NUM_PROGRAM '(' constant ')' { $$ = new get_num_program_expression($3); }
|
| GET_NUM_PROGRAM '(' constant ')' { $$ = new get_num_program_expression($3); }
|
||||||
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
| DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); }
|
||||||
| SQRT '(' expression ')' { $$ = new sqrt_expression($3); }
|
| SQRT '(' expression ')' { $$ = new sqrt_expression($3); }
|
||||||
|
@@ -43,7 +43,7 @@ using triton::lang::return_void;
|
|||||||
"float" { return return_impl(FP32, yytext); }
|
"float" { return return_impl(FP32, yytext); }
|
||||||
"double" { return return_impl(FP64, yytext); }
|
"double" { return return_impl(FP64, yytext); }
|
||||||
"..." { return return_impl(ELLIPSIS, yytext); }
|
"..." { return return_impl(ELLIPSIS, yytext); }
|
||||||
"get_range_id" { return return_impl(GET_RANGE_ID, yytext); }
|
"get_program_id" { return return_impl(GET_PROGRAM_ID, yytext); }
|
||||||
"get_num_program" { return return_impl(GET_NUM_PROGRAM, yytext); }
|
"get_num_program" { return return_impl(GET_NUM_PROGRAM, yytext); }
|
||||||
"__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); }
|
"__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); }
|
||||||
"__atomic_exch" { return return_impl(ATOMIC_EXCH, yytext); }
|
"__atomic_exch" { return return_impl(ATOMIC_EXCH, yytext); }
|
||||||
|
@@ -227,10 +227,10 @@ unsigned alignment_info::populate_starting_multiple(ir::value *v){
|
|||||||
if(auto *x = dynamic_cast<ir::constant_range*>(v)){
|
if(auto *x = dynamic_cast<ir::constant_range*>(v)){
|
||||||
return cache(x->get_first()->get_value());
|
return cache(x->get_first()->get_value());
|
||||||
}
|
}
|
||||||
if(auto *x = dynamic_cast<ir::nv_dynamic_range_idx_inst*>(v)){
|
if(auto *x = dynamic_cast<ir::nv_dynamic_program_idx_inst*>(v)){
|
||||||
return cache(128);
|
return cache(128);
|
||||||
}
|
}
|
||||||
if(auto *x = dynamic_cast<ir::nv_static_range_idx*>(v)){
|
if(auto *x = dynamic_cast<ir::nv_static_program_idx*>(v)){
|
||||||
return cache(x->get_range()->get_first()->get_value());
|
return cache(x->get_range()->get_first()->get_value());
|
||||||
}
|
}
|
||||||
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
|
if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)){
|
||||||
|
@@ -411,7 +411,7 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
|||||||
Value *else_value = value(ii->get_operand(2));
|
Value *else_value = value(ii->get_operand(2));
|
||||||
return builder.Insert(SelectInst::Create(pred, if_value, else_value));
|
return builder.Insert(SelectInst::Create(pred, if_value, else_value));
|
||||||
}
|
}
|
||||||
if(ir::get_range_id_inst* ii = dynamic_cast<ir::get_range_id_inst*>(inst)){
|
if(ir::get_program_id_inst* ii = dynamic_cast<ir::get_program_id_inst*>(inst)){
|
||||||
Value *result = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
|
Value *result = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis());
|
||||||
return (Instruction*)result;
|
return (Instruction*)result;
|
||||||
}
|
}
|
||||||
@@ -837,7 +837,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
|
|||||||
T->set_value(idx, idx[0]);
|
T->set_value(idx, idx[0]);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
if(is_inserted && dynamic_cast<ir::nv_static_range_idx*>(v)){
|
if(is_inserted && dynamic_cast<ir::nv_static_program_idx*>(v)){
|
||||||
T->for_each([&](indices_t idx){
|
T->for_each([&](indices_t idx){
|
||||||
assert(idx.size() == 1);
|
assert(idx.size() == 1);
|
||||||
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
BinaryOperator *bin_add = dyn_cast<BinaryOperator>(idx[0]);
|
||||||
@@ -996,7 +996,7 @@ void selection::lower_reduce(ir::reduce_inst *x, LLVMContext &ctx, Function *fn,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void selection::lower_dynamic_range_idx(ir::nv_dynamic_range_idx_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
void selection::lower_dynamic_program_idx(ir::nv_dynamic_program_idx_inst *x, LLVMContext &ctx, Function *fn, IRBuilder<> &builder) {
|
||||||
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
distributed_tile* result = (distributed_tile*)tmap_.at(x);
|
||||||
result->for_each([&](indices_t idx){
|
result->for_each([&](indices_t idx){
|
||||||
assert(idx.size() == 1);
|
assert(idx.size() == 1);
|
||||||
@@ -1418,8 +1418,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
|
|||||||
lower_downcast(x, ctx, fn, builder);
|
lower_downcast(x, ctx, fn, builder);
|
||||||
else if(auto *x = dynamic_cast<ir::reduce_inst*>(ins))
|
else if(auto *x = dynamic_cast<ir::reduce_inst*>(ins))
|
||||||
lower_reduce(x, ctx, fn, builder);
|
lower_reduce(x, ctx, fn, builder);
|
||||||
else if(auto *x = dynamic_cast<ir::nv_dynamic_range_idx_inst*>(ins))
|
else if(auto *x = dynamic_cast<ir::nv_dynamic_program_idx_inst*>(ins))
|
||||||
lower_dynamic_range_idx(x, ctx, fn, builder);
|
lower_dynamic_program_idx(x, ctx, fn, builder);
|
||||||
else if(auto *x = dynamic_cast<ir::reshape_inst*>(ins))
|
else if(auto *x = dynamic_cast<ir::reshape_inst*>(ins))
|
||||||
lower_reshape(x, ctx, fn, builder);
|
lower_reshape(x, ctx, fn, builder);
|
||||||
else if(auto *x = dynamic_cast<ir::splat_inst*>(ins))
|
else if(auto *x = dynamic_cast<ir::splat_inst*>(ins))
|
||||||
|
@@ -164,7 +164,7 @@ reassociate::reassociate(analysis::tune* params)
|
|||||||
void reassociate::run(ir::module &mod) {
|
void reassociate::run(ir::module &mod) {
|
||||||
ir::builder &builder = mod.get_builder();
|
ir::builder &builder = mod.get_builder();
|
||||||
|
|
||||||
// constant_range -> nv_dynamic_range_idx + nv_static_range_idx
|
// constant_range -> nv_dynamic_program_idx + nv_static_program_idx
|
||||||
for(ir::function *fn: mod.get_function_list()){
|
for(ir::function *fn: mod.get_function_list()){
|
||||||
std::vector<ir::constant_range*> ranges;
|
std::vector<ir::constant_range*> ranges;
|
||||||
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
std::vector<ir::basic_block*> rpo = ir::cfg::reverse_post_order(fn);
|
||||||
@@ -178,8 +178,8 @@ void reassociate::run(ir::module &mod) {
|
|||||||
|
|
||||||
builder.set_insert_point(rpo.front()->get_first_non_phi());
|
builder.set_insert_point(rpo.front()->get_first_non_phi());
|
||||||
for(ir::constant_range* old_range: ranges){
|
for(ir::constant_range* old_range: ranges){
|
||||||
ir::value* dyn_range = builder.insert(ir::nv_dynamic_range_idx_inst::create(old_range->get_type()));
|
ir::value* dyn_range = builder.insert(ir::nv_dynamic_program_idx_inst::create(old_range->get_type()));
|
||||||
ir::value* static_range = ir::nv_static_range_idx::get(old_range);
|
ir::value* static_range = ir::nv_static_program_idx::get(old_range);
|
||||||
ir::value* new_range = builder.create_add(dyn_range, static_range);
|
ir::value* new_range = builder.create_add(dyn_range, static_range);
|
||||||
old_range->replace_all_uses_with(new_range);
|
old_range->replace_all_uses_with(new_range);
|
||||||
params_->copy(dyn_range, old_range);
|
params_->copy(dyn_range, old_range);
|
||||||
|
@@ -82,7 +82,7 @@ void batchnorm_forward(float *Y, float *M, float *V,
|
|||||||
int rx[TM] = 0 ... TM;
|
int rx[TM] = 0 ... TM;
|
||||||
float *px[TM];
|
float *px[TM];
|
||||||
float x[TM] = 0;
|
float x[TM] = 0;
|
||||||
int c = get_range_id(1);
|
int c = get_program_id(1);
|
||||||
float g = *(G + c);
|
float g = *(G + c);
|
||||||
float b = *(B + c);
|
float b = *(B + c);
|
||||||
|
|
||||||
@@ -177,7 +177,7 @@ void batchnorm_backward(float *DX, float *DG, float *DB,
|
|||||||
restrict read_only float *V,
|
restrict read_only float *V,
|
||||||
int DHWN, float rcpDHWN, float epsilon) {
|
int DHWN, float rcpDHWN, float epsilon) {
|
||||||
int rx[TM] = 0 ... TM;
|
int rx[TM] = 0 ... TM;
|
||||||
int c = get_range_id(1);
|
int c = get_program_id(1);
|
||||||
int offset = c*DHWN;
|
int offset = c*DHWN;
|
||||||
float g = *(G + c);
|
float g = *(G + c);
|
||||||
float mean = *(M + c);
|
float mean = *(M + c);
|
||||||
|
@@ -114,11 +114,11 @@ std::string dot::triton_c_src_ydx() const {
|
|||||||
int lda, int ldb, int ldc,
|
int lda, int ldb, int ldc,
|
||||||
int N, int* lut,
|
int N, int* lut,
|
||||||
int* locks, int nlocks) {
|
int* locks, int nlocks) {
|
||||||
int ridx = get_range_id(0);
|
int ridx = get_program_id(0);
|
||||||
float acc[TM, TN] = 0;
|
float acc[TM, TN] = 0;
|
||||||
int rka[TK] = 0 ... TK;
|
int rka[TK] = 0 ... TK;
|
||||||
int rkb[TK] = 0 ... TK;
|
int rkb[TK] = 0 ... TK;
|
||||||
int *header = lut + get_range_id(1) * 4;
|
int *header = lut + get_program_id(1) * 4;
|
||||||
int offset = *(header + 0);
|
int offset = *(header + 0);
|
||||||
int K = *(header + 1);
|
int K = *(header + 1);
|
||||||
int column = *(header + 2);
|
int column = *(header + 2);
|
||||||
@@ -191,7 +191,7 @@ std::string dot::triton_c_src_dw() const {
|
|||||||
int lda, int ldb, int ldc,
|
int lda, int ldb, int ldc,
|
||||||
int N, int* lut,
|
int N, int* lut,
|
||||||
int* locks, int nlocks) {
|
int* locks, int nlocks) {
|
||||||
int ridx = get_range_id(0);
|
int ridx = get_program_id(0);
|
||||||
float acc[TM, TN] = 0;
|
float acc[TM, TN] = 0;
|
||||||
int rka[TK] = 0 ... TK;
|
int rka[TK] = 0 ... TK;
|
||||||
int rkb[TK] = 0 ... TK;
|
int rkb[TK] = 0 ... TK;
|
||||||
|
@@ -686,8 +686,8 @@ if(b_lut_){
|
|||||||
float* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
float* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis];
|
||||||
bool checkc0[TM] = rxc < M;
|
bool checkc0[TM] = rxc < M;
|
||||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||||
int ridx = get_range_id(0);
|
int ridx = get_program_id(0);
|
||||||
int ridy = get_range_id(1);
|
int ridy = get_program_id(1);
|
||||||
int *plock = locks + ridx + ridy*grid0;
|
int *plock = locks + ridx + ridy*grid0;
|
||||||
while(__atomic_cas(plock, 0, 1) == 1);
|
while(__atomic_cas(plock, 0, 1) == 1);
|
||||||
int *pcount = plock + grid0*grid1;
|
int *pcount = plock + grid0*grid1;
|
||||||
|
@@ -116,8 +116,8 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
|||||||
int M, int N, int K,
|
int M, int N, int K,
|
||||||
)" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc,
|
)" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc,
|
||||||
int bound, int *locks, int grid0, int grid1) {
|
int bound, int *locks, int grid0, int grid1) {
|
||||||
int ridx = get_range_id(0);
|
int ridx = get_program_id(0);
|
||||||
int ridy = get_range_id(1);
|
int ridy = get_program_id(1);
|
||||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||||
int ryb[TN] = ridy * TN + (0 ... TN);
|
int ryb[TN] = ridy * TN + (0 ... TN);
|
||||||
int rka[TK] = 0 ... TK;
|
int rka[TK] = 0 ... TK;
|
||||||
|
@@ -354,9 +354,9 @@ void shift(restrict read_only align(16) )" + a_ty_ + R"( *A,
|
|||||||
int BH, int BW,
|
int BH, int BW,
|
||||||
int CH, int CW,
|
int CH, int CW,
|
||||||
int* locks, int grid0, int grid1, int grid2) {
|
int* locks, int grid0, int grid1, int grid2) {
|
||||||
int ridx = get_range_id(0);
|
int ridx = get_program_id(0);
|
||||||
int ridy = get_range_id(1);
|
int ridy = get_program_id(1);
|
||||||
int rz = get_range_id(2);
|
int rz = get_program_id(2);
|
||||||
int rxa[TM] = ridx*TM + (0 ... TM);
|
int rxa[TM] = ridx*TM + (0 ... TM);
|
||||||
int ryb[TN] = ridy*TN + (0 ... TN);
|
int ryb[TN] = ridy*TN + (0 ... TN);
|
||||||
int rka[TK] = 0 ... TK;
|
int rka[TK] = 0 ... TK;
|
||||||
|
@@ -300,8 +300,8 @@ value *builder::create_downcast(value *arg, const std::string &name) {
|
|||||||
// built-in instructions
|
// built-in instructions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
value *builder::create_get_range_id(unsigned axis, const std::string &name) {
|
value *builder::create_get_program_id(unsigned axis, const std::string &name) {
|
||||||
return insert(get_range_id_inst::create(ctx_, axis, name));
|
return insert(get_program_id_inst::create(ctx_, axis, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
value *builder::create_get_num_program(unsigned axis, const std::string &name) {
|
value *builder::create_get_num_program(unsigned axis, const std::string &name) {
|
||||||
|
@@ -664,14 +664,14 @@ instruction* select_inst::create(value *pred, value *if_value, value *else_value
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
|
||||||
// get_range_id
|
// get_program_id
|
||||||
get_range_id_inst::get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
|
get_program_id_inst::get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next)
|
||||||
: builtin_inst(ty, 0, 1, name, next), axis_(axis){
|
: builtin_inst(ty, 0, 1, name, next), axis_(axis){
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
instruction* get_range_id_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
|
instruction* get_program_id_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) {
|
||||||
return new get_range_id_inst(type::get_int32_ty(ctx), axis, name, next);
|
return new get_program_id_inst(type::get_int32_ty(ctx), axis, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
// get_num_program
|
// get_num_program
|
||||||
@@ -745,25 +745,25 @@ barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instru
|
|||||||
return new barrier_inst(ctx, name, next);
|
return new barrier_inst(ctx, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
// nv_dynamic_range_idx
|
// nv_dynamic_program_idx
|
||||||
nv_dynamic_range_idx_inst::nv_dynamic_range_idx_inst(type *ty, const std::string &name, instruction *next)
|
nv_dynamic_program_idx_inst::nv_dynamic_program_idx_inst(type *ty, const std::string &name, instruction *next)
|
||||||
: instruction(ty, 0, 1, name, next) { }
|
: instruction(ty, 0, 1, name, next) { }
|
||||||
|
|
||||||
nv_dynamic_range_idx_inst* nv_dynamic_range_idx_inst::create(type *ty, const std::string &name, instruction *next) {
|
nv_dynamic_program_idx_inst* nv_dynamic_program_idx_inst::create(type *ty, const std::string &name, instruction *next) {
|
||||||
return new nv_dynamic_range_idx_inst(ty, name, next);
|
return new nv_dynamic_program_idx_inst(ty, name, next);
|
||||||
}
|
}
|
||||||
|
|
||||||
// nv_static_range_idx
|
// nv_static_program_idx
|
||||||
nv_static_range_idx::nv_static_range_idx(constant_range *range)
|
nv_static_program_idx::nv_static_program_idx(constant_range *range)
|
||||||
: constant(range->get_type(), 0), range_(range) { }
|
: constant(range->get_type(), 0), range_(range) { }
|
||||||
|
|
||||||
constant_range* nv_static_range_idx::get_range() const
|
constant_range* nv_static_program_idx::get_range() const
|
||||||
{ return range_; }
|
{ return range_; }
|
||||||
|
|
||||||
nv_static_range_idx* nv_static_range_idx::get(constant_range* range) {
|
nv_static_program_idx* nv_static_program_idx::get(constant_range* range) {
|
||||||
static std::map<constant_range*, nv_static_range_idx*> cache;
|
static std::map<constant_range*, nv_static_program_idx*> cache;
|
||||||
if(cache.find(range) == cache.end())
|
if(cache.find(range) == cache.end())
|
||||||
cache.insert({range, new nv_static_range_idx(range)});
|
cache.insert({range, new nv_static_program_idx(range)});
|
||||||
return cache.at(range);
|
return cache.at(range);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -115,9 +115,9 @@ ir::value* alloc_const_expression::codegen(ir::module *mod) const {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// get_range_id
|
// get_program_id
|
||||||
ir::value* get_range_id_expression::codegen(ir::module *mod) const {
|
ir::value* get_program_id_expression::codegen(ir::module *mod) const {
|
||||||
return mod->get_builder().create_get_range_id(axis_->value());
|
return mod->get_builder().create_get_program_id(axis_->value());
|
||||||
}
|
}
|
||||||
|
|
||||||
// get_num_program
|
// get_num_program
|
||||||
|
@@ -1,14 +1,6 @@
|
|||||||
import libtriton
|
import triton
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import distutils
|
|
||||||
import distutils.log
|
|
||||||
import setuptools.command.build_ext
|
|
||||||
import setuptools
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
import shutil
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
src = """
|
src = """
|
||||||
const tunable int TM = {128};
|
const tunable int TM = {128};
|
||||||
@@ -20,8 +12,8 @@ void matmul(restrict read_only align(16) half *A,
|
|||||||
restrict read_only align(16) half *C,
|
restrict read_only align(16) half *C,
|
||||||
int M, int N, int K,
|
int M, int N, int K,
|
||||||
multiple_of(8) int lda, multiple_of(8) int ldb, int ldc) {
|
multiple_of(8) int lda, multiple_of(8) int ldb, int ldc) {
|
||||||
int ridx = get_range_id(0);
|
int ridx = get_program_id(0);
|
||||||
int ridy = get_range_id(1);
|
int ridy = get_program_id(1);
|
||||||
int rxa[TM] = ridx * TM + (0 ... TM);
|
int rxa[TM] = ridx * TM + (0 ... TM);
|
||||||
int ryb[TN] = ridy * TN + (0 ... TN);
|
int ryb[TN] = ridy * TN + (0 ... TN);
|
||||||
int rka[TK] = 0 ... TK;
|
int rka[TK] = 0 ... TK;
|
||||||
@@ -40,7 +32,7 @@ void matmul(restrict read_only align(16) half *A,
|
|||||||
}
|
}
|
||||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||||
half* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
half* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis]*ldc;
|
||||||
half c[TM, TN] = xc;
|
half c[TM, TN] = xc;
|
||||||
bool checkc0[TM] = rxc < M;
|
bool checkc0[TM] = rxc < M;
|
||||||
bool checkc1[TN] = ryc < N;
|
bool checkc1[TN] = ryc < N;
|
||||||
@@ -49,100 +41,10 @@ void matmul(restrict read_only align(16) half *A,
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so')
|
|
||||||
|
|
||||||
|
|
||||||
def make_bindings(src, outputs, grids):
|
|
||||||
return libtriton.make_tensorflow_src(src, outputs, grids)
|
|
||||||
|
|
||||||
def make_cache_path(src):
|
|
||||||
md5 = hashlib.sha1(src.encode())
|
|
||||||
hexhash = md5.hexdigest()
|
|
||||||
home = os.path.expanduser('~')
|
|
||||||
cacheroot = os.path.join(home, '.triton', 'cache')
|
|
||||||
cachepath = os.path.join(cacheroot, str(hexhash))
|
|
||||||
if not os.path.exists(cachepath):
|
|
||||||
os.makedirs(cachepath)
|
|
||||||
print(cachepath)
|
|
||||||
return cachepath
|
|
||||||
|
|
||||||
def write_bindings(src, root):
|
|
||||||
cpp = os.path.join(root, 'tensorflow.cpp')
|
|
||||||
so = os.path.join(root, 'tensorflow.so')
|
|
||||||
recompile = False
|
|
||||||
# recompile if .so does not exist
|
|
||||||
if not os.path.exists(cpp) or not os.path.exists(so):
|
|
||||||
recompile = True
|
|
||||||
# recompile if cpp was modified after .so
|
|
||||||
elif max(cpp, so, key=os.path.getctime) == cpp:
|
|
||||||
recompile = True
|
|
||||||
# write cpp file
|
|
||||||
if recompile:
|
|
||||||
with open(cpp, 'w+') as handle:
|
|
||||||
handle.writelines(src)
|
|
||||||
# return path of cpp file
|
|
||||||
return cpp
|
|
||||||
|
|
||||||
def build(src, path):
|
|
||||||
# include directories
|
|
||||||
triton_include_dirs = ['/home/philippe/development/triton/include']
|
|
||||||
tensorflow_include_dirs = [tf.sysconfig.get_include()]
|
|
||||||
cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/']
|
|
||||||
include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs
|
|
||||||
# library directories
|
|
||||||
triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))]
|
|
||||||
tensorflow_library_dirs = [tf.sysconfig.get_lib()]
|
|
||||||
library_dirs = triton_library_dirs + tensorflow_library_dirs
|
|
||||||
# libraries
|
|
||||||
libraries = ['tensorflow_framework', 'triton']
|
|
||||||
# extra arguments
|
|
||||||
extra_compile_args = []
|
|
||||||
extra_link_args = []
|
|
||||||
# create extension module
|
|
||||||
ext = setuptools.Extension(
|
|
||||||
name = 'test',
|
|
||||||
language = 'c++',
|
|
||||||
sources = [src],
|
|
||||||
include_dirs = include_dirs,
|
|
||||||
extra_compile_args = extra_compile_args,
|
|
||||||
extra_link_args = extra_link_args,
|
|
||||||
library_dirs = library_dirs,
|
|
||||||
libraries = libraries
|
|
||||||
)
|
|
||||||
# build extension module
|
|
||||||
args = ['build_ext']
|
|
||||||
tmp = tempfile.mkdtemp()
|
|
||||||
args.append('--build-temp=' + tmp)
|
|
||||||
args.append('--build-lib=' + path)
|
|
||||||
args.append('-q')
|
|
||||||
args = dict(
|
|
||||||
name = 'test',
|
|
||||||
ext_modules = [ext],
|
|
||||||
script_args = args,
|
|
||||||
)
|
|
||||||
setuptools.setup(**args)
|
|
||||||
shutil.rmtree(tmp)
|
|
||||||
|
|
||||||
def make_tensorflow_op(src, outputs, grids):
|
|
||||||
bindings = make_bindings(src, outputs, grids)
|
|
||||||
cache_path = make_cache_path(bindings)
|
|
||||||
cpp = write_bindings(bindings, cache_path)
|
|
||||||
build(cpp, cache_path)
|
|
||||||
result = tf.load_op_library(os.path.join(cache_path, 'test.cpython-36m-x86_64-linux-gnu.so'))
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
library_dir = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
module = make_tensorflow_op(src, ['C'], ['(M + #TM - 1)/#TM', '(N + #TN - 1)/#TN'])
|
|
||||||
print(module.matmul)
|
|
||||||
|
|
||||||
|
|
||||||
class dot:
|
class dot:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
trans_a = True
|
self.matmul = triton.make_tensorflow_op(src, ['C'], ['(M + #TM - 1)/#TM', '(N + #TN - 1)/#TN'])
|
||||||
trans_b = False
|
|
||||||
|
|
||||||
def __call__(self, a, b):
|
def __call__(self, a, b):
|
||||||
shape_a = tf.shape(a)
|
shape_a = tf.shape(a)
|
||||||
@@ -152,17 +54,17 @@ class dot:
|
|||||||
N = shape_b[0]
|
N = shape_b[0]
|
||||||
lda = M
|
lda = M
|
||||||
ldb = K
|
ldb = K
|
||||||
ldc = M
|
ldc = N
|
||||||
c = extra_ops.alloc_empty(tf.stack([M, N]))
|
c = triton.empty([M, N])
|
||||||
return module.matmul(a, b, c, M, N, K, lda, ldb, ldc)
|
return self.matmul.matmul(a, b, c, M, N, K, lda, ldb, ldc)
|
||||||
|
|
||||||
dot_nt = dot()
|
dot_tn = dot()
|
||||||
def run_dot():
|
def run_dot():
|
||||||
M, N, K = 128, 128, 128
|
M, N, K = 128, 128, 128
|
||||||
a = tf.placeholder(tf.float16, shape=[M, K])
|
a = tf.placeholder(tf.float16, shape=[M, K])
|
||||||
b = tf.placeholder(tf.float16, shape=[N, K])
|
b = tf.placeholder(tf.float16, shape=[N, K])
|
||||||
# c = tf.matmul(a, b, transpose_a=True)
|
# c = tf.matmul(a, b, transpose_a=True)
|
||||||
c = dot_nt(a, b)
|
c = dot_tn(a, b)
|
||||||
# Reference
|
# Reference
|
||||||
ha = np.random.rand(M, K).astype(np.float16)
|
ha = np.random.rand(M, K).astype(np.float16)
|
||||||
hb = np.random.rand(N, K).astype(np.float16)
|
hb = np.random.rand(N, K).astype(np.float16)
|
||||||
@@ -172,7 +74,7 @@ def run_dot():
|
|||||||
result = sess.run([c], feed_dict = {a: ha,
|
result = sess.run([c], feed_dict = {a: ha,
|
||||||
b: hb})[0]
|
b: hb})[0]
|
||||||
# Test
|
# Test
|
||||||
hresult = np.dot(ha.T, hb).T
|
hresult = np.dot(ha.T, hb)
|
||||||
dif = np.abs(result - hresult)
|
dif = np.abs(result - hresult)
|
||||||
np.savetxt('dif.dat', dif, '%2.4f')
|
np.savetxt('dif.dat', dif, '%2.4f')
|
||||||
print(hresult)
|
print(hresult)
|
||||||
|
@@ -18,6 +18,7 @@ class CMakeExtension(Extension):
|
|||||||
|
|
||||||
|
|
||||||
class CMakeBuild(build_ext):
|
class CMakeBuild(build_ext):
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
try:
|
try:
|
||||||
out = subprocess.check_output(['cmake', '--version'])
|
out = subprocess.check_output(['cmake', '--version'])
|
||||||
@@ -80,6 +81,7 @@ setup(
|
|||||||
author_email='ptillet@g.harvard.edu',
|
author_email='ptillet@g.harvard.edu',
|
||||||
description='A language and compiler for custom Deep Learning operations',
|
description='A language and compiler for custom Deep Learning operations',
|
||||||
long_description='',
|
long_description='',
|
||||||
|
packages=['triton'],
|
||||||
ext_modules=[CMakeExtension('triton')],
|
ext_modules=[CMakeExtension('triton')],
|
||||||
cmdclass=dict(build_ext=CMakeBuild),
|
cmdclass=dict(build_ext=CMakeBuild),
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
|
@@ -177,8 +177,8 @@ result += R"(
|
|||||||
|
|
||||||
|
|
||||||
std::regex regex("#([a-zA-Z]([a-zA-Z]|[0-9])*)");
|
std::regex regex("#([a-zA-Z]([a-zA-Z]|[0-9])*)");
|
||||||
std::vector<std::string> grids;
|
std::vector<std::string> grids = macros;
|
||||||
for(size_t i = macros.size(); i < 3; i++)
|
for(size_t i = grids.size(); i < 3; i++)
|
||||||
grids.push_back("1");
|
grids.push_back("1");
|
||||||
std::string grid = "rt::grid_t{";
|
std::string grid = "rt::grid_t{";
|
||||||
for(size_t i = 0; i < grids.size(); i++){
|
for(size_t i = 0; i < grids.size(); i++){
|
||||||
|
1
python/triton/__init__.py
Normal file
1
python/triton/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .ops import *
|
103
python/triton/ops.py
Normal file
103
python/triton/ops.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
# import for cache
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import shutil
|
||||||
|
import hashlib
|
||||||
|
import sysconfig
|
||||||
|
import sys
|
||||||
|
# import for just-in-time compilation
|
||||||
|
import distutils
|
||||||
|
import setuptools.command.build_ext
|
||||||
|
import setuptools
|
||||||
|
# triton
|
||||||
|
import libtriton
|
||||||
|
# frameworks
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so')
|
||||||
|
|
||||||
|
|
||||||
|
def make_bindings(src, outputs, grids):
|
||||||
|
return libtriton.make_tensorflow_src(src, outputs, grids)
|
||||||
|
|
||||||
|
def make_cache_path(src):
|
||||||
|
md5 = hashlib.sha1(src.encode())
|
||||||
|
hexhash = md5.hexdigest()
|
||||||
|
home = os.path.expanduser('~')
|
||||||
|
cacheroot = os.path.join(home, '.triton', 'cache')
|
||||||
|
cachepath = os.path.join(cacheroot, str(hexhash))
|
||||||
|
if not os.path.exists(cachepath):
|
||||||
|
os.makedirs(cachepath)
|
||||||
|
return cachepath
|
||||||
|
|
||||||
|
def write_bindings(src, root):
|
||||||
|
cpp = os.path.join(root, 'tensorflow.cpp')
|
||||||
|
suffix = sysconfig.get_config_var('EXT_SUFFIX')
|
||||||
|
so = os.path.join(root, 'tensorflow{suffix}'.format(suffix=suffix))
|
||||||
|
recompile = False
|
||||||
|
# recompile if .so does not exist
|
||||||
|
if not os.path.exists(cpp) or not os.path.exists(so):
|
||||||
|
recompile = True
|
||||||
|
# recompile if cpp was modified after .so
|
||||||
|
elif max(cpp, so, key=os.path.getctime) == cpp:
|
||||||
|
recompile = True
|
||||||
|
# write cpp file
|
||||||
|
if recompile:
|
||||||
|
with open(cpp, 'w+') as handle:
|
||||||
|
handle.writelines(src)
|
||||||
|
# return path of cpp file
|
||||||
|
return (cpp, so)
|
||||||
|
|
||||||
|
def build(src, path):
|
||||||
|
# include directories
|
||||||
|
triton_include_dirs = ['/home/philippe/development/triton/include']
|
||||||
|
tensorflow_include_dirs = [tf.sysconfig.get_include()]
|
||||||
|
cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/']
|
||||||
|
include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs
|
||||||
|
# library directories
|
||||||
|
triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))]
|
||||||
|
tensorflow_library_dirs = [tf.sysconfig.get_lib()]
|
||||||
|
library_dirs = triton_library_dirs + tensorflow_library_dirs
|
||||||
|
# libraries
|
||||||
|
libraries = ['tensorflow_framework', 'triton']
|
||||||
|
# extra arguments
|
||||||
|
extra_compile_args = []
|
||||||
|
extra_link_args = []
|
||||||
|
# dependences
|
||||||
|
depends = [os.path.realpath(libtriton.__file__)]
|
||||||
|
# create extension module
|
||||||
|
ext = setuptools.Extension(
|
||||||
|
name = 'tensorflow',
|
||||||
|
language = 'c++',
|
||||||
|
sources = [src],
|
||||||
|
include_dirs = include_dirs,
|
||||||
|
extra_compile_args = extra_compile_args,
|
||||||
|
extra_link_args = extra_link_args,
|
||||||
|
library_dirs = library_dirs,
|
||||||
|
libraries = libraries,
|
||||||
|
depends = depends
|
||||||
|
)
|
||||||
|
# build extension module
|
||||||
|
args = ['build_ext']
|
||||||
|
tmp = tempfile.mkdtemp()
|
||||||
|
args.append('--build-temp=' + tmp)
|
||||||
|
args.append('--build-lib=' + path)
|
||||||
|
args.append('-q')
|
||||||
|
args = dict(
|
||||||
|
name = 'tensorflow',
|
||||||
|
ext_modules = [ext],
|
||||||
|
script_args = args,
|
||||||
|
)
|
||||||
|
setuptools.setup(**args)
|
||||||
|
shutil.rmtree(tmp)
|
||||||
|
|
||||||
|
def make_tensorflow_op(src, outputs, grids):
|
||||||
|
bindings = make_bindings(src, outputs, grids)
|
||||||
|
cache_path = make_cache_path(bindings)
|
||||||
|
cpp, so = write_bindings(bindings, cache_path)
|
||||||
|
build(cpp, cache_path)
|
||||||
|
result = tf.load_op_library(so)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def empty(shapes):
|
||||||
|
return extra_ops.alloc_empty(tf.stack(shapes))
|
Reference in New Issue
Block a user