Source code for qcportal.optimization.record_models

from __future__ import annotations

from copy import deepcopy
from enum import Enum
from typing import Iterable
from typing import Optional, Union, Any, List, Dict

from pydantic.v1 import BaseModel, Field, constr, validator, Extra
from qcelemental.models import Molecule
from qcelemental.models.procedures import (
    OptimizationResult,
    QCInputSpecification,
)
from typing_extensions import Literal

from qcportal.base_models import RestModelBase
from qcportal.cache import get_records_with_cache
from qcportal.record_models import (
    BaseRecord,
    RecordAddBodyBase,
    RecordQueryFilters,
    RecordStatusEnum,
    compare_base_records,
)
from qcportal.singlepoint import (
    SinglepointProtocols,
    SinglepointRecord,
    QCSpecification,
    SinglepointDriver,
    compare_singlepoint_records,
)
from qcportal.utils import is_included


[docs] class TrajectoryProtocolEnum(str, Enum): """ Which gradient evaluations to keep in an optimization trajectory. """ all = "all" initial_and_final = "initial_and_final" final = "final" none = "none"
[docs] class OptimizationProtocols(BaseModel): """ Protocols regarding the manipulation of a Optimization output data. """ trajectory: TrajectoryProtocolEnum = Field( TrajectoryProtocolEnum.all, description=str(TrajectoryProtocolEnum.__doc__) )
[docs] class OptimizationSpecification(BaseModel): """ An OptimizationSpecification as stored on the server This is the same as the input specification, with a few ids added """
[docs] class Config: extra = Extra.forbid
program: constr(to_lower=True) = Field(..., description="The program to use for an optimization") qc_specification: QCSpecification keywords: Dict[str, Any] = Field({}) protocols: OptimizationProtocols = Field(OptimizationProtocols())
[docs] @validator("qc_specification", pre=True) def force_qcspec(cls, v): if isinstance(v, QCSpecification): v = v.dict() v["driver"] = SinglepointDriver.deferred v["protocols"] = SinglepointProtocols() return v
[docs] class OptimizationRecord(BaseRecord): record_type: Literal["optimization"] = "optimization" specification: OptimizationSpecification initial_molecule_id: int final_molecule_id: Optional[int] energies: Optional[List[float]] ###################################################### # Fields not always included when fetching the record ###################################################### initial_molecule_: Optional[Molecule] = Field(None, alias="initial_molecule") final_molecule_: Optional[Molecule] = Field(None, alias="final_molecule") trajectory_ids_: Optional[List[int]] = Field(None, alias="trajectory_ids") ############################################## # Fields with child records # (generally not received from the server) ############################################## trajectory_records_: Optional[List[SinglepointRecord]] = Field(None, alias="trajectory_records") @classmethod def _fetch_children_multi( cls, client, record_cache, records: Iterable[OptimizationRecord], include: Iterable[str], force_fetch: bool = False, ): # Should be checked by the calling function assert records assert all(isinstance(x, OptimizationRecord) for x in records) base_url_prefix = next(iter(records))._base_url_prefix assert all(r._base_url_prefix == base_url_prefix for r in records) if is_included("trajectory", include, None, False): # collect all singlepoint ids for all optimizations sp_ids = set() for r in records: if r.trajectory_ids_: sp_ids.update(r.trajectory_ids_) sp_ids = list(sp_ids) sp_records = get_records_with_cache( client, base_url_prefix, record_cache, SinglepointRecord, sp_ids, include=include, force_fetch=force_fetch, ) sp_map = {r.id: r for r in sp_records} for r in records: if r.trajectory_ids_ is None: r.trajectory_records_ = None else: r.trajectory_records_ = [sp_map[x] for x in r.trajectory_ids_] r.propagate_client(r._client, base_url_prefix)
[docs] def propagate_client(self, client, base_url_prefix: Optional[str]): BaseRecord.propagate_client(self, client, base_url_prefix) if self.trajectory_records_ is not None: for sp in self.trajectory_records_: sp.propagate_client(client, base_url_prefix)
def _fetch_initial_molecule(self): self._assert_online() self.initial_molecule_ = self._client.get_molecules([self.initial_molecule_id])[0] def _fetch_final_molecule(self): if self.final_molecule_id is not None: self._assert_online() self.final_molecule_ = self._client.get_molecules([self.final_molecule_id])[0] def _fetch_trajectory(self): if self.trajectory_ids_ is None: self._assert_online() self.trajectory_ids_ = self._client.make_request( "get", f"api/v1/records/optimization/{self.id}/trajectory", List[int], ) self.fetch_children(["trajectory"])
[docs] def get_cache_dict(self, **kwargs) -> Dict[str, Any]: return self.dict(exclude={"trajectory_records_"}, **kwargs)
@property def initial_molecule(self) -> Molecule: if self.initial_molecule_ is None: self._fetch_initial_molecule() return self.initial_molecule_ @property def final_molecule(self) -> Optional[Molecule]: if self.final_molecule_ is None: self._fetch_final_molecule() return self.final_molecule_ @property def trajectory(self) -> Optional[List[SinglepointRecord]]: if self.trajectory_records_ is None: self._fetch_trajectory() return self.trajectory_records_
[docs] def trajectory_element(self, trajectory_index: int) -> SinglepointRecord: if self.trajectory_records_ is not None: return self.trajectory_records_[trajectory_index] else: self._assert_online() if self.trajectory_ids_ is None: self.trajectory_ids_ = self._client.make_request( "get", f"api/v1/records/optimization/{self.id}/trajectory", List[int], ) if self.trajectory_ids_ is not None: traj_id = self.trajectory_ids_[trajectory_index] sp_rec = self._get_child_records([traj_id], SinglepointRecord)[0] sp_rec.propagate_client(self._client, self._base_url_prefix) return sp_rec else: raise RuntimeError(f"Cannot find trajectory for record {self.id}")
[docs] def to_qcschema_result(self) -> OptimizationResult: if self.status != RecordStatusEnum.complete: raise RuntimeError(f"Cannot create QCSchema result from record with status {self.status}") extras = deepcopy(self.extras) extras["_qcfractal_modified_on"] = self.compute_history[0].modified_on if self.trajectory is not None: trajectory = [x.to_qcschema_result() for x in self.trajectory] else: trajectory = None # TODO - correct? new_keywords = deepcopy(self.specification.keywords) new_keywords["program"] = self.specification.qc_specification.program return OptimizationResult( initial_molecule=self.initial_molecule, final_molecule=self.final_molecule, trajectory=trajectory, energies=self.energies, keywords=new_keywords, input_specification=QCInputSpecification( driver=SinglepointDriver.gradient, # forced model=dict( method=self.specification.qc_specification.method, basis=self.specification.qc_specification.basis, ), keywords=self.specification.qc_specification.keywords, ), protocols=self.specification.protocols, extras=extras, stdout=self.stdout, provenance=self.provenance, success=True, # Status has been checked above )
[docs] class OptimizationQueryFilters(RecordQueryFilters): program: Optional[List[str]] = None qc_program: Optional[List[constr(to_lower=True)]] = None qc_method: Optional[List[constr(to_lower=True)]] = None qc_basis: Optional[List[Optional[constr(to_lower=True)]]] = None initial_molecule_id: Optional[List[int]] = None final_molecule_id: Optional[List[int]] = None @validator("qc_basis") def _convert_basis(cls, v): # Convert empty string to None # Lowercasing is handled by constr if v is not None: return ["" if x is None else x for x in v] else: return None
[docs] class OptimizationInput(RestModelBase): record_type: Literal["optimization"] = "optimization" specification: OptimizationSpecification initial_molecule: Union[int, Molecule]
[docs] class OptimizationMultiInput(RestModelBase): specification: OptimizationSpecification initial_molecules: List[Union[int, Molecule]]
[docs] class OptimizationAddBody(RecordAddBodyBase, OptimizationMultiInput): pass
[docs] def compare_optimization_records(record_1: OptimizationRecord, record_2: OptimizationRecord): compare_base_records(record_1, record_2) assert record_1.initial_molecule.get_hash() == record_2.initial_molecule.get_hash() assert (record_1.final_molecule_id is None) == (record_2.final_molecule_id is None) if record_1.final_molecule_id is not None: assert record_1.final_molecule.get_hash() == record_2.final_molecule.get_hash() assert record_1.energies == record_2.energies assert (record_1.trajectory_records_ is None) == (record_2.trajectory_records_ is None) assert (record_1.trajectory_ids_ is None) == (record_2.trajectory_ids_ is None) if record_1.trajectory_records_ is not None: assert len(record_1.trajectory_records_) == len(record_2.trajectory_records_) for t1, t2 in zip(record_1.trajectory_records_, record_2.trajectory_records_): compare_singlepoint_records(t1, t2)