Source code for qcportal.utils

from __future__ import annotations

import concurrent.futures
import datetime
import functools
import io
import itertools
import json
import logging
import math
import random
import re
import time
from contextlib import contextmanager, redirect_stderr, redirect_stdout
from hashlib import sha256
from typing import Optional, Union, Sequence, List, TypeVar, Any, Dict, Generator, Iterable, Callable, Set, Tuple

import numpy as np

from qcportal.serialization import _JSONEncoder

_T = TypeVar("_T")


[docs] def make_list(obj: Optional[Union[_T, Sequence[_T], Set[_T]]]) -> Optional[List[_T]]: """ Returns a list containing obj if obj is not a list or other iterable type object This will also work with sets """ # NOTE - you might be tempted to change this to work with Iterable rather than Sequence. However, # pydantic models and dicts and stuff are sequences, too, which we usually just want to return # within a list if isinstance(obj, list): return obj if obj is None: return None # Be careful. strings are sequences if isinstance(obj, str): return [obj] if isinstance(obj, set): return list(obj) if not isinstance(obj, Sequence): return [obj] return list(obj)
[docs] def chunk_iterable(it: Iterable[_T], chunk_size: int) -> Generator[List[_T], None, None]: """ Split an iterable (such as a list) into batches/chunks """ if chunk_size < 1: raise ValueError("chunk size must be >= 1") i = iter(it) batch = list(itertools.islice(i, chunk_size)) while batch: yield batch batch = list(itertools.islice(i, chunk_size))
[docs] def chunk_iterable_time( it: Iterable[_T], chunk_time: float, max_chunk_size: int, initial_chunk_size: int ) -> Generator[List[_T], None, None]: """ Split an iterable into chunks, trying to keep a constant time per chunk This function keeps track of the time it takes to process each chunk and tries to keep the time per chunk as close to 'chunk_time' as possible, increasing or decreasing the chunk size as needed (up to 'max_chunk_size') The first chunk will be of size 'initial_chunk_size' (assuming there is enough elements in the iterable to fill it). """ if chunk_time <= 0: raise ValueError("chunk_time must be > 0") if max_chunk_size < 1: raise ValueError("max_chunk_size must be >= 1") if initial_chunk_size < 1 or initial_chunk_size > max_chunk_size: raise ValueError("initial_chunk_size must be >= 1 and <= max_chunk_size") i = iter(it) batch = list(itertools.islice(i, initial_chunk_size)) while batch: # Time how long it takes the caller to process the first chunk start = time.time() yield batch end = time.time() # How many elements could we fit in the desired chunk_time time_per_element = (end - start) / len(batch) chunk_size = math.floor(int(chunk_time / time_per_element)) # Clamp to a valid size chunk_size = max(1, min(chunk_size, max_chunk_size)) # Get the next chunk batch = list(itertools.islice(i, chunk_size))
[docs] def process_chunk_iterable( fn: Callable[[Iterable[_T]], Any], it: Iterable[_T], chunk_time: float, max_chunk_size: int, initial_chunk_size: int, max_workers: int = 1, *, keep_order: bool = False, ) -> Generator[List[_T], None, None]: """ Process an iterable in chunks, trying to keep a constant time per chunk This function keeps track of the time it takes to process each chunk and tries to keep the time per chunk as close to 'chunk_time' as possible, increasing or decreasing the chunk size as needed (up to 'max_chunk_size') The first chunk will be of size 'initial_chunk_size' (assuming there is enough elements in the iterable to fill it). This function returns the results as chunks (lists) of the original iterable. If 'keep_order' is True, the results will be returned in the same order as the original iterable. If 'keep_order' is False, the results will be returned in the order they are completed. """ # NOTE: You might think that we should spin up another thread to handle all the processing and submission # to the thread pool. However, if the user takes a long time processing the chunk (returned via yield) on # their end then this would effectively just process all the data and hold that in the cache. This might be # undesirable if the user is trying to process a large amount of data. Also, the effect is largely the same # in terms of timing. # So this function more or less tries to pre-process enough so that the user is never waiting, striking a # balance between downloading all the data and doing things completely serially. if chunk_time <= 0.0: raise ValueError("chunk_time must be > 0.0") if max_chunk_size < 1: raise ValueError("max_chunk_size must be >= 1") if initial_chunk_size < 1 or initial_chunk_size > max_chunk_size: raise ValueError("initial_chunk_size must be >= 1 and <= max_chunk_size") if max_workers < 1: raise ValueError("max_workers must be >= 1") pool = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) # Get initial chunks to be submitted to the pool i = iter(it) chunks = [list(itertools.islice(i, initial_chunk_size)) for _ in range(max_workers)] # Remove empty chunks chunks = [b for b in chunks if b] # Wrap the provided function so that we get timing and chunk id def _process(chunk, chunk_id): start = time.time() ret = fn(chunk) end = time.time() return (end - start) / len(chunk), chunk_id, ret # chunk id we should submit next cur_chunk_idx = 0 # Current chunk id we are returning (if order is kept) cur_ret_chunk_id = 0 # Dictionary keeping the results (indexed by chunk id) results_cache = {} # Submit the given function with the given chunks to the thread pool futures = [pool.submit(_process, chunk, cur_chunk_idx + i) for i, chunk in enumerate(chunks)] cur_chunk_idx += len(chunks) while True: if len(futures) == 0: break # Wait for any of the futures done, not_done = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED) # Get the result of the first completed future average_per_element = 0.0 for future in done: avg_time, chunk_idx, ret = future.result() average_per_element += avg_time # Average per element of the iterable assert cur_chunk_idx not in results_cache results_cache[chunk_idx] = ret if len(done) != 0: # compute the next chunk size time_per_element = average_per_element / len(done) # Average of the averages # How many elements could we fit in the desired chunk_time chunk_size = math.floor(int(chunk_time / time_per_element)) # Clamp to a valid size chunk_size = max(1, min(chunk_size, max_chunk_size)) # next chunks chunks = [list(itertools.islice(i, chunk_size)) for _ in range(len(done))] # Remove empty chunks chunks = [b for b in chunks if b] # Submit to the thread pool futures = list(not_done) + [ pool.submit(_process, chunk, cur_chunk_idx + i) for i, chunk in enumerate(chunks) ] cur_chunk_idx += len(chunks) done_results = list(results_cache.keys()) if keep_order: while cur_ret_chunk_id in done_results: yield results_cache[cur_ret_chunk_id] del results_cache[cur_ret_chunk_id] cur_ret_chunk_id += 1 else: for k in done_results: yield results_cache[k] del results_cache[k] assert len(results_cache) == 0
[docs] def process_iterable( fn: Callable[[Iterable[_T]], Any], it: Iterable[_T], chunk_time: float, max_chunk_size: int, initial_chunk_size: int, max_workers: int = 1, *, keep_order: bool = False, ) -> Generator[List[_T], None, None]: """ Similar to process_chunk_iterable, but returns individual elements ranther than chunks """ for chunk in process_chunk_iterable( fn, it, chunk_time, max_chunk_size, initial_chunk_size, max_workers, keep_order=keep_order ): yield from chunk
[docs] def seconds_to_hms(seconds: Union[float, int]) -> str: """ Converts a number of seconds (as an integer) to a string representing hh:mm:ss """ if isinstance(seconds, float): fraction = seconds % 1 seconds = int(seconds) else: fraction = None hours, seconds = divmod(seconds, 3600) minutes, seconds = divmod(seconds, 60) if fraction is None: return f"{hours:02d}:{minutes:02d}:{seconds:02d}" else: return f"{hours:02d}:{minutes:02d}:{seconds+fraction:02.2f}"
[docs] def duration_to_seconds(s: Union[int, str, float]) -> int: """ Parses a string in dd:hh:mm:ss or 1d2h3m4s to an integer number of seconds """ # Is already an int if isinstance(s, int): return s # Is a float but represents an integer if isinstance(s, float): if s.is_integer(): return int(s) else: raise ValueError(f"Invalid duration format: {s} - cannot represent fractional seconds") # Plain number of seconds (as a string) if s.isdigit(): return int(s) try: f = float(s) if f.is_integer(): return int(f) else: raise ValueError(f"Invalid duration format: {s} - cannot represent fractional seconds") except ValueError: pass # Handle dd:hh:mm:ss format if ":" in s: parts = list(map(int, s.split(":"))) while len(parts) < 4: # Pad missing parts with zeros parts.insert(0, 0) days, hours, minutes, seconds = parts return days * 86400 + hours * 3600 + minutes * 60 + seconds # Handle format like 3d4h7m10s pattern = re.compile(r"(?:(\d+)d)?(?:(\d+)h)?(?:(\d+)m)?(?:(\d+)s)?") match = pattern.fullmatch(s) if not match: raise ValueError(f"Invalid duration format: {s}") days, hours, minutes, seconds = map(lambda x: int(x) if x else 0, match.groups()) return days * 86400 + hours * 3600 + minutes * 60 + seconds
[docs] def recursive_normalizer(value: Any, digits: int = 10, lowercase: bool = True) -> Any: """ Prepare a structure for hashing by lowercasing all values and round all floats """ if isinstance(value, (int, type(None))): pass elif isinstance(value, str): if lowercase: value = value.lower() elif isinstance(value, list): value = [recursive_normalizer(x, digits, lowercase) for x in value] elif isinstance(value, tuple): value = tuple(recursive_normalizer(x, digits, lowercase) for x in value) elif isinstance(value, dict): ret = {} for k, v in value.items(): if lowercase: k = k.lower() ret[k] = recursive_normalizer(v, digits, lowercase) value = ret elif isinstance(value, np.ndarray): if digits: # Round array value = np.around(value, digits) # Flip zeros value[np.abs(value) < 5 ** (-(digits + 1))] = 0 elif isinstance(value, float): if digits: value = round(value, digits) if value == -0.0: value = 0 if value == 0.0: value = 0 else: raise TypeError("Invalid type in recursive normalizer ({type(value)}), only simple Python types are allowed.") return value
[docs] def calculate_limit(max_limit: int, given_limit: Optional[int]) -> int: """Get the allowed limit on results to return for a particular or type of object If 'given_limit' is given (ie, by the user), this will return min(limit, max_limit) where max_limit is the set value for the table/type of object """ if given_limit is None: return max_limit return min(given_limit, max_limit)
[docs] def hash_dict(d: Dict[str, Any]) -> str: j = json.dumps(d, ensure_ascii=True, sort_keys=True, cls=_JSONEncoder).encode("utf-8") return sha256(j).hexdigest()
[docs] @contextmanager def capture_all_output(top_logger: str): """Captures all output, including stdout, stderr, and logging""" stdout_io = io.StringIO() stderr_io = io.StringIO() logger = logging.getLogger(top_logger) old_handlers = logger.handlers.copy() old_prop = logger.propagate logger.handlers.clear() logger.propagate = False # Make logging go to the string io handler = logging.StreamHandler(stdout_io) handler.terminator = "" logger.addHandler(handler) # Also redirect stdout/stderr to the string io objects with redirect_stdout(stdout_io) as rdout, redirect_stderr(stderr_io) as rderr: yield rdout, rderr logger.handlers.clear() logger.handlers = old_handlers logger.propagate = old_prop
[docs] def now_at_utc() -> datetime.datetime: """Get the current time as a timezone-aware datetime object""" # Note that the utcnow() function is deprecated, and does not result in a # timezone-aware datetime object return datetime.datetime.now(datetime.timezone.utc)
@functools.lru_cache def _is_included( key: str, include: Optional[Tuple[str, ...]], exclude: Optional[Tuple[str, ...]], default: bool ) -> bool: if exclude is None: exclude = [] if include is not None: in_include = ("*" in include and default) or "**" in include or key in include else: in_include = default in_exclude = key in exclude return in_include and not in_exclude
[docs] def is_included(key: str, include: Optional[Iterable[str]], exclude: Optional[Iterable[str]], default: bool) -> bool: """ Determine if a field should be included given the include and exclude lists Handles "*" and "**" as well """ if include is not None: include = tuple(sorted(include)) if exclude is not None: exclude = tuple(sorted(exclude)) return _is_included(key, include, exclude, default)
[docs] def update_nested_dict(d: Dict[str, Any], u: Dict[str, Any]): for k, v in u.items(): if isinstance(v, dict): d[k] = update_nested_dict(d.get(k, {}), v) else: d[k] = v return d
[docs] def apply_jitter(t: Union[int, float], jitter_fraction: float) -> float: f = random.uniform(-jitter_fraction, jitter_fraction) return max(t * (1 + f), 0.0)