pythonで等価性判定

pythonでの等価性判定のまとめ。
前提となるpythonやライブラリのバージョンは以下。

  • python 3.7.11
  • numpy 1.19.5
  • pandas 1.1.5

文字列、数値、真偽値

等価演算子「==」で判定。

print(1 == 1)        # True
print(1 == 0)        # False
print('1' == '1')    # True
print(1 == '1')      # False
print(True == True)  # True

bool型はint型のサブクラスで、True、Falseは1、0と等価。

print(1 == True)  # True
print(0 == False) # True

浮動小数点誤差を許容する場合には、等価演算子を利用するとFalseになるので。math.iscloseを利用する。

a = 0.3
b= 0.1 + 0.1 + 0.1
print(a, b)   # 0.3, 0.30000000000000004
print(a == b) # False
print(math.isclose(a, b)) # True

None

is 演算子で判定。

a = None
print(a is None) # True

is 演算子での判定は、オブジェクトidの比較と同じで、同一性判定と同じ。

print(id(a) == id(None)) # True

Noneの判定で、「==」を用いないのは、「==」は__eq__ メソッドを定義することで、クラス固有の等価判定を実装でき、思わぬ結果になることがあるため。
また、is 演算子の方が高速らしい。

参考:== と is の違い

非数(nan)

非数 nanは、float型の値で、mathとnumpyで定数定義されている。
等価演算子で比較することはできない。

import math
import numpy as np

a = float('nan')
print(a == math.nan)  # False
print(a == np.nan)    # False

isnan関数で判定する。

print(math.isnan(a))  # True
print(np.isnan(a))    # True

isnan関数は数値以外の引数はエラーになるので注意(javascriptのisNaNと同じノリで使用はできない)。

print(math.isnan(1))    # False
print(math.isnan('1'))  # TypeError

リスト(要素がプリミティブ値のみの場合)

リストの判定は、含まれる要素のパターンによって判定方法が異なってくる。
現時点での自分が考えるやり方は以下。

  • 基本は、等価演算子「==」
  • 要素が数値のみで、nan判定必要な場合は、numpy.array_equal(equal_nan=True)
  • 要素が数値のみで、浮動小数点誤差を考慮必要な場合は、numpy.allclose
  • 要素が数値以外を含み、nan判定必要な場合は、自作で頑張る(遅いけど)

なお、ここでのリストは、タイトルにある通り、要素はプリミティブ値のみの場合を考える。

等価演算子で判定する

list1 = [1, 2, 3]
list2 = [1, 2, 3]
print(list1 == list2)  # True

非数nan がリストに含まれる場合は、「==」では判定できない。
for文で要素の中身を比較して頑張ろうとすると、以下になる。

def list_equal(list1:list, list2:list) -> bool:
    isnan = lambda x: type(x) is float and math.isnan(x)

    if len(list1) != len(list2):
        return False
    for x, y in zip(list1, list2):
        if isnan(x):
            if not isnan(y):
                return False
        elif x != y:
            return False
    return True

l1 = [1, 2, '3', float('nan')]
l2 = [1, 2, '3', np.nan]
print(l1 == l2)            # False
print(list_equal(l1, l2))  # True

numpyの関数で判定する

numpyの関数だと、numpy.array_equalnumpy.allclose で判定できる。ただし、numpy.allcloseは等価性判定ではなく近似判定(浮動小数点floatの誤差を考慮するときに使用)。

どちらの関数も、デフォルトではnanの等価性判定オプション equal_nan が False になっている(array_equal は v1.19.0 でこのオプションが追加された)。
なので、nanが含まれるようなリストであれば、このオプションを True にする必要がある。

l1 = [1, 2, 3, float('nan')]
l2 = [1, 2, 3, float('nan')]
print(np.array_equal(l1, l2, equal_nan=True))  # True
print(np.allclose(l1, l2, equal_nan=True))     # True

l3 = [0.3, float('nan')]
l4 = [0.1+0.1+0.1, float('nan')]
print(np.array_equal(l3, l4, equal_nan=True))  # False <- 浮動小数点誤差のため
print(np.allclose(l3, l4, equal_nan=True))     # True

注意が必要なのは、上記のnumpyの関数だと、リストの要素に数値型以外が含まれていると予期せぬことが起こること。

l1 = [1, 2, float('nan'), '3']
l2 = [1, '2', float('nan'), 3]
print(np.array_equal(l1, l2))                   # True
print(np.array_equal(l1, l2, equal_nan=True))   # TypeError

上記のように、equal_nan=False の場合は、要素の型が違ってるのに等価性判定がTrueになってるし、equal_nan=True の場合は TypeError になってしまう。

これは、ソースを見れば原因は分かる。
まず、関数の中で引数を asarray でnumpy配列にしているが、dtypeを指定していない、かつ、文字列が含まれているため、要素がすべて文字列に変換される。そのため、equal_nan=False の場合は、等価だと判定される。
また、equal_nan=True の場合は nanを特定するために、isnan関数を使用しているが、文字列の場合はTypeErrorになる。

数値以外が要素に含まれているのを考慮すると、下記のように関数を修正。

# 数値以外の要素を考慮版の array_equal
def np_array_equal(a1, a2, equal_nan=False):
    try:
        a1, a2 = np.asarray(a1, dtype=object), np.asarray(a2, dtype=object)
    except Exception:
        return False
    if a1.shape != a2.shape:
        return False
    if not equal_nan:
        return bool(np.asarray(a1 == a2).all())

    isnan = lambda x: type(x) is float and np.isnan(x)
    np_isnan = np.frompyfunc(isnan, 1, 1)
    a1nan, a2nan = np_isnan(a1), np_isnan(a2)
    if not (a1nan == a2nan).all():
      return False
    return bool((a1[~a1nan.astype(bool)] == a2[~a1nan.astype(bool)]).all())


