Pythonのdataclassesでhashableなインスタンスを作るクラスを定義する

Pythonでコレクションを扱っていると、ハッシュ可能なオブジェクトの使用を求められることがままある。

例えば、辞書のキーはハッシュ可能なオブジェクトでないといけないし、集合の要素もハッシュ可能でなければならない。 これはコレクション内部のアルゴリズムにおいて、オブジェクトのハッシュ値を使っているからだろう。

時々、自分で定義したクラスのインスタンスを辞書や集合の中で使いたくなるときがある。そういうとき、そのクラスのインスタンスはハッシュ可能である必要がある。

そこで、ハッシュ可能なインスタンスを生成するクラスは、どのように定義するのが適切なのかについて考えたい。  

組み込みのclass

Pythonの組み込みのclassでクラスを定義すると、そのインスタンスはハッシュ可能である。
つまり、__hash__メソッドを持っていて、ハッシュ値を計算できる。  

class Company:
    def __init__(self, cid, name):
        self.cid = cid
        self.name = name

toyota = Company(1, "Toyota")
hash(toyota) # -9223371913295946368

 
じゃあ、組み込みのclass使っておけばOKなのだろうか。 そういうわけにもいかないと思われる。

なぜなら、Pythonの組み込みのclassのミュータブルな性質が、「プログラム中においてハッシュ値はイミュータブルでなければならない」という要件と相容れないからだ。

公式ドキュメントを読むと、辞書や集合の実装が、ハッシュ可能なオブジェクトのハッシュ値がイミュータブルであることを要求していると書いてある。 しかし、Pythonの組み込みのclassインスタンスのフィールドの値は簡単に書き換えることができ、そのインスタンスハッシュ値がイミュータブルであることは全然保証されていない。

要約すると、ハッシュ可能なオブジェクトはイミュータブルでなければならないが、Pythonの組み込みのclassインスタンスはハッシュ可能であるにもかかわらず、ミュータブルになりがちである。 そういうわけで、ユーザー定義クラスをハッシュ可能にしたいならば、組み込みのclassを使わない方がいいような気がする。

dataclasses

次にdataclassesモジュールを使う方法がある。
これはバージョン3.7で新たにリリースされた機能らしく、使えるようになったのは割と最近である。

このモジュールを使って新たにクラスを定義するときには、いくつかのパラメータを指定することができる。 そのパラメータの一つに、frozenがある。 デフォルトではfrozenFalseだが、これをTrueにしてやると、イミュータブルなインスタンスを生成するクラスを定義することができる。

@dataclass(frozen=True)
class Company:
    company_id: int
    name: str

そして、frozenなクラスから生成されるインスタンスは、ハッシュ可能である。
frozen=Trueにすると、適切な__hash__メソッドを自動で生成してくれている。
実際にハッシュ可能であるかを確かめてみよう。

@dataclass(frozen=True)
class Company:
    company_id: int
    name: str

toyota = Company(1, "Toyota")

hash(toyota)  #-124679679567354462
set([toyota]) # 集合の要素にできる
{toyota: 1}    # 辞書のキーにできる

確かに、Companyのインスタンスはハッシュ可能である。

試しに、上記のfrozenパラメータをFalseにしてみると、「toyotaはハッシュ不可能なオブジェクトだ」というエラーが発生する。 つまり、dataclassを使って普通に定義したクラスのインスタンスは、ミュータブルであるがハッシュ可能ではない。

これは、dataclassを使ってクラスを定義しておけば、ミュータブルでハッシュ可能なオブジェクトを作らずに済むということだろう。(変なことをしない限り)

とはいえ、先ほどのfrozenなクラスのインスタンスも完全にイミュータブルであるわけではない。
実はフィールドに辞書がセットされている場合、その辞書の一部の値を書き換えることはできる。

@dataclass(frozen=True)
class Company:
    company_id: int
    name: str
    child: dict

toyota = Company(1, "Toyota", {"company_id": 5, "name": "SUBALU"})
toyota.child["name"] = "SUBARU"

