Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ if(USE_NPU)
$ENV{PYTORCH_INSTALL_PATH}/include
$ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include
$ENV{PYTORCH_NPU_INSTALL_PATH}/include
$ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/distributed
$ENV{NPU_HOME_PATH}/include
$ENV{ATB_HOME_PATH}/include
$ENV{NPU_HOME_PATH}/opp/vendors/xllm/op_api/include/
Expand Down
2 changes: 1 addition & 1 deletion xllm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ target_link_libraries(xllm PRIVATE glog::glog brpc leveldb::leveldb ZLIB::ZLIB p
add_dependencies(xllm brpc-static)

if(USE_NPU)
set(COMMON_LIBS Python::Python ascendcl atb_customize hccl c_sec nnopbase ms_tools_ext)
set(COMMON_LIBS Python::Python ascendcl atb_customize hccl c_sec nnopbase ms_tools_ext torch_npu torch_python)
elseif(USE_MLU)
set(COMMON_LIBS Python::Python)
endif()
Expand Down
9 changes: 8 additions & 1 deletion xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,4 +430,11 @@ DEFINE_bool(
enable_dp_balance,
false,
"Whether to enable dp load balance, if true, sequences within a single "
"dp batch will be shuffled.");
"dp batch will be shuffled.");

#if defined(USE_NPU)
DEFINE_string(
npu_kernel_backend,
"ATB",
"NPU kernel backend. Supported options: ATB, TORCH. Default is ATB.");
#endif
4 changes: 4 additions & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,7 @@ DECLARE_bool(enable_prefetch_weight);
DECLARE_int32(flashinfer_workspace_buffer_size);

DECLARE_bool(enable_dp_balance);

