Source code for qcportal.models.records

"""
A model for Compute Records
"""

import abc
import datetime
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union

import numpy as np
import qcelemental as qcel
from pydantic import Field, constr, validator

from ..visualization import scatter_plot
from .common_models import DriverEnum, ObjectId, ProtoModel, QCSpecification
from .model_utils import hash_dictionary, prepare_basis, recursive_normalizer

if TYPE_CHECKING:  # pragma: no cover
    from qcelemental.models import OptimizationInput, ResultInput

    from .common_models import KeywordSet, Molecule

__all__ = ["OptimizationRecord", "ResultRecord", "OptimizationRecord", "RecordBase"]


class RecordStatusEnum(str, Enum):
    """
    The state of a record object. The states which are available are a finite set.
    """

    complete = "COMPLETE"
    incomplete = "INCOMPLETE"
    running = "RUNNING"
    error = "ERROR"


class RecordBase(ProtoModel, abc.ABC):
    """
    A BaseRecord object for Result and Procedure records. Contains all basic
    fields common to the all records.
    """

    # Classdata
    _hash_indices: Set[str]

    # Helper data
    client: Any = Field(None, description="The client object which the records are fetched from.")
    cache: Dict[str, Any] = Field(
        {},
        description="Object cache from expensive queries. It should be very rare that this needs to be set manually "
        "by the user.",
    )

    # Base identification
    id: ObjectId = Field(
        None, description="Id of the object on the database. This is assigned automatically by the database."
    )
    hash_index: Optional[str] = Field(
        None, description="Hash of this object used to detect duplication and collisions in the database."
    )
    procedure: str = Field(..., description="Name of the procedure which this Record targets.")
    program: str = Field(
        ...,
        description="The quantum chemistry program which carries out the individual quantum chemistry calculations.",
    )
    version: int = Field(..., description="The version of this record object describes.")
    protocols: Optional[Dict[str, Any]] = Field(
        None, description="Protocols that change the data stored in top level fields."
    )

    # Extra fields
    extras: Dict[str, Any] = Field({}, description="Extra information to associate with this record.")
    stdout: Optional[ObjectId] = Field(
        None,
        description="The Id of the stdout data stored in the database which was used to generate this record from the "
        "various programs which were called in the process.",
    )
    stderr: Optional[ObjectId] = Field(
        None,
        description="The Id of the stderr data stored in the database which was used to generate this record from the "
        "various programs which were called in the process.",
    )
    error: Optional[ObjectId] = Field(
        None,
        description="The Id of the error data stored in the database in the event that an error was generated in the "
        "process of carrying out the process this record targets. If no errors were raised, this field "
        "will be empty.",
    )

    # Compute status
    manager_name: Optional[str] = Field(None, description="Name of the Queue Manager which generated this record.")
    status: RecordStatusEnum = Field(RecordStatusEnum.incomplete, description=str(RecordStatusEnum.__doc__))
    modified_on: datetime.datetime = Field(None, description="Last time the data this record points to was modified.")
    created_on: datetime.datetime = Field(None, description="Time the data this record points to was first created.")

    # Carry-ons
    provenance: Optional[qcel.models.Provenance] = Field(
        None,
        description="Provenance information tied to the creation of this record. This includes things such as every "
        "program which was involved in generating the data for this record.",
    )

    class Config(ProtoModel.Config):
        build_hash_index = True

    @validator("program")
    def check_program(cls, v):
        return v.lower()

    def __init__(self, **data):

        # Set datetime defaults if not automatically available
        data.setdefault("modified_on", datetime.datetime.utcnow())
        data.setdefault("created_on", datetime.datetime.utcnow())

        super().__init__(**data)

        # Set hash index if not present
        if self.Config.build_hash_index and (self.hash_index is None):
            self.__dict__["hash_index"] = self.get_hash_index()

    def __repr_args__(self):

        return [("id", f"{self.id}"), ("status", f"{self.status}")]

    ### Serialization helpers

    @classmethod
    def get_hash_fields(cls) -> Set[str]:
        """Provides a description of the fields to be used in the hash
        that uniquely defines this object.

        Returns
        -------
        Set[str]
            A list of all fields that are used in the hash.

        """
        return cls._hash_indices | {"procedure", "program"}

    def get_hash_index(self) -> str:
        """Builds (or rebuilds) the hash of this
        object using the internally known hash fields.

        Returns
        -------
        str
            The objects unique hash index.
        """
        data = self.dict(include=self.get_hash_fields(), encoding="json")

        return hash_dictionary(data)

    def dict(self, *args, **kwargs):
        kwargs["exclude"] = (kwargs.pop("exclude", None) or set()) | {"client", "cache"}
        # kwargs["skip_defaults"] = True
        return super().dict(*args, **kwargs)

    ### Checkers

    def check_client(self, noraise: bool = False) -> bool:
        """Checks whether this object owns a FractalClient or not.
        This is often done so that objects pulled from a server using
        a FractalClient still posses a connection to the server so that
        additional data related to this object can be queried.

        Raises
        ------
        ValueError
            If this object does not contain own a client.

        Parameters
        ----------
        noraise : bool, optional
            Does not raise an error if this is True and instead returns
            a boolean depending if a client exists or not.

        Returns
        -------
        bool
            If True, the object owns a connection to a server. False otherwise.
        """
        if self.client is None:
            if noraise:
                return False

            raise ValueError("Requested method requires a client, but client was '{}'.".format(self.client))

        return True

    ### KVStore Getters

    def _kvstore_getter(self, field_name):
        """
        Internal KVStore getting object
        """
        self.check_client()

        oid = self.__dict__[field_name]
        if oid is None:
            return None

        if field_name not in self.cache:
            # Decompress here, rather than later
            # that way, it is decompressed in the cache
            kv = self.client.query_kvstore([oid])[oid]

            if field_name == "error":
                self.cache[field_name] = kv.get_json()
            else:
                self.cache[field_name] = kv.get_string()

        return self.cache[field_name]

    def get_stdout(self) -> Optional[str]:
        """Pulls the stdout from the denormalized KVStore and returns it to the user.

        Returns
        -------
        Optional[str]
            The requested stdout, none if no stdout present.
        """
        return self._kvstore_getter("stdout")

    def get_stderr(self) -> Optional[str]:
        """Pulls the stderr from the denormalized KVStore and returns it to the user.

        Returns
        -------
        Optional[str]
            The requested stderr, none if no stderr present.
        """

        return self._kvstore_getter("stderr")

    def get_error(self) -> Optional[qcel.models.ComputeError]:
        """Pulls the stderr from the denormalized KVStore and returns it to the user.

        Returns
        -------
        Optional[qcel.models.ComputeError]
            The requested compute error, none if no error present.
        """
        value = self._kvstore_getter("error")
        if value:
            return qcel.models.ComputeError(**value)
        else:
            return value


