from __future__ import annotations
import logging
import os
import random
import time
from typing import (
Any,
Dict,
Optional,
Union,
TypeVar,
Type,
)
import jwt
try:
import pydantic.v1 as pydantic
except ImportError:
import pydantic
import requests
from typing import Tuple
import yaml
import hashlib
from packaging.version import parse as parse_version
from . import __version__
from .exceptions import AuthenticationFailure
from .serialization import serialize, deserialize
_T = TypeVar("_T")
_U = TypeVar("_U")
_V = TypeVar("_V")
_ssl_error_msg = (
"\n\nSSL handshake failed. This is likely caused by a failure to retrieve 3rd party SSL certificates.\n"
"If you trust the server you are connecting to, try 'PortalClient(... verify=False)'"
)
_connection_error_msg = "\n\nCould not connect to server {}, please check the address and try again."
[docs]
def pretty_print_request(req):
print("----------------------")
print(f"{req.method} {req.url}")
print("\n".join(f"{k}: {v}" for k, v in req.headers.items()))
print("----------------------")
[docs]
def pretty_print_response(res):
print("----------------------")
print(f"RESPONSE {res.url} -> {res.status_code}")
print(f"Content: {len(res.content)} bytes")
print("\n".join(f"{k}: {v}" for k, v in res.headers.items()))
print("----------------------")
[docs]
class PortalRequestError(Exception):
def __init__(self, msg: str, status_code: int, details: Dict[str, Any]):
Exception.__init__(self, msg)
self.msg = msg
self.status_code = status_code
self.details = details
def __str__(self):
return f"{self.msg} (HTTP status {self.status_code})"
[docs]
class PortalClientBase:
def __init__(
self,
address: str,
username: Optional[str] = None,
password: Optional[str] = None,
verify: bool = True,
show_motd: bool = True,
*,
information_endpoint: str = "api/v1/information",
) -> None:
"""Initializes a PortalClient instance from an address and verification information.
Parameters
----------
address
The IP and port of the FractalServer instance ("192.168.1.1:8888")
username
The username to authenticate with.
password
The password to authenticate with.
verify
Verifies the SSL connection with a third party server. This may be False if a
FractalServer was not provided a SSL certificate and defaults back to self-signed
SSL keys.
show_motd
If a Message-of-the-Day is available, display it
"""
self._logger = logging.getLogger("PortalClientBase")
# For developer use and debugging
self.debug_requests = False
# Where we get the server information from
self._information_endpoint = information_endpoint.strip("/")
if not address.startswith("http://") and not address.startswith("https://"):
address = "https://" + address
# If we are `http`, ignore all SSL directives
if not address.startswith("https"):
self._verify = True
if not address.endswith("/"):
address += "/"
self.address = address
self.username = username
self._verify = verify
# A persistent session
# This results in significant speedup (~65% faster in my test)
# https://docs.python-requests.org/en/master/user/advanced/#session-objects
self._req_session = requests.Session()
self._req_session.headers.update({"User-Agent": f"qcportal/{__version__}"})
self.encoding = "application/json"
self.timeout = 60
# Handling retries of requests
self.retry_max = 5
self.retry_delay = 0.5
self.retry_backoff = 2
self.retry_jitter_fraction = 0.05
# Processing/downloading in threads
# Number of threads to use when fetching from the server
self.n_download_threads = 2
# Target time for how long a request should take (in seconds)
# Chunk size will be adjusted to try to reach this target time
self.download_target_time = 0.50
# If no 3rd party verification, quiet urllib
if self._verify is False:
from urllib3.exceptions import InsecureRequestWarning
requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning)
if username is not None and password is not None:
self._username = username
self._password = password
self._get_JWT_token()
else:
self._username = None
self._password = None
self._jwt_access_exp = None
self._jwt_refresh_exp = None
# Try to connect and pull the server info
self.server_info = self.get_server_information()
self.server_name = self.server_info["name"]
self.api_limits = self.server_info["api_limits"]
server_version = parse_version(self.server_info["version"])
client_version = parse_version(__version__)
if client_version > server_version:
self._logger.warning(
"WARNING: This client version is newer than the server version. This may work if the "
"versions are close, but expect exceptions and errors if attempting things the server "
"does not support. "
f"client version: {str(__version__)}, server version: {str(self.server_info['version'])}"
)
motd = self.server_info.get("motd", "")
if show_motd and motd:
print("*" * 10 + "- Message-of-the-Day from the server -" + "*" * 10)
print()
print(motd)
print()
print("*" * 14 + "- End of Message-of-the-Day -" + "*" * 15)
[docs]
@classmethod
def from_file(cls, server_name: Optional[str] = None, config_path: Optional[str] = None):
"""Creates a new client given information in a file.
If no path is passed in, the current working directory and finally ~/.qca
are searched for "qcportal_config.yaml"
Parameters
----------
server_name
Name/alias of the server in the yaml file
config_path
Full path to a configuration file, or a directory containing "qcportal_config.yaml".
"""
# Search canonical paths
if config_path is None:
test_paths = [os.getcwd(), os.path.join(os.path.expanduser("~"), ".qca")]
for path in test_paths:
local_path = os.path.join(path, "qcportal_config.yaml")
if os.path.exists(local_path):
config_path = local_path
break
if config_path is None:
raise FileNotFoundError(
"Could not find `qcportal_config.yaml` in the following paths:\n {}".format(
", ".join(test_paths)
)
)
else:
config_path = os.path.join(os.path.expanduser(config_path))
# Gave folder, not file
if os.path.isdir(config_path):
config_path = os.path.join(config_path, "qcportal_config.yaml")
with open(config_path, "r") as handle:
data = yaml.load(handle, Loader=yaml.SafeLoader)
if server_name is not None:
data = data.get(server_name)
if data is None:
raise RuntimeError(f"Server '{server_name}' does not exist in the configuration file")
if "address" not in data:
raise KeyError("Config file must at least contain an address field.")
return cls(**data)
[docs]
@classmethod
def from_env(cls):
"""Creates a new client given information stored in environment variables
The environment variables are:
* QCPORTAL_ADDRESS (required)
* QCPORTAL_USERNAME (optional)
* QCPORTAL_PASSWORD (optional)
* QCPORTAL_VERIFY (optional, defaults to True)
* QCPORTAL_CACHE_DIR (optional)
"""
address = os.environ.get("QCPORTAL_ADDRESS", None)
username = os.environ.get("QCPORTAL_USERNAME", None)
password = os.environ.get("QCPORTAL_PASSWORD", None)
verify = os.environ.get("QCPORTAL_VERIFY", True)
cache_dir = os.environ.get("QCPORTAL_CACHE_DIR", None)
if address is None:
raise KeyError("Required environment variable 'QCPORTAL_ADDRESS' not found")
data = {"address": address}
if username is not None:
data["username"] = username
if password is not None:
data["password"] = password
if cache_dir is not None:
data["cache_dir"] = cache_dir
data["verify"] = verify
return cls(**data)
@property
def encoding(self) -> str:
return self._encoding
@encoding.setter
def encoding(self, encoding: str):
self._encoding = encoding
enc_headers = {"Content-Type": encoding, "Accept": encoding}
self._req_session.headers.update(enc_headers)
def _send_request(self, req: requests.Request, allow_retries: bool = True) -> requests.Response:
"""
Sends a prepared request, optionally retrying on errors
Parameters
----------
req
A prepared request to send
allow_retries
If true, attempts to retry on certain kinds of errors
Returns
-------
:
The response returned from the request
"""
prep_req = self._req_session.prepare_request(req)
if self.debug_requests:
pretty_print_request(prep_req)
if not allow_retries:
ret = self._req_session.send(prep_req, verify=self._verify, timeout=self.timeout, allow_redirects=False)
if self.debug_requests:
pretty_print_response(ret)
if ret.is_redirect:
raise RuntimeError("Redirection is not allowed")
return ret
retry_count = 0
try:
while True:
try:
ret = self._req_session.send(
prep_req, verify=self._verify, timeout=self.timeout, allow_redirects=False
)
break
except requests.exceptions.SSLError:
raise ConnectionRefusedError(_ssl_error_msg) from None
except (requests.exceptions.ConnectionError, requests.exceptions.ConnectTimeout) as e:
if retry_count >= self.retry_max:
raise
# eg, if jitter fraction is 0.05, then multiply by something on the range 0.95 to 1.05
jitter = random.uniform(1.0 - self.retry_jitter_fraction, 1.0 + self.retry_jitter_fraction)
time_to_wait = self.retry_delay * (self.retry_backoff**retry_count) * jitter
retry_count += 1
self._logger.warning(
f"Connection error for {prep_req.url}: {str(e)} - retrying in {time_to_wait:.2f} seconds "
f"[{retry_count}/{self.retry_max}]"
)
time.sleep(time_to_wait)
except requests.exceptions.SSLError:
raise ConnectionRefusedError(_ssl_error_msg) from None
except requests.exceptions.ConnectionError:
raise ConnectionRefusedError(_connection_error_msg.format(self.address)) from None
if self.debug_requests:
pretty_print_response(ret)
if ret.is_redirect:
raise RuntimeError("Redirection is not allowed")
return ret
def _get_JWT_token(self) -> None:
full_uri = self.address + "auth/v1/login"
json = {"username": self._username, "password": self._password}
req = requests.Request(method="POST", url=full_uri, json=json)
ret = self._send_request(req)
if ret.status_code == 200:
ret_json = ret.json()
self._jwt_refresh_token = ret_json["refresh_token"]
self._jwt_access_token = ret_json["access_token"]
self._req_session.headers.update({"Authorization": f"Bearer {self._jwt_access_token}"})
# Store the expiration time of the access and refresh tokens
# (these are unix epoch timestamps)
decoded_access_token = jwt.decode(
self._jwt_access_token, algorithms=["HS256"], options={"verify_signature": False}
)
decoded_refresh_token = jwt.decode(
self._jwt_refresh_token, algorithms=["HS256"], options={"verify_signature": False}
)
self._jwt_access_exp = decoded_access_token["exp"]
self._jwt_refresh_exp = decoded_refresh_token["exp"]
else:
try:
msg = ret.json()["msg"]
except:
msg = ret.reason
raise AuthenticationFailure(msg)
def _refresh_JWT_token(self) -> None:
full_uri = self.address + "auth/v1/refresh"
headers = {"Authorization": f"Bearer {self._jwt_refresh_token}"}
req = requests.Request(method="POST", url=full_uri, headers=headers)
ret = self._send_request(req)
if ret.status_code == 200:
ret_json = ret.json()
self._jwt_access_token = ret_json["access_token"]
self._req_session.headers.update({"Authorization": f"Bearer {self._jwt_access_token}"})
# Store the expiration time of the access and refresh tokens
# (these are unix epoch timestamps)
decoded_access_token = jwt.decode(
self._jwt_access_token, algorithms=["HS256"], options={"verify_signature": False}
)
self._jwt_access_exp = decoded_access_token["exp"]
elif ret.status_code == 401 and "Token has expired" in ret.json()["msg"]:
# If the refresh token has expired, try to log in again
self._get_JWT_token()
elif ret.status_code == 401 and f" is disabled" in ret.json()["msg"]:
raise AuthenticationFailure("User account has been disabled")
elif ret.status_code == 401 and f" does not exist" in ret.json()["msg"]:
raise AuthenticationFailure("User account no longer exists")
else: # shouldn't happen unless user is blacklisted or something
print(ret, ret.text)
raise ConnectionRefusedError("Unable to refresh JWT authorization token! This is a server issue!!")
def _request(
self,
method: str,
endpoint: str,
*,
body: Optional[Union[bytes, str]] = None,
url_params: Optional[Dict[str, Any]] = None,
internal_retry: Optional[bool] = True,
allow_retries: bool = True,
additional_headers: Optional[Dict[str, Any]] = None,
) -> requests.Response:
# If refresh token has expired, log in again
if self._jwt_refresh_exp and self._jwt_refresh_exp < time.time():
self._get_JWT_token()
# If only the JWT token is expired, automatically renew it
if self._jwt_access_exp and self._jwt_access_exp < time.time():
self._refresh_JWT_token()
full_uri = self.address + endpoint
req = requests.Request(
method=method.upper(), url=full_uri, data=body, params=url_params, headers=additional_headers
)
r = self._send_request(req, allow_retries=allow_retries)
# If JWT token expired, automatically renew it and retry once. This should have been caught above,
# but can happen in rare instances where the token expires between the time we check it and the time
# we use it.
if internal_retry and (r.status_code == 401) and "Token has expired" in r.json()["msg"]:
self._refresh_JWT_token()
return self._request(method, endpoint, body=body, url_params=url_params, internal_retry=False)
if r.status_code != 200:
try:
# For many errors returned by our code, the error details are returned as json
# with the error message stored under "msg"
details = r.json()
except:
# If this error comes from, ie, the web server or something else, then
# we have to use 'reason'
details = {"msg": r.reason}
raise PortalRequestError(f"Request failed: {details['msg']}", r.status_code, details)
return r
[docs]
def make_request(
self,
method: str,
endpoint: str,
response_model: Optional[Type[_V]],
*,
body_model: Optional[Type[_T]] = None,
url_params_model: Optional[Type[_U]] = None,
body: Optional[Union[_T, Dict[str, Any]]] = None,
url_params: Optional[Union[_U, Dict[str, Any]]] = None,
allow_retries: bool = True,
additional_headers: Optional[Dict[str, Any]] = None,
) -> _V:
# If body_model or url_params_model are None, then use the type given
if body_model is None and body is not None:
body_model = type(body)
if url_params_model is None and url_params is not None:
url_params_model = type(url_params)
serialized_body = None
if body_model is not None:
parsed_body = pydantic.parse_obj_as(body_model, body)
serialized_body = serialize(parsed_body, self.encoding)
parsed_url_params = None
if url_params_model is not None:
parsed_url_params = pydantic.parse_obj_as(url_params_model, url_params)
if isinstance(parsed_url_params, pydantic.BaseModel):
parsed_url_params = parsed_url_params.dict()
r = self._request(
method,
endpoint,
body=serialized_body,
url_params=parsed_url_params,
allow_retries=allow_retries,
additional_headers=additional_headers,
)
d = deserialize(r.content, r.headers["Content-Type"])
if response_model is None:
return None
else:
return pydantic.parse_obj_as(response_model, d)
[docs]
def download_file(self, endpoint: str, destination_path: str, overwrite: bool = False) -> Tuple[int, str]:
sha256 = hashlib.sha256()
file_size = 0
# Remove if overwrite=True. This allows for any processes still using the old file to keep using it
# (at least on linux)
if os.path.exists(destination_path):
if overwrite:
os.remove(destination_path)
else:
raise RuntimeError(f"File already exists at {destination_path}. To overwrite, use `overwrite=True`")
full_uri = self.address + endpoint
response = self._req_session.get(full_uri, stream=True, allow_redirects=False)
if response.is_redirect:
# send again, but using a plain requests object
# that way, we don't pass the JWT to someone else
new_location = response.headers["Location"]
response = requests.get(new_location, stream=True, allow_redirects=True)
response.raise_for_status()
with open(destination_path, "wb") as f:
for chunk in response.iter_content(chunk_size=None):
if chunk:
f.write(chunk)
sha256.update(chunk)
file_size += len(chunk)
return file_size, sha256.hexdigest()
[docs]
def ping(self) -> bool:
"""
Pings the server to see if it is up
Returns
-------
:
True if the server is up and responded to the ping. False otherwise
"""
uri = f"{self.address}/api/v1/ping"
try:
r = requests.get(uri)
return r.json()["success"]
except requests.exceptions.ConnectionError:
return False