"""
Caching for the PortalClient
"""
from __future__ import annotations
import datetime
import os
from typing import TYPE_CHECKING, Optional, TypeVar, Type, Any, List, Iterable, Tuple, Sequence
from urllib.parse import urlparse
import apsw
import msgpack
from .utils import chunk_iterable
try:
    import pydantic.v1 as pydantic
    from pydantic.v1 import BaseModel, Extra, validator, PrivateAttr, Field
    from pydantic.v1.json import pydantic_encoder
except ImportError:
    import pydantic
    from pydantic import BaseModel, Extra, validator, PrivateAttr, Field
    from pydantic.json import pydantic_encoder
import zstandard
from .serialization import serialize, deserialize, _msgpack_encode
if TYPE_CHECKING:
    from qcportal.record_models import RecordStatusEnum
    from qcportal.client import PortalClient
_DATASET_T = TypeVar("_DATASET_T")
_RECORD_T = TypeVar("_RECORD_T")
_query_chunk_size = 125
def _msgpack_encode_cache(obj: Any) -> Any:
    # Similar to msgpack_encode in serialization, however
    # does NOT include child records
    # Importing BaseRecord is a circular dependency :(
    if hasattr(obj, "get_cache_dict"):
        return obj.get_cache_dict(exclude_unset=True, by_alias=True)
    else:
        return _msgpack_encode(obj)
[docs]
def compress_for_cache(data: Any) -> bytes:
    serialized_data = msgpack.dumps(data, default=_msgpack_encode_cache, use_bin_type=True)
    compressed_data = zstandard.compress(serialized_data, level=1)
    return compressed_data 
[docs]
def decompress_from_cache(data: bytes, value_type) -> Any:
    decompressed_data = zstandard.decompress(data)
    deserialized_data = deserialize(decompressed_data, "msgpack")
    return pydantic.parse_obj_as(value_type, deserialized_data) 
[docs]
class RecordCache:
    def __init__(self, cache_uri: str, read_only: bool):
        self.cache_uri = cache_uri
        self.read_only = read_only
        if self.read_only:
            self._conn = apsw.Connection(self.cache_uri, flags=apsw.SQLITE_OPEN_READONLY | apsw.SQLITE_OPEN_URI)
        else:
            self._conn = apsw.Connection(
                self.cache_uri, flags=apsw.SQLITE_OPEN_READWRITE | apsw.SQLITE_OPEN_CREATE | apsw.SQLITE_OPEN_URI
            )
        self._conn.pragma("foreign_keys", "ON")
        if not read_only:
            self._create_tables()
    def __str__(self):
        return f"<{self.__class__.__name__} path={self.cache_uri} {'ro' if self.read_only else 'rw'}>"
    def _assert_writable(self):
        assert not self.read_only, "This cache is read-only"
    def _create_tables(self):
        self._assert_writable()
        self._conn.execute(
            """
            CREATE TABLE IF NOT EXISTS records (
                id INTEGER PRIMARY KEY,
                status TEXT NOT NULL,
                modified_on DECIMAL NOT NULL,
                record BLOB NOT NULL
            )
            """
        )
        self._conn.execute("CREATE INDEX IF NOT EXISTS records_status ON records (status)")
[docs]
    def get_record(self, record_id: int, record_type: Type[_RECORD_T]) -> Optional[_RECORD_T]:
        stmt = "SELECT record FROM records WHERE id = ?"
        record_data = self._conn.execute(stmt, (record_id,)).fetchone()
        if record_data is None:
            return None
        record = decompress_from_cache(record_data[1], record_type)
        record._record_cache = self
        return record 
[docs]
    def get_records(self, record_ids: Iterable[int], record_type: Type[_RECORD_T]) -> List[_RECORD_T]:
        all_records = []
        for record_id_batch in chunk_iterable(record_ids, _query_chunk_size):
            id_params = ",".join("?" * len(record_id_batch))
            stmt = f"SELECT record FROM records WHERE id IN ({id_params})"
            rdata = self._conn.execute(stmt, record_id_batch).fetchall()
            for compressed_record in rdata:
                record = decompress_from_cache(compressed_record[0], record_type)
                record._record_cache = self
                all_records.append(record)
        return all_records 
