Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic v2 Overhaul [DNM yet] #321

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4777b4b
Start migrating to Pydantic v2 still very much a WIP
Lnaden Aug 1, 2023
538a62f
Finish initial pass converting to v2 pydantic API. Have not done any …
Lnaden Aug 3, 2023
5bd215b
More headway, starting on tests. have to be careful with serializers …
Lnaden Aug 4, 2023
76b565b
Why does nested serialization behavior change between versions? WHY!?
Lnaden Aug 4, 2023
248676d
Finally have all tests passing. Huzzah! Still need to cleanup
Lnaden Aug 9, 2023
a2e3501
Handled CI and Lint. All tests pass locally.
Lnaden Aug 9, 2023
86d88f9
This is the dumbest dependency tree resolution. Autodoc-pydantic requ…
Lnaden Aug 10, 2023
e7f4549
fix black
Lnaden Aug 10, 2023
4b4c977
and apparently isort
Lnaden Aug 10, 2023
541beb0
So apparently I made a fun problem where black and isort undid each o…
Lnaden Aug 10, 2023
e32570b
Types from typing before Python 3.10 didn't have a `__name__` attribu…
Lnaden Aug 10, 2023
6201290
A found an even more frustrating bug with Numpy NDArray from numpy.ty…
Lnaden Aug 11, 2023
ad7d43b
One last doc bug and black
Lnaden Aug 11, 2023
a4d0ed2
Fixed a serialization problem in Datum because you cannot chain seria…
Lnaden Aug 15, 2023
37fcc09
Fixed serialization of complex ndarrays in Datum
Lnaden Aug 15, 2023
31f0b7c
Dont cast listified complex or ndarray to string on jsonify of data
Lnaden Aug 15, 2023
ad3b633
black
Lnaden Aug 15, 2023
6af962b
Fixed NumPy native types being unknown on how to serialize by Pydantic
Lnaden Aug 17, 2023
01f19ae
Debugging more of the JSON serializing
Lnaden Aug 30, 2023
f2b5989
Fixed typo in Array serialization to JSON where flatten was not calle…
Lnaden Aug 30, 2023
2ca4b14
Removed leftover debugging lines
Lnaden Aug 30, 2023
75e6aee
For serializing models, just try dumping them with model_dump_json wh…
Lnaden Aug 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
More headway, starting on tests. have to be careful with serializers …
…and only use json on numpy arrays for now.
  • Loading branch information
Lnaden committed Aug 4, 2023
commit 5bd215b293975d36cb621ded7c27624df3c5494b
31 changes: 24 additions & 7 deletions qcelemental/datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,28 @@

from decimal import Decimal
from typing import Any, Dict, Optional
from typing_extensions import Annotated

import numpy as np

from pydantic import BaseModel, field_validator
from pydantic import BaseModel, field_validator, ConfigDict, WrapSerializer, SerializerFunctionWrapHandler


def cast_ndarray(v: Any, nxt: SerializerFunctionWrapHandler) -> str:
"""Special helper to list NumPy arrays before serializing"""
if isinstance(v, np.ndarray):
return f'{nxt(v.flatten().tolist())}'
return f'{nxt(v)}'


def cast_complex(v: Any, nxt: SerializerFunctionWrapHandler) -> str:
"""Special helper to serialize NumPy arrays before serializing"""
if isinstance(v, complex):
return f'{nxt((v.real, v.imag))}'
return f'{nxt(v)}'


AnyArrayComplex = Annotated[Any, WrapSerializer(cast_ndarray), WrapSerializer(cast_complex)]


class Datum(BaseModel):
Expand Down Expand Up @@ -35,15 +53,14 @@ class Datum(BaseModel):
numeric: bool
label: str
units: str
data: Any
data: AnyArrayComplex
comment: str = ""
doi: Optional[str] = None
glossary: str = ""

class Config:
extra = "forbid"
allow_mutation = False
json_encoders = {np.ndarray: lambda v: v.flatten().tolist(), complex: lambda v: (v.real, v.imag)}
model_config = ConfigDict(extra="forbid",
frozen=True,
)

def __init__(self, label, units, data, *, comment=None, doi=None, glossary=None, numeric=True):
kwargs = {"label": label, "units": units, "data": data, "numeric": numeric}
Expand Down Expand Up @@ -89,7 +106,7 @@ def __str__(self, label=""):
return "\n".join(text)

