Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic v2 Overhaul [DNM yet] #321

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4777b4b
Start migrating to Pydantic v2 still very much a WIP
Lnaden Aug 1, 2023
538a62f
Finish initial pass converting to v2 pydantic API. Have not done any …
Lnaden Aug 3, 2023
5bd215b
More headway, starting on tests. have to be careful with serializers …
Lnaden Aug 4, 2023
76b565b
Why does nested serialization behavior change between versions? WHY!?
Lnaden Aug 4, 2023
248676d
Finally have all tests passing. Huzzah! Still need to cleanup
Lnaden Aug 9, 2023
a2e3501
Handled CI and Lint. All tests pass locally.
Lnaden Aug 9, 2023
86d88f9
This is the dumbest dependency tree resolution. Autodoc-pydantic requ…
Lnaden Aug 10, 2023
e7f4549
fix black
Lnaden Aug 10, 2023
4b4c977
and apparently isort
Lnaden Aug 10, 2023
541beb0
So apparently I made a fun problem where black and isort undid each o…
Lnaden Aug 10, 2023
e32570b
Types from typing before Python 3.10 didn't have a `__name__` attribu…
Lnaden Aug 10, 2023
6201290
A found an even more frustrating bug with Numpy NDArray from numpy.ty…
Lnaden Aug 11, 2023
ad7d43b
One last doc bug and black
Lnaden Aug 11, 2023
a4d0ed2
Fixed a serialization problem in Datum because you cannot chain seria…
Lnaden Aug 15, 2023
37fcc09
Fixed serialization of complex ndarrays in Datum
Lnaden Aug 15, 2023
31f0b7c
Dont cast listified complex or ndarray to string on jsonify of data
Lnaden Aug 15, 2023
ad3b633
black
Lnaden Aug 15, 2023
6af962b
Fixed NumPy native types being unknown on how to serialize by Pydantic
Lnaden Aug 17, 2023
01f19ae
Debugging more of the JSON serializing
Lnaden Aug 30, 2023
f2b5989
Fixed typo in Array serialization to JSON where flatten was not calle…
Lnaden Aug 30, 2023
2ca4b14
Removed leftover debugging lines
Lnaden Aug 30, 2023
75e6aee
For serializing models, just try dumping them with model_dump_json wh…
Lnaden Aug 30, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ jobs:
fail-fast: true
matrix:
python-version: ["3.7", "3.9", "3.11"]
pydantic-version: ["1", "2"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might want to leave this line in for now to help the CI lanes "Required" labels cope.


steps:
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -24,15 +23,6 @@ jobs:
uses: actions/checkout@v3
- name: Install poetry
run: pip install poetry
# Force pydantic 1.0 by modifying poetry dep "pydantic" string with in-place sed
# -i is zero-length extension which does effectively in-place sub.
# Can't do -i '' because Ubuntu sed is -i{suffix} whereas OSX sed is -i {suffix}... ugh
# ^ start of line, pydantic, optional spaces and > sign, capture the version, replace with ^{version}
# Should avoid also replacing the autodoc-pydantic spec later on.
- name: Sed replace pydantic on repo
run: |
sed -i 's/^pydantic *= *">*= *\([0-9.]*\)"/pydantic = "^\1"/' pyproject.toml
if: matrix.pydantic-version == '1'
- name: Install repo with poetry (full deps)
if: matrix.python-version != '3.9'
run: poetry install --no-interaction --no-ansi --all-extras
Expand Down Expand Up @@ -65,10 +55,6 @@ jobs:
python-version: "3.7"
- name: Install poetry
run: pip install poetry
# Force pydantic 1.0 by modifying poetry dep "pydantic" string with in-place sed (see above for details)
- name: Sed replace pydantic on repo
run: |
sed -i 's/^pydantic *= *">*= *\([0-9.]*\)"/pydantic = "^\1"/' pyproject.toml
- name: Install repo
run: poetry install --no-interaction --no-ansi
- name: Build Documentation
Expand Down
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "qcelemental"
version = "0.26.0"
version = "0.27.0"
description = "Core data structures for Quantum Chemistry."
authors = ["The QCArchive Development Team <[email protected]>"]
license = "BSD-3-Clause"
Expand All @@ -31,7 +31,8 @@ numpy = [
]
python = "^3.7"
pint = ">=0.10.0"
pydantic = ">=1.8.2"
pydantic = "^2.1.0"
pydantic-settings = "*" # Maybe remove when Fractal merges next?
nglview = { version = "^3.0.3", optional = true }
ipykernel = { version = "<6.0.0", optional = true }
importlib-metadata = { version = ">=4.8", python = "<3.8" }
Expand Down Expand Up @@ -62,7 +63,7 @@ docutils = "<0.19"
sphinx = "<6.0.0"
sphinxcontrib-napoleon = "^0.7"
sphinx-rtd-theme = "^1.2.0"
autodoc-pydantic = "^1.8.0"
autodoc-pydantic = "^2.0.0"
sphinx-automodapi = "^0.15.0"
sphinx-autodoc-typehints = "^1.22"

Expand Down
65 changes: 49 additions & 16 deletions qcelemental/datum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,46 @@
"""

from decimal import Decimal
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import numpy as np
from pydantic import BaseModel, ConfigDict, SerializerFunctionWrapHandler, WrapSerializer, field_validator
from typing_extensions import Annotated

try:
from pydantic.v1 import BaseModel, validator
except ImportError: # Will also trap ModuleNotFoundError
from pydantic import BaseModel, validator

def cast_ndarray(v: Any, nxt: SerializerFunctionWrapHandler) -> str:
"""Special helper to list NumPy arrays before serializing"""
if isinstance(v, np.ndarray):
return f"{nxt(v.flatten().tolist())}"
return f"{nxt(v)}"


def cast_complex(v: Any, nxt: SerializerFunctionWrapHandler) -> str:
"""Special helper to serialize NumPy arrays before serializing"""
if isinstance(v, complex):
return f"{nxt((v.real, v.imag))}"
return nxt(v)


def preserve_decimal(v: Any, nxt: SerializerFunctionWrapHandler) -> Union[str, Decimal]:
"""
Ensure Decimal types are preserved on the way out

This arose because Decimal was serialized to string and "dump" is equal to "serialize" in v2 pydantic
https://docs.pydantic.dev/latest/migration/#changes-to-json-schema-generation
"""
if isinstance(v, Decimal):
return v
return nxt(v)


# Serializers are pop'd out of the list in FILO (right to left)
AnyArrayComplex = Annotated[
Any,
WrapSerializer(cast_ndarray, when_used="json"),
WrapSerializer(cast_complex, when_used="json"),
WrapSerializer(preserve_decimal),
]


class Datum(BaseModel):
Expand Down Expand Up @@ -38,15 +70,15 @@ class Datum(BaseModel):
numeric: bool
label: str
units: str
data: Any
data: AnyArrayComplex
comment: str = ""
doi: Optional[str] = None
glossary: str = ""

class Config:
extra = "forbid"
allow_mutation = False
json_encoders = {np.ndarray: lambda v: v.flatten().tolist(), complex: lambda v: (v.real, v.imag)}
model_config = ConfigDict(
extra="forbid",
frozen=True,
)

def __init__(self, label, units, data, *, comment=None, doi=None, glossary=None, numeric=True):
kwargs = {"label": label, "units": units, "data": data, "numeric": numeric}
Expand All @@ -59,20 +91,21 @@ def __init__(self, label, units, data, *, comment=None, doi=None, glossary=None,

super().__init__(**kwargs)

@validator("data")
def must_be_numerical(cls, v, values, **kwargs):
@field_validator("data")
@classmethod
def must_be_numerical(cls, v, info):
try:
1.0 * v
except TypeError:
try:
Decimal("1.0") * v
except TypeError:
if values["numeric"]:
if info.data["numeric"]:
raise ValueError(f"Datum data should be float, Decimal, or np.ndarray, not {type(v)}.")
else:
values["numeric"] = True
info.data["numeric"] = True
else:
values["numeric"] = True
info.data["numeric"] = True

return v

Expand All @@ -91,7 +124,7 @@ def __str__(self, label=""):
return "\n".join(text)

def dict(self, *args, **kwargs):
return super().dict(*args, **{**kwargs, **{"exclude_unset": True}})
return super().model_dump(*args, **{**kwargs, **{"exclude_unset": True}})

def to_units(self, units=None):
from .physical_constants import constants
Expand Down
21 changes: 13 additions & 8 deletions qcelemental/info/cpu_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
from functools import lru_cache
from typing import List, Optional

try:
from pydantic.v1 import Field
except ImportError: # Will also trap ModuleNotFoundError
from pydantic import Field
from pydantic import BeforeValidator, Field
from typing_extensions import Annotated

from ..models import ProtoModel

Expand All @@ -25,6 +23,13 @@ class VendorEnum(str, Enum):
arm = "arm"


def stringify(v) -> str:
return str(v)


Stringify = Annotated[str, BeforeValidator(stringify)]


class InstructionSetEnum(int, Enum):
"""Allowed instruction sets for CPUs in an ordinal enum."""

Expand All @@ -40,13 +45,13 @@ class ProcessorInfo(ProtoModel):
ncores: int = Field(..., description="The number of physical cores on the chip.")
nthreads: Optional[int] = Field(..., description="The maximum number of concurrent threads.")
base_clock: float = Field(..., description="The base clock frequency (GHz).")
boost_clock: Optional[float] = Field(..., description="The boost clock frequency (GHz).")
model: str = Field(..., description="The model number of the chip.")
boost_clock: Optional[float] = Field(None, description="The boost clock frequency (GHz).")
model: Stringify = Field(..., description="The model number of the chip.")
family: str = Field(..., description="The family of the chip.")
launch_date: Optional[int] = Field(..., description="The launch year of the chip.")
launch_date: Optional[int] = Field(None, description="The launch year of the chip.")
target_use: str = Field(..., description="Target use case (Desktop, Server, etc).")
vendor: VendorEnum = Field(..., description="The vendor the chip is produced by.")
microarchitecture: Optional[str] = Field(..., description="The microarchitecture the chip follows.")
microarchitecture: Optional[str] = Field(None, description="The microarchitecture the chip follows.")
instructions: InstructionSetEnum = Field(..., description="The maximum vectorized instruction set available.")
type: str = Field(..., description="The type of chip (cpu, gpu, etc).")

Expand Down
7 changes: 2 additions & 5 deletions qcelemental/info/dft_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@

from typing import Dict

try:
from pydantic.v1 import Field
except ImportError: # Will also trap ModuleNotFoundError
from pydantic import Field
from pydantic import Field

from ..models import ProtoModel

Expand Down Expand Up @@ -71,4 +68,4 @@ def get(name: str) -> DFTFunctionalInfo:
name = name.replace(x, "")
break

return dftfunctionalinfo.functionals[name].copy()
return dftfunctionalinfo.functionals[name].model_copy()
21 changes: 9 additions & 12 deletions qcelemental/models/align.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from typing import Optional

import numpy as np

try:
from pydantic.v1 import Field, validator
except ImportError: # Will also trap ModuleNotFoundError
from pydantic import Field, validator
from pydantic import Field, field_validator

from ..util import blockwise_contract, blockwise_expand
from .basemodels import ProtoModel
from .basemodels import ExtendedConfigDict, ProtoModel
from .types import Array

__all__ = ["AlignmentMill"]
Expand All @@ -30,19 +26,20 @@ class AlignmentMill(ProtoModel):
atommap: Optional[Array[int]] = Field(None, description="Atom exchange map (nat,) for coordinates.") # type: ignore
mirror: bool = Field(False, description="Do mirror invert coordinates?")

class Config:
force_skip_defaults = True
model_config = ExtendedConfigDict(force_skip_defaults=True)

@validator("shift")
def _must_be_3(cls, v, values, **kwargs):
@field_validator("shift")
@classmethod
def _must_be_3(cls, v):
try:
v = v.reshape(3)
except (ValueError, AttributeError):
raise ValueError("Shift must be castable to shape (3,)!")
return v

@validator("rotation")
def _must_be_33(cls, v, values, **kwargs):
@field_validator("rotation")
@classmethod
def _must_be_33(cls, v):
try:
v = v.reshape(3, 3)
except (ValueError, AttributeError):
Expand Down
Loading
Loading