from typing import TypeVar, Union, Callable, NamedTuple, Optional
from threading import Lock
import functools
import cloudpickle
import os
import inspect
import fnmatch
from .log import get_logger
from .util import to_awaitable
logger = get_logger(__name__)
_lock = Lock()
_checkpoint_file: Optional[str] = None
_checkpoint_data: Optional[dict] = None
[docs]class FnInfo(NamedTuple):
fn_name: str
args: tuple
kwargs: dict
call_site: str
KeyFn = Callable[[FnInfo], str]
EMPTY = object()
[docs]def set_checkpoint_file(path: str):
global _checkpoint_file
if _checkpoint_file is not None:
raise RuntimeError(
"checkpoint path has been set to {}".format(_checkpoint_file))
_checkpoint_file = path
_load_checkpoint()
[docs]def apply_checkpoint(key_fn: Union[str, KeyFn], disable = False):
"""
apply checkpoint for function.
Note: This checkpoint implementation doesn't support multiprocess.
To support multiple process we need to have a dedicated background process to read/write checkpoint,
which will require message queue (e.g. nanomsg or nng) to implement it.
Example:
>>> set_checkpoint_file('/tmp/test.ckpt')
>>> task_fn = lambda a, b: a + b
>>> checkpoint('task_1+2')(task_fn)(1, 2)
"""
call_site = inspect.getframeinfo(inspect.stack()[1][0])
T = TypeVar('T', bound=Callable)
def _checkpoint(fn: T) -> T:
@functools.wraps(fn)
def wrapper(*args, **kwargs):
fn_info = FnInfo(
fn_name=fn.__name__,
args=args,
kwargs=kwargs,
call_site=f'{call_site.filename}:{call_site.lineno}',
)
key = key_fn if isinstance(key_fn, str) else key_fn(fn_info)
if disable or _checkpoint_file is None:
return fn(*args, **kwargs)
ret = _get_checkpoint(key)
if ret is not EMPTY:
return ret
ret = fn(*args, **kwargs)
if inspect.isawaitable(ret):
async def _wrap_fn():
_ret = await ret
_set_checkpoint(key, _ret, fn_info, True)
return _ret
return _wrap_fn()
else:
_set_checkpoint(key, ret, fn_info, False)
return ret
return wrapper # type: ignore
return _checkpoint
def _load_checkpoint():
global _checkpoint_data
if _checkpoint_data is not None:
return
assert _checkpoint_file is not None, '_checkpoint_path should not be None!'
if os.path.exists(_checkpoint_file):
with open(_checkpoint_file, 'rb') as f:
_checkpoint_data = cloudpickle.load(f)
else:
_checkpoint_data = dict()
def _dump_checkpoint():
assert _checkpoint_data is not None, '_checkpoint_data should not be None!'
with open(_checkpoint_file, 'wb') as f: # type: ignore
cloudpickle.dump(_checkpoint_data, f)
def _get_checkpoint(key: str):
try:
with _lock:
_load_checkpoint()
assert _checkpoint_data is not None
value = _checkpoint_data.get(key, None)
if value is None:
return EMPTY
logger.info(f"Hit checkpoint: {key}")
if value['is_awaitable']:
return to_awaitable(value['return'])
else:
return value['return']
except Exception as e:
logger.error(f"Fail to get checkpoint: {key}", e)
return EMPTY
def _set_checkpoint(key: str, value, info: FnInfo, is_awaitable: bool = False):
try:
with _lock:
assert _checkpoint_data is not None
# args, kwargs may contain unpickable objects
_checkpoint_data[key] = {
'return': value,
'is_awaitable': is_awaitable,
'info': {
'fn_name': info.fn_name,
'call_site': info.call_site,
}
}
_dump_checkpoint()
except Exception as e:
logger.error('Fail to set checkpoint', e)
[docs]def del_checkpoint(key: str):
try:
with _lock:
_load_checkpoint()
assert _checkpoint_data is not None
if key in _checkpoint_data:
del _checkpoint_data[key]
_dump_checkpoint()
except Exception as e:
logger.error('Fail to delete checkpoint', e)
[docs]class CheckpointCmd:
"""checkpoint command line interface"""
[docs] def load(self, file):
set_checkpoint_file(file)
return self
[docs] def ls(self, verbose=False):
'''list all the checkpoint entries in the checkpoint file'''
assert _checkpoint_data is not None
for i, (key, value) in enumerate(_checkpoint_data.items()):
if verbose:
print('\n'.join([
'=' * 80,
f'Key: \t{key}',
f'Call Site: \t{value["info"]["call_site"]}',
f'Function: \t{value["info"]["fn_name"]}',
]))
else:
print(key)
[docs] def rm(self, glob_pattern: str, yes=False, exclude: Optional[str]=None):
"""remove checkpoint entries with the given pattern"""
assert _checkpoint_data is not None
keys = [ key for key in _checkpoint_data.keys() if fnmatch.fnmatch(key, glob_pattern) ]
if exclude is not None:
keys = [ key for key in keys if not fnmatch.fnmatch(key, exclude) ]
for key in keys:
if not yes:
print(f"Delete checkpoint {key}? [y/n]")
if input().lower() != 'y':
continue
del _checkpoint_data[key]
print(f"Delete checkpoint {key}")
_dump_checkpoint()