def dict(self, *args, **kwargs):
return super().dict(*args, **{**kwargs, **{"exclude_unset": True}})
return super().model_dump(*args, **{**kwargs, **{"exclude_unset": True}})

def to_units(self, units=None):
from .physical_constants import constants
Expand Down
12 changes: 10 additions & 2 deletions qcelemental/info/cpu_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from enum import Enum
from functools import lru_cache
from typing import List, Optional
from typing_extensions import Annotated

from pydantic import Field
from pydantic import Field, BeforeValidator

from ..models import ProtoModel

Expand All @@ -22,6 +23,13 @@ class VendorEnum(str, Enum):
arm = "arm"


def stringify(v) -> str:
return str(v)


Stringify = Annotated[str, BeforeValidator(stringify)]


class InstructionSetEnum(int, Enum):
"""Allowed instruction sets for CPUs in an ordinal enum."""

Expand All @@ -38,7 +46,7 @@ class ProcessorInfo(ProtoModel):
nthreads: Optional[int] = Field(..., description="The maximum number of concurrent threads.")
base_clock: float = Field(..., description="The base clock frequency (GHz).")
boost_clock: Optional[float] = Field(None, description="The boost clock frequency (GHz).")
model: str = Field(..., description="The model number of the chip.")
model: Stringify = Field(..., description="The model number of the chip.")
family: str = Field(..., description="The family of the chip.")
launch_date: Optional[int] = Field(None, description="The launch year of the chip.")
target_use: str = Field(..., description="Target use case (Desktop, Server, etc).")
Expand Down
5 changes: 2 additions & 3 deletions qcelemental/models/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import Field, field_validator

from ..util import blockwise_contract, blockwise_expand
from .basemodels import ProtoModel
from .basemodels import ProtoModel, ExtendedConfigDict
from .types import Array

__all__ = ["AlignmentMill"]
Expand All @@ -27,8 +27,7 @@ class AlignmentMill(ProtoModel):
atommap: Optional[Array[int]] = Field(None, description="Atom exchange map (nat,) for coordinates.") # type: ignore
mirror: bool = Field(False, description="Do mirror invert coordinates?")

class Config:
force_skip_defaults = True
model_config = ExtendedConfigDict(force_skip_defaults=True)

@field_validator("shift")
@classmethod
Expand Down
6 changes: 3 additions & 3 deletions qcelemental/models/basemodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def model_dump(self, **kwargs) -> Dict[str, Any]:

kwargs["exclude"] = (
kwargs.get("exclude", None) or set()
) | self.__config__.serialize_default_excludes # type: ignore
kwargs.setdefault("exclude_unset", self.__config__.serialize_skip_defaults) # type: ignore
if self.__config__.force_skip_defaults: # type: ignore
) | self.model_config["serialize_default_excludes"] # type: ignore
kwargs.setdefault("exclude_unset", self.model_config["serialize_skip_defaults"]) # type: ignore
if self.model_config["force_skip_defaults"]: # type: ignore
kwargs["exclude_unset"] = True

data = super().model_dump(**kwargs)
Expand Down
20 changes: 12 additions & 8 deletions qcelemental/models/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ElectronShell(ProtoModel):
...,
description="General contraction coefficients for the shell; "
"individual list components will be the individual segment contraction coefficients.",
min_items=1,
min_length=1,
)