#if defined(USE_NPU)
DECLARE_string(npu_kernel_backend);
#endif
2 changes: 0 additions & 2 deletions xllm/core/distributed_runtime/worker_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ void WorkerServer::create_server(

CollectiveCommunicator comm(worker_global_rank, world_size, dp_size, ep_size);
const ParallelArgs* parallel_args = comm.parallel_args();
#if defined(USE_MLU) || defined(USE_CUDA)
comm.create_process_groups(master_node_addr, device);
#endif

std::unique_ptr<Worker> worker =
std::make_unique<Worker>(*parallel_args, device, options, worker_type);
Expand Down
31 changes: 13 additions & 18 deletions xllm/core/framework/parallel_state/collective_communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ limitations under the License.
#include "mapping_npu.h"

#if defined(USE_NPU)
#include "npu_process_group.h"
#include "xllm_kernels/core/include/atb_speed/base/external_comm_manager.h"
#include "xllm_kernels/core/include/atb_speed/utils/singleton.h"
#include "xllm_kernels/models/base/param/mapping.h"
#elif defined(USE_MLU)
#include "mlu_process_group.h"
#elif defined(USE_CUDA)
Expand All @@ -30,23 +30,6 @@ limitations under the License.
#include "parallel_args.h"
#include "util/net.h"

namespace {
#if defined(USE_NPU)
std::unique_ptr<xllm::ProcessGroup> create_process_group(
int rank,
int world_size,
int rank_size,
int port,
bool trans,
const std::string& host,
const std::string& group_name,
const torch::Device& device) {
LOG(FATAL) << "Unsupported device type";
return nullptr;
}
#endif
} // namespace

namespace xllm {

CollectiveCommunicator::CollectiveCommunicator(int global_rank,
Expand All @@ -72,6 +55,13 @@ CollectiveCommunicator::CollectiveCommunicator(int global_rank,
// std::make_unique<ProcessGroupHCCL>(
// global_rank, world_size, device, comm);

// comunicator will be inited in torch.
if (FLAGS_npu_kernel_backend == "TORCH") {
parallel_args_ = std::make_unique<ParallelArgs>(
global_rank, world_size, dp_size, nullptr, ep_size);
return;
}

// comunicator will be inited in atb.
MappingNPU::Options mapping_options;
mapping_options.dp_size(dp_size)
Expand Down Expand Up @@ -116,6 +106,11 @@ CollectiveCommunicator::CollectiveCommunicator(int global_rank,
void CollectiveCommunicator::create_process_groups(
const std::string& master_addr,
const torch::Device& device) {
#if defined(USE_NPU)
if (FLAGS_npu_kernel_backend == "ATB") {
return;
}
#endif
std::string host;
int port;
net::parse_host_port_from_addr(master_addr, host, port);
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/framework/parallel_state/cuda_process_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ class ProcessGroupNccl : public ProcessGroup {
: ProcessGroup(device) {
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> pg_options =
c10d::ProcessGroupNCCL::Options::create();
#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 7
#if TORCH_VERSION_MAJOR > 2 || \
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 7)
pg_options->group_name = group_name;
#endif
int rank = global_rank;
Expand Down
147 changes: 53 additions & 94 deletions xllm/core/framework/parallel_state/npu_process_group.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ limitations under the License.

#include "npu_process_group.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

npu_process_group.cpp should be deleted, because npu_process_group.h is enough, like cuda/mlu_process_group.h.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. However, I strongly prefer to keep the .cpp file. Defining implementation directly in the header is generally considered bad practice. I hope you understand my decision to maintain this separation.


#include <c10d/ProcessGroup.hpp>
#include <c10d/TCPStore.hpp>
#include <torch_npu/csrc/distributed/ProcessGroupHCCL.hpp>

namespace {

#define HCCLCHECK(cmd) \
Expand All @@ -24,113 +28,68 @@ namespace {
LOG(FATAL) << "Failed, HCCL error :" << HcclGetErrorString(r); \
} \
} while (0)
} // namespace

inline bool is_npu(const at::Tensor& tensor) {
if (!tensor.defined()) {
return false;
}
return tensor.device().is_privateuseone();
}

inline bool is_npu(const at::TensorOptions& options) {
return options.device().is_privateuseone();
}
namespace xllm {

inline bool is_npu(const at::Device& device) {
return device.is_privateuseone();
}
ProcessGroupHCCL::ProcessGroupHCCL(int global_rank,
int world_size,
int rank_size,
int port,
bool trans,
const std::string& host,
const std::string& group_name,
const torch::Device& device)
: ProcessGroup(device) {
c10::intrusive_ptr<c10d_npu::ProcessGroupHCCL::Options> hccl_pg_options =
c10d_npu::ProcessGroupHCCL::Options::create();
#if TORCH_VERSION_MAJOR > 2 || \
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 7)
hccl_pg_options->group_name = group_name;
#endif
int rank = global_rank;
if (world_size != rank_size) {
auto [local_rank, group_ranks] =
get_group_rank(world_size, global_rank, rank_size, trans);
std::vector<uint32_t> uint32_ranks;
for (auto rank : group_ranks) {
uint32_ranks.push_back(static_cast<uint32_t>(rank));
}
hccl_pg_options->global_ranks_in_group = uint32_ranks;
rank = local_rank;
}

at::Tensor flatten_for_scatter_gather(std::vector<at::Tensor>& tensors) {
auto& t = tensors[0];
std::vector<int64_t> sizes{static_cast<int64_t>(tensors.size())};
sizes.insert(sizes.end(), t.sizes().begin(), t.sizes().end());
return at::empty(sizes, t.options());
auto store = create_tcp_store(host, port, rank);
pg_ = std::make_unique<c10d_npu::ProcessGroupHCCL>(
store, rank, rank_size, hccl_pg_options);
}

HcclDataType to_hccl_data_type(const torch::Tensor& input) {
const auto type = input.scalar_type();
switch (type) {
case at::kFloat:
return HCCL_DATA_TYPE_FP32;
case at::kHalf:
return HCCL_DATA_TYPE_FP16;
case at::kDouble:
return HCCL_DATA_TYPE_FP64;
case at::kLong:
return HCCL_DATA_TYPE_INT64;
case at::kInt:
return HCCL_DATA_TYPE_INT32;
case at::kChar:
return HCCL_DATA_TYPE_INT8;
case at::kByte:
return HCCL_DATA_TYPE_UINT8;
case at::kBool:
return HCCL_DATA_TYPE_UINT8;
case at::kBFloat16:
return HCCL_DATA_TYPE_BFP16;
default:
LOG(FATAL) << "Unconvertible HCCL type: " << type;
// Destructor.
ProcessGroupHCCL::~ProcessGroupHCCL() {
if (pg_) {
pg_->shutdown();
} else {
HCCLCHECK(HcclCommDestroy(comm_));
}
}

void check_input(torch::Tensor input) {
CHECK(is_npu(input)) << "input should be npu tensor";
CHECK(input.is_contiguous()) << "input should be contiguous";
CHECK(!input.is_sparse()) << "input have to be npu dense tensor";
}

} // namespace

namespace xllm {

ProcessGroupHCCL::ProcessGroupHCCL(int rank,
int world_size,
const torch::Device& device,
HcclComm comm)
: ProcessGroup(device), comm_(comm) {}
// Destructor.
ProcessGroupHCCL::~ProcessGroupHCCL() { HCCLCHECK(HcclCommDestroy(comm_)); }

void ProcessGroupHCCL::allreduce(torch::Tensor& input) {
DCHECK(input.device() == device())
<< "input should be on the same device as the process group";
check_input(input);
// inplace all reduce
// const auto count = input.numel();
// const auto data_type = to_hccl_data_type(input);
// auto stream = c10_npu::getCurrentNPUStream();
// torch::DeviceGuard device_guard(device());
// HCCLCHECK(HcclAllReduce(
// /*sendbuff=*/input.data_ptr(),
// /*recvbuff=*/input.data_ptr(),
// /*count=*/count,
// /*datatype=*/data_type,
// /*op=*/HCCL_REDUCE_SUM,
// /*comm=*/comm_,
// /*stream=*/stream));
}
void ProcessGroupHCCL::allgather(const torch::Tensor& input,
std::vector<torch::Tensor>& outputs) {
check_input(input);
// CHECK(outputs.size() == world_size())
// << "outputs should have the same size as world_size";
// DCHECK(input.device() == device())
// << "input should be on the same device as the process group";
// torch::DeviceGuard device_guard(device());
// torch::Tensor flattened_output = flatten_for_scatter_gather(outputs);
// const auto count = input.numel();
// const auto data_type = to_hccl_data_type(input);
// auto stream = c10_npu::getCurrentNPUStream();
// HCCLCHECK(HcclAllGather(
// /*sendbuff=*/input.data_ptr(),
// /*recvbuff=*/flattened_output.data_ptr(),
// /*sendcount=*/count,
// /*datatype=*/data_type,
// /*comm=*/comm_,
// /*stream=*/stream));
// // copy the flattened output tensors to the outputs.
// for (int i = 0; i < outputs.size(); ++i) {
// outputs[i].copy_(flattened_output[i], /*non_blocking=*/true);
// }
std::unique_ptr<xllm::ProcessGroup> create_process_group(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_process_group function can placed into an anonymous namespace in collective_communicator.cpp for all devices.

Copy link
Collaborator Author

@yingxudeng yingxudeng Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback.

Regarding the suggestion to consolidate the create_process_group functions, I have a few concerns:

Since ProcessGroupHCCL, ProcessGroupCncl, and ProcessGroupNccl are device-specific implementations , moving them to collective_communicator.cpp would introduce excessive #if/#elif preprocessor directives.

int rank,
int world_size,
int rank_size,
int port,
bool trans,
const std::string& host,
const std::string& group_name,
const torch::Device& device) {
return std::make_unique<ProcessGroupHCCL>(
rank, world_size, rank_size, port, trans, host, group_name, device);
}

} // namespace xllm
24 changes: 19 additions & 5 deletions xllm/core/framework/parallel_state/npu_process_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,30 @@ class ProcessGroupHCCL : public ProcessGroup {
const torch::Device& device,
HcclComm comm);

ProcessGroupHCCL(int rank,
int world_size,
int rank_size,
int port,
bool trans,
const std::string& host,
const std::string& group_name,
const torch::Device& device);

// Destructor.
~ProcessGroupHCCL() override;

void allreduce(torch::Tensor& input) override;

void allgather(const torch::Tensor& input,
std::vector<torch::Tensor>& outputs) override;

private:
HcclComm comm_ = nullptr;
};

std::unique_ptr<xllm::ProcessGroup> create_process_group(
int rank,
int world_size,
int rank_size,
int port,
bool trans,
const std::string& host,
const std::string& group_name,
const torch::Device& device);

} // namespace xllm
14 changes: 14 additions & 0 deletions xllm/core/framework/parallel_state/process_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ limitations under the License.

#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/TCPStore.hpp>

#if defined(USE_NPU)
#include <torch_npu/csrc/distributed/ProcessGroupHCCL.hpp>
#endif

namespace xllm {
std::pair<int, std::vector<uint64_t>> get_group_rank(int world_size,
int global_rank,
Expand Down Expand Up @@ -60,7 +65,16 @@ class ProcessGroup {
torch::Device device_;

protected:
#if defined(USE_NPU) && \
(TORCH_VERSION_MAJOR < 2 || \
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 7))
// Using ProcessGroupHCCL for NPU devices
// Note: torch_npu uses an older torch version where c10d::Backend lacks
// shutdown() method
std::unique_ptr<c10d_npu::ProcessGroupHCCL> pg_{nullptr};
#else
std::unique_ptr<c10d::Backend> pg_{nullptr};
#endif
};

} // namespace xllm
3 changes: 2 additions & 1 deletion xllm/core/layers/common/tests/tests_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ class MockBackend : public c10d::Backend {

int64_t getSize() const { return world_size_; }

#if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 7
#if TORCH_VERSION_MAJOR > 2 || \
(TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 7)
void shutdown() override {
// Mock implementation - do nothing
}
Expand Down