【Python】デコレータ【時間計測・Try】

Python

本記事では、Pythonのデコレータ(関数の上の行に付与する@から始まる文言)の使用方法について、関数の時間計測やTry構文を例に説明します。

1. 時間計測

デコレータは@hogeのように宣言された下の行から始まる関数を引数に取り、その関数を用いた処理をする関数を返します。

これにより、引数の関数の実行前後に処理を追加することができます。

デコレータの最も単純な例として、以下の時間計測が挙げられます。

import time

# デコレータの本体
def timeit(func):
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        print(f'Elapsed time for "{func.__name__}" was {time.time() -  start:.2f} sec.')
        return result
    return wrapper
  


@timeit
def sum_n(n):
    s = 0
    for i in range(n):
        for j in range(n):
            for k in range(n):
                s += i
    return s
  
  
 # 実行
sum_n(500)
Elapsed time for "sum_n" was 2.51 sec.
31187500000

2. functoolsのwrapsを使用

前節の例では、こちらで説明されているような問題が発生します。

その問題を回避するために、以下のようにしてfunctoolswrapsを使用することが推奨されています。

from functools import wraps # この行を追加
import time

def timeit(func):
    @wraps(func) # この行を追加
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args,**kwargs)
        print(f'Elapsed time for "{func.__name__}" was {time.time() -  start:.2f} sec.')
        return result
    return wrapper
  
  
@timeit
def sum_n(n):
    s = 0
    for i in range(n):
        for j in range(n):
            for k in range(n):
                s += i
    return s

  
sum_n(500)
Elapsed time for "sum_n" was 2.54 sec.
31187500000

3. 時間計測(自動的に単位を変換)

コピペ用の時間計測のデコレータを下記に掲載しておきます。

測定時間の単位が自動的に判定されて出力されるようにしてあります。

from functools import wraps
import time
import datetime


def get_d_h_m_s_us(sec):
    td = datetime.timedelta(seconds=sec)
    m, s = divmod(td.seconds, 60)
    h, m = divmod(m, 60)
    return td.days, h, m, s, td.microseconds


def timeit(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args,**kwargs)
        t =  time.time() - start
        d, h, m, s, us = get_d_h_m_s_us(t)
        if d != 0:
            print(f'Elapsed time for "{func.__name__}" was {d} day {h} hr {m} min {s}.{str(us)[:3]} sec.')
        elif h != 0:
            print(f'Elapsed time for "{func.__name__}" was {h} hr {m} min {s}.{str(us)[:3]} sec.')
        elif m != 0:
            print(f'Elapsed time for "{func.__name__}" was {m} min {s}.{str(us)[:3]} sec.')
        else:
            print(f'Elapsed time for "{func.__name__}" was {s}.{str(us)[:3]} sec.')
        return result
    return wrapper
  
  
@timeit
def sum_n(n):
    s = 0
    for i in range(n):
        for j in range(n):
            for k in range(n):
                s += i
    return s
  
  
sum_n(300)
Elapsed time for "sum_n" was 0.498 sec.
4036500000

4. 複数のデコレータを使う: 単純に重ねる

以下のように、デコレータは単純に重ねて使用できます。

