[PYTHON][EXAMPLES] Tentative support for einsum with transpositions
This commit is contained in:
@@ -430,7 +430,6 @@ public:
|
|||||||
void AddrOpTypeChecking();
|
void AddrOpTypeChecking();
|
||||||
void DerefOpTypeChecking();
|
void DerefOpTypeChecking();
|
||||||
void ReduceOpTypeChecking();
|
void ReduceOpTypeChecking();
|
||||||
void TransOpTypeChecking();
|
|
||||||
void UnaryArithmOpTypeChecking();
|
void UnaryArithmOpTypeChecking();
|
||||||
void CastOpTypeChecking();
|
void CastOpTypeChecking();
|
||||||
|
|
||||||
@@ -448,6 +447,28 @@ protected:
|
|||||||
Expr* operand_;
|
Expr* operand_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class TransOp: public Expr {
|
||||||
|
friend class Generator;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using PermInt = std::vector<int>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
static TransOp* New(const PermInt& perm, Expr* operand);
|
||||||
|
const PermInt& getPerm() const { return perm_; }
|
||||||
|
void Accept(Visitor* v);
|
||||||
|
bool IsLVal() { return false; }
|
||||||
|
void TypeChecking();
|
||||||
|
|
||||||
|
protected:
|
||||||
|
TransOp(const PermInt& perm, Expr* operand)
|
||||||
|
: Expr(operand->Tok(), nullptr), operand_(operand), perm_(perm) {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Expr* operand_;
|
||||||
|
PermInt perm_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
// cond ? true : false
|
// cond ? true : false
|
||||||
class ConditionalOp : public Expr {
|
class ConditionalOp : public Expr {
|
||||||
|
@@ -58,6 +58,7 @@ public:
|
|||||||
// Expression
|
// Expression
|
||||||
void VisitBinaryOp(BinaryOp* binaryOp);
|
void VisitBinaryOp(BinaryOp* binaryOp);
|
||||||
void VisitUnaryOp(UnaryOp* unaryOp);
|
void VisitUnaryOp(UnaryOp* unaryOp);
|
||||||
|
void VisitTransOp(TransOp* transOp);
|
||||||
void VisitConditionalOp(ConditionalOp* condOp);
|
void VisitConditionalOp(ConditionalOp* condOp);
|
||||||
void VisitFuncCall(FuncCall* funcCall);
|
void VisitFuncCall(FuncCall* funcCall);
|
||||||
void VisitObject(Object* obj);
|
void VisitObject(Object* obj);
|
||||||
@@ -130,6 +131,7 @@ public:
|
|||||||
|
|
||||||
void VisitConditionalOp(ConditionalOp*) { should_not_happen(); }
|
void VisitConditionalOp(ConditionalOp*) { should_not_happen(); }
|
||||||
void VisitFuncCall(FuncCall*) { should_not_happen(); }
|
void VisitFuncCall(FuncCall*) { should_not_happen(); }
|
||||||
|
void VisitTransOp(TransOp*) { should_not_happen(); }
|
||||||
void VisitEnumerator(Enumerator*) { should_not_happen(); }
|
void VisitEnumerator(Enumerator*) { should_not_happen(); }
|
||||||
void VisitConstant(Constant*) { should_not_happen(); }
|
void VisitConstant(Constant*) { should_not_happen(); }
|
||||||
void VisitTempVar(TempVar*) { should_not_happen(); }
|
void VisitTempVar(TempVar*) { should_not_happen(); }
|
||||||
|
@@ -30,6 +30,9 @@ public:
|
|||||||
virtual void VisitIdentifier(Identifier* ident) {
|
virtual void VisitIdentifier(Identifier* ident) {
|
||||||
Error(ident, "expect constant expression");
|
Error(ident, "expect constant expression");
|
||||||
}
|
}
|
||||||
|
virtual void VisitTransOp(TransOp* trans) {
|
||||||
|
Error(trans, "expect constant expression");
|
||||||
|
}
|
||||||
virtual void VisitObject(Object* obj) {
|
virtual void VisitObject(Object* obj) {
|
||||||
Error(obj, "expect constant expression");
|
Error(obj, "expect constant expression");
|
||||||
}
|
}
|
||||||
@@ -83,6 +86,9 @@ public:
|
|||||||
virtual void VisitFuncCall(FuncCall* funcCall) {
|
virtual void VisitFuncCall(FuncCall* funcCall) {
|
||||||
Error(funcCall, "expect constant expression");
|
Error(funcCall, "expect constant expression");
|
||||||
}
|
}
|
||||||
|
virtual void VisitTransOp(TransOp* trans) {
|
||||||
|
Error(trans, "expect constant expression");
|
||||||
|
}
|
||||||
virtual void VisitEnumerator(Enumerator* enumer) {
|
virtual void VisitEnumerator(Enumerator* enumer) {
|
||||||
addr_.offset_ = enumer->Val();
|
addr_.offset_ = enumer->Val();
|
||||||
}
|
}
|
||||||
|
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
class BinaryOp;
|
class BinaryOp;
|
||||||
class UnaryOp;
|
class UnaryOp;
|
||||||
|
class TransOp;
|
||||||
class ConditionalOp;
|
class ConditionalOp;
|
||||||
class FuncCall;
|
class FuncCall;
|
||||||
class Identifier;
|
class Identifier;
|
||||||
@@ -31,6 +32,7 @@ public:
|
|||||||
virtual ~Visitor() {}
|
virtual ~Visitor() {}
|
||||||
virtual void VisitBinaryOp(BinaryOp* binary) = 0;
|
virtual void VisitBinaryOp(BinaryOp* binary) = 0;
|
||||||
virtual void VisitUnaryOp(UnaryOp* unary) = 0;
|
virtual void VisitUnaryOp(UnaryOp* unary) = 0;
|
||||||
|
virtual void VisitTransOp(TransOp* trans) = 0;
|
||||||
virtual void VisitConditionalOp(ConditionalOp* cond) = 0;
|
virtual void VisitConditionalOp(ConditionalOp* cond) = 0;
|
||||||
virtual void VisitFuncCall(FuncCall* funcCall) = 0;
|
virtual void VisitFuncCall(FuncCall* funcCall) = 0;
|
||||||
virtual void VisitEnumerator(Enumerator* enumer) = 0;
|
virtual void VisitEnumerator(Enumerator* enumer) = 0;
|
||||||
|
@@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
|
|
||||||
static MemPoolImp<BinaryOp> binaryOpPool;
|
static MemPoolImp<BinaryOp> binaryOpPool;
|
||||||
|
static MemPoolImp<TransOp> transOpPool;
|
||||||
static MemPoolImp<ConditionalOp> conditionalOpPool;
|
static MemPoolImp<ConditionalOp> conditionalOpPool;
|
||||||
static MemPoolImp<FuncCall> funcCallPool;
|
static MemPoolImp<FuncCall> funcCallPool;
|
||||||
static MemPoolImp<Declaration> initializationPool;
|
static MemPoolImp<Declaration> initializationPool;
|
||||||
@@ -78,6 +79,9 @@ void UnaryOp::Accept(Visitor* v) {
|
|||||||
v->VisitUnaryOp(this);
|
v->VisitUnaryOp(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TransOp::Accept(Visitor* v) {
|
||||||
|
v->VisitTransOp(this);
|
||||||
|
}
|
||||||
|
|
||||||
void ConditionalOp::Accept(Visitor* v) {
|
void ConditionalOp::Accept(Visitor* v) {
|
||||||
v->VisitConditionalOp(this);
|
v->VisitConditionalOp(this);
|
||||||
@@ -645,9 +649,6 @@ void UnaryOp::TypeChecking() {
|
|||||||
case Token::CAST:
|
case Token::CAST:
|
||||||
return CastOpTypeChecking();
|
return CastOpTypeChecking();
|
||||||
|
|
||||||
case '^':
|
|
||||||
return TransOpTypeChecking();
|
|
||||||
|
|
||||||
case Token::REDUCE:
|
case Token::REDUCE:
|
||||||
return ReduceOpTypeChecking();
|
return ReduceOpTypeChecking();
|
||||||
|
|
||||||
@@ -702,15 +703,6 @@ void UnaryOp::ReduceOpTypeChecking() {
|
|||||||
type_ = TileType::New(shape, tileType->Derived());
|
type_ = TileType::New(shape, tileType->Derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
void UnaryOp::TransOpTypeChecking() {
|
|
||||||
auto tileType = operand_->Type()->ToTile();
|
|
||||||
if(!tileType)
|
|
||||||
Error(this, "tile expected for transposition operator '^'");
|
|
||||||
auto shape = tileType->Shape();
|
|
||||||
std::rotate(shape.begin(), shape.begin() + 1, shape.end());
|
|
||||||
type_ = TileType::New(shape, tileType->Derived());
|
|
||||||
}
|
|
||||||
|
|
||||||
void UnaryOp::UnaryArithmOpTypeChecking() {
|
void UnaryOp::UnaryArithmOpTypeChecking() {
|
||||||
auto scalType = TryExtractScalarType(this, operand_);
|
auto scalType = TryExtractScalarType(this, operand_);
|
||||||
if (Token::PLUS == op_ || Token::MINUS == op_) {
|
if (Token::PLUS == op_ || Token::MINUS == op_) {
|
||||||
@@ -769,6 +761,29 @@ void UnaryOp::CastOpTypeChecking() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Transposition Operator
|
||||||
|
*/
|
||||||
|
void TransOp::TypeChecking() {
|
||||||
|
auto tileType = operand_->Type()->ToTile();
|
||||||
|
if(!tileType)
|
||||||
|
Error(this, "tile expected for transposition operator '^'");
|
||||||
|
auto opShape = tileType->Shape();
|
||||||
|
if(perm_.size() != opShape.size())
|
||||||
|
Error(this, "invalid permutations");
|
||||||
|
// permutate input shape
|
||||||
|
TileType::ShapeInt resShape(opShape.size());
|
||||||
|
for(int d = 0; d < opShape.size(); d++)
|
||||||
|
resShape[d] = opShape[perm_[d]];
|
||||||
|
type_ = TileType::New(resShape, tileType->Derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
TransOp* TransOp::New(const PermInt& perm, Expr* operand) {
|
||||||
|
auto ret = new (transOpPool.Alloc()) TransOp(perm, operand);
|
||||||
|
ret->pool_ = &transOpPool;
|
||||||
|
ret->TypeChecking();
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Conditional Operator
|
* Conditional Operator
|
||||||
|
@@ -185,7 +185,6 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
|
|||||||
case '~': return set_ret(bld_->create_neg(arg));
|
case '~': return set_ret(bld_->create_neg(arg));
|
||||||
case '!': return set_ret(bld_->create_not(arg));
|
case '!': return set_ret(bld_->create_not(arg));
|
||||||
case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
case Token::CAST: return set_ret(GenCastOp(arg, GenIRType(unary->Type(), *ctx_)));
|
||||||
case '^': return set_ret(bld_->create_trans(arg));
|
|
||||||
case Token::REDUCE: {
|
case Token::REDUCE: {
|
||||||
int ax, tag;
|
int ax, tag;
|
||||||
UnaryOp::decodeRed(unary->info_, ax, tag);
|
UnaryOp::decodeRed(unary->info_, ax, tag);
|
||||||
@@ -198,6 +197,12 @@ void Generator::VisitUnaryOp(UnaryOp* unary) {
|
|||||||
return error_not_implemented();
|
return error_not_implemented();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Generator::VisitTransOp(TransOp *trans) {
|
||||||
|
Visit(trans->operand_);
|
||||||
|
ir::value* arg = ret_;
|
||||||
|
return set_ret(bld_->create_trans(arg, trans->getPerm()));
|
||||||
|
}
|
||||||
|
|
||||||
void Generator::VisitConditionalOp(ConditionalOp* condOp) {
|
void Generator::VisitConditionalOp(ConditionalOp* condOp) {
|
||||||
// auto &instructions = bld_->get_insert_block()->get_inst_list();
|
// auto &instructions = bld_->get_insert_block()->get_inst_list();
|
||||||
VisitExpr(condOp->cond_);
|
VisitExpr(condOp->cond_);
|
||||||
|
@@ -451,6 +451,7 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
|||||||
QualType lhsQual = lhsTile->Derived();
|
QualType lhsQual = lhsTile->Derived();
|
||||||
// create ret shape
|
// create ret shape
|
||||||
TileType::ShapeInt shape;
|
TileType::ShapeInt shape;
|
||||||
|
TileType::ShapeInt axVec;
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
const Token* tok;
|
const Token* tok;
|
||||||
std::vector<std::pair<int, int>> redInfo;
|
std::vector<std::pair<int, int>> redInfo;
|
||||||
@@ -475,14 +476,37 @@ Expr* Parser::ParseSubScripting(Expr* lhs) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case '^':{
|
||||||
|
Expr* expr = ParseConditionalExpr();
|
||||||
|
EnsureInteger(expr);
|
||||||
|
int ax = Evaluator<long>().Eval(expr);
|
||||||
|
axVec.push_back(ax);
|
||||||
|
if(ax < 0 || ax >= lhsShape.size())
|
||||||
|
Error(tok, "unknown axis %d in transposition", ax);
|
||||||
|
shape.push_back(lhsShape[ax]);
|
||||||
|
i++;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
Error(tok, "Unexpected subscript symbol encountered at dimension %d", i);
|
Error(tok, "Unexpected subscript symbol encountered at dimension %d", i);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}while(ts_.Try(','));
|
}while(ts_.Try(','));
|
||||||
ts_.Expect(']');
|
ts_.Expect(']');
|
||||||
|
|
||||||
|
// transposition mode
|
||||||
|
std::set<int> axSet(axVec.begin(), axVec.end());
|
||||||
|
if(!axSet.empty()){
|
||||||
|
if(axSet.size()!=lhsShape.size())
|
||||||
|
Error(tok, "transposition must address all axes of input array");
|
||||||
|
return TransOp::New(axVec, lhs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// broadcasting mode
|
||||||
if(lhsShape.size() > i)
|
if(lhsShape.size() > i)
|
||||||
Error(tok, "broadcasting not using all operand axes");
|
Error(tok, "broadcasting not using all operand axes");
|
||||||
|
|
||||||
// create ret tile
|
// create ret tile
|
||||||
Expr* res = lhs;
|
Expr* res = lhs;
|
||||||
for(auto r: redInfo){
|
for(auto r: redInfo){
|
||||||
@@ -553,7 +577,15 @@ Expr* Parser::ParseUnaryExpr() {
|
|||||||
case '-': return ParseUnaryOp(tok, Token::MINUS);
|
case '-': return ParseUnaryOp(tok, Token::MINUS);
|
||||||
case '~': return ParseUnaryOp(tok, '~');
|
case '~': return ParseUnaryOp(tok, '~');
|
||||||
case '!': return ParseUnaryOp(tok, '!');
|
case '!': return ParseUnaryOp(tok, '!');
|
||||||
case '^': return ParseUnaryOp(tok, Token::XOR);
|
case '^': {
|
||||||
|
auto operand = ParseCastExpr();
|
||||||
|
TileType::ShapeInt shape = operand->Type()->ToTile()->Shape();
|
||||||
|
TransOp::PermInt perm(shape.size());
|
||||||
|
for(int d = 0; d < shape.size(); d++)
|
||||||
|
perm[d] = d;
|
||||||
|
std::rotate(perm.begin(), perm.begin() + 1, perm.end());
|
||||||
|
return TransOp::New(perm, operand);
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
ts_.PutBack();
|
ts_.PutBack();
|
||||||
return ParsePostfixExpr();
|
return ParsePostfixExpr();
|
||||||
|
@@ -8,41 +8,36 @@ class _einsum(triton.function):
|
|||||||
int std_A0, int std_B0, int std_C0,
|
int std_A0, int std_B0, int std_C0,
|
||||||
int std_A1, int std_B1, int std_C1) {
|
int std_A1, int std_B1, int std_C1) {
|
||||||
// program id
|
// program id
|
||||||
int pid0 = get_program_id(0);
|
int pgm = get_program_id(0);
|
||||||
int pid1 = get_program_id(1);
|
int pgn = get_program_id(1);
|
||||||
int pid2 = get_program_id(2);
|
int pgb = get_program_id(2);
|
||||||
// range
|
// range
|
||||||
int rma[TM] = pid0 * TM + 0 ... TM;
|
int rm[TM] = pgm * TM + 0 ... TM;
|
||||||
int rnb[TN] = pid1 * TN + 0 ... TN;
|
int rn[TN] = pgn * TN + 0 ... TN;
|
||||||
int rka[TK] = 0 ... TK;
|
int rb[TB] = pgb * TB + 0 ... TB;
|
||||||
int rkb[TK] = 0 ... TK;
|
int rk[TK] = 0 ... TK;
|
||||||
int rba[TB] = pid2 * TB + 0 ... TB;
|
|
||||||
int rbb[TB] = pid2 * TB + 0 ... TB;
|
|
||||||
// accumulator
|
// accumulator
|
||||||
TYPE c[TM, TN, TB] = 0;
|
TYPE c[TM, TN, TB] = 0;
|
||||||
// pointers to a
|
// pointers to a
|
||||||
TYPE *pa[TM, TK, TB] = A + rka[newaxis, :, newaxis] * 1
|
TYPE *pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK
|
||||||
+ rma[:, newaxis, newaxis] * std_A1
|
+ rm[BROADCAST_AM] * STRIDE_AM
|
||||||
+ rba[newaxis, newaxis, :] * std_A0;
|
+ rb[newaxis, newaxis, :] * std_A0;
|
||||||
// pointers to b
|
// pointers to b
|
||||||
TYPE *pb[TK, TN, TB] = B + rkb[:, newaxis, newaxis] * 1
|
TYPE *pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK
|
||||||
+ rnb[newaxis, :, newaxis] * std_B1
|
+ rn[BROADCAST_BN] * STRIDE_BN
|
||||||
+ rbb[newaxis, newaxis, :] * std_B0;
|
+ rb[newaxis, newaxis, :] * std_B0;
|
||||||
// accumulation
|
// accumulation
|
||||||
for(int k = dim_K; k > 0; k -= TK) {
|
for(int k = dim_K; k > 0; k -= TK) {
|
||||||
TYPE a[TM, TK, TB] = *pa;
|
TYPE a[SHAPE_A] = *pa;
|
||||||
TYPE b[TK, TN, TB] = *pb;
|
TYPE b[SHAPE_B] = *pb;
|
||||||
c += a @ b;
|
c += a @ b;
|
||||||
pa += TK;
|
pa += TK;
|
||||||
pb += TK;
|
pb += TK;
|
||||||
}
|
}
|
||||||
// write-back
|
// write-back
|
||||||
int rmc[TM] = pid0 * TM + 0 ... TM;
|
TYPE *pc[TM, TN, TB] = C + rm[:, newaxis, newaxis] * std_C1
|
||||||
int rnc[TN] = pid1 * TN + 0 ... TN;
|
+ rn[newaxis, :, newaxis] * 1
|
||||||
int rbc[TB] = pid2 * TB + 0 ... TB;
|
+ rb[newaxis, newaxis, :] * std_C0;
|
||||||
TYPE *pc[TM, TN, TB] = C + rmc[:, newaxis, newaxis] * std_C1
|
|
||||||
+ rnc[newaxis, :, newaxis] * 1
|
|
||||||
+ rbc[newaxis, newaxis, :] * std_C0;
|
|
||||||
*pc = c;
|
*pc = c;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
@@ -138,12 +133,25 @@ class _einsum(triton.function):
|
|||||||
grid = lambda opt: [triton.cdiv(bmnk[1], opt.d('TM')),
|
grid = lambda opt: [triton.cdiv(bmnk[1], opt.d('TM')),
|
||||||
triton.cdiv(bmnk[2], opt.d('TN')),
|
triton.cdiv(bmnk[2], opt.d('TN')),
|
||||||
triton.cdiv(bmnk[0], opt.d('TB'))]
|
triton.cdiv(bmnk[0], opt.d('TB'))]
|
||||||
#print(std0, std1)
|
macros = {# handle A transposition
|
||||||
|
'USE_A' : 'a[^1, ^0, ^2]' if trans_a else 'a',
|
||||||
|
'STRIDE_AK' : 'std_A1' if trans_a else '1',
|
||||||
|
'STRIDE_AM' : '1' if trans_a else 'std_A1',
|
||||||
|
'BROADCAST_AK': ':, newaxis, newaxis' if trans_a else 'newaxis, :, newaxis',
|
||||||
|
'BROADCAST_AM': 'newaxis, :, newaxis' if trans_a else ':, newaxis, newaxis',
|
||||||
|
'SHAPE_A' : 'TK, TM, TB' if trans_a else 'TM, TK, TB',
|
||||||
|
# handle B transposition
|
||||||
|
'USE_B' : 'b[^1, ^0, ^2]' if not trans_b else 'b',
|
||||||
|
'STRIDE_BK' : 'std_B1' if not trans_b else '1',
|
||||||
|
'STRIDE_BN' : '1' if not trans_b else 'std_B1',
|
||||||
|
'BROADCAST_BK': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis',
|
||||||
|
'BROADCAST_BN': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis',
|
||||||
|
'SHAPE_B' : 'TN, TK, TB' if not trans_b else 'TK, TN, TB'}
|
||||||
return _einsum.kernel(a, b, c,
|
return _einsum.kernel(a, b, c,
|
||||||
bmnk[1], bmnk[2], bmnk[3],
|
bmnk[1], bmnk[2], bmnk[3],
|
||||||
std0[0], std0[1], std0[2],
|
std0[0], std0[1], std0[2],
|
||||||
std1[0], std1[1], std1[2],
|
std1[0], std1[1], std1[2],
|
||||||
grid,
|
grid, **macros,
|
||||||
TYPE='float', TM=32, TN=32, TK=8, TB=1)
|
TYPE='float', TM=32, TN=32, TK=8, TB=1)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -86,14 +86,14 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
|||||||
// macros
|
// macros
|
||||||
rt::function::options_space_t opt;
|
rt::function::options_space_t opt;
|
||||||
// A access patterns
|
// A access patterns
|
||||||
opt.defines.push_back({"USEA", {AT? "^a" : "a" }});
|
opt.defines.push_back({"USEA", {AT? "a[^1, ^0]" : "a" }});
|
||||||
opt.defines.push_back({"BROADCAST_AK", {AT? ":, newaxis" : "newaxis, :" }});
|
opt.defines.push_back({"BROADCAST_AK", {AT? ":, newaxis" : "newaxis, :" }});
|
||||||
opt.defines.push_back({"BROADCAST_AM", {AT? "newaxis, :" : ":, newaxis" }});
|
opt.defines.push_back({"BROADCAST_AM", {AT? "newaxis, :" : ":, newaxis" }});
|
||||||
opt.defines.push_back({"SHAPE_A", {AT? "TK, TM" : "TM, TK" }});
|
opt.defines.push_back({"SHAPE_A", {AT? "TK, TM" : "TM, TK" }});
|
||||||
opt.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }});
|
opt.defines.push_back({"STRIDE_AK", {AT? sa[a_order[0]] : sa[a_order[1]] }});
|
||||||
opt.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }});
|
opt.defines.push_back({"STRIDE_AM", {AT? sa[a_order[1]] : sa[a_order[0]] }});
|
||||||
// B access patterns
|
// B access patterns
|
||||||
opt.defines.push_back({"USEB", {BT? "^b" : "b" }});
|
opt.defines.push_back({"USEB", {BT? "b[^1, ^0]" : "b" }});
|
||||||
opt.defines.push_back({"BROADCAST_BK", {BT? "newaxis, :" : ":, newaxis" }});
|
opt.defines.push_back({"BROADCAST_BK", {BT? "newaxis, :" : ":, newaxis" }});
|
||||||
opt.defines.push_back({"BROADCAST_BN", {BT? ":, newaxis" : "newaxis, :" }});
|
opt.defines.push_back({"BROADCAST_BN", {BT? ":, newaxis" : "newaxis, :" }});
|
||||||
opt.defines.push_back({"SHAPE_B", {BT? "TN, TK" : "TK, TN" }});
|
opt.defines.push_back({"SHAPE_B", {BT? "TN, TK" : "TK, TN" }});
|
||||||
|
Reference in New Issue
Block a user