Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
6d81876
Harden GPU MPI staging helpers
Flying-dragon-boxing Apr 30, 2026
2a07d9b
Add NCCL collectives for parallel_device
Flying-dragon-boxing Apr 30, 2026
5f1cb81
Fix NCCL headers in parallel_device
Flying-dragon-boxing Apr 30, 2026
356fb28
Route PGemm collectives through device wrappers
Flying-dragon-boxing Apr 30, 2026
d7e8334
Tighten NCCL collective correctness
Flying-dragon-boxing Apr 30, 2026
d1c5979
Relax NCCL discovery for existing environments
Flying-dragon-boxing Apr 30, 2026
bb01cef
Decouple NCCL parallel_device from CUDA-aware MPI
Flying-dragon-boxing Apr 30, 2026
aaaaad9
Propagate NCCL headers to subdirectory targets
Flying-dragon-boxing Apr 30, 2026
cf0b949
Merge branch 'develop' into fix-gpu-mpi-staging-comm
Flying-dragon-boxing Apr 30, 2026
7c34f15
Fix: narrow CPU staging guards in para_gemm to respect NCCL collectives
Flying-dragon-boxing May 4, 2026
9c5eed4
Merge branch 'develop' into fix-gpu-mpi-staging-comm
Flying-dragon-boxing May 4, 2026
6c7559d
Merge branch 'develop' into fix-gpu-mpi-staging-comm
Flying-dragon-boxing May 6, 2026
d2ca432
Refactor(exx_pw): unify GPU/CPU bcast via Parallel_Common::bcast_dev
Flying-dragon-boxing May 6, 2026
7412273
Fix(sdft): update bcast_dev call with root parameter
Flying-dragon-boxing May 6, 2026
e1438ea
Fix: add missing include for base_device namespace in parallel_device.h
Flying-dragon-boxing May 6, 2026
a156bf8
Merge branch 'develop' into 260508-exxdev
Flying-dragon-boxing May 6, 2026
2b0e6ab
Fix: ELF on GPU and MPI-disabled compile errors
Flying-dragon-boxing May 6, 2026
6e98001
Merge branch 'develop' into 260508-exxdev
Flying-dragon-boxing May 6, 2026
1a2e27c
Docs: fix bcast_dev doxygen to match actual signature
Flying-dragon-boxing May 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions source/source_base/parallel_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ NcclCommRegistry& get_nccl_registry()
}

template <typename T>
void nccl_bcast_impl(T* object, const int n, MPI_Comm& comm, ncclDataType_t datatype, const int count_scale = 1)
void nccl_bcast_impl(T* object, const int n, MPI_Comm& comm, ncclDataType_t datatype, int root = 0, const int count_scale = 1)
{
NcclCommContext& ctx = get_nccl_registry().get(comm);
if (ctx.size <= 1 || n <= 0)
{
return;
}
CHECK_NCCL(ncclBroadcast(object, object, static_cast<size_t>(n) * count_scale, datatype, 0, ctx.comm, ctx.stream));
CHECK_NCCL(ncclBroadcast(object, object, static_cast<size_t>(n) * count_scale, datatype, root, ctx.comm, ctx.stream));
CHECK_CUDA(cudaStreamSynchronize(ctx.stream));
}

Expand Down Expand Up @@ -183,24 +183,24 @@ void nccl_gatherv_impl(const T* sendbuf,
}
} // namespace

void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm)
void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm, int root)
{
nccl_bcast_impl(object, n, comm, ncclDouble);
nccl_bcast_impl(object, n, comm, ncclDouble, root);
}

void nccl_bcast_data(std::complex<double>* object, const int& n, MPI_Comm& comm)
void nccl_bcast_data(std::complex<double>* object, const int& n, MPI_Comm& comm, int root)
{
nccl_bcast_impl(reinterpret_cast<double*>(object), n, comm, ncclDouble, 2);
nccl_bcast_impl(reinterpret_cast<double*>(object), n, comm, ncclDouble, root, 2);
}

