Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic v2 Overhaul [DNM yet] #321

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4777b4b
Start migrating to Pydantic v2 still very much a WIP
Lnaden Aug 1, 2023
538a62f
Finish initial pass converting to v2 pydantic API. Have not done any …
Lnaden Aug 3, 2023
5bd215b
More headway, starting on tests. have to be careful with serializers …
Lnaden Aug 4, 2023
76b565b
Why does nested serialization behavior change between versions? WHY!?
Lnaden Aug 4, 2023
248676d
Finally have all tests passing. Huzzah! Still need to cleanup
Lnaden Aug 9, 2023
a2e3501
Handled CI and Lint. All tests pass locally.
Lnaden Aug 9, 2023
86d88f9
This is the dumbest dependency tree resolution. Autodoc-pydantic requ…
Lnaden Aug 10, 2023
e7f4549
fix black
Lnaden Aug 10, 2023
4b4c977
and apparently isort
Lnaden Aug 10, 2023
541beb0
So apparently I made a fun problem where black and isort undid each o…
Lnaden Aug 10, 2023
e32570b
Types from typing before Python 3.10 didn't have a `__name__` attribu…
Lnaden Aug 10, 2023
6201290
A found an even more frustrating bug with Numpy NDArray from numpy.ty…
Lnaden Aug 11, 2023
ad7d43b
One last doc bug and black
Lnaden Aug 11, 2023
a4d0ed2
Fixed a serialization problem in Datum because you cannot chain seria…
Lnaden Aug 15, 2023
37fcc09
Fixed serialization of complex ndarrays in Datum
Lnaden Aug 15, 2023
31f0b7c
Dont cast listified complex or ndarray to string on jsonify of data
Lnaden Aug 15, 2023
ad3b633
black
Lnaden Aug 15, 2023
6af962b
Fixed NumPy native types being unknown on how to serialize by Pydantic
Lnaden Aug 17, 2023
01f19ae
Debugging more of the JSON serializing
Lnaden Aug 30, 2023
f2b5989
Fixed typo in Array serialization to JSON where flatten was not calle…
Lnaden Aug 30, 2023
2ca4b14
Removed leftover debugging lines
Lnaden Aug 30, 2023
75e6aee
For serializing models, just try dumping them with model_dump_json wh…
Lnaden Aug 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Why does nested serialization behavior change between versions? WHY!?
  • Loading branch information
Lnaden committed Aug 4, 2023
commit 76b565b38eeb399a99451352f22217d258b72296
24 changes: 21 additions & 3 deletions qcelemental/datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from decimal import Decimal
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
from typing_extensions import Annotated

import numpy as np
Expand All @@ -22,10 +22,28 @@ def cast_complex(v: Any, nxt: SerializerFunctionWrapHandler) -> str:
"""Special helper to serialize NumPy arrays before serializing"""
if isinstance(v, complex):
return f'{nxt((v.real, v.imag))}'
return f'{nxt(v)}'
return nxt(v)


def preserve_decimal(v: Any, nxt: SerializerFunctionWrapHandler) -> Union[str, Decimal]:
"""
Ensure Decimal types are preserved on the way out

This arose because Decimal was serialized to string and "dump" is equal to "serialize" in v2 pydantic
https://docs.pydantic.dev/latest/migration/#changes-to-json-schema-generation
"""
if isinstance(v, Decimal):
return v
return nxt(v)


AnyArrayComplex = Annotated[Any, WrapSerializer(cast_ndarray), WrapSerializer(cast_complex)]
# Serializers are pop'd out of the list in FILO (right to left)
AnyArrayComplex = Annotated[
Any,
WrapSerializer(cast_ndarray, when_used="json"),
WrapSerializer(cast_complex, when_used="json"),
WrapSerializer(preserve_decimal)
]


