Source code for qcportal.optimization.record_models

from __future__ import annotations

from collections.abc import Iterable
from copy import deepcopy
from enum import Enum
from typing import Literal, Any

from pydantic import BaseModel, Field, field_validator, ConfigDict

from qcportal.base_models import RestModelBase
from qcportal.cache import get_records_with_cache
from qcportal.common_types import LowerStr
from qcportal.exceptions import NoClientError
from qcportal.molecules import Molecule
from qcportal.qcschema_v1 import QCInputSpecification, OptimizationResult
from qcportal.record_models import (
    BaseRecord,
    RecordAddBodyBase,
    RecordQueryFilters,
    RecordStatusEnum,
    compare_base_records,
)
from qcportal.singlepoint import (
    SinglepointProtocols,
    SinglepointRecord,
    QCSpecification,
    SinglepointDriver,
    compare_singlepoint_records,
)
from qcportal.utils import is_included


[docs] class TrajectoryProtocolEnum(str, Enum): """ Which gradient evaluations to keep in an optimization trajectory. """ all = "all" initial_and_final = "initial_and_final" final = "final" none = "none"
[docs] class OptimizationProtocols(BaseModel): """ Protocols regarding the manipulation of a Optimization output data. """ trajectory: TrajectoryProtocolEnum = Field( TrajectoryProtocolEnum.all, description=str(TrajectoryProtocolEnum.__doc__) )
[docs] class OptimizationSpecification(BaseModel): """ An OptimizationSpecification as stored on the server This is the same as the input specification, with a few ids added """ model_config = ConfigDict(extra="forbid") program: LowerStr = Field(..., description="The program to use for an optimization") qc_specification: QCSpecification keywords: dict[str, Any] = Field({}) protocols: OptimizationProtocols = Field(default_factory=OptimizationProtocols)
[docs] @field_validator("qc_specification", mode="before") @classmethod def force_qcspec(cls, v): if isinstance(v, QCSpecification): v = v.model_dump() v["driver"] = SinglepointDriver.deferred v["protocols"] = SinglepointProtocols() return v
[docs] class OptimizationRecord(BaseRecord): record_type: Literal["optimization"] = "optimization" specification: OptimizationSpecification initial_molecule_id: int final_molecule_id: int | None energies: list[float] | None ###################################################### # Fields not always included when fetching the record ###################################################### initial_molecule_: Molecule | None = Field(None, alias="initial_molecule") final_molecule_: Molecule | None = Field(None, alias="final_molecule") trajectory_ids_: list[int] | None = Field(None, alias="trajectory_ids") ############################################## # Fields with child records # (generally not received from the server) ############################################## trajectory_records_: list[SinglepointRecord] | None = Field(None, alias="trajectory_records") @classmethod def _fetch_children_multi( cls, client, record_cache, records: Iterable[OptimizationRecord], include: Iterable[str], force_fetch: bool = False, ): # Should be checked by the calling function assert records assert all(isinstance(x, OptimizationRecord) for x in records) base_url_prefix = next(iter(records))._base_url_prefix assert all(r._base_url_prefix == base_url_prefix for r in records) if is_included("trajectory", include, None, False): # collect all singlepoint ids for all optimizations sp_ids = set() for r in records: if r.trajectory_ids_: sp_ids.update(r.trajectory_ids_) sp_ids = list(sp_ids) sp_records = get_records_with_cache( client, base_url_prefix, record_cache, SinglepointRecord, sp_ids, include=include, force_fetch=force_fetch, ) sp_map = {r.id: r for r in sp_records} for r in records: if r.trajectory_ids_ is None: r.trajectory_records_ = None else: r.trajectory_records_ = [sp_map[x] for x in r.trajectory_ids_] r.propagate_client(r._client, base_url_prefix)
[docs] def propagate_client(self, client, base_url_prefix: str | None): BaseRecord.propagate_client(self, client, base_url_prefix) if self.trajectory_records_ is not None: for sp in self.trajectory_records_: sp.propagate_client(client, base_url_prefix)
def _fetch_initial_molecule(self): self._assert_online() self.initial_molecule_ = self._client.get_molecules([self.initial_molecule_id])[0] def _fetch_final_molecule(self): if self.final_molecule_id is not None: self._assert_online() self.final_molecule_ = self._client.get_molecules([self.final_molecule_id])[0] def _fetch_trajectory(self): if self.trajectory_ids_ is None: self._assert_online() self.trajectory_ids_ = self._client.make_request( "get", f"api/v1/records/optimization/{self.id}/trajectory", list[int], ) self.fetch_children(["trajectory"])
[docs] def get_cache_dict(self, **kwargs) -> dict[str, Any]: return self.model_dump(exclude={"trajectory_records_"}, **kwargs)
@property def initial_molecule(self) -> Molecule: if self.initial_molecule_ is None: self._fetch_initial_molecule() return self.initial_molecule_ @property def final_molecule(self) -> Molecule | None: if self.final_molecule_ is None: self._fetch_final_molecule() return self.final_molecule_ @property def trajectory(self) -> list[SinglepointRecord] | None: if self.trajectory_records_ is None: self._fetch_trajectory() return self.trajectory_records_
[docs] def trajectory_element(self, trajectory_index: int) -> SinglepointRecord: if self.trajectory_records_ is not None: return self.trajectory_records_[trajectory_index] else: self._assert_online() if self.trajectory_ids_ is None: self.trajectory_ids_ = self._client.make_request( "get", f"api/v1/records/optimization/{self.id}/trajectory", list[int], ) if self.trajectory_ids_ is not None: traj_id = self.trajectory_ids_[trajectory_index] sp_rec = self._get_child_records([traj_id], SinglepointRecord)[0] sp_rec.propagate_client(self._client, self._base_url_prefix) return sp_rec else: raise RuntimeError(f"Cannot find trajectory for record {self.id}")
[docs] def to_qcschema_result(self) -> OptimizationResult: if self.status != RecordStatusEnum.complete: raise RuntimeError(f"Cannot create QCSchema result from record with status {self.status}") try: extras = deepcopy(self.extras) extras["_qcfractal_modified_on"] = self.compute_history[0].modified_on if self.trajectory is not None: trajectory = [x.to_qcschema_result() for x in self.trajectory] else: trajectory = None # TODO - correct? new_keywords = deepcopy(self.specification.keywords) new_keywords["program"] = self.specification.qc_specification.program return OptimizationResult( initial_molecule=self.initial_molecule, final_molecule=self.final_molecule, trajectory=trajectory, energies=self.energies, keywords=new_keywords, input_specification=QCInputSpecification( driver=SinglepointDriver.gradient, # forced model=dict( method=self.specification.qc_specification.method, basis=self.specification.qc_specification.basis, ), keywords=self.specification.qc_specification.keywords, ), protocols=self.specification.protocols.model_dump(), extras=extras, stdout=self.stdout, provenance=self.provenance.model_dump(), success=True, # Status has been checked above ) except NoClientError: raise RuntimeError( "Record does not contain the required data for a QCSchema result, and this record is " "not connected to a client. If fetching records, use include=['**']. " "If this is from a dataset view, use include=['**'] and include_children=True " "when creating the view" )
[docs] class OptimizationQueryFilters(RecordQueryFilters): program: list[str] | None = None qc_program: list[LowerStr] | None = None qc_method: list[LowerStr] | None = None qc_basis: list[LowerStr | None] | None = None initial_molecule_id: list[int] | None = None final_molecule_id: list[int] | None = None @field_validator("qc_basis") @classmethod def _convert_basis(cls, v): # Convert empty string to None # Lowercasing is handled by constr if v is not None: return ["" if x is None else x for x in v] else: return None
[docs] class OptimizationInput(RestModelBase): record_type: Literal["optimization"] = "optimization" specification: OptimizationSpecification initial_molecule: int | Molecule
[docs] class OptimizationMultiInput(RestModelBase): specification: OptimizationSpecification initial_molecules: list[int | Molecule]
[docs] class OptimizationAddBody(RecordAddBodyBase, OptimizationMultiInput): pass
[docs] def compare_optimization_records(record_1: OptimizationRecord, record_2: OptimizationRecord): compare_base_records(record_1, record_2) assert record_1.initial_molecule == record_2.initial_molecule if record_1.final_molecule_id is not None: assert record_1.final_molecule == record_2.final_molecule assert record_1.energies == record_2.energies assert (record_1.trajectory is None) == (record_2.trajectory is None) if record_1.trajectory is not None: assert len(record_1.trajectory) == len(record_2.trajectory) for t1, t2 in zip(record_1.trajectory, record_2.trajectory): compare_singlepoint_records(t1, t2)