Source code for qcportal.metadata_models

from __future__ import annotations

from collections.abc import Sequence

from pydantic import Field, field_validator, model_validator, BaseModel


[docs] class InsertMetadata(BaseModel): """ Metadata returned by insertion / adding functions """ # Integers in errors, inserted, existing are indices in the input/output list error_description: str | None = None errors: list[tuple[int, str]] = Field(default_factory=list) inserted_idx: list[int] = Field(default_factory=list) # inserted into the db existing_idx: list[int] = Field(default_factory=list) # existing but not updated @property def n_inserted(self): return len(self.inserted_idx) @property def n_existing(self): return len(self.existing_idx) @property def n_errors(self): return len(self.errors) @property def error_idx(self): return [x[0] for x in self.errors] @property def success(self): return self.error_description is None and len(self.errors) == 0 @property def error_string(self): s = "" if self.error_description: s += self.error_description + "\n" s += "\n".join(f" Index {x}: {y}" for x, y in self.errors) return s
[docs] @field_validator("errors", "inserted_idx", "existing_idx", mode="before") @classmethod def sort_fields(cls, v): return sorted(v)
[docs] @model_validator(mode="after") def check_all_indices(self): # Test that all indices are accounted for and that the same index doesn't show up in # inserted_idx, existing_idx, or errors ins_idx = set(self.inserted_idx) existing_idx = set(self.existing_idx) error_idx = set(x[0] for x in self.errors) if not ins_idx.isdisjoint(existing_idx): intersection = ins_idx.intersection(existing_idx) raise ValueError(f"inserted_idx and existing_idx are not disjoint: intersection={intersection}") if not ins_idx.isdisjoint(error_idx): intersection = ins_idx.intersection(error_idx) raise ValueError(f"inserted_idx and error_idx are not disjoint: intersection={intersection}") if not existing_idx.isdisjoint(error_idx): intersection = existing_idx.intersection(error_idx) raise ValueError(f"existing_idx and error_idx are not disjoint: intersection={intersection}") all_idx = ins_idx | existing_idx | error_idx # Skip the rest if we don't have any data if len(all_idx) == 0: return self # Are all the indices accounted for? all_possible = set(range(max(all_idx) + 1)) if all_idx != all_possible: missing = all_possible - all_idx raise ValueError(f"All indices are not accounted for. Max is {max(all_idx)} and we are missing {missing}") return self
[docs] @staticmethod def merge(metadata: Sequence[InsertMetadata]) -> InsertMetadata: new_inserted_idx: list[int] = [] new_existing_idx: list[int] = [] new_errors: list[tuple[int, str]] = [] new_error_description: str | None = None base_idx = 0 for m in metadata: new_inserted_idx.extend(i + base_idx for i in m.inserted_idx) new_existing_idx.extend(i + base_idx for i in m.existing_idx) new_errors.extend((i + base_idx, e) for i, e in m.errors) if m.error_description is not None: if new_error_description is None: new_error_description = m.error_description else: new_error_description += "\n" + m.error_description base_idx += len(m.inserted_idx) + len(m.existing_idx) + len(m.errors) return InsertMetadata( inserted_idx=new_inserted_idx, existing_idx=new_existing_idx, errors=new_errors, error_description=new_error_description, )
[docs] class InsertCountsMetadata(BaseModel): """ Metadata returned by insertion / adding functions, only including counts """ # Integers in errors, inserted, existing are indices in the input/output list n_inserted: int n_existing: int error_description: str | None = None errors: list[str] = Field(default_factory=list) @property def n_errors(self): return len(self.errors) @property def success(self): return self.error_description is None and len(self.errors) == 0 @property def error_string(self): s = "" if self.error_description: s += self.error_description + "\n" s += "\n".join(f" Index {x}: {y}" for x, y in self.errors) return s
[docs] @staticmethod def from_insert_metadata(insert_meta: InsertMetadata) -> InsertCountsMetadata: return InsertCountsMetadata( n_inserted=insert_meta.n_inserted, n_existing=insert_meta.n_existing, error_description=insert_meta.error_description, errors=[e for _, e in insert_meta.errors], )
[docs] class DeleteMetadata(BaseModel): """ Metadata returned by delete functions """ # Integers in errors, missing, found are indices in the input/output list error_description: str | None = None errors: list[tuple[int, str]] = Field(default_factory=list) deleted_idx: list[int] = Field(default_factory=list) n_children_deleted: int = 0 @property def n_deleted(self): return len(self.deleted_idx) @property def n_errors(self): return len(self.errors) @property def error_idx(self): return [x[0] for x in self.errors] @property def success(self): return self.error_description is None and len(self.errors) == 0 @property def error_string(self): s = "" if self.error_description: s += self.error_description + "\n" s += "\n".join(f" Index {x}: {y}" for x, y in self.errors) return s
[docs] @field_validator("errors", "deleted_idx", mode="before") @classmethod def sort_fields(cls, v): return sorted(v)
[docs] @model_validator(mode="after") def check_all_indices(self): # Test that all indices are accounted for and that the same index doesn't show up in # deleted_idx, or errors del_idx = set(self.deleted_idx) error_idx = set(x[0] for x in self.errors) if not del_idx.isdisjoint(error_idx): intersection = del_idx.intersection(error_idx) raise ValueError(f"deleted_idx and error_idx are not disjoint: intersection={intersection}") all_idx = del_idx | error_idx # Skip the rest if we don't have any data if len(all_idx) == 0: return self # Are all the indices accounted for? all_possible = set(range(max(all_idx) + 1)) if all_idx != all_possible: missing = all_possible - all_idx raise ValueError(f"All indices are not accounted for. Max is {max(all_idx)} and we are missing {missing}") return self
[docs] class UpdateMetadata(BaseModel): """ Metadata returned by update functions """ # Integers in errors, updated_idx error_description: str | None = None errors: list[tuple[int, str]] = Field(default_factory=list) updated_idx: list[int] = Field(default_factory=list) # inserted into the db n_children_updated: int = 0 @property def n_updated(self): return len(self.updated_idx) @property def n_errors(self): return len(self.errors) @property def error_idx(self): return [x[0] for x in self.errors] @property def success(self): return self.error_description is None and len(self.errors) == 0 @property def error_string(self): s = "" if self.error_description: s += self.error_description + "\n" s += "\n".join(f" Index {x}: {y}" for x, y in self.errors) return s
[docs] @field_validator("errors", "updated_idx", mode="before") @classmethod def sort_fields(cls, v): return sorted(v)
[docs] @model_validator(mode="after") def check_all_indices(self): # Test that all indices are accounted for and that the same index doesn't show up in # inserted_idx, existing_idx, or errors upd_idx = set(self.updated_idx) error_idx = set(x[0] for x in self.errors) if not upd_idx.isdisjoint(error_idx): intersection = upd_idx.intersection(error_idx) raise ValueError(f"updated_idx and error_idx are not disjoint: intersection={intersection}") all_idx = upd_idx | error_idx # Skip the rest if we don't have any data if len(all_idx) == 0: return self # Are all the indices accounted for? all_possible = set(range(max(all_idx) + 1)) if all_idx != all_possible: missing = all_possible - all_idx raise ValueError(f"All indices are not accounted for. Max is {max(all_idx)} and we are missing {missing}") return self
[docs] class TaskReturnMetadata(BaseModel): """ Metadata returned to managers that have sent completed tasks back to the server """ # Integers in errors, accepted_ids are task ids error_description: str | None = None rejected_info: list[tuple[int, str]] = Field(default_factory=list) accepted_ids: list[int] = Field(default_factory=list) # Accepted by the server @property def n_accepted(self): return len(self.accepted_ids) @property def n_rejected(self): return len(self.rejected_ids) @property def rejected_ids(self): return [x[0] for x in self.rejected_info] @property def success(self): return self.error_description is None @property def error_string(self): s = "" if self.error_description: s += self.error_description + "\n" s += "\n".join(f" Task id {x}: {y}" for x, y in self.rejected_info) return s
[docs] @field_validator("rejected_info", "accepted_ids", mode="before") @classmethod def sort_fields(cls, v): return sorted(v)