print(toyota)
# Company(company_id=1, name='Toyota', child={'company_id': 5, 'name': 'SUBARU'})

frozenなクラスも「浅く」イミュータブルなだけであるようだ。
しかし、この場合でも、ちゃんとインスタンス自体はハッシュ不可能にしてくれている。

dataclassesは最近追加された機能だが、構文的にも他の言語でいうところのクラスに近くて良いと思う。 組み込みのクラスを使う場合と比較して、オーバーヘッドがどうなのかというところはちょっと気になるところではあるが、特にパフォーマンスが問題にならない場面では積極的に使わない手はないと思っている。

Pythonの集合同士の差分を取るときに陥った落とし穴を記録しておく

WEBから取ってきたデータと既にDBに入っているデータの差分を取りたいと思うことがある。

辞書

最初は、辞書でデータを扱っていたとしよう。

companies_in_web = [{"cid": 0, "name": "Toyota"}, {"cid": 1, "name": "Nissan"}, {"cid": 2, "name": "Honda"}]
companies_in_db =  [{"cid": 0, "name": "Toyota"}, {"cid": 1, "name": "Nissan"}, {"cid": 4, "name": "Mazda"}]

 
そして、下のような感じで辞書が入ったリストを集合に変換して差分を取りたい。
取れるんじゃないかと思っていた。  

diff = set(companies_in_web) - set(companies_in_db)

 
しかし、残念ながら、これはTypeErrorを吐く。
Pythonのset型の要素はハッシュ可能である必要があり、辞書はハッシュ可能ではないからだ。

https://docs.python.org/ja/3/library/stdtypes.html#set  

# TypeError: unhashable type: 'dict'

ユーザー定義クラス

それでは、辞書ではなく、こういうデータ構造ならどうだろうか。

class Company:

    def __init__(self, cid, name):
        self.cid = cid
        self.name = name

    def __repr__(self):
        return f"Company: cid={self.cid}, name={self.name}"

companies_in_web = [Company(0, "Toyota"), Company(1, "Nissan"), Company(2, "Honda")]
companies_in_db =  [Company(0, "Toyota"), Company(1, "Nissan"), Company(3, "Mazda")]

それでは、差分を取ってみよう。

diff = set(companies_in_web) - set(companies_in_db)
# {Company: cid=0, name=Toyota, Company: cid=1, name=Nissan, Company: cid=2, name=Honda}

 
残念ながら、これは期待した結果ではない。
Company(2, "Honda")だけが差分抽出されてほしいのに、ToyotaNissanも抽出されている。  

ユーザー定義クラス・改

それでは、先ほどのクラスに__eq____hash__関数を実装してみよう。
(追記:ミュータブルなオブジェクトに__hash__を実装するのは適切なのか...)

class Company:

    def __init__(self, cid, name):
        self.cid = cid
        self.name = name

    def __repr__(self):
        return f"Company: cid={self.cid}, name={self.name}"

    def __eq__(self, other):
        if isinstance(other, Company):
            return self.cid == other.cid and self.name == other.name

    def __hash__(self):
        return hash((self.cid, self.name))

companies_in_web = [Company(0, "Toyota"), Company(1, "Nissan"), Company(2, "Honda")]
companies_in_db =  [Company(0, "Toyota"), Company(1, "Nissan"), Company(3, "Mazda")]

差分をとってみる。

diff = set(companies_in_web) - set(companies_in_db)
# {Company: cid=2, name=Honda}

ようやく期待通りのものが抽出された。

集合同士の差分を取る際には、当然要素同士の比較が行われているはずだが、
どうやら__eq____hash__が実装されているか否かによってその比較の挙動が違うようである。

今日はこのような確認だけしておき、次回の記事でオブジェクトの同一性・同値性の詳細について書いていきたい。

Shadow Tacticsという死にゲー

一体、何回死んだだろうか。

