[lang] added support for batched matrix multiplication
This commit is contained in:
@@ -96,7 +96,6 @@ void allocation::run(ir::module &mod) {
|
|||||||
offsets_[x] = starts[x] + colors[x] * Adj;
|
offsets_[x] = starts[x] + colors[x] * Adj;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Save maximum size of induced memory space
|
// Save maximum size of induced memory space
|
||||||
allocated_size_ = 0;
|
allocated_size_ = 0;
|
||||||
for(layout_t* x: V)
|
for(layout_t* x: V)
|
||||||
|
@@ -77,11 +77,6 @@ void axes::update_graph_dot(ir::instruction *i) {
|
|||||||
// add edges between result and accumulator
|
// add edges between result and accumulator
|
||||||
for(unsigned d = 0; d < shapes.size(); d++)
|
for(unsigned d = 0; d < shapes.size(); d++)
|
||||||
graph_.add_edge({dot, d}, {D, d});
|
graph_.add_edge({dot, d}, {D, d});
|
||||||
// add edge for batch dimension
|
|
||||||
for(unsigned d = 2; d < shapes.size(); d++){
|
|
||||||
graph_.add_edge({dot, d}, {A, d});
|
|
||||||
graph_.add_edge({dot, d}, {B, d});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void axes::update_graph_elementwise(ir::instruction *i) {
|
void axes::update_graph_elementwise(ir::instruction *i) {
|
||||||
|
@@ -300,6 +300,10 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
|
|||||||
}
|
}
|
||||||
std::vector<int> col = {0, 1};
|
std::vector<int> col = {0, 1};
|
||||||
std::vector<int> row = {1, 0};
|
std::vector<int> row = {1, 0};
|
||||||
|
for(size_t s = 2; s < shapes.size(); s++){
|
||||||
|
col.push_back(s);
|
||||||
|
row.push_back(s);
|
||||||
|
}
|
||||||
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
|
bool is_nonhmma_dot_a = dot_a && !hmma_dot_a;
|
||||||
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
|
bool is_nonhmma_dot_b = dot_b && !hmma_dot_b;
|
||||||
if(is_nonhmma_dot_a)
|
if(is_nonhmma_dot_a)
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
#include <climits>
|
#include <climits>
|
||||||
|
#include <iostream>
|
||||||
#include "triton/codegen/analysis/liveness.h"
|
#include "triton/codegen/analysis/liveness.h"
|
||||||
#include "triton/codegen/analysis/layout.h"
|
#include "triton/codegen/analysis/layout.h"
|
||||||
#include "triton/ir/function.h"
|
#include "triton/ir/function.h"
|
||||||
@@ -37,12 +38,13 @@ void liveness::run(ir::module &mod) {
|
|||||||
}
|
}
|
||||||
// compute intervals
|
// compute intervals
|
||||||
unsigned start = INT32_MAX;
|
unsigned start = INT32_MAX;
|
||||||
|
for(ir::value *v: layout->values)
|
||||||
|
if(indices.find(v) != indices.end())
|
||||||
|
start = std::min(start, indices.at(v));
|
||||||
unsigned end = 0;
|
unsigned end = 0;
|
||||||
for(ir::user *u: users)
|
for(ir::user *u: users)
|
||||||
if(indices.find(u) != indices.end()){
|
if(indices.find(u) != indices.end())
|
||||||
start = std::min(start, indices.at(u));
|
|
||||||
end = std::max(end, indices.at(u));
|
end = std::max(end, indices.at(u));
|
||||||
}
|
|
||||||
intervals_[layout] = segment{start, end};
|
intervals_[layout] = segment{start, end};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <iostream>
|
||||||
#include "triton/ir/basic_block.h"
|
#include "triton/ir/basic_block.h"
|
||||||
#include "triton/ir/module.h"
|
#include "triton/ir/module.h"
|
||||||
#include "triton/ir/type.h"
|
#include "triton/ir/type.h"
|
||||||
|
@@ -471,12 +471,20 @@ void BinaryOp::MatmulOpTypeChecking() {
|
|||||||
auto rhsShape = rhsType->Shape();
|
auto rhsShape = rhsType->Shape();
|
||||||
size_t lhsRank = lhsShape.size();
|
size_t lhsRank = lhsShape.size();
|
||||||
size_t rhsRank = rhsShape.size();
|
size_t rhsRank = rhsShape.size();
|
||||||
if(lhsRank != 2 || rhsRank != 2)
|
if(lhsRank != rhsRank)
|
||||||
Error(this, "matrix multiplication operands must have rank 2");
|
Error(this, "matrix multiplication operands have incompatible rank"
|
||||||
|
"%d and %d", lhsRank, rhsRank);
|
||||||
|
for(int d = 2; d < lhsRank; d++)
|
||||||
|
if(lhsShape[d] != rhsShape[d])
|
||||||
|
Error(this, "matrix multiplication operands have incompatible batch dimension"
|
||||||
|
"%d and %d for axis %d", lhsShape[d], rhsShape[d], d);
|
||||||
if(lhsShape[1] != rhsShape[0])
|
if(lhsShape[1] != rhsShape[0])
|
||||||
Error(this, "matrix multiplication operands have incompatible inner dimension"
|
Error(this, "matrix multiplication operands have incompatible inner dimension"
|
||||||
" %d and %d", lhsShape[1], rhsShape[0]);
|
" %d and %d", lhsShape[1], rhsShape[0]);
|
||||||
|
// ret shape
|
||||||
TileType::ShapeInt retShape = {lhsShape[0], rhsShape[1]};
|
TileType::ShapeInt retShape = {lhsShape[0], rhsShape[1]};
|
||||||
|
for(int d = 2; d < lhsRank; d++)
|
||||||
|
retShape.push_back(lhsShape[d]);
|
||||||
QualType retType = lhsType->Derived();
|
QualType retType = lhsType->Derived();
|
||||||
if(retType != rhsType->Derived())
|
if(retType != rhsType->Derived())
|
||||||
Error(this, "matrix multiplication operands have incompatible data types");
|
Error(this, "matrix multiplication operands have incompatible data types");
|
||||||
|
Reference in New Issue
Block a user