[codegen] worked around bug seemingly from nvptx/ptxas by simplifying multiplications by 1:
- Generated LLVM-IR looked correct - Illegal addressing disappeared when running cuda-memcheck - Illegal addressing disappeared when using nvptx-short-pointer
This commit is contained in:
@@ -23,6 +23,7 @@ private:
|
|||||||
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
bool rewrite_dot_fp32(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||||
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
bool rewrite_dot_hmma(ir::dot_inst *dot, ir::builder& builder, bool trans_a, bool trans_b, ir::value *A, ir::value *B, ir::value *D);
|
||||||
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
|
bool rewrite_dot(ir::instruction *value, ir::builder& builder);
|
||||||
|
bool rewrite_mult(ir::instruction *value, ir::builder& builder);
|
||||||
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
bool rewrite_unit_red(ir::instruction *value, ir::builder& builder);
|
||||||
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
bool rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder);
|
||||||
|
|
||||||
|
@@ -95,7 +95,7 @@ void grids::init_c_graph(ir::instruction *v) {
|
|||||||
}
|
}
|
||||||
// Splat
|
// Splat
|
||||||
else if(dynamic_cast<ir::splat_inst*>(v)){
|
else if(dynamic_cast<ir::splat_inst*>(v)){
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
// Trans
|
// Trans
|
||||||
else if(auto *x = dynamic_cast<ir::trans_inst*>(v)){
|
else if(auto *x = dynamic_cast<ir::trans_inst*>(v)){
|
||||||
|
@@ -469,21 +469,21 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function<Value*(ir
|
|||||||
return (Instruction*)res;
|
return (Instruction*)res;
|
||||||
}
|
}
|
||||||
if(ir::atomic_add_inst* ii = dynamic_cast<ir::atomic_add_inst*>(inst)){
|
if(ir::atomic_add_inst* ii = dynamic_cast<ir::atomic_add_inst*>(inst)){
|
||||||
Value *ptr = value(ii->get_operand(0));
|
// Value *ptr = value(ii->get_operand(0));
|
||||||
Value *val = value(ii->get_operand(1));
|
// Value *val = value(ii->get_operand(1));
|
||||||
Value *atom_f_add = nullptr;
|
// Value *atom_f_add = nullptr;
|
||||||
if(val->getType()->isFloatTy())
|
// if(val->getType()->isFloatTy())
|
||||||
atom_f_add = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_atomic_load_add_f32, {ptr->getType()});
|
// atom_f_add = Intrinsic::getDeclaration(builder.GetInsertBlock()->getModule(), Intrinsic::nvvm_atomic_load_add_f32, {ptr->getType()});
|
||||||
else if(val->getType()->isHalfTy()){
|
// else if(val->getType()->isHalfTy()){
|
||||||
Type *fp16 = Type::getHalfTy(ctx);
|
// Type *fp16 = Type::getHalfTy(ctx);
|
||||||
|
|
||||||
FunctionType *atom_ty = FunctionType::get(fp16, {fp16->getPointerTo(), fp16}, false);
|
// FunctionType *atom_ty = FunctionType::get(fp16, {fp16->getPointerTo(), fp16}, false);
|
||||||
atom_f_add = InlineAsm::get(atom_ty, " atom.relaxed.global.gpu.add.noftz.f16 $0, [$1], $2;", "=h,l,h", true);
|
// atom_f_add = InlineAsm::get(atom_ty, " atom.relaxed.global.gpu.add.noftz.f16 $0, [$1], $2;", "=h,l,h", true);
|
||||||
}
|
// }
|
||||||
if(atom_f_add == nullptr)
|
// if(atom_f_add == nullptr)
|
||||||
throw std::runtime_error("unsupported atomic add");
|
throw std::runtime_error("unsupported");
|
||||||
Value *res = builder.CreateCall(atom_f_add, {ptr, val});
|
// Value *res = builder.CreateCall(atom_f_add, {ptr, val});
|
||||||
return (Instruction*)res;
|
// return (Instruction*)res;
|
||||||
}
|
}
|
||||||
if(ir::sqrt_inst* ii = dynamic_cast<ir::sqrt_inst*>(inst)){
|
if(ir::sqrt_inst* ii = dynamic_cast<ir::sqrt_inst*>(inst)){
|
||||||
Value *val = value(ii->get_operand(0));
|
Value *val = value(ii->get_operand(0));
|
||||||
|
@@ -169,6 +169,7 @@ bool peephole::rewrite_dot(ir::instruction *value, ir::builder& builder){
|
|||||||
return false;
|
return false;
|
||||||
ir::value *a = dot->get_operand(0);
|
ir::value *a = dot->get_operand(0);
|
||||||
ir::value *b = dot->get_operand(1);
|
ir::value *b = dot->get_operand(1);
|
||||||
|
builder.set_insert_point(add);
|
||||||
ir::value * new_dot = builder.insert(ir::dot_inst::create(a, b, other,
|
ir::value * new_dot = builder.insert(ir::dot_inst::create(a, b, other,
|
||||||
dot->is_a_trans(), dot->is_b_trans(),
|
dot->is_a_trans(), dot->is_b_trans(),
|
||||||
dot->get_name()));
|
dot->get_name()));
|
||||||
@@ -212,6 +213,30 @@ bool peephole::rewrite_unit_red(ir::instruction *value, ir::builder& builder){
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool peephole::rewrite_mult(ir::instruction *value, ir::builder& builder) {
|
||||||
|
auto binop = dynamic_cast<ir::binary_operator*>(value);
|
||||||
|
if(binop && binop->get_op() == ir::binary_op_t::Mul) {
|
||||||
|
ir::value *lhs = binop->get_operand(0);
|
||||||
|
ir::value *rhs = binop->get_operand(1);
|
||||||
|
ir::constant_int *_1_lhs = nullptr;
|
||||||
|
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(lhs))
|
||||||
|
_1_lhs = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||||
|
ir::constant_int *_1_rhs = nullptr;
|
||||||
|
if(ir::splat_inst *splat = dynamic_cast<ir::splat_inst*>(rhs))
|
||||||
|
_1_rhs = dynamic_cast<ir::constant_int*>(splat->get_operand(0));
|
||||||
|
if(_1_lhs){
|
||||||
|
binop->replace_all_uses_with(rhs);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
else if(_1_rhs){
|
||||||
|
binop->replace_all_uses_with(lhs);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) {
|
bool peephole::rewrite_gep_ptr_min_off_plus_off(ir::instruction *value, ir::builder& builder) {
|
||||||
auto x = dynamic_cast<ir::getelementptr_inst*>(value);
|
auto x = dynamic_cast<ir::getelementptr_inst*>(value);
|
||||||
if(!x)
|
if(!x)
|
||||||
@@ -250,8 +275,9 @@ void peephole::run(ir::module &mod) {
|
|||||||
if(seen.find(i) != seen.end())
|
if(seen.find(i) != seen.end())
|
||||||
continue;
|
continue;
|
||||||
bool was_modified = rewrite_dot(i, builder);
|
bool was_modified = rewrite_dot(i, builder);
|
||||||
if(was_modified)
|
if(was_modified){
|
||||||
seen.insert(i);
|
seen.insert(i);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}while(seen.size() != n_seen);
|
}while(seen.size() != n_seen);
|
||||||
|
|
||||||
@@ -265,6 +291,7 @@ void peephole::run(ir::module &mod) {
|
|||||||
if(seen.find(i) != seen.end())
|
if(seen.find(i) != seen.end())
|
||||||
continue;
|
continue;
|
||||||
bool was_modified = false;
|
bool was_modified = false;
|
||||||
|
was_modified = was_modified || rewrite_mult(i, builder);
|
||||||
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
was_modified = was_modified || rewrite_trans_phi(i, builder);
|
||||||
was_modified = was_modified || rewrite_unit_red(i, builder);
|
was_modified = was_modified || rewrite_unit_red(i, builder);
|
||||||
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
was_modified = was_modified || rewrite_gep_ptr_min_off_plus_off(i, builder);
|
||||||
|
@@ -218,29 +218,24 @@ ocl_module::ocl_module(driver::context * context, llvm::Module* src): module(con
|
|||||||
/* ------------------------ */
|
/* ------------------------ */
|
||||||
|
|
||||||
std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
std::string cu_module::compile_llvm_module(llvm::Module* module) {
|
||||||
// set data layout
|
// options
|
||||||
std::string layout = "e";
|
auto options = llvm::cl::getRegisteredOptions();
|
||||||
bool is_64bit = true;
|
static_cast<llvm::cl::opt<bool>*>(options["nvptx-short-ptr"])->setValue(true);
|
||||||
bool use_short_pointers = true;
|
// create
|
||||||
if (!is_64bit)
|
llvm::SmallVector<char, 0> buffer;
|
||||||
layout += "-p:32:32";
|
module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_70", "", buffer, "", Assembly);
|
||||||
else if (use_short_pointers)
|
std::string result(buffer.begin(), buffer.end());
|
||||||
layout += "-p3:32:32-p4:32:32-p5:32:32";
|
size_t start_replace = result.find(".version");
|
||||||
layout += "-i64:64-i128:128-v16:16-v32:32-n16:32:64";
|
size_t end_replace = result.find('\n', start_replace);
|
||||||
// create
|
assert(start_replace != std::string::npos);
|
||||||
llvm::SmallVector<char, 0> buffer;
|
result.replace(start_replace, end_replace - start_replace, ".version 6.4");
|
||||||
module::compile_llvm_module(module, "nvptx64-nvidia-cuda", "sm_70", layout, buffer, "", Assembly);
|
return result;
|
||||||
std::string result(buffer.begin(), buffer.end());
|
|
||||||
size_t start_replace = result.find(".version");
|
|
||||||
size_t end_replace = result.find('\n', start_replace);
|
|
||||||
assert(start_replace != std::string::npos);
|
|
||||||
result.replace(start_replace, end_replace - start_replace, ".version 6.4");
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
|
||||||
|
|
||||||
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
|
||||||
|
// std::cout << source << std::endl;
|
||||||
cu_context::context_switcher ctx_switch(*context);
|
cu_context::context_switcher ctx_switch(*context);
|
||||||
// JIT compile source-code
|
// JIT compile source-code
|
||||||
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};
|
||||||
|
@@ -49,8 +49,13 @@ void print(module &mod, std::ostream& os) {
|
|||||||
size_t num_ops = inst->get_num_operands();
|
size_t num_ops = inst->get_num_operands();
|
||||||
if(num_ops > 0)
|
if(num_ops > 0)
|
||||||
os << " ";;
|
os << " ";;
|
||||||
for(unsigned i = 0; i < num_ops; i++)
|
for(unsigned i = 0; i < num_ops; i++){
|
||||||
os << get_name(ops[i], cnt++) << (i < num_ops - 1?", ":"");
|
if(auto *x = dynamic_cast<ir::constant_int*>(ops[i]))
|
||||||
|
os << x->get_value();
|
||||||
|
else
|
||||||
|
os << get_name(ops[i], cnt++);
|
||||||
|
os << (i < num_ops - 1?", ":"");
|
||||||
|
}
|
||||||
os << ";" << std::endl;
|
os << ";" << std::endl;
|
||||||
}
|
}
|
||||||
os << std::endl;
|
os << std::endl;
|
||||||
|
@@ -217,6 +217,7 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
|||||||
dce.run(module);
|
dce.run(module);
|
||||||
vectorize.run(module);
|
vectorize.run(module);
|
||||||
dce.run(module);
|
dce.run(module);
|
||||||
|
// ir::print(module, std::cout);
|
||||||
// generate llvm code
|
// generate llvm code
|
||||||
llvm::LLVMContext ctx;
|
llvm::LLVMContext ctx;
|
||||||
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
std::unique_ptr<llvm::Module> llvm(new llvm::Module(module.get_name(), ctx));
|
||||||
|
@@ -1,23 +1,43 @@
|
|||||||
import triton
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
import triton
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
src = """
|
src = """
|
||||||
#if AT == 1
|
#if AT == 1
|
||||||
#define USEA ^a
|
#define USEA ^a
|
||||||
|
#define STRIDE_AK lda
|
||||||
|
#define STRIDE_AM 1
|
||||||
|
#define BROADCAST_AK :, newaxis
|
||||||
|
#define BROADCAST_AM newaxis, :
|
||||||
|
#define SHAPE_A TK, TM
|
||||||
#else
|
#else
|
||||||
#define USEA a
|
#define USEA a
|
||||||
|
#define STRIDE_AK 1
|
||||||
|
#define STRIDE_AM lda
|
||||||
|
#define BROADCAST_AK newaxis, :
|
||||||
|
#define BROADCAST_AM :, newaxis
|
||||||
|
#define SHAPE_A TM, TK
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if BT == 1
|
#if BT == 1
|
||||||
#define USEB ^b
|
#define USEB ^b
|
||||||
|
#define STRIDE_BK 1
|
||||||
|
#define STRIDE_BN ldb
|
||||||
|
#define BROADCAST_BK newaxis, :
|
||||||
|
#define BROADCAST_BN :, newaxis
|
||||||
|
#define SHAPE_B TN, TK
|
||||||
#else
|
#else
|
||||||
#define USEB b
|
#define USEB b
|
||||||
|
#define STRIDE_BK ldb
|
||||||
|
#define STRIDE_BN 1
|
||||||
|
#define BROADCAST_BK :, newaxis
|
||||||
|
#define BROADCAST_BN newaxis, :
|
||||||
|
#define SHAPE_B TK, TN
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void dot(TYPE * A __noalias __readonly __aligned(16),
|
void dot(TYPE * A,
|
||||||
TYPE * B __noalias __readonly __aligned(16),
|
TYPE * B,
|
||||||
TYPE * C __noalias __readonly __aligned(16),
|
TYPE * C,
|
||||||
int M, int N, int K,
|
int M, int N, int K,
|
||||||
int lda __multipleof(8),
|
int lda __multipleof(8),
|
||||||
int ldb __multipleof(8),
|
int ldb __multipleof(8),
|
||||||
@@ -31,42 +51,20 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
|
|||||||
int rka[TK] = 0 ... TK;
|
int rka[TK] = 0 ... TK;
|
||||||
int rkb[TK] = 0 ... TK;
|
int rkb[TK] = 0 ... TK;
|
||||||
float xc[TM, TN] = 0;
|
float xc[TM, TN] = 0;
|
||||||
|
/* pointers for operands */
|
||||||
/* pointers for A */
|
TYPE* pa[SHAPE_A] = A + rka[BROADCAST_AK] * STRIDE_AK + rxa[BROADCAST_AM] * STRIDE_AM;
|
||||||
#if AT == 1
|
TYPE* pb[SHAPE_B] = B + rkb[BROADCAST_BK] * STRIDE_BK + ryb[BROADCAST_BN] * STRIDE_BN;
|
||||||
TYPE* pa[TK, TM] = A + rka[:, newaxis]*lda + rxa[newaxis, :];
|
/* prefetches operands */
|
||||||
TYPE a[TK, TM] = *pa;
|
TYPE a[SHAPE_A] = *pa;
|
||||||
#else
|
TYPE b[SHAPE_B] = *pb;
|
||||||
TYPE* pa[TM, TK] = A + rka[newaxis, :] + rxa[:, newaxis]*lda;
|
|
||||||
TYPE a[TM, TK] = *pa;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/* pointers for B */
|
|
||||||
#if BT == 1
|
|
||||||
TYPE* pb[TN, TK] = B + rkb[newaxis, :] + ryb[:, newaxis]*ldb;
|
|
||||||
TYPE b[TN, TK] = *pb;
|
|
||||||
#else
|
|
||||||
TYPE* pb[TK, TN] = B + rkb[:, newaxis]*ldb + ryb[newaxis, :];
|
|
||||||
TYPE b[TK, TN] = *pb;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/* reduction loop */
|
/* reduction loop */
|
||||||
for(int k = K; k > 0; k = k - TK){
|
for(int k = K; k > 0; k = k - TK){
|
||||||
xc = USEA @ USEB + xc;
|
xc = USEA @ USEB + xc;
|
||||||
#if AT == 1
|
pa = pa + TK * STRIDE_AK;
|
||||||
pa = pa + TK*lda;
|
pb = pb + TK * STRIDE_BK;
|
||||||
#else
|
|
||||||
pa = pa + TK;
|
|
||||||
#endif
|
|
||||||
#if BT == 1
|
|
||||||
pb = pb + TK;
|
|
||||||
#else
|
|
||||||
pb = pb + TK*ldb;
|
|
||||||
#endif
|
|
||||||
a = *pa;
|
a = *pa;
|
||||||
b = *pb;
|
b = *pb;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* epilogue */
|
/* epilogue */
|
||||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||||
@@ -75,7 +73,7 @@ void dot(TYPE * A __noalias __readonly __aligned(16),
|
|||||||
bool checkc0[TM] = rxc < M;
|
bool checkc0[TM] = rxc < M;
|
||||||
bool checkc1[TN] = ryc < N;
|
bool checkc1[TN] = ryc < N;
|
||||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||||
*?(checkc) pc = c;
|
*pc = c;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -112,10 +110,12 @@ class dot_op:
|
|||||||
AT = self.trans_a, BT = self.trans_b, TYPE = tf.float16,
|
AT = self.trans_a, BT = self.trans_b, TYPE = tf.float16,
|
||||||
TM = [128], TN = [ 128], TK = [32])
|
TM = [128], TN = [ 128], TK = [32])
|
||||||
|
|
||||||
dot_nt = dot_op(False, True)
|
|
||||||
dot_nn = dot_op(False, False)
|
def dot(a, b, trans_a = False, trans_b = False):
|
||||||
dot_tn = dot_op(True, False)
|
if (trans_a, trans_b) not in dot.ops:
|
||||||
dot_tt = dot_op(True, True)
|
dot.ops[trans_a, trans_b] = dot_op(trans_a, trans_b)
|
||||||
|
return dot.ops[trans_a, trans_b](a, b)
|
||||||
|
dot.ops = dict()
|
||||||
|
|
||||||
# @triton.register_gradient(dot_op)
|
# @triton.register_gradient(dot_op)
|
||||||
# def _dot_grad(op, dy):
|
# def _dot_grad(op, dy):
|
||||||
@@ -127,9 +127,7 @@ def run_dot():
|
|||||||
M, N, K = 128, 128, 128
|
M, N, K = 128, 128, 128
|
||||||
a = tf.placeholder(tf.float16, shape=[M, K])
|
a = tf.placeholder(tf.float16, shape=[M, K])
|
||||||
b = tf.placeholder(tf.float16, shape=[N, K])
|
b = tf.placeholder(tf.float16, shape=[N, K])
|
||||||
# c = tf.matmul(a, b, transpose_a=True)
|
c = dot(a, b, trans_a = False, trans_b = True)
|
||||||
c = dot_nt(a, b)
|
|
||||||
# grads = tf.gradients(c, [a])
|
|
||||||
# Reference
|
# Reference
|
||||||
ha = np.random.rand(M, K).astype(np.float16)
|
ha = np.random.rand(M, K).astype(np.float16)
|
||||||
hb = np.random.rand(K, N).astype(np.float16)
|
hb = np.random.rand(K, N).astype(np.float16)
|
||||||
@@ -142,8 +140,6 @@ def run_dot():
|
|||||||
hresult = np.dot(ha, hb.T)
|
hresult = np.dot(ha, hb.T)
|
||||||
dif = np.abs(result - hresult)
|
dif = np.abs(result - hresult)
|
||||||
np.savetxt('dif.dat', dif, '%2.4f')
|
np.savetxt('dif.dat', dif, '%2.4f')
|
||||||
print(hresult)
|
|
||||||
print(result)
|
|
||||||
print("dif: %f" % np.max(dif))
|
print("dif: %f" % np.max(dif))
|
||||||
|
|
||||||
run_dot()
|
run_dot()
|
@@ -105,7 +105,8 @@ def _build(src, path, framework):
|
|||||||
if framework == tensorflow_id:
|
if framework == tensorflow_id:
|
||||||
_import_tensorflow()
|
_import_tensorflow()
|
||||||
library_dirs += [tensorflow.sysconfig.get_lib()]
|
library_dirs += [tensorflow.sysconfig.get_lib()]
|
||||||
include_dirs += [tensorflow.sysconfig.get_lib()]
|
include_dirs += [tensorflow.sysconfig.get_include()]
|
||||||
|
include_dirs += ['/usr/local/cuda/include/']
|
||||||
libraries += ['tensorflow_framework']
|
libraries += ['tensorflow_framework']
|
||||||
elif framework == torch_id:
|
elif framework == torch_id:
|
||||||
_import_torch()
|
_import_torch()
|
||||||
@@ -215,7 +216,7 @@ class op:
|
|||||||
self.fw_grids = dict()
|
self.fw_grids = dict()
|
||||||
self.src = src
|
self.src = src
|
||||||
self.outputs = outputs
|
self.outputs = outputs
|
||||||
self.framework = _find_framework(None)
|
self.framework = _find_framework(framework)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
# create a new op when defines are different
|
# create a new op when defines are different
|
||||||
|
Reference in New Issue
Block a user