バレないと思って脇を横切ろうとしたところを発見され、なぶり殺しにされる。
見られてないと高をくくってやった殺しを目撃され、逆襲を受ける。
気づきもしなかった町人にチクられて、思わぬリンチを食らう。

数えきれないほど赤に染まった画面を見た。

f:id:nowaai:20200220223452p:plain

とはいえ、このゲームは、そこまで「死」を引きずらない。
それは、見た目に反して、このゲームがパズルゲームに近いからなのだと思う。

このゲームを進めるには、目の前の敵を殺すか、タイミングよく脇をすり抜けるか、どちらかである。
多くの場面で「殺す」ことを選択しなければならないが、目の前の殺すべき敵は必ず別の敵が見ている。
目の前の敵を殺すには、まず彼を視界にとらえている別の敵を殺す必要がある。
その別の敵を殺すには、彼を視界にとらえているまた別の敵を...

このゲームのパズルとは、このような認知ゾーンの依存関係を解きほぐしていくパズルである。

f:id:nowaai:20200220224002p:plain

ただ、本当にこのゲームがドライなパズルでしかなかったとしたら、すぐに飽きてしまっただろう。
なかなか飽きが来ないのは、各ダンジョンがリアルに江戸時代を再現しているからだと思う。

雪で白んだ屋敷町、雨の陰鬱な農村、紅葉鮮やかな城下町、緑豊かな山中の寺。
四季折々のダンジョンがプレイヤーを待ち受ける。
ダンジョン内のオブジェクトは一つ一つがきめ細やかで、非常に再現性が高い。
思わず見入ってしまう時もある。

f:id:nowaai:20200220224048p:plain

このような画像だけを見て始めたりすると、最初はこのゲームの本質がパズルにあるとは思わない。
ましてや、ステルスなどという紹介でMGSを頭に思い浮かべながら、このゲームを始めた自分みたいなプレイヤーは。

しかし、最初のダンジョン、自軍支援のための城下町潜入プロジェクトですぐに気づくことになるだろう。 初っ端から何度も何度も死ぬ羽目になるのだから。そして、ゲームを進めるごとにそのパズルの難易度は上がっていく。

本当にこれクリアできるのかと思う場面が何度もあり、諦めてしまいそうになるが、寝て起きると、解決策が閃いてクリアできたりする。そういう試行錯誤を繰り返しつつ、オレは今日も江戸時代パズルに立ち向かっていく...

Rの自作パッケージ内で依存パッケージを読み込む方法(Rで顔文字演算子)

今回はRで顔文字演算子を作ることができる話...ではなく、Rの自作パッケージ内でどうやってパッケージを読み込むかについての話である。

Rでパッケージを自作していると、自作パッケージ内で依存パッケージをどのように読み込むかという問題に当たる。

例えば、自作しているパッケージでは、ガツガツmagrittrdplyrを使っている。
これらは、Rのbaseのデータフレーム操作の関数よりもコードをすっきり書けるのでよい。

特にmagrittrのパイプが素晴らしい。

library(dplyr)
library(magrittr)

iris %>%
  filter(Species != "setosa") %>%
  select(-starts_with("Sepal")) %>%
  mutate(petal_area = Petal.Length * Petal.Width * 0.5) %>%
  group_by(Species) %>%
  summarise_all(funs(mean))

しかし、ここで問題なのは、

library(dplyr)
library(magrittr)

である。

普通に使い捨てのスクリプトを書いているときは、これで問題ない。
しかし、自作のパッケージの中で、library(x)を使ってパッケージを読み込むのはNGであるようだ。

http://r-pkgs.had.co.nz/namespace.html

そもそも、library(x)は何をしているのか。これはパッケージを「アタッチ」している。アタッチとは、サーチパスにパッケージの名前を入れることであり、Rの何らかの関数が呼び出されるとき、まずはこのサーチパスの中に入っている名前空間から関数が検索される。 そして、サーチパスに入っているパッケージについては、わざわざその名前空間を指定しなくても、関数を呼び出すことができる。

