[GENERAL] Various bugfixes
This commit is contained in:
committed by
Philippe Tillet
parent
50587bbf4b
commit
8f8d36c7a4
@@ -362,6 +362,30 @@ void generator::visit_unmasked_load_inst(ir::unmasked_load_inst* x) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
void generator::visit_masked_load_inst(ir::masked_load_inst* x) {
|
||||||
|
if(!x->get_type()->is_tile_ty()){
|
||||||
|
Value *ptr = vmap_.at(x->get_pointer_operand());
|
||||||
|
Value *mask = vmap_.at(x->get_mask_operand());
|
||||||
|
BasicBlock *current_bb = builder_->GetInsertBlock();
|
||||||
|
Function *parent = builder_->GetInsertBlock()->getParent();
|
||||||
|
BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent);
|
||||||
|
BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent);
|
||||||
|
builder_->CreateCondBr(mask, mask_then_bb, mask_done_bb);
|
||||||
|
builder_->SetInsertPoint(mask_then_bb);
|
||||||
|
Value *result_then = builder_->CreateLoad(ptr);
|
||||||
|
builder_->CreateBr(mask_done_bb);
|
||||||
|
builder_->SetInsertPoint(mask_done_bb);
|
||||||
|
Value *result = nullptr;
|
||||||
|
if(x->get_false_value_operand()){
|
||||||
|
Value *result_false = vmap_.at(x->get_false_value_operand());
|
||||||
|
result = builder_->CreatePHI(result_then->getType(), 2);
|
||||||
|
((PHINode*)result)->addIncoming(result_then, mask_then_bb);
|
||||||
|
((PHINode*)result)->addIncoming(result_false, current_bb);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
result = result_then;
|
||||||
|
vmap_[x] = result;
|
||||||
|
return;
|
||||||
|
}
|
||||||
// find vector size
|
// find vector size
|
||||||
ir::value *ptr = x->get_pointer_operand();
|
ir::value *ptr = x->get_pointer_operand();
|
||||||
auto order = layouts_->get(ptr)->get_order();
|
auto order = layouts_->get(ptr)->get_order();
|
||||||
@@ -677,6 +701,8 @@ void generator::visit_atomic_exch_inst(ir::atomic_exch_inst* xchg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
||||||
|
|
||||||
|
|
||||||
if(add->get_type()->is_tile_ty()){
|
if(add->get_type()->is_tile_ty()){
|
||||||
ir::value* ptr = add->get_operand(0);
|
ir::value* ptr = add->get_operand(0);
|
||||||
ir::value* val = add->get_operand(1);
|
ir::value* val = add->get_operand(1);
|
||||||
@@ -684,21 +710,36 @@ void generator::visit_atomic_add_inst(ir::atomic_add_inst* add) {
|
|||||||
distributed_tile* ptrs = (distributed_tile*)tmap_.at(ptr);
|
distributed_tile* ptrs = (distributed_tile*)tmap_.at(ptr);
|
||||||
distributed_tile* vals = (distributed_tile*)tmap_.at(val);
|
distributed_tile* vals = (distributed_tile*)tmap_.at(val);
|
||||||
distributed_tile* msks = (distributed_tile*)tmap_.at(msk);
|
distributed_tile* msks = (distributed_tile*)tmap_.at(msk);
|
||||||
|
|
||||||
for_each(ptr, [&](indices_t idx){
|
for_each(ptr, [&](indices_t idx){
|
||||||
Value *rmw_ptr = ptrs->get_value(idx);
|
Value *rmw_ptr = ptrs->get_value(idx);
|
||||||
Value *rmw_val = vals->get_value(idx);
|
Value *rmw_val = vals->get_value(idx);
|
||||||
Value *rmw_msk = msks->get_value(idx);
|
Value *rmw_msk = msks->get_value(idx);
|
||||||
BasicBlock *current_bb = builder_->GetInsertBlock();
|
// num bytes
|
||||||
Function *parent = builder_->GetInsertBlock()->getParent();
|
Type* ty = rmw_val->getType();
|
||||||
BasicBlock *mask_then_bb = BasicBlock::Create(*ctx_, "mask_then", parent);
|
size_t nbits = ty->getScalarSizeInBits();
|
||||||
BasicBlock *mask_done_bb = BasicBlock::Create(*ctx_, "mask_done", parent);
|
// extract pointer offset
|
||||||
builder_->CreateCondBr(rmw_msk, mask_then_bb, mask_done_bb);
|
std::string offset = "";
|
||||||
builder_->SetInsertPoint(mask_then_bb);
|
if(GetElementPtrInst *gep = dyn_cast<GetElementPtrInst>(rmw_ptr))
|
||||||
builder_->CreateAtomicRMW(AtomicRMWInst::FAdd, rmw_ptr, rmw_val,
|
if(gep->getNumIndices() == 1)
|
||||||
AtomicOrdering::Unordered,
|
if(ConstantInt *cst = dyn_cast<ConstantInt>(gep->idx_begin())){
|
||||||
SyncScope::System);
|
offset = " + " + std::to_string(cst->getValue().getSExtValue()*nbits/8);
|
||||||
builder_->CreateBr(mask_done_bb);
|
rmw_ptr = gep->getPointerOperand();
|
||||||
builder_->SetInsertPoint(mask_done_bb);
|
}
|
||||||
|
rmw_ptr = builder_->CreateBitCast(rmw_ptr, ty->getPointerTo(1));
|
||||||
|
// asm argument type
|
||||||
|
std::vector<Type*> arg_ty = {rmw_msk->getType(), rmw_ptr->getType(), rmw_val->getType()};
|
||||||
|
// asm function type
|
||||||
|
FunctionType *fn_ty = FunctionType::get(ty, arg_ty, false);
|
||||||
|
// asm string
|
||||||
|
std::string mod = nbits == 32 ? "" : ".noftz";
|
||||||
|
std::string asm_str = "@$0 atom.global.sys.add" + mod + ".f" + std::to_string(nbits) + " $1, [$2" + offset + "], $3;";
|
||||||
|
std::string ty_id = nbits == 32 ? "f" : "h";
|
||||||
|
std::string constraint = "b,=" + ty_id + ",l," + ty_id;
|
||||||
|
// create inline asm
|
||||||
|
InlineAsm *iasm = InlineAsm::get(fn_ty, asm_str, constraint, true);
|
||||||
|
// call asm
|
||||||
|
builder_->CreateCall(iasm, {rmw_msk, rmw_ptr, rmw_val});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
@@ -803,6 +844,7 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
|
|||||||
indices_t idx_b = {builder_->CreateAdd(offset_b_k, _K), current_offset_b_i};
|
indices_t idx_b = {builder_->CreateAdd(offset_b_k, _K), current_offset_b_i};
|
||||||
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
|
idx_a.insert(idx_a.end(), x.first.begin(), x.first.end());
|
||||||
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
|
idx_b.insert(idx_b.end(), x.first.begin(), x.first.end());
|
||||||
|
|
||||||
Value *ha = TA->get_value(idx_a);
|
Value *ha = TA->get_value(idx_a);
|
||||||
Value *hb = TB->get_value(idx_b);
|
Value *hb = TB->get_value(idx_b);
|
||||||
for(unsigned ii = 0; ii < hmma->pack_size_0_; ii++)
|
for(unsigned ii = 0; ii < hmma->pack_size_0_; ii++)
|
||||||
|
@@ -255,7 +255,6 @@ cu_module::cu_module(driver::context * context, std::unique_ptr<llvm::Module> ll
|
|||||||
|
|
||||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||||
cu_context::context_switcher ctx(*context);
|
cu_context::context_switcher ctx(*context);
|
||||||
// std::cout << source << std::endl;
|
|
||||||
// JIT compile source-code
|
// JIT compile source-code
|
||||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||||
unsigned int errbufsize = 8096;
|
unsigned int errbufsize = 8096;
|
||||||
@@ -264,10 +263,11 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo
|
|||||||
try{
|
try{
|
||||||
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
|
dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval);
|
||||||
}catch(exception::cuda::base const &){
|
}catch(exception::cuda::base const &){
|
||||||
#ifdef TRITON_LOG_PTX_ERROR
|
//#ifdef TRITON_LOG_PTX_ERROR
|
||||||
std::cerr << "Compilation Failed! Log: " << std::endl;
|
std::cout << source << std::endl;
|
||||||
|
std::cerr << "It appears that Triton produced invalid PTX code:" << std::endl;
|
||||||
std::cerr << errbuf << std::endl;
|
std::cerr << errbuf << std::endl;
|
||||||
#endif
|
//#endif
|
||||||
throw;
|
throw;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -231,7 +231,7 @@ void Generator::VisitConditionalOp(ConditionalOp* condOp) {
|
|||||||
VisitExpr(condOp->exprFalse_);
|
VisitExpr(condOp->exprFalse_);
|
||||||
ir::value* false_val = ret_;
|
ir::value* false_val = ret_;
|
||||||
if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) {
|
if(ir::unmasked_load_inst* ld = dynamic_cast<ir::unmasked_load_inst*>(true_val)) {
|
||||||
if(!false_val->get_type()->is_tile_ty())
|
if(true_val->get_type()->is_tile_ty() && !false_val->get_type()->is_tile_ty())
|
||||||
false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes());
|
false_val = bld_->create_splat(false_val, cond->get_type()->get_tile_shapes());
|
||||||
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(),
|
ir::value* new_ld = bld_->create_masked_load(ld->get_pointer_operand(),
|
||||||
cond,
|
cond,
|
||||||
|
@@ -238,8 +238,8 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module,
|
|||||||
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
if(allocation.allocated_size() > context->device()->max_shared_memory())
|
||||||
throw std::runtime_error("using too much shared memory");
|
throw std::runtime_error("using too much shared memory");
|
||||||
barriers.run(module);
|
barriers.run(module);
|
||||||
|
//ir::print(module, std::cout);
|
||||||
isel.visit(module, *llvm);
|
isel.visit(module, *llvm);
|
||||||
// ir::print(module, std::cout);
|
|
||||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@@ -364,6 +364,7 @@ std::string function::preheader() {
|
|||||||
|
|
||||||
DECLARATION(float, 64, 64);
|
DECLARATION(float, 64, 64);
|
||||||
DECLARATION(half , 64, 64);
|
DECLARATION(half , 64, 64);
|
||||||
|
DECLARATION(half , 128, 128);
|
||||||
|
|
||||||
extern int atomic_cas(int*, int, int);
|
extern int atomic_cas(int*, int, int);
|
||||||
extern int atomic_xchg(int*, int);
|
extern int atomic_xchg(int*, int);
|
||||||
|
@@ -3,16 +3,16 @@ import triton
|
|||||||
|
|
||||||
class _dot(torch.autograd.Function):
|
class _dot(torch.autograd.Function):
|
||||||
src = """
|
src = """
|
||||||
__global__ void dot(TYPE *A __noalias __readonly __aligned(16),
|
__global__ void dot(TYPE * A __noalias __readonly __aligned(16),
|
||||||
TYPE *B __noalias __readonly __aligned(16),
|
TYPE * B __noalias __readonly __aligned(16),
|
||||||
TYPE *C __noalias __aligned(16),
|
TYPE * C __noalias __aligned(16),
|
||||||
float alpha,
|
float alpha,
|
||||||
int M __retune,
|
int M __retune,
|
||||||
int N __retune,
|
int N __retune,
|
||||||
int K __retune,
|
int K __retune __multipleof(16),
|
||||||
int lda __multipleof(8),
|
int lda __multipleof(8),
|
||||||
int ldb __multipleof(8),
|
int ldb __multipleof(8),
|
||||||
int ldc __multipleof(8)) {
|
int ldc __multipleof(8)) {
|
||||||
// prologue
|
// prologue
|
||||||
int ridx = get_program_id(0);
|
int ridx = get_program_id(0);
|
||||||
int ridy = get_program_id(1);
|
int ridy = get_program_id(1);
|
||||||
@@ -95,11 +95,12 @@ class _dot(torch.autograd.Function):
|
|||||||
if dtype not in _dot.kernel:
|
if dtype not in _dot.kernel:
|
||||||
defines = {
|
defines = {
|
||||||
'TYPE' : dtype,
|
'TYPE' : dtype,
|
||||||
|
'SHAPE_A': 'TM, TK', 'SHAPE_B': 'TK, TN',
|
||||||
'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
|
'STRIDE_AM': 'lda', 'STRIDE_AK': '1',
|
||||||
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
|
'STRIDE_BN': '1', 'STRIDE_BK': 'ldb',
|
||||||
'TM' : [64, 128],
|
'TM' : [128],
|
||||||
'TN' : [64, 128],
|
'TN' : [128],
|
||||||
'TK' : [8, 16],
|
'TK' : [16],
|
||||||
'TZ' : [1]
|
'TZ' : [1]
|
||||||
}
|
}
|
||||||
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
|
_dot.kernel[dtype] = triton.kernel(_dot.src, num_warps=[4], defines=defines)
|
||||||
@@ -120,7 +121,7 @@ dot = _dot.apply
|
|||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|
||||||
M, N, K = 2048, 2048, 2048
|
M, N, K = 4096, 4096, 4096
|
||||||
a = torch.rand((M, K)).cuda().half()
|
a = torch.rand((M, K)).cuda().half()
|
||||||
b = torch.rand((K, N)).cuda().half()
|
b = torch.rand((K, N)).cuda().half()
|
||||||
|
|
||||||
@@ -130,4 +131,5 @@ b = torch.rand((K, N)).cuda().half()
|
|||||||
zc = torch.matmul(a,b)
|
zc = torch.matmul(a,b)
|
||||||
zc_ = dot(a,b)
|
zc_ = dot(a,b)
|
||||||
|
|
||||||
|
|
||||||
print(torch.allclose(zc, zc_))
|
print(torch.allclose(zc, zc_))
|
||||||
|
@@ -51,11 +51,6 @@ std::string get_fn_ptx(const map_key_t& key, const rt::function::options_t& opt)
|
|||||||
return id_fn_map[key]->ptx(&stream, opt);
|
return id_fn_map[key]->ptx(&stream, opt);
|
||||||
}
|
}
|
||||||
|
|
||||||
void register_cst(const map_key_t& key, const std::string& name, pybind11::buffer& data) {
|
|
||||||
pybind11::buffer_info info = data.request();
|
|
||||||
id_fn_map[key]->set_cst(name, info.ptr, info.size*info.itemsize);
|
|
||||||
}
|
|
||||||
|
|
||||||
void cleanup() {
|
void cleanup() {
|
||||||
id_grid_map.clear();
|
id_grid_map.clear();
|
||||||
id_fn_map.clear();
|
id_fn_map.clear();
|
||||||
@@ -134,7 +129,6 @@ PYBIND11_MODULE(libtriton, m) {
|
|||||||
m.def("register_grid", ®ister_grid);
|
m.def("register_grid", ®ister_grid);
|
||||||
m.def("delete_grid", &delete_grid);
|
m.def("delete_grid", &delete_grid);
|
||||||
m.def("register_fn", ®ister_fn);
|
m.def("register_fn", ®ister_fn);
|
||||||
m.def("register_cst", ®ister_cst);
|
|
||||||
m.def("delete_fn", &delete_fn);
|
m.def("delete_fn", &delete_fn);
|
||||||
m.def("make_op_id", &make_op_id);
|
m.def("make_op_id", &make_op_id);
|
||||||
m.def("cleanup", &cleanup);
|
m.def("cleanup", &cleanup);
|
||||||
|
@@ -31,19 +31,25 @@ CUstream torch_get_cuda_stream(int64_t dev_id) {
|
|||||||
return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
|
return (CUstream)at::cuda::getCurrentCUDAStream(dev_id).stream();
|
||||||
}
|
}
|
||||||
|
|
||||||
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args){
|
void launch_kernel(int64_t op_id, int64_t dev_id, const std::string& args,
|
||||||
|
const std::vector<std::string>& constant_names, const std::vector<torch::Tensor>& constant_vals){
|
||||||
|
rt::function* fn = id_fn_map.at({op_id, dev_id}).get();
|
||||||
|
for(size_t n = 0; n < constant_names.size(); n++){
|
||||||
|
const torch::Tensor& x = constant_vals[n];
|
||||||
|
fn->set_cst(constant_names[n], (char*)x.data_ptr(), x.numel()*x.element_size());
|
||||||
|
}
|
||||||
if(dev_id == -1){
|
if(dev_id == -1){
|
||||||
if(!host_stream){
|
if(!host_stream){
|
||||||
host_device.reset(new drv::host_device());
|
host_device.reset(new drv::host_device());
|
||||||
host_context.reset(drv::context::create(&*host_device));
|
host_context.reset(drv::context::create(&*host_device));
|
||||||
host_stream.reset(drv::stream::create(&*host_context));
|
host_stream.reset(drv::stream::create(&*host_context));
|
||||||
}
|
}
|
||||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
|
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &*host_stream);
|
||||||
}
|
}
|
||||||
else{
|
else{
|
||||||
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
triton::driver::cu_stream stream(torch_get_cuda_stream(dev_id), false);
|
||||||
triton::driver::context* ctx = stream.context();
|
triton::driver::context* ctx = stream.context();
|
||||||
(*id_fn_map.at({op_id, dev_id}))((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
|
(*fn)((void**)args.c_str(), args.size(), *id_grid_map.at({op_id, dev_id}), &stream);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -63,9 +63,6 @@ class kernel:
|
|||||||
size = sum([sizes[x] for x in arg_types])
|
size = sum([sizes[x] for x in arg_types])
|
||||||
self.tys = ''.join([codes[x] for x in arg_types])
|
self.tys = ''.join([codes[x] for x in arg_types])
|
||||||
|
|
||||||
def set_constant(self, device, name, value):
|
|
||||||
libtriton.register_cst((self.op_id, device), name, value)
|
|
||||||
|
|
||||||
def ptx(self, device, **kwargs):
|
def ptx(self, device, **kwargs):
|
||||||
dev_id = device.index
|
dev_id = device.index
|
||||||
libtriton.register_fn((self.op_id, dev_id), self.src, self.opt)
|
libtriton.register_fn((self.op_id, dev_id), self.src, self.opt)
|
||||||
@@ -103,5 +100,7 @@ class kernel:
|
|||||||
if 'autotune_buf' in kwargs:
|
if 'autotune_buf' in kwargs:
|
||||||
pass
|
pass
|
||||||
# launch
|
# launch
|
||||||
params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args])
|
params = pack(self.tys, *[x.data_ptr() if isinstance(x, torch.Tensor) else x for x in args])
|
||||||
torch.ops.triton.launch_kernel(self.op_id, device, params)
|
names = list(kwargs['constants'].keys()) if 'constants' in kwargs else []
|
||||||
|
constants = list(kwargs['constants'].values()) if 'constants' in kwargs else []
|
||||||
|
torch.ops.triton.launch_kernel(self.op_id, device, params, names, constants)
|
@@ -9,7 +9,7 @@ int main() {
|
|||||||
// shapes to benchmark
|
// shapes to benchmark
|
||||||
typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t;
|
typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t;
|
||||||
std::vector<config_t> configs;
|
std::vector<config_t> configs;
|
||||||
for(auto ord: std::vector<std::vector<int>>{{0, 1}})
|
for(auto ord: std::vector<std::vector<int>>{{1, 0}})
|
||||||
for(auto x: std::vector<std::array<bool, 2>>{{false, true}, {false, false}, {true, false}, {true, true}}){
|
for(auto x: std::vector<std::array<bool, 2>>{{false, true}, {false, false}, {true, false}, {true, true}}){
|
||||||
std::vector<config_t> tmp = {
|
std::vector<config_t> tmp = {
|
||||||
// config_t{ord, x[0], x[1], 128, 128, 128},
|
// config_t{ord, x[0], x[1], 128, 128, 128},
|
||||||
@@ -21,7 +21,7 @@ int main() {
|
|||||||
// config_t{ord, x[0], x[1], 1280, 1280, 1280},
|
// config_t{ord, x[0], x[1], 1280, 1280, 1280},
|
||||||
// config_t{ord, x[0], x[1], 1536, 1536, 1536},
|
// config_t{ord, x[0], x[1], 1536, 1536, 1536},
|
||||||
// config_t{ord, x[0], x[1], 2048, 2048, 2048},
|
// config_t{ord, x[0], x[1], 2048, 2048, 2048},
|
||||||
config_t{ord, x[0], x[1], 8192, 8192, 8192},
|
config_t{ord, x[0], x[1], 4096, 4096, 4096},
|
||||||
|
|
||||||
// config_t{ord, x[0], x[1], 256, 16, 256},
|
// config_t{ord, x[0], x[1], 256, 16, 256},
|
||||||
// config_t{ord, x[0], x[1], 512, 16, 512},
|
// config_t{ord, x[0], x[1], 512, 16, 512},
|
||||||
|
@@ -147,7 +147,7 @@ inline cublasGemmAlgo_t cublasGemmFastest(
|
|||||||
M, N, K,
|
M, N, K,
|
||||||
alpha, (const void*)A, cudt, lda,
|
alpha, (const void*)A, cudt, lda,
|
||||||
(const void*)B, cudt, ldb,
|
(const void*)B, cudt, ldb,
|
||||||
beta, (void*)C, cudt, ldc, cudt,
|
beta, (void*)C, cudt, ldc, CUDA_R_32F,
|
||||||
a); }, stream);
|
a); }, stream);
|
||||||
if(status != CUBLAS_STATUS_SUCCESS)
|
if(status != CUBLAS_STATUS_SUCCESS)
|
||||||
nanosec = INFINITY;
|
nanosec = INFINITY;
|
||||||
@@ -216,6 +216,6 @@ inline void cublasGemm(cublasDataType_t dtype,
|
|||||||
cublasStatus_t status = cublas::cublasGemmEx(handle, opa, opb, M, N, K,
|
cublasStatus_t status = cublas::cublasGemmEx(handle, opa, opb, M, N, K,
|
||||||
alpha, (const void*)*A->cu(), dtype, lda,
|
alpha, (const void*)*A->cu(), dtype, lda,
|
||||||
(const void*)*B->cu(), dtype, ldb,
|
(const void*)*B->cu(), dtype, ldb,
|
||||||
beta, (void*)*C->cu(), dtype, ldc, dtype, algo);
|
beta, (void*)*C->cu(), dtype, ldc, CUDA_R_32F, algo);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -152,16 +152,16 @@ void triton_dot(drv::stream* stream, bool AT, bool BT,
|
|||||||
bench.push_back(tflops(triton_ns));
|
bench.push_back(tflops(triton_ns));
|
||||||
|
|
||||||
// cublas
|
// cublas
|
||||||
// if(cublas::cublasinit()){
|
if(cublas::cublasinit()){
|
||||||
// T alpha(static_cast<double>(1));
|
T alpha(static_cast<double>(1));
|
||||||
// T beta(static_cast<double>(0));
|
T beta(static_cast<double>(0));
|
||||||
// cublasGemmAlgo_t fastest;
|
cublasGemmAlgo_t fastest;
|
||||||
// cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest);
|
||||||
// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K,
|
double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K,
|
||||||
// &alpha, &*da, lda, &*db, ldb, &beta, &*dc,
|
&alpha, &*da, lda, &*db, ldb, &beta, &*dc,
|
||||||
// ldc, nullptr, fastest); }, stream);
|
ldc, nullptr, fastest); }, stream);
|
||||||
// bench.push_back(tflops(cublas_ms));
|
bench.push_back(tflops(cublas_ms));
|
||||||
// }
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// test triton
|
// test triton
|
||||||
|
Reference in New Issue
Block a user