Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic v2 Overhaul [DNM yet] #321

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4777b4b
Start migrating to Pydantic v2 still very much a WIP
Lnaden Aug 1, 2023
538a62f
Finish initial pass converting to v2 pydantic API. Have not done any …
Lnaden Aug 3, 2023
5bd215b
More headway, starting on tests. have to be careful with serializers …
Lnaden Aug 4, 2023
76b565b
Why does nested serialization behavior change between versions? WHY!?
Lnaden Aug 4, 2023
248676d
Finally have all tests passing. Huzzah! Still need to cleanup
Lnaden Aug 9, 2023
a2e3501
Handled CI and Lint. All tests pass locally.
Lnaden Aug 9, 2023
86d88f9
This is the dumbest dependency tree resolution. Autodoc-pydantic requ…
Lnaden Aug 10, 2023
e7f4549
fix black
Lnaden Aug 10, 2023
4b4c977
and apparently isort
Lnaden Aug 10, 2023
541beb0
So apparently I made a fun problem where black and isort undid each o…
Lnaden Aug 10, 2023
e32570b
Types from typing before Python 3.10 didn't have a `__name__` attribu…
Lnaden Aug 10, 2023
6201290
A found an even more frustrating bug with Numpy NDArray from numpy.ty…
Lnaden Aug 11, 2023
ad7d43b
One last doc bug and black
Lnaden Aug 11, 2023
a4d0ed2
Fixed a serialization problem in Datum because you cannot chain seria…
Lnaden Aug 15, 2023
37fcc09
Fixed serialization of complex ndarrays in Datum
Lnaden Aug 15, 2023
31f0b7c
Dont cast listified complex or ndarray to string on jsonify of data
Lnaden Aug 15, 2023
ad3b633
black
Lnaden Aug 15, 2023
6af962b
Fixed NumPy native types being unknown on how to serialize by Pydantic
Lnaden Aug 17, 2023
01f19ae
Debugging more of the JSON serializing
Lnaden Aug 30, 2023
f2b5989
Fixed typo in Array serialization to JSON where flatten was not calle…
Lnaden Aug 30, 2023
2ca4b14
Removed leftover debugging lines
Lnaden Aug 30, 2023
75e6aee
For serializing models, just try dumping them with model_dump_json wh…
Lnaden Aug 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Finish initial pass converting to v2 pydantic API. Have not done any …
…testing so very much a WIP still.
  • Loading branch information
Lnaden committed Aug 3, 2023
commit 538a62f972effcbe898506c4770cf2f7e2ee6b49
33 changes: 18 additions & 15 deletions qcelemental/models/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import numpy as np

from pydantic import Field, constr, field_validator
from pydantic.v1 import Field as FE

