@@ -168,7 +168,7 @@ Scheduling languages are, without a doubt, one of the most popular approaches fo
|
|||||||
Limitations
|
Limitations
|
||||||
++++++++++++
|
++++++++++++
|
||||||
|
|
||||||
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indice without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse computations, whose iteration spaces may be irregular.
|
This ease-of-development comes at a cost. First of all, existing systems that follow this paradigm tend to be noticeably slower than Triton on modern hardware when applicable (e.g., V100/A100 tensor cores w/ equal tile sizes). I do believe that this is not a fundamental issue of scheduling languages -- in the sense that it could probably be solved with more efforts -- but it could mean that these systems are harder to engineer. More importantly, existing scheduling languages generate loops whose bounds and increments cannot depend on surrounding loop indices without at least imposing severe constraints on possible schedules -- if not breaking the system entirely. This is problematic for sparse computations, whose iteration spaces may be irregular.
|
||||||
|
|
||||||
.. table::
|
.. table::
|
||||||
:widths: 50 50
|
:widths: 50 50
|
||||||
|
@@ -44,7 +44,7 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation(
|
|||||||
if (pessimistic) {
|
if (pessimistic) {
|
||||||
return markAllPessimisticFixpoint(op->getResults());
|
return markAllPessimisticFixpoint(op->getResults());
|
||||||
}
|
}
|
||||||
// Join all latice elements
|
// Join all lattice elements
|
||||||
ChangeResult result = ChangeResult::NoChange;
|
ChangeResult result = ChangeResult::NoChange;
|
||||||
for (Value value : op->getResults()) {
|
for (Value value : op->getResults()) {
|
||||||
result |= getLatticeElement(value).join(aliasInfo);
|
result |= getLatticeElement(value).join(aliasInfo);
|
||||||
|
@@ -235,7 +235,7 @@ private:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extends the liveness range by unioning the liveness range of the aliased
|
/// Extends the liveness range by unionizing the liveness range of the aliased
|
||||||
/// values because each allocated buffer could be an alias of others, if block
|
/// values because each allocated buffer could be an alias of others, if block
|
||||||
/// arguments are involved.
|
/// arguments are involved.
|
||||||
void resolveAliasBufferLiveness(
|
void resolveAliasBufferLiveness(
|
||||||
|
@@ -372,7 +372,7 @@ unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
|||||||
|
|
||||||
unsigned
|
unsigned
|
||||||
DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||||
assert(0 && "DotOPerandEncodingAttr::getElemsPerThread not implemented");
|
assert(0 && "DotOperandEncodingAttr::getElemsPerThread not implemented");
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -24,7 +24,7 @@ struct CanonicalizePass
|
|||||||
// The following piece of code is a workaround to
|
// The following piece of code is a workaround to
|
||||||
// very crudely remove dead code, by making an iteration
|
// very crudely remove dead code, by making an iteration
|
||||||
// argument yield itself if it is not used to create
|
// argument yield itself if it is not used to create
|
||||||
// side-effects anywhere.
|
// side effects anywhere.
|
||||||
getOperation()->walk([&](scf::ForOp forOp) -> void {
|
getOperation()->walk([&](scf::ForOp forOp) -> void {
|
||||||
for (size_t i = 0; i < forOp.getNumResults(); ++i) {
|
for (size_t i = 0; i < forOp.getNumResults(); ++i) {
|
||||||
// condition 1: no other iter arguments depend on it
|
// condition 1: no other iter arguments depend on it
|
||||||
|
@@ -29,7 +29,7 @@ namespace {
|
|||||||
// convert(blocked, dot_operand) ->
|
// convert(blocked, dot_operand) ->
|
||||||
// convert(blocked, mma) + convert(mma, dot_operand)
|
// convert(blocked, mma) + convert(mma, dot_operand)
|
||||||
// if this value is itself the result of a dot operation
|
// if this value is itself the result of a dot operation
|
||||||
// this is a heuristic to accomodate some pattern seen in fused attention
|
// this is a heuristic to accommodate some pattern seen in fused attention
|
||||||
// kernels.
|
// kernels.
|
||||||
// TODO: replace this by something more generic, i.e. layout-aware CSE
|
// TODO: replace this by something more generic, i.e. layout-aware CSE
|
||||||
class DecomposeDotOperand : public mlir::RewritePattern {
|
class DecomposeDotOperand : public mlir::RewritePattern {
|
||||||
@@ -81,7 +81,7 @@ public:
|
|||||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||||
// we don't handle conversions to DotOperandEncodingAttr
|
// we don't handle conversions to DotOperandEncodingAttr
|
||||||
// this is a heuristics to accomodate fused attention
|
// this is a heuristics to accommodate fused attention
|
||||||
// if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
// if (dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
||||||
// return mlir::failure();
|
// return mlir::failure();
|
||||||
// convert to the same layout -- we can delete
|
// convert to the same layout -- we can delete
|
||||||
@@ -265,7 +265,7 @@ public:
|
|||||||
isSharedEncoding(cvt->getOperand(0)))
|
isSharedEncoding(cvt->getOperand(0)))
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
// we don't handle conversions to DotOperandEncodingAttr
|
// we don't handle conversions to DotOperandEncodingAttr
|
||||||
// this is a heuristics to accomodate fused attention
|
// this is a heuristics to accommodate fused attention
|
||||||
auto targetType = cvt->getResultTypes()[0].cast<RankedTensorType>();
|
auto targetType = cvt->getResultTypes()[0].cast<RankedTensorType>();
|
||||||
if (targetType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
if (targetType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
|
||||||
return mlir::failure();
|
return mlir::failure();
|
||||||
@@ -285,7 +285,7 @@ public:
|
|||||||
// we stop everything
|
// we stop everything
|
||||||
if (expensive_to_remat(currOp))
|
if (expensive_to_remat(currOp))
|
||||||
break;
|
break;
|
||||||
// a conversion will be removed here (i.e. transfered to operands)
|
// a conversion will be removed here (i.e. transferred to operands)
|
||||||
numCvts -= 1;
|
numCvts -= 1;
|
||||||
// done processing
|
// done processing
|
||||||
processed.insert(currOp);
|
processed.insert(currOp);
|
||||||
|
@@ -110,7 +110,7 @@ Value LoopPipeliner::lookupOrDefault(Value origin, int stage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
||||||
// Loop-invarant value. skip
|
// Loop-invariant value, skip
|
||||||
if (v.getParentRegion() != &forOp.getLoopBody())
|
if (v.getParentRegion() != &forOp.getLoopBody())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
@@ -125,7 +125,7 @@ void LoopPipeliner::collectDeps(Value v, int stages, DenseSet<Value> &deps) {
|
|||||||
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps);
|
collectDeps(yieldOp->getOperand(arg.getArgNumber() - 1), stages - 1, deps);
|
||||||
} else { // value
|
} else { // value
|
||||||
// v might be in deps, but we still need to visit v.
|
// v might be in deps, but we still need to visit v.
|
||||||
// This is because v might depends on value in previous iterations
|
// This is because v might depend on value in previous iterations
|
||||||
deps.insert(v);
|
deps.insert(v);
|
||||||
for (Value op : v.getDefiningOp()->getOperands())
|
for (Value op : v.getDefiningOp()->getOperands())
|
||||||
collectDeps(op, stages, deps);
|
collectDeps(op, stages, deps);
|
||||||
@@ -175,18 +175,18 @@ LogicalResult LoopPipeliner::initialize() {
|
|||||||
// other load in the prologue, which is against the point of the pipeline
|
// other load in the prologue, which is against the point of the pipeline
|
||||||
// pass)
|
// pass)
|
||||||
for (triton::LoadOp loadOp : allLoads) {
|
for (triton::LoadOp loadOp : allLoads) {
|
||||||
bool isCandiate = true;
|
bool isCandidate = true;
|
||||||
for (triton::LoadOp other : allLoads) {
|
for (triton::LoadOp other : allLoads) {
|
||||||
if (loadDeps[loadOp].contains(other)) {
|
if (loadDeps[loadOp].contains(other)) {
|
||||||
isCandiate = false;
|
isCandidate = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We only pipeline loads that have one covert_layout (to dot_op) use
|
// We only pipeline loads that have one covert_layout (to dot_op) use
|
||||||
// TODO: lift this constraint in the future
|
// TODO: lift this constraint in the future
|
||||||
if (isCandiate && loadOp.getResult().hasOneUse()) {
|
if (isCandidate && loadOp.getResult().hasOneUse()) {
|
||||||
isCandiate = false;
|
isCandidate = false;
|
||||||
Operation *use = *loadOp.getResult().getUsers().begin();
|
Operation *use = *loadOp.getResult().getUsers().begin();
|
||||||
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
|
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
|
||||||
if (auto tensorType = convertLayout.getResult()
|
if (auto tensorType = convertLayout.getResult()
|
||||||
@@ -194,7 +194,7 @@ LogicalResult LoopPipeliner::initialize() {
|
|||||||
.dyn_cast<RankedTensorType>()) {
|
.dyn_cast<RankedTensorType>()) {
|
||||||
if (auto dotOpEnc = tensorType.getEncoding()
|
if (auto dotOpEnc = tensorType.getEncoding()
|
||||||
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
|
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
|
||||||
isCandiate = true;
|
isCandidate = true;
|
||||||
loadsMapping[loadOp] = convertLayout;
|
loadsMapping[loadOp] = convertLayout;
|
||||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
|
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
|
||||||
@@ -208,9 +208,9 @@ LogicalResult LoopPipeliner::initialize() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else
|
} else
|
||||||
isCandiate = false;
|
isCandidate = false;
|
||||||
|
|
||||||
if (isCandiate)
|
if (isCandidate)
|
||||||
loads.insert(loadOp);
|
loads.insert(loadOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -317,7 +317,7 @@ void LoopPipeliner::emitPrologue() {
|
|||||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||||
Value originalResult = op->getResult(dstIdx);
|
Value originalResult = op->getResult(dstIdx);
|
||||||
// copy_async will update the value of its only use
|
// copy_async will update the value of its only use
|
||||||
// TODO: load should no be used in the preheader?
|
// TODO: load should not be used in the preheader?
|
||||||
if (loads.contains(originalResult)) {
|
if (loads.contains(originalResult)) {
|
||||||
break;
|
break;
|
||||||
// originalResult = loadsMapping[originalResult];
|
// originalResult = loadsMapping[originalResult];
|
||||||
|
@@ -35,7 +35,7 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
|
|||||||
});
|
});
|
||||||
|
|
||||||
//
|
//
|
||||||
// materailizations
|
// Materializations
|
||||||
//
|
//
|
||||||
// This will be called when (newArgType != origArgType)
|
// This will be called when (newArgType != origArgType)
|
||||||
// This will create newArg, and map(origArg, newArg)
|
// This will create newArg, and map(origArg, newArg)
|
||||||
|
@@ -172,7 +172,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
if (mlir::Operation *definingOp = self.getDefiningOp())
|
if (mlir::Operation *definingOp = self.getDefiningOp())
|
||||||
definingOp->setAttr(name, attr);
|
definingOp->setAttr(name, attr);
|
||||||
else {
|
else {
|
||||||
/* issue an warning */
|
/* issue a warning */
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.def("replace_all_uses_with",
|
.def("replace_all_uses_with",
|
||||||
@@ -180,7 +180,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
self.replaceAllUsesWith(newValue);
|
self.replaceAllUsesWith(newValue);
|
||||||
});
|
});
|
||||||
|
|
||||||
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement");
|
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_argument");
|
||||||
|
|
||||||
py::class_<mlir::Region>(m, "region")
|
py::class_<mlir::Region>(m, "region")
|
||||||
.def("get_parent_region", &mlir::Region::getParentRegion, ret::reference)
|
.def("get_parent_region", &mlir::Region::getParentRegion, ret::reference)
|
||||||
@@ -288,7 +288,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
py::class_<mlir::scf::WhileOp, mlir::OpState>(m, "WhileOp")
|
py::class_<mlir::scf::WhileOp, mlir::OpState>(m, "WhileOp")
|
||||||
.def("get_before", &mlir::scf::WhileOp::getBefore, ret::reference)
|
.def("get_before", &mlir::scf::WhileOp::getBefore, ret::reference)
|
||||||
.def("get_after", &mlir::scf::WhileOp::getAfter, ret::reference);
|
.def("get_after", &mlir::scf::WhileOp::getAfter, ret::reference);
|
||||||
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
|
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "ConditionOp");
|
||||||
|
|
||||||
// dynamic_attr is used to transfer ownership of the MLIR context to the
|
// dynamic_attr is used to transfer ownership of the MLIR context to the
|
||||||
// module
|
// module
|
||||||
@@ -423,7 +423,7 @@ void init_triton_ir(py::module &&m) {
|
|||||||
.def("get_bool_attr", &mlir::OpBuilder::getBoolAttr)
|
.def("get_bool_attr", &mlir::OpBuilder::getBoolAttr)
|
||||||
.def("get_int32_attr", &mlir::OpBuilder::getI32IntegerAttr)
|
.def("get_int32_attr", &mlir::OpBuilder::getI32IntegerAttr)
|
||||||
// Use arith.ConstantOp to create constants
|
// Use arith.ConstantOp to create constants
|
||||||
// // Constants
|
// Constants
|
||||||
.def("get_int1",
|
.def("get_int1",
|
||||||
[](mlir::OpBuilder &self, bool v) -> mlir::Value {
|
[](mlir::OpBuilder &self, bool v) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
@@ -588,14 +588,14 @@ void init_triton_ir(py::module &&m) {
|
|||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return self.create<mlir::scf::WhileOp>(loc, retTypes, initArgs);
|
return self.create<mlir::scf::WhileOp>(loc, retTypes, initArgs);
|
||||||
})
|
})
|
||||||
.def("create_condtion_op",
|
.def("create_condition_op",
|
||||||
[](mlir::OpBuilder &self, mlir::Value &cond,
|
[](mlir::OpBuilder &self, mlir::Value &cond,
|
||||||
std::vector<mlir::Value> &args) -> mlir::scf::ConditionOp {
|
std::vector<mlir::Value> &args) -> mlir::scf::ConditionOp {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return self.create<mlir::scf::ConditionOp>(loc, cond, args);
|
return self.create<mlir::scf::ConditionOp>(loc, cond, args);
|
||||||
})
|
})
|
||||||
|
|
||||||
// miscellious
|
// miscellaneous
|
||||||
.def("create_make_range",
|
.def("create_make_range",
|
||||||
[](mlir::OpBuilder &self, int start, int end) -> mlir::Value {
|
[](mlir::OpBuilder &self, int start, int end) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
@@ -976,15 +976,15 @@ void init_triton_ir(py::module &&m) {
|
|||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return self.create<mlir::arith::OrIOp>(loc, lhs, rhs);
|
return self.create<mlir::arith::OrIOp>(loc, lhs, rhs);
|
||||||
})
|
})
|
||||||
// // Input/Output
|
// Input/Output
|
||||||
.def("create_load",
|
.def("create_load",
|
||||||
[](mlir::OpBuilder &self, mlir::Value &ptrs,
|
[](mlir::OpBuilder &self, mlir::Value &ptrs,
|
||||||
mlir::triton::CacheModifier cacheModifer,
|
mlir::triton::CacheModifier cacheModifier,
|
||||||
mlir::triton::EvictionPolicy evictionPolicy,
|
mlir::triton::EvictionPolicy evictionPolicy,
|
||||||
bool isVolatile) -> mlir::Value {
|
bool isVolatile) -> mlir::Value {
|
||||||
auto loc = self.getUnknownLoc();
|
auto loc = self.getUnknownLoc();
|
||||||
return self.create<mlir::triton::LoadOp>(
|
return self.create<mlir::triton::LoadOp>(
|
||||||
loc, ptrs, cacheModifer, evictionPolicy, isVolatile);
|
loc, ptrs, cacheModifier, evictionPolicy, isVolatile);
|
||||||
})
|
})
|
||||||
.def("create_store",
|
.def("create_store",
|
||||||
[](mlir::OpBuilder &self, mlir::Value &ptrs,
|
[](mlir::OpBuilder &self, mlir::Value &ptrs,
|
||||||
|
@@ -131,7 +131,7 @@ def vecadd_no_scf_tester(num_warps, block_size, shape):
|
|||||||
|
|
||||||
def vecadd_fcmp_no_scf_tester(num_warps, block_size, shape):
|
def vecadd_fcmp_no_scf_tester(num_warps, block_size, shape):
|
||||||
'''
|
'''
|
||||||
vecadd tester with float comparation as load/store mask.
|
vecadd tester with float comparison as load/store mask.
|
||||||
'''
|
'''
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def kernel(x_ptr,
|
def kernel(x_ptr,
|
||||||
|
@@ -523,8 +523,8 @@ class CodeGenerator(ast.NodeVisitor):
|
|||||||
[ty.to_ir(self.builder) for ty in ret_types])
|
[ty.to_ir(self.builder) for ty in ret_types])
|
||||||
cond_block.merge_block_before(before_block)
|
cond_block.merge_block_before(before_block)
|
||||||
self.builder.set_insertion_point_to_end(before_block)
|
self.builder.set_insertion_point_to_end(before_block)
|
||||||
# create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
|
# create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
|
||||||
self.builder.create_condtion_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
|
self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
|
||||||
# merge the loop body
|
# merge the loop body
|
||||||
after_block = self.builder.create_block_with_parent(while_op.get_after(),
|
after_block = self.builder.create_block_with_parent(while_op.get_after(),
|
||||||
[ty.to_ir(self.builder) for ty in ret_types])
|
[ty.to_ir(self.builder) for ty in ret_types])
|
||||||
@@ -910,7 +910,7 @@ def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = Non
|
|||||||
:param mod: a TritonGPU dialect module
|
:param mod: a TritonGPU dialect module
|
||||||
:return:
|
:return:
|
||||||
- PTX code
|
- PTX code
|
||||||
- shared memory alloaction size
|
- shared memory allocation size
|
||||||
'''
|
'''
|
||||||
if compute_capability is None:
|
if compute_capability is None:
|
||||||
device = torch.cuda.current_device()
|
device = torch.cuda.current_device()
|
||||||
@@ -1194,7 +1194,7 @@ class CacheManager:
|
|||||||
os.rename(filepath + ".tmp", filepath)
|
os.rename(filepath + ".tmp", filepath)
|
||||||
|
|
||||||
|
|
||||||
# utilties for generating and compiling C wrappers
|
# Utilities for generating and compiling C wrappers
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache()
|
@functools.lru_cache()
|
||||||
|
@@ -768,7 +768,7 @@ def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=No
|
|||||||
"""
|
"""
|
||||||
Returns the matrix product of two blocks.
|
Returns the matrix product of two blocks.
|
||||||
|
|
||||||
The two blocks must be two dimensionals and have compatible inner dimensions.
|
The two blocks must be two-dimensional and have compatible inner dimensions.
|
||||||
|
|
||||||
:param input: The first tensor to be multiplied.
|
:param input: The first tensor to be multiplied.
|
||||||
:type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
|
:type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
|
||||||
@@ -1172,7 +1172,7 @@ def ravel(x):
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def swizzle2d(i, j, size_i, size_j, size_g):
|
def swizzle2d(i, j, size_i, size_j, size_g):
|
||||||
"""
|
"""
|
||||||
transformes indices of a row-major size_i*size_j matrix into those
|
Transforms indices of a row-major size_i*size_j matrix into those
|
||||||
of one where indices are row major for each group of size_j rows.
|
of one where indices are row major for each group of size_j rows.
|
||||||
For example, for size_i = size_j = 4 and size_g = 2, it will transform
|
For example, for size_i = size_j = 4 and size_g = 2, it will transform
|
||||||
[[0 , 1 , 2 , 3 ],
|
[[0 , 1 , 2 , 3 ],
|
||||||
|
@@ -7,12 +7,12 @@ from triton._C.libtriton.triton import ir
|
|||||||
|
|
||||||
|
|
||||||
# Create custom exception that prints message "hello"
|
# Create custom exception that prints message "hello"
|
||||||
class IncompatibleTypeErrorimpl(Exception):
|
class IncompatibleTypeErrorImpl(Exception):
|
||||||
def __init__(self, type_a, type_b):
|
def __init__(self, type_a, type_b):
|
||||||
self.type_a = type_a
|
self.type_a = type_a
|
||||||
self.type_b = type_b
|
self.type_b = type_b
|
||||||
self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__()
|
self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__()
|
||||||
super(IncompatibleTypeErrorimpl, self).__init__(self.message)
|
super(IncompatibleTypeErrorImpl, self).__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
# ===----------------------------------------------------------------------===##
|
# ===----------------------------------------------------------------------===##
|
||||||
@@ -88,13 +88,13 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t
|
|||||||
def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None:
|
def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None:
|
||||||
if type_a.is_ptr():
|
if type_a.is_ptr():
|
||||||
if not allow_ptr_a:
|
if not allow_ptr_a:
|
||||||
raise IncompatibleTypeErrorimpl(type_a, type_b)
|
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
||||||
# T* + U* with T != U
|
# T* + U* with T != U
|
||||||
if type_b.is_ptr() and (type_a != type_b):
|
if type_b.is_ptr() and (type_a != type_b):
|
||||||
raise IncompatibleTypeErrorimpl(type_a, type_b)
|
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
||||||
# T* + float
|
# T* + float
|
||||||
if type_b.is_floating():
|
if type_b.is_floating():
|
||||||
raise IncompatibleTypeErrorimpl(type_a, type_b)
|
raise IncompatibleTypeErrorImpl(type_a, type_b)
|
||||||
|
|
||||||
|
|
||||||
def binary_op_type_checking_impl(lhs: tl.tensor,
|
def binary_op_type_checking_impl(lhs: tl.tensor,
|
||||||
@@ -223,7 +223,7 @@ def fdiv(input: tl.tensor,
|
|||||||
input_scalar_ty = input.type.scalar
|
input_scalar_ty = input.type.scalar
|
||||||
other_scalar_ty = other.type.scalar
|
other_scalar_ty = other.type.scalar
|
||||||
if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
|
if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
|
||||||
raise ValueError("both operands of fdiv must have floating poscalar type")
|
raise ValueError("both operands of fdiv must have floating scalar type")
|
||||||
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True)
|
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True)
|
||||||
ret = builder.create_fdiv(input.handle, other.handle)
|
ret = builder.create_fdiv(input.handle, other.handle)
|
||||||
return tl.tensor(ret, input.type)
|
return tl.tensor(ret, input.type)
|
||||||
@@ -262,7 +262,7 @@ def bitwise_op_type_checking_impl(input: tl.tensor,
|
|||||||
input_sca_ty = input.type.scalar
|
input_sca_ty = input.type.scalar
|
||||||
other_sca_ty = other.type.scalar
|
other_sca_ty = other.type.scalar
|
||||||
if not input_sca_ty.is_int() or not other_sca_ty.is_int():
|
if not input_sca_ty.is_int() or not other_sca_ty.is_int():
|
||||||
raise IncompatibleTypeErrorimpl(input_sca_ty, other_sca_ty)
|
raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty)
|
||||||
ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty)
|
ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty)
|
||||||
if ret_sca_ty != input_sca_ty:
|
if ret_sca_ty != input_sca_ty:
|
||||||
input = cast(input, ret_sca_ty, builder)
|
input = cast(input, ret_sca_ty, builder)
|
||||||
|
@@ -63,7 +63,7 @@ def add(x: torch.Tensor, y: torch.Tensor):
|
|||||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||||
# NOTE:
|
# NOTE:
|
||||||
# - each torch.tensor object is implicitly converted into a pointer to its first element.
|
# - each torch.tensor object is implicitly converted into a pointer to its first element.
|
||||||
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
|
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel
|
||||||
# - don't forget to pass meta-parameters as keywords arguments
|
# - don't forget to pass meta-parameters as keywords arguments
|
||||||
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
|
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
|
||||||
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
|
||||||
|
@@ -80,7 +80,7 @@ def softmax_kernel(
|
|||||||
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
|
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
|
||||||
# Subtract maximum for numerical stability
|
# Subtract maximum for numerical stability
|
||||||
row_minus_max = row - tl.max(row, axis=0)
|
row_minus_max = row - tl.max(row, axis=0)
|
||||||
# Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
|
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
|
||||||
numerator = tl.exp(row_minus_max)
|
numerator = tl.exp(row_minus_max)
|
||||||
denominator = tl.sum(numerator, axis=0)
|
denominator = tl.sum(numerator, axis=0)
|
||||||
softmax_output = numerator / denominator
|
softmax_output = numerator / denominator
|
||||||
@@ -188,4 +188,4 @@ benchmark.run(show_plots=True, print_data=True)
|
|||||||
#
|
#
|
||||||
# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
|
# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
|
||||||
# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
|
# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
|
||||||
# Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.
|
# Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape.
|
||||||
|
@@ -15,7 +15,7 @@ import triton.language as tl
|
|||||||
@triton.jit
|
@triton.jit
|
||||||
def _fwd_kernel(
|
def _fwd_kernel(
|
||||||
Q, K, V, sm_scale,
|
Q, K, V, sm_scale,
|
||||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
TMP, L, M, # NOTE: TMP is a scratchpad buffer to work around a compiler bug
|
||||||
Out,
|
Out,
|
||||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||||
|
Reference in New Issue
Block a user