diff --git a/doc/nep/output_files/virial_out.rst b/doc/nep/output_files/virial_out.rst index 86de8c199..cb63697f8 100644 --- a/doc/nep/output_files/virial_out.rst +++ b/doc/nep/output_files/virial_out.rst @@ -7,15 +7,11 @@ ================ The ``virial_train.out`` and ``virial_test.out`` files contain the predicted and target virials. -There are 2 columns. -The first column gives the virials in units of eV/atom calculated using the :term:`NEP` model. -The second column gives the corresponding target virials in units of eV/atom. -The are :math:`6N_\mathrm{c}` rows, where :math:`N_\mathrm{c}` is the number of configurations in the :ref:`train.xyz and test.xyz input files `. +There are :math:`N_\mathrm{c}` rows, where :math:`N_\mathrm{c}` is the number of configurations in the :ref:`train.xyz and test.xyz input files `. -* The first :math:`N_\mathrm{c}` rows correspond to the :math:`xx` component of the virial. -* The second :math:`N_\mathrm{c}` rows correspond to the :math:`yy` component of the virial. -* The third :math:`N_\mathrm{c}` rows correspond to the :math:`zz` component of the virial. -* The fourth :math:`N_\mathrm{c}` rows correspond to the :math:`xy` component of the virial. -* The fifth :math:`N_\mathrm{c}` rows correspond to the :math:`yz` component of the virial. -* The sixth :math:`N_\mathrm{c}` rows correspond to the :math:`zx` component of the virial. +There are 12 columns. +The first 6 columns give the :math:`xx`, :math:`yy`, :math:`zz`, :math:`xy`, :math:`yz`, and :math:`zx` virial components calculated using the :term:`NEP` model. +The last 6 columns give the corresponding target virials. + +The virial values are in units of eV/atom. diff --git a/src/main_nep/fitness.cu b/src/main_nep/fitness.cu index cc70704ef..3e691983b 100644 --- a/src/main_nep/fitness.cu +++ b/src/main_nep/fitness.cu @@ -165,15 +165,25 @@ void Fitness::compute( } } -void Fitness::predict_energy_or_stress(FILE* fid, float* data, float* ref, Dataset& dataset) +void Fitness::output( + int num_components, FILE* fid, float* prediction, float* reference, Dataset& dataset) { for (int nc = 0; nc < dataset.Nc; ++nc) { - int offset = dataset.Na_sum_cpu[nc]; - float data_nc = 0.0f; - for (int m = 0; m < dataset.Na_cpu[nc]; ++m) { - data_nc += data[offset + m]; + for (int n = 0; n < num_components; ++n) { + int offset = n * dataset.N + dataset.Na_sum_cpu[nc]; + float data_nc = 0.0f; + for (int m = 0; m < dataset.Na_cpu[nc]; ++m) { + data_nc += prediction[offset + m]; + } + fprintf(fid, "%g ", data_nc / dataset.Na_cpu[nc]); + } + for (int n = 0; n < num_components; ++n) { + if (n == num_components - 1) { + fprintf(fid, "%g\n", reference[n * dataset.Nc + nc]); + } else { + fprintf(fid, "%g ", reference[n * dataset.Nc + nc]); + } } - fprintf(fid, "%g %g\n", data_nc / dataset.Na_cpu[nc], ref[nc]); } } @@ -348,7 +358,6 @@ void Fitness::update_energy_force_virial( dataset.virial.copy_to_host(dataset.virial_cpu.data()); dataset.force.copy_to_host(dataset.force_cpu.data()); - // update force.out for (int nc = 0; nc < dataset.Nc; ++nc) { int offset = dataset.Na_sum_cpu[nc]; for (int m = 0; m < dataset.structures[nc].num_atom; ++m) { @@ -360,36 +369,20 @@ void Fitness::update_energy_force_virial( } } - // update energy.out - predict_energy_or_stress( - fid_energy, dataset.energy_cpu.data(), dataset.energy_ref_cpu.data(), dataset); - - // update virial.out - for (int k = 0; k < 6; ++k) { - predict_energy_or_stress( - fid_virial, dataset.virial_cpu.data() + dataset.N * k, - dataset.virial_ref_cpu.data() + dataset.Nc * k, dataset); - } + output(1, fid_energy, dataset.energy_cpu.data(), dataset.energy_ref_cpu.data(), dataset); + output(6, fid_virial, dataset.virial_cpu.data(), dataset.virial_ref_cpu.data(), dataset); } void Fitness::update_dipole(FILE* fid_dipole, Dataset& dataset) { dataset.virial.copy_to_host(dataset.virial_cpu.data()); - for (int k = 0; k < 3; ++k) { - predict_energy_or_stress( - fid_dipole, dataset.virial_cpu.data() + dataset.N * k, - dataset.virial_ref_cpu.data() + dataset.Nc * k, dataset); - } + output(3, fid_dipole, dataset.virial_cpu.data(), dataset.virial_ref_cpu.data(), dataset); } void Fitness::update_polarizability(FILE* fid_polarizability, Dataset& dataset) { dataset.virial.copy_to_host(dataset.virial_cpu.data()); - for (int k = 0; k < 6; ++k) { - predict_energy_or_stress( - fid_polarizability, dataset.virial_cpu.data() + dataset.N * k, - dataset.virial_ref_cpu.data() + dataset.Nc * k, dataset); - } + output(6, fid_polarizability, dataset.virial_cpu.data(), dataset.virial_ref_cpu.data(), dataset); } void Fitness::predict(Parameters& para, float* elite) diff --git a/src/main_nep/fitness.cuh b/src/main_nep/fitness.cuh index fc0882e54..8c12153d7 100644 --- a/src/main_nep/fitness.cuh +++ b/src/main_nep/fitness.cuh @@ -47,7 +47,7 @@ protected: std::unique_ptr potential; std::vector> train_set; std::vector test_set; - void predict_energy_or_stress(FILE* fid, float* data, float* ref, Dataset& dataset); + void output(int num_components, FILE* fid, float* prediction, float* reference, Dataset& dataset); void update_energy_force_virial(FILE* fid_energy, FILE* fid_force, FILE* fid_virial, Dataset& dataset); void update_dipole(FILE* fid_dipole, Dataset& dataset); diff --git a/src/main_nep/snes.cu b/src/main_nep/snes.cu index dbfcd47ce..c01435e38 100644 --- a/src/main_nep/snes.cu +++ b/src/main_nep/snes.cu @@ -168,7 +168,7 @@ void SNES::compute(Parameters& para, Fitness* fitness_function) if (para.prediction == 0) { printf("Started training.\n"); } else { - printf("Started predcting.\n"); + printf("Started predicting.\n"); } print_line_2(); diff --git a/src/main_nep/structure.cu b/src/main_nep/structure.cu index d0aed9607..67646b872 100644 --- a/src/main_nep/structure.cu +++ b/src/main_nep/structure.cu @@ -260,7 +260,13 @@ static void read_one_structure(const Parameters& para, std::ifstream& input, Str } } if (!structure.has_virial) { - PRINT_INPUT_ERROR("'dipole' is missing in the second line of a frame."); + if (para.prediction == 0) { + PRINT_INPUT_ERROR("'dipole' is missing in the second line of a frame."); + } else { + for (int m = 0; m < 6; ++m) { + structure.virial[m] = -1e6; + } + } } } @@ -283,7 +289,13 @@ static void read_one_structure(const Parameters& para, std::ifstream& input, Str } } if (!structure.has_virial) { - PRINT_INPUT_ERROR("'pol' is missing in the second line of a frame."); + if (para.prediction == 0) { + PRINT_INPUT_ERROR("'pol' is missing in the second line of a frame."); + } else { + for (int m = 0; m < 6; ++m) { + structure.virial[m] = -1e6; + } + } } }