model_config = ExtendedConfigDict(json_schema_extra=electron_shell_json_schema_extra,
Expand Down Expand Up @@ -111,15 +111,15 @@ class ECPPotential(ProtoModel):

ecp_type: ECPType = Field(..., description=str(ECPType.__doc__))
angular_momentum: List[NonnegativeInt] = Field(
..., description="Angular momentum for the potential as an array of integers.", min_items=1
..., description="Angular momentum for the potential as an array of integers.", min_length=1
)
r_exponents: List[int] = Field(..., description="Exponents of the 'r' term.", min_items=1)
gaussian_exponents: List[float] = Field(..., description="Exponents of the 'gaussian' term.", min_items=1)
r_exponents: List[int] = Field(..., description="Exponents of the 'r' term.", min_length=1)
gaussian_exponents: List[float] = Field(..., description="Exponents of the 'gaussian' term.", min_length=1)
coefficients: List[List[float]] = Field(
...,
description="General contraction coefficients for the potential; "
"individual list components will be the individual segment contraction coefficients.",
min_items=1,
min_length=1,
)

model_config = ExtendedConfigDict(json_schema_extra=ecp_json_schema_extra,
Expand Down Expand Up @@ -163,6 +163,10 @@ class BasisCenter(ProtoModel):
**ProtoModel.model_config)


def basis_set_json_schema_extra(schema, model):
schema["$schema"] = qcschema_draft


class BasisSet(ProtoModel):
"""
A quantum chemistry basis description.
Expand Down Expand Up @@ -191,9 +195,9 @@ class BasisSet(ProtoModel):
description="The number of basis functions. Use for convenience or as checksum",
validate_default=True)

class Config(ProtoModel.Config):
def schema_extra(schema, model):
schema["$schema"] = qcschema_draft
model_config = ExtendedConfigDict(**ProtoModel.model_config,
json_schema_extra=basis_set_json_schema_extra
)

@field_validator("atom_map")
@classmethod
Expand Down
18 changes: 11 additions & 7 deletions qcelemental/models/common_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ class Provenance(ProtoModel):
)
routine: str = Field("", description="The name of the routine or function within the creator, blank otherwise.")

model_config = ExtendedConfigDict(canonical_repr=True,
json_schema_extra=provenance_json_schema_extra,
**ProtoModel.model_config,
extra="allow")
model_config = ExtendedConfigDict(**{**ProtoModel.model_config,
**ExtendedConfigDict(canonical_repr=True,
json_schema_extra=provenance_json_schema_extra,
extra="allow")
}
)


class Model(ProtoModel):
Expand All @@ -52,9 +54,11 @@ class Model(ProtoModel):
)

# basis_spec: BasisSpec = None # This should be exclusive with basis, but for now will be omitted
model_config = ExtendedConfigDict(canonical_repr=True,
**ProtoModel.model_config,
extra="allow")
model_config = ExtendedConfigDict(**{**ProtoModel.model_config,
**ExtendedConfigDict(canonical_repr=True,
extra="allow")
}
)


class DriverEnum(str, Enum):
Expand Down
20 changes: 12 additions & 8 deletions qcelemental/models/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ class Identifiers(ProtoModel):
pubchem_sid: Optional[str] = Field(None, description="PubChem Substance ID")
pubchem_conformerid: Optional[str] = Field(None, description="PubChem Conformer ID")

model_config = ExtendedConfigDict(**ProtoModel.model_config,
serialize_skip_defaults=True
model_config = ExtendedConfigDict(**{**ProtoModel.model_config,
**ExtendedConfigDict(serialize_skip_defaults=True)
}
)


Expand Down Expand Up @@ -342,12 +343,15 @@ class Molecule(ProtoModel):
description="Additional information to bundle with the molecule. Use for schema development and scratch space.",
)

model_config = ExtendedConfigDict(**ProtoModel.model_config,
serialize_skip_defaults=True,
repr_style=lambda self: [("name", self.name),
("formula", self.get_molecular_formula()),
("hash", self.get_hash()[:7])],
json_schema_extra= molecule_json_schema_extras
model_config = ExtendedConfigDict(**{**ProtoModel.model_config,
**ExtendedConfigDict(serialize_skip_defaults=True,
repr_style=lambda self: [("name", self.name),
("formula",
self.get_molecular_formula()),
("hash", self.get_hash()[:7])],
json_schema_extra=molecule_json_schema_extras
)
}
)
# Alias fields are handled with the Field objects above

Expand Down
4 changes: 2 additions & 2 deletions qcelemental/models/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
qcschema_optimization_output_default,
qcschema_torsion_drive_input_default,
qcschema_torsion_drive_output_default,
ExtendedConfigDict,
)
from .molecule import Molecule
from .results import AtomicResult
Expand Down Expand Up @@ -43,8 +44,7 @@ class OptimizationProtocols(ProtoModel):
TrajectoryProtocolEnum.all, description=str(TrajectoryProtocolEnum.__doc__)
)

class Config:
force_skip_defaults = True
model_config = ExtendedConfigDict(force_skip_defaults=True)


class QCInputSpecification(ProtoModel):
Expand Down
Loading