[Triton-MLIR] Fix some typos (#874)

Fix some typos
This commit is contained in:
Chenggang Zhao
2022-11-14 10:15:53 +08:00
committed by GitHub
parent f40c63fb03
commit 516a241234
16 changed files with 47 additions and 47 deletions

View File

@@ -172,7 +172,7 @@ void init_triton_ir(py::module &&m) {
if (mlir::Operation *definingOp = self.getDefiningOp())
definingOp->setAttr(name, attr);
else {
/* issue an warning */
/* issue a warning */
}
})
.def("replace_all_uses_with",
@@ -180,7 +180,7 @@ void init_triton_ir(py::module &&m) {
self.replaceAllUsesWith(newValue);
});
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_arguement");
py::class_<mlir::BlockArgument, mlir::Value>(m, "block_argument");
py::class_<mlir::Region>(m, "region")
.def("get_parent_region", &mlir::Region::getParentRegion, ret::reference)
@@ -288,7 +288,7 @@ void init_triton_ir(py::module &&m) {
py::class_<mlir::scf::WhileOp, mlir::OpState>(m, "WhileOp")
.def("get_before", &mlir::scf::WhileOp::getBefore, ret::reference)
.def("get_after", &mlir::scf::WhileOp::getAfter, ret::reference);
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "CondtionOp");
py::class_<mlir::scf::ConditionOp, mlir::OpState>(m, "ConditionOp");
// dynamic_attr is used to transfer ownership of the MLIR context to the
// module
@@ -423,7 +423,7 @@ void init_triton_ir(py::module &&m) {
.def("get_bool_attr", &mlir::OpBuilder::getBoolAttr)
.def("get_int32_attr", &mlir::OpBuilder::getI32IntegerAttr)
// Use arith.ConstantOp to create constants
// // Constants
// Constants
.def("get_int1",
[](mlir::OpBuilder &self, bool v) -> mlir::Value {
auto loc = self.getUnknownLoc();
@@ -588,14 +588,14 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::WhileOp>(loc, retTypes, initArgs);
})
.def("create_condtion_op",
.def("create_condition_op",
[](mlir::OpBuilder &self, mlir::Value &cond,
std::vector<mlir::Value> &args) -> mlir::scf::ConditionOp {
auto loc = self.getUnknownLoc();
return self.create<mlir::scf::ConditionOp>(loc, cond, args);
})
// miscellious
// miscellaneous
.def("create_make_range",
[](mlir::OpBuilder &self, int start, int end) -> mlir::Value {
auto loc = self.getUnknownLoc();
@@ -976,15 +976,15 @@ void init_triton_ir(py::module &&m) {
auto loc = self.getUnknownLoc();
return self.create<mlir::arith::OrIOp>(loc, lhs, rhs);
})
// // Input/Output
// Input/Output
.def("create_load",
[](mlir::OpBuilder &self, mlir::Value &ptrs,
mlir::triton::CacheModifier cacheModifer,
mlir::triton::CacheModifier cacheModifier,
mlir::triton::EvictionPolicy evictionPolicy,
bool isVolatile) -> mlir::Value {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::LoadOp>(
loc, ptrs, cacheModifer, evictionPolicy, isVolatile);
loc, ptrs, cacheModifier, evictionPolicy, isVolatile);
})
.def("create_store",
[](mlir::OpBuilder &self, mlir::Value &ptrs,

View File

@@ -131,7 +131,7 @@ def vecadd_no_scf_tester(num_warps, block_size, shape):
def vecadd_fcmp_no_scf_tester(num_warps, block_size, shape):
'''
vecadd tester with float comparation as load/store mask.
vecadd tester with float comparison as load/store mask.
'''
@triton.jit
def kernel(x_ptr,

View File

@@ -523,8 +523,8 @@ class CodeGenerator(ast.NodeVisitor):
[ty.to_ir(self.builder) for ty in ret_types])
cond_block.merge_block_before(before_block)
self.builder.set_insertion_point_to_end(before_block)
# create CondtionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
self.builder.create_condtion_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
# create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ...
self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))])
# merge the loop body
after_block = self.builder.create_block_with_parent(while_op.get_after(),
[ty.to_ir(self.builder) for ty in ret_types])
@@ -910,7 +910,7 @@ def llir_to_ptx(mod: Any, compute_capability: int = None, ptx_version: int = Non
:param mod: a TritonGPU dialect module
:return:
- PTX code
- shared memory alloaction size
- shared memory allocation size
'''
if compute_capability is None:
device = torch.cuda.current_device()
@@ -1194,7 +1194,7 @@ class CacheManager:
os.rename(filepath + ".tmp", filepath)
# utilties for generating and compiling C wrappers
# Utilities for generating and compiling C wrappers
@functools.lru_cache()

View File

@@ -768,7 +768,7 @@ def dot(input, other, allow_tf32=True, trans_a=False, trans_b=False, _builder=No
"""
Returns the matrix product of two blocks.
The two blocks must be two dimensionals and have compatible inner dimensions.
The two blocks must be two-dimensional and have compatible inner dimensions.
:param input: The first tensor to be multiplied.
:type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`}
@@ -1172,7 +1172,7 @@ def ravel(x):
@triton.jit
def swizzle2d(i, j, size_i, size_j, size_g):
"""
transformes indices of a row-major size_i*size_j matrix into those
Transforms indices of a row-major size_i*size_j matrix into those
of one where indices are row major for each group of size_j rows.
For example, for size_i = size_j = 4 and size_g = 2, it will transform
[[0 , 1 , 2 , 3 ],

View File

@@ -7,12 +7,12 @@ from triton._C.libtriton.triton import ir
# Create custom exception that prints message "hello"
class IncompatibleTypeErrorimpl(Exception):
class IncompatibleTypeErrorImpl(Exception):
def __init__(self, type_a, type_b):
self.type_a = type_a
self.type_b = type_b
self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__()
super(IncompatibleTypeErrorimpl, self).__init__(self.message)
super(IncompatibleTypeErrorImpl, self).__init__(self.message)
# ===----------------------------------------------------------------------===##
@@ -88,13 +88,13 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t
def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None:
if type_a.is_ptr():
if not allow_ptr_a:
raise IncompatibleTypeErrorimpl(type_a, type_b)
raise IncompatibleTypeErrorImpl(type_a, type_b)
# T* + U* with T != U
if type_b.is_ptr() and (type_a != type_b):
raise IncompatibleTypeErrorimpl(type_a, type_b)
raise IncompatibleTypeErrorImpl(type_a, type_b)
# T* + float
if type_b.is_floating():
raise IncompatibleTypeErrorimpl(type_a, type_b)
raise IncompatibleTypeErrorImpl(type_a, type_b)
def binary_op_type_checking_impl(lhs: tl.tensor,
@@ -223,7 +223,7 @@ def fdiv(input: tl.tensor,
input_scalar_ty = input.type.scalar
other_scalar_ty = other.type.scalar
if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating():
raise ValueError("both operands of fdiv must have floating poscalar type")
raise ValueError("both operands of fdiv must have floating scalar type")
input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True)
ret = builder.create_fdiv(input.handle, other.handle)
return tl.tensor(ret, input.type)
@@ -262,7 +262,7 @@ def bitwise_op_type_checking_impl(input: tl.tensor,
input_sca_ty = input.type.scalar
other_sca_ty = other.type.scalar
if not input_sca_ty.is_int() or not other_sca_ty.is_int():
raise IncompatibleTypeErrorimpl(input_sca_ty, other_sca_ty)
raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty)
ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty)
if ret_sca_ty != input_sca_ty:
input = cast(input, ret_sca_ty, builder)

View File

@@ -63,7 +63,7 @@ def add(x: torch.Tensor, y: torch.Tensor):
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# NOTE:
# - each torch.tensor object is implicitly converted into a pointer to its first element.
# - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
# - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel
# - don't forget to pass meta-parameters as keywords arguments
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
# We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still

View File

@@ -80,7 +80,7 @@ def softmax_kernel(
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=0)
# Note that exponentials in Triton are fast but approximate (i.e., think __expf in CUDA)
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
@@ -188,4 +188,4 @@ benchmark.run(show_plots=True, print_data=True)
#
# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.
# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**.
# Note however that the PyTorch `softmax` operation is more general and will works on tensors of any shape.
# Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape.

View File

@@ -15,7 +15,7 @@ import triton.language as tl
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
TMP, L, M, # NOTE: TMP is a scratchpad buffer to work around a compiler bug
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,