Skip to content

Commit

Permalink
test float mult
Browse files Browse the repository at this point in the history
  • Loading branch information
loriab committed Sep 30, 2023
1 parent 506c791 commit ed797b3
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
25 changes: 21 additions & 4 deletions qcelemental/models/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class Molecule(ProtoModel):
description="Additional comments for this molecule. Intended for pure human/user consumption and clarity.",
)
molecular_charge: float = Field(0.0, description="The net electrostatic charge of the molecule.") # type: ignore
molecular_multiplicity: Union[int, float] = Field(1, description="The total multiplicity of the molecule.") # type: ignore
molecular_multiplicity: float = Field(1, description="The total multiplicity of the molecule.") # type: ignore

# Atom data
masses_: Optional[Array[float]] = Field( # type: ignore
Expand Down Expand Up @@ -251,7 +251,7 @@ class Molecule(ProtoModel):
"if not provided (and :attr:`~qcelemental.models.Molecule.fragments` are specified).",
shape=["nfr"],
)
fragment_multiplicities_: Optional[List[Union[int, float]]] = Field( # type: ignore
fragment_multiplicities_: Optional[List[float]] = Field( # type: ignore
None,
description="The multiplicity of each fragment in the :attr:`~qcelemental.models.Molecule.fragments` list. The index of this "
"list matches the 0-index indices of :attr:`~qcelemental.models.Molecule.fragments` list. Will be filled in based on a set of "
Expand Down Expand Up @@ -397,7 +397,7 @@ def _populate_real(cls, v, values, **kwargs):
v = np.array([True for _ in range(n)])
return v

@validator("fragment_charges_", "fragment_multiplicities_")
@validator("fragment_charges_")
def _must_be_n_frag(cls, v, values, **kwargs):
if "fragments_" in values and values["fragments_"] is not None:
n = len(values["fragments_"])
Expand All @@ -407,6 +407,23 @@ def _must_be_n_frag(cls, v, values, **kwargs):
)
return v

@validator("fragment_multiplicities_")
def _must_be_n_frag_mult(cls, v, values, **kwargs):
if "fragments_" in values and values["fragments_"] is not None:
n = len(values["fragments_"])
if len(v) != n:
raise ValueError(
"Fragment Charges and Fragment Multiplicities must be same number of entries as Fragments"
)
int_ized_v = [(int(m) if m.is_integer() else m) for m in v]
return int_ized_v

@validator("molecular_multiplicity")
def _int_if_possible(cls, v, values, **kwargs):
if v.is_integer():
v = int(v)
return v

@property
def hash_fields(self):
return [
Expand Down Expand Up @@ -478,7 +495,7 @@ def fragment_charges(self) -> List[float]:
return fragment_charges

@property
def fragment_multiplicities(self) -> List[int]:
def fragment_multiplicities(self) -> List[float]:
fragment_multiplicities = self.__dict__.get("fragment_multiplicities_")
if fragment_multiplicities is None:
fragment_multiplicities = [self.molecular_multiplicity]
Expand Down
21 changes: 21 additions & 0 deletions qcelemental/tests/test_molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,3 +734,24 @@ def test_extras():

mol = qcel.models.Molecule(symbols=["He"], geometry=[0, 0, 0], extras={"foo": "bar"})
assert mol.extras["foo"] == "bar"


@pytest.mark.parametrize(
"mult_in,mult_store",
[
pytest.param(3, 3),
pytest.param(3.1, 3.1),
pytest.param(3.00001, 3.00001),
pytest.param(3.0, 3),
pytest.param(1.0, 1),
pytest.param(1, 1),
pytest.param(2.000000000000000000002, 2),
pytest.param(2.000000000000002, 2.000000000000002),
],
)
def test_mol_multiplicity_types(mult_in, mult_store):
# validate=False b/c molparse can't check the physics of chg/mult for float multiplicity
mol = qcel.models.Molecule(symbols=["He"], geometry=[0, 0, 0], molecular_multiplicity=mult_in, validate=False)

assert mult_store == mol.molecular_multiplicity
assert type(mult_store) is type(mol.molecular_multiplicity)

0 comments on commit ed797b3

Please sign in to comment.