Source code for ai2_kit.core.util

from ruamel.yaml import YAML, ScalarNode, SequenceNode
from pathlib import Path
from typing import Tuple, List, TypeVar, Union, Iterable
from dataclasses import field
from itertools import zip_longest
import asyncio

import shortuuid
import hashlib
import base64
import copy
import os
import random
import json
import glob

from .log import get_logger

logger = get_logger(__name__)

EMPTY = object()


[docs]def load_json(path: Union[Path, str], encoding: str = 'utf-8'): if isinstance(path, str): path = Path(path) with open(path, 'r', encoding=encoding) as f: return json.load(f)
[docs]def load_text(path: Union[Path, str], encoding: str = 'utf-8'): if isinstance(path, str): path = Path(path) with open(path, 'r', encoding=encoding) as f: return f.read()
[docs]def parse_path_list(path_list_str: Union[str, List[str]], to_abs: bool = False): """ Parse path list of environment variable style string """ def parse_path(path: str): return os.path.expanduser(path) if path.startswith('~/') else path if isinstance(path_list_str, str): path_list = path_list_str.split(':') else: path_list = path_list_str if to_abs: path_list = [parse_path(path) for path in path_list] return path_list
[docs]def wait_for_change(widget, attribute): """ Wait for attribute change of a Jupyter widget """ future = asyncio.Future() def getvalue(change): # make the new value available future.set_result(change.new) widget.unobserve(getvalue, attribute) widget.observe(getvalue, attribute) return future
[docs]def default_mutable_field(obj): return field(default_factory=lambda: copy.copy(obj))
[docs]def get_yaml(): yaml = YAML(typ='safe') JoinTag.register(yaml) LoadTextTag.register(yaml) LoadYamlTag.register(yaml) return yaml
[docs]def load_yaml_file(path: Union[Path, str]): if isinstance(path, str): path = Path(path) yaml = get_yaml() return yaml.load(path)
[docs]def load_yaml_files(*paths: Tuple[Path], quiet: bool = False, purge_anonymous = True): d = {} for path in paths: print('load yaml file: ', path) d = merge_dict(d, load_yaml_file(Path(path)), quiet=quiet) # type: ignore if purge_anonymous: dict_remove_dot_keys(d) return d
[docs]def nested_set(d: dict, keys: List[str], value): for key in keys[:-1]: d = d.setdefault(key, {}) d[keys[-1]] = value
[docs]def s_uuid(): """short uuid""" return shortuuid.uuid()
[docs]def sort_unique_str_list(l: List[str]) -> List[str]: """remove duplicate str and sort""" return sorted(set(l))
T = TypeVar('T')
[docs]def flatten(l: List[List[T]]) -> List[T]: return [item for sublist in l for item in sublist]
[docs]def format_env_string(s: str) -> str: return s.format(**os.environ)
[docs]def list_split(l: List[T], n: int) -> List[List[T]]: """split list into n chunks""" # ref: https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length k, m = divmod(len(l), n) return [l[i*k+min(i, m): (i+1)*k+min(i+1, m)] for i in range(n)]
[docs]def short_hash(s: str) -> str: """short hash string""" digest = hashlib.sha1(s.encode('utf-8')).digest() # use urlsafe encode to avoid '/' in the string, as it will cause problem in file path return base64.urlsafe_b64encode(digest).decode('utf-8')[:-2]
[docs]async def to_awaitable(value: T) -> T: return value
[docs]class JoinTag: """a tag to join strings in a list""" yaml_tag = u'!join'
[docs] @classmethod def from_yaml(cls, constructor, node): seq = constructor.construct_sequence(node) return ''.join([str(i) for i in seq])
[docs] @classmethod def to_yaml(cls, dumper, data): ...
[docs] @classmethod def register(cls, yaml: YAML): yaml.register_class(cls)
[docs]class LoadTextTag: """a tag to read string from file""" yaml_tag = u'!load_text'
[docs] @classmethod def from_yaml(cls, constructor, node): path = _yaml_get_path_node(node, constructor) with open(path, 'r') as f: return f.read()
[docs] @classmethod def to_yaml(cls, dumper, data): ...
[docs] @classmethod def register(cls, yaml: YAML): yaml.register_class(cls)
[docs]class LoadYamlTag: """a tag to read string from file""" yaml_tag = u'!load_yaml'
[docs] @classmethod def from_yaml(cls, constructor, node): path = _yaml_get_path_node(node, constructor) yaml = get_yaml() with open(path, 'r') as f: return yaml.load(f)
[docs] @classmethod def to_yaml(cls, dumper, data): ...
[docs] @classmethod def register(cls, yaml: YAML): yaml.register_class(cls)
def _yaml_get_path_node(node, constructor): if isinstance(node, ScalarNode): return constructor.construct_scalar(node) elif isinstance(node, SequenceNode): seq = constructor.construct_sequence(node) return os.path.join(*seq) else: raise ValueError(f'Unknown node type {type(node)}')
[docs]def dict_remove_dot_keys(d): for k in list(d.keys()): if k.startswith('.'): del d[k] elif isinstance(d[k], dict): dict_remove_dot_keys(d[k])
def __export_remote_functions(): """cloudpickle compatible: https://stackoverflow.com/questions/75292769""" def merge_dict(lo: dict, ro: dict, path=None, ignore_none=True, quiet=False): """ Merge two dict, the left dict will be overridden. Note: list will be replaced instead of merged. """ if path is None: path = [] for key, value in ro.items(): if ignore_none and value is None: continue if key in lo: current_path = path + [str(key)] if isinstance(lo[key], dict) and isinstance(value, dict): merge_dict(lo[key], value, path=current_path, ignore_none=ignore_none, quiet=quiet) else: if not quiet: print('.'.join(current_path) + ' has been overridden') lo[key] = value else: lo[key] = value return lo def dict_nested_get(d: dict, keys: List[str], default=EMPTY): """get value from nested dict""" for key in keys: if key not in d and default is not EMPTY: return default d = d[key] return d def dict_nested_set(d: dict, keys: List[str], value): """set value to nested dict""" for key in keys[:-1]: d = d[key] d[keys[-1]] = value def list_even_sample(l, size): if size <= 0 or size > len(l): return l # calculate the sample interval interval = len(l) / size return [l[int(i * interval)] for i in range(size)] def list_random_sample(l, size, seed = None): if seed is None: seed = len(l) random.seed(seed) return random.sample(l, size) def list_sample(l, size, method='even', **kwargs): if method == 'even': return list_even_sample(l, size) elif method == 'random': return list_random_sample(l, size, **kwargs) elif method == 'truncate': return l[:size] else: raise ValueError(f'Unknown sample method {method}') def flat_evenly(list_of_lists): """ flat a list of lists and ensure the output result distributed evenly >>> flat_evenly([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) [1, 4, 7, 2, 5, 8, 3, 6, 9] Ref: https://stackoverflow.com/questions/76751171/how-to-flat-a-list-of-lists-and-ensure-the-output-result-distributed-evenly-in-p """ return [e for tup in zip_longest(*list_of_lists) for e in tup if e is not None] def limit(it, size=-1): """ limit the size of an iterable """ if size <= 0: yield from it else: for i, x in enumerate(it): if i >= size: break yield x def dump_json(obj, path: str): default = lambda o: f"<<non-serializable: {type(o).__qualname__}>>" with open(path, 'w', encoding='utf-8') as f: json.dump(obj, f, indent=2, default=default) def dump_text(text: str, path: str, **kwargs): with open(path, 'w', **kwargs) as f: f.write(text) def flush_stdio(): import sys sys.stdout.flush() sys.stderr.flush() def ensure_dir(path: str): dirname = os.path.dirname(path) if dirname: os.makedirs(dirname, exist_ok=True) def expand_globs(patterns: Iterable[str], raise_invalid=False) -> List[str]: """ Expand glob patterns in paths :param patterns: list of paths or glob patterns :param raise_invalid: if True, will raise error if no file found for a glob pattern :return: list of expanded paths """ paths = [] for pattern in patterns: result = glob.glob(pattern, recursive=True) if '*' in pattern else [pattern] if raise_invalid and len(result) == 0: raise FileNotFoundError(f'No file found for {pattern}') paths += result return sort_unique_str_list(paths) # export functions return ( merge_dict, dict_nested_get, dict_nested_set, list_even_sample, list_random_sample, list_sample, flat_evenly, limit, dump_json, dump_text, flush_stdio, ensure_dir, expand_globs, ) ( merge_dict, dict_nested_get, dict_nested_set, list_even_sample, list_random_sample, list_sample, flat_evenly, limit, dump_json, dump_text, flush_stdio, ensure_dir, expand_globs, ) = __export_remote_functions()