[test] added tests for copy

This commit is contained in:
Philippe Tillet
2019-09-23 12:07:24 -04:00
parent 001973630e
commit 856e7baa04
14 changed files with 449 additions and 170 deletions

View File

@@ -156,6 +156,8 @@ private:
Constant* llvm_constant(ir::constant *cst, LLVMContext &ctx);
Value* llvm_alloc_const(ir::alloc_const *v, Module *module, Builder &builder);
ArrayType* llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx);
Function* llvm_fn(ir::function *fn, Builder& builder, Module &dst);
Value* alloc_shared(Builder &builder, Module& dst);
// grid construction
void create_grids(std::vector<ir::value *> &grids,
@@ -167,7 +169,7 @@ private:
void init_strided_scan_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
void init_hmma_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
void init_axes(ir::value *i, Builder &builder, Value *u_thread_id, Value *u_warp_id);
void init_grids(ir::function *fn, Builder &builder, Value *sh_mem_ptr);
void init_layouts(ir::function *fn, Builder &builder, Value *sh_mem_ptr);
// lower scalar instruction
void lower_instruction(ir::instruction *src, Builder &builder);

View File

@@ -14,7 +14,7 @@ namespace ir {
}
namespace codegen{
namespace analysis{
namespace transform{
class cts {
public:

View File

@@ -573,51 +573,31 @@ inline int32_t ceil(int32_t num, int32_t div){
return (num + div - 1)/div;
}
inline void to_warps(const std::vector<unsigned> &bs, const std::vector<int>& order, std::vector<unsigned> &nw, std::vector<unsigned> &ws){
static const size_t warp_size = 32;
size_t nthreads = 1, nwarps = 1;
nw.resize(bs.size());
ws.resize(bs.size());
for(size_t i = 0; i < bs.size(); ++i){
nthreads *= bs[i];
nw[order[i]] = ceil(nthreads, nwarps*warp_size);
nwarps *= nw[order[i]];
}
for(size_t i = 0; i < bs.size(); ++i){
ws[i] = bs[i] / nw[i];
}
}
void selection::init_strided_scan_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id, Value *u_warp_id) {
auto order = tiles_->order(v);
const auto& shapes = v->get_type()->get_tile_shapes();
size_t dim = shapes.size();
std::vector<unsigned> contiguous(dim);
std::vector<unsigned> block_size(dim);
std::vector<unsigned> warp_size(dim);
std::vector<unsigned> n_warps(dim);
for(unsigned i = 0; i < shapes.size(); i++){
contiguous[i] = tiles_->nts(v, i);
block_size[i] = tiles_->mts(v, i);
}
to_warps(block_size, order, n_warps, warp_size);
std::vector<Value*> thread_id_in_warp = delinearize(u_thread_id, order, warp_size, builder);
std::vector<Value*> warp_id = delinearize(u_warp_id, order, n_warps, builder);
Value* full_thread_id = builder.CreateAdd(builder.CreateMul(u_warp_id, builder.getInt32(32)), u_thread_id);
std::vector<Value*> thread_id = delinearize(full_thread_id, order, block_size, builder);
// Create axes
for(unsigned k = 0; k < dim; k++) {
std::string str_k = std::to_string(k);
Value *warp_size_k = builder.getInt32(warp_size[k]);
Value *contiguous_k = builder.getInt32(contiguous[k]);
Value *thread_id = builder.CreateAdd(thread_id_in_warp[k], builder.CreateMul(warp_id[k], warp_size_k));
Value *scaled_thread_id = builder.CreateMul(thread_id, contiguous_k);
unsigned per_block = contiguous[k] * warp_size[k] * n_warps[k];
Value *scaled_thread_id = builder.CreateMul(thread_id[k], contiguous_k);
unsigned per_block = contiguous[k] * block_size[k];
unsigned per_thread = contiguous[k] * shapes[k] / per_block;
std::vector<Value*> idx_list(per_thread);
for(unsigned n = 0 ; n < per_thread; n++){
unsigned offset = n / contiguous[k] * per_block + n % contiguous[k];
idx_list[n] = builder.CreateAdd(scaled_thread_id, builder.getInt32(offset), "idx_" + str_k + "_" + std::to_string(n));
}
axes_[a_axes_->get_id(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id};
axes_[a_axes_->get_id(v, k)] = distributed_axis{contiguous[k], idx_list, thread_id[k]};
}
}
@@ -825,7 +805,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder,
create_distributed_tile(v, builder);
}
void selection::init_grids(ir::function *fn, IRBuilder<> &builder, Value *sh_mem_ptr){
void selection::init_layouts(ir::function *fn, IRBuilder<> &builder, Value *sh_mem_ptr){
// fetch linear ID
Module *mod = builder.GetInsertBlock()->getParent()->getParent();
Value *warp_size = builder.getInt32(32);
@@ -1454,84 +1434,83 @@ ArrayType* selection::llvm_linearized_tile_type(ir::type *ty, LLVMContext &ctx)
return ArrayType::get(llvm_type(ty->get_scalar_ty(), ctx), size);
}
Function* selection::llvm_fn(ir::function *fn, IRBuilder<>& builder, Module& dst) {
LLVMContext &ctx = builder.getContext();
FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), ctx);
FunctionType *dst_fn_ty = fn_ty;
if(!tgt_->is_gpu()){
Type *dst_fn_ret_ty = fn_ty->getReturnType();
std::vector<Type*> dst_fn_args_ty;
for(unsigned i = 0; i < fn_ty->getNumParams(); i++)
dst_fn_args_ty.push_back(fn_ty->getParamType(i));
dst_fn_args_ty.push_back(builder.getInt32Ty());
dst_fn_args_ty.push_back(builder.getInt32Ty());
dst_fn_args_ty.push_back(builder.getInt32Ty());
dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false);
}
Function *ret = Function::Create(dst_fn_ty, Function::ExternalLinkage, fn->get_name(), &dst);
// set attributes
for(auto attr_pair: fn->attrs()){
unsigned id = attr_pair.first;
for(ir::attribute attr: attr_pair.second)
if(attr.is_llvm_attr())
ret->addAttribute(id, llvm_attr(ctx, attr));
}
// set metadata
tgt_->set_kernel(builder, ctx, &dst, ret);
Metadata *md_args[] = {
ValueAsMetadata::get(ret),
MDString::get(ctx, "maxntidx"),
ValueAsMetadata::get(builder.getInt32(num_warps_*32))
};
dst.getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(ctx, md_args));
// map parameters
for(unsigned i = 0; i < fn->args().size(); i++)
vmap_[fn->args()[i]] = &*(ret->arg_begin() + i);
// create blocks
for(ir::basic_block *block: fn->blocks()) {
BasicBlock *dst_block = BasicBlock::Create(ctx, block->get_name(), ret);
vmap_[block] = dst_block;
}
builder.SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]);
}
Value* selection::alloc_shared(IRBuilder<> &builder, Module& dst) {
Value *ret = nullptr;
LLVMContext &ctx = builder.getContext();
if(tgt_->is_gpu())
if(unsigned alloc_size = alloc_->allocated_size()){
Type *int_8_ty = Type::getInt8Ty(ctx);
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
Type *ptr_ty = PointerType::get(int_8_ty, 3);
GlobalVariable *sh_mem_array =
new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
ret = builder.CreateBitCast(sh_mem_array, ptr_ty);
}
return ret;
}
void selection::run(ir::module &src, Module &dst) {
vmap_.clear();
tmap_.clear();
LLVMContext &dst_ctx = dst.getContext();
IRBuilder<> dst_builder(dst_ctx);
for(ir::alloc_const *x: src.allocs()) {
// constant memory
for(ir::alloc_const *x: src.allocs())
vmap_[x] = llvm_alloc_const(x, &dst, dst_builder);
}
// iterate over functions
for(ir::function *fn: src.get_function_list()) {
// create LLVM function
FunctionType *fn_ty = (FunctionType*)llvm_type(fn->get_fn_type(), dst_ctx);
FunctionType *dst_fn_ty = fn_ty;
if(!tgt_->is_gpu()){
Type *dst_fn_ret_ty = fn_ty->getReturnType();
std::vector<Type*> dst_fn_args_ty;
for(unsigned i = 0; i < fn_ty->getNumParams(); i++)
dst_fn_args_ty.push_back(fn_ty->getParamType(i));
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
dst_fn_args_ty.push_back(dst_builder.getInt32Ty());
dst_fn_ty = FunctionType::get(dst_fn_ret_ty, dst_fn_args_ty, false);
}
// grid indices
fn->get_fn_type()->get_return_ty();
Function *dst_fn = Function::Create(dst_fn_ty, Function::ExternalLinkage, fn->get_name(), &dst);
// set attributes
for(auto attr_pair: fn->attrs()){
unsigned id = attr_pair.first;
for(ir::attribute attr: attr_pair.second)
if(attr.is_llvm_attr()){
dst_fn->addAttribute(id, llvm_attr(dst_ctx, attr));
}
}
tgt_->set_kernel(dst_builder, dst_ctx, &dst, dst_fn);
// set metadata
Metadata *md_args[] = {
ValueAsMetadata::get(dst_fn),
MDString::get(dst_ctx, "maxntidx"),
ValueAsMetadata::get(dst_builder.getInt32(num_warps_*32))
};
dst.getOrInsertNamedMetadata("nvvm.annotations")->addOperand(MDNode::get(dst_ctx, md_args));
// map parameters
for(unsigned i = 0; i < fn->args().size(); i++)
vmap_[fn->args()[i]] = &*(dst_fn->arg_begin() + i);
// create blocks
for(ir::basic_block *block: fn->blocks()) {
BasicBlock *dst_block = BasicBlock::Create(dst_ctx, block->get_name(), dst_fn);
vmap_[block] = dst_block;
}
dst_builder.SetInsertPoint((BasicBlock*)vmap_[fn->blocks()[0]]);
llvm_fn(fn, dst_builder, dst);
// allocate shared memory
Value *sh_mem_ptr = nullptr;
if(tgt_->is_gpu())
if(unsigned alloc_size = alloc_->allocated_size()){
Type *int_8_ty = Type::getInt8Ty(dst_ctx);
ArrayType *array_ty = ArrayType::get(int_8_ty, alloc_size);
Type *ptr_ty = PointerType::get(int_8_ty, 3);
GlobalVariable *sh_mem_array =
new GlobalVariable(dst, array_ty, false, GlobalVariable::ExternalLinkage,
nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3);
sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty);
}
sh_mem_ptr_ = sh_mem_ptr;
// create grids
init_grids(fn, dst_builder, sh_mem_ptr);
// iterate through block
sh_mem_ptr_ = alloc_shared(dst_builder, dst);
// initialize layouts
init_layouts(fn, dst_builder, sh_mem_ptr_);
// generate LLVM-IR code
std::map<ir::basic_block*, BasicBlock*> last_block;
for(ir::basic_block *block: fn->blocks()) {
BasicBlock *parent = (BasicBlock*)vmap_[block];
@@ -1547,7 +1526,7 @@ void selection::run(ir::module &src, Module &dst) {
last_block[block] = dst_builder.GetInsertBlock();
}
}
// finalize double-buffering
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *inst: block->get_inst_list()) {
if(liveness_->has_double(inst)) {
@@ -1574,7 +1553,7 @@ void selection::run(ir::module &src, Module &dst) {
}
}
// add phi operands
// finalize phi
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *inst: block->get_inst_list())
if(auto *phi = dynamic_cast<ir::phi_node*>(inst)){

View File

@@ -9,9 +9,8 @@
#include "triton/ir/type.h"
namespace triton {
namespace codegen{
namespace analysis{
namespace transform{
// run pass on module
void add_copy(ir::instruction *parent, ir::value *x, ir::builder &builder) {

View File

@@ -199,7 +199,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
llvm::LLVMContext ctx;
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
// create passes
codegen::analysis::cts cts;
codegen::transform::cts cts;
codegen::analysis::align align;
codegen::analysis::liveness shmem_liveness;
codegen::analysis::axes axes;

View File

@@ -1,65 +1,35 @@
#include <cstring>
#include <sstream>
#include <cstdio>
#include <iostream>
#include <tuple>
#include "copy.h"
#include "triton/driver/backend.h"
#include "triton/driver/stream.h"
#include "triton/tools/bench.hpp"
#include "triton/external/half.hpp"
#include "triton/runtime/function.h"
#include "src/copy.h"
#include "util.h"
#include "cuda/cublas.h"
std::vector<double> do_bench(drv::stream* stream, int32_t M, int32_t N, order_t order_x, order_t order_y){
typedef float NumericT;
std::string ty = "float";
size_t dt_nbytes = sizeof(NumericT);
drv::context* context = stream->context();
// create inputs
auto dx = std::unique_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, M*N*dt_nbytes));
// create options
rt::function::options_space_t opt;
opt.defines.push_back({"TYPE", {ty}});
opt.defines.push_back({"STRIDE_XM", {(order_x == ROWMAJOR)?"M":"1"}});
opt.defines.push_back({"STRIDE_XN", {(order_x == ROWMAJOR)?"1":"N"}});
opt.defines.push_back({"STRIDE_YM", {(order_y == ROWMAJOR)?"M":"1"}});
opt.defines.push_back({"STRIDE_YN", {(order_y == ROWMAJOR)?"1":"N"}});
opt.defines.push_back({"TM", {"32"}});
opt.defines.push_back({"TN", {"32"}});
opt.num_warps = {4};
// create function
rt::function function(src::copy2d, opt);
// benchmark available libraries
std::vector<double> result;
auto gbps = [&](double ns) { return 2*M*N*dt_nbytes / (ns * 1e-9) * 1e-9; };
// triton
double triton_ns = triton::tools::bench([&]() { function({&*dx, &*dy, M, N}, grid2d(M, N), stream);}, stream);
result.push_back(gbps(triton_ns));
// done
return result;
}
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context);
// shapes to benchmark
typedef std::tuple<int, int, order_t, order_t> config_t;
typedef std::tuple<std::vector<int>, std::vector<int>, std::vector<int>> config_t;
std::vector<config_t> configs = {
{4096, 4096, ROWMAJOR, ROWMAJOR},
{4096, 4096, COLMAJOR, ROWMAJOR},
{4096, 4096, ROWMAJOR, COLMAJOR},
{4096, 4096, COLMAJOR, COLMAJOR},
{{4096*4096}, {0}, {0}},
{{4096, 4096}, {0, 1}, {1, 0}},
{{4096, 4096}, {0, 1}, {1, 0}},
{{4096, 4096}, {1, 0}, {0, 1}},
{{4096, 4096}, {0, 1}, {0, 1}},
{{256, 256, 256}, {0, 1, 2}, {0, 1, 2}},
{{256, 256, 256}, {0, 1, 2}, {0, 2, 1}},
{{256, 256, 256}, {1, 0, 2}, {1, 2, 0}},
{{256, 256, 256}, {1, 2, 0}, {1, 0, 2}},
{{256, 256, 256}, {2, 0, 1}, {0, 1, 2}},
{{256, 256, 256}, {2, 1, 0}, {0, 2, 1}}
};
// does the work
int32_t M, N;
order_t ord_x, ord_y;
std::vector<int32_t> shape;
std::vector<int32_t> ord_x, ord_y;
for(const auto& c: configs){
std::tie(M, N, ord_x, ord_y) = c;
std::cout << "// " << M << ", " << N << ", " << ord_x << ", " << ord_y << std::flush;
for(auto perf: do_bench(stream, M, N, ord_x, ord_y))
std::tie(shape, ord_x, ord_y) = c;
std::cout << "// " << c << std::flush;
for(auto perf: bench_copy_nd(stream, shape, ord_x, ord_y))
std::cout << ", " << perf << std::flush;
std::cout << std::endl;
}

142
tests/common/copy.h Normal file
View File

@@ -0,0 +1,142 @@
#include "src/copy.h"
#include "triton/driver/stream.h"
#include "triton/runtime/function.h"
#include "triton/tools/bench.hpp"
#include "util.h"
int32_t off(const std::vector<int32_t>& idx, const std::vector<int32_t>& strides) {
int32_t res = 0;
for(size_t d = 0; d < idx.size(); d++)
res += idx[d] * strides[d];
return res;
}
enum run_mode_t {
BENCH,
TEST
};
template<class T>
void cc_copy_nd(const std::vector<T>& x, std::vector<T>& y,
const std::vector<int32_t>& shape,
const std::vector<int32_t>& x_order, const std::vector<int32_t>& y_order) {
size_t rank = shape.size();
// strides for x
std::vector<int32_t> x_strides(shape.size());
for(size_t d = 0; d < rank; d++)
x_strides[x_order[d]] = (d == 0) ? 1 : (x_strides[x_order[d-1]] * shape[x_order[d-1]]);
// strides for y
std::vector<int32_t> y_strides(shape.size());
for(size_t d = 0; d < rank; d++)
y_strides[y_order[d]] = (d == 0) ? 1 : (y_strides[y_order[d-1]] * shape[y_order[d-1]]);
// copy 1d
if(rank == 1)
for(int32_t i = 0; i < shape[0]; i++)
y[off({i}, y_strides)] = x[off({i}, x_strides)];
// copy 2d
if(rank == 2)
for(int32_t i = 0; i < shape[0]; i++)
for(int32_t j = 0; j < shape[1]; j++)
y[off({i, j}, y_strides)] = x[off({i, j}, x_strides)];
// copy 3d
if(rank == 3)
for(int32_t i = 0; i < shape[0]; i++)
for(int32_t j = 0; j < shape[1]; j++)
for(int32_t k = 0; k < shape[2]; k++)
y[off({i, j, k}, y_strides)] = x[off({i, j, k}, x_strides)];
}
void triton_copy_nd(drv::stream* stream, const std::vector<int32_t>& shape,
const std::vector<int32_t>& x_order, const std::vector<int32_t>& y_order,
std::vector<std::vector<std::string>> TS,
run_mode_t mode, std::vector<double>& bench, bool &test) {
typedef float NumericT;
std::string ty = "float";
size_t dtsize = sizeof(NumericT);
drv::context* context = stream->context();
// rank
size_t rank = shape.size();
// size
size_t size = 1;
for(int32_t d: shape)
size *= d;
std::vector<std::string> shapename = {"S0", "S1", "S2"};
// strides for x
std::vector<std::string> x_strides = {"1"};
for(size_t d = 0; d < rank - 1; d++)
x_strides.push_back(x_strides[d] + " * " + shapename[x_order[d]]);
// strides for y
std::vector<std::string> y_strides = {"1"};
for(size_t d = 0; d < rank - 1; d++)
y_strides.push_back(y_strides[d] + " * " + shapename[y_order[d]]);
// create inputs
auto dx = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size*dtsize));
auto dy = std::unique_ptr<drv::buffer>(drv::buffer::create(context, size*dtsize));
// create options
rt::function::options_space_t opt;
// macros
opt.defines.push_back({"TYPE", {ty}});
for(size_t d = 0; d < rank; d++)
opt.defines.push_back({"STRIDE_XS" + std::to_string(x_order[d]), {x_strides[d]}});
for(size_t d = 0; d < rank; d++)
opt.defines.push_back({"STRIDE_YS" + std::to_string(y_order[d]), {y_strides[d]}});
if(TS.empty())
TS = tile_nd(rank);
for(size_t d = 0; d < rank; d++)
opt.defines.push_back({"TS" + std::to_string(d), TS[d]});
opt.num_warps = {4};
// kernel
rt::function function(src::copy_nd[rank - 1], opt);
std::vector<rt::arg> args = {&*dx, &*dy};
for(int32_t d: shape)
args.push_back(d);
std::vector<std::string> ts = {"TS0", "TS1", "TS2"};
auto grid = grid_nd(shape, ts);
// metrics
if(mode == BENCH){
auto gbps = [&](double ns) { return 2 * size * dtsize / (ns * 1e-9) * 1e-9; };
double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream);
bench.push_back(gbps(triton_ns));
}
// test triton
if(mode == TEST){
std::vector<NumericT> hx(size);
std::vector<NumericT> hy(size);
std::vector<NumericT> ry(size);
for(size_t i = 0; i < hx.size(); i++)
hx[i] = static_cast<NumericT>((float)rand()/RAND_MAX);
stream->write(&*dx, true, 0, hx);
function(args, grid, stream);
stream->synchronize();
stream->read(&*dy, true, 0, hy);
cc_copy_nd(hx, ry, shape, x_order, y_order);
test = testing::diff(hy, ry);
}
}
std::vector<double> bench_copy_nd(drv::stream* stream, const std::vector<int32_t>& shape,
const std::vector<int32_t>& x_order, const std::vector<int32_t>& y_order) {
std::vector<double> bench;
bool test;
triton_copy_nd(stream, shape, x_order, y_order, {}, BENCH, bench, test);
return bench;
}
bool test_copy_nd(drv::stream* stream, const std::vector<int32_t>& shape,
const std::vector<int32_t>& TS,
const std::vector<int32_t>& x_order, const std::vector<int32_t>& y_order) {
std::vector<double> bench;
bool test;
std::vector<std::vector<std::string>> TSS;
for(int32_t d: TS)
TSS.push_back({std::to_string(d)});
triton_copy_nd(stream, shape, x_order, y_order, TSS, TEST, bench, test);
return test;
}

