Pythonでユーザー定義クラスのオブジェクトを等価比較する

Pythonで独自に定義したユーザー定義クラスのオブジェクト同士を等価比較する方法を整理します。

はじめに

おはよう。@bioerrorlogです。

Pythonでは、ユーザー定義クラスは等価演算子で比較することが出来ません。

例えば、次のようなクラス/pytestコードを定義した場合、オブジェクト比較部分のassertでエラーが発生します。

class MyClass:
    def __init__(self, amount):
        self.amount = amount


def test_equality():
    obj_1 = MyClass(5)
    obj_2 = MyClass(5)
    assert obj_1 == obj_2 # これが出来ない


pytest実行結果:

=============================== FAILURES ================================
_____________________________ test_equality _____________________________

    def test_equality():
        obj_1 = MyClass(5)
        obj_2 = MyClass(5)
>       assert obj_1 == obj_2
E       assert <equality.MyClass object at 0x0000014F05071848> == <equality.MyClass object at 0x0000014F05071C88>

equality.py:14: AssertionError
======================== short test summary info ========================
FAILED equality.py::test_equality - assert <equality.MyClass object at ...
=========================== 1 failed in 0.03s ===========================


今回は、上のように独自に定義したオブジェクト同士を比較できるようにする方法を記録します。

動作環境

Python 3.7.6 で動作確認しています。

Pythonでユーザー定義クラスのオブジェクトを等価比較する

ユーザー定義クラスのオブジェクトを等価比較できるようにするには、特殊メソッド__eq__を利用します。

__eq__メソッドは、等価演算子==が使われたときに呼び出されるものです。
例えば x==y が実行されたときには x.__eq__(y)が呼び出されます。
※参考: 3. Data model — Python 3.9.6 documentation

以下、やり方を見ていきます。

__eq__メソッドでインスタンス変数を比較する

まず、__eq__メソッドでインスタンス変数を比較することでオブジェクトを等価比較できるようになります。

実装例:

class MyClass:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):
        return self.amount == other.amount # インスタンス変数を比較

def test_equality():
    obj_1 = MyClass(5)
    obj_2 = MyClass(5)
    assert obj_1 == obj_2 # このテストは通る


ただ、このやり方だとインスタンス変数が変更されるたびに__eq__メソッドを更新する必要が出てきます。 そこで次は__dict__を利用する方法を紹介します。

__eq__メソッドで__dict__を比較する

__dict__は、インスタンス変数がdictで格納されている特殊属性です。 これを__eq__メソッドで比較することで、インスタンス変数をひとつひとつ比較せずにも等価比較が可能です。

実装例:

class MyClass:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):
        return self.__dict__ == other.__dict__ # インスタンス変数を比較

def test_equality():
    obj_1 = MyClass(5)
    obj_2 = MyClass(5)
    assert obj_1 == obj_2 # このテストは通る

同一クラスかを比較する

上に挙げたやり方では、インスタンス変数を比較しているのみなのでクラスの種類までは比較できていません。 なので、仮に全く同じインスタンス変数を持つ別々のクラス同士を比較した場合も、同一のものとして判定されてしまいます:

class MyClass:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

class DummyClass:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):
        return self.__dict__ == other.__dict__


def test_equality():
    obj_1 = MyClass(5)
    obj_2 = DummyClass(5)
    assert obj_1 == obj_2 # このテストが通る


異なるクラスのオブジェクトを異なるものとして判定させるためには、__eq__メソッド内でクラス__class__を比較させます。 __class__は、クラス名が格納されている特殊属性です。

実装例:

class MyClass:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):
        return (
            isinstance(other, self.__class__) and
            self.__dict__ == other.__dict__
        )


class DummyClass:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):
        return (
            isinstance(other, self.__class__) and
            self.__dict__ == other.__dict__
        )


def test_equality():
    obj_1 = MyClass(5)
    obj_2 = MyClass(5)
    assert obj_1 == obj_2 # このテストは通る
    
    obj_3 = DummyClass(5)
    assert obj_1 == obj_3 # このテストは通らない

同一クラスでない時はNotImplementedを投げる

同一クラスでない時には、等価演算がサポートされていない意を示すNotImplementedを投げる実装もできます。

実装例:

class MyClass:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):
        if not isinstance(other, self.__class__):
            return NotImplemented
        return self.__dict__ == other.__dict__

class DummyClass:
    def __init__(self, amount):
        self.amount = amount

    def __eq__(self, other):
        if not isinstance(other, self.__class__):
            return NotImplemented
        return self.__dict__ == other.__dict__


def test_equality():
    obj_1 = MyClass(5)
    obj_2 = MyClass(5)
    assert obj_1 == obj_2 # このテストは通る
    
    obj_3 = DummyClass(5)
    assert obj_1 == obj_3 # このテストは通らない

おわりに

以上、Pythonでユーザー定義クラスのオブジェクト同士を等価比較する方法をまとめました。

今回取り上げた問題は、Kent BeckのTDD本をPythonでやり直しているときに遭遇したものです。

Java/JUnitではassertEquals()ですぐオブジェクトを比較できたものが、Pythonだとちょっとした工夫が必要なことに気付き、備忘録を書きました。

どなたかの参考になれば幸いです。

[関連記事]

www.bioerrorlog.work

www.bioerrorlog.work

参考

3. Data model — Python 3.9.6 documentation

Built-in Constants — Python 3.9.6 documentation

python - Check if two objects have equal content in Pytest - Stack Overflow

Is there a way to check if two object contain the same values in each of their variables in python? - Stack Overflow

Assert custom objects are equal in Python unit test - Gems

aws-cli/__init__.py at develop · aws/aws-cli · GitHub

aws-cli/test_alias.py at develop · aws/aws-cli · GitHub