[docs]
    def get_existing_records(self, record_ids: Iterable[int]) -> List[int]:
        ret = []
        for record_id_batch in chunk_iterable(record_ids, _query_chunk_size):
            record_id_params = ",".join("?" * len(record_id_batch))
            stmt = f"SELECT id FROM records WHERE id IN ({record_id_params})"
            r = self._conn.execute(stmt, record_id_batch).fetchall()
            ret.extend(r)
        return ret 
[docs]
    def update_records(self, records: Iterable[_RECORD_T]):
        self._assert_writable()
        with self._conn:
            for record_batch in chunk_iterable(records, 10):
                n_batch = len(record_batch)
                values_params = ",".join(["(?, ?, ?, ?)"] * n_batch)
                all_params = []
                for r in record_batch:
                    all_params.extend((r.id, r.status, r.modified_on.timestamp(), compress_for_cache(r)))
                stmt = f"REPLACE INTO records (id, status, modified_on, record) VALUES {values_params}"
                self._conn.execute(stmt, all_params)
        for r in records:
            r._record_cache = self
            r._cache_dirty = False 
[docs]
    def writeback_record(self, record):
        self._assert_writable()
        compressed_record = compress_for_cache(record)
        # Only update if timestamp is same or newer, and if this record is larger
        # than what is stored already
        stmt = f"""INSERT OR REPLACE INTO records (id, status, modified_on, record)
                   SELECT ?, ?, ?, ? WHERE NOT EXISTS (SELECT 1 FROM records WHERE id = ?
                   AND (modified_on > ? OR (modified_on = ? and length(record) > ?)))"""
        ts = record.modified_on.timestamp()
        row_data = (record.id, record.status, ts, compressed_record, record.id, ts, ts, len(compressed_record))
        self._conn.execute(stmt, row_data) 
[docs]
    def delete_record(self, record_id: int):
        self._assert_writable()
        stmt = "DELETE FROM records WHERE id=?"
        self._conn.execute(stmt, (record_id,)) 
[docs]
    def delete_records(self, record_ids: Iterable[int]):
        self._assert_writable()
        for record_id_batch in chunk_iterable(record_ids, _query_chunk_size):
            record_id_params = ",".join("?" * len(record_id_batch))
            stmt = f"DELETE FROM records WHERE id IN ({record_id_params})"
            self._conn.execute(stmt, record_id_batch) 
 
[docs]
class DatasetCache(RecordCache):
    def __init__(self, cache_uri: str, read_only: bool, dataset_type: Type[_DATASET_T]):
        self._entry_type = dataset_type._entry_type
        self._specification_type = dataset_type._specification_type
        self._record_type = dataset_type._record_type
        RecordCache.__init__(self, cache_uri=cache_uri, read_only=read_only)
    def _create_tables(self):
        RecordCache._create_tables(self)
        self._conn.execute(
            """
            CREATE TABLE IF NOT EXISTS metadata (
                key TEXT PRIMARY KEY,
                value BLOB NOT NULL
            )
        """
        )
        self._conn.execute(
            """
            CREATE TABLE IF NOT EXISTS dataset_entries (
                name TEXT PRIMARY KEY,
                entry BLOB NOT NULL
            )
        """
        )
        self._conn.execute(
            """
            CREATE TABLE IF NOT EXISTS dataset_specifications (
                name TEXT PRIMARY KEY,
                specification BLOB NOT NULL
            )
        """
        )
        self._conn.execute(
            """
            CREATE TABLE IF NOT EXISTS dataset_records (
                entry_name TEXT NOT NULL,
                specification_name TEXT NOT NULL,
                record_id INTEGER NOT NULL,
                PRIMARY KEY (entry_name, specification_name),
                FOREIGN KEY (entry_name) REFERENCES dataset_entries(name) ON DELETE CASCADE ON UPDATE CASCADE,
                FOREIGN KEY (specification_name) REFERENCES dataset_specifications(name) ON DELETE CASCADE ON UPDATE CASCADE
            )
        """
        )
        self._conn.execute("CREATE INDEX IF NOT EXISTS dataset_records_entry_name ON dataset_records (entry_name)")
        self._conn.execute(
            "CREATE INDEX IF NOT EXISTS dataset_records_specification_name ON dataset_records (specification_name)"
        )
        self._conn.execute("CREATE INDEX IF NOT EXISTS dataset_records_record_id ON dataset_records (record_id)")