class Datum(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion qcelemental/info/dft_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ def get(name: str) -> DFTFunctionalInfo:
name = name.replace(x, "")
break

return dftfunctionalinfo.functionals[name].copy()
return dftfunctionalinfo.functionals[name].model_copy()
1 change: 1 addition & 0 deletions qcelemental/models/basemodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ProtoModel(BaseModel):
model_config = ExtendedConfigDict(
frozen=True,
extra="forbid",
populate_by_name=True, # Allows using alias to populate
serialize_default_excludes=set(),
serialize_skip_defaults=False,
force_skip_defaults=False
Expand Down
4 changes: 2 additions & 2 deletions qcelemental/models/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ def _check_atom_map(cls, v, info):

@field_validator("nbf")
@classmethod
def _check_nbf(cls, v, values):
def _check_nbf(cls, v, info):
# Bad construction, pass on errors
try:
nbf = cls._calculate_nbf(values["atom_map"], values["center_data"])
nbf = cls._calculate_nbf(info.data["atom_map"], info.data["center_data"])
except KeyError:
return v

Expand Down
16 changes: 10 additions & 6 deletions qcelemental/models/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def orient_molecule(self):
r"""
Centers the molecule and orients via inertia tensor before returning a new Molecule
"""
return Molecule(orient=True, **self.dict())
return Molecule(orient=True, **self.model_dump())

def compare(self, other):
warnings.warn(
Expand All @@ -600,10 +600,14 @@ def __eq__(self, other):

return self.get_hash() == other.get_hash()

def dict(self, *args, **kwargs):
def dict(self, **kwargs):
warnings.warn('The `dict` method is deprecated; use `model_dump` instead.', DeprecationWarning)
return self.model_dump(**kwargs)

def model_dump(self, **kwargs) -> Dict[str, Any]:
kwargs["by_alias"] = True
kwargs["exclude_unset"] = True
return super().model_dump(*args, **kwargs)
return super().model_dump(**kwargs)

def pretty_print(self):
r"""Print the molecule in Angstroms. Same as :py:func:`print_out` only always in Angstroms.
Expand Down Expand Up @@ -790,7 +794,7 @@ def to_string( # type: ignore

Suggest psi4 --> psi4frag and psi4 route to to_string
"""
molrec = from_schema(self.dict(), nonphysical=True)
molrec = from_schema(self.model_dump(), nonphysical=True)
return to_string(
molrec,
dtype=dtype,
Expand Down Expand Up @@ -1291,7 +1295,7 @@ def align(
"atomic_numbers": solution.align_atoms(concern_mol.atomic_numbers),
"mass_numbers": solution.align_atoms(concern_mol.mass_numbers),
}
adict = {**concern_mol.dict(), **aupdate}
adict = {**concern_mol.model_dump(), **aupdate}

# preserve intrinsic symmetry with lighter truncation
amol = Molecule(validate=True, **adict, geometry_noise=13)
Expand Down Expand Up @@ -1415,7 +1419,7 @@ def scramble(
"atomic_numbers": perturbation.align_atoms(ref_mol.atomic_numbers),
"mass_numbers": perturbation.align_atoms(ref_mol.mass_numbers),
}
cdict = {**ref_mol.dict(), **cupdate}
cdict = {**ref_mol.model_dump(), **cupdate}

# preserve intrinsic symmetry with lighter truncation
cmol = Molecule(validate=True, **cdict, geometry_noise=13)
Expand Down
2 changes: 1 addition & 1 deletion qcelemental/models/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class OptimizationInput(ProtoModel):

def __repr_args__(self) -> "ReprArgs":
return [
("model", self.input_specification.model.dict()),
("model", self.input_specification.model.model_dump()),
("molecule_hash", self.initial_molecule.get_hash()[:7]),
]

Expand Down
4 changes: 2 additions & 2 deletions qcelemental/models/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,8 @@ def _assert2d_nao_x(cls, v, info):
"scf_fock_b",
)
@classmethod
def _assert2d(cls, v, values):
bas = values.get("basis", None)
def _assert2d(cls, v, info):
bas = info.data.get("basis", None)

# Do not raise multiple errors
if bas is None:
Expand Down
8 changes: 4 additions & 4 deletions qcelemental/molutil/test_molutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def test_relative_geoms_align_free(request):
do_shift=True, do_rotate=True, do_resort=False, do_plot=False, verbose=2, do_test=True
)

rmolrec = qcel.molparse.from_schema(s22_12.dict())
cmolrec = qcel.molparse.from_schema(cmol.dict())
rmolrec = qcel.molparse.from_schema(s22_12.model_dump())
cmolrec = qcel.molparse.from_schema(cmol.model_dump())
assert compare_molrecs(rmolrec, cmolrec, atol=1.0e-4, relative_geoms="align")


Expand All @@ -68,8 +68,8 @@ def test_relative_geoms_align_fixed(request):
do_shift=False, do_rotate=False, do_resort=False, do_plot=False, verbose=2, do_test=True
)

rmolrec = qcel.molparse.from_schema(s22_12.dict())
cmolrec = qcel.molparse.from_schema(cmol.dict())
rmolrec = qcel.molparse.from_schema(s22_12.model_dump())
cmolrec = qcel.molparse.from_schema(cmol.model_dump())
assert compare_molrecs(rmolrec, cmolrec, atol=1.0e-4, relative_geoms="align")


Expand Down
4 changes: 2 additions & 2 deletions qcelemental/tests/test_model_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def test_result_derivatives_array(request):
def test_model_dictable(result_data_fixture, optimization_data_fixture, smodel):
if smodel == "molecule":
model = qcel.models.Molecule
data = result_data_fixture["molecule"].dict()
data = result_data_fixture["molecule"].model_dump()

elif smodel == "atomicresultproperties":
model = qcel.models.AtomicResultProperties
Expand All @@ -514,7 +514,7 @@ def test_model_dictable(result_data_fixture, optimization_data_fixture, smodel):
data = optimization_data_fixture

instance = model(**data)
assert model(**instance.dict())
assert model(**instance.model_dump())


def test_result_model_deprecations(result_data_fixture, optimization_data_fixture):
Expand Down
38 changes: 19 additions & 19 deletions qcelemental/tests/test_molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@


def test_molecule_data_constructor_numpy():
water_psi = water_dimer_minima.copy()
water_psi = water_dimer_minima.model_copy()
ele = np.array(water_psi.atomic_numbers).reshape(-1, 1)
npwater = np.hstack((ele, water_psi.geometry * qcel.constants.conversion_factor("Bohr", "angstrom")))

Expand All @@ -53,10 +53,10 @@ def test_molecule_data_constructor_numpy():


def test_molecule_data_constructor_dict():
water_psi = water_dimer_minima.copy()
water_psi = water_dimer_minima.model_copy()

# Check the JSON construct/deconstruct
water_from_json = Molecule.from_data(water_psi.dict())
water_from_json = Molecule.from_data(water_psi.model_dump())
assert water_psi == water_from_json

water_from_json = Molecule.from_data(water_psi.json(), "json")
Expand Down Expand Up @@ -134,16 +134,16 @@ def test_molecule_np_constructors():


def test_molecule_compare():
water_molecule2 = water_molecule.copy()
water_molecule2 = water_molecule.model_copy()
assert water_molecule2 == water_molecule

water_molecule3 = water_molecule.copy(update={"geometry": (water_molecule.geometry + np.array([0.1, 0, 0]))})
water_molecule3 = water_molecule.model_copy(update={"geometry": (water_molecule.geometry + np.array([0.1, 0, 0]))})
assert water_molecule != water_molecule3


def test_water_minima_data():
# Give it a name
mol_dict = water_dimer_minima.dict()
mol_dict = water_dimer_minima.model_dump()
mol_dict["name"] = "water dimer"
mol = Molecule(orient=True, **mol_dict)

Expand Down Expand Up @@ -174,7 +174,7 @@ def test_water_minima_data():


def test_water_minima_fragment():
mol = water_dimer_minima.copy()
mol = water_dimer_minima.model_copy()
frag_0 = mol.get_fragment(0, orient=True)
frag_1 = mol.get_fragment(1, orient=True)
assert frag_0.get_hash() == "5f31757232a9a594c46073082534ca8a6806d367" # pragma: allowlist secret
Expand All @@ -194,12 +194,12 @@ def test_water_minima_fragment():


def test_pretty_print():
mol = water_dimer_minima.copy()
mol = water_dimer_minima.model_copy()
assert isinstance(mol.pretty_print(), str)


def test_to_string():
mol = water_dimer_minima.copy()
mol = water_dimer_minima.model_copy()
assert isinstance(mol.to_string("psi4"), str)


Expand Down Expand Up @@ -365,21 +365,21 @@ def test_water_orient():


def test_molecule_errors_extra():
data = water_dimer_minima.dict(exclude_unset=True)
data = water_dimer_minima.model_dump(exclude_unset=True)
data["whatever"] = 5
with pytest.raises(Exception):
Molecule(**data, validate=False)


def test_molecule_errors_connectivity():
data = water_molecule.dict()
data = water_molecule.model_dump()
data["connectivity"] = [(-1, 5, 5)]
with pytest.raises(Exception):
Molecule(**data)


def test_molecule_errors_shape():
data = water_molecule.dict()
data = water_molecule.model_dump()
data["geometry"] = list(range(8))
with pytest.raises(Exception):
Molecule(**data)
Expand All @@ -388,7 +388,7 @@ def test_molecule_errors_shape():
def test_molecule_json_serialization():
assert isinstance(water_dimer_minima.json(), str)

assert isinstance(water_dimer_minima.dict(encoding="json")["geometry"], list)
assert isinstance(water_dimer_minima.model_dump(encoding="json")["geometry"], list)

assert water_dimer_minima == Molecule.from_data(water_dimer_minima.model_dump_json(), dtype="json")

Expand Down Expand Up @@ -521,10 +521,10 @@ def test_molecule_repeated_hashing():
h1 = mol.get_hash()
assert mol.get_molecular_formula() == "H2O2"

mol2 = Molecule(orient=False, **mol.dict())
mol2 = Molecule(orient=False, **mol.model_dump())
assert h1 == mol2.get_hash()

mol3 = Molecule(orient=False, **mol2.dict())
mol3 = Molecule(orient=False, **mol2.model_dump())
assert h1 == mol3.get_hash()


Expand Down Expand Up @@ -694,7 +694,7 @@ def test_sparse_molecule_fields(mol_string, extra_keys):
if extra_keys is not None:
expected_keys |= extra_keys

diff_keys = mol.dict().keys() ^ expected_keys
diff_keys = mol.model_dump().keys() ^ expected_keys
assert len(diff_keys) == 0, f"Diff Keys {diff_keys}"


Expand All @@ -703,11 +703,11 @@ def test_sparse_molecule_connectivity():
A bit of a weird test, but because we set connectivity it should carry through.
"""
mol = Molecule(symbols=["He", "He"], geometry=[0, 0, -2, 0, 0, 2], connectivity=None)
assert "connectivity" in mol.dict()
assert mol.dict()["connectivity"] is None
assert "connectivity" in mol.model_dump()
assert mol.model_dump()["connectivity"] is None

mol = Molecule(symbols=["He", "He"], geometry=[0, 0, -2, 0, 0, 2])
assert "connectivity" not in mol.dict()
assert "connectivity" not in mol.model_dump()


def test_bad_isotope_spec():
Expand Down
6 changes: 3 additions & 3 deletions qcelemental/tests/test_molparse_from_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_psi4_qm_1a():
assert compare_molrecs(fullans, final["qm"], tnm() + ": full")

kmol = Molecule.from_data(subject)
_check_eq_molrec_minimal_model([], kmol.dict(), fullans)
_check_eq_molrec_minimal_model([], kmol.model_dump(), fullans)


def test_psi4_qm_1ab():
Expand Down Expand Up @@ -150,7 +150,7 @@ def test_psi4_qm_1c():
assert compare_molrecs(fullans, final["qm"], tnm() + ": full")

kmol = Molecule.from_data(subject)
_check_eq_molrec_minimal_model([], kmol.dict(), fullans)
_check_eq_molrec_minimal_model([], kmol.model_dump(), fullans)


def test_psi4_qm_1d():
Expand Down Expand Up @@ -346,7 +346,7 @@ def test_psi4_qm_2a():
kmol = Molecule.from_data(subject)
_check_eq_molrec_minimal_model(
["fragments", "fragment_charges", "fragment_multiplicities", "mass_numbers", "masses", "atom_labels", "real"],
kmol.dict(),
kmol.model_dump(),
fullans,
)

Expand Down
8 changes: 4 additions & 4 deletions qcelemental/tests/test_molutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def test_relative_geoms_align_free(request):
do_shift=True, do_rotate=True, do_resort=False, do_plot=False, verbose=2, do_test=True
)

rmolrec = qcel.molparse.from_schema(s22_12.dict())
cmolrec = qcel.molparse.from_schema(cmol.dict())
rmolrec = qcel.molparse.from_schema(s22_12.model_dump())
cmolrec = qcel.molparse.from_schema(cmol.model_dump())
assert compare_molrecs(rmolrec, cmolrec, atol=1.0e-4, relative_geoms="align")


Expand All @@ -68,8 +68,8 @@ def test_relative_geoms_align_fixed(request):
do_shift=False, do_rotate=False, do_resort=False, do_plot=False, verbose=2, do_test=True
)

rmolrec = qcel.molparse.from_schema(s22_12.dict())
cmolrec = qcel.molparse.from_schema(cmol.dict())
rmolrec = qcel.molparse.from_schema(s22_12.model_dump())
cmolrec = qcel.molparse.from_schema(cmol.model_dump())
assert compare_molrecs(rmolrec, cmolrec, atol=1.0e-4, relative_geoms="align")


Expand Down