from __future__ import annotations
from typing import List, Optional, Union, Dict, Iterable, Any
try:
from pydantic.v1 import BaseModel, Field, Extra, root_validator, constr, validator
except ImportError:
from pydantic import BaseModel, Field, Extra, root_validator, constr, validator
from typing_extensions import Literal
from qcportal.base_models import RestModelBase
from qcportal.cache import get_records_with_cache
from qcportal.molecules import Molecule
from qcportal.record_models import (
BaseRecord,
RecordAddBodyBase,
RecordQueryFilters,
RecordStatusEnum,
compare_base_records,
)
from qcportal.utils import recursive_normalizer, is_included
from ..optimization.record_models import OptimizationRecord, OptimizationSpecification, compare_optimization_records
from ..singlepoint.record_models import QCSpecification, SinglepointRecord, compare_singlepoint_records
[docs]
class NEBKeywords(BaseModel):
"""
NEBRecord options
"""
class Config:
extra = Extra.forbid
images: int = Field(
11,
description="Number of images that will be used to locate a rough transition state structure.",
gt=5,
)
spring_constant: float = Field(
1.0,
description="Spring constant in kcal/mol/Ang^2.",
)
spring_type: int = Field(
0,
description="0: Nudged Elastic Band (parallel spring force + perpendicular gradients)\n"
"1: Hybrid Elastic Band (full spring force + perpendicular gradients)\n"
"2: Plain Elastic Band (full spring force + full gradients)\n",
)
maximum_force: float = Field(
0.05,
description="Convergence criteria. Converge when maximum RMS-gradient (ev/Ang) of the chain fall below maximum_force.",
)
average_force: float = Field(
0.025,
description="Convergence criteria. Converge when average RMS-gradient (ev/Ang) of the chain fall below average_force.",
)
maximum_cycle: int = Field(100, description="Maximum iteration number for NEB calculation.")
optimize_ts: bool = Field(
False,
description="Setting it equal to true will perform a transition sate optimization starting with the guessed transition state structure from the NEB calculation result.",
)
optimize_endpoints: bool = Field(
False,
description="Setting it equal to True will optimize two end points of the initial chain before starting NEB.",
)
align: bool = Field(True, description="Align the images before starting the NEB calculation.")
epsilon: float = Field(1e-5, description="Small eigenvalue threshold for resetting Hessian.")
[docs]
@root_validator
def normalize(cls, values):
return recursive_normalizer(values)
[docs]
class NEBSpecification(BaseModel):
class Config:
extra = Extra.forbid
program: constr(to_lower=True) = "geometric"
singlepoint_specification: QCSpecification
optimization_specification: Optional[OptimizationSpecification] = None
keywords: NEBKeywords
[docs]
class NEBOptimization(BaseModel):
[docs]
class config:
extra = Extra.forbid
optimization_id: int
position: int
ts: bool
[docs]
class NEBSinglepoint(BaseModel):
class Config:
extra = Extra.forbid
singlepoint_id: int
chain_iteration: int
position: int
[docs]
class NEBAddBody(RecordAddBodyBase, NEBMultiInput):
pass
[docs]
class NEBQueryFilters(RecordQueryFilters):
program: Optional[List[str]] = "geometric"
qc_program: Optional[List[constr(to_lower=True)]] = None
qc_method: Optional[List[constr(to_lower=True)]] = None
qc_basis: Optional[List[Optional[constr(to_lower=True)]]] = None
molecule_id: Optional[List[int]] = None
@validator("qc_basis")
def _convert_basis(cls, v):
# Convert empty string to None
# Lowercasing is handled by constr
if v is not None:
return ["" if x is None else x for x in v]
else:
return None
[docs]
class NEBRecord(BaseRecord):
record_type: Literal["neb"] = "neb"
specification: NEBSpecification
######################################################
# Fields not always included when fetching the record
######################################################
initial_chain_molecule_ids_: Optional[List[int]] = Field(None, alias="initial_chain_molecule_ids")
singlepoints_: Optional[List[NEBSinglepoint]] = Field(None, alias="singlepoints")
optimizations_: Optional[Dict[str, NEBOptimization]] = Field(None, alias="optimizations")
neb_result_: Optional[Molecule] = Field(None, alias="neb_result")
initial_chain_: Optional[List[Molecule]] = Field(None, alias="initial_chain")
##############################################
# Fields with child records
# (generally not received from the server)
##############################################
optimization_records_: Optional[Dict[str, OptimizationRecord]] = Field(None, field="optimization_records")
singlepoint_records_: Optional[Dict[int, List[SinglepointRecord]]] = Field(None, field="singlepoint_records")
ts_hessian_: Optional[SinglepointRecord] = Field(None, field="ts_hessian")
[docs]
def propagate_client(self, client, base_url_prefix: Optional[str]):
BaseRecord.propagate_client(self, client, base_url_prefix)
if self.optimization_records_ is not None:
for opt in self.optimization_records_.values():
opt.propagate_client(client, base_url_prefix)
if self.singlepoint_records_ is not None:
for splist in self.singlepoint_records_.values():
for sp2 in splist:
sp2.propagate_client(client, base_url_prefix)
@classmethod
def _fetch_children_multi(
cls, client, record_cache, records: Iterable[NEBRecord], include: Iterable[str], force_fetch: bool = False
):
# Should be checked by the calling function
assert records
assert all(isinstance(x, NEBRecord) for x in records)
base_url_prefix = next(iter(records))._base_url_prefix
assert all(r._base_url_prefix == base_url_prefix for r in records)
do_sp = is_included("singlepoints", include, None, False)
do_opt = is_included("optimizations", include, None, False)
if not do_sp and not do_opt:
return
# Collect optimization and singlepoint ids for all NEB
opt_ids = set()
sp_ids = set()
for r in records:
if r.optimizations_ is not None:
opt_ids.update(x.optimization_id for x in r.optimizations_.values())
if r.singlepoints_ is not None:
sp_ids.update(x.singlepoint_id for x in r.singlepoints_)
sp_ids = list(sp_ids)
opt_ids = list(opt_ids)
if do_sp:
sp_records = get_records_with_cache(
client,
base_url_prefix,
record_cache,
SinglepointRecord,
sp_ids,
include=include,
force_fetch=force_fetch,
)
sp_map = {r.id: r for r in sp_records}
if do_opt:
opt_records = get_records_with_cache(
client,
base_url_prefix,
record_cache,
OptimizationRecord,
opt_ids,
include=include,
force_fetch=force_fetch,
)
opt_map = {r.id: r for r in opt_records}
for r in records:
if r.optimizations_ is None:
r.optimization_records_ = None
elif do_opt:
r.optimization_records_ = dict()
for opt_key, opt_info in r.optimizations_.items():
r.optimization_records_[opt_key] = opt_map[opt_info.optimization_id]
if r.singlepoints_ is None:
r.singlepoint_records_ = None
elif do_sp:
r.singlepoint_records_ = dict()
for sp_info in r.singlepoints_:
r.singlepoint_records_.setdefault(sp_info.chain_iteration, list())
r.singlepoint_records_[sp_info.chain_iteration].append(sp_map[sp_info.singlepoint_id])
if len(r.singlepoint_records_) > 0:
if len(r.singlepoint_records_[max(r.singlepoint_records_)]) == 1:
_, temp_list = r.singlepoint_records_.popitem()
r.ts_hessian_ = temp_list[0]
assert r.ts_hessian_.specification.driver == "hessian"
r.propagate_client(r._client, base_url_prefix)
def _fetch_optimizations(self):
if self.optimizations_ is None:
self._assert_online()
self.optimizations_ = self._client.make_request(
"get",
f"api/v1/records/neb/{self.id}/optimizations",
Dict[str, NEBOptimization],
)
self.fetch_children(["optimizations"])
def _fetch_singlepoints(self):
if self.singlepoints_ is None:
self._assert_online()
self.singlepoints_ = self._client.make_request(
"get",
f"api/v1/records/neb/{self.id}/singlepoints",
List[NEBSinglepoint],
)
self.fetch_children(["singlepoints"])
def _fetch_initial_chain(self):
if self.initial_chain_molecule_ids_ is None:
self._assert_online()
self.initial_chain_molecule_ids_ = self._client.make_request(
"get",
f"api/v1/records/neb/{self.id}/initial_chain",
List[int],
)
self.initial_chain_ = self._client.get_molecules(self.initial_chain_molecule_ids_)
def _fetch_neb_result(self):
if self.neb_result_ is None:
self._assert_online()
self.neb_result_ = self._client.make_request(
"get",
f"api/v1/records/neb/{self.id}/neb_result",
Optional[Molecule],
)
[docs]
def get_cache_dict(self, **kwargs) -> Dict[str, Any]:
return self.dict(exclude={"optimization_records_", "singlepoint_records_", "ts_hessian_"}, **kwargs)
@property
def initial_chain(self) -> List[Molecule]:
if self.initial_chain_ is None:
self._fetch_initial_chain()
return self.initial_chain_
@property
def final_chain(self) -> List[SinglepointRecord]:
return self.singlepoints[max(self.singlepoints.keys())]
@property
def singlepoints(self) -> Dict[int, List[SinglepointRecord]]:
if self.singlepoint_records_ is None:
self._fetch_singlepoints()
return self.singlepoint_records_
@property
def result(self):
if self.status != RecordStatusEnum.complete:
raise ValueError("NEB result is only available after the calculation is complete.")
if self.neb_result_ is None:
# Fetch the result if possible
if self._client is not None:
self._fetch_neb_result()
elif self.singlepoints_ is not None:
max_iter = max(self.singlepoints.keys())
max_sp = max(self.singlepoints[max_iter], key=lambda x: x.properties["return_energy"])
self.neb_result_ = max_sp.molecule
else:
# Raise the usual exception
self.assert_online()
return self.neb_result_
@property
def optimizations(self) -> Optional[Dict[str, OptimizationRecord]]:
if self.optimization_records_ is None:
self._fetch_optimizations()
return self.optimization_records_
@property
def ts_optimization(self) -> Optional[OptimizationRecord]:
return self.optimizations.get("transition", None)
@property
def ts_hessian(self) -> Optional[SinglepointRecord]:
if self.singlepoint_records_ is None:
self._fetch_singlepoints()
return self.ts_hessian_
[docs]
def compare_neb_records(record_1: NEBRecord, record_2: NEBRecord):
compare_base_records(record_1, record_2)
assert len(record_1.initial_chain_) == len(record_2.initial_chain_)
for m1, m2 in zip(record_1.initial_chain_, record_2.initial_chain_):
assert m1.get_hash() == m2.get_hash()
if record_1.status == RecordStatusEnum.complete:
assert record_1.result.get_hash() == record_2.result.get_hash()
assert (record_1.singlepoint_records_.keys()) == (record_2.singlepoint_records_.keys())
# sort singlepoint info by iteration and position
assert len(record_1.singlepoints_) == len(record_2.singlepoints_)
singlepoint_info_1 = sorted(record_1.singlepoints_, key=lambda x: (x.chain_iteration, x.position))
singlepoint_info_2 = sorted(record_2.singlepoints_, key=lambda x: (x.chain_iteration, x.position))
assert len(singlepoint_info_1) == len(singlepoint_info_2)
for m1, m2 in zip(singlepoint_info_1, singlepoint_info_2):
assert m1.chain_iteration == m2.chain_iteration
assert m1.position == m2.position
# compare actual records
for k, sp1 in record_1.singlepoint_records_.items():
sp2 = record_2.singlepoint_records_[k]
assert len(sp1) == len(sp2)
for m1, m2 in zip(sp1, sp2):
compare_singlepoint_records(m1, m2)
# Hessian part
assert (record_1.ts_hessian_ is None) == (record_2.ts_hessian_ is None)
if record_1.ts_hessian_ is not None:
compare_singlepoint_records(record_1.ts_hessian_, record_2.ts_hessian_)
# Optimizations
assert (record_1.optimizations_ is None) == (record_2.optimizations_ is None)
if record_1.optimizations_ is not None:
assert len(record_1.optimizations_) == len(record_2.optimizations_)
for k, v in record_1.optimizations_.items():
v2 = record_2.optimizations_[k]
assert v.position == v2.position
assert v.ts == v2.ts
for k, c1 in record_1.optimization_records_.items():
c2 = record_2.optimization_records_[k]
compare_optimization_records(c1, c2)