Pythonのdataclassesでhashableなインスタンスを作るクラスを定義する
Pythonでコレクションを扱っていると、ハッシュ可能なオブジェクトの使用を求められることがままある。
例えば、辞書のキーはハッシュ可能なオブジェクトでないといけないし、集合の要素もハッシュ可能でなければならない。 これはコレクション内部のアルゴリズムにおいて、オブジェクトのハッシュ値を使っているからだろう。
時々、自分で定義したクラスのインスタンスを辞書や集合の中で使いたくなるときがある。そういうとき、そのクラスのインスタンスはハッシュ可能である必要がある。
そこで、ハッシュ可能なインスタンスを生成するクラスは、どのように定義するのが適切なのかについて考えたい。
組み込みのclass
Pythonの組み込みのclass
でクラスを定義すると、そのインスタンスはハッシュ可能である。
つまり、__hash__
メソッドを持っていて、ハッシュ値を計算できる。
class Company: def __init__(self, cid, name): self.cid = cid self.name = name toyota = Company(1, "Toyota") hash(toyota) # -9223371913295946368
じゃあ、組み込みのclass
使っておけばOKなのだろうか。 そういうわけにもいかないと思われる。
なぜなら、Pythonの組み込みのclass
のミュータブルな性質が、「プログラム中においてハッシュ値はイミュータブルでなければならない」という要件と相容れないからだ。
公式ドキュメントを読むと、辞書や集合の実装が、ハッシュ可能なオブジェクトのハッシュ値がイミュータブルであることを要求していると書いてある。
しかし、Pythonの組み込みのclass
のインスタンスのフィールドの値は簡単に書き換えることができ、そのインスタンスのハッシュ値がイミュータブルであることは全然保証されていない。
要約すると、ハッシュ可能なオブジェクトはイミュータブルでなければならないが、Pythonの組み込みのclass
のインスタンスはハッシュ可能であるにもかかわらず、ミュータブルになりがちである。
そういうわけで、ユーザー定義クラスをハッシュ可能にしたいならば、組み込みのclass
を使わない方がいいような気がする。
dataclasses
次にdataclasses
モジュールを使う方法がある。
これはバージョン3.7で新たにリリースされた機能らしく、使えるようになったのは割と最近である。
このモジュールを使って新たにクラスを定義するときには、いくつかのパラメータを指定することができる。
そのパラメータの一つに、frozen
がある。
デフォルトではfrozen
はFalse
だが、これをTrue
にしてやると、イミュータブルなインスタンスを生成するクラスを定義することができる。
@dataclass(frozen=True) class Company: company_id: int name: str
そして、frozenなクラスから生成されるインスタンスは、ハッシュ可能である。
frozen=True
にすると、適切な__hash__
メソッドを自動で生成してくれている。
実際にハッシュ可能であるかを確かめてみよう。
@dataclass(frozen=True) class Company: company_id: int name: str toyota = Company(1, "Toyota") hash(toyota) #-124679679567354462 set([toyota]) # 集合の要素にできる {toyota: 1} # 辞書のキーにできる
確かに、Companyのインスタンスはハッシュ可能である。
試しに、上記のfrozen
パラメータをFalse
にしてみると、「toyota
はハッシュ不可能なオブジェクトだ」というエラーが発生する。
つまり、dataclass
を使って普通に定義したクラスのインスタンスは、ミュータブルであるがハッシュ可能ではない。
これは、dataclass
を使ってクラスを定義しておけば、ミュータブルでハッシュ可能なオブジェクトを作らずに済むということだろう。(変なことをしない限り)
とはいえ、先ほどのfrozen
なクラスのインスタンスも完全にイミュータブルであるわけではない。
実はフィールドに辞書がセットされている場合、その辞書の一部の値を書き換えることはできる。
@dataclass(frozen=True) class Company: company_id: int name: str child: dict toyota = Company(1, "Toyota", {"company_id": 5, "name": "SUBALU"}) toyota.child["name"] = "SUBARU" print(toyota) # Company(company_id=1, name='Toyota', child={'company_id': 5, 'name': 'SUBARU'})
frozenなクラスも「浅く」イミュータブルなだけであるようだ。
しかし、この場合でも、ちゃんとインスタンス自体はハッシュ不可能にしてくれている。
dataclasses
は最近追加された機能だが、構文的にも他の言語でいうところのクラスに近くて良いと思う。
組み込みのクラスを使う場合と比較して、オーバーヘッドがどうなのかというところはちょっと気になるところではあるが、特にパフォーマンスが問題にならない場面では積極的に使わない手はないと思っている。
Pythonの集合同士の差分を取るときに陥った落とし穴を記録しておく
WEBから取ってきたデータと既にDBに入っているデータの差分を取りたいと思うことがある。
辞書
最初は、辞書でデータを扱っていたとしよう。
companies_in_web = [{"cid": 0, "name": "Toyota"}, {"cid": 1, "name": "Nissan"}, {"cid": 2, "name": "Honda"}] companies_in_db = [{"cid": 0, "name": "Toyota"}, {"cid": 1, "name": "Nissan"}, {"cid": 4, "name": "Mazda"}]
そして、下のような感じで辞書が入ったリストを集合に変換して差分を取りたい。
取れるんじゃないかと思っていた。
diff = set(companies_in_web) - set(companies_in_db)
しかし、残念ながら、これはTypeError
を吐く。
Pythonのset型の要素はハッシュ可能である必要があり、辞書はハッシュ可能ではないからだ。
https://docs.python.org/ja/3/library/stdtypes.html#set
# TypeError: unhashable type: 'dict'
ユーザー定義クラス
それでは、辞書ではなく、こういうデータ構造ならどうだろうか。
class Company: def __init__(self, cid, name): self.cid = cid self.name = name def __repr__(self): return f"Company: cid={self.cid}, name={self.name}" companies_in_web = [Company(0, "Toyota"), Company(1, "Nissan"), Company(2, "Honda")] companies_in_db = [Company(0, "Toyota"), Company(1, "Nissan"), Company(3, "Mazda")]
それでは、差分を取ってみよう。
diff = set(companies_in_web) - set(companies_in_db)
# {Company: cid=0, name=Toyota, Company: cid=1, name=Nissan, Company: cid=2, name=Honda}
残念ながら、これは期待した結果ではない。
Company(2, "Honda")
だけが差分抽出されてほしいのに、Toyota
もNissan
も抽出されている。
ユーザー定義クラス・改
それでは、先ほどのクラスに__eq__
と__hash__
関数を実装してみよう。
(追記:ミュータブルなオブジェクトに__hash__
を実装するのは適切なのか...)
class Company: def __init__(self, cid, name): self.cid = cid self.name = name def __repr__(self): return f"Company: cid={self.cid}, name={self.name}" def __eq__(self, other): if isinstance(other, Company): return self.cid == other.cid and self.name == other.name def __hash__(self): return hash((self.cid, self.name)) companies_in_web = [Company(0, "Toyota"), Company(1, "Nissan"), Company(2, "Honda")] companies_in_db = [Company(0, "Toyota"), Company(1, "Nissan"), Company(3, "Mazda")]
差分をとってみる。
diff = set(companies_in_web) - set(companies_in_db)
# {Company: cid=2, name=Honda}
ようやく期待通りのものが抽出された。
集合同士の差分を取る際には、当然要素同士の比較が行われているはずだが、
どうやら__eq__
と__hash__
が実装されているか否かによってその比較の挙動が違うようである。
今日はこのような確認だけしておき、次回の記事でオブジェクトの同一性・同値性の詳細について書いていきたい。
Shadow Tacticsという死にゲー
一体、何回死んだだろうか。
バレないと思って脇を横切ろうとしたところを発見され、なぶり殺しにされる。
見られてないと高をくくってやった殺しを目撃され、逆襲を受ける。
気づきもしなかった町人にチクられて、思わぬリンチを食らう。
数えきれないほど赤に染まった画面を見た。
とはいえ、このゲームは、そこまで「死」を引きずらない。
それは、見た目に反して、このゲームがパズルゲームに近いからなのだと思う。
このゲームを進めるには、目の前の敵を殺すか、タイミングよく脇をすり抜けるか、どちらかである。
多くの場面で「殺す」ことを選択しなければならないが、目の前の殺すべき敵は必ず別の敵が見ている。
目の前の敵を殺すには、まず彼を視界にとらえている別の敵を殺す必要がある。
その別の敵を殺すには、彼を視界にとらえているまた別の敵を...
このゲームのパズルとは、このような認知ゾーンの依存関係を解きほぐしていくパズルである。
ただ、本当にこのゲームがドライなパズルでしかなかったとしたら、すぐに飽きてしまっただろう。
なかなか飽きが来ないのは、各ダンジョンがリアルに江戸時代を再現しているからだと思う。
雪で白んだ屋敷町、雨の陰鬱な農村、紅葉鮮やかな城下町、緑豊かな山中の寺。
四季折々のダンジョンがプレイヤーを待ち受ける。
ダンジョン内のオブジェクトは一つ一つがきめ細やかで、非常に再現性が高い。
思わず見入ってしまう時もある。
このような画像だけを見て始めたりすると、最初はこのゲームの本質がパズルにあるとは思わない。
ましてや、ステルスなどという紹介でMGSを頭に思い浮かべながら、このゲームを始めた自分みたいなプレイヤーは。
しかし、最初のダンジョン、自軍支援のための城下町潜入プロジェクトですぐに気づくことになるだろう。 初っ端から何度も何度も死ぬ羽目になるのだから。そして、ゲームを進めるごとにそのパズルの難易度は上がっていく。
本当にこれクリアできるのかと思う場面が何度もあり、諦めてしまいそうになるが、寝て起きると、解決策が閃いてクリアできたりする。そういう試行錯誤を繰り返しつつ、オレは今日も江戸時代パズルに立ち向かっていく...
Rの自作パッケージ内で依存パッケージを読み込む方法(Rで顔文字演算子)
今回はRで顔文字演算子を作ることができる話...ではなく、Rの自作パッケージ内でどうやってパッケージを読み込むかについての話である。
Rでパッケージを自作していると、自作パッケージ内で依存パッケージをどのように読み込むかという問題に当たる。
例えば、自作しているパッケージでは、ガツガツmagrittr
やdplyr
を使っている。
これらは、Rのbase
のデータフレーム操作の関数よりもコードをすっきり書けるのでよい。
特にmagrittr
のパイプが素晴らしい。
library(dplyr) library(magrittr) iris %>% filter(Species != "setosa") %>% select(-starts_with("Sepal")) %>% mutate(petal_area = Petal.Length * Petal.Width * 0.5) %>% group_by(Species) %>% summarise_all(funs(mean))
しかし、ここで問題なのは、
library(dplyr) library(magrittr)
である。
普通に使い捨てのスクリプトを書いているときは、これで問題ない。
しかし、自作のパッケージの中で、library(x)
を使ってパッケージを読み込むのはNGであるようだ。
http://r-pkgs.had.co.nz/namespace.html
そもそも、library(x)
は何をしているのか。これはパッケージを「アタッチ」している。アタッチとは、サーチパスにパッケージの名前を入れることであり、Rの何らかの関数が呼び出されるとき、まずはこのサーチパスの中に入っている名前空間から関数が検索される。
そして、サーチパスに入っているパッケージについては、わざわざその名前空間を指定しなくても、関数を呼び出すことができる。
library(dplyr) filter(iris, Species != "setosa")
たしかに、これによってタイピング量を減らすことはできるのだが、パッケージの名前空間を丸ごとサーチパスに入れるため、場合によっては、予期せぬ関数名の衝突を起こす可能性がある。そのため、パッケージレベルではlibrary(x)
の使用はNGとなっている。(個人とか会社のチーム内で使うとか、そういうレベルのパッケージならいいかもしれないが...)
そのため、パッケージ内では、基本的に
dplyr::filter()
のように名前空間を指定して、関数を呼び出すのが基本となる。
そうすると、困ることが一点だけある。
あの美しいパイプラインを実現してくれていた、magrittr
のパイプ演算子(%>%
)である。
実はこのパイプ演算子、名前空間を指定して使おうとすると、二項演算子ではなく関数として認識されてしまう。
(まあ、名前空間つきで使えたとしても、そんな冗長な二項演算子使う人もいないだろうが...)
magrittr::`%>%`(ihs, rhs)
それでは、パッケージ内では%>%
を演算子として使えないということになってしまう。
ところがどっこい、Rには裏技があったようだ。 「スペシャル演算子定義」である。
https://cran.r-project.org/doc/manuals/r-release/R-lang.html#Special-operators
Rでは、下のように%
で囲むと、お手軽に新しい演算子を定義することができる。
そして、%
の中にはどんな文字列を入れてやってもよい。まさにやりたい放題できるというわけだ。
ちなみに、これは既存の演算子のオーバーロードという生易しいやつではない。
(言語仕様的に適切なのだろうか...)
例えば、+
のエイリアスを作ることできる。
`%+%` <- function(a, b) { return(a + b) } 1 %+% 2 # 3
同じようにすれば、magrittr
の%>%
問題も万事解決する。
# library(magrittr) # 演算子エイリアス `%>%` <- magrittr::`%>%` # 使える iris %>% dplyr::filter(Species != "setosa")
これを使えば、Rでは顔文字演算子なども簡単に作れたりしてしまう。
`%アワワヽ(´Д`;≡;´Д`)ノアワワ%` <- function(a, b){ print("詰み") } "バグ" %アワワヽ(´Д`;≡;´Д`)ノアワワ% "障害" # 詰み
カリー化した関数に型ヒントを与える
Pythonの型ヒントに慣れるために、カリー化した関数に型ヒントを与えるというタスクをやってみる。
ちなみに、カリー化とは、複数の引数をとる関数を、一つの引数だけをとる関数のチェーンに変換する操作のことである。
以下、簡単なカリー化の例(Python)。
# 引数2つの場合 add_2num = lambda x, y: x + y cadd_2num = lambda x: lambda y: x + y # カリー化後 # 引数3つの場合 add_3num = lambda x, y, z: x + y + z cadd_3num = lambda x: lambda y: lambda z: x + y + z # カリー化後
カリー化した関数は、下記のように使用できる。
add_2num(1,2) # 3 cadd_2num(1)(2) # 3 add_3num(1,2,3) # 6 cadd_3num(1)(2)(3) # 6 increment = cadd_2num(1) # 部分適用 increment(2) # 3
さらに進んで、カリー化した関数に型ヒントを与えてみる。
from typing import Callable # 引数2つの場合 add_2num: Callable[[float, float], float] = lambda x, y: x + y cadd_2num: Callable[[float], Callable[[float], float]] = lambda x: lambda y: x + y # 引数3つの場合 add_3num: Callable[[float, float, float], float] = lambda x, y, z: x + y + z cadd_3num: Callable[[float], Callable[[float], Callable[[float], float]]] = lambda x: lambda y: lambda z: x + y + z cadd_2num(1)(2) # 3 cadd_2num(1)('2') # TypeError: intとstringを+演算しようとしてエラー cadd_2num('1')('2') # 12 エラーは発生しない
mypyで型チェックをしてみる。
mypy main.py # curried.py:12: error: Argument 1 has incompatible type "str"; expected "float" # curried.py:13: error: Argument 1 has incompatible type "str"; expected "float"
cadd_2num('1')('2')
の2番目の引数の不正は指摘してくれないのは、何故なのか...
参考
https://ja.wikipedia.org/wiki/%E3%82%AB%E3%83%AA%E3%83%BC%E5%8C%96 https://www.ibm.com/developerworks/jp/java/library/j-jn9/index.html
Pythonでの雑でガバっとした例外処理
定期的に動かすので、ある程度きちんとプログラムを書かないといけないけど、そんなに丁寧に例外処理を書いている余裕はないぞというケースがある。
そういう時、実行スクリプトでは、このように雑に例外処理を行うことが多い。 こうしておけば、とりあえずどんな例外もログに吐き出すことができ、未知の例外に遭遇した時はエラーメッセージを見ながら、適宜コードを修正することができる。
import traceback try: a() b() c() d() e() except: msg = traceback.format_exc() mail_logger.error(msg)
ところでDB周りの業務をやっていると、SQLAlchemyを使って、下のようにsession_scope
関数を定義し、色んなところで使いまわす。
from contextlib import contextmanager @contextmanager def session_scope(): session = Session() try: yield session session.commit() except: session.rollback() finally: session.close() # session_scopeはこのように使う def insert(something): with session_scope() as session: session.add(something)
# メインの処理はこのような感じ try: something = Something() insert(something) except: msg = traceback.format_exc() mail_logger.error(msg)
しかし、このsession_scope
を上記のようなtry句内で使うと、DB操作で例外が発生した時に、session_scope
内のexcept
、finnaly
でプログラムが終了し、肝心の最上位のexcept
句に処理が到達しない。
こうなると、ログが吐き出されないので、困ったことになる。
こういうときは、session_scope
内のexcept
句の中で例外を再スローしてやるとよい。
そうすると、きちんと処理が最上位のexcept
にまで到達し、DB操作時に例外が発生した時の状況がログに書き込まれる。
from exceptions import DBHandleException @contextmanager def session_scope(): session = Session() try: yield session session.commit() except: session.rollback() # 例外を再スローする raise DBHandleException finally: session.close()
Pythonのloggingで独自ハンドラを作りたいとき
Pythonのloggingを利用していると、独自のハンドラを作りたくなるときがある。
そういう時には、すでにloggingに定義されているハンドラのコードを参考にすることができる。
例えば、logging.handlers.SMTPHandler
は下記のように定義されている。
https://github.com/python/cpython/blob/master/Lib/logging/handlers.py
# ソースコード内のコメントは削除 class SMTPHandler(logging.Handler): def __init__(self, mailhost, fromaddr, toaddrs, subject, credentials=None, secure=None, timeout=5.0): logging.Handler.__init__(self) if isinstance(mailhost, (list, tuple)): self.mailhost, self.mailport = mailhost else: self.mailhost, self.mailport = mailhost, None if isinstance(credentials, (list, tuple)): self.username, self.password = credentials else: self.username = None self.fromaddr = fromaddr if isinstance(toaddrs, str): toaddrs = [toaddrs] self.toaddrs = toaddrs self.subject = subject self.secure = secure self.timeout = timeout def getSubject(self, record): return self.subject def emit(self, record): try: import smtplib from email.message import EmailMessage import email.utils port = self.mailport if not port: port = smtplib.SMTP_PORT smtp = smtplib.SMTP(self.mailhost, port, timeout=self.timeout) msg = EmailMessage() msg['From'] = self.fromaddr msg['To'] = ','.join(self.toaddrs) msg['Subject'] = self.getSubject(record) msg['Date'] = email.utils.localtime() msg.set_content(self.format(record)) if self.username: if self.secure is not None: smtp.ehlo() smtp.starttls(*self.secure) smtp.ehlo() smtp.login(self.username, self.password) smtp.send_message(msg) smtp.quit() except Exception: self.handleError(record)
これを参考にすると、独自ハンドラを作るには、logging.Handler
クラスを継承して、必要に応じて、親クラスのメソッドをオーバーライド、もしくはメソッド追加を行えばよいということが分かる。
その際、emit
メソッドだけは、子クラスでオーバーライドして実装してあげる必要がある。
もし、これを実装しなければ、必ずNotImplementedError
が発生するようになっている。
emit
は、ログの書き込み処理や送信処理を行うメソッドであるので、実装が必要なのはまあ当然である。
ちなみに、親クラスとなるlogging.Handler
は、下記のようになっている。
https://github.com/python/cpython/blob/master/Lib/logging/init.py
# ソースコード内のコメントは削除 class Handler(Filterer): def __init__(self, level=NOTSET): Filterer.__init__(self) self._name = None self.level = _checkLevel(level) self.formatter = None _addHandlerRef(self) self.createLock() def get_name(self): return self._name def set_name(self, name): _acquireLock() try: if self._name in _handlers: del _handlers[self._name] self._name = name if name: _handlers[name] = self finally: _releaseLock() name = property(get_name, set_name) def createLock(self): self.lock = threading.RLock() _register_at_fork_reinit_lock(self) def acquire(self): if self.lock: self.lock.acquire() def release(self): if self.lock: self.lock.release() def setLevel(self, level): self.level = _checkLevel(level) def format(self, record): if self.formatter: fmt = self.formatter else: fmt = _defaultFormatter return fmt.format(record) def emit(self, record): raise NotImplementedError('emit must be implemented ' 'by Handler subclasses') def handle(self, record): rv = self.filter(record) if rv: self.acquire() try: self.emit(record) finally: self.release() return rv def setFormatter(self, fmt): self.formatter = fmt def flush(self): pass def close(self): _acquireLock() try: if self._name and self._name in _handlers: del _handlers[self._name] finally: _releaseLock() def handleError(self, record): if raiseExceptions and sys.stderr: t, v, tb = sys.exc_info() try: sys.stderr.write('--- Logging error ---\n') traceback.print_exception(t, v, tb, None, sys.stderr) sys.stderr.write('Call stack:\n') frame = tb.tb_frame while (frame and os.path.dirname(frame.f_code.co_filename) == __path__[0]): frame = frame.f_back if frame: traceback.print_stack(frame, file=sys.stderr) else: sys.stderr.write('Logged from file %s, line %s\n' % ( record.filename, record.lineno)) try: sys.stderr.write('Message: %r\n' 'Arguments: %s\n' % (record.msg, record.args)) except RecursionError: raise except Exception: sys.stderr.write('Unable to print the message and arguments' ' - possible formatting error.\nUse the' ' traceback above to help find the error.\n' ) except OSError: pass finally: del t, v, tb def __repr__(self): level = getLevelName(self.level) return '<%s (%s)>' % (self.__class__.__name__, level)
自分の場合、DBのメール機能を使ってログを送りたかったので、DBMailHandler
クラスを作ったのである。