library(dplyr)

filter(iris, Species != "setosa")

たしかに、これによってタイピング量を減らすことはできるのだが、パッケージの名前空間を丸ごとサーチパスに入れるため、場合によっては、予期せぬ関数名の衝突を起こす可能性がある。そのため、パッケージレベルではlibrary(x)の使用はNGとなっている。(個人とか会社のチーム内で使うとか、そういうレベルのパッケージならいいかもしれないが...)

そのため、パッケージ内では、基本的に

dplyr::filter()

のように名前空間を指定して、関数を呼び出すのが基本となる。

そうすると、困ることが一点だけある。
あの美しいパイプラインを実現してくれていた、magrittrのパイプ演算子%>%)である。
実はこのパイプ演算子名前空間を指定して使おうとすると、二項演算子ではなく関数として認識されてしまう。 (まあ、名前空間つきで使えたとしても、そんな冗長な二項演算子使う人もいないだろうが...)

magrittr::`%>%`(ihs, rhs)

それでは、パッケージ内では%>%演算子として使えないということになってしまう。
ところがどっこい、Rには裏技があったようだ。 「スペシャ演算子定義」である。

https://cran.r-project.org/doc/manuals/r-release/R-lang.html#Special-operators

Rでは、下のように%で囲むと、お手軽に新しい演算子を定義することができる。
そして、%の中にはどんな文字列を入れてやってもよい。まさにやりたい放題できるというわけだ。
ちなみに、これは既存の演算子オーバーロードという生易しいやつではない。
(言語仕様的に適切なのだろうか...)

例えば、+エイリアスを作ることできる。

`%+%` <- function(a, b) { return(a + b) }
1 %+% 2 # 3

同じようにすれば、magrittr%>%問題も万事解決する。

# library(magrittr)

# 演算子エイリアス
`%>%` <- magrittr::`%>%`

# 使える
iris %>% dplyr::filter(Species != "setosa")

これを使えば、Rでは顔文字演算子なども簡単に作れたりしてしまう。

`%アワワヽ(´Д`;≡;´Д`)ノアワワ%` <- function(a, b){ print("詰み") }

"バグ" %アワワヽ(´Д`;≡;´Д`)ノアワワ% "障害"

# 詰み

カリー化した関数に型ヒントを与える

Pythonの型ヒントに慣れるために、カリー化した関数に型ヒントを与えるというタスクをやってみる。

ちなみに、カリー化とは、複数の引数をとる関数を、一つの引数だけをとる関数のチェーンに変換する操作のことである。

以下、簡単なカリー化の例(Python)。

# 引数2つの場合
add_2num = lambda x, y: x + y 
cadd_2num = lambda x: lambda y: x + y # カリー化後

# 引数3つの場合
add_3num = lambda x, y, z: x + y + z
cadd_3num = lambda x: lambda y: lambda z: x + y + z # カリー化後

カリー化した関数は、下記のように使用できる。

add_2num(1,2) # 3
cadd_2num(1)(2) # 3

add_3num(1,2,3) # 6
cadd_3num(1)(2)(3) # 6

increment = cadd_2num(1) # 部分適用
increment(2) # 3

さらに進んで、カリー化した関数に型ヒントを与えてみる。

from typing import Callable

# 引数2つの場合
add_2num: Callable[[float, float], float] = lambda x, y: x + y
cadd_2num: Callable[[float], Callable[[float], float]] = lambda x: lambda y: x + y

# 引数3つの場合
add_3num: Callable[[float, float, float], float] = lambda x, y, z: x + y + z
cadd_3num: Callable[[float], Callable[[float], Callable[[float], float]]] = lambda x: lambda y: lambda z: x + y + z

cadd_2num(1)(2) # 3
cadd_2num(1)('2') # TypeError: intとstringを+演算しようとしてエラー
cadd_2num('1')('2') # 12 エラーは発生しない

mypyで型チェックをしてみる。

mypy main.py

