from __future__ import annotations
from collections.abc import Iterable
from copy import deepcopy
from enum import Enum
from typing import Literal, Any
from pydantic import BaseModel, Field, field_validator, ConfigDict
from qcportal.base_models import RestModelBase
from qcportal.cache import get_records_with_cache
from qcportal.common_types import LowerStr
from qcportal.exceptions import NoClientError
from qcportal.molecules import Molecule
from qcportal.qcschema_v1 import QCInputSpecification, OptimizationResult
from qcportal.record_models import (
BaseRecord,
RecordAddBodyBase,
RecordQueryFilters,
RecordStatusEnum,
compare_base_records,
)
from qcportal.singlepoint import (
SinglepointProtocols,
SinglepointRecord,
QCSpecification,
SinglepointDriver,
compare_singlepoint_records,
)
from qcportal.utils import is_included
[docs]
class TrajectoryProtocolEnum(str, Enum):
"""
Which gradient evaluations to keep in an optimization trajectory.
"""
all = "all"
initial_and_final = "initial_and_final"
final = "final"
none = "none"
[docs]
class OptimizationProtocols(BaseModel):
"""
Protocols regarding the manipulation of a Optimization output data.
"""
trajectory: TrajectoryProtocolEnum = Field(
TrajectoryProtocolEnum.all, description=str(TrajectoryProtocolEnum.__doc__)
)
[docs]
class OptimizationSpecification(BaseModel):
"""
An OptimizationSpecification as stored on the server
This is the same as the input specification, with a few ids added
"""
model_config = ConfigDict(extra="forbid")
program: LowerStr = Field(..., description="The program to use for an optimization")
qc_specification: QCSpecification
keywords: dict[str, Any] = Field({})
protocols: OptimizationProtocols = Field(default_factory=OptimizationProtocols)
[docs]
@field_validator("qc_specification", mode="before")
@classmethod
def force_qcspec(cls, v):
if isinstance(v, QCSpecification):
v = v.model_dump()
v["driver"] = SinglepointDriver.deferred
v["protocols"] = SinglepointProtocols()
return v
[docs]
class OptimizationRecord(BaseRecord):
record_type: Literal["optimization"] = "optimization"
specification: OptimizationSpecification
initial_molecule_id: int
final_molecule_id: int | None
energies: list[float] | None
######################################################
# Fields not always included when fetching the record
######################################################
initial_molecule_: Molecule | None = Field(None, alias="initial_molecule")
final_molecule_: Molecule | None = Field(None, alias="final_molecule")
trajectory_ids_: list[int] | None = Field(None, alias="trajectory_ids")
##############################################
# Fields with child records
# (generally not received from the server)
##############################################
trajectory_records_: list[SinglepointRecord] | None = Field(None, alias="trajectory_records")
@classmethod
def _fetch_children_multi(
cls,
client,
record_cache,
records: Iterable[OptimizationRecord],
include: Iterable[str],
force_fetch: bool = False,
):
# Should be checked by the calling function
assert records
assert all(isinstance(x, OptimizationRecord) for x in records)
base_url_prefix = next(iter(records))._base_url_prefix
assert all(r._base_url_prefix == base_url_prefix for r in records)
if is_included("trajectory", include, None, False):
# collect all singlepoint ids for all optimizations
sp_ids = set()
for r in records:
if r.trajectory_ids_:
sp_ids.update(r.trajectory_ids_)
sp_ids = list(sp_ids)
sp_records = get_records_with_cache(
client,
base_url_prefix,
record_cache,
SinglepointRecord,
sp_ids,
include=include,
force_fetch=force_fetch,
)
sp_map = {r.id: r for r in sp_records}
for r in records:
if r.trajectory_ids_ is None:
r.trajectory_records_ = None
else:
r.trajectory_records_ = [sp_map[x] for x in r.trajectory_ids_]
r.propagate_client(r._client, base_url_prefix)
[docs]
def propagate_client(self, client, base_url_prefix: str | None):
BaseRecord.propagate_client(self, client, base_url_prefix)
if self.trajectory_records_ is not None:
for sp in self.trajectory_records_:
sp.propagate_client(client, base_url_prefix)
def _fetch_initial_molecule(self):
self._assert_online()
self.initial_molecule_ = self._client.get_molecules([self.initial_molecule_id])[0]
def _fetch_final_molecule(self):
if self.final_molecule_id is not None:
self._assert_online()
self.final_molecule_ = self._client.get_molecules([self.final_molecule_id])[0]
def _fetch_trajectory(self):
if self.trajectory_ids_ is None:
self._assert_online()
self.trajectory_ids_ = self._client.make_request(
"get",
f"api/v1/records/optimization/{self.id}/trajectory",
list[int],
)
self.fetch_children(["trajectory"])
[docs]
def get_cache_dict(self, **kwargs) -> dict[str, Any]:
return self.model_dump(exclude={"trajectory_records_"}, **kwargs)
@property
def initial_molecule(self) -> Molecule:
if self.initial_molecule_ is None:
self._fetch_initial_molecule()
return self.initial_molecule_
@property
def final_molecule(self) -> Molecule | None:
if self.final_molecule_ is None:
self._fetch_final_molecule()
return self.final_molecule_
@property
def trajectory(self) -> list[SinglepointRecord] | None:
if self.trajectory_records_ is None:
self._fetch_trajectory()
return self.trajectory_records_
[docs]
def trajectory_element(self, trajectory_index: int) -> SinglepointRecord:
if self.trajectory_records_ is not None:
return self.trajectory_records_[trajectory_index]
else:
self._assert_online()
if self.trajectory_ids_ is None:
self.trajectory_ids_ = self._client.make_request(
"get",
f"api/v1/records/optimization/{self.id}/trajectory",
list[int],
)
if self.trajectory_ids_ is not None:
traj_id = self.trajectory_ids_[trajectory_index]
sp_rec = self._get_child_records([traj_id], SinglepointRecord)[0]
sp_rec.propagate_client(self._client, self._base_url_prefix)
return sp_rec
else:
raise RuntimeError(f"Cannot find trajectory for record {self.id}")
[docs]
def to_qcschema_result(self) -> OptimizationResult:
if self.status != RecordStatusEnum.complete:
raise RuntimeError(f"Cannot create QCSchema result from record with status {self.status}")
try:
extras = deepcopy(self.extras)
extras["_qcfractal_modified_on"] = self.compute_history[0].modified_on
if self.trajectory is not None:
trajectory = [x.to_qcschema_result() for x in self.trajectory]
else:
trajectory = None
# TODO - correct?
new_keywords = deepcopy(self.specification.keywords)
new_keywords["program"] = self.specification.qc_specification.program
return OptimizationResult(
initial_molecule=self.initial_molecule,
final_molecule=self.final_molecule,
trajectory=trajectory,
energies=self.energies,
keywords=new_keywords,
input_specification=QCInputSpecification(
driver=SinglepointDriver.gradient, # forced
model=dict(
method=self.specification.qc_specification.method,
basis=self.specification.qc_specification.basis,
),
keywords=self.specification.qc_specification.keywords,
),
protocols=self.specification.protocols.model_dump(),
extras=extras,
stdout=self.stdout,
provenance=self.provenance.model_dump(),
success=True, # Status has been checked above
)
except NoClientError:
raise RuntimeError(
"Record does not contain the required data for a QCSchema result, and this record is "
"not connected to a client. If fetching records, use include=['**']. "
"If this is from a dataset view, use include=['**'] and include_children=True "
"when creating the view"
)
[docs]
class OptimizationQueryFilters(RecordQueryFilters):
program: list[str] | None = None
qc_program: list[LowerStr] | None = None
qc_method: list[LowerStr] | None = None
qc_basis: list[LowerStr | None] | None = None
initial_molecule_id: list[int] | None = None
final_molecule_id: list[int] | None = None
@field_validator("qc_basis")
@classmethod
def _convert_basis(cls, v):
# Convert empty string to None
# Lowercasing is handled by constr
if v is not None:
return ["" if x is None else x for x in v]
else:
return None
[docs]
class OptimizationAddBody(RecordAddBodyBase, OptimizationMultiInput):
pass
[docs]
def compare_optimization_records(record_1: OptimizationRecord, record_2: OptimizationRecord):
compare_base_records(record_1, record_2)
assert record_1.initial_molecule == record_2.initial_molecule
if record_1.final_molecule_id is not None:
assert record_1.final_molecule == record_2.final_molecule
assert record_1.energies == record_2.energies
assert (record_1.trajectory is None) == (record_2.trajectory is None)
if record_1.trajectory is not None:
assert len(record_1.trajectory) == len(record_2.trajectory)
for t1, t2 in zip(record_1.trajectory, record_2.trajectory):
compare_singlepoint_records(t1, t2)