from __future__ import annotations
import logging
import os
import random
import time
from typing import (
    Any,
    Dict,
    Optional,
    Union,
    TypeVar,
    Type,
)
import jwt
import urllib3.exceptions
try:
    import pydantic.v1 as pydantic
except ImportError:
    import pydantic
import requests
from typing import Tuple, Iterable
import yaml
import hashlib
from packaging.version import parse as parse_version
from tqdm import tqdm
from . import __version__
from .exceptions import AuthenticationFailure
from .serialization import serialize, deserialize
AllowedConnectionExceptions = (
    ConnectionError,
    requests.exceptions.Timeout,
    requests.exceptions.ConnectionError,
    urllib3.exceptions.TimeoutError,
)
_T = TypeVar("_T")
_U = TypeVar("_U")
_V = TypeVar("_V")
_ssl_error_msg = (
    "\n\nSSL handshake failed. This is likely caused by a failure to retrieve 3rd party SSL certificates.\n"
    "If you trust the server you are connecting to, try 'PortalClient(... verify=False)'"
)
_connection_error_msg = "\n\nCould not connect to server {}, please check the address and try again."
[docs]
def pretty_print_request(req):
    print("----------------------")
    print(f"{req.method} {req.url}")
    print("\n".join(f"{k}: {v}" for k, v in req.headers.items()))
    print("----------------------") 
[docs]
def pretty_print_response(res):
    print("----------------------")
    print(f"RESPONSE {res.url} -> {res.status_code}")
    print(f"Content: {len(res.content)} bytes")
    print("\n".join(f"{k}: {v}" for k, v in res.headers.items()))
    print("----------------------") 
[docs]
class PortalRequestError(Exception):
    def __init__(self, msg: str, status_code: int, details: Dict[str, Any]):
        Exception.__init__(self, msg)
        self.msg = msg
        self.status_code = status_code
        self.details = details
    def __str__(self):
        return f"{self.msg} (HTTP status {self.status_code})" 
[docs]
class PortalClientBase:
    def __init__(
        self,
        address: str,
        username: Optional[str] = None,
        password: Optional[str] = None,
        verify: bool = True,
        show_motd: bool = True,
        *,
        information_endpoint: str = "api/v1/information",
    ) -> None:
        """Initializes a PortalClient instance from an address and verification information.
        Parameters
        ----------
        address
            The IP and port of the FractalServer instance ("192.168.1.1:8888")
        username
            The username to authenticate with.
        password
            The password to authenticate with.
        verify
            Verifies the SSL connection with a third party server. This may be False if a
            FractalServer was not provided a SSL certificate and defaults back to self-signed
            SSL keys.
        show_motd
            If a Message-of-the-Day is available, display it
        """
        self._logger = logging.getLogger("PortalClientBase")
        # For developer use and debugging
        self.debug_requests = False
        # Where we get the server information from
        self._information_endpoint = information_endpoint.strip("/")
        if not address.startswith("http://") and not address.startswith("https://"):
            address = "https://" + address
        # If we are `http`, ignore all SSL directives
        if not address.startswith("https"):
            self._verify = True
        if not address.endswith("/"):
            address += "/"
        self.address = address
        self.username = username
        self._verify = verify
        # A persistent session
        # This results in significant speedup (~65% faster in my test)
        # https://docs.python-requests.org/en/master/user/advanced/#session-objects
        self._req_session = requests.Session()
        self._req_session.headers.update({"User-Agent": f"qcportal/{__version__}"})
        self.encoding = "application/json"
        self.timeout = 60
        # Handling retries of requests
        self.retry_max = 5
        self.retry_delay = 0.5
        self.retry_backoff = 2
        self.retry_jitter_fraction = 0.05
        # Processing/downloading in threads
        # Number of threads to use when fetching from the server
        self.n_download_threads = 2
        # Target time for how long a request should take (in seconds)
        # Chunk size will be adjusted to try to reach this target time
        self.download_target_time = 0.50
        # If no 3rd party verification, quiet urllib
        if self._verify is False:
            requests.packages.urllib3.disable_warnings(category=urllib3.exceptions.InsecureRequestWarning)
        if username is not None and password is not None:
            self._username = username
            self._password = password
            self._get_JWT_token()
        else:
            self._username = None
            self._password = None
            self._jwt_access_exp = None
            self._jwt_refresh_exp = None
        # Try to connect and pull the server info
        self.server_info = self.get_server_information()
        self.server_name = self.server_info["name"]
        self.api_limits = self.server_info["api_limits"]
        server_version = parse_version(self.server_info["version"])
        client_version = parse_version(__version__)
        if client_version > server_version:
            self._logger.warning(
                "WARNING: This client version is newer than the server version. This may work if the "
                "versions are close, but expect exceptions and errors if attempting things the server "
                "does not support. "
                f"client version: {str(__version__)}, server version: {str(self.server_info['version'])}"
            )
        motd = self.server_info.get("motd", "")
        if show_motd and motd:
            print("*" * 10 + "- Message-of-the-Day from the server -" + "*" * 10)
            print()
            print(motd)
            print()
            print("*" * 14 + "- End of Message-of-the-Day -" + "*" * 15)