# molparse imports separated b/c https://github.com/python/mypy/issues/7203
from ..molparse.from_arrays import from_arrays
Expand Down Expand Up @@ -401,34 +400,38 @@ def __init__(self, orient: bool = False, validate: Optional[bool] = None, **kwar
elif validate or geometry_prep:
values["geometry"] = float_prep(values["geometry"], geometry_noise)

@validator("geometry")
def _must_be_3n(cls, v, values, **kwargs):
n = len(values["symbols"])
@field_validator("geometry")
@classmethod
def _must_be_3n(cls, v, info):
n = len(info.data["symbols"])
try:
v = v.reshape(n, 3)
except (ValueError, AttributeError):
raise ValueError("Geometry must be castable to shape (N,3)!")
return v

@validator("masses_", "real_")
def _must_be_n(cls, v, values, **kwargs):
n = len(values["symbols"])
@field_validator("masses_", "real_")
@classmethod
def _must_be_n(cls, v, info):
n = len(info.data["symbols"])
if len(v) != n:
raise ValueError("Masses and Real must be same number of entries as Symbols")
return v

@validator("real_")
def _populate_real(cls, v, values, **kwargs):
@field_validator("real_")
@classmethod
def _populate_real(cls, v, info):
# Can't use geometry here since its already been validated and not in values
n = len(values["symbols"])
n = len(info.data["symbols"])
if len(v) == 0:
v = np.array([True for _ in range(n)])
return v

@validator("fragment_charges_", "fragment_multiplicities_")
def _must_be_n_frag(cls, v, values, **kwargs):
if "fragments_" in values and values["fragments_"] is not None:
n = len(values["fragments_"])
@field_validator("fragment_charges_", "fragment_multiplicities_")
@classmethod
def _must_be_n_frag(cls, v, info):
if "fragments_" in info.data and info.data["fragments_"] is not None:
n = len(info.data["fragments_"])
if len(v) != n:
raise ValueError(
"Fragment Charges and Fragment Multiplicities must be same number of entries as Fragments"
Expand Down Expand Up @@ -596,7 +599,7 @@ def __eq__(self, other):
def dict(self, *args, **kwargs):
kwargs["by_alias"] = True
kwargs["exclude_unset"] = True
return super().dict(*args, **kwargs)
return super().model_dump(*args, **kwargs)

def pretty_print(self):
r"""Print the molecule in Angstroms. Same as :py:func:`print_out` only always in Angstroms.
Expand Down
45 changes: 24 additions & 21 deletions qcelemental/models/procedures.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

try:
from pydantic.v1 import Field, conlist, constr, validator
except ImportError: # Will also trap ModuleNotFoundError
from pydantic import Field, conlist, constr, validator
from pydantic import Field, conlist, constr, field_validator

from ..util import provenance_stamp
from .basemodels import ProtoModel
Expand All @@ -23,10 +20,7 @@
from .results import AtomicResult

if TYPE_CHECKING:
try:
from pydantic.v1.typing import ReprArgs
except ImportError: # Will also trap ModuleNotFoundError
from pydantic.typing import ReprArgs
from .common_models import ReprArgs


class TrajectoryProtocolEnum(str, Enum):
Expand Down Expand Up @@ -58,7 +52,7 @@ class QCInputSpecification(ProtoModel):
A compute description for energy, gradient, and Hessian computations used in a geometry optimization.
"""

schema_name: constr(strip_whitespace=True, regex=qcschema_input_default) = qcschema_input_default # type: ignore
schema_name: constr(strip_whitespace=True, pattern=qcschema_input_default) = qcschema_input_default # type: ignore
schema_version: int = 1

driver: DriverEnum = Field(DriverEnum.gradient, description=str(DriverEnum.__doc__))
Expand All @@ -75,7 +69,7 @@ class OptimizationInput(ProtoModel):
id: Optional[str] = None
hash_index: Optional[str] = None
schema_name: constr( # type: ignore
strip_whitespace=True, regex=qcschema_optimization_input_default
strip_whitespace=True, pattern=qcschema_optimization_input_default
) = qcschema_optimization_input_default
schema_version: int = 1

Expand All @@ -97,7 +91,7 @@ def __repr_args__(self) -> "ReprArgs":

class OptimizationResult(OptimizationInput):
schema_name: constr( # type: ignore
strip_whitespace=True, regex=qcschema_optimization_output_default
strip_whitespace=True, pattern=qcschema_optimization_output_default
) = qcschema_optimization_output_default

final_molecule: Optional[Molecule] = Field(..., description="The final molecule of the geometry optimization.")
Expand All @@ -115,13 +109,14 @@ class OptimizationResult(OptimizationInput):
error: Optional[ComputeError] = Field(None, description=str(ComputeError.__doc__))
provenance: Provenance = Field(..., description=str(Provenance.__doc__))

@validator("trajectory", each_item=False)
def _trajectory_protocol(cls, v, values):
@field_validator("trajectory")
@classmethod
def _trajectory_protocol(cls, v, info):
# Do not propogate validation errors
if "protocols" not in values:
if "protocols" not in info.data:
raise ValueError("Protocols was not properly formed.")

keep_enum = values["protocols"].trajectory
keep_enum = info.data["protocols"].trajectory
if keep_enum == "all":
pass
elif keep_enum == "initial_and_final":
Expand All @@ -148,14 +143,17 @@ class OptimizationSpecification(ProtoModel):
* This class is still provisional and may be subject to removal and re-design.
"""

schema_name: constr(strip_whitespace=True, regex="qcschema_optimization_specification") = "qcschema_optimization_specification" # type: ignore
schema_name: constr(strip_whitespace=True,
pattern="qcschema_optimization_specification"
) = "qcschema_optimization_specification" # type: ignore
schema_version: int = 1

procedure: str = Field(..., description="Optimization procedure to run the optimization with.")
keywords: Dict[str, Any] = Field({}, description="The optimization specific keywords to be used.")
protocols: OptimizationProtocols = Field(OptimizationProtocols(), description=str(OptimizationProtocols.__doc__))

@validator("procedure")
@field_validator("procedure")
@classmethod
def _check_procedure(cls, v):
return v.lower()

Expand Down Expand Up @@ -205,14 +203,16 @@ class TorsionDriveInput(ProtoModel):
* This class is still provisional and may be subject to removal and re-design.
"""

schema_name: constr(strip_whitespace=True, regex=qcschema_torsion_drive_input_default) = qcschema_torsion_drive_input_default # type: ignore
schema_name: constr(strip_whitespace=True,
pattern=qcschema_torsion_drive_input_default
) = qcschema_torsion_drive_input_default # type: ignore
schema_version: int = 1

keywords: TDKeywords = Field(..., description="The torsion drive specific keywords to be used.")
extras: Dict[str, Any] = Field({}, description="Extra fields that are not part of the schema.")

input_specification: QCInputSpecification = Field(..., description=str(QCInputSpecification.__doc__))
initial_molecule: conlist(item_type=Molecule, min_items=1) = Field(
initial_molecule: conlist(item_type=Molecule, min_length=1) = Field(
..., description="The starting molecule(s) for the torsion drive."
)

Expand All @@ -222,7 +222,8 @@ class TorsionDriveInput(ProtoModel):

provenance: Provenance = Field(Provenance(**provenance_stamp(__name__)), description=str(Provenance.__doc__))

@validator("input_specification")
@field_validator("input_specification")
@classmethod
def _check_input_specification(cls, value):
assert value.driver == DriverEnum.gradient, "driver must be set to gradient"
return value
Expand All @@ -236,7 +237,9 @@ class TorsionDriveResult(TorsionDriveInput):
* This class is still provisional and may be subject to removal and re-design.
"""

schema_name: constr(strip_whitespace=True, regex=qcschema_torsion_drive_output_default) = qcschema_torsion_drive_output_default # type: ignore
schema_name: constr(strip_whitespace=True,
pattern=qcschema_torsion_drive_output_default
) = qcschema_torsion_drive_output_default # type: ignore
schema_version: int = 1

final_energies: Dict[str, float] = Field(
Expand Down
Loading