New python binding
This commit is contained in:
@@ -174,10 +174,15 @@ add_subdirectory(lib)
|
|||||||
|
|
||||||
add_library(triton SHARED ${PYTHON_SRC})
|
add_library(triton SHARED ${PYTHON_SRC})
|
||||||
|
|
||||||
|
find_package(PythonLibs REQUIRED)
|
||||||
|
|
||||||
target_link_libraries(triton
|
target_link_libraries(triton
|
||||||
TritonIR
|
TritonIR
|
||||||
TritonDriver
|
TritonDriver
|
||||||
TritonCodeGen
|
# TritonCodeGen
|
||||||
|
|
||||||
|
MLIRCAPIIR
|
||||||
|
${PYTHON_LIBRARIES}
|
||||||
)
|
)
|
||||||
|
|
||||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||||
|
@@ -91,6 +91,25 @@ def TT_FpToFp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape, NoSideEffect,
|
|||||||
//
|
//
|
||||||
// Load/Store Ops
|
// Load/Store Ops
|
||||||
//
|
//
|
||||||
|
def TT_CacheModifierAttr : I32EnumAttr<
|
||||||
|
"CacheModifier", "",
|
||||||
|
[
|
||||||
|
I32EnumAttrCase<"NONE", 1, "none">,
|
||||||
|
I32EnumAttrCase<"CA", 2, "ca">,
|
||||||
|
I32EnumAttrCase<"CG", 3, "cg">,
|
||||||
|
]> {
|
||||||
|
let cppNamespace = "::mlir::triton";
|
||||||
|
}
|
||||||
|
def TT_EvictionPolicyAttr : I32EnumAttr<
|
||||||
|
"EvictionPolicy", "",
|
||||||
|
[
|
||||||
|
I32EnumAttrCase<"NORMAL", 1, "normal">,
|
||||||
|
I32EnumAttrCase<"EVICT_FIRST", 2, "evict_first">,
|
||||||
|
I32EnumAttrCase<"EVICT_LAST", 3, "evict_last">
|
||||||
|
]> {
|
||||||
|
let cppNamespace = "::mlir::triton";
|
||||||
|
}
|
||||||
|
|
||||||
def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> {
|
def TT_LoadOp : TT_Op<"load", [SameOperandsAndResultShape]> {
|
||||||
let summary = "load";
|
let summary = "load";
|
||||||
|
|
||||||
@@ -157,10 +176,13 @@ def TT_RedOpAttr : I32EnumAttr<
|
|||||||
/*name*/"RedOp", /*summary*/"",
|
/*name*/"RedOp", /*summary*/"",
|
||||||
/*case*/
|
/*case*/
|
||||||
[
|
[
|
||||||
I32EnumAttrCase</*sym*/"SUM", 1, /*str*/"sum">,
|
I32EnumAttrCase</*sym*/"ADD", 1, /*str*/"add">,
|
||||||
I32EnumAttrCase<"MAX", 2, "max">,
|
I32EnumAttrCase<"MAX", 2, "max">,
|
||||||
I32EnumAttrCase<"MIN", 3, "min">,
|
I32EnumAttrCase<"MIN", 3, "min">,
|
||||||
I32EnumAttrCase<"XOR_SUM", 4, "xor_sum">
|
I32EnumAttrCase<"FADD", 4, "fadd">,
|
||||||
|
I32EnumAttrCase<"FMAX", 5, "fmax">,
|
||||||
|
I32EnumAttrCase<"FMIN", 6, "fmin">,
|
||||||
|
I32EnumAttrCase<"XOR", 7, "xor">
|
||||||
]> {
|
]> {
|
||||||
let cppNamespace = "::mlir::triton";
|
let cppNamespace = "::mlir::triton";
|
||||||
}
|
}
|
||||||
@@ -179,10 +201,11 @@ def TT_AtomicRMWAttr : I32EnumAttr<
|
|||||||
I32EnumAttrCase<"OR", 2, "or">,
|
I32EnumAttrCase<"OR", 2, "or">,
|
||||||
I32EnumAttrCase<"XOR", 3, "xor">,
|
I32EnumAttrCase<"XOR", 3, "xor">,
|
||||||
I32EnumAttrCase<"ADD", 4, "add">,
|
I32EnumAttrCase<"ADD", 4, "add">,
|
||||||
I32EnumAttrCase<"MAX", 5, "max">,
|
I32EnumAttrCase<"FADD", 5, "fadd">,
|
||||||
I32EnumAttrCase<"MIN", 6, "min">,
|
I32EnumAttrCase<"MAX", 6, "max">,
|
||||||
I32EnumAttrCase<"UMAX", 7, "umax">,
|
I32EnumAttrCase<"MIN", 7, "min">,
|
||||||
I32EnumAttrCase<"UMIN", 8, "umin">
|
I32EnumAttrCase<"UMAX", 8, "umax">,
|
||||||
|
I32EnumAttrCase<"UMIN", 9, "umin">
|
||||||
]> {
|
]> {
|
||||||
let cppNamespace = "::mlir::triton";
|
let cppNamespace = "::mlir::triton";
|
||||||
}
|
}
|
||||||
|
@@ -18,17 +18,3 @@ add_mlir_dialect_library(TritonIR
|
|||||||
|
|
||||||
MLIRTensor
|
MLIRTensor
|
||||||
)
|
)
|
||||||
|
|
||||||
# add_library(TritonIR
|
|
||||||
# Dialect.cpp
|
|
||||||
# Ops.cpp
|
|
||||||
# Types.cpp
|
|
||||||
# )
|
|
||||||
|
|
||||||
# target_link_libraries(TritonIR PUBLIC
|
|
||||||
# MLIRIR
|
|
||||||
# MLIRArithmetic
|
|
||||||
# MLIRControlFlow
|
|
||||||
# MLIRFunc
|
|
||||||
# MLIRTensor
|
|
||||||
# )
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
@@ -395,30 +395,30 @@ class constexpr:
|
|||||||
|
|
||||||
|
|
||||||
class tensor:
|
class tensor:
|
||||||
# infer dtype from ir type
|
# # infer dtype from ir type
|
||||||
@staticmethod
|
# @staticmethod
|
||||||
def _to_dtype(ir_type):
|
# def _to_dtype(ir_type):
|
||||||
# block type
|
# # block type
|
||||||
if ir_type.is_block():
|
# if ir_type.is_block():
|
||||||
scalar_ty = tensor._to_dtype(ir_type.scalar)
|
# scalar_ty = tensor._to_dtype(ir_type.scalar)
|
||||||
return block_type(scalar_ty, ir_type.get_block_shapes())
|
# return block_type(scalar_ty, ir_type.get_block_shapes())
|
||||||
# pointer type
|
# # pointer type
|
||||||
if ir_type.is_ptr():
|
# if ir_type.is_ptr():
|
||||||
element_ty = tensor._to_dtype(ir_type.element)
|
# element_ty = tensor._to_dtype(ir_type.element)
|
||||||
return pointer_type(element_ty)
|
# return pointer_type(element_ty)
|
||||||
# primitive type
|
# # primitive type
|
||||||
if ir_type.is_void(): return void
|
# if ir_type.is_void(): return void
|
||||||
if ir_type.is_int1(): return int1
|
# if ir_type.is_int1(): return int1
|
||||||
if ir_type.is_int8(): return int8
|
# if ir_type.is_int8(): return int8
|
||||||
if ir_type.is_int16(): return int16
|
# if ir_type.is_int16(): return int16
|
||||||
if ir_type.is_int32(): return int32
|
# if ir_type.is_int32(): return int32
|
||||||
if ir_type.is_int64(): return int64
|
# if ir_type.is_int64(): return int64
|
||||||
if ir_type.is_fp8(): return float8
|
# if ir_type.is_fp8(): return float8
|
||||||
if ir_type.is_fp16(): return float16
|
# if ir_type.is_fp16(): return float16
|
||||||
if ir_type.is_bf16(): return bfloat16
|
# if ir_type.is_bf16(): return bfloat16
|
||||||
if ir_type.is_fp32(): return float32
|
# if ir_type.is_fp32(): return float32
|
||||||
if ir_type.is_fp64(): return float64
|
# if ir_type.is_fp64(): return float64
|
||||||
raise ValueError(f"Unsupported type {ir_type.repr()}")
|
# raise ValueError(f"Unsupported type {ir_type.repr()}")
|
||||||
|
|
||||||
def __init__(self, handle, type: dtype):
|
def __init__(self, handle, type: dtype):
|
||||||
# IR handle
|
# IR handle
|
||||||
|
Reference in New Issue
Block a user