[docs]
    @classmethod
    def from_file(cls, server_name: Optional[str] = None, config_path: Optional[str] = None):
        """Creates a new client given information in a file.
        If no path is passed in, the current working directory and finally ~/.qca
        are searched for "qcportal_config.yaml"
        Parameters
        ----------
        server_name
            Name/alias of the server in the yaml file
        config_path
            Full path to a configuration file, or a directory containing "qcportal_config.yaml".
        """
        # Search canonical paths
        if config_path is None:
            test_paths = [os.getcwd(), os.path.join(os.path.expanduser("~"), ".qca")]
            for path in test_paths:
                local_path = os.path.join(path, "qcportal_config.yaml")
                if os.path.exists(local_path):
                    config_path = local_path
                    break
            if config_path is None:
                raise FileNotFoundError(
                    "Could not find `qcportal_config.yaml` in the following paths:\n    {}".format(
                        ", ".join(test_paths)
                    )
                )
        else:
            config_path = os.path.join(os.path.expanduser(config_path))
            # Gave folder, not file
            if os.path.isdir(config_path):
                config_path = os.path.join(config_path, "qcportal_config.yaml")
        with open(config_path, "r") as handle:
            data = yaml.load(handle, Loader=yaml.SafeLoader)
        if server_name is not None:
            data = data.get(server_name)
            if data is None:
                raise RuntimeError(f"Server '{server_name}' does not exist in the configuration file")
        if "address" not in data:
            raise KeyError("Config file must at least contain an address field.")
        return cls(**data) 
