Source code for qcportal.utils

from __future__ import annotations

import collections
import concurrent.futures
import datetime
import functools
import io
import itertools
import json
import logging
import math
import random
import re
import time
from contextlib import contextmanager, redirect_stderr, redirect_stdout
from hashlib import sha256
from typing import Optional, Union, Sequence, List, TypeVar, Any, Dict, Generator, Iterable, Callable, Set, Tuple

import numpy as np

from qcportal.serialization import _JSONEncoder

_T = TypeVar("_T")


[docs] def make_list(obj: Optional[Union[_T, Sequence[_T], Set[_T]]]) -> Optional[List[_T]]: """ Returns a list containing obj if obj is not a list or other iterable type object This will also work with sets """ # NOTE - you might be tempted to change this to work with Iterable rather than Sequence. However, # pydantic models and dicts and stuff are sequences, too, which we usually just want to return # within a list if isinstance(obj, list): return obj if obj is None: return None # Be careful. strings are sequences if isinstance(obj, str): return [obj] if isinstance(obj, set): return list(obj) if not isinstance(obj, Sequence): return [obj] return list(obj)
[docs] def chunk_iterable(it: Iterable[_T], chunk_size: int) -> Generator[List[_T], None, None]: """ Split an iterable (such as a list) into batches/chunks """ if chunk_size < 1: raise ValueError("chunk size must be >= 1") i = iter(it) batch = list(itertools.islice(i, chunk_size)) while batch: yield batch batch = list(itertools.islice(i, chunk_size))
[docs] def chunk_iterable_time( it: Iterable[_T], chunk_time: float, max_chunk_size: int, initial_chunk_size: int ) -> Generator[List[_T], None, None]: """ Split an iterable into chunks, trying to keep a constant time per chunk This function keeps track of the time it takes to process each chunk and tries to keep the time per chunk as close to 'chunk_time' as possible, increasing or decreasing the chunk size as needed (up to 'max_chunk_size') The first chunk will be of size 'initial_chunk_size' (assuming there is enough elements in the iterable to fill it). """ if chunk_time <= 0: raise ValueError("chunk_time must be > 0") if max_chunk_size < 1: raise ValueError("max_chunk_size must be >= 1") if initial_chunk_size < 1 or initial_chunk_size > max_chunk_size: raise ValueError("initial_chunk_size must be >= 1 and <= max_chunk_size") i = iter(it) batch = list(itertools.islice(i, initial_chunk_size)) while batch: # Time how long it takes the caller to process the first chunk start = time.time() yield batch end = time.time() # How many elements could we fit in the desired chunk_time time_per_element = (end - start) / len(batch) chunk_size = math.floor(int(chunk_time / time_per_element)) # Clamp to a valid size chunk_size = max(1, min(chunk_size, max_chunk_size)) # Get the next chunk batch = list(itertools.islice(i, chunk_size))
[docs] def process_chunk_iterable( fn: Callable[[Iterable[_T]], Any], it: Iterable[_T], chunk_time: float, max_chunk_size: int, initial_chunk_size: int, max_workers: int = 1, *, keep_order: bool = False, ) -> Generator[List[_T], None, None]: """ Process an iterable in chunks, trying to keep a constant time per chunk This function keeps track of the time it takes to process each chunk and tries to keep the time per chunk as close to 'chunk_time' as possible, increasing or decreasing the chunk size as needed (up to 'max_chunk_size') The first chunk will be of size 'initial_chunk_size' (assuming there is enough elements in the iterable to fill it). This function returns the results as chunks (lists) of the original iterable. If 'keep_order' is True, the results will be returned in the same order as the original iterable. If 'keep_order' is False, the results will be returned in the order they are completed. """ # NOTE: You might think that we should spin up another thread to handle all the processing and submission # to the thread pool. However, if the user takes a long time processing the chunk (returned via yield) on # their end then this would effectively just process all the data and hold that in the cache. This might be # undesirable if the user is trying to process a large amount of data. Also, the effect is largely the same # in terms of timing. # So this function more or less tries to pre-process enough so that the user is never waiting, striking a # balance between downloading all the data and doing things completely serially. if chunk_time <= 0.0: raise ValueError("chunk_time must be > 0.0") if max_chunk_size < 1: raise ValueError("max_chunk_size must be >= 1") if initial_chunk_size < 1 or initial_chunk_size > max_chunk_size: raise ValueError("initial_chunk_size must be >= 1 and <= max_chunk_size") if max_workers < 1: raise ValueError("max_workers must be >= 1") pool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) # Get initial chunks to be submitted to the pool i = iter(it) chunks = [list(itertools.islice(i, initial_chunk_size)) for _ in range(max_workers)] # Remove empty chunks chunks = [b for b in chunks if b] # Wrap the provided function so that we get timing and chunk id def _process(chunk, chunk_id): start = time.time() ret = fn(chunk) end = time.time() return (end - start) / len(chunk), chunk_id, ret # chunk id we should submit next cur_chunk_idx = 0 # Current chunk id we are returning (if order is kept) cur_ret_chunk_id = 0 # Dictionary keeping the results (indexed by chunk id) results_cache = {} # Submit the given function with the given chunks to the thread pool futures = [pool.submit(_process, chunk, cur_chunk_idx + i) for i, chunk in enumerate(chunks)] cur_chunk_idx += len(chunks) while True: if len(futures) == 0: break # Wait for any of the futures done, not_done = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED) # Get the result of the first completed future average_per_element = 0.0 for future in done: avg_time, chunk_idx, ret = future.result() average_per_element += avg_time # Average per element of the iterable assert cur_chunk_idx not in results_cache results_cache[chunk_idx] = ret if len(done) != 0: # compute the next chunk size time_per_element = average_per_element / len(done) # Average of the averages # How many elements could we fit in the desired chunk_time chunk_size = math.floor(int(chunk_time / time_per_element)) # Clamp to a valid size chunk_size = max(1, min(chunk_size, max_chunk_size)) # next chunks chunks = [list(itertools.islice(i, chunk_size)) for _ in range(len(done))] # Remove empty chunks chunks = [b for b in chunks if b] # Submit to the thread pool futures = list(not_done) + [ pool.submit(_process, chunk, cur_chunk_idx + i) for i, chunk in enumerate(chunks) ] cur_chunk_idx += len(chunks) done_results = list(results_cache.keys()) if keep_order: while cur_ret_chunk_id in done_results: yield results_cache[cur_ret_chunk_id] del results_cache[cur_ret_chunk_id] cur_ret_chunk_id += 1 else: for k in done_results: yield results_cache[k] del results_cache[k] assert len(results_cache) == 0
[docs] def process_iterable( fn: Callable[[Iterable[_T]], Any], it: Iterable[_T], chunk_time: float, max_chunk_size: int, initial_chunk_size: int, max_workers: int = 1, *, keep_order: bool = False, ) -> Generator[List[_T], None, None]: """ Similar to process_chunk_iterable, but returns individual elements ranther than chunks """ for chunk in process_chunk_iterable( fn, it, chunk_time, max_chunk_size, initial_chunk_size, max_workers, keep_order=keep_order ): yield from chunk
[docs] def seconds_to_hms(seconds: Union[float, int]) -> str: """ Converts a number of seconds (as an integer) to a string representing hh:mm:ss """ if isinstance(seconds, float): fraction = seconds % 1 seconds = int(seconds) else: fraction = None hours, seconds = divmod(seconds, 3600) minutes, seconds = divmod(seconds, 60) if fraction is None: return f"{hours:02d}:{minutes:02d}:{seconds:02d}" else: return f"{hours:02d}:{minutes:02d}:{seconds+fraction:02.2f}"
[docs] def duration_to_seconds(s: Union[int, str, float]) -> int: """ Parses a string in dd:hh:mm:ss or 1d2h3m4s to an integer number of seconds """ # Is already an int if isinstance(s, int): return s # Is a float but represents an integer if isinstance(s, float): if s.is_integer(): return int(s) else: raise ValueError(f"Invalid duration format: {s} - cannot represent fractional seconds") # Plain number of seconds (as a string) if s.isdigit(): return int(s) try: f = float(s) if f.is_integer(): return int(f) else: raise ValueError(f"Invalid duration format: {s} - cannot represent fractional seconds") except ValueError: pass # Handle dd:hh:mm:ss format if ":" in s: parts = list(map(int, s.split(":"))) while len(parts) < 4: # Pad missing parts with zeros parts.insert(0, 0) days, hours, minutes, seconds = parts return days * 86400 + hours * 3600 + minutes * 60 + seconds # Handle format like 3d4h7m10s pattern = re.compile(r"(?:(\d+)d)?(?:(\d+)h)?(?:(\d+)m)?(?:(\d+)s)?") match = pattern.fullmatch(s) if not match: raise ValueError(f"Invalid duration format: {s}") days, hours, minutes, seconds = map(lambda x: int(x) if x else 0, match.groups()) return days * 86400 + hours * 3600 + minutes * 60 + seconds
[docs] def recursive_normalizer(value: Any, digits: int = 10, lowercase: bool = True) -> Any: """ Prepare a structure for hashing by lowercasing all values and round all floats """ if isinstance(value, (int, type(None))): pass elif isinstance(value, str): if lowercase: value = value.lower() elif isinstance(value, list): value = [recursive_normalizer(x, digits, lowercase) for x in value] elif isinstance(value, tuple): value = tuple(recursive_normalizer(x, digits, lowercase) for x in value) elif isinstance(value, dict): ret = {} for k, v in value.items(): if lowercase: k = k.lower() ret[k] = recursive_normalizer(v, digits, lowercase) value = ret elif isinstance(value, np.ndarray): if digits: # Round array value = np.around(value, digits) # Flip zeros value[np.abs(value) < 5 ** (-(digits + 1))] = 0 elif isinstance(value, float): if digits: value = round(value, digits) if value == -0.0: value = 0 if value == 0.0: value = 0 else: raise TypeError("Invalid type in recursive normalizer ({type(value)}), only simple Python types are allowed.") return value
[docs] def calculate_limit(max_limit: int, given_limit: Optional[int]) -> int: """Get the allowed limit on results to return for a particular or type of object If 'given_limit' is given (ie, by the user), this will return min(limit, max_limit) where max_limit is the set value for the table/type of object """ if given_limit is None: return max_limit return min(given_limit, max_limit)
[docs] def hash_dict(d: Dict[str, Any]) -> str: j = json.dumps(d, ensure_ascii=True, sort_keys=True, cls=_JSONEncoder).encode("utf-8") return sha256(j).hexdigest()
[docs] @contextmanager def capture_all_output(top_logger: str): """Captures all output, including stdout, stderr, and logging""" stdout_io = io.StringIO() stderr_io = io.StringIO() logger = logging.getLogger(top_logger) old_handlers = logger.handlers.copy() old_prop = logger.propagate logger.handlers.clear() logger.propagate = False # Make logging go to the string io handler = logging.StreamHandler(stdout_io) handler.terminator = "" logger.addHandler(handler) # Also redirect stdout/stderr to the string io objects with redirect_stdout(stdout_io) as rdout, redirect_stderr(stderr_io) as rderr: yield rdout, rderr logger.handlers.clear() logger.handlers = old_handlers logger.propagate = old_prop
[docs] def now_at_utc() -> datetime.datetime: """Get the current time as a timezone-aware datetime object""" # Note that the utcnow() function is deprecated, and does not result in a # timezone-aware datetime object return datetime.datetime.now(datetime.timezone.utc)
@functools.lru_cache def _is_included( key: str, include: Optional[Tuple[str, ...]], exclude: Optional[Tuple[str, ...]], default: bool ) -> bool: if exclude is None: exclude = [] if include is not None: in_include = ("*" in include and default) or "**" in include or key in include else: in_include = default in_exclude = key in exclude return in_include and not in_exclude
[docs] def is_included(key: str, include: Optional[Iterable[str]], exclude: Optional[Iterable[str]], default: bool) -> bool: """ Determine if a field should be included given the include and exclude lists Handles "*" and "**" as well """ if include is not None: include = tuple(sorted(include)) if exclude is not None: exclude = tuple(sorted(exclude)) return _is_included(key, include, exclude, default)
[docs] def update_nested_dict(d: Dict[str, Any], u: Dict[str, Any]): for k, v in u.items(): if isinstance(v, dict): d[k] = update_nested_dict(d.get(k, {}), v) else: d[k] = v return d
[docs] def apply_jitter(t: Union[int, float], jitter_fraction: float) -> float: f = random.uniform(-jitter_fraction, jitter_fraction) return max(t * (1 + f), 0.0)
[docs] def time_based_cache(seconds: int = 10, maxsize: Optional[int] = None): def decorator(func): cache = collections.OrderedDict() @functools.wraps(func) def wrapper(*args, **kwargs): key = (args, frozenset(kwargs.items())) now = time.time() # Clean up old items expiration_time = now - seconds keys_to_delete = [k for k, (timestamp, _) in cache.items() if timestamp < expiration_time] for k in keys_to_delete: del cache[k] # Return from cache if valid if key in cache: return cache[key][1] # Compute and store result result = func(*args, **kwargs) cache[key] = (now, result) # Enforce max size if len(cache) > maxsize: cache.popitem(last=False) # Remove oldest return result return wrapper return decorator