View File

@@ -1,33 +1,66 @@
#ifndef _TRITON_TEST_SRC_COPY_H_
#define _TRITON_TEST_SRC_COPY_H_
namespace src {
const char *copy1d =
R"(
void copy1d(TYPE * X __noalias __readonly __aligned(16),
TYPE * Y __noalias __readonly __aligned(16),
int N) {
int ridm = get_program_id(0);
int rm[TN] = ridm * TN + 0 ... TN;
TYPE* px[TN] = X + rm;
TYPE* py[TN] = Y + rm;
int S0) {
int pid0 = get_program_id(0);
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
TYPE* px[TS0] = X + rs0;
TYPE* py[TS0] = Y + rs0;
*py = *px;
}
)";
const char *copy2d =
R"(
void copy2d(TYPE * X __noalias __readonly __aligned(16),
TYPE * Y __noalias __writeonly __aligned(16),
int M __multipleof(8),
int N __multipleof(8)) {
int ridm = get_program_id(0);
int ridn = get_program_id(1);
int rm[TM] = ridm * TM + 0 ... TM;
int rn[TN] = ridn * TN + 0 ... TN;
TYPE* px[TM, TN] = X + rm[:, newaxis] * STRIDE_XM + rn[newaxis, :] * STRIDE_XN;
TYPE* py[TM, TN] = Y + rm[:, newaxis] * STRIDE_YM + rn[newaxis, :] * STRIDE_YN;
int S0 __multipleof(8),
int S1 __multipleof(8)) {
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
int rs1[TS1] = pid1 * TS1 + 0 ... TS1;
TYPE* px[TS0, TS1] = X + rs0[:, newaxis] * STRIDE_XS0 + rs1[newaxis, :] * STRIDE_XS1;
TYPE* py[TS0, TS1] = Y + rs0[:, newaxis] * STRIDE_YS0 + rs1[newaxis, :] * STRIDE_YS1;
*py = *px;
}
)";
const char *copy3d =
R"(
void copy3d(TYPE * X __noalias __readonly __aligned(16),
TYPE * Y __noalias __writeonly __aligned(16),
int S0 __multipleof(8),
int S1 __multipleof(8),
int S2 __multipleof(8)) {
// program id
int pid0 = get_program_id(0);
int pid1 = get_program_id(1);
int pid2 = get_program_id(2);
// ranges
int rs0[TS0] = pid0 * TS0 + 0 ... TS0;
int rs1[TS1] = pid1 * TS1 + 0 ... TS1;
int rs2[TS2] = pid2 * TS2 + 0 ... TS2;
// X pointers
TYPE* px[TS0, TS1, TS2] = X + rs0[:, newaxis, newaxis] * STRIDE_XS0
+ rs1[newaxis, :, newaxis] * STRIDE_XS1
+ rs2[newaxis, newaxis, :] * STRIDE_XS2;
// Y pointers
TYPE* py[TS0, TS1, TS2] = Y + rs0[:, newaxis, newaxis] * STRIDE_YS0
+ rs1[newaxis, :, newaxis] * STRIDE_YS1
+ rs2[newaxis, newaxis, :] * STRIDE_YS2;
*py = *px;
}
)";
const char* copy_nd[] = {copy1d, copy2d, copy3d};
}
#endif

