Source code for qcportal.client_base

from __future__ import annotations

import hashlib
import logging
import os
import random
import time
from typing import Any, Dict, Optional, Union, TypeVar, Type, overload, Tuple, Iterable

import jwt
import pydantic
import requests
import urllib3.exceptions
import yaml
from packaging.version import parse as parse_version
from tqdm import tqdm

from . import __version__
from .exceptions import AuthenticationFailure
from .serialization import serialize, deserialize

AllowedConnectionExceptions = (
    ConnectionError,
    requests.exceptions.Timeout,
    requests.exceptions.ConnectionError,
    urllib3.exceptions.TimeoutError,
)

_T = TypeVar("_T")
_U = TypeVar("_U")
_V = TypeVar("_V")


_ssl_error_msg = (
    "\n\nSSL handshake failed. This is likely caused by a failure to retrieve 3rd party SSL certificates.\n"
    "If you trust the server you are connecting to, try 'PortalClient(... verify=False)'"
)
_connection_error_msg = "\n\nCould not connect to server {}, please check the address and try again."


[docs] def pretty_print_request(req): print("----------------------") print(f"{req.method} {req.url}") print("\n".join(f"{k}: {v}" for k, v in req.headers.items())) print("----------------------")
[docs] def pretty_print_response(res): print("----------------------") print(f"RESPONSE {res.url} -> {res.status_code}") print(f"Content: {len(res.content)} bytes") print("\n".join(f"{k}: {v}" for k, v in res.headers.items())) print("----------------------")
[docs] class PortalRequestError(Exception): def __init__(self, msg: str, status_code: int, details: Dict[str, Any]): Exception.__init__(self, msg) self.msg = msg self.status_code = status_code self.details = details def __str__(self): return f"{self.msg} (HTTP status {self.status_code})"
[docs] class PortalClientBase: def __init__( self, address: str, username: Optional[str] = None, password: Optional[str] = None, verify: bool = True, show_motd: bool = True, *, information_endpoint: str = "api/v1/information", ) -> None: """Initializes a PortalClient instance from an address and verification information. Parameters ---------- address The IP and port of the FractalServer instance ("192.168.1.1:8888") username The username to authenticate with. password The password to authenticate with. verify Verifies the SSL connection with a third party server. This may be False if a FractalServer was not provided a SSL certificate and defaults back to self-signed SSL keys. show_motd If a Message-of-the-Day is available, display it """ self._logger = logging.getLogger("PortalClientBase") # For developer use and debugging self.debug_requests = False # Where we get the server information from self._information_endpoint = information_endpoint.strip("/") if not address.startswith("http://") and not address.startswith("https://"): address = "https://" + address # If we are `http`, ignore all SSL directives if not address.startswith("https"): self._verify = True if not address.endswith("/"): address += "/" self.address = address self.username = username self.user_id = None self._verify = verify # A persistent session # This results in significant speedup (~65% faster in my test) # https://docs.python-requests.org/en/master/user/advanced/#session-objects self._req_session = requests.Session() self._req_session.headers.update({"User-Agent": f"qcportal/{__version__}"}) self.encoding = "application/json" self.timeout = 60 # Handling retries of requests self.retry_max = 5 self.retry_delay = 0.5 self.retry_backoff = 2 self.retry_jitter_fraction = 0.05 # Processing/downloading in threads # Number of threads to use when fetching from the server self.n_download_threads = 2 # Target time for how long a request should take (in seconds) # Chunk size will be adjusted to try to reach this target time self.download_target_time = 0.50 # If no 3rd party verification, quiet urllib if self._verify is False: requests.packages.urllib3.disable_warnings(category=urllib3.exceptions.InsecureRequestWarning) if username is not None and password is not None: self._username = username self._password = password self._get_JWT_token() else: self._username = None self._password = None self._jwt_access_exp = None self._jwt_refresh_exp = None # Try to connect and pull the server info self.server_info = self.get_server_information() self.server_name = self.server_info["name"] self.api_limits = self.server_info["api_limits"] server_version = parse_version(self.server_info["version"]) client_version = parse_version(__version__) if client_version > server_version: self._logger.warning( "WARNING: This client version is newer than the server version. This may work if the " "versions are close, but expect exceptions and errors if attempting things the server " "does not support. " f"client version: {str(__version__)}, server version: {str(self.server_info['version'])}" ) motd = self.server_info.get("motd", "") if show_motd and motd: print("*" * 10 + "- Message-of-the-Day from the server -" + "*" * 10) print() print(motd) print() print("*" * 14 + "- End of Message-of-the-Day -" + "*" * 15)
[docs] @classmethod def from_file(cls, server_name: Optional[str] = None, config_path: Optional[str] = None): """Creates a new client given information in a file. If no path is passed in, the current working directory and finally ~/.qca are searched for "qcportal_config.yaml" Parameters ---------- server_name Name/alias of the server in the yaml file config_path Full path to a configuration file, or a directory containing "qcportal_config.yaml". """ # Search canonical paths if config_path is None: test_paths = [os.getcwd(), os.path.join(os.path.expanduser("~"), ".qca")] for path in test_paths: local_path = os.path.join(path, "qcportal_config.yaml") if os.path.exists(local_path): config_path = local_path break if config_path is None: raise FileNotFoundError( "Could not find `qcportal_config.yaml` in the following paths:\n {}".format( ", ".join(test_paths) ) ) else: config_path = os.path.join(os.path.expanduser(config_path)) # Gave folder, not file if os.path.isdir(config_path): config_path = os.path.join(config_path, "qcportal_config.yaml") with open(config_path, "r") as handle: data = yaml.load(handle, Loader=yaml.SafeLoader) if server_name is not None: data = data.get(server_name) if data is None: raise RuntimeError(f"Server '{server_name}' does not exist in the configuration file") if "address" not in data: raise KeyError("Config file must at least contain an address field.") return cls(**data)
[docs] @classmethod def from_env(cls): """Creates a new client given information stored in environment variables The environment variables are: * QCPORTAL_ADDRESS (required) * QCPORTAL_USERNAME (optional) * QCPORTAL_PASSWORD (optional) * QCPORTAL_VERIFY (optional, defaults to True) * QCPORTAL_CACHE_DIR (optional) """ address = os.environ.get("QCPORTAL_ADDRESS", None) username = os.environ.get("QCPORTAL_USERNAME", None) password = os.environ.get("QCPORTAL_PASSWORD", None) verify = os.environ.get("QCPORTAL_VERIFY", True) cache_dir = os.environ.get("QCPORTAL_CACHE_DIR", None) if address is None: raise KeyError("Required environment variable 'QCPORTAL_ADDRESS' not found") data = {"address": address} if username is not None: data["username"] = username if password is not None: data["password"] = password if cache_dir is not None: data["cache_dir"] = cache_dir data["verify"] = verify return cls(**data)
@property def encoding(self) -> str: return self._encoding @encoding.setter def encoding(self, encoding: str): self._encoding = encoding enc_headers = {"Accept": encoding} self._req_session.headers.update(enc_headers) def _send_request(self, req: requests.Request, allow_retries: bool = True) -> requests.Response: """ Sends a prepared request, optionally retrying on errors Parameters ---------- req A prepared request to send allow_retries If true, attempts to retry on certain kinds of errors Returns ------- : The response returned from the request """ prep_req = self._req_session.prepare_request(req) if self.debug_requests: pretty_print_request(prep_req) if not allow_retries: ret = self._req_session.send(prep_req, verify=self._verify, timeout=self.timeout, allow_redirects=False) if self.debug_requests: pretty_print_response(ret) if ret.is_redirect: raise RuntimeError("Redirection is not allowed") return ret retry_count = 0 try: while True: try: ret = self._req_session.send( prep_req, verify=self._verify, timeout=self.timeout, allow_redirects=False ) break except requests.exceptions.SSLError: raise ConnectionRefusedError(_ssl_error_msg) from None except AllowedConnectionExceptions as e: if retry_count >= self.retry_max: raise # eg, if jitter fraction is 0.05, then multiply by something on the range 0.95 to 1.05 jitter = random.uniform(1.0 - self.retry_jitter_fraction, 1.0 + self.retry_jitter_fraction) time_to_wait = self.retry_delay * (self.retry_backoff**retry_count) * jitter retry_count += 1 self._logger.warning( f"Connection error for {prep_req.url}: {str(e)} - retrying in {time_to_wait:.2f} seconds " f"[{retry_count}/{self.retry_max}]" ) time.sleep(time_to_wait) except requests.exceptions.SSLError: raise ConnectionRefusedError(_ssl_error_msg) from None except AllowedConnectionExceptions: raise ConnectionRefusedError(_connection_error_msg.format(self.address)) from None if self.debug_requests: pretty_print_response(ret) if ret.is_redirect: raise RuntimeError("Redirection is not allowed") return ret def _get_JWT_token(self) -> None: full_uri = self.address + "auth/v1/login" json = {"username": self._username, "password": self._password} req = requests.Request(method="POST", url=full_uri, json=json) ret = self._send_request(req) if ret.status_code == 200: ret_json = ret.json() self._jwt_refresh_token = ret_json["refresh_token"] self._jwt_access_token = ret_json["access_token"] self._req_session.headers.update({"Authorization": f"Bearer {self._jwt_access_token}"}) # Store the expiration time of the access and refresh tokens # (these are unix epoch timestamps) decoded_access_token = jwt.decode( self._jwt_access_token, algorithms=["HS256"], options={"verify_signature": False} ) decoded_refresh_token = jwt.decode( self._jwt_refresh_token, algorithms=["HS256"], options={"verify_signature": False} ) self._jwt_access_exp = decoded_access_token["exp"] self._jwt_refresh_exp = decoded_refresh_token["exp"] self.user_id = int(decoded_access_token["sub"]) # "identity" "subject" else: try: msg = ret.json()["msg"] except: msg = ret.reason raise AuthenticationFailure(msg) def _refresh_JWT_token(self) -> None: full_uri = self.address + "auth/v1/refresh" headers = {"Authorization": f"Bearer {self._jwt_refresh_token}"} req = requests.Request(method="POST", url=full_uri, headers=headers) ret = self._send_request(req) if ret.status_code == 200: ret_json = ret.json() self._jwt_access_token = ret_json["access_token"] self._req_session.headers.update({"Authorization": f"Bearer {self._jwt_access_token}"}) # Store the expiration time of the access and refresh tokens # (these are unix epoch timestamps) decoded_access_token = jwt.decode( self._jwt_access_token, algorithms=["HS256"], options={"verify_signature": False} ) self._jwt_access_exp = decoded_access_token["exp"] elif ret.status_code == 401 and "Token has expired" in ret.json()["msg"]: # If the refresh token has expired, try to log in again self._get_JWT_token() elif ret.status_code == 401 and f" is disabled" in ret.json()["msg"]: raise AuthenticationFailure("User account has been disabled") elif ret.status_code == 401 and f" does not exist" in ret.json()["msg"]: raise AuthenticationFailure("User account no longer exists") else: # shouldn't happen unless user is blacklisted or something print(ret, ret.text) raise ConnectionRefusedError("Unable to refresh JWT authorization token! This is a server issue!!") def _request( self, method: str, endpoint: str, *, body: Optional[Union[bytes, str]] = None, url_params: Optional[Dict[str, Any]] = None, file_data: Optional[Iterable[Tuple[str, Any]]] = None, internal_retry: Optional[bool] = True, allow_retries: bool = True, additional_headers: Optional[Dict[str, Any]] = None, ) -> requests.Response: # If refresh token has expired, log in again if self._jwt_refresh_exp and self._jwt_refresh_exp < time.time(): self._get_JWT_token() # If only the JWT token is expired, automatically renew it if self._jwt_access_exp and self._jwt_access_exp < time.time(): self._refresh_JWT_token() full_uri = self.address + endpoint headers = {} # Let requests handle content-type if doing multipart # but specify our encoding otherwise if file_data is None: headers = {"Content-Type": self.encoding} if additional_headers is not None: headers.update(additional_headers) req = requests.Request( method=method.upper(), url=full_uri, data=body, params=url_params, files=file_data, headers=headers ) r = self._send_request(req, allow_retries=allow_retries) # If JWT token expired, automatically renew it and retry once. This should have been caught above, # but can happen in rare instances where the token expires between the time we check it and the time # we use it. if internal_retry and (r.status_code == 401) and "Token has expired" in r.json()["msg"]: self._refresh_JWT_token() return self._request(method, endpoint, body=body, url_params=url_params, internal_retry=False) if r.status_code != 200: try: # For many errors returned by our code, the error details are returned as json # with the error message stored under "msg" details = r.json() except: # If this error comes from, ie, the web server or something else, then # we have to use 'reason' details = {"msg": r.reason} raise PortalRequestError(f"Request failed: {details['msg']}", r.status_code, details) return r # Overload for giving a response model @overload def make_request( self, method: str, endpoint: str, response_model: Type[_V], *, body_model: Optional[Type[_T]] = None, url_params_model: Optional[Type[_U]] = None, body: Optional[Union[_T, Dict[str, Any]]] = None, url_params: Optional[Union[_U, Dict[str, Any]]] = None, upload_files: Optional[Iterable[Tuple[str, str]]] = None, allow_retries: bool = True, additional_headers: Optional[Dict[str, Any]] = None, ) -> _V: ... # Overload for no response model @overload def make_request( self, method: str, endpoint: str, response_model: None, *, body_model: Optional[Type[_T]] = None, url_params_model: Optional[Type[_U]] = None, body: Optional[Union[_T, Dict[str, Any]]] = None, url_params: Optional[Union[_U, Dict[str, Any]]] = None, upload_files: Optional[Iterable[Tuple[str, str]]] = None, allow_retries: bool = True, additional_headers: Optional[Dict[str, Any]] = None, ) -> None: ...
[docs] def make_request( self, method: str, endpoint: str, response_model: Type[_V] | None, *, body_model: Optional[Type[_T]] = None, url_params_model: Optional[Type[_U]] = None, body: Optional[Union[_T, Dict[str, Any]]] = None, url_params: Optional[Union[_U, Dict[str, Any]]] = None, upload_files: Optional[Iterable[Tuple[str, str]]] = None, allow_retries: bool = True, additional_headers: Optional[Dict[str, Any]] = None, ) -> _V | None: # If body_model or url_params_model are None, then use the type given if body_model is None and body is not None: body_model = type(body) if url_params_model is None and url_params is not None: url_params_model = type(url_params) serialized_body = None if body_model is not None: parsed_body = pydantic.TypeAdapter(body_model).validate_python(body) serialized_body = serialize(parsed_body, self.encoding) parsed_url_params = None if url_params_model is not None: parsed_url_params = pydantic.TypeAdapter(url_params_model).validate_python(url_params) if isinstance(parsed_url_params, pydantic.BaseModel): parsed_url_params = parsed_url_params.model_dump() if upload_files is not None: # Yes, a list of tuples. We always use the "files" key, and doing it this way # allows for multiple files to be uploaded in a single request. file_data = [("files", (fname, open(fpath, "rb"))) for fname, fpath in upload_files] # We must also send the serialized body as part of the multipart upload if serialized_body is not None: file_data.append(("body_data", ("body_data", serialized_body, self.encoding))) serialized_body = None else: file_data = None assert (serialized_body is None) or (file_data is None) # Just to check my logic r = self._request( method, endpoint, body=serialized_body, url_params=parsed_url_params, file_data=file_data, allow_retries=allow_retries, additional_headers=additional_headers, ) return deserialize(r.content, r.headers["Content-Type"], response_model)
[docs] def download_file( self, endpoint: str, destination_path: str, overwrite: bool = False, expected_size: Optional[int] = None, show_progress: bool = False, ) -> Tuple[int, str]: """ Download a file with optional progress bar Parameters ---------- endpoint API endpoint to download from destination_path Where to save the file overwrite Whether to overwrite existing files expected_size Expected size of the file in bytes (used for progress bar if enabled) show_progress Whether to show a progress bar during download """ sha256 = hashlib.sha256() file_size = 0 # Remove if overwrite=True. This allows for any processes still using the old file to keep using it # (at least on linux) if os.path.exists(destination_path): if overwrite: os.remove(destination_path) else: raise RuntimeError(f"File already exists at {destination_path}. To overwrite, use `overwrite=True`") full_uri = self.address + endpoint response = self._req_session.get(full_uri, stream=True, allow_redirects=False) if response.is_redirect: # send again, but using a plain requests object # that way, we don't pass the JWT to someone else new_location = response.headers["Location"] response = requests.get(new_location, stream=True, allow_redirects=True) response.raise_for_status() # Get filename for display in progress bar filename = os.path.basename(destination_path) with open(destination_path, "wb") as f: if show_progress: # Show progress bar with expected_size (which can be None) with tqdm( total=expected_size, unit="B", unit_scale=True, unit_divisor=1024, desc=f"Downloading to {filename}", miniters=1, # Update after each iteration mininterval=0.1, # Allow updates as frequently as every 0.1 seconds ) as pbar: # set chunk_size here so that progress bar can be updated # if set to None, it may only be updated after the entire download is complete # 4*1024*1024 is 4MB for chunk in response.iter_content(chunk_size=4 * 1024 * 1024): if chunk: f.write(chunk) sha256.update(chunk) chunk_size = len(chunk) file_size += chunk_size pbar.update(chunk_size) else: for chunk in response.iter_content(chunk_size=None): if chunk: f.write(chunk) sha256.update(chunk) file_size += len(chunk) return file_size, sha256.hexdigest()
[docs] def ping(self) -> bool: """ Pings the server to see if it is up Returns ------- : True if the server is up and responded to the ping. False otherwise """ uri = f"{self.address}/api/v1/ping" try: r = requests.get(uri) return r.json()["success"] except AllowedConnectionExceptions: return False
[docs] def get_server_information(self) -> Dict[str, Any]: """Request general information about the server Returns ------- : Server information. """ # Request the info, and store here for later use # TODO - this fallback is temporary - remove in a future version try: return self.make_request("get", self._information_endpoint, Dict[str, Any]) except PortalRequestError as e: return self.make_request("get", "api/v1/information", Dict[str, Any])