from __future__ import annotations
import hashlib
import logging
import os
import random
import time
from typing import Any, Dict, Optional, Union, TypeVar, Type, overload, Tuple, Iterable
import jwt
import pydantic
import requests
import urllib3.exceptions
import yaml
from packaging.version import parse as parse_version
from tqdm import tqdm
from . import __version__
from .exceptions import AuthenticationFailure
from .serialization import serialize, deserialize
AllowedConnectionExceptions = (
ConnectionError,
requests.exceptions.Timeout,
requests.exceptions.ConnectionError,
urllib3.exceptions.TimeoutError,
)
_T = TypeVar("_T")
_U = TypeVar("_U")
_V = TypeVar("_V")
_ssl_error_msg = (
"\n\nSSL handshake failed. This is likely caused by a failure to retrieve 3rd party SSL certificates.\n"
"If you trust the server you are connecting to, try 'PortalClient(... verify=False)'"
)
_connection_error_msg = "\n\nCould not connect to server {}, please check the address and try again."
[docs]
def pretty_print_request(req):
print("----------------------")
print(f"{req.method} {req.url}")
print("\n".join(f"{k}: {v}" for k, v in req.headers.items()))
print("----------------------")
[docs]
def pretty_print_response(res):
print("----------------------")
print(f"RESPONSE {res.url} -> {res.status_code}")
print(f"Content: {len(res.content)} bytes")
print("\n".join(f"{k}: {v}" for k, v in res.headers.items()))
print("----------------------")
[docs]
class PortalRequestError(Exception):
def __init__(self, msg: str, status_code: int, details: Dict[str, Any]):
Exception.__init__(self, msg)
self.msg = msg
self.status_code = status_code
self.details = details
def __str__(self):
return f"{self.msg} (HTTP status {self.status_code})"
[docs]
class PortalClientBase:
def __init__(
self,
address: str,
username: Optional[str] = None,
password: Optional[str] = None,
verify: bool = True,
show_motd: bool = True,
*,
information_endpoint: str = "api/v1/information",
) -> None:
"""Initializes a PortalClient instance from an address and verification information.
Parameters
----------
address
The IP and port of the FractalServer instance ("192.168.1.1:8888")
username
The username to authenticate with.
password
The password to authenticate with.
verify
Verifies the SSL connection with a third party server. This may be False if a
FractalServer was not provided a SSL certificate and defaults back to self-signed
SSL keys.
show_motd
If a Message-of-the-Day is available, display it
"""
self._logger = logging.getLogger("PortalClientBase")
# For developer use and debugging
self.debug_requests = False
# Where we get the server information from
self._information_endpoint = information_endpoint.strip("/")
if not address.startswith("http://") and not address.startswith("https://"):
address = "https://" + address
# If we are `http`, ignore all SSL directives
if not address.startswith("https"):
self._verify = True
if not address.endswith("/"):
address += "/"
self.address = address
self.username = username
self.user_id = None
self._verify = verify
# A persistent session
# This results in significant speedup (~65% faster in my test)
# https://docs.python-requests.org/en/master/user/advanced/#session-objects
self._req_session = requests.Session()
self._req_session.headers.update({"User-Agent": f"qcportal/{__version__}"})
self.encoding = "application/json"
self.timeout = 60
# Handling retries of requests
self.retry_max = 5
self.retry_delay = 0.5
self.retry_backoff = 2
self.retry_jitter_fraction = 0.05
# Processing/downloading in threads
# Number of threads to use when fetching from the server
self.n_download_threads = 2
# Target time for how long a request should take (in seconds)
# Chunk size will be adjusted to try to reach this target time
self.download_target_time = 0.50
# If no 3rd party verification, quiet urllib
if self._verify is False:
requests.packages.urllib3.disable_warnings(category=urllib3.exceptions.InsecureRequestWarning)
if username is not None and password is not None:
self._username = username
self._password = password
self._get_JWT_token()
else:
self._username = None
self._password = None
self._jwt_access_exp = None
self._jwt_refresh_exp = None
# Try to connect and pull the server info
self.server_info = self.get_server_information()
self.server_name = self.server_info["name"]
self.api_limits = self.server_info["api_limits"]
server_version = parse_version(self.server_info["version"])
client_version = parse_version(__version__)
if client_version > server_version:
self._logger.warning(
"WARNING: This client version is newer than the server version. This may work if the "
"versions are close, but expect exceptions and errors if attempting things the server "
"does not support. "
f"client version: {str(__version__)}, server version: {str(self.server_info['version'])}"
)
motd = self.server_info.get("motd", "")
if show_motd and motd:
print("*" * 10 + "- Message-of-the-Day from the server -" + "*" * 10)
print()
print(motd)
print()
print("*" * 14 + "- End of Message-of-the-Day -" + "*" * 15)
[docs]
@classmethod
def from_file(cls, server_name: Optional[str] = None, config_path: Optional[str] = None):
"""Creates a new client given information in a file.
If no path is passed in, the current working directory and finally ~/.qca
are searched for "qcportal_config.yaml"
Parameters
----------
server_name
Name/alias of the server in the yaml file
config_path
Full path to a configuration file, or a directory containing "qcportal_config.yaml".
"""
# Search canonical paths
if config_path is None:
test_paths = [os.getcwd(), os.path.join(os.path.expanduser("~"), ".qca")]
for path in test_paths:
local_path = os.path.join(path, "qcportal_config.yaml")
if os.path.exists(local_path):
config_path = local_path
break
if config_path is None:
raise FileNotFoundError(
"Could not find `qcportal_config.yaml` in the following paths:\n {}".format(
", ".join(test_paths)
)
)
else:
config_path = os.path.join(os.path.expanduser(config_path))
# Gave folder, not file
if os.path.isdir(config_path):
config_path = os.path.join(config_path, "qcportal_config.yaml")
with open(config_path, "r") as handle:
data = yaml.load(handle, Loader=yaml.SafeLoader)
if server_name is not None:
data = data.get(server_name)
if data is None:
raise RuntimeError(f"Server '{server_name}' does not exist in the configuration file")
if "address" not in data:
raise KeyError("Config file must at least contain an address field.")
return cls(**data)
[docs]
@classmethod
def from_env(cls):
"""Creates a new client given information stored in environment variables
The environment variables are:
* QCPORTAL_ADDRESS (required)
* QCPORTAL_USERNAME (optional)
* QCPORTAL_PASSWORD (optional)
* QCPORTAL_VERIFY (optional, defaults to True)
* QCPORTAL_CACHE_DIR (optional)
"""
address = os.environ.get("QCPORTAL_ADDRESS", None)
username = os.environ.get("QCPORTAL_USERNAME", None)
password = os.environ.get("QCPORTAL_PASSWORD", None)
verify = os.environ.get("QCPORTAL_VERIFY", True)
cache_dir = os.environ.get("QCPORTAL_CACHE_DIR", None)
if address is None:
raise KeyError("Required environment variable 'QCPORTAL_ADDRESS' not found")
data = {"address": address}
if username is not None:
data["username"] = username
if password is not None:
data["password"] = password
if cache_dir is not None:
data["cache_dir"] = cache_dir
data["verify"] = verify
return cls(**data)
@property
def encoding(self) -> str:
return self._encoding
@encoding.setter
def encoding(self, encoding: str):
self._encoding = encoding
enc_headers = {"Accept": encoding}
self._req_session.headers.update(enc_headers)
def _send_request(self, req: requests.Request, allow_retries: bool = True) -> requests.Response:
"""
Sends a prepared request, optionally retrying on errors
Parameters
----------
req
A prepared request to send
allow_retries
If true, attempts to retry on certain kinds of errors
Returns
-------
:
The response returned from the request
"""
prep_req = self._req_session.prepare_request(req)
if self.debug_requests:
pretty_print_request(prep_req)
if not allow_retries:
ret = self._req_session.send(prep_req, verify=self._verify, timeout=self.timeout, allow_redirects=False)
if self.debug_requests:
pretty_print_response(ret)
if ret.is_redirect:
raise RuntimeError("Redirection is not allowed")
return ret
retry_count = 0
try:
while True:
try:
ret = self._req_session.send(
prep_req, verify=self._verify, timeout=self.timeout, allow_redirects=False
)
break
except requests.exceptions.SSLError:
raise ConnectionRefusedError(_ssl_error_msg) from None
except AllowedConnectionExceptions as e:
if retry_count >= self.retry_max:
raise
# eg, if jitter fraction is 0.05, then multiply by something on the range 0.95 to 1.05
jitter = random.uniform(1.0 - self.retry_jitter_fraction, 1.0 + self.retry_jitter_fraction)
time_to_wait = self.retry_delay * (self.retry_backoff**retry_count) * jitter
retry_count += 1
self._logger.warning(
f"Connection error for {prep_req.url}: {str(e)} - retrying in {time_to_wait:.2f} seconds "
f"[{retry_count}/{self.retry_max}]"
)
time.sleep(time_to_wait)
except requests.exceptions.SSLError:
raise ConnectionRefusedError(_ssl_error_msg) from None
except AllowedConnectionExceptions:
raise ConnectionRefusedError(_connection_error_msg.format(self.address)) from None
if self.debug_requests:
pretty_print_response(ret)
if ret.is_redirect:
raise RuntimeError("Redirection is not allowed")
return ret
def _get_JWT_token(self) -> None:
full_uri = self.address + "auth/v1/login"
json = {"username": self._username, "password": self._password}
req = requests.Request(method="POST", url=full_uri, json=json)
ret = self._send_request(req)
if ret.status_code == 200:
ret_json = ret.json()
self._jwt_refresh_token = ret_json["refresh_token"]
self._jwt_access_token = ret_json["access_token"]
self._req_session.headers.update({"Authorization": f"Bearer {self._jwt_access_token}"})
# Store the expiration time of the access and refresh tokens
# (these are unix epoch timestamps)
decoded_access_token = jwt.decode(
self._jwt_access_token, algorithms=["HS256"], options={"verify_signature": False}
)
decoded_refresh_token = jwt.decode(
self._jwt_refresh_token, algorithms=["HS256"], options={"verify_signature": False}
)
self._jwt_access_exp = decoded_access_token["exp"]
self._jwt_refresh_exp = decoded_refresh_token["exp"]
self.user_id = int(decoded_access_token["sub"]) # "identity" "subject"
else:
try:
msg = ret.json()["msg"]
except:
msg = ret.reason
raise AuthenticationFailure(msg)
def _refresh_JWT_token(self) -> None:
full_uri = self.address + "auth/v1/refresh"
headers = {"Authorization": f"Bearer {self._jwt_refresh_token}"}
req = requests.Request(method="POST", url=full_uri, headers=headers)
ret = self._send_request(req)
if ret.status_code == 200:
ret_json = ret.json()
self._jwt_access_token = ret_json["access_token"]
self._req_session.headers.update({"Authorization": f"Bearer {self._jwt_access_token}"})
# Store the expiration time of the access and refresh tokens
# (these are unix epoch timestamps)
decoded_access_token = jwt.decode(
self._jwt_access_token, algorithms=["HS256"], options={"verify_signature": False}
)
self._jwt_access_exp = decoded_access_token["exp"]
elif ret.status_code == 401 and "Token has expired" in ret.json()["msg"]:
# If the refresh token has expired, try to log in again
self._get_JWT_token()
elif ret.status_code == 401 and f" is disabled" in ret.json()["msg"]:
raise AuthenticationFailure("User account has been disabled")
elif ret.status_code == 401 and f" does not exist" in ret.json()["msg"]:
raise AuthenticationFailure("User account no longer exists")
else: # shouldn't happen unless user is blacklisted or something
print(ret, ret.text)
raise ConnectionRefusedError("Unable to refresh JWT authorization token! This is a server issue!!")
def _request(
self,
method: str,
endpoint: str,
*,
body: Optional[Union[bytes, str]] = None,
url_params: Optional[Dict[str, Any]] = None,
file_data: Optional[Iterable[Tuple[str, Any]]] = None,
internal_retry: Optional[bool] = True,
allow_retries: bool = True,
additional_headers: Optional[Dict[str, Any]] = None,
) -> requests.Response:
# If refresh token has expired, log in again
if self._jwt_refresh_exp and self._jwt_refresh_exp < time.time():
self._get_JWT_token()
# If only the JWT token is expired, automatically renew it
if self._jwt_access_exp and self._jwt_access_exp < time.time():
self._refresh_JWT_token()
full_uri = self.address + endpoint
headers = {}
# Let requests handle content-type if doing multipart
# but specify our encoding otherwise
if file_data is None:
headers = {"Content-Type": self.encoding}
if additional_headers is not None:
headers.update(additional_headers)
req = requests.Request(
method=method.upper(), url=full_uri, data=body, params=url_params, files=file_data, headers=headers
)
r = self._send_request(req, allow_retries=allow_retries)
# If JWT token expired, automatically renew it and retry once. This should have been caught above,
# but can happen in rare instances where the token expires between the time we check it and the time
# we use it.
if internal_retry and (r.status_code == 401) and "Token has expired" in r.json()["msg"]:
self._refresh_JWT_token()
return self._request(method, endpoint, body=body, url_params=url_params, internal_retry=False)
if r.status_code != 200:
try:
# For many errors returned by our code, the error details are returned as json
# with the error message stored under "msg"
details = r.json()
except:
# If this error comes from, ie, the web server or something else, then
# we have to use 'reason'
details = {"msg": r.reason}
raise PortalRequestError(f"Request failed: {details['msg']}", r.status_code, details)
return r
# Overload for giving a response model
@overload
def make_request(
self,
method: str,
endpoint: str,
response_model: Type[_V],
*,
body_model: Optional[Type[_T]] = None,
url_params_model: Optional[Type[_U]] = None,
body: Optional[Union[_T, Dict[str, Any]]] = None,
url_params: Optional[Union[_U, Dict[str, Any]]] = None,
upload_files: Optional[Iterable[Tuple[str, str]]] = None,
allow_retries: bool = True,
additional_headers: Optional[Dict[str, Any]] = None,
) -> _V: ...
# Overload for no response model
@overload
def make_request(
self,
method: str,
endpoint: str,
response_model: None,
*,
body_model: Optional[Type[_T]] = None,
url_params_model: Optional[Type[_U]] = None,
body: Optional[Union[_T, Dict[str, Any]]] = None,
url_params: Optional[Union[_U, Dict[str, Any]]] = None,
upload_files: Optional[Iterable[Tuple[str, str]]] = None,
allow_retries: bool = True,
additional_headers: Optional[Dict[str, Any]] = None,
) -> None: ...
[docs]
def make_request(
self,
method: str,
endpoint: str,
response_model: Type[_V] | None,
*,
body_model: Optional[Type[_T]] = None,
url_params_model: Optional[Type[_U]] = None,
body: Optional[Union[_T, Dict[str, Any]]] = None,
url_params: Optional[Union[_U, Dict[str, Any]]] = None,
upload_files: Optional[Iterable[Tuple[str, str]]] = None,
allow_retries: bool = True,
additional_headers: Optional[Dict[str, Any]] = None,
) -> _V | None:
# If body_model or url_params_model are None, then use the type given
if body_model is None and body is not None:
body_model = type(body)
if url_params_model is None and url_params is not None:
url_params_model = type(url_params)
serialized_body = None
if body_model is not None:
parsed_body = pydantic.TypeAdapter(body_model).validate_python(body)
serialized_body = serialize(parsed_body, self.encoding)
parsed_url_params = None
if url_params_model is not None:
parsed_url_params = pydantic.TypeAdapter(url_params_model).validate_python(url_params)
if isinstance(parsed_url_params, pydantic.BaseModel):
parsed_url_params = parsed_url_params.model_dump()
if upload_files is not None:
# Yes, a list of tuples. We always use the "files" key, and doing it this way
# allows for multiple files to be uploaded in a single request.
file_data = [("files", (fname, open(fpath, "rb"))) for fname, fpath in upload_files]
# We must also send the serialized body as part of the multipart upload
if serialized_body is not None:
file_data.append(("body_data", ("body_data", serialized_body, self.encoding)))
serialized_body = None
else:
file_data = None
assert (serialized_body is None) or (file_data is None) # Just to check my logic
r = self._request(
method,
endpoint,
body=serialized_body,
url_params=parsed_url_params,
file_data=file_data,
allow_retries=allow_retries,
additional_headers=additional_headers,
)
return deserialize(r.content, r.headers["Content-Type"], response_model)
[docs]
def download_file(
self,
endpoint: str,
destination_path: str,
overwrite: bool = False,
expected_size: Optional[int] = None,
show_progress: bool = False,
) -> Tuple[int, str]:
"""
Download a file with optional progress bar
Parameters
----------
endpoint
API endpoint to download from
destination_path
Where to save the file
overwrite
Whether to overwrite existing files
expected_size
Expected size of the file in bytes (used for progress bar if enabled)
show_progress
Whether to show a progress bar during download
"""
sha256 = hashlib.sha256()
file_size = 0
# Remove if overwrite=True. This allows for any processes still using the old file to keep using it
# (at least on linux)
if os.path.exists(destination_path):
if overwrite:
os.remove(destination_path)
else:
raise RuntimeError(f"File already exists at {destination_path}. To overwrite, use `overwrite=True`")
full_uri = self.address + endpoint
response = self._req_session.get(full_uri, stream=True, allow_redirects=False)
if response.is_redirect:
# send again, but using a plain requests object
# that way, we don't pass the JWT to someone else
new_location = response.headers["Location"]
response = requests.get(new_location, stream=True, allow_redirects=True)
response.raise_for_status()
# Get filename for display in progress bar
filename = os.path.basename(destination_path)
with open(destination_path, "wb") as f:
if show_progress:
# Show progress bar with expected_size (which can be None)
with tqdm(
total=expected_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
desc=f"Downloading to {filename}",
miniters=1, # Update after each iteration
mininterval=0.1, # Allow updates as frequently as every 0.1 seconds
) as pbar:
# set chunk_size here so that progress bar can be updated
# if set to None, it may only be updated after the entire download is complete
# 4*1024*1024 is 4MB
for chunk in response.iter_content(chunk_size=4 * 1024 * 1024):
if chunk:
f.write(chunk)
sha256.update(chunk)
chunk_size = len(chunk)
file_size += chunk_size
pbar.update(chunk_size)
else:
for chunk in response.iter_content(chunk_size=None):
if chunk:
f.write(chunk)
sha256.update(chunk)
file_size += len(chunk)
return file_size, sha256.hexdigest()
[docs]
def ping(self) -> bool:
"""
Pings the server to see if it is up
Returns
-------
:
True if the server is up and responded to the ping. False otherwise
"""
uri = f"{self.address}/api/v1/ping"
try:
r = requests.get(uri)
return r.json()["success"]
except AllowedConnectionExceptions:
return False