[docs]class ResultRecord(RecordBase): # Classdata _hash_indices = {"driver", "method", "basis", "molecule", "keywords", "program"} # Version data version: int = Field(1, description="Version of the ResultRecord Model which this data was created with.") procedure: constr(strip_whitespace=True, regex="single") = Field( "single", description='Procedure is fixed as "single" because this is single quantum chemistry result.' ) # Input data driver: DriverEnum = Field(..., description=str(DriverEnum.__doc__)) method: str = Field(..., description="The quantum chemistry method the driver runs with.") molecule: ObjectId = Field( ..., description="The Id of the molecule in the Database which the result is computed on." ) basis: Optional[str] = Field( None, description="The quantum chemistry basis set to evaluate (e.g., 6-31g, cc-pVDZ, ...). Can be ``None`` for " "methods without basis sets.", ) keywords: Optional[ObjectId] = Field( None, description="The Id of the :class:`KeywordSet` which was passed into the quantum chemistry program that " "performed this calculation.", ) protocols: Optional[qcel.models.results.ResultProtocols] = Field( qcel.models.results.ResultProtocols(), description="" ) # Output data return_result: Union[float, qcel.models.types.Array[float], Dict[str, Any]] = Field( None, description="The primary result of the calculation, output is a function of the specified ``driver``." ) properties: qcel.models.ResultProperties = Field( None, description="Additional data and results computed as part of the ``return_result``." ) wavefunction: Optional[Dict[str, Any]] = Field(None, description="Wavefunction data generated by the Result.") wavefunction_data_id: Optional[ObjectId] = Field(None, description="The id of the wavefunction") class Config(RecordBase.Config): """A hash index is not used for ResultRecords as they can be uniquely determined with queryable keys. """ build_hash_index = False @validator("method") def check_method(cls, v): """Methods should have a lower string to match the database.""" return v.lower() @validator("basis") def check_basis(cls, v): return prepare_basis(v) def get_wavefunction(self, key: Union[str, List[str]]) -> Any: """ Pulls down the Wavefunction data associated with the computation. """ if self.wavefunction is None: raise AttributeError("This Record was not computed with Wavefunction data.") single_return = False if isinstance(key, str): key = [key] single_return = True keys = [x.lower() for x in key] self.cache.setdefault("wavefunction", {}) mapped_keys = {self.wavefunction["return_map"].get(x, x) for x in keys} missing = mapped_keys - self.cache["wavefunction"].keys() unknown = missing - set(self.wavefunction["available"] + ["basis", "restricted"]) if unknown: raise KeyError( f"Wavefunction Key(s) `{unknown}` not understood, available keys are: {self.wavefunction['available']}" ) if missing: # Translate a return value proj = [self.wavefunction["return_map"].get(x, x) for x in missing] self.cache["wavefunction"].update( self.client.custom_query( "wavefunctionstore", None, {"id": self.wavefunction_data_id}, meta={"include": proj} ) ) if "basis" in missing: self.cache["wavefunction"]["basis"] = qcel.models.BasisSet(**self.cache["wavefunction"]["basis"]) # Remap once more ret = {} for k in keys: mkey = self.wavefunction["return_map"].get(k, k) ret[k] = self.cache["wavefunction"][mkey] if single_return: return ret[keys[0]] else: return ret def get_molecule(self) -> "Molecule": """ Pulls the Result's Molecule from the connected database. Returns ------- Molecule The requested Molecule """ self.check_client() if self.molecule is None: return None if "molecule" not in self.cache: self.cache["molecule"] = self.client.query_molecules(id=self.molecule)[0] return self.cache["molecule"]
[docs]class OptimizationRecord(RecordBase): """ A OptimizationRecord for all optimization procedure data. """ # Class data _hash_indices = {"initial_molecule", "keywords", "qc_spec"} # Version data version: int = Field(1, description="Version of the OptimizationRecord Model which this data was created with.") procedure: constr(strip_whitespace=True, regex="optimization") = Field( "optimization", description='A fixed string indication this is a record for an "Optimization".' ) schema_version: int = Field(1, description="The version number of QCSchema under which this record conforms to.") # Input data initial_molecule: ObjectId = Field( ..., description="The Id of the molecule which was passed in as the reference for this Optimization." ) qc_spec: QCSpecification = Field( ..., description="The specification of the quantum chemistry calculation to run at each point." ) keywords: Dict[str, Any] = Field( {}, description="The keyword options which were passed into the Optimization program. " "Note: These are a dictionary and not a :class:`KeywordSet` object.", ) protocols: Optional[qcel.models.procedures.OptimizationProtocols] = Field( qcel.models.procedures.OptimizationProtocols(), description="" ) # Automatting issue currently # description=str(qcel.models.procedures.OptimizationProtocols.__doc__)) # Results energies: List[float] = Field(None, description="The ordered list of energies at each step of the Optimization.") final_molecule: ObjectId = Field( None, description="The ``ObjectId`` of the final, optimized Molecule the Optimization procedure converged to." ) trajectory: List[ObjectId] = Field( None, description="The list of Molecule Id's the Optimization procedure generated at each step of the optimization." "``initial_molecule`` will be the first index, and ``final_molecule`` will be the last index.", ) class Config(RecordBase.Config): pass @validator("keywords") def check_keywords(cls, v): if v is not None: v = recursive_normalizer(v) return v ## Standard function
[docs] def get_final_energy(self) -> float: """The final energy of the geometry optimization. Returns ------- float The optimization molecular energy. """ return self.energies[-1]
[docs] def get_trajectory(self) -> List[ResultRecord]: """Returns the Result records for each gradient evaluation in the trajectory. Returns ------- List['ResultRecord'] A ordered list of Result record gradient computations. """ if "trajectory" not in self.cache: result = {x.id: x for x in self.client.query_results(id=self.trajectory)} self.cache["trajectory"] = [result[x] for x in self.trajectory] return self.cache["trajectory"]
[docs] def get_molecular_trajectory(self) -> List["Molecule"]: """Returns the Molecule at each gradient evaluation in the trajectory. Returns ------- List['Molecule'] A ordered list of Molecules in the trajectory. """ if "molecular_trajectory" not in self.cache: mol_ids = [x.molecule for x in self.get_trajectory()] mols = {x.id: x for x in self.client.query_molecules(id=mol_ids)} self.cache["molecular_trajectory"] = [mols[x] for x in mol_ids] return self.cache["molecular_trajectory"]
[docs] def get_initial_molecule(self) -> "Molecule": """Returns the initial molecule Returns ------- Molecule The initial molecule """ ret = self.client.query_molecules(id=[self.initial_molecule]) return ret[0]
[docs] def get_final_molecule(self) -> "Molecule": """Returns the optimized molecule Returns ------- Molecule The optimized molecule """ ret = self.client.query_molecules(id=[self.final_molecule]) return ret[0]
## Show functions
[docs] def show_history( self, units: str = "kcal/mol", digits: int = 3, relative: bool = True, return_figure: Optional[bool] = None ) -> "plotly.Figure": """Plots the energy of the trajectory the optimization took. Parameters ---------- units : str, optional Units to display the trajectory in. digits : int, optional The number of valid digits to show. relative : bool, optional If True, all energies are shifted by the lowest energy in the trajectory. Otherwise provides raw energies. return_figure : Optional[bool], optional If True, return the raw plotly figure. If False, returns a hosted iPlot. If None, return a iPlot display in Jupyter notebook and a raw plotly figure in all other circumstances. Returns ------- plotly.Figure The requested figure. """ cf = qcel.constants.conversion_factor("hartree", units) energies = np.array(self.energies) if relative: energies = energies - np.min(energies) trace = {"mode": "lines+markers", "x": list(range(1, len(energies) + 1)), "y": np.around(energies * cf, digits)} if relative: ylabel = f"Relative Energy [{units}]" else: ylabel = f"Absolute Energy [{units}]" custom_layout = { "title": "Geometry Optimization", "yaxis": {"title": ylabel, "zeroline": True}, "xaxis": { "title": "Optimization Step", # "zeroline": False, "range": [min(trace["x"]), max(trace["x"])], }, } return scatter_plot([trace], custom_layout=custom_layout, return_figure=return_figure)