読者です 読者をやめる 読者になる 読者になる

きくらげ観察日記

好きなことを、適当に。

Pythonで代数的データ型とパターンマッチ

PythonHaskellっぽい代数的データ型とパターンマッチをできるようにするためのメタクラスです。
まあこれも余り用途はなさそうですけど。

def AlgTypeMeta(*names):
    class _AlgTypeMeta(type):
        def __new__(metacls, name, bases, methods):
            def init(self, *args):
                len_args = len(args)
                len_names = len(names)
                if len_args != len_names:
                    raise TypeError('%s takes %d positional arguments but %d were given'
                                   % (name, len_names, len_args))
                self._values = args
                for prop, value in zip(names, args):
                    setattr(self, prop, value)
            methods['__init__'] = init

            def get_values(self):
                return self._values
            methods['get_values'] = get_values

            return super().__new__(metacls, name, bases, methods)
    return _AlgTypeMeta

def match(obj, pat):
    return pat[obj.__class__](*obj.get_values())

このメタクラスを使うと、以下のように代数的データ型のようなクラスを定義することができます。

class Point2D(metaclass=AlgTypeMeta('x', 'y')):
    pass
class Point3D(metaclass=AlgTypeMeta('x', 'y', 'z')):
    pass

AlgTypeMetaメタクラスを利用することによって、Point2D.__init__(self, x, y), Point3D.__init__(self, x, y, z)が自動生成されます。

>>> p = Point2D(3, 2)
>>> p3 = Point3D(1, 3, 5)

また、先ほどのソースの下の方で定義されているmatch関数を使うと、Haskellのパターンマッチのようなことを行うことができます。

from math import sqrt

def norm(p):
    return match(p, {
        Point2D: lambda x, y: sqrt(x**2 + y**2),
        Point3D: lambda x, y, z: sqrt(x**2 + y**2 + z**2)
    })

実行例:

def norm(p):
    return match(p, {
        Point2D: lambda x, y: sqrt(x**2 + y**2),
        Point3D: lambda x, y, z: sqrt(x**2 + y**2 + z**2)
    })