diff --git a/source/source_base/parallel_device.cpp b/source/source_base/parallel_device.cpp index ad43e78ed5..8ba8c3ae8e 100644 --- a/source/source_base/parallel_device.cpp +++ b/source/source_base/parallel_device.cpp @@ -86,14 +86,14 @@ NcclCommRegistry& get_nccl_registry() } template -void nccl_bcast_impl(T* object, const int n, MPI_Comm& comm, ncclDataType_t datatype, const int count_scale = 1) +void nccl_bcast_impl(T* object, const int n, MPI_Comm& comm, ncclDataType_t datatype, int root = 0, const int count_scale = 1) { NcclCommContext& ctx = get_nccl_registry().get(comm); if (ctx.size <= 1 || n <= 0) { return; } - CHECK_NCCL(ncclBroadcast(object, object, static_cast(n) * count_scale, datatype, 0, ctx.comm, ctx.stream)); + CHECK_NCCL(ncclBroadcast(object, object, static_cast(n) * count_scale, datatype, root, ctx.comm, ctx.stream)); CHECK_CUDA(cudaStreamSynchronize(ctx.stream)); } @@ -183,24 +183,24 @@ void nccl_gatherv_impl(const T* sendbuf, } } // namespace -void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm) +void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm, int root) { - nccl_bcast_impl(object, n, comm, ncclDouble); + nccl_bcast_impl(object, n, comm, ncclDouble, root); } -void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm) +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm, int root) { - nccl_bcast_impl(reinterpret_cast(object), n, comm, ncclDouble, 2); + nccl_bcast_impl(reinterpret_cast(object), n, comm, ncclDouble, root, 2); } -void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm) +void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm, int root) { - nccl_bcast_impl(object, n, comm, ncclFloat); + nccl_bcast_impl(object, n, comm, ncclFloat, root); } -void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm) +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm, int root) { - nccl_bcast_impl(reinterpret_cast(object), n, comm, ncclFloat, 2); + nccl_bcast_impl(reinterpret_cast(object), n, comm, ncclFloat, root, 2); } void nccl_reduce_data(double* object, const int& n, MPI_Comm& comm) @@ -302,21 +302,21 @@ void recv_data(std::complex* buf, int count, int source, int tag, MPI_Com { MPI_Recv(buf, count, MPI_COMPLEX, source, tag, comm, status); } -void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm) +void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm, int root) { - MPI_Bcast(object, n * 2, MPI_DOUBLE, 0, comm); + MPI_Bcast(object, n * 2, MPI_DOUBLE, root, comm); } -void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm) +void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm, int root) { - MPI_Bcast(object, n * 2, MPI_FLOAT, 0, comm); + MPI_Bcast(object, n * 2, MPI_FLOAT, root, comm); } -void bcast_data(double* object, const int& n, const MPI_Comm& comm) +void bcast_data(double* object, const int& n, const MPI_Comm& comm, int root) { - MPI_Bcast(object, n, MPI_DOUBLE, 0, comm); + MPI_Bcast(object, n, MPI_DOUBLE, root, comm); } -void bcast_data(float* object, const int& n, const MPI_Comm& comm) +void bcast_data(float* object, const int& n, const MPI_Comm& comm, int root) { - MPI_Bcast(object, n, MPI_FLOAT, 0, comm); + MPI_Bcast(object, n, MPI_FLOAT, root, comm); } void reduce_data(std::complex* object, const int& n, const MPI_Comm& comm) { diff --git a/source/source_base/parallel_device.h b/source/source_base/parallel_device.h index 264a0beb89..7826e4b465 100644 --- a/source/source_base/parallel_device.h +++ b/source/source_base/parallel_device.h @@ -4,6 +4,7 @@ #include "mpi.h" #include #include +#include "source_base/module_device/types.h" namespace Parallel_Common { void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); @@ -18,10 +19,10 @@ void recv_data(double* buf, int count, int source, int tag, MPI_Comm& comm, MPI_ void recv_data(std::complex* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); void recv_data(float* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); void recv_data(std::complex* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); -void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm); -void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm); -void bcast_data(double* object, const int& n, const MPI_Comm& comm); -void bcast_data(float* object, const int& n, const MPI_Comm& comm); +void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm, int root = 0); +void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm, int root = 0); +void bcast_data(double* object, const int& n, const MPI_Comm& comm, int root = 0); +void bcast_data(float* object, const int& n, const MPI_Comm& comm, int root = 0); void reduce_data(std::complex* object, const int& n, const MPI_Comm& comm); void reduce_data(std::complex* object, const int& n, const MPI_Comm& comm); void reduce_data(double* object, const int& n, const MPI_Comm& comm); @@ -32,10 +33,10 @@ void gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int void gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); #if defined(__NCCL_PARALLEL_DEVICE) -void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm); -void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm); -void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm); -void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm); +void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm, int root = 0); +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm, int root = 0); +void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm, int root = 0); +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm, int root = 0); void nccl_reduce_data(double* object, const int& n, MPI_Comm& comm); void nccl_reduce_data(std::complex* object, const int& n, MPI_Comm& comm); void nccl_reduce_data(float* object, const int& n, MPI_Comm& comm); @@ -116,35 +117,35 @@ void recv_dev(T* object, int count, int source, int tag, MPI_Comm& comm, MPI_Sta } /** - * @brief bcast data in Device + * @brief broadcast data in Device * * @tparam T: float, double, std::complex, std::complex * @tparam Device - * @param ctx Device ctx - * @param object complex arrays in Device - * @param n the size of complex arrays + * @param object arrays in Device + * @param n the size of array * @param comm MPI_Comm - * @param tmp_space tmp space in CPU + * @param root root rank (default 0) + * @param tmp_space optional tmp space in CPU (default nullptr) */ template -void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) +void bcast_dev(T* object, const int& n, const MPI_Comm& comm, int root = 0, T* tmp_space = nullptr) { #if defined(__NCCL_PARALLEL_DEVICE) if (std::is_same::value) { - nccl_bcast_data(object, n, const_cast(comm)); + nccl_bcast_data(object, n, const_cast(comm), root); return; } #endif #ifdef __CUDA_MPI - bcast_data(object, n, comm); + bcast_data(object, n, comm, root); #else object_cpu_point o; int rank = 0; MPI_Comm_rank(comm, &rank); - T* object_cpu = rank == 0 ? o.get(object, n, tmp_space) : o.get_buffer(object, n, tmp_space); - bcast_data(object_cpu, n, comm); - if (rank != 0) + T* object_cpu = rank == root ? o.get(object, n, tmp_space) : o.get_buffer(object, n, tmp_space); + bcast_data(object_cpu, n, comm, root); + if (rank != root) { o.sync_h2d(object, object_cpu, n); } diff --git a/source/source_hsolver/hsolver_pw_sdft.cpp b/source/source_hsolver/hsolver_pw_sdft.cpp index 3840239d75..f3c3d2f66a 100644 --- a/source/source_hsolver/hsolver_pw_sdft.cpp +++ b/source/source_hsolver/hsolver_pw_sdft.cpp @@ -61,7 +61,7 @@ void HSolverPW_SDFT::solve(const UnitCell& ucell, #ifdef __MPI if (nbands > 0 && !PARAM.globalv.all_ks_run) { - Parallel_Common::bcast_dev(&psi(ik, 0, 0), npwx * nbands, BP_WORLD, &psi_cpu(ik, 0, 0)); + Parallel_Common::bcast_dev(&psi(ik, 0, 0), npwx * nbands, BP_WORLD, 0, &psi_cpu(ik, 0, 0)); MPI_Bcast(&pes->ekb(ik, 0), nbands, MPI_DOUBLE, 0, BP_WORLD); } #endif diff --git a/source/source_pw/module_pwdft/op_pw_exx.cpp b/source/source_pw/module_pwdft/op_pw_exx.cpp index f38624479a..84f09cc8f6 100644 --- a/source/source_pw/module_pwdft/op_pw_exx.cpp +++ b/source/source_pw/module_pwdft/op_pw_exx.cpp @@ -3,6 +3,7 @@ #include "source_base/constants.h" #include "source_base/global_variable.h" #include "source_base/parallel_common.h" +#include "source_base/parallel_device.h" #include "source_base/parallel_comm.h" // use KP_WORLD #include "source_base/parallel_reduce.h" #include "source_base/module_external/lapack_connector.h" @@ -350,27 +351,7 @@ void OperatorEXXPW::act_op_kpar(const int nbands, // send } #ifdef __MPI -#ifdef __CUDA_MPI - MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD); -#else - if (PARAM.inp.device == "cpu") - { - MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD); - } - else if (PARAM.inp.device == "gpu") - { - // need to copy to cpu first - T* psi_mq_real_cpu = new T[wfcpw->nrxx]; - syncmem_complex_d2c_op()(psi_mq_real_cpu, psi_mq_real, wfcpw->nrxx); - MPI_Bcast(psi_mq_real_cpu, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD); - syncmem_complex_c2d_op()(psi_mq_real, psi_mq_real_cpu, wfcpw->nrxx); - delete[] psi_mq_real_cpu; - } - else - { - ModuleBase::WARNING_QUIT("OperatorEXXPW", "construct_ace: unknown device"); - } -#endif + Parallel_Common::bcast_dev(psi_mq_real, wfcpw->nrxx, KP_WORLD, iq_pool); #endif for (int n_iband = 0; n_iband < nbands; n_iband++) { diff --git a/source/source_pw/module_pwdft/op_pw_exx_ace.cpp b/source/source_pw/module_pwdft/op_pw_exx_ace.cpp index 4a47b2ac03..8685b355f0 100644 --- a/source/source_pw/module_pwdft/op_pw_exx_ace.cpp +++ b/source/source_pw/module_pwdft/op_pw_exx_ace.cpp @@ -1,5 +1,6 @@ #include "op_pw_exx.h" #include "source_base/parallel_comm.h" +#include "source_base/parallel_device.h" #include "source_base/parallel_reduce.h" #include "source_io/module_parameter/parameter.h" #include "source_hamilt/module_xc/exx_info.h" @@ -46,7 +47,9 @@ void OperatorEXXPW::act_op_ace(const int nbands, nbands_tot ); - Parallel_Reduce::reduce_pool(Xi_psi, nbands_tot * nbands); +#ifdef __MPI + Parallel_Common::reduce_dev(Xi_psi, nbands_tot * nbands, POOL_WORLD); +#endif // Xi^\dagger * (Xi * psi) gemm_complex_op()(trans_C, @@ -179,32 +182,9 @@ void OperatorEXXPW::construct_ace() const { const T* psi_mq = get_pw(m_iband, iq_loc); wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq_loc); - // send } - // if (iq == 0) - // std::cout << "Bcast psi_mq_real" << std::endl; #ifdef __MPI -#ifdef __CUDA_MPI - MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD); -#else - if (PARAM.inp.device == "cpu") - { - MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD); - } - else if (PARAM.inp.device == "gpu") - { - // need to copy to cpu first - T* psi_mq_real_cpu = new T[wfcpw->nrxx]; - syncmem_complex_d2c_op()(psi_mq_real_cpu, psi_mq_real, wfcpw->nrxx); - MPI_Bcast(psi_mq_real_cpu, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD); - syncmem_complex_c2d_op()(psi_mq_real, psi_mq_real_cpu, wfcpw->nrxx); - delete[] psi_mq_real_cpu; - } - else - { - ModuleBase::WARNING_QUIT("OperatorEXXPW", "construct_ace: unknown device"); - } -#endif + Parallel_Common::bcast_dev(psi_mq_real, wfcpw->nrxx, KP_WORLD, iq_pool); #endif } // end of iq @@ -232,7 +212,9 @@ void OperatorEXXPW::construct_ace() const nbands); // reduction of psi_h_psi_ace, due to distributed memory - Parallel_Reduce::reduce_pool(psi_h_psi_ace, nbands * nbands); +#ifdef __MPI + Parallel_Common::reduce_dev(psi_h_psi_ace, nbands * nbands, POOL_WORLD); +#endif T intermediate_minus_one = -1.0; axpy_complex_op()(nbands * nbands,