本記事では、Pythonのデコレータ(関数の上の行に付与する@から始まる文言)の使用方法について、関数の時間計測やTry構文を例に説明します。
1. 時間計測
デコレータは@hoge
のように宣言された下の行から始まる関数を引数に取り、その関数を用いた処理をする関数を返します。
これにより、引数の関数の実行前後に処理を追加することができます。
デコレータの最も単純な例として、以下の時間計測が挙げられます。
import time # デコレータの本体 def timeit(func): def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) print(f'Elapsed time for "{func.__name__}" was {time.time() - start:.2f} sec.') return result return wrapper @timeit def sum_n(n): s = 0 for i in range(n): for j in range(n): for k in range(n): s += i return s # 実行 sum_n(500)
Elapsed time for "sum_n" was 2.51 sec.
31187500000
2. functoolsのwrapsを使用
前節の例では、こちらで説明されているような問題が発生します。
その問題を回避するために、以下のようにしてfunctools
のwraps
を使用することが推奨されています。
from functools import wraps # この行を追加 import time def timeit(func): @wraps(func) # この行を追加 def wrapper(*args, **kwargs): start = time.time() result = func(*args,**kwargs) print(f'Elapsed time for "{func.__name__}" was {time.time() - start:.2f} sec.') return result return wrapper @timeit def sum_n(n): s = 0 for i in range(n): for j in range(n): for k in range(n): s += i return s sum_n(500)
Elapsed time for "sum_n" was 2.54 sec.
31187500000
3. 時間計測(自動的に単位を変換)
コピペ用の時間計測のデコレータを下記に掲載しておきます。
測定時間の単位が自動的に判定されて出力されるようにしてあります。
from functools import wraps import time import datetime def get_d_h_m_s_us(sec): td = datetime.timedelta(seconds=sec) m, s = divmod(td.seconds, 60) h, m = divmod(m, 60) return td.days, h, m, s, td.microseconds def timeit(func): @wraps(func) def wrapper(*args, **kwargs): start = time.time() result = func(*args,**kwargs) t = time.time() - start d, h, m, s, us = get_d_h_m_s_us(t) if d != 0: print(f'Elapsed time for "{func.__name__}" was {d} day {h} hr {m} min {s}.{str(us)[:3]} sec.') elif h != 0: print(f'Elapsed time for "{func.__name__}" was {h} hr {m} min {s}.{str(us)[:3]} sec.') elif m != 0: print(f'Elapsed time for "{func.__name__}" was {m} min {s}.{str(us)[:3]} sec.') else: print(f'Elapsed time for "{func.__name__}" was {s}.{str(us)[:3]} sec.') return result return wrapper @timeit def sum_n(n): s = 0 for i in range(n): for j in range(n): for k in range(n): s += i return s sum_n(300)
Elapsed time for "sum_n" was 0.498 sec.
4036500000
4. 複数のデコレータを使う: 単純に重ねる
以下のように、デコレータは単純に重ねて使用できます。
# 2つ目のデコレータ: 標準出力の前に水平線を出力して見やすくする def separator(func): @wraps(func) def wrapper(*args, **kwargs): sep = '-' * 100 print(sep) res = func(*args, **kwargs) return res return wrapper # 時間計測のデコレータ: 前節と同じ from functools import wraps import time import datetime def get_d_h_m_s_us(sec): td = datetime.timedelta(seconds=sec) m, s = divmod(td.seconds, 60) h, m = divmod(m, 60) return td.days, h, m, s, td.microseconds def timeit(func): @wraps(func) def wrapper(*args, **kwargs): start = time.time() result = func(*args,**kwargs) t = time.time() - start d, h, m, s, us = get_d_h_m_s_us(t) if d != 0: print(f'Elapsed time for "{func.__name__}" was {d} day {h} hr {m} min {s}.{str(us)[:3]} sec.') elif h != 0: print(f'Elapsed time for "{func.__name__}" was {h} hr {m} min {s}.{str(us)[:3]} sec.') elif m != 0: print(f'Elapsed time for "{func.__name__}" was {m} min {s}.{str(us)[:3]} sec.') else: print(f'Elapsed time for "{func.__name__}" was {s}.{str(us)[:3]} sec.') return result return wrapper # デコレータを重ねて使用 @separator @timeit def sum_n(n): s = 0 for i in range(n): for j in range(n): for k in range(n): s += i return s sum_n(300)
----------------------------------------------------------------------------------------------------
Elapsed time for "sum_n" was 0.493 sec.
4036500000
ただし、順序を間違えると狙った動作にならないので、注意が必要です。
5. 複数のデコレータを使う: ネストさせる
前節のように単純にデコレータを重ねても問題ありませんが、必ずセットで利用する場合は以下のように他方をもう一方の中で定義するとコードの見通しがよくなります。
def separator(func): @wraps(func) def wrapper(*args, **kwargs): sep = '-' * 100 print(sep) res = func(*args, **kwargs) return res return wrapper from functools import wraps import time import datetime def get_d_h_m_s_us(sec): td = datetime.timedelta(seconds=sec) m, s = divmod(td.seconds, 60) h, m = divmod(m, 60) return td.days, h, m, s, td.microseconds def timeit(func): @separator # ここに他方のデコレータを挟む @wraps(func) def wrapper(*args, **kwargs): start = time.time() result = func(*args,**kwargs) t = time.time() - start d, h, m, s, us = get_d_h_m_s_us(t) if d != 0: print(f'Elapsed time for "{func.__name__}" was {d} day {h} hr {m} min {s}.{str(us)[:3]} sec.') elif h != 0: print(f'Elapsed time for "{func.__name__}" was {h} hr {m} min {s}.{str(us)[:3]} sec.') elif m != 0: print(f'Elapsed time for "{func.__name__}" was {m} min {s}.{str(us)[:3]} sec.') else: print(f'Elapsed time for "{func.__name__}" was {s}.{str(us)[:3]} sec.') return result return wrapper @timeit def sum_n(n): s = 0 for i in range(n): for j in range(n): for k in range(n): s += i return s sum_n(300)
----------------------------------------------------------------------------------------------------
Elapsed time for "sum_n" was 0.493 sec.
4036500000
6. デコレータに引数を渡す
以下のようにすれば、デコレータに引数を渡すことも可能です。
# 引数ありバージョンのデコレータの定義 def separator(level): def _separator(func): @wraps(func) def wrapper(*args, **kwargs): sep1 = '=' * 100 sep2 = '-' * 100 if level == 1: print(sep1) res = func(*args, **kwargs) elif level == 2: print(sep2) res = func(*args, **kwargs) else: print(f'Invalid level.') res = func(*args, **kwargs) return res return wrapper return _separator from functools import wraps import time import datetime def get_d_h_m_s_us(sec): td = datetime.timedelta(seconds=sec) m, s = divmod(td.seconds, 60) h, m = divmod(m, 60) return td.days, h, m, s, td.microseconds def timeit(func): @separator(1) @wraps(func) def wrapper(*args, **kwargs): start = time.time() result = func(*args,**kwargs) t = time.time() - start d, h, m, s, us = get_d_h_m_s_us(t) if d != 0: print(f'Elapsed time for "{func.__name__}" was {d} day {h} hr {m} min {s}.{str(us)[:3]} sec.') elif h != 0: print(f'Elapsed time for "{func.__name__}" was {h} hr {m} min {s}.{str(us)[:3]} sec.') elif m != 0: print(f'Elapsed time for "{func.__name__}" was {m} min {s}.{str(us)[:3]} sec.') else: print(f'Elapsed time for "{func.__name__}" was {s}.{str(us)[:3]} sec.') return result return wrapper @timeit def sum_n(n): s = 0 for i in range(n): for j in range(n): for k in range(n): s += i return s sum_n(300)
====================================================================================================
Elapsed time for "sum_n" was 0.501 sec.
4036500000
7. デコレータでtry構文を付与する
デコレータのよくある使用例には、時間計測の他にTry & Exceptの構文を付与する例があります。
以下のように実装することで、@tryanderror
で修飾した関数をtry
構文の中に記述した場合と同等になります。
# Exception時に呼び出す関数 def print_error_msg(msg, func, args, kwargs, e): print(f"{msg}") print(f">>> func: {func.__name__}") print(f">>> args: {args}") print(f">>> kwargs: {kwargs}") print(f'>>> error type: {type(e)}') print(f'>>> error msg: {str(e.args)}') # try構文を付与するデコレータの定義 def tryanderror(func): @wraps(func) def wrapper(*args, **kwargs): try: res = func(*args, **kwargs) return res except Exception as e: print_error_msg("Unexpected error occurred.", func, args, kwargs, e) return wrapper def separator(level): def _separator(func): @tryanderror @wraps(func) def wrapper(*args, **kwargs): sep1 = '=' * 100 sep2 = '-' * 100 if level == 1: print(sep1) res = func(*args, **kwargs) elif level == 2: print(sep2) res = func(*args, **kwargs) else: print(f'Invalid level.') res = func(*args, **kwargs) return res return wrapper return _separator from functools import wraps import time import datetime def get_d_h_m_s_us(sec): td = datetime.timedelta(seconds=sec) m, s = divmod(td.seconds, 60) h, m = divmod(m, 60) return td.days, h, m, s, td.microseconds def timeit(func): @separator(1) @tryanderror @wraps(func) def wrapper(*args, **kwargs): start = time.time() result = func(*args,**kwargs) t = time.time() - start d, h, m, s, us = get_d_h_m_s_us(t) if d != 0: print(f'Elapsed time for "{func.__name__}" was {d} day {h} hr {m} min {s}.{str(us)[:3]} sec.') elif h != 0: print(f'Elapsed time for "{func.__name__}" was {h} hr {m} min {s}.{str(us)[:3]} sec.') elif m != 0: print(f'Elapsed time for "{func.__name__}" was {m} min {s}.{str(us)[:3]} sec.') else: print(f'Elapsed time for "{func.__name__}" was {s}.{str(us)[:3]} sec.') return result return wrapper @tryanderror @timeit def sum_n(n): s = 0 for i in range(n): for j in range(n): for k in range(n): s += i return s sum_n('a')
====================================================================================================
Unexpected error occurred.
>>> func: sum_n
>>> args: ('a',)
>>> kwargs: {}
>>> error type: <class 'TypeError'>
>>> error msg: ("'str' object cannot be interpreted as an integer",)
8. クラス内の関数をデコレート
クラス内の関数(メソッド)であっても、外部で定義したデコレータをそのまま利用できます。
def print_error_msg(msg, func, args, kwargs, e): print(f"{msg}") print(f">>> func: {func.__name__}") print(f">>> args: {args}") print(f">>> kwargs: {kwargs}") print(f'>>> error type: {type(e)}') print(f'>>> error msg: {str(e.args)}') def tryanderror(func): @wraps(func) def wrapper(*args, **kwargs): try: res = func(*args, **kwargs) return res except Exception as e: print_error_msg("Unexpected error occurred.", func, args, kwargs, e) return wrapper def separator(level): def _separator(func): @tryanderror @wraps(func) def wrapper(*args, **kwargs): sep1 = '=' * 100 sep2 = '-' * 100 if level == 1: print(sep1) res = func(*args, **kwargs) elif level == 2: print(sep2) res = func(*args, **kwargs) else: print(f'Invalid level.') res = func(*args, **kwargs) return res return wrapper return _separator from functools import wraps import time import datetime def get_d_h_m_s_us(sec): td = datetime.timedelta(seconds=sec) m, s = divmod(td.seconds, 60) h, m = divmod(m, 60) return td.days, h, m, s, td.microseconds def timeit(func): @separator(1) @tryanderror @wraps(func) def wrapper(*args, **kwargs): start = time.time() result = func(*args,**kwargs) t = time.time() - start d, h, m, s, us = get_d_h_m_s_us(t) if d != 0: print(f'Elapsed time for "{func.__name__}" was {d} day {h} hr {m} min {s}.{str(us)[:3]} sec.') elif h != 0: print(f'Elapsed time for "{func.__name__}" was {h} hr {m} min {s}.{str(us)[:3]} sec.') elif m != 0: print(f'Elapsed time for "{func.__name__}" was {m} min {s}.{str(us)[:3]} sec.') else: print(f'Elapsed time for "{func.__name__}" was {s}.{str(us)[:3]} sec.') return result return wrapper # デコレータつきのクラスの定義 class SumN(object): def __init__(self): self.value = 0 @tryanderror @timeit def add_sum1(self, n): for i in range(n): self.value += i @tryanderror @timeit def add_sum2(self, n): for i in range(n): for j in range(n): self.value += i @tryanderror @timeit def add_sum3(self, n): for i in range(n): for j in range(n): for k in range(n): self.value += i @separator(2) def reset(self): self.value = 0 print(f'Reset.') # 実行 SN = SumN() SN.add_sum1(300) SN.add_sum2(300) SN.add_sum3(300) SN.reset()
====================================================================================================
Elapsed time for "add_sum1" was 0.60 sec.
====================================================================================================
Elapsed time for "add_sum2" was 0.945 sec.
====================================================================================================
Elapsed time for "add_sum3" was 0.605 sec.
----------------------------------------------------------------------------------------------------
Reset.
9. クラスのメソッドでデコレータを定義
クラス内のメソッドとしてデコレータを定義することも可能です。
この場合、@classmethod
などを利用するなど少々定義方法や使用方法が異なることに注意します。
以下が例になります。
# クラス外におけるデコレータの定義 def print_error_msg(msg, func, args, kwargs, e): print(f"{msg}") print(f">>> func: {func.__name__}") print(f">>> args: {args}") print(f">>> kwargs: {kwargs}") print(f'>>> error type: {type(e)}') print(f'>>> error msg: {str(e.args)}') def tryanderror(func): @wraps(func) def wrapper(*args, **kwargs): try: res = func(*args, **kwargs) return res except Exception as e: print_error_msg("Unexpected error occurred.", func, args, kwargs, e) return wrapper def separator(level): def _separator(func): @tryanderror @wraps(func) def wrapper(*args, **kwargs): sep1 = '=' * 100 sep2 = '-' * 100 if level == 1: print(sep1) res = func(*args, **kwargs) elif level == 2: print(sep2) res = func(*args, **kwargs) else: print(f'Invalid level.') res = func(*args, **kwargs) return res return wrapper return _separator # クラスとして時間計測のデコレータを定義 class Time(object): def __init__(self): pass @classmethod # デコレータにするメソッドには@classmethodによる修飾を付与 def timeit(cls, func): # 第一引数にはcls, 第二引数に関数を設定 @separator(1) # 外部のデコレータ(引数あり)を通常通り使用することも可能 @tryanderror # 外部のデコレータ(引数なし)を通常通り使用することも可能 @wraps(func) def wrapper(*args, **kwargs): start = time.time() result = func(*args,**kwargs) t = time.time() - start d, h, m, s, us = get_d_h_m_s_us(t) if d != 0: print(f'Elapsed time for "{func.__name__}" was {d} day {h} hr {m} min {s}.{str(us)[:3]} sec.') elif h != 0: print(f'Elapsed time for "{func.__name__}" was {h} hr {m} min {s}.{str(us)[:3]} sec.') elif m != 0: print(f'Elapsed time for "{func.__name__}" was {m} min {s}.{str(us)[:3]} sec.') else: print(f'Elapsed time for "{func.__name__}" was {s}.{str(us)[:3]} sec.') return result return wrapper # クラス内で定義したデコレータを利用するため、上記のTimeクラスを継承 class SumN(Time): def __init__(self): self.value = 0 @tryanderror # 外部で定義した通常のデコレータはそのまま使用可能 @Time.timeit # クラスのメソッドとして定義したデコレータの使用方法は通常と異なる def add_sum1(self, n): for i in range(n): self.value += i @tryanderror @Time.timeit def add_sum2(self, n): for i in range(n): for j in range(n): self.value += i @tryanderror @Time.timeit def add_sum3(self, n): for i in range(n): for j in range(n): for k in range(n): self.value += i @separator(2) def reset(self): self.value = 0 print(f'Reset.') SN = SumN() SN.add_sum1(300) SN.add_sum2(300) SN.add_sum3(300) SN.reset()
====================================================================================================
Elapsed time for "add_sum1" was 0.60 sec.
====================================================================================================
Elapsed time for "add_sum2" was 0.102 sec.
====================================================================================================
Elapsed time for "add_sum3" was 0.624 sec.
----------------------------------------------------------------------------------------------------
Reset.
10.クラスのメソッドで引数つきのデコレータを定義
前節のクラス内デコレータは引数がない場合でした。
ここでは、引数付きのクラス内デコレータの定義方法を示します。
クラス内で定義したデコレータを同じクラス内の別のメソッドに即座に適用することはできないため、以下のようにして適宜別クラスとして定義して継承することで利用します。
# 最初にtry構文用のデコレータをDebugクラスとして定義 class Debug(object): def __init__(self): pass def print_error_msg(self, msg, func, args, kwargs, e): print(f"{msg}") print(f">>> func: {func.__name__}") print(f">>> args: {args}") print(f">>> kwargs: {kwargs}") print(f'>>> error type: {type(e)}') print(f'>>> error msg: {str(e.args)}') @classmethod def tryanderror(cls, func): @wraps(func) def wrapper(*args, **kwargs): try: res = func(*args, **kwargs) return res except Exception as e: print_error_msg("Unexpected error occurred.", func, args, kwargs, e) return wrapper # 次にSeparatorを「引数あり」のクラスとして定義 # この際、Debugクラスを継承 class Separator(Debug): def __init__(self): pass @classmethod def separator(cls, level): def _separator(func): @Debug.tryanderror @wraps(func) def wrapper_separator(*args, **kwargs): sep1 = '=' * 100 sep2 = '-' * 100 if level == 1: print(sep1) res = func(*args, **kwargs) elif level == 2: print(sep2) res = func(*args, **kwargs) else: print(f'Invalid level.') res = func(*args, **kwargs) return res return wrapper_separator return _separator # 最後に時間計測用のクラスをSeparatorクラスを継承して定義 class Time(Separator): def __init__(self): pass @classmethod def timeit(cls, func): @Separator.separator(1) @Debug.tryanderror @wraps(func) def wrapper(*args, **kwargs): start = time.time() result = func(*args,**kwargs) t = time.time() - start d, h, m, s, us = get_d_h_m_s_us(t) if d != 0: print(f'Elapsed time for "{func.__name__}" was {d} day {h} hr {m} min {s}.{str(us)[:3]} sec.') elif h != 0: print(f'Elapsed time for "{func.__name__}" was {h} hr {m} min {s}.{str(us)[:3]} sec.') elif m != 0: print(f'Elapsed time for "{func.__name__}" was {m} min {s}.{str(us)[:3]} sec.') else: print(f'Elapsed time for "{func.__name__}" was {s}.{str(us)[:3]} sec.') return result return wrapper # Timeクラスを継承して個々のクラスを定義 class SumN(Time): def __init__(self): self.value = 0 @Debug.tryanderror @Time.timeit def add_sum1(self, n): for i in range(n): self.value += i @Debug.tryanderror @Time.timeit def add_sum2(self, n): for i in range(n): for j in range(n): self.value += i @Debug.tryanderror @Time.timeit def add_sum3(self, n): for i in range(n): for j in range(n): for k in range(n): self.value += i @Separator.separator(2) def reset(self): self.value = 0 print(f'Reset.') # 実行 SN = SumN() SN.add_sum1(300) SN.add_sum2(300) SN.add_sum3(300) SN.reset()
====================================================================================================
Elapsed time for "add_sum1" was 0.43 sec.
====================================================================================================
Elapsed time for "add_sum2" was 0.967 sec.
====================================================================================================
Elapsed time for "add_sum3" was 0.606 sec.
----------------------------------------------------------------------------------------------------
Reset.
11. まとめ
本記事では、Pythonのデコレータについて基本的な使い方を時間計測を例にして紹介しました。
また、複数のデコレータを適用する場合、デコレータに引数を渡す方法、クラスのメソッドに対してデコレータを適用する方法、クラスのメソッドをデコレータとして定義して利用する方法についても具体例を示して紹介しました。
以上になります。
コメント