void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm)
void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm, int root)
{
nccl_bcast_impl(object, n, comm, ncclFloat);
nccl_bcast_impl(object, n, comm, ncclFloat, root);
}

void nccl_bcast_data(std::complex<float>* object, const int& n, MPI_Comm& comm)
void nccl_bcast_data(std::complex<float>* object, const int& n, MPI_Comm& comm, int root)
{
nccl_bcast_impl(reinterpret_cast<float*>(object), n, comm, ncclFloat, 2);
nccl_bcast_impl(reinterpret_cast<float*>(object), n, comm, ncclFloat, root, 2);
}

void nccl_reduce_data(double* object, const int& n, MPI_Comm& comm)
Expand Down Expand Up @@ -302,21 +302,21 @@ void recv_data(std::complex<float>* buf, int count, int source, int tag, MPI_Com
{
MPI_Recv(buf, count, MPI_COMPLEX, source, tag, comm, status);
}
void bcast_data(std::complex<double>* object, const int& n, const MPI_Comm& comm)
void bcast_data(std::complex<double>* object, const int& n, const MPI_Comm& comm, int root)
{
MPI_Bcast(object, n * 2, MPI_DOUBLE, 0, comm);
MPI_Bcast(object, n * 2, MPI_DOUBLE, root, comm);
}
void bcast_data(std::complex<float>* object, const int& n, const MPI_Comm& comm)
void bcast_data(std::complex<float>* object, const int& n, const MPI_Comm& comm, int root)
{
MPI_Bcast(object, n * 2, MPI_FLOAT, 0, comm);
MPI_Bcast(object, n * 2, MPI_FLOAT, root, comm);
}
void bcast_data(double* object, const int& n, const MPI_Comm& comm)
void bcast_data(double* object, const int& n, const MPI_Comm& comm, int root)
{
MPI_Bcast(object, n, MPI_DOUBLE, 0, comm);
MPI_Bcast(object, n, MPI_DOUBLE, root, comm);
}
void bcast_data(float* object, const int& n, const MPI_Comm& comm)
void bcast_data(float* object, const int& n, const MPI_Comm& comm, int root)
{
MPI_Bcast(object, n, MPI_FLOAT, 0, comm);
MPI_Bcast(object, n, MPI_FLOAT, root, comm);
}
void reduce_data(std::complex<double>* object, const int& n, const MPI_Comm& comm)
{
Expand Down
39 changes: 20 additions & 19 deletions source/source_base/parallel_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "mpi.h"
#include <complex>
#include <type_traits>
#include "source_base/module_device/types.h"
namespace Parallel_Common
{
void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request);
Expand All @@ -18,10 +19,10 @@ void recv_data(double* buf, int count, int source, int tag, MPI_Comm& comm, MPI_
void recv_data(std::complex<double>* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status);
void recv_data(float* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status);
void recv_data(std::complex<float>* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status);
void bcast_data(std::complex<double>* object, const int& n, const MPI_Comm& comm);
void bcast_data(std::complex<float>* object, const int& n, const MPI_Comm& comm);
void bcast_data(double* object, const int& n, const MPI_Comm& comm);
void bcast_data(float* object, const int& n, const MPI_Comm& comm);
void bcast_data(std::complex<double>* object, const int& n, const MPI_Comm& comm, int root = 0);
void bcast_data(std::complex<float>* object, const int& n, const MPI_Comm& comm, int root = 0);
void bcast_data(double* object, const int& n, const MPI_Comm& comm, int root = 0);
void bcast_data(float* object, const int& n, const MPI_Comm& comm, int root = 0);
void reduce_data(std::complex<double>* object, const int& n, const MPI_Comm& comm);
void reduce_data(std::complex<float>* object, const int& n, const MPI_Comm& comm);
void reduce_data(double* object, const int& n, const MPI_Comm& comm);
Expand All @@ -32,10 +33,10 @@ void gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int
void gatherv_data(const std::complex<float>* sendbuf, int sendcount, std::complex<float>* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm);

#if defined(__NCCL_PARALLEL_DEVICE)
void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm);
void nccl_bcast_data(std::complex<double>* object, const int& n, MPI_Comm& comm);
void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm);
void nccl_bcast_data(std::complex<float>* object, const int& n, MPI_Comm& comm);
void nccl_bcast_data(double* object, const int& n, MPI_Comm& comm, int root = 0);
void nccl_bcast_data(std::complex<double>* object, const int& n, MPI_Comm& comm, int root = 0);
void nccl_bcast_data(float* object, const int& n, MPI_Comm& comm, int root = 0);
void nccl_bcast_data(std::complex<float>* object, const int& n, MPI_Comm& comm, int root = 0);
void nccl_reduce_data(double* object, const int& n, MPI_Comm& comm);
void nccl_reduce_data(std::complex<double>* object, const int& n, MPI_Comm& comm);
void nccl_reduce_data(float* object, const int& n, MPI_Comm& comm);
Expand Down Expand Up @@ -116,35 +117,35 @@ void recv_dev(T* object, int count, int source, int tag, MPI_Comm& comm, MPI_Sta
}

/**
* @brief bcast data in Device
* @brief broadcast data in Device
*
* @tparam T: float, double, std::complex<float>, std::complex<double>
* @tparam Device
* @param ctx Device ctx
* @param object complex arrays in Device
* @param n the size of complex arrays
* @param object arrays in Device
* @param n the size of array
* @param comm MPI_Comm
* @param tmp_space tmp space in CPU
* @param root root rank (default 0)
* @param tmp_space optional tmp space in CPU (default nullptr)
*/
template <typename T, typename Device>
void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr)
void bcast_dev(T* object, const int& n, const MPI_Comm& comm, int root = 0, T* tmp_space = nullptr)
{
#if defined(__NCCL_PARALLEL_DEVICE)
if (std::is_same<Device, base_device::DEVICE_GPU>::value)
{
nccl_bcast_data(object, n, const_cast<MPI_Comm&>(comm));
nccl_bcast_data(object, n, const_cast<MPI_Comm&>(comm), root);
return;
}
#endif
#ifdef __CUDA_MPI
bcast_data(object, n, comm);
bcast_data(object, n, comm, root);
#else
object_cpu_point<T,Device> o;
int rank = 0;
MPI_Comm_rank(comm, &rank);
T* object_cpu = rank == 0 ? o.get(object, n, tmp_space) : o.get_buffer(object, n, tmp_space);
bcast_data(object_cpu, n, comm);
if (rank != 0)
T* object_cpu = rank == root ? o.get(object, n, tmp_space) : o.get_buffer(object, n, tmp_space);
bcast_data(object_cpu, n, comm, root);
if (rank != root)
{
o.sync_h2d(object, object_cpu, n);
}
Expand Down
2 changes: 1 addition & 1 deletion source/source_hsolver/hsolver_pw_sdft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void HSolverPW_SDFT<T, Device>::solve(const UnitCell& ucell,
#ifdef __MPI
if (nbands > 0 && !PARAM.globalv.all_ks_run)
{
Parallel_Common::bcast_dev<T,Device>(&psi(ik, 0, 0), npwx * nbands, BP_WORLD, &psi_cpu(ik, 0, 0));
Parallel_Common::bcast_dev<T,Device>(&psi(ik, 0, 0), npwx * nbands, BP_WORLD, 0, &psi_cpu(ik, 0, 0));
MPI_Bcast(&pes->ekb(ik, 0), nbands, MPI_DOUBLE, 0, BP_WORLD);
}
#endif
Expand Down
23 changes: 2 additions & 21 deletions source/source_pw/module_pwdft/op_pw_exx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "source_base/constants.h"
#include "source_base/global_variable.h"
#include "source_base/parallel_common.h"
#include "source_base/parallel_device.h"
#include "source_base/parallel_comm.h" // use KP_WORLD
#include "source_base/parallel_reduce.h"
#include "source_base/module_external/lapack_connector.h"
Expand Down Expand Up @@ -350,27 +351,7 @@ void OperatorEXXPW<T, Device>::act_op_kpar(const int nbands,
// send
}
#ifdef __MPI
#ifdef __CUDA_MPI
MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD);
#else
if (PARAM.inp.device == "cpu")
{
MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD);
}
else if (PARAM.inp.device == "gpu")
{
// need to copy to cpu first
T* psi_mq_real_cpu = new T[wfcpw->nrxx];
syncmem_complex_d2c_op()(psi_mq_real_cpu, psi_mq_real, wfcpw->nrxx);
MPI_Bcast(psi_mq_real_cpu, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD);
syncmem_complex_c2d_op()(psi_mq_real, psi_mq_real_cpu, wfcpw->nrxx);
delete[] psi_mq_real_cpu;
}
else
{
ModuleBase::WARNING_QUIT("OperatorEXXPW", "construct_ace: unknown device");
}
#endif
Parallel_Common::bcast_dev<T, Device>(psi_mq_real, wfcpw->nrxx, KP_WORLD, iq_pool);
Comment thread
Flying-dragon-boxing marked this conversation as resolved.
#endif
for (int n_iband = 0; n_iband < nbands; n_iband++)
{
Expand Down
34 changes: 8 additions & 26 deletions source/source_pw/module_pwdft/op_pw_exx_ace.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "op_pw_exx.h"
#include "source_base/parallel_comm.h"
#include "source_base/parallel_device.h"
#include "source_base/parallel_reduce.h"
#include "source_io/module_parameter/parameter.h"
#include "source_hamilt/module_xc/exx_info.h"
Expand Down Expand Up @@ -46,7 +47,9 @@ void OperatorEXXPW<T, Device>::act_op_ace(const int nbands,
nbands_tot
);

Parallel_Reduce::reduce_pool(Xi_psi, nbands_tot * nbands);
#ifdef __MPI
Parallel_Common::reduce_dev<T, Device>(Xi_psi, nbands_tot * nbands, POOL_WORLD);
#endif

// Xi^\dagger * (Xi * psi)
gemm_complex_op()(trans_C,
Expand Down Expand Up @@ -179,32 +182,9 @@ void OperatorEXXPW<T, Device>::construct_ace() const
{
const T* psi_mq = get_pw(m_iband, iq_loc);
wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq_loc);
// send
}
// if (iq == 0)
// std::cout << "Bcast psi_mq_real" << std::endl;
#ifdef __MPI
#ifdef __CUDA_MPI
MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD);
#else
if (PARAM.inp.device == "cpu")
{
MPI_Bcast(psi_mq_real, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD);
}
else if (PARAM.inp.device == "gpu")
{
// need to copy to cpu first
T* psi_mq_real_cpu = new T[wfcpw->nrxx];
syncmem_complex_d2c_op()(psi_mq_real_cpu, psi_mq_real, wfcpw->nrxx);
MPI_Bcast(psi_mq_real_cpu, wfcpw->nrxx, MPI_DOUBLE_COMPLEX, iq_pool, KP_WORLD);
syncmem_complex_c2d_op()(psi_mq_real, psi_mq_real_cpu, wfcpw->nrxx);
delete[] psi_mq_real_cpu;
}
else
{
ModuleBase::WARNING_QUIT("OperatorEXXPW", "construct_ace: unknown device");
}
#endif
Parallel_Common::bcast_dev<T, Device>(psi_mq_real, wfcpw->nrxx, KP_WORLD, iq_pool);
Comment thread
Flying-dragon-boxing marked this conversation as resolved.
#endif

} // end of iq
Expand Down Expand Up @@ -232,7 +212,9 @@ void OperatorEXXPW<T, Device>::construct_ace() const
nbands);

// reduction of psi_h_psi_ace, due to distributed memory
Parallel_Reduce::reduce_pool(psi_h_psi_ace, nbands * nbands);
#ifdef __MPI
Parallel_Common::reduce_dev<T, Device>(psi_h_psi_ace, nbands * nbands, POOL_WORLD);
#endif

T intermediate_minus_one = -1.0;
axpy_complex_op()(nbands * nbands,
Expand Down
Loading