[python] basic support for pytorch seems to be working
This commit is contained in:
@@ -34,6 +34,5 @@ def run_torch():
|
|||||||
b = th.randn(K, N).cuda()
|
b = th.randn(K, N).cuda()
|
||||||
th_c = th.matmul(a, b)
|
th_c = th.matmul(a, b)
|
||||||
tr_c = triton.ops.dot(a, b)
|
tr_c = triton.ops.dot(a, b)
|
||||||
print(c)
|
|
||||||
|
|
||||||
run_torch()
|
run_torch()
|
@@ -4,45 +4,49 @@ class OpContext(object):
|
|||||||
|
|
||||||
def save_for_backward(self, *tensors):
|
def save_for_backward(self, *tensors):
|
||||||
self.to_save = tensors
|
self.to_save = tensors
|
||||||
|
|
||||||
def mark_dirty(self, *args):
|
|
||||||
self.dirty_tensors = args
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def saved_tensors(self):
|
def saved_tensors(self):
|
||||||
return self.to_save
|
return self.to_save
|
||||||
|
|
||||||
|
|
||||||
class function_meta(type):
|
class function_meta(type):
|
||||||
|
|
||||||
def __init__(cls, name, bases, attrs):
|
def __init__(cls, name, bases, attrs):
|
||||||
cls.contexts = dict()
|
cls.contexts = dict()
|
||||||
cls.registered = False
|
cls.registered = False
|
||||||
|
cls.framework = None
|
||||||
return super(function_meta, cls).__init__(name, bases, attrs)
|
return super(function_meta, cls).__init__(name, bases, attrs)
|
||||||
|
|
||||||
class function(metaclass = function_meta):
|
class function(metaclass = function_meta):
|
||||||
|
|
||||||
def __init__(self, framework = None):
|
|
||||||
self.framework = _find_framework(framework)
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, *args, **kwargs):
|
def forward(ctx, *args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, *args, **kwargs):
|
def apply_torch(cls, *args, **kwargs):
|
||||||
# call forward
|
fw._import_torch()
|
||||||
|
class TorchFunction(fw.torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, *targs, **tkwargs):
|
||||||
|
return cls.forward(ctx, *targs, **tkwargs)
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
return cls.backward(ctx, grad_output)
|
||||||
|
return TorchFunction.apply(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply_tensorflow(cls, *args, **kwargs):
|
||||||
|
fw._import_tensorflow()
|
||||||
ctx = OpContext()
|
ctx = OpContext()
|
||||||
result = cls.forward(ctx, *args, **kwargs)
|
result = cls.forward(ctx, *args, **kwargs)
|
||||||
id = result.op.get_attr('id')
|
id = result.op.get_attr('id')
|
||||||
cls.contexts[id] = ctx
|
cls.contexts[id] = ctx
|
||||||
# register backward
|
# register backward
|
||||||
fw._import_tensorflow()
|
|
||||||
name = result.op.op_def.name
|
name = result.op.op_def.name
|
||||||
if not cls.registered:
|
if not cls.registered:
|
||||||
@fw.tensorflow.RegisterGradient(name)
|
@fw.tensorflow.RegisterGradient(name)
|
||||||
@@ -51,4 +55,12 @@ class function(metaclass = function_meta):
|
|||||||
return cls.backward(cls.contexts[id], dy)
|
return cls.backward(cls.contexts[id], dy)
|
||||||
cls.registered = True
|
cls.registered = True
|
||||||
# return result tensor
|
# return result tensor
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def apply(cls, *args, **kwargs):
|
||||||
|
cls.framework = fw._find_framework(cls.framework)
|
||||||
|
if cls.framework == fw.tensorflow_id:
|
||||||
|
return cls.apply_tensorflow(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return cls.apply_torch(*args, **kwargs)
|
||||||
|
Reference in New Issue
Block a user