# curried.py:12: error: Argument 1 has incompatible type "str"; expected "float"
# curried.py:13: error: Argument 1 has incompatible type "str"; expected "float"

cadd_2num('1')('2')の2番目の引数の不正は指摘してくれないのは、何故なのか...

参考

https://ja.wikipedia.org/wiki/%E3%82%AB%E3%83%AA%E3%83%BC%E5%8C%96 https://www.ibm.com/developerworks/jp/java/library/j-jn9/index.html

Pythonでの雑でガバっとした例外処理

定期的に動かすので、ある程度きちんとプログラムを書かないといけないけど、そんなに丁寧に例外処理を書いている余裕はないぞというケースがある。

そういう時、実行スクリプトでは、このように雑に例外処理を行うことが多い。 こうしておけば、とりあえずどんな例外もログに吐き出すことができ、未知の例外に遭遇した時はエラーメッセージを見ながら、適宜コードを修正することができる。

import traceback


try:
    a()
    b()
    c()
    d()
    e()
except:
    msg = traceback.format_exc()
    mail_logger.error(msg)

ところでDB周りの業務をやっていると、SQLAlchemyを使って、下のようにsession_scope関数を定義し、色んなところで使いまわす。

from contextlib import contextmanager


@contextmanager
def session_scope():
    session = Session()
    try:
        yield session
        session.commit()
    except:
        session.rollback()
    finally:
        session.close()

# session_scopeはこのように使う
def insert(something):
    with session_scope() as session:
        session.add(something)
# メインの処理はこのような感じ
try:
    something = Something()
    insert(something)
except:
    msg = traceback.format_exc()
    mail_logger.error(msg)

しかし、このsession_scopeを上記のようなtry句内で使うと、DB操作で例外が発生した時に、session_scope内のexceptfinnalyでプログラムが終了し、肝心の最上位のexcept句に処理が到達しない。 こうなると、ログが吐き出されないので、困ったことになる。

こういうときは、session_scope内のexcept句の中で例外を再スローしてやるとよい。 そうすると、きちんと処理が最上位のexceptにまで到達し、DB操作時に例外が発生した時の状況がログに書き込まれる。

from exceptions import DBHandleException


@contextmanager
def session_scope():
    session = Session()
    try:
        yield session
        session.commit()
    except:
        session.rollback()
        # 例外を再スローする
        raise DBHandleException
    finally:
        session.close()

Pythonのloggingで独自ハンドラを作りたいとき

Pythonのloggingを利用していると、独自のハンドラを作りたくなるときがある。
そういう時には、すでにloggingに定義されているハンドラのコードを参考にすることができる。

例えば、logging.handlers.SMTPHandlerは下記のように定義されている。

https://github.com/python/cpython/blob/master/Lib/logging/handlers.py

# ソースコード内のコメントは削除
class SMTPHandler(logging.Handler):

    def __init__(self, mailhost, fromaddr, toaddrs, subject,
                 credentials=None, secure=None, timeout=5.0):

        logging.Handler.__init__(self)
        if isinstance(mailhost, (list, tuple)):
            self.mailhost, self.mailport = mailhost
        else:
            self.mailhost, self.mailport = mailhost, None
        if isinstance(credentials, (list, tuple)):
            self.username, self.password = credentials
        else:
            self.username = None
        self.fromaddr = fromaddr
        if isinstance(toaddrs, str):
            toaddrs = [toaddrs]
        self.toaddrs = toaddrs
        self.subject = subject
        self.secure = secure
        self.timeout = timeout

    def getSubject(self, record):
        return self.subject

    def emit(self, record):
        try:
            import smtplib
            from email.message import EmailMessage
            import email.utils

            port = self.mailport
            if not port:
                port = smtplib.SMTP_PORT
            smtp = smtplib.SMTP(self.mailhost, port, timeout=self.timeout)
            msg = EmailMessage()
            msg['From'] = self.fromaddr
            msg['To'] = ','.join(self.toaddrs)
            msg['Subject'] = self.getSubject(record)
            msg['Date'] = email.utils.localtime()
            msg.set_content(self.format(record))
            if self.username:
                if self.secure is not None:
                    smtp.ehlo()
                    smtp.starttls(*self.secure)
                    smtp.ehlo()
                smtp.login(self.username, self.password)
            smtp.send_message(msg)
            smtp.quit()
        except Exception:
            self.handleError(record)

