Source code for ai2_kit.tool.dpdata

from ai2_kit.core.util import ensure_dir, expand_globs
from ai2_kit.core.log import get_logger

import numpy as np

import dpdata
from dpdata.data_type import Axis, DataType

def __export_remote():
    def register_data_types():
        DATA_TYPES = [
            DataType("fparam", np.ndarray, (Axis.NFRAMES, -1), required=False),
            DataType("aparam", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, -1), required=False),
            DataType("efield", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3), required=False),
            DataType("ext_efield", np.ndarray, (Axis.NFRAMES, 3), required=False),
            DataType("atomic_dipole", np.ndarray, (Axis.NFRAMES, -1), required=False),
        ]
        dpdata.System.register_data_type(*DATA_TYPES)
        dpdata.LabeledSystem.register_data_type(*DATA_TYPES)


    def set_fparam(system, fparam):
        nframes = system.get_nframes()
        system.data['fparam'] = np.tile(fparam, (nframes, 1))
        return system


    return (
        register_data_types,
        set_fparam,
    )

(
    register_data_types,
    set_fparam,
) = __export_remote()


logger = get_logger(__name__)
register_data_types()


[docs]class DpdataHelper: def __init__(self, label: bool = True): """ label: if True, read data with labels (force, energy, etc), else read data without labels, use --nolabel to disable reading labels """ self._systems = [] self._label = label
[docs] def read(self, *file_path_or_glob: str, **kwargs): """ read data from multiple paths, support glob pattern default format is deepmd/npy :param file_path_or_glob: path or glob pattern to find data files :param kwargs: arguments to pass to dpdata.System / dpdata.LabeledSystem """ kwargs.setdefault('fmt', 'deepmd/npy') files = expand_globs(file_path_or_glob) if len(files) == 0: raise FileNotFoundError(f'No file found for {file_path_or_glob}') for file in files: self._read(file, **kwargs) return self
[docs] def filter(self, lambda_expr: str): """ filter data with lambda expression :param lambda_expr: lambda expression to filter data """ fn = eval(lambda_expr) self._systems = [ system for system in self._systems if fn(system.data)] return self
@property def merge_write(self): logger.warn('merge_write is deprecated, use write instead') return self.write
[docs] def write(self, out_path: str, fmt='deepmd/npy', merge: bool = True): """ write data to specific path, support deepmd/npy, deepmd/raw, deepmd/hdf5 formats :param out_path: path to write data :param fmt: format to write, default is deepmd/npy :param merge: if True, merge all data use dpdata.MultiSystems, else write data without merging """ ensure_dir(out_path) if len(self._systems) == 0: raise ValueError('No data to merge') if merge: systems = dpdata.MultiSystems(self._systems[0]) else: systems = self._systems[0] for system in self._systems[1:]: systems.append(system) if fmt == 'deepmd/npy': systems.to_deepmd_npy(out_path) # type: ignore elif fmt == 'deepmd/raw': systems.to_deepmd_raw(out_path) # type: ignore elif fmt == 'deepmd/hdf5': systems.to_deepmd_hdf5(out_path) # type: ignore else: raise ValueError(f'Unknown fmt {fmt}')
[docs] def set_fparam(self, fparam): """ Set fparam for all systems :param fparam: fparam to set, should be a scalar or vector """ for system in self._systems: set_fparam(system, fparam) return self
def _read(self, file: str, **kwargs): if self._label: self._systems.extend(dpdata.LabeledSystem(file, **kwargs)) # type: ignore else: self._systems.extend(dpdata.System(file, **kwargs)) # type: ignore