[docs]
    @classmethod
    def from_env(cls):
        """Creates a new client given information stored in environment variables
        The environment variables are:
          * QCPORTAL_ADDRESS (required)
          * QCPORTAL_USERNAME (optional)
          * QCPORTAL_PASSWORD (optional)
          * QCPORTAL_VERIFY (optional, defaults to True)
          * QCPORTAL_CACHE_DIR (optional)
        """
        address = os.environ.get("QCPORTAL_ADDRESS", None)
        username = os.environ.get("QCPORTAL_USERNAME", None)
        password = os.environ.get("QCPORTAL_PASSWORD", None)
        verify = os.environ.get("QCPORTAL_VERIFY", True)
        cache_dir = os.environ.get("QCPORTAL_CACHE_DIR", None)
        if address is None:
            raise KeyError("Required environment variable 'QCPORTAL_ADDRESS' not found")
        data = {"address": address}
        if username is not None:
            data["username"] = username
        if password is not None:
            data["password"] = password
        if cache_dir is not None:
            data["cache_dir"] = cache_dir
        data["verify"] = verify
        return cls(**data) 
    @property
    def encoding(self) -> str:
        return self._encoding
    @encoding.setter
    def encoding(self, encoding: str):
        self._encoding = encoding
        enc_headers = {"Accept": encoding}
        self._req_session.headers.update(enc_headers)
    def _send_request(self, req: requests.Request, allow_retries: bool = True) -> requests.Response:
        """
        Sends a prepared request, optionally retrying on errors
        Parameters
        ----------
        req
            A prepared request to send
        allow_retries
            If true, attempts to retry on certain kinds of errors
        Returns
        -------
        :
            The response returned from the request
        """
        prep_req = self._req_session.prepare_request(req)
        if self.debug_requests:
            pretty_print_request(prep_req)
        if not allow_retries:
            ret = self._req_session.send(prep_req, verify=self._verify, timeout=self.timeout, allow_redirects=False)
            if self.debug_requests:
                pretty_print_response(ret)
            if ret.is_redirect:
                raise RuntimeError("Redirection is not allowed")
            return ret
        retry_count = 0
        try:
            while True:
                try:
                    ret = self._req_session.send(
                        prep_req, verify=self._verify, timeout=self.timeout, allow_redirects=False
                    )
                    break
                except requests.exceptions.SSLError:
                    raise ConnectionRefusedError(_ssl_error_msg) from None
                except AllowedConnectionExceptions as e:
                    if retry_count >= self.retry_max:
                        raise
                    # eg, if jitter fraction is 0.05, then multiply by something on the range 0.95 to 1.05
                    jitter = random.uniform(1.0 - self.retry_jitter_fraction, 1.0 + self.retry_jitter_fraction)
                    time_to_wait = self.retry_delay * (self.retry_backoff**retry_count) * jitter
                    retry_count += 1
                    self._logger.warning(
                        f"Connection error for {prep_req.url}: {str(e)} - retrying in {time_to_wait:.2f} seconds "
                        f"[{retry_count}/{self.retry_max}]"
                    )
                    time.sleep(time_to_wait)
        except requests.exceptions.SSLError:
            raise ConnectionRefusedError(_ssl_error_msg) from None
        except AllowedConnectionExceptions:
            raise ConnectionRefusedError(_connection_error_msg.format(self.address)) from None
        if self.debug_requests:
            pretty_print_response(ret)
        if ret.is_redirect:
            raise RuntimeError("Redirection is not allowed")
        return ret
    def _get_JWT_token(self) -> None:
        full_uri = self.address + "auth/v1/login"
        json = {"username": self._username, "password": self._password}
        req = requests.Request(method="POST", url=full_uri, json=json)
        ret = self._send_request(req)
        if ret.status_code == 200:
            ret_json = ret.json()
            self._jwt_refresh_token = ret_json["refresh_token"]
            self._jwt_access_token = ret_json["access_token"]
            self._req_session.headers.update({"Authorization": f"Bearer {self._jwt_access_token}"})
            # Store the expiration time of the access and refresh tokens
            # (these are unix epoch timestamps)
            decoded_access_token = jwt.decode(
                self._jwt_access_token, algorithms=["HS256"], options={"verify_signature": False}
            )
            decoded_refresh_token = jwt.decode(
                self._jwt_refresh_token, algorithms=["HS256"], options={"verify_signature": False}
            )
            self._jwt_access_exp = decoded_access_token["exp"]
            self._jwt_refresh_exp = decoded_refresh_token["exp"]
        else:
            try:
                msg = ret.json()["msg"]
            except:
                msg = ret.reason
            raise AuthenticationFailure(msg)
    def _refresh_JWT_token(self) -> None:
        full_uri = self.address + "auth/v1/refresh"
        headers = {"Authorization": f"Bearer {self._jwt_refresh_token}"}
        req = requests.Request(method="POST", url=full_uri, headers=headers)
        ret = self._send_request(req)
        if ret.status_code == 200:
            ret_json = ret.json()
            self._jwt_access_token = ret_json["access_token"]
            self._req_session.headers.update({"Authorization": f"Bearer {self._jwt_access_token}"})
            # Store the expiration time of the access and refresh tokens
            # (these are unix epoch timestamps)
            decoded_access_token = jwt.decode(
                self._jwt_access_token, algorithms=["HS256"], options={"verify_signature": False}
            )
            self._jwt_access_exp = decoded_access_token["exp"]
        elif ret.status_code == 401 and "Token has expired" in ret.json()["msg"]:
            # If the refresh token has expired, try to log in again
            self._get_JWT_token()
        elif ret.status_code == 401 and f" is disabled" in ret.json()["msg"]:
            raise AuthenticationFailure("User account has been disabled")
        elif ret.status_code == 401 and f" does not exist" in ret.json()["msg"]:
            raise AuthenticationFailure("User account no longer exists")
        else:  # shouldn't happen unless user is blacklisted or something
            print(ret, ret.text)
            raise ConnectionRefusedError("Unable to refresh JWT authorization token! This is a server issue!!")
    def _request(
        self,
        method: str,
        endpoint: str,
        *,
        body: Optional[Union[bytes, str]] = None,
        url_params: Optional[Dict[str, Any]] = None,
        file_data: Optional[Iterable[Tuple[str, Any]]] = None,
        internal_retry: Optional[bool] = True,
        allow_retries: bool = True,
        additional_headers: Optional[Dict[str, Any]] = None,
    ) -> requests.Response:
        # If refresh token has expired, log in again
        if self._jwt_refresh_exp and self._jwt_refresh_exp < time.time():
            self._get_JWT_token()
        # If only the JWT token is expired, automatically renew it
        if self._jwt_access_exp and self._jwt_access_exp < time.time():
            self._refresh_JWT_token()
        full_uri = self.address + endpoint
        headers = {}
        # Let requests handle content-type if doing multipart
        # but specify our encoding otherwise
        if file_data is None:
            headers = {"Content-Type": self.encoding}
        if additional_headers is not None:
            headers.update(additional_headers)
        req = requests.Request(
            method=method.upper(), url=full_uri, data=body, params=url_params, files=file_data, headers=headers
        )
        r = self._send_request(req, allow_retries=allow_retries)
        # If JWT token expired, automatically renew it and retry once. This should have been caught above,
        # but can happen in rare instances where the token expires between the time we check it and the time
        # we use it.
        if internal_retry and (r.status_code == 401) and "Token has expired" in r.json()["msg"]:
            self._refresh_JWT_token()
            return self._request(method, endpoint, body=body, url_params=url_params, internal_retry=False)
        if r.status_code != 200:
            try:
                # For many errors returned by our code, the error details are returned as json
                # with the error message stored under "msg"
                details = r.json()
            except:
                # If this error comes from, ie, the web server or something else, then
                # we have to use 'reason'
                details = {"msg": r.reason}
            raise PortalRequestError(f"Request failed: {details['msg']}", r.status_code, details)
        return r