これを参考にすると、独自ハンドラを作るには、logging.Handlerクラスを継承して、必要に応じて、親クラスのメソッドをオーバーライド、もしくはメソッド追加を行えばよいということが分かる。

その際、emitメソッドだけは、子クラスでオーバーライドして実装してあげる必要がある。
もし、これを実装しなければ、必ずNotImplementedErrorが発生するようになっている。
emitは、ログの書き込み処理や送信処理を行うメソッドであるので、実装が必要なのはまあ当然である。

ちなみに、親クラスとなるlogging.Handlerは、下記のようになっている。

https://github.com/python/cpython/blob/master/Lib/logging/init.py

# ソースコード内のコメントは削除
class Handler(Filterer):

    def __init__(self, level=NOTSET):
        Filterer.__init__(self)
        self._name = None
        self.level = _checkLevel(level)
        self.formatter = None
        _addHandlerRef(self)
        self.createLock()

    def get_name(self):
        return self._name

    def set_name(self, name):
        _acquireLock()
        try:
            if self._name in _handlers:
                del _handlers[self._name]
            self._name = name
            if name:
                _handlers[name] = self
        finally:
            _releaseLock()

    name = property(get_name, set_name)

    def createLock(self):
        self.lock = threading.RLock()
        _register_at_fork_reinit_lock(self)

    def acquire(self):
        if self.lock:
            self.lock.acquire()

    def release(self):
        if self.lock:
            self.lock.release()

    def setLevel(self, level):
        self.level = _checkLevel(level)

    def format(self, record):
        if self.formatter:
            fmt = self.formatter
        else:
            fmt = _defaultFormatter
        return fmt.format(record)

    def emit(self, record):
        raise NotImplementedError('emit must be implemented '
                                  'by Handler subclasses')

    def handle(self, record):
        rv = self.filter(record)
        if rv:
            self.acquire()
            try:
                self.emit(record)
            finally:
                self.release()
        return rv

    def setFormatter(self, fmt):
        self.formatter = fmt

    def flush(self):
        pass

    def close(self):
        _acquireLock()
        try:
            if self._name and self._name in _handlers:
                del _handlers[self._name]
        finally:
            _releaseLock()

    def handleError(self, record):
        if raiseExceptions and sys.stderr:
            t, v, tb = sys.exc_info()
            try:
                sys.stderr.write('--- Logging error ---\n')
                traceback.print_exception(t, v, tb, None, sys.stderr)
                sys.stderr.write('Call stack:\n')
                frame = tb.tb_frame
                while (frame and os.path.dirname(frame.f_code.co_filename) ==
                       __path__[0]):
                    frame = frame.f_back
                if frame:
                    traceback.print_stack(frame, file=sys.stderr)
                else:
                    sys.stderr.write('Logged from file %s, line %s\n' % (
                                     record.filename, record.lineno))
                try:
                    sys.stderr.write('Message: %r\n'
                                     'Arguments: %s\n' % (record.msg,
                                                          record.args))
                except RecursionError:
                    raise
                except Exception:
                    sys.stderr.write('Unable to print the message and arguments'
                                     ' - possible formatting error.\nUse the'
                                     ' traceback above to help find the error.\n'
                                    )
            except OSError:
                pass
            finally:
                del t, v, tb

    def __repr__(self):
        level = getLevelName(self.level)
        return '<%s (%s)>' % (self.__class__.__name__, level)

自分の場合、DBのメール機能を使ってログを送りたかったので、DBMailHandlerクラスを作ったのである。