View File

@@ -4,6 +4,7 @@
#define _TRITON_TESTS_UTIL_H
#include <iomanip>
#include <cmath>
#include "triton/runtime/function.h"
namespace drv = triton::driver;
@@ -26,6 +27,30 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) {
};
}
inline rt::function::grid_fn_ty grid_nd(const std::vector<int32_t> &shape,
const std::vector<std::string>& ts) {
return [&shape, &ts](const rt::function::options_t& x) {
rt::grid_t ret;
for(size_t d = 0; d < shape.size(); d++)
ret.push_back(ceil(shape[d], x.D<int>(ts[d])));
return ret;
};
}
inline std::vector<std::vector<std::string>> tile_nd(size_t rank) {
assert(rank <= 3);
if(rank == 1)
return {{"128", "256", "512", "1024"}};
if(rank == 2)
return {{"16", "32", "64"},
{"16", "32", "64"}};
if(rank == 3)
return {{"4", "16", "32"},
{"4", "16", "32"},
{"4", "16", "32"}};
return {};
}
enum order_t {
ROWMAJOR,
COLMAJOR
@@ -44,17 +69,30 @@ struct gen_seq<0, Is...> : seq<Is...>{};
template<class Ch, class Tr, class Tuple, std::size_t... Is>
void print_tuple(std::basic_ostream<Ch,Tr>& os, Tuple const& t, seq<Is...>){
using swallow = int[];
(void)swallow{0, (void(os << (Is == 0? "" : ", ") << std::setfill(' ') << std::setw(3) << std::get<Is>(t)), 0)...};
(void)swallow{0, (void(os << (Is == 0? "" : ", ") << std::get<Is>(t)), 0)...};
}
} // aux::
template<class Ch, class Tr, class... Args>
auto operator<<(std::basic_ostream<Ch, Tr>& os, std::tuple<Args...> const& t)
-> std::basic_ostream<Ch, Tr>&
{
os << "(";
aux::print_tuple(os, t, aux::gen_seq<sizeof...(Args)>());
return os << ")";
return os;
}
template<class Ch, class Tr, class T>
auto operator<<(std::basic_ostream<Ch, Tr>& os, std::vector<T> const& t)
-> std::basic_ostream<Ch, Tr>&
{
os << "{";
for(size_t i = 0; i < t.size(); i++) {
if(i > 0)
os << ", ";
os << t[i];
}
return os << "}";
}