[docs]
    def entry_exists(self, name: str) -> bool:
        stmt = "SELECT 1 FROM dataset_entries WHERE name=?"
        return self._conn.execute(stmt, (name,)).fetchone() is not None 
[docs]
    def get_entry_names(self) -> List[str]:
        stmt = "SELECT name FROM dataset_entries"
        return [x[0] for x in self._conn.execute(stmt).fetchall()] 
[docs]
    def get_entry(self, name: str) -> Optional[BaseModel]:
        stmt = "SELECT entry FROM dataset_entries WHERE name=?"
        entry_data = self._conn.execute(stmt, (name,)).fetchone()
        if entry_data is None:
            return None
        return decompress_from_cache(entry_data[0], self._entry_type) 
[docs]
    def get_entries(self, names: Iterable[str]) -> List[BaseModel]:
        all_entries = []
        for names_batch in chunk_iterable(names, _query_chunk_size):
            name_param = ",".join("?" * len(names_batch))
            stmt = f"""SELECT entry FROM dataset_entries WHERE name IN ({name_param})"""
            entry_data = self._conn.execute(stmt, (*names_batch,)).fetchall()
            all_entries.extend(decompress_from_cache(x[0], self._entry_type) for x in entry_data)
        return all_entries 
[docs]
    def update_entries(self, entries: Iterable[BaseModel]):
        self._assert_writable()
        assert all(isinstance(e, self._entry_type) for e in entries)
        with self._conn:
            for entry_batch in chunk_iterable(entries, 50):
                n_batch = len(entry_batch)
                values_params = ",".join(["(?, ?)"] * n_batch)
                all_params = []
                for e in entry_batch:
                    all_params.extend((e.name, compress_for_cache(e)))
                stmt = f"REPLACE INTO dataset_entries (name, entry) VALUES {values_params}"
                self._conn.execute(stmt, all_params) 
[docs]
    def rename_entry(self, old_name: str, new_name: str):
        self._assert_writable()
        entry = self.get_entry(old_name)
        if entry is None:  # does not exist
            return
        entry.name = new_name
        stmt = "UPDATE dataset_entries SET name=?, entry=? WHERE name=?"
        self._conn.execute(stmt, (new_name, compress_for_cache(entry), old_name)) 
[docs]
    def delete_entry(self, name):
        self._assert_writable()
        stmt = "DELETE FROM dataset_entries WHERE name=?"
        self._conn.execute(stmt, (name,)) 
[docs]
    def specification_exists(self, name: str) -> bool:
        stmt = "SELECT 1 FROM dataset_specifications WHERE name=?"
        return self._conn.execute(stmt, (name,)).fetchone() is not None 
[docs]
    def get_specification_names(self) -> List[str]:
        stmt = "SELECT name FROM dataset_specifications"
        return [x[0] for x in self._conn.execute(stmt).fetchall()] 
[docs]
    def get_specification(self, name: str):
        stmt = "SELECT specification FROM dataset_specifications WHERE name=?"
        spec_data = self._conn.execute(stmt, (name,)).fetchone()
        if spec_data is None:
            return None
        return decompress_from_cache(spec_data[0], self._specification_type) 
[docs]
    def get_all_specifications(self) -> List[BaseModel]:
        stmt = "SELECT specification FROM dataset_specifications"
        spec_data = self._conn.execute(stmt).fetchall()
        return [decompress_from_cache(x[0], self._specification_type) for x in spec_data] 
