Source code for qcportal.manybody.record_models

from __future__ import annotations

from enum import Enum
from typing import List, Union, Optional, Dict, Any, Iterable

try:
    from pydantic.v1 import BaseModel, Extra, validator, constr, Field
except ImportError:
    from pydantic import BaseModel, Extra, validator, constr, Field
from typing_extensions import Literal

from qcportal.base_models import RestModelBase
from qcportal.cache import get_records_with_cache
from qcportal.molecules import Molecule
from qcportal.record_models import BaseRecord, RecordAddBodyBase, RecordQueryFilters, compare_base_records
from qcportal.singlepoint.record_models import (
    QCSpecification,
    SinglepointRecord,
    compare_singlepoint_records,
)


[docs] class BSSECorrectionEnum(str, Enum): nocp = "nocp" cp = "cp" vmfc = "vmfc"
[docs] class ManybodyKeywords(BaseModel): class Config: extra = Extra.forbid return_total_data: bool = False
[docs] class ManybodySpecification(BaseModel): class Config: extra = Extra.forbid program: constr(to_lower=True) = "qcmanybody" levels: Dict[Union[int, Literal["supersystem"]], QCSpecification] bsse_correction: List[BSSECorrectionEnum] keywords: ManybodyKeywords = Field(ManybodyKeywords()) protocols: Dict[str, Any] = Field(default_factory=dict)
[docs] class ManybodyInput(RestModelBase): record_type: Literal["manybody"] = "manybody" specification: ManybodySpecification initial_molecule: Union[int, Molecule]
[docs] class ManybodyMultiInput(RestModelBase): specification: ManybodySpecification initial_molecules: List[Union[int, Molecule]]
[docs] class ManybodyAddBody(RecordAddBodyBase, ManybodyMultiInput): pass
[docs] class ManybodyQueryFilters(RecordQueryFilters): program: Optional[List[str]] = None qc_program: Optional[List[constr(to_lower=True)]] = None qc_method: Optional[List[constr(to_lower=True)]] = None qc_basis: Optional[List[Optional[constr(to_lower=True)]]] = None initial_molecule_id: Optional[List[int]] = None
[docs] class ManybodyClusterMeta(BaseModel): class Config: extra = Extra.forbid molecule_id: int mc_level: str fragments: List[int] basis: List[int] singlepoint_id: Optional[int] molecule: Optional[Molecule] = None
[docs] class ManybodyCluster(ManybodyClusterMeta): singlepoint_record: Optional[SinglepointRecord] = None
[docs] class ManybodyRecord(BaseRecord): record_type: Literal["manybody"] = "manybody" specification: ManybodySpecification initial_molecule_id: int ###################################################### # Fields not always included when fetching the record ###################################################### initial_molecule_: Optional[Molecule] = Field(None, alias="initial_molecule") clusters_meta_: Optional[List[ManybodyClusterMeta]] = Field(None, alias="clusters") ############################################## # Fields with child records # (generally not received from the server) ############################################## cluster_records_: Optional[List[ManybodyCluster]] = Field(None, alias="cluster_records")
[docs] def propagate_client(self, client, base_url_prefix: Optional[str]): BaseRecord.propagate_client(self, client, base_url_prefix) if self.cluster_records_ is not None: for cluster in self.cluster_records_: if cluster.singlepoint_record: cluster.singlepoint_record.propagate_client(client, base_url_prefix)
@classmethod def _fetch_children_multi( cls, client, record_cache, records: Iterable[ManybodyRecord], include: Iterable[str], force_fetch: bool = False ): # Should be checked by the calling function assert records assert all(isinstance(x, ManybodyRecord) for x in records) base_url_prefix = next(iter(records))._base_url_prefix assert all(r._base_url_prefix == base_url_prefix for r in records) if "clusters" in include or "**" in include: # collect all singlepoint ids for all manybody records sp_ids = set() for r in records: if r.clusters_meta_: sp_ids.update(x.singlepoint_id for x in r.clusters_meta_ if x.singlepoint_id is not None) sp_ids = list(sp_ids) sp_recs = get_records_with_cache( client, base_url_prefix, record_cache, SinglepointRecord, sp_ids, include=include, force_fetch=force_fetch, ) sp_map = {x.id: x for x in sp_recs} for r in records: if r.clusters_meta_ is None: r.cluster_records_ = None else: r.cluster_records_ = [] for cm in r.clusters_meta_: cluster = ManybodyCluster(**cm.dict()) if cluster.singlepoint_id is not None: cluster.singlepoint_record = sp_map[cluster.singlepoint_id] r.cluster_records_.append(cluster) r.propagate_client(client, base_url_prefix) def _fetch_initial_molecule(self): self._assert_online() self.initial_molecule_ = self._client.get_molecules([self.initial_molecule_id])[0] def _fetch_clusters(self): if self.clusters_meta_ is None: self._assert_online() self.clusters_meta_ = self._client.make_request( "get", f"api/v1/records/manybody/{self.id}/clusters", List[ManybodyClusterMeta], ) self.fetch_children(["clusters"])
[docs] def get_cache_dict(self, **kwargs) -> Dict[str, Any]: return self.dict(exclude={"cluster_records_"}, **kwargs)
@property def initial_molecule(self) -> Molecule: if self.initial_molecule_ is None: self._fetch_initial_molecule() return self.initial_molecule_ @property def clusters(self) -> List[ManybodyCluster]: if self.cluster_records_ is None: self._fetch_clusters() return self.cluster_records_
[docs] def compare_manybody_records(record_1: ManybodyRecord, record_2: ManybodyRecord): compare_base_records(record_1, record_2) assert record_1.initial_molecule.get_hash() == record_2.initial_molecule.get_hash() assert (record_1.cluster_records_ is None) == (record_2.cluster_records_ is None) if record_1.cluster_records_ is not None: assert len(record_1.cluster_records_) == len(record_2.cluster_records_) # Record ids can be different, so sort by level, fragments, basis cluster_1 = sorted(record_1.cluster_records_, key=lambda x: (x.mc_level, x.fragments, x.basis)) cluster_2 = sorted(record_2.cluster_records_, key=lambda x: (x.mc_level, x.fragments, x.basis)) for r1, r2 in zip(cluster_1, cluster_2): assert r1.mc_level == r2.mc_level assert r1.fragments == r2.fragments assert r1.basis == r2.basis assert r1.molecule.get_hash() == r2.molecule.get_hash() compare_singlepoint_records(r1.singlepoint_record, r2.singlepoint_record)