# 2つ目のデコレータ: 標準出力の前に水平線を出力して見やすくする
def separator(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        sep = '-' * 100
        print(sep)
        res = func(*args, **kwargs)
        return res
    return wrapper
  
  
# 時間計測のデコレータ: 前節と同じ
from functools import wraps
import time
import datetime
def get_d_h_m_s_us(sec):
    td = datetime.timedelta(seconds=sec)
    m, s = divmod(td.seconds, 60)
    h, m = divmod(m, 60)
    return td.days, h, m, s, td.microseconds


def timeit(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args,**kwargs)
        t =  time.time() - start
        d, h, m, s, us = get_d_h_m_s_us(t)
        if d != 0:
            print(f'Elapsed time for "{func.__name__}" was {d} day {h} hr {m} min {s}.{str(us)[:3]} sec.')
        elif h != 0:
            print(f'Elapsed time for "{func.__name__}" was {h} hr {m} min {s}.{str(us)[:3]} sec.')
        elif m != 0:
            print(f'Elapsed time for "{func.__name__}" was {m} min {s}.{str(us)[:3]} sec.')
        else:
            print(f'Elapsed time for "{func.__name__}" was {s}.{str(us)[:3]} sec.')
        return result
    return wrapper
  

# デコレータを重ねて使用
@separator
@timeit
def sum_n(n):
    s = 0
    for i in range(n):
        for j in range(n):
            for k in range(n):
                s += i
    return s
  
  
sum_n(300)
----------------------------------------------------------------------------------------------------
Elapsed time for "sum_n" was 0.493 sec.
4036500000

ただし、順序を間違えると狙った動作にならないので、注意が必要です。

5. 複数のデコレータを使う: ネストさせる

前節のように単純にデコレータを重ねても問題ありませんが、必ずセットで利用する場合は以下のように他方をもう一方の中で定義するとコードの見通しがよくなります。

def separator(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        sep = '-' * 100
        print(sep)
        res = func(*args, **kwargs)
        return res
    return wrapper
  
  
from functools import wraps
import time
import datetime
def get_d_h_m_s_us(sec):
    td = datetime.timedelta(seconds=sec)
    m, s = divmod(td.seconds, 60)
    h, m = divmod(m, 60)
    return td.days, h, m, s, td.microseconds

def timeit(func):
    @separator # ここに他方のデコレータを挟む
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args,**kwargs)
        t =  time.time() - start
        d, h, m, s, us = get_d_h_m_s_us(t)
        if d != 0:
            print(f'Elapsed time for "{func.__name__}" was {d} day {h} hr {m} min {s}.{str(us)[:3]} sec.')
        elif h != 0:
            print(f'Elapsed time for "{func.__name__}" was {h} hr {m} min {s}.{str(us)[:3]} sec.')
        elif m != 0:
            print(f'Elapsed time for "{func.__name__}" was {m} min {s}.{str(us)[:3]} sec.')
        else:
            print(f'Elapsed time for "{func.__name__}" was {s}.{str(us)[:3]} sec.')
        return result
    return wrapper
  
  
@timeit
def sum_n(n):
    s = 0
    for i in range(n):
        for j in range(n):
            for k in range(n):
                s += i
    return s
  
  
sum_n(300)
----------------------------------------------------------------------------------------------------
Elapsed time for "sum_n" was 0.493 sec.
4036500000

6. デコレータに引数を渡す

以下のようにすれば、デコレータに引数を渡すことも可能です。

# 引数ありバージョンのデコレータの定義
def separator(level):
    def _separator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            sep1 = '=' * 100
            sep2 = '-' * 100
            if level == 1:
                print(sep1)
                res = func(*args, **kwargs)
            elif level == 2:
                print(sep2)
                res = func(*args, **kwargs)
            else:
                print(f'Invalid level.')
                res = func(*args, **kwargs)
            return res
        return wrapper
    return _separator
  
  
from functools import wraps
import time
import datetime
def get_d_h_m_s_us(sec):
    td = datetime.timedelta(seconds=sec)
    m, s = divmod(td.seconds, 60)
    h, m = divmod(m, 60)
    return td.days, h, m, s, td.microseconds


def timeit(func):
    @separator(1)
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args,**kwargs)
        t =  time.time() - start
        d, h, m, s, us = get_d_h_m_s_us(t)
        if d != 0:
            print(f'Elapsed time for "{func.__name__}" was {d} day {h} hr {m} min {s}.{str(us)[:3]} sec.')
        elif h != 0:
            print(f'Elapsed time for "{func.__name__}" was {h} hr {m} min {s}.{str(us)[:3]} sec.')
        elif m != 0:
            print(f'Elapsed time for "{func.__name__}" was {m} min {s}.{str(us)[:3]} sec.')
        else:
            print(f'Elapsed time for "{func.__name__}" was {s}.{str(us)[:3]} sec.')
        return result
    return wrapper
  
  
@timeit
def sum_n(n):
    s = 0
    for i in range(n):
        for j in range(n):
            for k in range(n):
                s += i
    return s
  
  
sum_n(300)
====================================================================================================
Elapsed time for "sum_n" was 0.501 sec.
4036500000

7. デコレータでtry構文を付与する

デコレータのよくある使用例には、時間計測の他にTry & Exceptの構文を付与する例があります。

以下のように実装することで、@tryanderrorで修飾した関数をtry構文の中に記述した場合と同等になります。

# Exception時に呼び出す関数
def print_error_msg(msg, func, args, kwargs, e):
    print(f"{msg}")
    print(f">>> func: {func.__name__}")
    print(f">>> args: {args}")
    print(f">>> kwargs: {kwargs}")
    print(f'>>> error type: {type(e)}')
    print(f'>>> error msg: {str(e.args)}')                    


