Skip to content

Commit

Permalink
Merge pull request brucefan1983#399 from brucefan1983/change_virial_o…
Browse files Browse the repository at this point in the history
…utput

Change virial output
  • Loading branch information
brucefan1983 committed Mar 25, 2023
2 parents ea13be6 + 403c390 commit 398fc9d
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 41 deletions.
16 changes: 6 additions & 10 deletions doc/nep/output_files/virial_out.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,11 @@
================

The ``virial_train.out`` and ``virial_test.out`` files contain the predicted and target virials.
There are 2 columns.
The first column gives the virials in units of eV/atom calculated using the :term:`NEP` model.
The second column gives the corresponding target virials in units of eV/atom.

The are :math:`6N_\mathrm{c}` rows, where :math:`N_\mathrm{c}` is the number of configurations in the :ref:`train.xyz and test.xyz input files <train_test_xyz>`.
There are :math:`N_\mathrm{c}` rows, where :math:`N_\mathrm{c}` is the number of configurations in the :ref:`train.xyz and test.xyz input files <train_test_xyz>`.

* The first :math:`N_\mathrm{c}` rows correspond to the :math:`xx` component of the virial.
* The second :math:`N_\mathrm{c}` rows correspond to the :math:`yy` component of the virial.
* The third :math:`N_\mathrm{c}` rows correspond to the :math:`zz` component of the virial.
* The fourth :math:`N_\mathrm{c}` rows correspond to the :math:`xy` component of the virial.
* The fifth :math:`N_\mathrm{c}` rows correspond to the :math:`yz` component of the virial.
* The sixth :math:`N_\mathrm{c}` rows correspond to the :math:`zx` component of the virial.
There are 12 columns.
The first 6 columns give the :math:`xx`, :math:`yy`, :math:`zz`, :math:`xy`, :math:`yz`, and :math:`zx` virial components calculated using the :term:`NEP` model.
The last 6 columns give the corresponding target virials.

The virial values are in units of eV/atom.
47 changes: 20 additions & 27 deletions src/main_nep/fitness.cu
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,25 @@ void Fitness::compute(
}
}

void Fitness::predict_energy_or_stress(FILE* fid, float* data, float* ref, Dataset& dataset)
void Fitness::output(
int num_components, FILE* fid, float* prediction, float* reference, Dataset& dataset)
{
for (int nc = 0; nc < dataset.Nc; ++nc) {
int offset = dataset.Na_sum_cpu[nc];
float data_nc = 0.0f;
for (int m = 0; m < dataset.Na_cpu[nc]; ++m) {
data_nc += data[offset + m];
for (int n = 0; n < num_components; ++n) {
int offset = n * dataset.N + dataset.Na_sum_cpu[nc];
float data_nc = 0.0f;
for (int m = 0; m < dataset.Na_cpu[nc]; ++m) {
data_nc += prediction[offset + m];
}
fprintf(fid, "%g ", data_nc / dataset.Na_cpu[nc]);
}
for (int n = 0; n < num_components; ++n) {
if (n == num_components - 1) {
fprintf(fid, "%g\n", reference[n * dataset.Nc + nc]);
} else {
fprintf(fid, "%g ", reference[n * dataset.Nc + nc]);
}
}
fprintf(fid, "%g %g\n", data_nc / dataset.Na_cpu[nc], ref[nc]);
}
}

Expand Down Expand Up @@ -348,7 +358,6 @@ void Fitness::update_energy_force_virial(
dataset.virial.copy_to_host(dataset.virial_cpu.data());
dataset.force.copy_to_host(dataset.force_cpu.data());

// update force.out
for (int nc = 0; nc < dataset.Nc; ++nc) {
int offset = dataset.Na_sum_cpu[nc];
for (int m = 0; m < dataset.structures[nc].num_atom; ++m) {
Expand All @@ -360,36 +369,20 @@ void Fitness::update_energy_force_virial(
}
}

// update energy.out
predict_energy_or_stress(
fid_energy, dataset.energy_cpu.data(), dataset.energy_ref_cpu.data(), dataset);

// update virial.out
for (int k = 0; k < 6; ++k) {
predict_energy_or_stress(
fid_virial, dataset.virial_cpu.data() + dataset.N * k,
dataset.virial_ref_cpu.data() + dataset.Nc * k, dataset);
}
output(1, fid_energy, dataset.energy_cpu.data(), dataset.energy_ref_cpu.data(), dataset);
output(6, fid_virial, dataset.virial_cpu.data(), dataset.virial_ref_cpu.data(), dataset);
}

void Fitness::update_dipole(FILE* fid_dipole, Dataset& dataset)
{
dataset.virial.copy_to_host(dataset.virial_cpu.data());
for (int k = 0; k < 3; ++k) {
predict_energy_or_stress(
fid_dipole, dataset.virial_cpu.data() + dataset.N * k,
dataset.virial_ref_cpu.data() + dataset.Nc * k, dataset);
}
output(3, fid_dipole, dataset.virial_cpu.data(), dataset.virial_ref_cpu.data(), dataset);
}

void Fitness::update_polarizability(FILE* fid_polarizability, Dataset& dataset)
{
dataset.virial.copy_to_host(dataset.virial_cpu.data());
for (int k = 0; k < 6; ++k) {
predict_energy_or_stress(
fid_polarizability, dataset.virial_cpu.data() + dataset.N * k,
dataset.virial_ref_cpu.data() + dataset.Nc * k, dataset);
}
output(6, fid_polarizability, dataset.virial_cpu.data(), dataset.virial_ref_cpu.data(), dataset);
}

void Fitness::predict(Parameters& para, float* elite)
Expand Down
2 changes: 1 addition & 1 deletion src/main_nep/fitness.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ protected:
std::unique_ptr<Potential> potential;
std::vector<std::vector<Dataset>> train_set;
std::vector<Dataset> test_set;
void predict_energy_or_stress(FILE* fid, float* data, float* ref, Dataset& dataset);
void output(int num_components, FILE* fid, float* prediction, float* reference, Dataset& dataset);
void
update_energy_force_virial(FILE* fid_energy, FILE* fid_force, FILE* fid_virial, Dataset& dataset);
void update_dipole(FILE* fid_dipole, Dataset& dataset);
Expand Down
2 changes: 1 addition & 1 deletion src/main_nep/snes.cu
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ void SNES::compute(Parameters& para, Fitness* fitness_function)
if (para.prediction == 0) {
printf("Started training.\n");
} else {
printf("Started predcting.\n");
printf("Started predicting.\n");
}

print_line_2();
Expand Down
16 changes: 14 additions & 2 deletions src/main_nep/structure.cu
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,13 @@ static void read_one_structure(const Parameters& para, std::ifstream& input, Str
}
}
if (!structure.has_virial) {
PRINT_INPUT_ERROR("'dipole' is missing in the second line of a frame.");
if (para.prediction == 0) {
PRINT_INPUT_ERROR("'dipole' is missing in the second line of a frame.");
} else {
for (int m = 0; m < 6; ++m) {
structure.virial[m] = -1e6;
}
}
}
}

Expand All @@ -283,7 +289,13 @@ static void read_one_structure(const Parameters& para, std::ifstream& input, Str
}
}
if (!structure.has_virial) {
PRINT_INPUT_ERROR("'pol' is missing in the second line of a frame.");
if (para.prediction == 0) {
PRINT_INPUT_ERROR("'pol' is missing in the second line of a frame.");
} else {
for (int m = 0; m < 6; ++m) {
structure.virial[m] = -1e6;
}
}
}
}

Expand Down

0 comments on commit 398fc9d

Please sign in to comment.