Pythonで独自に定義したユーザー定義クラスのオブジェクト同士を等価比較する方法を整理します。
はじめに
こんにちは、@bioerrorlogです。
Pythonでは、ユーザー定義クラスは等価演算子で比較することが出来ません。
例えば、次のようなクラス/pytestコードを定義した場合、オブジェクト比較部分のassertでエラーが発生します。
class MyClass: def __init__(self, amount): self.amount = amount def test_equality(): obj_1 = MyClass(5) obj_2 = MyClass(5) assert obj_1 == obj_2 # これが出来ない
pytest実行結果:
=============================== FAILURES ================================ _____________________________ test_equality _____________________________ def test_equality(): obj_1 = MyClass(5) obj_2 = MyClass(5) > assert obj_1 == obj_2 E assert <equality.MyClass object at 0x0000014F05071848> == <equality.MyClass object at 0x0000014F05071C88> equality.py:14: AssertionError ======================== short test summary info ======================== FAILED equality.py::test_equality - assert <equality.MyClass object at ... =========================== 1 failed in 0.03s ===========================
今回は、上のように独自に定義したオブジェクト同士を比較できるようにする方法を記録します。
動作環境
Python 3.7.6 で動作確認しています。
Pythonでユーザー定義クラスのオブジェクトを等価比較する
ユーザー定義クラスのオブジェクトを等価比較できるようにするには、特殊メソッド__eq__
を利用します。
__eq__
メソッドは、等価演算子==
が使われたときに呼び出されるものです。
例えば x==y
が実行されたときには x.__eq__(y)
が呼び出されます。
※参考: 3. Data model — Python 3.10.5 documentation
以下、やり方を見ていきます。
__eq__
メソッドでインスタンス変数を比較する
まず、__eq__
メソッドでインスタンス変数を比較することでオブジェクトを等価比較できるようになります。
実装例:
class MyClass: def __init__(self, amount): self.amount = amount def __eq__(self, other): return self.amount == other.amount # インスタンス変数を比較 def test_equality(): obj_1 = MyClass(5) obj_2 = MyClass(5) assert obj_1 == obj_2 # このテストは通る
ただ、このやり方だとインスタンス変数が変更されるたびに__eq__
メソッドを更新する必要が出てきます。
そこで次は__dict__
を利用する方法を紹介します。
__eq__
メソッドで__dict__
を比較する
__dict__
は、インスタンス変数がdictで格納されている特殊属性です。
これを__eq__
メソッドで比較することで、インスタンス変数をひとつひとつ比較せずにも等価比較が可能です。
実装例:
class MyClass: def __init__(self, amount): self.amount = amount def __eq__(self, other): return self.__dict__ == other.__dict__ # インスタンス変数を比較 def test_equality(): obj_1 = MyClass(5) obj_2 = MyClass(5) assert obj_1 == obj_2 # このテストは通る
同一クラスかを比較する
上に挙げたやり方では、インスタンス変数を比較しているのみなのでクラスの種類までは比較できていません。 なので、仮に全く同じインスタンス変数を持つ別々のクラス同士を比較した場合も、同一のものとして判定されてしまいます:
class MyClass: def __init__(self, amount): self.amount = amount def __eq__(self, other): return self.__dict__ == other.__dict__ class DummyClass: def __init__(self, amount): self.amount = amount def __eq__(self, other): return self.__dict__ == other.__dict__ def test_equality(): obj_1 = MyClass(5) obj_2 = DummyClass(5) assert obj_1 == obj_2 # このテストが通る
異なるクラスのオブジェクトを異なるものとして判定させるためには、__eq__
メソッド内でクラス__class__
を比較させます。
__class__
は、クラス名が格納されている特殊属性です。
実装例:
class MyClass: def __init__(self, amount): self.amount = amount def __eq__(self, other): return ( isinstance(other, self.__class__) and self.__dict__ == other.__dict__ ) class DummyClass: def __init__(self, amount): self.amount = amount def __eq__(self, other): return ( isinstance(other, self.__class__) and self.__dict__ == other.__dict__ ) def test_equality(): obj_1 = MyClass(5) obj_2 = MyClass(5) assert obj_1 == obj_2 # このテストは通る obj_3 = DummyClass(5) assert obj_1 == obj_3 # このテストは通らない
同一クラスでない時はNotImplementedを投げる
同一クラスでない時には、等価演算がサポートされていない意を示すNotImplementedを投げる実装もできます。
実装例:
class MyClass: def __init__(self, amount): self.amount = amount def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented return self.__dict__ == other.__dict__ class DummyClass: def __init__(self, amount): self.amount = amount def __eq__(self, other): if not isinstance(other, self.__class__): return NotImplemented return self.__dict__ == other.__dict__ def test_equality(): obj_1 = MyClass(5) obj_2 = MyClass(5) assert obj_1 == obj_2 # このテストは通る obj_3 = DummyClass(5) assert obj_1 == obj_3 # このテストは通らない
おわりに
以上、Pythonでユーザー定義クラスのオブジェクト同士を等価比較する方法をまとめました。
今回取り上げた問題は、Kent BeckのTDD本をPythonでやり直しているときに遭遇したものです。
Java/JUnitではassertEquals()
ですぐオブジェクトを比較できたものが、Pythonだとちょっとした工夫が必要なことに気付き、備忘録を書きました。
どなたかの参考になれば幸いです。
[関連記事]
参考
3. Data model — Python 3.10.5 documentation
Built-in Constants — Python 3.10.5 documentation
python - Check if two objects have equal content in Pytest - Stack Overflow
Assert custom objects are equal in Python unit test - Gems