From 0f2b2ec156ece94f0ffe0f0d452d21c0e0fecef2 Mon Sep 17 00:00:00 2001 From: Chen Nuo <49788094+Cstandardlib@users.noreply.github.com> Date: Wed, 6 May 2026 17:39:11 +0800 Subject: [PATCH] [Refactor] Remove unnecessary manual MPI_Reduce in hsolver dav code reduce_pool template already has single precision (std::complex) instantiation with MPI_IN_PLACE, so manually creating swap buffers and calling MPI_Reduce for complex types is unnecessary. Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- source/source_hsolver/diago_dav_subspace.cpp | 60 ++----------------- source/source_hsolver/diago_david.cpp | 22 +------ .../test/diago_david_float_test.cpp | 3 +- .../test/diago_david_real_test.cpp | 3 +- .../source_hsolver/test/diago_david_test.cpp | 3 +- 5 files changed, 13 insertions(+), 78 deletions(-) diff --git a/source/source_hsolver/diago_dav_subspace.cpp b/source/source_hsolver/diago_dav_subspace.cpp index 048653c2a87..96501fd6c0c 100644 --- a/source/source_hsolver/diago_dav_subspace.cpp +++ b/source/source_hsolver/diago_dav_subspace.cpp @@ -19,6 +19,7 @@ #ifdef __MPI #include +#include "source_base/parallel_comm.h" #endif using namespace hsolver; @@ -583,62 +584,9 @@ void Diago_DavSubspace::cal_elem(const int& dim, // Only on dsp hardware need an extra space to reduce data mtfunc::dsp_dav_subspace_reduce(hcc, scc, nbase, this->nbase_x, this->notconv, this->diag_comm.comm); #else - auto* swap = new T[notconv * this->nbase_x]; - - syncmem_complex_op()(swap, hcc + nbase * this->nbase_x, notconv * this->nbase_x); - - if (std::is_same::value) - { - Parallel_Reduce::reduce_pool(hcc + nbase * this->nbase_x, notconv * this->nbase_x); - Parallel_Reduce::reduce_pool(scc + nbase * this->nbase_x, notconv * this->nbase_x); - } - else - { - if (base_device::get_current_precision(swap) == "single") - { - MPI_Reduce(swap, - hcc + nbase * this->nbase_x, - notconv * this->nbase_x, - MPI_COMPLEX, - MPI_SUM, - 0, - this->diag_comm.comm); - } - else - { - MPI_Reduce(swap, - hcc + nbase * this->nbase_x, - notconv * this->nbase_x, - MPI_DOUBLE_COMPLEX, - MPI_SUM, - 0, - this->diag_comm.comm); - } - - syncmem_complex_op()(swap, scc + nbase * this->nbase_x, notconv * this->nbase_x); - - if (base_device::get_current_precision(swap) == "single") - { - MPI_Reduce(swap, - scc + nbase * this->nbase_x, - notconv * this->nbase_x, - MPI_COMPLEX, - MPI_SUM, - 0, - this->diag_comm.comm); - } - else - { - MPI_Reduce(swap, - scc + nbase * this->nbase_x, - notconv * this->nbase_x, - MPI_DOUBLE_COMPLEX, - MPI_SUM, - 0, - this->diag_comm.comm); - } - } - delete[] swap; + assert(this->diag_comm.comm == POOL_WORLD); + Parallel_Reduce::reduce_pool(hcc + nbase * this->nbase_x, notconv * this->nbase_x); + Parallel_Reduce::reduce_pool(scc + nbase * this->nbase_x, notconv * this->nbase_x); #endif } #endif diff --git a/source/source_hsolver/diago_david.cpp b/source/source_hsolver/diago_david.cpp index 5839ca2e2be..04e50e76c68 100644 --- a/source/source_hsolver/diago_david.cpp +++ b/source/source_hsolver/diago_david.cpp @@ -6,6 +6,7 @@ #include "source_hsolver/kernels/hegvd_op.h" #include "source_base/kernels/math_kernel_op.h" +#include "source_base/parallel_comm.h" using namespace hsolver; @@ -613,25 +614,8 @@ void DiagoDavid::cal_elem(const int& dim, { ModuleBase::matrixTranspose_op()(nbase_x, nbase_x, hcc, hcc); - auto* swap = new T[notconv * nbase_x]; - syncmem_complex_op()(swap, hcc + nbase * nbase_x, notconv * nbase_x); - if (std::is_same::value) - { - Parallel_Reduce::reduce_pool(hcc + nbase * nbase_x, notconv * nbase_x); - } - else - { - if (base_device::get_current_precision(swap) == "single") { - MPI_Reduce(swap, hcc + nbase * nbase_x, notconv * nbase_x, MPI_COMPLEX, MPI_SUM, 0, diag_comm.comm); - } - else { - MPI_Reduce(swap, hcc + nbase * nbase_x, notconv * nbase_x, MPI_DOUBLE_COMPLEX, MPI_SUM, 0, diag_comm.comm); - } - - } - delete[] swap; - - // Parallel_Reduce::reduce_complex_double_pool( hcc + nbase * nbase_x, notconv * nbase_x ); + assert(diag_comm.comm == POOL_WORLD); + Parallel_Reduce::reduce_pool(hcc + nbase * nbase_x, notconv * nbase_x); ModuleBase::matrixTranspose_op()(nbase_x, nbase_x, hcc, hcc); } diff --git a/source/source_hsolver/test/diago_david_float_test.cpp b/source/source_hsolver/test/diago_david_float_test.cpp index 000bb1b22ec..b061601caa9 100644 --- a/source/source_hsolver/test/diago_david_float_test.cpp +++ b/source/source_hsolver/test/diago_david_float_test.cpp @@ -1,5 +1,6 @@ #include"source_hsolver/diago_david.h" #include"source_hsolver/diago_iter_assist.h" +#include "source_base/parallel_comm.h" #include"source_pw/module_pwdft/hamilt_pw.h" #include"diago_mock.h" #include "source_psi/psi.h" @@ -85,7 +86,7 @@ class DiagoDavPrepare phm = new hamilt::HamiltPW>(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); #ifdef __MPI - const hsolver::diag_comm_info comm_info = {MPI_COMM_WORLD, mypnum, nprocs}; + const hsolver::diag_comm_info comm_info = {POOL_WORLD, mypnum, nprocs}; #else const hsolver::diag_comm_info comm_info = {mypnum, nprocs}; #endif diff --git a/source/source_hsolver/test/diago_david_real_test.cpp b/source/source_hsolver/test/diago_david_real_test.cpp index 8e0ecf35c4a..a3bfb15e5eb 100644 --- a/source/source_hsolver/test/diago_david_real_test.cpp +++ b/source/source_hsolver/test/diago_david_real_test.cpp @@ -1,5 +1,6 @@ #include"source_hsolver/diago_david.h" #include"source_hsolver/diago_iter_assist.h" +#include "source_base/parallel_comm.h" #include"source_pw/module_pwdft/hamilt_pw.h" #include"diago_mock.h" #include "source_psi/psi.h" @@ -84,7 +85,7 @@ class DiagoDavPrepare phm = new hamilt::HamiltPW(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); #ifdef __MPI - const hsolver::diag_comm_info comm_info = {MPI_COMM_WORLD, mypnum, nprocs}; + const hsolver::diag_comm_info comm_info = {POOL_WORLD, mypnum, nprocs}; #else const hsolver::diag_comm_info comm_info = {mypnum, nprocs}; #endif diff --git a/source/source_hsolver/test/diago_david_test.cpp b/source/source_hsolver/test/diago_david_test.cpp index 5f9de9833dc..9dfc3a03f78 100644 --- a/source/source_hsolver/test/diago_david_test.cpp +++ b/source/source_hsolver/test/diago_david_test.cpp @@ -1,5 +1,6 @@ #include"source_hsolver/diago_david.h" #include"source_hsolver/diago_iter_assist.h" +#include "source_base/parallel_comm.h" #include"source_pw/module_pwdft/hamilt_pw.h" #include"diago_mock.h" #include "source_psi/psi.h" @@ -89,7 +90,7 @@ class DiagoDavPrepare phm = new hamilt::HamiltPW>(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); #ifdef __MPI - const hsolver::diag_comm_info comm_info = {MPI_COMM_WORLD, mypnum, nprocs}; + const hsolver::diag_comm_info comm_info = {POOL_WORLD, mypnum, nprocs}; #else const hsolver::diag_comm_info comm_info = {mypnum, nprocs}; #endif