# try構文を付与するデコレータの定義
def tryanderror(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            res = func(*args, **kwargs)
            return res
        except Exception as e:
            print_error_msg("Unexpected error occurred.", func, args, kwargs, e)
    return wrapper
  
  
def separator(level):
    def _separator(func):
        @tryanderror
        @wraps(func)
        def wrapper(*args, **kwargs):
            sep1 = '=' * 100
            sep2 = '-' * 100
            if level == 1:
                print(sep1)
                res = func(*args, **kwargs)
            elif level == 2:
                print(sep2)
                res = func(*args, **kwargs)
            else:
                print(f'Invalid level.')
                res = func(*args, **kwargs)                
            return res
        return wrapper
    return _separator
  
  
from functools import wraps
import time
import datetime
def get_d_h_m_s_us(sec):
    td = datetime.timedelta(seconds=sec)
    m, s = divmod(td.seconds, 60)
    h, m = divmod(m, 60)
    return td.days, h, m, s, td.microseconds


def timeit(func):
    @separator(1)
    @tryanderror
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args,**kwargs)
        t =  time.time() - start
        d, h, m, s, us = get_d_h_m_s_us(t)
        if d != 0:
            print(f'Elapsed time for "{func.__name__}" was {d} day {h} hr {m} min {s}.{str(us)[:3]} sec.')
        elif h != 0:
            print(f'Elapsed time for "{func.__name__}" was {h} hr {m} min {s}.{str(us)[:3]} sec.')
        elif m != 0:
            print(f'Elapsed time for "{func.__name__}" was {m} min {s}.{str(us)[:3]} sec.')
        else:
            print(f'Elapsed time for "{func.__name__}" was {s}.{str(us)[:3]} sec.')
        return result
    return wrapper
  
  
@tryanderror
@timeit
def sum_n(n):
    s = 0
    for i in range(n):
        for j in range(n):
            for k in range(n):
                s += i
    return s
  
  
sum_n('a')
====================================================================================================
Unexpected error occurred.
>>> func: sum_n
>>> args: ('a',)
>>> kwargs: {}
>>> error type: <class 'TypeError'>
>>> error msg: ("'str' object cannot be interpreted as an integer",)

8. クラス内の関数をデコレート

クラス内の関数(メソッド)であっても、外部で定義したデコレータをそのまま利用できます。

def print_error_msg(msg, func, args, kwargs, e):
    print(f"{msg}")
    print(f">>> func: {func.__name__}")
    print(f">>> args: {args}")
    print(f">>> kwargs: {kwargs}")
    print(f'>>> error type: {type(e)}')
    print(f'>>> error msg: {str(e.args)}')                    

    
def tryanderror(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            res = func(*args, **kwargs)
            return res
        except Exception as e:
            print_error_msg("Unexpected error occurred.", func, args, kwargs, e)
    return wrapper


def separator(level):
    def _separator(func):
        @tryanderror
        @wraps(func)
        def wrapper(*args, **kwargs):
            sep1 = '=' * 100
            sep2 = '-' * 100
            if level == 1:
                print(sep1)
                res = func(*args, **kwargs)
            elif level == 2:
                print(sep2)
                res = func(*args, **kwargs)
            else:
                print(f'Invalid level.')
                res = func(*args, **kwargs)                
            return res
        return wrapper
    return _separator


from functools import wraps
import time
import datetime


def get_d_h_m_s_us(sec):
    td = datetime.timedelta(seconds=sec)
    m, s = divmod(td.seconds, 60)
    h, m = divmod(m, 60)
    return td.days, h, m, s, td.microseconds


def timeit(func):
    @separator(1)
    @tryanderror
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args,**kwargs)
        t =  time.time() - start
        d, h, m, s, us = get_d_h_m_s_us(t)
        if d != 0:
            print(f'Elapsed time for "{func.__name__}" was {d} day {h} hr {m} min {s}.{str(us)[:3]} sec.')
        elif h != 0:
            print(f'Elapsed time for "{func.__name__}" was {h} hr {m} min {s}.{str(us)[:3]} sec.')
        elif m != 0:
            print(f'Elapsed time for "{func.__name__}" was {m} min {s}.{str(us)[:3]} sec.')
        else:
            print(f'Elapsed time for "{func.__name__}" was {s}.{str(us)[:3]} sec.')
        return result
    return wrapper
  
  