[docs]
    def get_specifications(self, names: Iterable[str]) -> List[BaseModel]:
        name_param = ",".join("?" * len(names))
        stmt = f"""SELECT specification FROM dataset_specifications WHERE name IN ({name_param})"""
        entry_data = self._conn.execute(stmt, (*names,)).fetchall()
        return [decompress_from_cache(x[0], self._specification_type) for x in entry_data] 
[docs]
    def update_specifications(self, specifications: Iterable[BaseModel]):
        self._assert_writable()
        assert all(isinstance(s, self._specification_type) for s in specifications)
        with self._conn:
            for specification_batch in chunk_iterable(specifications, 50):
                n_batch = len(specification_batch)
                values_params = ",".join(["(?, ?)"] * n_batch)
                all_params = []
                for s in specification_batch:
                    all_params.extend((s.name, compress_for_cache(s)))
                stmt = f"REPLACE INTO dataset_specifications (name, specification) VALUES {values_params}"
                self._conn.execute(stmt, all_params) 
[docs]
    def rename_specification(self, old_name: str, new_name: str):
        self._assert_writable()
        specification = self.get_specification(old_name)
        if specification is None:  # does not exist
            return
        specification.name = new_name
        stmt = "UPDATE dataset_specifications SET name=?, specification=? WHERE name=?"
        self._conn.execute(stmt, (new_name, compress_for_cache(specification), old_name)) 
[docs]
    def delete_specification(self, name):
        self._assert_writable()
        stmt = "DELETE FROM dataset_specifications WHERE name=?"
        self._conn.execute(stmt, (name,)) 
[docs]
    def dataset_record_exists(self, entry_name: str, specification_name: str) -> bool:
        stmt = "SELECT 1 FROM dataset_records WHERE entry_name=? and specification_name=?"
        return self._conn.execute(stmt, (entry_name, specification_name)).fetchone() is not None 
[docs]
    def get_dataset_record(self, entry_name: str, specification_name: str) -> Optional[_RECORD_T]:
        stmt = """SELECT r.record FROM records r
                  INNER JOIN dataset_records dr ON r.id = dr.record_id
                  WHERE dr.entry_name=? and dr.specification_name=?"""
        record_data = self._conn.execute(stmt, (entry_name, specification_name)).fetchone()
        if record_data is None:
            return None
        record = decompress_from_cache(record_data[0], self._record_type)
        record._record_cache = self
        return record 
[docs]
    def get_dataset_records(
        self,
        entry_names: Iterable[str],
        specification_names: Iterable[str],
        status: Optional[Iterable[RecordStatusEnum]] = None,
    ) -> List[Tuple[str, str, _RECORD_T]]:
        specification_params = ",".join("?" * len(specification_names))
        all_records = []
        for entry_names_batch in chunk_iterable(entry_names, _query_chunk_size):
            entry_params = ",".join("?" * len(entry_names_batch))
            stmt = f"""SELECT dr.entry_name, dr.specification_name, r.record
                       FROM dataset_records dr
                       INNER JOIN records r ON r.id = dr.record_id
                       WHERE dr.entry_name IN ({entry_params})
                       AND dr.specification_name IN ({specification_params})"""
            all_params = (*entry_names_batch, *specification_names)
            if status:
                status_params = ",".join("?" * len(status))
                stmt = stmt + f"AND r.status IN ({status_params})"
                all_params = (*all_params, *status)
            rdata = self._conn.execute(stmt, all_params).fetchall()
            for ename, sname, compressed_record in rdata:
                record = decompress_from_cache(compressed_record, self._record_type)
                record._record_cache = self
                all_records.append((ename, sname, record))
        return all_records 
[docs]
    def update_dataset_records(self, record_info: Iterable[Tuple[str, str, int]]):
        self._assert_writable()
        with self._conn:
            for info_batch in chunk_iterable(record_info, 10):
                n_batch = len(info_batch)
                values_params = ",".join(["(?, ?, ?)"] * n_batch)
                all_params = []
                for e, s, rid in info_batch:
                    all_params.extend((e, s, rid))
                stmt = f"""REPLACE INTO dataset_records (entry_name, specification_name, record_id)
                          VALUES {values_params}"""
                self._conn.execute(stmt, all_params) 
