[python] basic support for pytorch seems to be working
This commit is contained in:
@@ -34,6 +34,5 @@ def run_torch():
|
||||
b = th.randn(K, N).cuda()
|
||||
th_c = th.matmul(a, b)
|
||||
tr_c = triton.ops.dot(a, b)
|
||||
print(c)
|
||||
|
||||
run_torch()
|
@@ -5,27 +5,20 @@ class OpContext(object):
|
||||
def save_for_backward(self, *tensors):
|
||||
self.to_save = tensors
|
||||
|
||||
def mark_dirty(self, *args):
|
||||
self.dirty_tensors = args
|
||||
|
||||
@property
|
||||
def saved_tensors(self):
|
||||
return self.to_save
|
||||
|
||||
|
||||
class function_meta(type):
|
||||
|
||||
def __init__(cls, name, bases, attrs):
|
||||
cls.contexts = dict()
|
||||
cls.registered = False
|
||||
cls.framework = None
|
||||
return super(function_meta, cls).__init__(name, bases, attrs)
|
||||
|
||||
class function(metaclass = function_meta):
|
||||
|
||||
def __init__(self, framework = None):
|
||||
self.framework = _find_framework(framework)
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
@@ -35,14 +28,25 @@ class function(metaclass = function_meta):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def apply(cls, *args, **kwargs):
|
||||
# call forward
|
||||
def apply_torch(cls, *args, **kwargs):
|
||||
fw._import_torch()
|
||||
class TorchFunction(fw.torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, *targs, **tkwargs):
|
||||
return cls.forward(ctx, *targs, **tkwargs)
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return cls.backward(ctx, grad_output)
|
||||
return TorchFunction.apply(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def apply_tensorflow(cls, *args, **kwargs):
|
||||
fw._import_tensorflow()
|
||||
ctx = OpContext()
|
||||
result = cls.forward(ctx, *args, **kwargs)
|
||||
id = result.op.get_attr('id')
|
||||
cls.contexts[id] = ctx
|
||||
# register backward
|
||||
fw._import_tensorflow()
|
||||
name = result.op.op_def.name
|
||||
if not cls.registered:
|
||||
@fw.tensorflow.RegisterGradient(name)
|
||||
@@ -52,3 +56,11 @@ class function(metaclass = function_meta):
|
||||
cls.registered = True
|
||||
# return result tensor
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def apply(cls, *args, **kwargs):
|
||||
cls.framework = fw._find_framework(cls.framework)
|
||||
if cls.framework == fw.tensorflow_id:
|
||||
return cls.apply_tensorflow(*args, **kwargs)
|
||||
else:
|
||||
return cls.apply_torch(*args, **kwargs)
|
||||
|
Reference in New Issue
Block a user