l1 = [1, 2, float('nan'), '3']
l2 = [1, '2', float('nan'), 3]
l3 = [1, 2, float('nan'), '3']

print(np_array_equal(l1, l2))                  # False
print(np_array_equal(l1, l2, equal_nan=True))  # False
print(np_array_equal(l1, l3))                  # False
print(np_array_equal(l1, l3, equal_nan=True))  # True

リスト等価性判定まとめ

結局どれで判定すればよいのか。性能面も比較すると以下となる。

判定方法特徴性能(※)
等価演算子「==」・一番高速
・nanが判定できない
・数値以外もOK
0.825 秒
numpy.array_equal・オプション指定でnanが判定可能
・数値以外が要素にあると予期せぬ挙動
8.733 秒
numpy.allclose・オプション指定でnanが判定可能
・数値以外が要素にあると予期せぬ挙動
・浮動小数点誤差を考慮可能
8.859 秒
自作 list_equal
(for文回して頑張るやつ)
・nanが判定できる
・数値以外もOK
・遅い
11.721 秒
自作 np_array_equal
(array_equalを改造したやつ)
・nanが判定できる
・数値以外もOK
・とても遅い
26.585 秒
※性能は、5千万要素のリスト同士で等価性判定したときの結果。
実行環境は、Google Colabの無料版の環境。

というわけで、節の冒頭で述べた結論となる。

numpyの配列

numpyの配列(ndarray)の場合には、リストの判定の節で記載した、numpy.array_equalnumpy.allclose で判定する。(リストの節で、これらの関数の特徴をある程度記載しているのでここでは省略。)

ちなみに、等価演算子だと、各要素の等価性判定が ndarrayで出力される。

a = np.array([ [1, 2, 3], [4, 5, 6] ])
b = np.array([ [1, 2, 3], [4, 5, 7] ])
print(a == b)  # [[ True  True  True] [ True  True False]]

辞書(要素がプリミティブ値のみの場合)

リストの場合と同様に、含まれる要素のパターンによって判定方法が異なってくる。
基本は、等価演算子「==」で良いが、nan判定が必要な場合とか、浮動小数点誤差の考慮が必要な場合は、いったんリストに変換して、リストの等価性判定で利用した関数を利用すれば良い。

等価演算子で判定

a = dict(one=1, two=2, three=3)
b = dict(one=1, two=2, three=3)
c = dict(three=3, two=2, one=1)

print(a == b)  # True
print(a == c)  # True(キー順序が違って定義しても等価と判定してくれる)

nanが含まれる場合や、浮動小数点誤差が発生する場合は、判定がFalseになる

a = dict(one=1, two=2, nan=float('nan'))
b = dict(one=1, two=2, nan=float('nan'))
print(a == b)  # False

a = dict(one=1, two=2, three=0.3)
b = dict(one=1, two=2, three=0.1+0.1+0.1)
print(a == b)  # False

nanが含まれる場合

いったんリストに変換してソートし、あとはリストの等価性判定の応用。

a = dict(one=1, two=2, nan=float('nan'))
b = dict(one=1, two=2, nan=float('nan'))

sort_a = sorted(a.items())
key_a, val_a = [x[0] for x in sort_a], [x[1] for x in sort_a]
sort_b = sorted(b.items())
key_b, val_b = [x[0] for x in sort_b], [x[1] for x in sort_b]

print(key_a == key_b and np.array_equal(val_a, val_b, equal_nan=True)) # True

他にも、浮動小数点誤差を考慮する場合などもあるが、上記のようにリストの等価性判定の応用でいけるので、ここでは省略。

一般的なリスト、辞書

リストの要素が辞書だったり、辞書の値がリストだったりする場合でも、等価演算子「==」で再帰的に判定してくれる。

a = dict(one=[1], two=[1,2], three=[1, 2, 3])
b = dict(three=[1, 2, 3], two=[1,2], one=[1])
print(a == b) # True

a = [dict(one=1), dict(one=1, two=2), dict(one=1, two=2, three=3)]
b = [dict(one=1), dict(one=1, two=2), dict(one=1, two=2, three=3)]
print(a == b) # True

nanが含まれる場合や、浮動小数点誤差を考慮する場合は、基本的にはリストの場合の応用でできると思うので、ここでは省略。

DataFrame

pandasのDataFrameの場合、等価演算子「==」だと、各要素の比較を行う。

df_a = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6], c=[7, float('nan'), 9]))
df_b = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6], c=[7, float('nan'), 9]))
df_a == df_b

上記のように、nanの場合はFalseになる。

DataFrame全体で判定を行いたい場合は、equalsメソッドを利用する。
このメソッドの場合には、nanが含まれていても、等価性判定が正常に行われる。

df_a = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6], c=[7, float('nan'), 9]))
df_b = pd.DataFrame(dict(a=[1, 2, 3], b=[4, 5, 6], c=[7, float('nan'), 9]))
df_a.equals(df_b) # True

浮動小数点誤差の考慮が必要な場合は、equalsメソッドではFalseになるので、DataFrameのvaluesをnumpy.allclose で判定する。

df_a = pd.DataFrame(dict(a=[0.1, 0.2], b=[0.3, 0.4]))
df_b = pd.DataFrame(dict(a=[0.1, 0.2], b=[0.1+0.1+0.1, 0.4]))

df_a.equals(df_b) # False
np.allclose(df_a.values, df_b.values) \
and list(df_a.index) == list(df_b.index) \
and list(df_a.columns) == list(df_b.columns) # True

以上です。