[docs]
    def delete_dataset_record(self, entry_name: str, specification_name: str):
        self._assert_writable()
        stmt = "DELETE FROM dataset_records WHERE entry_name=? AND specification_name=?"
        self._conn.execute(stmt, (entry_name, specification_name)) 
[docs]
    def delete_dataset_records(
        self, entry_names: Optional[Iterable[str]], specification_names: Optional[Iterable[str]]
    ):
        self._assert_writable()
        all_params = []
        conds = []
        stmt = f"DELETE FROM dataset_records "
        if entry_names is not None:
            entry_params = ",".join("?" * len(entry_names))
            all_params.extend(entry_names)
            conds.append(f" entry_name IN ({entry_params})")
        if specification_names is not None:
            specification_params = ",".join("?" * len(specification_names))
            all_params.extend(specification_names)
            conds.append(f" specification_name IN ({specification_params})")
        if conds:
            stmt += " WHERE " + " AND ".join(conds)
        self._conn.execute(stmt, all_params) 
[docs]
    def get_dataset_record_info(
        self,
        entry_names: Iterable[str],
        specification_names: Iterable[str],
        status: Optional[Iterable[RecordStatusEnum]],
    ) -> List[Tuple[str, str, int, RecordStatusEnum, datetime.datetime]]:
        all_info = []
        specification_params = ",".join("?" * len(specification_names))
        if status is not None:
            status_params = ",".join("?" * len(status))
        else:
            status_params = None
        for entry_names_batch in chunk_iterable(entry_names, _query_chunk_size):
            entry_params = ",".join("?" * len(entry_names_batch))
            stmt = f"""SELECT dr.entry_name, dr.specification_name, dr.record_id, r.status, r.modified_on
                       FROM dataset_records dr
                       INNER JOIN records r ON r.id = dr.record_id
                       WHERE dr.entry_name IN ({entry_params})
                       AND dr.specification_name IN ({specification_params})"""
            if status_params is not None:
                stmt = stmt + f"AND status IN ({status_params})"
                all_params = (*entry_names_batch, *specification_names, *status)
            else:
                all_params = (*entry_names_batch, *specification_names)
            rinfo = self._conn.execute(stmt, all_params).fetchall()
            all_info.extend(
                (e, s, id, status, datetime.datetime.fromtimestamp(modified_on, tz=datetime.timezone.utc))
                for e, s, id, status, modified_on in rinfo
            )
        return all_info 
[docs]
    def get_existing_dataset_records(
        self, entry_names: Iterable[str], specification_names: Iterable[str]
    ) -> List[Tuple[str, str, int]]:
        specification_params = ",".join("?" * len(specification_names))
        ret = []
        for entry_names_batch in chunk_iterable(entry_names, _query_chunk_size):
            entry_params = ",".join("?" * len(entry_names_batch))
            stmt = f"""SELECT entry_name, specification_name, record_id
                       FROM dataset_records
                       WHERE entry_name IN ({entry_params})
                       AND specification_name IN ({specification_params})"""
            r = self._conn.execute(stmt, (*entry_names_batch, *specification_names)).fetchall()
            ret.extend(r)
        return ret 
 
[docs]
class PortalCache:
    def __init__(self, server_uri: str, cache_dir: Optional[str], max_size: int):
        parsed_url = urlparse(server_uri)
        # Should work as a reasonable fingerprint?
        self.server_fingerprint = f"{parsed_url.hostname}_{parsed_url.port}"
        if cache_dir:
            # _shared_memory shouldn't be used, so we don't set it and wait for errors
            self._is_disk = True
            self.cache_dir = os.path.join(os.path.abspath(cache_dir), self.server_fingerprint)
            os.makedirs(self.cache_dir, exist_ok=True)
        else:
            self._is_disk = False
            self.cache_dir = None
[docs]
    def get_cache_path(self, cache_name: str) -> str:
        if not self._is_disk:
            raise RuntimeError("Cannot get path to cache for memory-only cache")
        return os.path.join(self.cache_dir, f"{cache_name}.sqlite") 