[docs]
    def make_request(
        self,
        method: str,
        endpoint: str,
        response_model: Optional[Type[_V]],
        *,
        body_model: Optional[Type[_T]] = None,
        url_params_model: Optional[Type[_U]] = None,
        body: Optional[Union[_T, Dict[str, Any]]] = None,
        url_params: Optional[Union[_U, Dict[str, Any]]] = None,
        upload_files: Optional[Iterable[Tuple[str, str]]] = None,
        allow_retries: bool = True,
        additional_headers: Optional[Dict[str, Any]] = None,
    ) -> _V:
        # If body_model or url_params_model are None, then use the type given
        if body_model is None and body is not None:
            body_model = type(body)
        if url_params_model is None and url_params is not None:
            url_params_model = type(url_params)
        serialized_body = None
        if body_model is not None:
            parsed_body = pydantic.parse_obj_as(body_model, body)
            serialized_body = serialize(parsed_body, self.encoding)
        parsed_url_params = None
        if url_params_model is not None:
            parsed_url_params = pydantic.parse_obj_as(url_params_model, url_params)
        if isinstance(parsed_url_params, pydantic.BaseModel):
            parsed_url_params = parsed_url_params.dict()
        if upload_files is not None:
            # Yes, a list of tuples. We always use the "files" key, and doing it this way
            # allows for multiple files to be uploaded in a single request.
            file_data = [("files", (fname, open(fpath, "rb"))) for fname, fpath in upload_files]
            # We must also send the serialized body as part of the multipart upload
            if serialized_body is not None:
                file_data.append(("body_data", ("body_data", serialized_body, self.encoding)))
                serialized_body = None
        else:
            file_data = None
        assert (serialized_body is None) or (file_data is None)  # Just to check my logic
        r = self._request(
            method,
            endpoint,
            body=serialized_body,
            url_params=parsed_url_params,
            file_data=file_data,
            allow_retries=allow_retries,
            additional_headers=additional_headers,
        )
        d = deserialize(r.content, r.headers["Content-Type"])
        if response_model is None:
            return None
        else:
            return pydantic.parse_obj_as(response_model, d) 