View File

@@ -1,4 +1,4 @@
foreach(PROG dot)
foreach(PROG dot copy1d copy2d copy3d)
set(TARGET unit_${PROG})
add_executable(${TARGET} ${PROG}.cc)
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET})

30
tests/unit/copy1d.cc Normal file
View File

@@ -0,0 +1,30 @@
#include <iostream>
#include <tuple>
#include "copy.h"
#include "triton/driver/backend.h"
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context);
// shapes to benchmark
typedef std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, std::vector<int>> config_t;
std::vector<config_t> configs = {
// {{65536}, {32}, {0}, {0}},
{{65536}, {128}, {0}, {0}},
{{65536}, {512}, {0}, {0}},
{{65536}, {1024}, {0}, {0}},
};
// does the work
std::vector<int32_t> shape, tile;
std::vector<int32_t> ord_x, ord_y;
bool result = true;
for(const auto& c: configs){
std::tie(shape, tile, ord_x, ord_y) = c;
bool pass = test_copy_nd(stream, shape, tile, ord_x, ord_y);
result = result && pass;
std::cout << "// " << c << ", " << pass << std::endl;
}
return result;
}

46
tests/unit/copy2d.cc Normal file
View File

@@ -0,0 +1,46 @@
#include <iostream>
#include <tuple>
#include "copy.h"
#include "triton/driver/backend.h"
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context);
// shapes to benchmark
typedef std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, std::vector<int>> config_t;
std::vector<config_t> configs = {
{{256, 256}, {16, 16}, {0, 1}, {0, 1}},
{{256, 256}, {16, 64}, {0, 1}, {0, 1}},
{{256, 256}, {64, 16}, {0, 1}, {0, 1}},
{{256, 256}, {64, 64}, {0, 1}, {0, 1}},
{{256, 256}, {16, 16}, {0, 1}, {1, 0}},
{{256, 256}, {16, 64}, {0, 1}, {1, 0}},
{{256, 256}, {64, 16}, {0, 1}, {1, 0}},
{{256, 256}, {64, 64}, {0, 1}, {1, 0}},
{{256, 256}, {16, 16}, {1, 0}, {0, 1}},
{{256, 256}, {16, 64}, {1, 0}, {0, 1}},
{{256, 256}, {64, 16}, {1, 0}, {0, 1}},
{{256, 256}, {64, 64}, {1, 0}, {0, 1}},
{{256, 256}, {64, 64}, {1, 0}, {1, 0}},
{{256, 256}, {16, 64}, {1, 0}, {1, 0}},
{{256, 256}, {64, 16}, {1, 0}, {1, 0}},
{{256, 256}, {64, 64}, {1, 0}, {1, 0}},
};
// does the work
std::vector<int32_t> shape, tile;
std::vector<int32_t> ord_x, ord_y;
bool result = true;
for(const auto& c: configs){
std::tie(shape, tile, ord_x, ord_y) = c;
bool pass = test_copy_nd(stream, shape, tile, ord_x, ord_y);
result = result && pass;
std::cout << "// " << c << ", " << pass << std::endl;
}
return result;
}