[docs]
    def get_cache_uri(self, cache_name: str) -> str:
        if self._is_disk:
            file_path = self.get_cache_path(cache_name)
            uri = f"file:{file_path}"
        else:
            uri = ":memory:"
        return uri 
[docs]
    def get_dataset_cache_path(self, dataset_id: int) -> str:
        return self.get_cache_path(f"dataset_{dataset_id}") 
[docs]
    def get_dataset_cache_uri(self, dataset_id: int) -> str:
        return self.get_cache_uri(f"dataset_{dataset_id}") 
[docs]
    def get_dataset_cache(self, dataset_id: int, dataset_type: Type[_DATASET_T]) -> DatasetCache:
        uri = self.get_dataset_cache_uri(dataset_id)
        # If you are asking this for a dataset cache, it should be writable
        return DatasetCache(uri, False, dataset_type) 
    @property
    def is_disk(self) -> bool:
        return self._is_disk
[docs]
    def vacuum(self, cache_name: Optional[str] = None):
        if self._is_disk:
            # TODO
            return 
 
[docs]
def get_records_with_cache(
    client: Optional[PortalClient],
    base_url_prefix: str,
    record_cache: Optional[RecordCache],
    record_type: Type[_RECORD_T],
    record_ids: Sequence[int],
    include: Optional[Iterable[str]] = None,
    force_fetch: bool = False,
) -> List[_RECORD_T]:
    """
    Helper function for obtaining child records either from the cache or from the server
    The records are returned in the same order as the `record_ids` parameter.
    If records are missing from the cache, and client is None, and exception is raised.
    Newly-fetched records will not be immediately written to the cache. Instead, they will be attached to this cache
    and will be written back to the cache when the record object is destructed.
    If `include` is specified, additional fields will be fetched from the server. However, if the records are in the
    cache already, they may be missing those fields. In that case, the additional information may be fetched
    from the server. If a client is not provided, an exception will be raised.
    This function will fetch the children of the records if enough information
    is fetched of the parent record. This is handled by the various fetch_children_multi
    class functions of the record types.
    Parameters
    ----------
    client
        The client to use for fetching records from the server.
        If `None`, the function will only use the cache
    base_url_prefix
        Prefix of all the URLs for fetching records
    record_cache
        The cache to use for fetching records from the cache.
        If `None`, the function will only use the client
    record_type
        The type of record to fetch
    record_ids
        Single ID or sequence/list of records to obtain
    include
        Additional fields to include in the returned record (if fetching from the client)
    force_fetch
        If `True`, the function will fetch all records from the server,
        regardless of whether they are in the cache
    Returns
    -------
    :
        List of records in the same order as the input `record_ids`.
    """
    if record_cache is None or force_fetch:
        existing_records = []
        records_tofetch = set(record_ids)
    else:
        existing_records = record_cache.get_records(record_ids, record_type)
        records_tofetch = set(record_ids) - {x.id for x in existing_records}
        for r in existing_records:
            r.propagate_client(client, base_url_prefix)
    if records_tofetch:
        if client is None:
            raise RuntimeError("Need to fetch some records, but not connected to a client")
        recs = client._fetch_records(base_url_prefix, record_type, list(records_tofetch), include=include)
        # Set up for the writeback on change, but write the record as-is for now
        if record_cache is not None:
            record_cache.update_records(recs)
        for r in recs:
            r._record_cache = record_cache
            r._cache_dirty = False
        existing_records += recs
    # Fetch all children as well
    record_type.fetch_children_multi(existing_records, include=include, force_fetch=force_fetch)
    # Return everything in the same order as the input
    all_recs = {r.id: r for r in existing_records}
    ret = [all_recs.get(rid, None) for rid in record_ids]
    if any(x is None for x in ret):
        missing_ids = set(record_ids) - set(all_recs.keys())
        raise RuntimeError(
            f"Not all records found either in the cache or on the server. Missing records: {missing_ids}"
        )
    return ret