[docs]
    def download_file(
        self,
        endpoint: str,
        destination_path: str,
        overwrite: bool = False,
        expected_size: Optional[int] = None,
        show_progress: bool = False,
    ) -> Tuple[int, str]:
        """
        Download a file with optional progress bar
        Parameters
        ----------
        endpoint
            API endpoint to download from
        destination_path
            Where to save the file
        overwrite
            Whether to overwrite existing files
        expected_size
            Expected size of the file in bytes (used for progress bar if enabled)
        show_progress
            Whether to show a progress bar during download
        """
        sha256 = hashlib.sha256()
        file_size = 0
        # Remove if overwrite=True. This allows for any processes still using the old file to keep using it
        # (at least on linux)
        if os.path.exists(destination_path):
            if overwrite:
                os.remove(destination_path)
            else:
                raise RuntimeError(f"File already exists at {destination_path}. To overwrite, use `overwrite=True`")
        full_uri = self.address + endpoint
        response = self._req_session.get(full_uri, stream=True, allow_redirects=False)
        if response.is_redirect:
            # send again, but using a plain requests object
            # that way, we don't pass the JWT to someone else
            new_location = response.headers["Location"]
            response = requests.get(new_location, stream=True, allow_redirects=True)
        response.raise_for_status()
        # Get filename for display in progress bar
        filename = os.path.basename(destination_path)
        with open(destination_path, "wb") as f:
            if show_progress:
                # Show progress bar with expected_size (which can be None)
                with tqdm(
                    total=expected_size,
                    unit="B",
                    unit_scale=True,
                    unit_divisor=1024,
                    desc=f"Downloading to {filename}",
                    miniters=1,  # Update after each iteration
                    mininterval=0.1,  # Allow updates as frequently as every 0.1 seconds
                ) as pbar:
                    # set chunk_size here so that progress bar can be updated
                    # if set to None, it may only be updated after the entire download is complete
                    # 4*1024*1024 is 4MB
                    for chunk in response.iter_content(chunk_size=4 * 1024 * 1024):
                        if chunk:
                            f.write(chunk)
                            sha256.update(chunk)
                            chunk_size = len(chunk)
                            file_size += chunk_size
                            pbar.update(chunk_size)
            else:
                for chunk in response.iter_content(chunk_size=None):
                    if chunk:
                        f.write(chunk)
                        sha256.update(chunk)
                        file_size += len(chunk)
        return file_size, sha256.hexdigest() 
[docs]
    def ping(self) -> bool:
        """
        Pings the server to see if it is up
        Returns
        -------
        :
            True if the server is up and responded to the ping. False otherwise
        """
        uri = f"{self.address}/api/v1/ping"
        try:
            r = requests.get(uri)
            return r.json()["success"]
        except AllowedConnectionExceptions:
            return False