# デコレータつきのクラスの定義
class SumN(object):
    
    def __init__(self):
        self.value = 0
        
     
	@tryanderror
    @timeit
    def add_sum1(self, n):
        for i in range(n):
            self.value += i
            
    
    @tryanderror
    @timeit
    def add_sum2(self, n):
        for i in range(n):
            for j in range(n):
                self.value += i
            
    
    @tryanderror
    @timeit
    def add_sum3(self, n):
        for i in range(n):
            for j in range(n):
                for k in range(n):
                    self.value += i
            
    
    @separator(2)
    def reset(self):
        self.value = 0
        print(f'Reset.')
        
        
# 実行
SN = SumN()
SN.add_sum1(300)
SN.add_sum2(300)
SN.add_sum3(300)
SN.reset()
====================================================================================================
Elapsed time for "add_sum1" was 0.60 sec.
====================================================================================================
Elapsed time for "add_sum2" was 0.945 sec.
====================================================================================================
Elapsed time for "add_sum3" was 0.605 sec.
----------------------------------------------------------------------------------------------------
Reset.

9. クラスのメソッドでデコレータを定義

クラス内のメソッドとしてデコレータを定義することも可能です。

この場合、@classmethodなどを利用するなど少々定義方法や使用方法が異なることに注意します。

以下が例になります。

# クラス外におけるデコレータの定義
def print_error_msg(msg, func, args, kwargs, e):
    print(f"{msg}")
    print(f">>> func: {func.__name__}")
    print(f">>> args: {args}")
    print(f">>> kwargs: {kwargs}")
    print(f'>>> error type: {type(e)}')
    print(f'>>> error msg: {str(e.args)}')                    

    
