From 6d81876053ba546bad8af605e10d1f3d300407ca Mon Sep 17 00:00:00 2001 From: someone Date: Thu, 30 Apr 2026 11:22:20 +0800 Subject: [PATCH 01/14] Harden GPU MPI staging helpers --- source/source_base/parallel_device.cpp | 13 +++++++++++-- source/source_base/parallel_device.h | 21 +++++++++++---------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/source/source_base/parallel_device.cpp b/source/source_base/parallel_device.cpp index 933064e2486..887dde3bebf 100644 --- a/source/source_base/parallel_device.cpp +++ b/source/source_base/parallel_device.cpp @@ -104,7 +104,7 @@ template struct object_cpu_point { bool alloc = false; - T* get(const T* object, const int& n, T* tmp_space = nullptr) + T* get_buffer(const T* object, const int& n, T* tmp_space = nullptr) { T* object_cpu = nullptr; alloc = false; @@ -118,6 +118,11 @@ struct object_cpu_point { object_cpu = tmp_space; } + return object_cpu; + } + T* get(const T* object, const int& n, T* tmp_space = nullptr) + { + T* object_cpu = get_buffer(object, n, tmp_space); base_device::memory::synchronize_memory_op()(object_cpu, object, n); @@ -149,6 +154,10 @@ template struct object_cpu_point { bool alloc = false; + T* get_buffer(const T* object, const int& n, T* tmp_space = nullptr) + { + return const_cast(object); + } T* get(const T* object, const int& n, T* tmp_space = nullptr) { return const_cast(object); @@ -175,4 +184,4 @@ template struct object_cpu_point, base_device::DEVICE_GPU>; #endif } // namespace Parallel_Common -#endif \ No newline at end of file +#endif diff --git a/source/source_base/parallel_device.h b/source/source_base/parallel_device.h index 7293b375d74..40e325d8ac0 100644 --- a/source/source_base/parallel_device.h +++ b/source/source_base/parallel_device.h @@ -37,6 +37,7 @@ template struct object_cpu_point { bool alloc = false; + T* get_buffer(const T* object, const int& n, T* tmp_space = nullptr); T* get(const T* object, const int& n, T* tmp_space = nullptr); void del(T* object); void sync_d2h(T* object_cpu, const T* object, const int& n); @@ -56,7 +57,6 @@ void send_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, T* #else object_cpu_point o; T* object_cpu = o.get(object, count, tmp_space); - o.sync_d2h(object_cpu, object, count); send_data(object_cpu, count, dest, tag, comm); o.del(object_cpu); #endif @@ -76,7 +76,6 @@ void isend_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, MP #else object_cpu_point o; T* object_cpu = o.get(object, count, send_space); - o.sync_d2h(object_cpu, object, count); isend_data(object_cpu, count, dest, tag, comm, request); o.del(object_cpu); #endif @@ -94,7 +93,7 @@ void recv_dev(T* object, int count, int source, int tag, MPI_Comm& comm, MPI_Sta recv_data(object, count, source, tag, comm, status); #else object_cpu_point o; - T* object_cpu = o.get(object, count, tmp_space); + T* object_cpu = o.get_buffer(object, count, tmp_space); recv_data(object_cpu, count, source, tag, comm, status); o.sync_h2d(object, object_cpu, count); o.del(object_cpu); @@ -120,10 +119,14 @@ void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nul bcast_data(object, n, comm); #else object_cpu_point o; - T* object_cpu = o.get(object, n, tmp_space); - o.sync_d2h(object_cpu, object, n); + int rank = 0; + MPI_Comm_rank(comm, &rank); + T* object_cpu = rank == 0 ? o.get(object, n, tmp_space) : o.get_buffer(object, n, tmp_space); bcast_data(object_cpu, n, comm); - o.sync_h2d(object, object_cpu, n); + if (rank != 0) + { + o.sync_h2d(object, object_cpu, n); + } o.del(object_cpu); #endif return; @@ -137,7 +140,6 @@ void reduce_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nu #else object_cpu_point o; T* object_cpu = o.get(object, n, tmp_space); - o.sync_d2h(object_cpu, object, n); reduce_data(object_cpu, n, comm); o.sync_h2d(object, object_cpu, n); o.del(object_cpu); @@ -163,8 +165,7 @@ void gatherv_dev(const T* sendbuf, MPI_Comm_size(comm, &size); int gather_space = displs[size - 1] + recvcounts[size - 1]; T* sendbuf_cpu = o1.get(sendbuf, sendcount, tmp_sspace); - T* recvbuf_cpu = o2.get(recvbuf, gather_space, tmp_rspace); - o1.sync_d2h(sendbuf_cpu, sendbuf, sendcount); + T* recvbuf_cpu = o2.get_buffer(recvbuf, gather_space, tmp_rspace); gatherv_data(sendbuf_cpu, sendcount, recvbuf_cpu, recvcounts, displs, comm); o2.sync_h2d(recvbuf, recvbuf_cpu, gather_space); o1.del(sendbuf_cpu); @@ -177,4 +178,4 @@ void gatherv_dev(const T* sendbuf, #endif -#endif \ No newline at end of file +#endif From 2a07d9b679ce2b15c687ec442200512e4fee22cb Mon Sep 17 00:00:00 2001 From: someone Date: Thu, 30 Apr 2026 11:32:50 +0800 Subject: [PATCH 02/14] Add NCCL collectives for parallel_device --- CMakeLists.txt | 14 ++ cmake/SetupNccl.cmake | 27 +++ source/source_base/parallel_device.cpp | 243 +++++++++++++++++++++++++ source/source_base/parallel_device.h | 37 ++++ 4 files changed, 321 insertions(+) create mode 100644 cmake/SetupNccl.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index 6dc0bf4b5f8..02c22b87d99 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,7 @@ option(ENABLE_GOOGLEBENCH "Enable GOOGLE-benchmark usage" OFF) option(ENABLE_RAPIDJSON "Enable rapid-json usage" OFF) option(ENABLE_CNPY "Enable cnpy usage" OFF) option(ENABLE_CUSOLVERMP "Enable cusolvermp" OFF) +option(ENABLE_NCCL_PARALLEL_DEVICE "Enable NCCL-backed collectives in parallel_device" OFF) if(NOT DEFINED NVHPC_ROOT_DIR AND DEFINED ENV{NVHPC_ROOT}) set(NVHPC_ROOT_DIR @@ -451,6 +452,19 @@ if(USE_CUDA) if (USE_OPENMP AND OpenMP_CXX_FOUND) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=${OpenMP_CXX_FLAGS}" CACHE STRING "CUDA flags" FORCE) endif() + if (ENABLE_NCCL_PARALLEL_DEVICE) + if (NOT ENABLE_MPI) + message(FATAL_ERROR + "ENABLE_NCCL_PARALLEL_DEVICE requires ENABLE_MPI=ON.") + endif() + if (NOT USE_CUDA_MPI) + message(FATAL_ERROR + "ENABLE_NCCL_PARALLEL_DEVICE requires USE_CUDA_MPI=ON.") + endif() + add_compile_definitions(__NCCL_PARALLEL_DEVICE) + include(cmake/SetupNccl.cmake) + abacus_setup_nccl(${ABACUS_BIN_NAME}) + endif() if (ENABLE_CUSOLVERMP) # Keep cuSOLVERMp discovery/linking logic in a dedicated module. include(cmake/SetupCuSolverMp.cmake) diff --git a/cmake/SetupNccl.cmake b/cmake/SetupNccl.cmake new file mode 100644 index 00000000000..0ecf513b589 --- /dev/null +++ b/cmake/SetupNccl.cmake @@ -0,0 +1,27 @@ +include_guard(GLOBAL) + +function(abacus_setup_nccl target_name) + find_library(NCCL_LIBRARY NAMES nccl + HINTS ${NCCL_PATH} ${NVHPC_ROOT_DIR} + PATH_SUFFIXES lib lib64 comm_libs/nccl/lib) + find_path(NCCL_INCLUDE_DIR NAMES nccl.h + HINTS ${NCCL_PATH} ${NVHPC_ROOT_DIR} + PATHS ${CUDAToolkit_ROOT} + PATH_SUFFIXES include comm_libs/nccl/include) + + if(NOT NCCL_LIBRARY OR NOT NCCL_INCLUDE_DIR) + message(FATAL_ERROR + "NCCL not found. Set NCCL_PATH or NVHPC_ROOT_DIR.") + endif() + + message(STATUS "Found NCCL for parallel_device: ${NCCL_LIBRARY}") + if(NOT TARGET NCCL::NCCL) + add_library(NCCL::NCCL IMPORTED INTERFACE) + set_target_properties(NCCL::NCCL PROPERTIES + INTERFACE_LINK_LIBRARIES "${NCCL_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}") + endif() + + include_directories(${NCCL_INCLUDE_DIR}) + target_link_libraries(${target_name} NCCL::NCCL) +endfunction() diff --git a/source/source_base/parallel_device.cpp b/source/source_base/parallel_device.cpp index 887dde3bebf..10c5dc489e8 100644 --- a/source/source_base/parallel_device.cpp +++ b/source/source_base/parallel_device.cpp @@ -1,7 +1,250 @@ #include "parallel_device.h" + +#if defined(__MPI) && defined(__CUDA_MPI) && defined(__NCCL_PARALLEL_DEVICE) +#include "source_base/module_device/device_check.h" + +#include +#include +#include +#endif + #ifdef __MPI namespace Parallel_Common { +#if defined(__CUDA_MPI) && defined(__NCCL_PARALLEL_DEVICE) +namespace +{ +struct NcclCommContext +{ + ncclComm_t comm = nullptr; + cudaStream_t stream = nullptr; + int size = 0; +}; + +class NcclCommRegistry +{ + public: + ~NcclCommRegistry() + { + for (std::map::iterator it = contexts_.begin(); it != contexts_.end(); ++it) + { + if (it->second.stream != nullptr) + { + cudaStreamDestroy(it->second.stream); + } + if (it->second.comm != nullptr) + { + ncclCommDestroy(it->second.comm); + } + } + } + + NcclCommContext& get(MPI_Comm comm) + { + const MPI_Fint key = MPI_Comm_c2f(comm); + std::lock_guard lock(mutex_); + std::map::iterator found = contexts_.find(key); + if (found != contexts_.end()) + { + return found->second; + } + + int rank = 0; + int size = 0; + MPI_Comm_rank(comm, &rank); + MPI_Comm_size(comm, &size); + + NcclCommContext ctx; + ctx.size = size; + if (size > 1) + { + ncclUniqueId id; + if (rank == 0) + { + CHECK_NCCL(ncclGetUniqueId(&id)); + } + MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, comm); + CHECK_NCCL(ncclCommInitRank(&ctx.comm, size, id, rank)); + CHECK_CUDA(cudaStreamCreateWithFlags(&ctx.stream, cudaStreamNonBlocking)); + } + + std::pair::iterator, bool> inserted = contexts_.insert(std::make_pair(key, ctx)); + return inserted.first->second; + } + + private: + std::map contexts_; + std::mutex mutex_; +}; + +NcclCommRegistry& get_nccl_registry() +{ + static NcclCommRegistry registry; + return registry; +} + +template +void nccl_bcast_impl(T* object, const int n, MPI_Comm& comm, ncclDataType_t datatype, const int count_scale = 1) +{ + NcclCommContext& ctx = get_nccl_registry().get(comm); + if (ctx.size <= 1 || n <= 0) + { + return; + } + CHECK_NCCL(ncclBroadcast(object, object, static_cast(n) * count_scale, datatype, 0, ctx.comm, ctx.stream)); + CHECK_CUDA(cudaStreamSynchronize(ctx.stream)); +} + +template +void nccl_reduce_impl(T* object, const int n, MPI_Comm& comm, ncclDataType_t datatype, const int count_scale = 1) +{ + NcclCommContext& ctx = get_nccl_registry().get(comm); + if (ctx.size <= 1 || n <= 0) + { + return; + } + CHECK_NCCL(ncclAllReduce(object, object, static_cast(n) * count_scale, datatype, ncclSum, ctx.comm, ctx.stream)); + CHECK_CUDA(cudaStreamSynchronize(ctx.stream)); +} + +template +void nccl_gatherv_impl(const T* sendbuf, + const int sendcount, + T* recvbuf, + const int* recvcounts, + const int* displs, + MPI_Comm& comm) +{ + NcclCommContext& ctx = get_nccl_registry().get(comm); + if (ctx.size <= 1) + { + if (sendbuf != recvbuf && sendcount > 0) + { + CHECK_CUDA(cudaMemcpy(recvbuf, sendbuf, static_cast(sendcount) * sizeof(T), cudaMemcpyDeviceToDevice)); + } + return; + } + + int chunk_count = 0; + for (int i = 0; i < ctx.size; ++i) + { + if (recvcounts[i] > chunk_count) + { + chunk_count = recvcounts[i]; + } + } + if (chunk_count <= 0) + { + return; + } + + const size_t chunk_bytes = static_cast(chunk_count) * sizeof(T); + const size_t recv_bytes = chunk_bytes * ctx.size; + unsigned char* staged_send = nullptr; + unsigned char* staged_recv = nullptr; + + CHECK_CUDA(cudaMalloc(&staged_send, chunk_bytes)); + CHECK_CUDA(cudaMalloc(&staged_recv, recv_bytes)); + if (sendcount > 0) + { + CHECK_CUDA(cudaMemcpyAsync(staged_send, + sendbuf, + static_cast(sendcount) * sizeof(T), + cudaMemcpyDeviceToDevice, + ctx.stream)); + } + + CHECK_NCCL(ncclAllGather(staged_send, staged_recv, chunk_bytes, ncclUint8, ctx.comm, ctx.stream)); + + for (int i = 0; i < ctx.size; ++i) + { + if (recvcounts[i] > 0) + { + CHECK_CUDA(cudaMemcpyAsync(recvbuf + displs[i], + staged_recv + static_cast(i) * chunk_bytes, + static_cast(recvcounts[i]) * sizeof(T), + cudaMemcpyDeviceToDevice, + ctx.stream)); + } + } + + CHECK_CUDA(cudaStreamSynchronize(ctx.stream)); + CHECK_CUDA(cudaFree(staged_send)); + CHECK_CUDA(cudaFree(staged_recv)); +} +} // namespace + +void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm) +{ + nccl_bcast_impl(object, n, comm, ncclDouble); +} + +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm) +{ + nccl_bcast_impl(reinterpret_cast(object), n, comm, ncclDouble, 2); +} + +void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm) +{ + nccl_bcast_impl(object, n, comm, ncclFloat); +} + +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm) +{ + nccl_bcast_impl(reinterpret_cast(object), n, comm, ncclFloat, 2); +} + +void nccl_reduce_data(double* object, const int& n, MPI_Comm& comm) +{ + nccl_reduce_impl(object, n, comm, ncclDouble); +} + +void nccl_reduce_data(std::complex* object, const int& n, MPI_Comm& comm) +{ + nccl_reduce_impl(reinterpret_cast(object), n, comm, ncclDouble, 2); +} + +void nccl_reduce_data(float* object, const int& n, MPI_Comm& comm) +{ + nccl_reduce_impl(object, n, comm, ncclFloat); +} + +void nccl_reduce_data(std::complex* object, const int& n, MPI_Comm& comm) +{ + nccl_reduce_impl(reinterpret_cast(object), n, comm, ncclFloat, 2); +} + +void nccl_gatherv_data(const double* sendbuf, int sendcount, double* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm) +{ + nccl_gatherv_impl(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); +} + +void nccl_gatherv_data(const std::complex* sendbuf, + int sendcount, + std::complex* recvbuf, + const int* recvcounts, + const int* displs, + MPI_Comm& comm) +{ + nccl_gatherv_impl(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); +} + +void nccl_gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm) +{ + nccl_gatherv_impl(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); +} + +void nccl_gatherv_data(const std::complex* sendbuf, + int sendcount, + std::complex* recvbuf, + const int* recvcounts, + const int* displs, + MPI_Comm& comm) +{ + nccl_gatherv_impl(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); +} +#endif + void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request) { MPI_Isend(buf, count, MPI_DOUBLE, dest, tag, comm, request); diff --git a/source/source_base/parallel_device.h b/source/source_base/parallel_device.h index 40e325d8ac0..46beb7080b0 100644 --- a/source/source_base/parallel_device.h +++ b/source/source_base/parallel_device.h @@ -5,6 +5,7 @@ #include "source_base/module_device/device.h" #include "source_base/module_device/memory_op.h" #include +#include namespace Parallel_Common { void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); @@ -32,6 +33,21 @@ void gatherv_data(const std::complex* sendbuf, int sendcount, std::compl void gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); void gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +#if defined(__CUDA_MPI) && defined(__NCCL_PARALLEL_DEVICE) +void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm); +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm); +void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm); +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm); +void nccl_reduce_data(double* object, const int& n, MPI_Comm& comm); +void nccl_reduce_data(std::complex* object, const int& n, MPI_Comm& comm); +void nccl_reduce_data(float* object, const int& n, MPI_Comm& comm); +void nccl_reduce_data(std::complex* object, const int& n, MPI_Comm& comm); +void nccl_gatherv_data(const double* sendbuf, int sendcount, double* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +void nccl_gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +void nccl_gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +void nccl_gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +#endif + #ifndef __CUDA_MPI template struct object_cpu_point @@ -116,6 +132,13 @@ template void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) { #ifdef __CUDA_MPI +#if defined(__NCCL_PARALLEL_DEVICE) + if (std::is_same::value) + { + nccl_bcast_data(object, n, const_cast(comm)); + return; + } +#endif bcast_data(object, n, comm); #else object_cpu_point o; @@ -136,6 +159,13 @@ template void reduce_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) { #ifdef __CUDA_MPI +#if defined(__NCCL_PARALLEL_DEVICE) + if (std::is_same::value) + { + nccl_reduce_data(object, n, const_cast(comm)); + return; + } +#endif reduce_data(object, n, comm); #else object_cpu_point o; @@ -158,6 +188,13 @@ void gatherv_dev(const T* sendbuf, T* tmp_rspace = nullptr) { #ifdef __CUDA_MPI +#if defined(__NCCL_PARALLEL_DEVICE) + if (std::is_same::value) + { + nccl_gatherv_data(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); + return; + } +#endif gatherv_data(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); #else object_cpu_point o1, o2; From 5f1cb8131452bcb1f8efafcfb6c37ed87cc2f694 Mon Sep 17 00:00:00 2001 From: someone Date: Thu, 30 Apr 2026 11:43:07 +0800 Subject: [PATCH 03/14] Fix NCCL headers in parallel_device --- source/source_base/parallel_device.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/source/source_base/parallel_device.cpp b/source/source_base/parallel_device.cpp index 10c5dc489e8..6badd60bd12 100644 --- a/source/source_base/parallel_device.cpp +++ b/source/source_base/parallel_device.cpp @@ -4,8 +4,27 @@ #include "source_base/module_device/device_check.h" #include +#include + #include #include + +#include +#include + +#ifndef CHECK_NCCL +#define CHECK_NCCL(func) \ + do \ + { \ + ncclResult_t status = (func); \ + if (status != ncclSuccess) \ + { \ + fprintf(stderr, "In File %s : NCCL API failed at line %d with error: %s (%d)\n", __FILE__, __LINE__, \ + ncclGetErrorString(status), status); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) +#endif #endif #ifdef __MPI From 356fb280373adf39ccd61a950182d22383045430 Mon Sep 17 00:00:00 2001 From: someone Date: Thu, 30 Apr 2026 11:50:56 +0800 Subject: [PATCH 04/14] Route PGemm collectives through device wrappers --- source/source_base/para_gemm.cpp | 41 ++++++++++++-------------------- 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/source/source_base/para_gemm.cpp b/source/source_base/para_gemm.cpp index edb798554cc..40913994eb4 100644 --- a/source/source_base/para_gemm.cpp +++ b/source/source_base/para_gemm.cpp @@ -277,38 +277,27 @@ void PGemmCN::multiply_col(const T alpha, const T* A, const T* B, con if (this->gatherC) { -#ifdef __CUDA_MPI - T* Clocal_mpi = C_local; - T* Cglobal_mpi = C; -#else - T* Clocal_mpi = C_tmp_.data(); - T* Cglobal_mpi = nullptr; + T* reduce_tmp = nullptr; + T* gather_tmp = nullptr; +#ifndef __CUDA_MPI if (std::is_same::value) { - syncmem_d2h_op()(Clocal_mpi, C_local, size_C_local); - Cglobal_mpi = C_global_tmp_.data(); - } - else - { - Cglobal_mpi = C; + reduce_tmp = C_tmp_.data(); + gather_tmp = C_global_tmp_.data(); } #endif if (this->row_nproc > 1) { - Parallel_Common::reduce_data(Clocal_mpi, size_C_local, row_world); + Parallel_Common::reduce_dev(C_local, size_C_local, row_world, reduce_tmp); } - Parallel_Common::gatherv_data(Clocal_mpi, - size_C_local, - Cglobal_mpi, - recv_counts.data(), - displs.data(), - col_world); -#ifndef __CUDA_MPI - if (std::is_same::value) - { - syncmem_h2d_op()(C, Cglobal_mpi, size_C_global); - } -#endif + Parallel_Common::gatherv_dev(C_local, + size_C_local, + C, + recv_counts.data(), + displs.data(), + col_world, + reduce_tmp, + gather_tmp); } else { @@ -409,4 +398,4 @@ template class PGemmCN, base_device::DEVICE_GPU>; template class PGemmCN, base_device::DEVICE_GPU>; #endif -} // namespace ModuleBase \ No newline at end of file +} // namespace ModuleBase From d7e8334157a19680aab6214fd5de789ba6ae7409 Mon Sep 17 00:00:00 2001 From: someone Date: Thu, 30 Apr 2026 12:01:34 +0800 Subject: [PATCH 05/14] Tighten NCCL collective correctness --- source/source_base/parallel_device.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/source/source_base/parallel_device.cpp b/source/source_base/parallel_device.cpp index 6badd60bd12..ab867ce9318 100644 --- a/source/source_base/parallel_device.cpp +++ b/source/source_base/parallel_device.cpp @@ -11,6 +11,7 @@ #include #include +#include #ifndef CHECK_NCCL #define CHECK_NCCL(func) \ @@ -47,10 +48,6 @@ class NcclCommRegistry { for (std::map::iterator it = contexts_.begin(); it != contexts_.end(); ++it) { - if (it->second.stream != nullptr) - { - cudaStreamDestroy(it->second.stream); - } if (it->second.comm != nullptr) { ncclCommDestroy(it->second.comm); @@ -84,7 +81,6 @@ class NcclCommRegistry } MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, comm); CHECK_NCCL(ncclCommInitRank(&ctx.comm, size, id, rank)); - CHECK_CUDA(cudaStreamCreateWithFlags(&ctx.stream, cudaStreamNonBlocking)); } std::pair::iterator, bool> inserted = contexts_.insert(std::make_pair(key, ctx)); @@ -145,6 +141,8 @@ void nccl_gatherv_impl(const T* sendbuf, } int chunk_count = 0; + int rank = 0; + MPI_Comm_rank(comm, &rank); for (int i = 0; i < ctx.size; ++i) { if (recvcounts[i] > chunk_count) @@ -152,6 +150,10 @@ void nccl_gatherv_impl(const T* sendbuf, chunk_count = recvcounts[i]; } } + if (recvcounts[rank] != sendcount) + { + throw std::runtime_error("nccl_gatherv_data: sendcount does not match recvcounts[rank]"); + } if (chunk_count <= 0) { return; @@ -164,6 +166,7 @@ void nccl_gatherv_impl(const T* sendbuf, CHECK_CUDA(cudaMalloc(&staged_send, chunk_bytes)); CHECK_CUDA(cudaMalloc(&staged_recv, recv_bytes)); + CHECK_CUDA(cudaMemsetAsync(staged_send, 0, chunk_bytes, ctx.stream)); if (sendcount > 0) { CHECK_CUDA(cudaMemcpyAsync(staged_send, From d1c59795691b61714aafd23b96e85a242b30acb0 Mon Sep 17 00:00:00 2001 From: someone Date: Thu, 30 Apr 2026 12:18:03 +0800 Subject: [PATCH 06/14] Relax NCCL discovery for existing environments --- cmake/SetupNccl.cmake | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/cmake/SetupNccl.cmake b/cmake/SetupNccl.cmake index 0ecf513b589..31f2cb75c6c 100644 --- a/cmake/SetupNccl.cmake +++ b/cmake/SetupNccl.cmake @@ -1,5 +1,7 @@ include_guard(GLOBAL) +include(CheckIncludeFileCXX) + function(abacus_setup_nccl target_name) find_library(NCCL_LIBRARY NAMES nccl HINTS ${NCCL_PATH} ${NVHPC_ROOT_DIR} @@ -9,19 +11,36 @@ function(abacus_setup_nccl target_name) PATHS ${CUDAToolkit_ROOT} PATH_SUFFIXES include comm_libs/nccl/include) - if(NOT NCCL_LIBRARY OR NOT NCCL_INCLUDE_DIR) + check_include_file_cxx("nccl.h" HAVE_NCCL_HEADER) + + if(NOT NCCL_LIBRARY) + set(NCCL_LIBRARY nccl) + endif() + + if(NOT NCCL_INCLUDE_DIR AND NOT HAVE_NCCL_HEADER) message(FATAL_ERROR "NCCL not found. Set NCCL_PATH or NVHPC_ROOT_DIR.") endif() - message(STATUS "Found NCCL for parallel_device: ${NCCL_LIBRARY}") + if(NCCL_INCLUDE_DIR) + message(STATUS "Found NCCL for parallel_device: ${NCCL_LIBRARY}") + else() + message(STATUS "Using default compiler/linker search paths for NCCL: ${NCCL_LIBRARY}") + endif() if(NOT TARGET NCCL::NCCL) add_library(NCCL::NCCL IMPORTED INTERFACE) - set_target_properties(NCCL::NCCL PROPERTIES - INTERFACE_LINK_LIBRARIES "${NCCL_LIBRARY}" - INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}") + if(NCCL_INCLUDE_DIR) + set_target_properties(NCCL::NCCL PROPERTIES + INTERFACE_LINK_LIBRARIES "${NCCL_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}") + else() + set_target_properties(NCCL::NCCL PROPERTIES + INTERFACE_LINK_LIBRARIES "${NCCL_LIBRARY}") + endif() endif() - include_directories(${NCCL_INCLUDE_DIR}) + if(NCCL_INCLUDE_DIR) + target_include_directories(${target_name} PRIVATE ${NCCL_INCLUDE_DIR}) + endif() target_link_libraries(${target_name} NCCL::NCCL) endfunction() From bb01cefc26e66ce796e3ed2f52fa7ac6d1e23f33 Mon Sep 17 00:00:00 2001 From: someone Date: Thu, 30 Apr 2026 17:46:49 +0800 Subject: [PATCH 07/14] Decouple NCCL parallel_device from CUDA-aware MPI --- CMakeLists.txt | 4 ---- source/source_base/para_gemm.cpp | 6 +++--- source/source_base/parallel_device.cpp | 4 ++-- source/source_base/parallel_device.h | 8 ++++---- 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 02c22b87d99..707f8caaca3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -457,10 +457,6 @@ if(USE_CUDA) message(FATAL_ERROR "ENABLE_NCCL_PARALLEL_DEVICE requires ENABLE_MPI=ON.") endif() - if (NOT USE_CUDA_MPI) - message(FATAL_ERROR - "ENABLE_NCCL_PARALLEL_DEVICE requires USE_CUDA_MPI=ON.") - endif() add_compile_definitions(__NCCL_PARALLEL_DEVICE) include(cmake/SetupNccl.cmake) abacus_setup_nccl(${ABACUS_BIN_NAME}) diff --git a/source/source_base/para_gemm.cpp b/source/source_base/para_gemm.cpp index 40913994eb4..36d99d53b80 100644 --- a/source/source_base/para_gemm.cpp +++ b/source/source_base/para_gemm.cpp @@ -105,7 +105,7 @@ void PGemmCN::set_dimension( if (std::is_same::value) { resmem_dev_op()(A_tmp_device_, max_colA * LDA); -#ifndef __CUDA_MPI +#if !defined(__CUDA_MPI) || defined(__NCCL_PARALLEL_DEVICE) isend_tmp_.resize(max_colA * LDA); #endif } @@ -133,7 +133,7 @@ void PGemmCN::set_dimension( if (std::is_same::value) { resmem_dev_op()(C_local_tmp_, size_C_local); -#ifndef __CUDA_MPI +#if !defined(__CUDA_MPI) || defined(__NCCL_PARALLEL_DEVICE) C_global_tmp_.resize(size_C_global); #endif } @@ -279,7 +279,7 @@ void PGemmCN::multiply_col(const T alpha, const T* A, const T* B, con { T* reduce_tmp = nullptr; T* gather_tmp = nullptr; -#ifndef __CUDA_MPI +#if !defined(__CUDA_MPI) || defined(__NCCL_PARALLEL_DEVICE) if (std::is_same::value) { reduce_tmp = C_tmp_.data(); diff --git a/source/source_base/parallel_device.cpp b/source/source_base/parallel_device.cpp index ab867ce9318..d8e9b690903 100644 --- a/source/source_base/parallel_device.cpp +++ b/source/source_base/parallel_device.cpp @@ -1,6 +1,6 @@ #include "parallel_device.h" -#if defined(__MPI) && defined(__CUDA_MPI) && defined(__NCCL_PARALLEL_DEVICE) +#if defined(__MPI) && defined(__NCCL_PARALLEL_DEVICE) #include "source_base/module_device/device_check.h" #include @@ -31,7 +31,7 @@ #ifdef __MPI namespace Parallel_Common { -#if defined(__CUDA_MPI) && defined(__NCCL_PARALLEL_DEVICE) +#if defined(__NCCL_PARALLEL_DEVICE) namespace { struct NcclCommContext diff --git a/source/source_base/parallel_device.h b/source/source_base/parallel_device.h index 46beb7080b0..a148226e9fd 100644 --- a/source/source_base/parallel_device.h +++ b/source/source_base/parallel_device.h @@ -33,7 +33,7 @@ void gatherv_data(const std::complex* sendbuf, int sendcount, std::compl void gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); void gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); -#if defined(__CUDA_MPI) && defined(__NCCL_PARALLEL_DEVICE) +#if defined(__NCCL_PARALLEL_DEVICE) void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm); void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm); void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm); @@ -131,7 +131,6 @@ void recv_dev(T* object, int count, int source, int tag, MPI_Comm& comm, MPI_Sta template void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) { -#ifdef __CUDA_MPI #if defined(__NCCL_PARALLEL_DEVICE) if (std::is_same::value) { @@ -139,6 +138,7 @@ void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nul return; } #endif +#ifdef __CUDA_MPI bcast_data(object, n, comm); #else object_cpu_point o; @@ -158,7 +158,6 @@ void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nul template void reduce_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) { -#ifdef __CUDA_MPI #if defined(__NCCL_PARALLEL_DEVICE) if (std::is_same::value) { @@ -166,6 +165,7 @@ void reduce_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nu return; } #endif +#ifdef __CUDA_MPI reduce_data(object, n, comm); #else object_cpu_point o; @@ -187,7 +187,6 @@ void gatherv_dev(const T* sendbuf, T* tmp_sspace = nullptr, T* tmp_rspace = nullptr) { -#ifdef __CUDA_MPI #if defined(__NCCL_PARALLEL_DEVICE) if (std::is_same::value) { @@ -195,6 +194,7 @@ void gatherv_dev(const T* sendbuf, return; } #endif +#ifdef __CUDA_MPI gatherv_data(sendbuf, sendcount, recvbuf, recvcounts, displs, comm); #else object_cpu_point o1, o2; From aaaaad9e60b8eb41a6ac76966563199474bf85b7 Mon Sep 17 00:00:00 2001 From: someone Date: Thu, 30 Apr 2026 18:19:07 +0800 Subject: [PATCH 08/14] Propagate NCCL headers to subdirectory targets --- cmake/SetupNccl.cmake | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cmake/SetupNccl.cmake b/cmake/SetupNccl.cmake index 31f2cb75c6c..56e8e10e7b0 100644 --- a/cmake/SetupNccl.cmake +++ b/cmake/SetupNccl.cmake @@ -40,6 +40,9 @@ function(abacus_setup_nccl target_name) endif() if(NCCL_INCLUDE_DIR) + # `parallel_device.cpp` is compiled inside the later `base` OBJECT library, + # so the header path must also be visible to targets created in subdirs. + include_directories(${NCCL_INCLUDE_DIR}) target_include_directories(${target_name} PRIVATE ${NCCL_INCLUDE_DIR}) endif() target_link_libraries(${target_name} NCCL::NCCL) From 7c34f158f3bd998c9f05e4a76923ad8a80e40cde Mon Sep 17 00:00:00 2001 From: someone Date: Mon, 4 May 2026 12:19:53 +0800 Subject: [PATCH 09/14] Fix: narrow CPU staging guards in para_gemm to respect NCCL collectives MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit isend_dev has no NCCL path — keep guard as #ifndef __CUDA_MPI. reduce_dev / gatherv_dev have NCCL early-returns — exclude CPU staging when __NCCL_PARALLEL_DEVICE is defined (&& !defined). --- source/source_base/para_gemm.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/source/source_base/para_gemm.cpp b/source/source_base/para_gemm.cpp index 36d99d53b80..3e56aa83ac2 100644 --- a/source/source_base/para_gemm.cpp +++ b/source/source_base/para_gemm.cpp @@ -105,7 +105,7 @@ void PGemmCN::set_dimension( if (std::is_same::value) { resmem_dev_op()(A_tmp_device_, max_colA * LDA); -#if !defined(__CUDA_MPI) || defined(__NCCL_PARALLEL_DEVICE) +#ifndef __CUDA_MPI isend_tmp_.resize(max_colA * LDA); #endif } @@ -133,7 +133,7 @@ void PGemmCN::set_dimension( if (std::is_same::value) { resmem_dev_op()(C_local_tmp_, size_C_local); -#if !defined(__CUDA_MPI) || defined(__NCCL_PARALLEL_DEVICE) +#if !defined(__CUDA_MPI) && !defined(__NCCL_PARALLEL_DEVICE) C_global_tmp_.resize(size_C_global); #endif } @@ -279,7 +279,7 @@ void PGemmCN::multiply_col(const T alpha, const T* A, const T* B, con { T* reduce_tmp = nullptr; T* gather_tmp = nullptr; -#if !defined(__CUDA_MPI) || defined(__NCCL_PARALLEL_DEVICE) +#if !defined(__CUDA_MPI) && !defined(__NCCL_PARALLEL_DEVICE) if (std::is_same::value) { reduce_tmp = C_tmp_.data(); From d2ca43270c7b205bfb0186b697eed9eab01c781b Mon Sep 17 00:00:00 2001 From: someone Date: Wed, 6 May 2026 10:42:26 +0800 Subject: [PATCH 10/14] Refactor(exx_pw): unify GPU/CPU bcast via Parallel_Common::bcast_dev Add configurable root parameter to bcast_data/nccl_bcast_data/bcast_dev (default root=0 for backward compat). Replace manual MPI_Bcast with GPU/CPU branching in EXX PW operator with unified bcast_dev/reduce_dev. --- source/source_base/parallel_device.cpp | 36 +++++++++---------- source/source_base/parallel_device.h | 28 +++++++-------- source/source_pw/module_pwdft/op_pw_exx.cpp | 23 ++---------- .../source_pw/module_pwdft/op_pw_exx_ace.cpp | 30 +++------------- 4 files changed, 38 insertions(+), 79 deletions(-) diff --git a/source/source_base/parallel_device.cpp b/source/source_base/parallel_device.cpp index ad43e78ed5a..8ba8c3ae8ed 100644 --- a/source/source_base/parallel_device.cpp +++ b/source/source_base/parallel_device.cpp @@ -86,14 +86,14 @@ NcclCommRegistry& get_nccl_registry() } template -void nccl_bcast_impl(T* object, const int n, MPI_Comm& comm, ncclDataType_t datatype, const int count_scale = 1) +void nccl_bcast_impl(T* object, const int n, MPI_Comm& comm, ncclDataType_t datatype, int root = 0, const int count_scale = 1) { NcclCommContext& ctx = get_nccl_registry().get(comm); if (ctx.size <= 1 || n <= 0) { return; } - CHECK_NCCL(ncclBroadcast(object, object, static_cast(n) * count_scale, datatype, 0, ctx.comm, ctx.stream)); + CHECK_NCCL(ncclBroadcast(object, object, static_cast(n) * count_scale, datatype, root, ctx.comm, ctx.stream)); CHECK_CUDA(cudaStreamSynchronize(ctx.stream)); } @@ -183,24 +183,24 @@ void nccl_gatherv_impl(const T* sendbuf, } } // namespace -void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm) +void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm, int root) { - nccl_bcast_impl(object, n, comm, ncclDouble); + nccl_bcast_impl(object, n, comm, ncclDouble, root); } -void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm) +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm, int root) { - nccl_bcast_impl(reinterpret_cast(object), n, comm, ncclDouble, 2); + nccl_bcast_impl(reinterpret_cast(object), n, comm, ncclDouble, root, 2); } -void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm) +void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm, int root) { - nccl_bcast_impl(object, n, comm, ncclFloat); + nccl_bcast_impl(object, n, comm, ncclFloat, root); } -void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm) +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm, int root) { - nccl_bcast_impl(reinterpret_cast(object), n, comm, ncclFloat, 2); + nccl_bcast_impl(reinterpret_cast(object), n, comm, ncclFloat, root, 2); } void nccl_reduce_data(double* object, const int& n, MPI_Comm& comm) @@ -302,21 +302,21 @@ void recv_data(std::complex* buf, int count, int source, int tag, MPI_Com { MPI_Recv(buf, count, MPI_COMPLEX, source, tag, comm, status); } -void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm) +void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm, int root) { - MPI_Bcast(object, n * 2, MPI_DOUBLE, 0, comm); + MPI_Bcast(object, n * 2, MPI_DOUBLE, root, comm); } -void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm) +void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm, int root) { - MPI_Bcast(object, n * 2, MPI_FLOAT, 0, comm); + MPI_Bcast(object, n * 2, MPI_FLOAT, root, comm); } -void bcast_data(double* object, const int& n, const MPI_Comm& comm) +void bcast_data(double* object, const int& n, const MPI_Comm& comm, int root) { - MPI_Bcast(object, n, MPI_DOUBLE, 0, comm); + MPI_Bcast(object, n, MPI_DOUBLE, root, comm); } -void bcast_data(float* object, const int& n, const MPI_Comm& comm) +void bcast_data(float* object, const int& n, const MPI_Comm& comm, int root) { - MPI_Bcast(object, n, MPI_FLOAT, 0, comm); + MPI_Bcast(object, n, MPI_FLOAT, root, comm); } void reduce_data(std::complex* object, const int& n, const MPI_Comm& comm) { diff --git a/source/source_base/parallel_device.h b/source/source_base/parallel_device.h index 264a0beb89f..26fb8f50687 100644 --- a/source/source_base/parallel_device.h +++ b/source/source_base/parallel_device.h @@ -18,10 +18,10 @@ void recv_data(double* buf, int count, int source, int tag, MPI_Comm& comm, MPI_ void recv_data(std::complex* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); void recv_data(float* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); void recv_data(std::complex* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); -void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm); -void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm); -void bcast_data(double* object, const int& n, const MPI_Comm& comm); -void bcast_data(float* object, const int& n, const MPI_Comm& comm); +void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm, int root = 0); +void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm, int root = 0); +void bcast_data(double* object, const int& n, const MPI_Comm& comm, int root = 0); +void bcast_data(float* object, const int& n, const MPI_Comm& comm, int root = 0); void reduce_data(std::complex* object, const int& n, const MPI_Comm& comm); void reduce_data(std::complex* object, const int& n, const MPI_Comm& comm); void reduce_data(double* object, const int& n, const MPI_Comm& comm); @@ -32,10 +32,10 @@ void gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int void gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); #if defined(__NCCL_PARALLEL_DEVICE) -void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm); -void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm); -void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm); -void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm); +void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm, int root = 0); +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm, int root = 0); +void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm, int root = 0); +void nccl_bcast_data(std::complex* object, const int& n, MPI_Comm& comm, int root = 0); void nccl_reduce_data(double* object, const int& n, MPI_Comm& comm); void nccl_reduce_data(std::complex* object, const int& n, MPI_Comm& comm); void nccl_reduce_data(float* object, const int& n, MPI_Comm& comm); @@ -127,24 +127,24 @@ void recv_dev(T* object, int count, int source, int tag, MPI_Comm& comm, MPI_Sta * @param tmp_space tmp space in CPU */ template -void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) +void bcast_dev(T* object, const int& n, const MPI_Comm& comm, int root = 0, T* tmp_space = nullptr) { #if defined(__NCCL_PARALLEL_DEVICE) if (std::is_same::value) { - nccl_bcast_data(object, n, const_cast(comm)); + nccl_bcast_data(object, n, const_cast(comm), root); return; } #endif #ifdef __CUDA_MPI - bcast_data(object, n, comm); + bcast_data(object, n, comm, root); #else object_cpu_point o; int rank = 0; MPI_Comm_rank(comm, &rank); - T* object_cpu = rank == 0 ? o.get(object, n, tmp_space) : o.get_buffer(object, n, tmp_space); - bcast_data(object_cpu, n, comm); - if (rank != 0) + T* object_cpu = rank == root ? o.get(object, n, tmp_space) : o.get_buffer(object, n, tmp_space); + bcast_data(object_cpu, n, comm, root); + if (rank != root) { o.sync_h2d(object, object_cpu, n); } diff --git a/source/source_pw/module_pwdft/op_pw_exx.cpp b/source/source_pw/module_pwdft/op_pw_exx.cpp index f38624479a7..84f09cc8f65 100644 --- a/source/source_pw/module_pwdft/op_pw_exx.cpp +++ b/source/source_pw/module_pwdft/op_pw_exx.cpp @@ -3,6 +3,7 @@ #include "source_base/constants.h" #include "source_base/global_variable.h" #include "source_base/parallel_common.h" +#include "source_base/parallel_device.h" #include "source_base/parallel_comm.h" // use KP_WORLD #include "source_base/parallel_reduce.h" #include "source_base/module_external/lapack_connector.h" @@ -350,27 +351,7 @@ void OperatorEXXPW::act_op_kpar(const int nbands, // send } #ifdef __MPI -#ifdef __CUDA_MPI - MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD); -#else - if (PARAM.inp.device == "cpu") - { - MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD); - } - else if (PARAM.inp.device == "gpu") - { - // need to copy to cpu first - T* psi_mq_real_cpu = new T[wfcpw->nrxx]; - syncmem_complex_d2c_op()(psi_mq_real_cpu, psi_mq_real, wfcpw->nrxx); - MPI_Bcast(psi_mq_real_cpu, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD); - syncmem_complex_c2d_op()(psi_mq_real, psi_mq_real_cpu, wfcpw->nrxx); - delete[] psi_mq_real_cpu; - } - else - { - ModuleBase::WARNING_QUIT("OperatorEXXPW", "construct_ace: unknown device"); - } -#endif + Parallel_Common::bcast_dev(psi_mq_real, wfcpw->nrxx, KP_WORLD, iq_pool); #endif for (int n_iband = 0; n_iband < nbands; n_iband++) { diff --git a/source/source_pw/module_pwdft/op_pw_exx_ace.cpp b/source/source_pw/module_pwdft/op_pw_exx_ace.cpp index 4a47b2ac030..367d46f2f99 100644 --- a/source/source_pw/module_pwdft/op_pw_exx_ace.cpp +++ b/source/source_pw/module_pwdft/op_pw_exx_ace.cpp @@ -1,5 +1,6 @@ #include "op_pw_exx.h" #include "source_base/parallel_comm.h" +#include "source_base/parallel_device.h" #include "source_base/parallel_reduce.h" #include "source_io/module_parameter/parameter.h" #include "source_hamilt/module_xc/exx_info.h" @@ -46,7 +47,7 @@ void OperatorEXXPW::act_op_ace(const int nbands, nbands_tot ); - Parallel_Reduce::reduce_pool(Xi_psi, nbands_tot * nbands); + Parallel_Common::reduce_dev(Xi_psi, nbands_tot * nbands, POOL_WORLD); // Xi^\dagger * (Xi * psi) gemm_complex_op()(trans_C, @@ -179,32 +180,9 @@ void OperatorEXXPW::construct_ace() const { const T* psi_mq = get_pw(m_iband, iq_loc); wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq_loc); - // send } - // if (iq == 0) - // std::cout << "Bcast psi_mq_real" << std::endl; #ifdef __MPI -#ifdef __CUDA_MPI - MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD); -#else - if (PARAM.inp.device == "cpu") - { - MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD); - } - else if (PARAM.inp.device == "gpu") - { - // need to copy to cpu first - T* psi_mq_real_cpu = new T[wfcpw->nrxx]; - syncmem_complex_d2c_op()(psi_mq_real_cpu, psi_mq_real, wfcpw->nrxx); - MPI_Bcast(psi_mq_real_cpu, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD); - syncmem_complex_c2d_op()(psi_mq_real, psi_mq_real_cpu, wfcpw->nrxx); - delete[] psi_mq_real_cpu; - } - else - { - ModuleBase::WARNING_QUIT("OperatorEXXPW", "construct_ace: unknown device"); - } -#endif + Parallel_Common::bcast_dev(psi_mq_real, wfcpw->nrxx, KP_WORLD, iq_pool); #endif } // end of iq @@ -232,7 +210,7 @@ void OperatorEXXPW::construct_ace() const nbands); // reduction of psi_h_psi_ace, due to distributed memory - Parallel_Reduce::reduce_pool(psi_h_psi_ace, nbands * nbands); + Parallel_Common::reduce_dev(psi_h_psi_ace, nbands * nbands, POOL_WORLD); T intermediate_minus_one = -1.0; axpy_complex_op()(nbands * nbands, From 7412273257eb79d438655900e3ab4785c902a616 Mon Sep 17 00:00:00 2001 From: someone Date: Wed, 6 May 2026 10:53:06 +0800 Subject: [PATCH 11/14] Fix(sdft): update bcast_dev call with root parameter --- source/source_hsolver/hsolver_pw_sdft.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/source_hsolver/hsolver_pw_sdft.cpp b/source/source_hsolver/hsolver_pw_sdft.cpp index 3840239d758..f3c3d2f66a3 100644 --- a/source/source_hsolver/hsolver_pw_sdft.cpp +++ b/source/source_hsolver/hsolver_pw_sdft.cpp @@ -61,7 +61,7 @@ void HSolverPW_SDFT::solve(const UnitCell& ucell, #ifdef __MPI if (nbands > 0 && !PARAM.globalv.all_ks_run) { - Parallel_Common::bcast_dev(&psi(ik, 0, 0), npwx * nbands, BP_WORLD, &psi_cpu(ik, 0, 0)); + Parallel_Common::bcast_dev(&psi(ik, 0, 0), npwx * nbands, BP_WORLD, 0, &psi_cpu(ik, 0, 0)); MPI_Bcast(&pes->ekb(ik, 0), nbands, MPI_DOUBLE, 0, BP_WORLD); } #endif From e1438ea0cd9dc1d9614c4e77b2f74746a3a7e9e5 Mon Sep 17 00:00:00 2001 From: someone Date: Wed, 6 May 2026 10:55:22 +0800 Subject: [PATCH 12/14] Fix: add missing include for base_device namespace in parallel_device.h --- source/source_base/parallel_device.h | 1 + 1 file changed, 1 insertion(+) diff --git a/source/source_base/parallel_device.h b/source/source_base/parallel_device.h index 26fb8f50687..6f375c4b5e0 100644 --- a/source/source_base/parallel_device.h +++ b/source/source_base/parallel_device.h @@ -4,6 +4,7 @@ #include "mpi.h" #include #include +#include "source_base/module_device/types.h" namespace Parallel_Common { void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); From 2b0e6ab5166f0631d072ae8cf6d172574100270b Mon Sep 17 00:00:00 2001 From: someone Date: Wed, 6 May 2026 16:15:10 +0800 Subject: [PATCH 13/14] Fix: ELF on GPU and MPI-disabled compile errors - ELF GPU: use static_cast with Device-typed psi to bypass virtual dispatch mismatch where DEVICE_GPU cal_tau() did not override base class DEVICE_CPU cal_tau(). Meta-GGA path was unaffected because tau is computed in psiToRho during SCF. - ACE EXX: guard Parallel_Common::reduce_dev calls with __MPI since POOL_WORLD and reduce_dev are only declared when MPI is enabled. --- source/source_pw/module_pwdft/op_pw_exx_ace.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/source/source_pw/module_pwdft/op_pw_exx_ace.cpp b/source/source_pw/module_pwdft/op_pw_exx_ace.cpp index 367d46f2f99..8685b355f08 100644 --- a/source/source_pw/module_pwdft/op_pw_exx_ace.cpp +++ b/source/source_pw/module_pwdft/op_pw_exx_ace.cpp @@ -47,7 +47,9 @@ void OperatorEXXPW::act_op_ace(const int nbands, nbands_tot ); +#ifdef __MPI Parallel_Common::reduce_dev(Xi_psi, nbands_tot * nbands, POOL_WORLD); +#endif // Xi^\dagger * (Xi * psi) gemm_complex_op()(trans_C, @@ -210,7 +212,9 @@ void OperatorEXXPW::construct_ace() const nbands); // reduction of psi_h_psi_ace, due to distributed memory +#ifdef __MPI Parallel_Common::reduce_dev(psi_h_psi_ace, nbands * nbands, POOL_WORLD); +#endif T intermediate_minus_one = -1.0; axpy_complex_op()(nbands * nbands, From 1a2e27c236875682aa3c6db22e5dcc2d3810b764 Mon Sep 17 00:00:00 2001 From: someone Date: Wed, 6 May 2026 16:54:37 +0800 Subject: [PATCH 14/14] Docs: fix bcast_dev doxygen to match actual signature --- source/source_base/parallel_device.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/source/source_base/parallel_device.h b/source/source_base/parallel_device.h index 6f375c4b5e0..7826e4b4653 100644 --- a/source/source_base/parallel_device.h +++ b/source/source_base/parallel_device.h @@ -117,15 +117,15 @@ void recv_dev(T* object, int count, int source, int tag, MPI_Comm& comm, MPI_Sta } /** - * @brief bcast data in Device + * @brief broadcast data in Device * * @tparam T: float, double, std::complex, std::complex * @tparam Device - * @param ctx Device ctx - * @param object complex arrays in Device - * @param n the size of complex arrays + * @param object arrays in Device + * @param n the size of array * @param comm MPI_Comm - * @param tmp_space tmp space in CPU + * @param root root rank (default 0) + * @param tmp_space optional tmp space in CPU (default nullptr) */ template void bcast_dev(T* object, const int& n, const MPI_Comm& comm, int root = 0, T* tmp_space = nullptr)