38
tests/unit/copy3d.cc Normal file
View File

@@ -0,0 +1,38 @@
#include <iostream>
#include <tuple>
#include "copy.h"
#include "triton/driver/backend.h"
int main() {
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
triton::driver::stream* stream = triton::driver::stream::create(context);
// shapes to benchmark
typedef std::tuple<std::vector<int>, std::vector<int>, std::vector<int>, std::vector<int>> config_t;
std::vector<config_t> configs;
std::vector<int> x_idx = {0, 1, 2};
do {
std::vector<int> y_idx = {0, 1, 2};
do {
configs.push_back(config_t{{64, 64, 32}, {16, 4, 8}, x_idx, y_idx});
configs.push_back(config_t{{64, 64, 32}, {8, 16, 2}, x_idx, y_idx});
configs.push_back(config_t{{64, 64, 32}, {32, 2, 2}, x_idx, y_idx});
configs.push_back(config_t{{64, 64, 32}, {16, 64, 4}, x_idx, y_idx});
} while(std::next_permutation(y_idx.begin(), y_idx.end()));
} while(std::next_permutation(x_idx.begin(), x_idx.end()));
// testing
std::vector<int32_t> shape, tile;
std::vector<int32_t> ord_x, ord_y;
bool result = true;
for(const auto& c: configs){
std::tie(shape, tile, ord_x, ord_y) = c;
bool pass = test_copy_nd(stream, shape, tile, ord_x, ord_y);
result = result && pass;
std::cout << "// " << c << ", " << pass << std::endl;
}
return result;
}

View File

@@ -32,7 +32,7 @@ static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vecto
float acc = 0;
for(size_t k = 0; k < K; k++)
acc = acc + (AT ? a[k + m*K] : a[m + k*M]) * (BT ? b[n + k*N] : b[k + n*K]);
c[m + n*M] = static_cast<T>(acc);
c[m*N + n] = static_cast<T>(acc);
}
}
@@ -120,7 +120,9 @@ int main() {
std::cout << "Testing " << c << " ... " << std::flush;
if(do_test(stream, AT, BT, M, N, K, TM, TN, TK, (size_t)nwarp))
std::cout << " Pass! " << std::endl;
else
else{
std::cout << " Fail! " << std::endl;
exit(EXIT_FAILURE);
}
}
}