def tryanderror(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            res = func(*args, **kwargs)
            return res
        except Exception as e:
            print_error_msg("Unexpected error occurred.", func, args, kwargs, e)
    return wrapper


def separator(level):
    def _separator(func):
        @tryanderror
        @wraps(func)
        def wrapper(*args, **kwargs):
            sep1 = '=' * 100
            sep2 = '-' * 100
            if level == 1:
                print(sep1)
                res = func(*args, **kwargs)
            elif level == 2:
                print(sep2)
                res = func(*args, **kwargs)
            else:
                print(f'Invalid level.')
                res = func(*args, **kwargs)                
            return res
        return wrapper
    return _separator
  

# クラスとして時間計測のデコレータを定義  
class Time(object):
    def __init__(self):
        pass
        
    
    @classmethod # デコレータにするメソッドには@classmethodによる修飾を付与
    def timeit(cls, func): # 第一引数にはcls, 第二引数に関数を設定
        @separator(1) # 外部のデコレータ(引数あり)を通常通り使用することも可能
        @tryanderror # 外部のデコレータ(引数なし)を通常通り使用することも可能
        @wraps(func)
        def wrapper(*args, **kwargs):
            start = time.time()
            result = func(*args,**kwargs)
            t =  time.time() - start
            d, h, m, s, us = get_d_h_m_s_us(t)
            if d != 0:
                print(f'Elapsed time for "{func.__name__}" was {d} day {h} hr {m} min {s}.{str(us)[:3]} sec.')
            elif h != 0:
                print(f'Elapsed time for "{func.__name__}" was {h} hr {m} min {s}.{str(us)[:3]} sec.')
            elif m != 0:
                print(f'Elapsed time for "{func.__name__}" was {m} min {s}.{str(us)[:3]} sec.')
            else:
                print(f'Elapsed time for "{func.__name__}" was {s}.{str(us)[:3]} sec.')
            return result
        return wrapper
      

# クラス内で定義したデコレータを利用するため、上記のTimeクラスを継承
class SumN(Time):
    
    def __init__(self):
        self.value = 0
        
     
    @tryanderror # 外部で定義した通常のデコレータはそのまま使用可能
    @Time.timeit # クラスのメソッドとして定義したデコレータの使用方法は通常と異なる
    def add_sum1(self, n):
        for i in range(n):
            self.value += i
            
    
    @tryanderror
    @Time.timeit
    def add_sum2(self, n):
        for i in range(n):
            for j in range(n):
                self.value += i
            
    
    @tryanderror
    @Time.timeit
    def add_sum3(self, n):
        for i in range(n):
            for j in range(n):
                for k in range(n):
                    self.value += i
            
    
    @separator(2)
    def reset(self):
        self.value = 0
        print(f'Reset.')
        
        
SN = SumN()
SN.add_sum1(300)
SN.add_sum2(300)
SN.add_sum3(300)
SN.reset()
====================================================================================================
Elapsed time for "add_sum1" was 0.60 sec.
====================================================================================================
Elapsed time for "add_sum2" was 0.102 sec.
====================================================================================================
Elapsed time for "add_sum3" was 0.624 sec.
----------------------------------------------------------------------------------------------------
Reset.

10.クラスのメソッドで引数つきのデコレータを定義

前節のクラス内デコレータは引数がない場合でした。

ここでは、引数付きのクラス内デコレータの定義方法を示します。

クラス内で定義したデコレータを同じクラス内の別のメソッドに即座に適用することはできないため、以下のようにして適宜別クラスとして定義して継承することで利用します。

# 最初にtry構文用のデコレータをDebugクラスとして定義
class Debug(object):
    def __init__(self):
        pass
    
    
    def print_error_msg(self, msg, func, args, kwargs, e):
        print(f"{msg}")
        print(f">>> func: {func.__name__}")
        print(f">>> args: {args}")
        print(f">>> kwargs: {kwargs}")
        print(f'>>> error type: {type(e)}')
        print(f'>>> error msg: {str(e.args)}')                    


    @classmethod
    def tryanderror(cls, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            try:
                res = func(*args, **kwargs)
                return res
            except Exception as e:
                print_error_msg("Unexpected error occurred.", func, args, kwargs, e)
        return wrapper
      
      
# 次にSeparatorを「引数あり」のクラスとして定義
# この際、Debugクラスを継承
class Separator(Debug):
    def __init__(self):
        pass
    
    
    @classmethod
    def separator(cls, level):
        def _separator(func):
            @Debug.tryanderror
            @wraps(func)
            def wrapper_separator(*args, **kwargs):
                sep1 = '=' * 100
                sep2 = '-' * 100
                if level == 1:
                    print(sep1)
                    res = func(*args, **kwargs)
                elif level == 2:
                    print(sep2)
                    res = func(*args, **kwargs)
                else:
                    print(f'Invalid level.')
                    res = func(*args, **kwargs)
                return res
            return wrapper_separator
        return _separator
      
      
# 最後に時間計測用のクラスをSeparatorクラスを継承して定義
class Time(Separator):
    def __init__(self):
        pass
        
    
    @classmethod
    def timeit(cls, func):
        @Separator.separator(1)
        @Debug.tryanderror
        @wraps(func)
        def wrapper(*args, **kwargs):
            start = time.time()
            result = func(*args,**kwargs)
            t =  time.time() - start
            d, h, m, s, us = get_d_h_m_s_us(t)
            if d != 0:
                print(f'Elapsed time for "{func.__name__}" was {d} day {h} hr {m} min {s}.{str(us)[:3]} sec.')
            elif h != 0:
                print(f'Elapsed time for "{func.__name__}" was {h} hr {m} min {s}.{str(us)[:3]} sec.')
            elif m != 0:
                print(f'Elapsed time for "{func.__name__}" was {m} min {s}.{str(us)[:3]} sec.')
            else:
                print(f'Elapsed time for "{func.__name__}" was {s}.{str(us)[:3]} sec.')
            return result
        return wrapper
      
      
# Timeクラスを継承して個々のクラスを定義
class SumN(Time):
    
    def __init__(self):
        self.value = 0
        
     
    @Debug.tryanderror
    @Time.timeit
    def add_sum1(self, n):
        for i in range(n):
            self.value += i
            
    
    @Debug.tryanderror
    @Time.timeit
    def add_sum2(self, n):
        for i in range(n):
            for j in range(n):
                self.value += i
            
    
    @Debug.tryanderror
    @Time.timeit
    def add_sum3(self, n):
        for i in range(n):
            for j in range(n):
                for k in range(n):
                    self.value += i
            
    
    @Separator.separator(2)
    def reset(self):
        self.value = 0
        print(f'Reset.')
        

# 実行        
SN = SumN()
SN.add_sum1(300)
SN.add_sum2(300)
SN.add_sum3(300)
SN.reset()
====================================================================================================
Elapsed time for "add_sum1" was 0.43 sec.
====================================================================================================
Elapsed time for "add_sum2" was 0.967 sec.
====================================================================================================
Elapsed time for "add_sum3" was 0.606 sec.
----------------------------------------------------------------------------------------------------
Reset.

11. まとめ

本記事では、Pythonのデコレータについて基本的な使い方を時間計測を例にして紹介しました。

また、複数のデコレータを適用する場合、デコレータに引数を渡す方法、クラスのメソッドに対してデコレータを適用する方法、クラスのメソッドをデコレータとして定義して利用する方法についても具体例を示して紹介しました。

以上になります。

コメント