From a05713bc90f0cad35b2418c18f4209fd25d13556 Mon Sep 17 00:00:00 2001 From: zhangyue <138768300+zhangyue207@users.noreply.github.com> Date: Tue, 21 Apr 2026 14:32:40 +0800 Subject: [PATCH 01/11] feat(ascend-framwork): framework scaffolding + CI/generator fixes for operator split (#64) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(ci): tolerate docker teardown SIGKILL when pytest passes cleanly Docker 18.09 occasionally SIGKILLs the container during its `chown` teardown step, causing `.ci/run.py` to exit 137 even when pytest completed normally. Parse `/workspace/results/test-results.xml` for `errors` / `failures` fields and treat 137 as success when pytest reports no failures. Also bundles a small Dockerfile update for the Ascend image used by `.ci/run.py`. * fix(scripts): align `py::arg` order with C++ lambda params + optional defaults Two fixes in the pybind11 bindings generator: 1. `py::arg("implementation_index")` was emitted before `py::arg("stream")` in the generated `def(...)` call, but the C++ lambda parameters were declared in the opposite order. Kwargs then silently swapped — the stream integer landed in the impl-index slot, and dispatch SIGABRT'd. Re-order so `py::arg` entries are positional-consistent with the C++ lambda signature. 2. Only `std::optional` parameters had a `= py::none()` default; `std::optional` (and other scalar optionals) had no default, forcing callers to pass them explicitly. Generalize the default emission to all `std::optional<...>` parameters. * feat(ascend): framework scaffolding + custom_kernel build infra Framework headers shared across all Ascend operators: - `common.h`: `AclTensorCache` descriptor-caching + `toAclDtype` helpers - `workspace_pool_.h`: stream-scoped `WorkspacePool` with named arenas; `GetWorkspacePool()` / `Pool::Ensure()` entry points (matches master PR #60 naming) - `atb_common_.h`: ATB `Context` management + `toAtbTensor` helper for operators wrapping ATB APIs - `data_type_.h`, `device_.h`: `TypeMap` + `Runtime` specialization - `runtime_.h` is the existing file; left untouched by this PR `custom_kernel/` ships the AscendC standalone build system for custom kernels. Gated by its own `CMakeLists.txt`; produces `libascend_kernel.so` consumed by `kernel_custom.h` op variants (landed in follow-up category PRs). * feat: core framework, build, and test infra for Ascend operator split Shared changes needed by every Ascend operator PR: - `src/hash.h` + `src/operator.h`: cache-key plumbing used by `Operator` dispatch - `src/pybind11_utils.h`: tensor / optional-tensor / vector-tensor pybind11 casters used by the generator output - `CMakeLists.txt` + `src/CMakeLists.txt`: Ascend build target, atb discovery, `WITH_ASCEND` option - `tests/conftest.py`: `auto_act_and_assert` fixture + device parametrization (`--devices ascend/nvidia/...`) - `tests/utils.py`: `Payload`, `randn_strided`, `get_npu_stream`, and similar test helpers shared by every `tests/test_.py` * test(conftest): auto-skip tests whose op has no impl on the target device Adds a `skip_op_without_platform_impl` autouse fixture that derives the InfiniOps class name from the test module filename (`tests/test_.py` → ``) and checks `active_implementation_indices` for the parametrized device. When the op has no backend specialization on the current branch, the test is skipped instead of SIGABRTing through `Operator::Make()`. This is essential for the operator split: each per-category branch contains only its category's Ascend impls but inherits test files for all operators from master. Without this guard, `pytest tests/ --devices ascend` crashes on ops lacking ascend impls on the branch. * chore(custom_kernel): drop perf/design docs and standalone .so pytest files Remove content that duplicates what the pytest integration tests (`tests/test_rms_norm.py`, `tests/test_add_rms_norm.py`) already cover, or that's developer scratchpad rather than checked-in artifact: - `csrc/ops/rms_norm/{README,design}.md` — design scratch - `csrc/ops/rms_norm/test/{benchmark_rms_norm_msprof,run_rms_norm_case}.py`, `rms_norm_cases.jsonl`, `rms_norm_perf_report.md`, `rms_norm-test-cases.md` — per-op perf benchmarking + reports - `tests/test_{rms_norm,add_rms_norm}.py` under custom_kernel/ — redundant with the top-level pytest integration tests Build infra, kernel sources, registration, and utility headers are unchanged; the `libascend_kernel.so` artifact and its consumers (`kernel_custom.h` variants in the op-norm-rope PR) are unaffected. * style(scripts,custom_kernel): fix Python blank-line hygiene + drop redundant .gitignore entry Review items 1-5 on `scripts/generate_wrappers.py`: - Restore docstring quoting in `_find_optional_tensor_params` (reverts accidental change to ```int`` and the double-space). - Restore blank lines before `return` in `_find_optional_tensor_params`, `_is_optional_tensor`, and `_generate_params` / `_generate_arguments` (project CLAUDE.md Python style: "blank line before `return` unless inside a block body"). - Add missing blank line before `return` in `_find_vector_tensor_params` and `_is_vector_tensor`. - Drop redundant `import re` inside `_find_vector_tensor_params` — `re` is imported at module level. Review item 10 on `src/ascend/custom_kernel/.gitignore`: - Drop redundant `build/` entry (already ignored globally via the project-root `.gitignore`). Keep `output/` and `python/` — both are AscendC-specific build artifacts not covered by the root ignore. * refactor(custom): rename `custom_kernel` → `custom` and flatten to match vllm-ascend/csrc layout Reviewer top-level feedback on PR #64: mirror the directory layout of https://github.com/vllm-project/vllm-ascend/tree/main/csrc and drop the extra nesting layers. Directory changes: - `src/ascend/custom_kernel/` → `src/ascend/custom/` - Merge `csrc/` into the top: move `csrc/register.cpp`, `csrc/ops.h`, `csrc/utils/` up one level. - Rename `register.cpp` → `torch_binding.cpp` to match vllm-ascend naming. - Promote `csrc/ops//` to `/` at the top (drop the `ops/` layer). - Merge `csrc/CMakeLists.txt` content into top-level `CMakeLists.txt`; delete the now-empty `csrc/` layer. - Remove `src/ascend/custom_kernel/.gitignore` (root `.gitignore` already ignores `build/`; `output/`+`python/` were custom_kernel-scoped build artifacts that fit the root gitignore's scope too). Resulting layout: custom/ ├── build.sh ├── CMakeLists.txt ├── cmake/{config_ascend,config_envs}.cmake ├── ops.h ├── torch_binding.cpp (was `register.cpp`) ├── utils/torch_kernel_helper.h ├── rms_norm/{op_host,op_kernel}/rms_norm.cpp └── add_rms_norm/{op_host,op_kernel}/add_rms_norm.cpp License preservation: files shared in structure/substance with vllm-ascend (`torch_binding.cpp`, `ops.h`, `utils/torch_kernel_helper.h`, top-level `CMakeLists.txt`) now carry proper Apache License 2.0 headers with the original Huawei Technologies copyright preserved alongside InfiniTensor's modification copyright. Callers: - `src/CMakeLists.txt`: `custom_kernel` → `custom` in two references. - Root `CMakeLists.txt`: updated inline comment pointing to the build script. - Library name (`ascend_kernel`), static lib (`no_workspace_kernel`), and Python module name remain unchanged — `kernel_custom.h` consumers in the op-norm-rope PR link via those identifiers, not by path, so this rename does not ripple into that branch. CI: `.ci/run.py --local --gpu-id 0` passes 3072/1782 on Ascend 910B with `BUILD_CUSTOM_KERNEL=OFF` (default); the custom kernel build itself is exercised by the op-norm-rope PR's `kernel_custom.h` integration. * style: fix PR #64 review patterns found outside `custom/` Scan-and-fix pass for patterns flagged in reviewer comments on `custom_kernel/` that also appear in other files in this PR. - `src/ascend/common.h`: wrap `aclTensor` in backticks in two comments (matches comment 9 on Markdown formatting in custom_kernel). - `tests/utils.py`: add missing blank line before trailing `return` in `get_stream()` (matches comments 3/5 on missing blank line before return in non-block-body context). No camelCase-local violations in the framework C++ headers (atb_common_, common, data_type_, device_, workspace_pool_, hash, operator, pybind11_utils) — reviewer comment 6 was specific to `custom/` op_host code adapted from vllm-ascend. * style(custom): address PR #64 review comments 6+7 — C++ naming Reviewer @voltjia on PR #64 inline comments: - Comment 6: local variables must follow Google C++ Style Guide (`dimLength` → `dim_length`, etc.). Applied across all locals in the two op_host files. - Comment 7: namespace `ascend_kernel` is non-standard; use `detail` or `ascend::detail` to match other platforms. Renamed to `ascend::detail` in `ops.h`, `torch_binding.cpp`, `utils/torch_kernel_helper.h`, and both `op_host/*.cpp` files. The library name (`ascend_kernel` → `libascend_kernel.so`), `OP_PLUGIN_NAME`, and Python-import name are unchanged — those are compile/link identity and are independent of the C++ namespace. `kernel_custom.h` in op-norm-rope links via the C `extern` launch symbol, not the namespace, so this rename does not ripple into that branch. Also took the opportunity to backtick-wrap identifiers in comments that the rename touched. Inline comments 8 and 9 (Markdown formatting in comments) were already covered by the backtick pass in commit 0aed3a5 for non-custom files; the custom/ comments here also get normalized as a side-effect of rewriting the affected lines. * review(pr#64): address remaining unresolved inline comments Scanned ALL 30 inline comments on PR #64 (not just the 10 visible in collapsed view). 22 had been missed by the earlier passes. Generator (scripts/generate_wrappers.py): - Comments 8-10: swap `stream` and `implementation_index` in both the pybind lambda parameters and the `py::arg` declarations, to match the `Operator::Call(Handle, Config, ...)` order (Handle first, Config second). Previously ordered impl_index first for lambda-signature alignment; with the swap, both are reordered together so kwargs still resolve correctly. - Comment 11: restore backticks around device names in `--devices` help text. - Comment 12: `.def_static("clear_cache", ...)` kept — it is the API used by the new `_clear_operator_caches` pytest fixture. CMakeLists.txt: - Comments 13-14: wrap `NEEDED` and `torch_npu` in Markdown backticks in comments. tests/conftest.py (comments 23-29): - Reset the file to master's content and re-apply only the two new fixtures (`_clear_operator_caches`, `skip_op_without_platform_impl`) with Markdown docstrings (single backticks, not rST double). Reverts incidental changes to `pytest_addoption` help text, `skip_unsupported_dtypes` rename, `_PLATFORM_TO_TORCH_DEVICE` dict order, `_resolve_device` docstring, and the `torch_npu` comment line-wrap. - Fix comment 27's concern: `_TORCH_DEVICE_TO_PLATFORMS` now maps one torch device type to multiple platforms (`cuda` → `{nvidia, metax, iluvatar}`) and `skip_op_without_platform_impl` checks `active_implementation_indices` across all of them; it skips only when every mapped platform reports empty. tests/utils.py: - Comment 16: remove `get_npu_stream`; `get_stream(device)` covers all torch device types. tests/test_{add,causal_softmax,gemm,rms_norm,swiglu}.py: - Comments 17-22: replace the `if device.type == "npu"` branches with a single call that passes `stream=get_stream(.device)`. Single- line import restored in `test_add.py` (comment 22 — format minimization after dropping the `get_npu_stream` import). test_gemm.py specifically: moved the "impl=2 on Ascend is broken because of `src/torch/gemm/gemm.h` SFINAE pollution" workaround from the helper-level conditional into a `pytest.skip` at the top of the test body, so the helper itself becomes unconditional. * style: fix clang-format and ruff format violations - `src/ascend/custom/utils/torch_kernel_helper.h`: clang-format wrapped a long `ConvertTypes` macro continuation. - `tests/test_add.py`: ruff `format` wrapped the 5-import `tests.utils` line (89 chars, over the default 88 limit) back into multi-line form. Reviewer comment 22 suggested restoring a single line after dropping `get_npu_stream`, but with `get_stream` added the shortened form still exceeds the ruff line-length cap. * style(comments): wrap remaining technical identifiers in Markdown backticks Scan-and-fix pass for identifiers in comments that still lack Markdown backticks, matching reviewer comments 9, 11, 13, 14 on PR #64. Applied only to files authored / modified by this PR (leaves custom/cmake/ config_envs.cmake and similar vllm-ascend-verbatim content untouched to stay consistent with the upstream it was adapted from). - `CMakeLists.txt`: `pybind11` (line 7). - `src/ascend/common.h`: `shape`, `strides`, `storage_shape`, `dtype` in the `AclTensorCache` class doc. - `src/ascend/custom/CMakeLists.txt`: `AscendC` toolchain reference. - `src/ascend/custom/build.sh`: `AscendC`, `libascend_kernel.so`. - `src/ascend/custom/cmake/config_ascend.cmake`: `SOC_VERSION`, `CANN`, `AscendC`. * style(ascend): rename free-function helpers from camelCase to PascalCase Per Google C++ Style Guide §Function Names: ordinary non-accessor functions are PascalCase. Accessors/mutators (get/set on class members) are snake_case. These 7 are standalone helpers / converters / predicates — not member accessors — so they need PascalCase. threadLocalAtbContext → ThreadLocalAtbContext getAtbContext → GetAtbContext toAtbTensor (×2) → ToAtbTensor isAclRuntimeAlive → IsAclRuntimeAlive buildAclTensor → BuildAclTensor toAclDtype → ToAclDtype isIntegerDtype → IsIntegerDtype CANN APIs (`aclrtGetDevice`, `aclCreateTensor`, …), STL/PyTorch interop methods (`begin`/`end`/`size`/`data`/…), and class accessors (`get_`/`set_`) are all kept as-is — they either belong to another vendor or match the "looks like a variable" exception. Callers on the three category branches (op-simple / op-norm-rope / op-cache-attn) will pick up the new names automatically on rebase. * style(ascend): reformat CMake comments + restore `Gemm`/`MatMul` backticks in common.h - `src/ascend/custom/cmake/config_envs.cmake`: capitalize + period + Markdown backticks on all comments and status messages. - `src/ascend/custom/cmake/config_ascend.cmake`: fix `CANN` casing and backticks in the fatal-error message. - `src/ascend/custom/CMakeLists.txt`: polish status messages and inline comments (Markdown backticks + sentence case). - `src/ascend/common.h`: restore `Gemm` and `MatMul` backticks in the `BuildAclTensor` docstring per PR #64 review. * refactor(ascend): address PR #64 review — clean headers, Markdown in `TORCH_CHECK`, Google C++ naming - `workspace_pool_.h`: uncomment `` / `` (needed for `PRIu64` and `fprintf` in the destructor; not transitively available on all platforms). - `device_.h`: switch relative `../device.h` to absolute `device.h` — the historical `src/ascend/device.h` naming collision is no longer relevant. - `custom/{add_rms_norm,rms_norm}/op_host/*.cpp`: drop unneeded BSD-3-Clause headers and switch `TORCH_CHECK` messages to Markdown-backticked identifiers. - `custom/{add_rms_norm,rms_norm}/op_kernel/*.cpp`: drop unneeded BSD-3-Clause headers. - Rename wrapper functions to PascalCase per Google C++ Style: `add_rms_norm` → `AddRmsNorm`, `rms_norm` → `RmsNorm` (ops.h + torch_binding.cpp updated; `torch.ops.npu.rms_norm` registry name unchanged; kernel entry-point names stay snake_case as required by `EXEC_KERNEL_CMD`). --------- Co-authored-by: zhangyue --- .ci/images/ascend/Dockerfile | 8 + .ci/run.py | 56 +++- CMakeLists.txt | 31 ++- scripts/generate_wrappers.py | 43 ++- src/CMakeLists.txt | 46 +++- src/ascend/atb_common_.h | 95 +++++++ src/ascend/common.h | 154 ++++++++++- src/ascend/custom/CMakeLists.txt | 97 +++++++ src/ascend/custom/add_rms_norm/CMakeLists.txt | 1 + .../add_rms_norm/op_host/add_rms_norm.cpp | 134 ++++++++++ .../add_rms_norm/op_kernel/add_rms_norm.cpp | 249 ++++++++++++++++++ src/ascend/custom/build.sh | 30 +++ src/ascend/custom/cmake/config_ascend.cmake | 40 +++ src/ascend/custom/cmake/config_envs.cmake | 83 ++++++ src/ascend/custom/ops.h | 28 ++ src/ascend/custom/rms_norm/CMakeLists.txt | 1 + .../custom/rms_norm/op_host/rms_norm.cpp | 120 +++++++++ .../custom/rms_norm/op_kernel/rms_norm.cpp | 215 +++++++++++++++ src/ascend/custom/torch_binding.cpp | 31 +++ src/ascend/custom/utils/torch_kernel_helper.h | 87 ++++++ src/ascend/data_type_.h | 16 +- src/ascend/gemm/kernel.h | 8 +- src/ascend/workspace_pool_.h | 128 ++++++++- src/hash.h | 9 + src/operator.h | 25 +- src/pybind11_utils.h | 12 +- tests/conftest.py | 82 ++++++ tests/test_add.py | 16 +- tests/test_causal_softmax.py | 6 +- tests/test_gemm.py | 12 + tests/test_rms_norm.py | 24 +- tests/test_swiglu.py | 39 ++- tests/utils.py | 41 ++- 33 files changed, 1892 insertions(+), 75 deletions(-) create mode 100644 src/ascend/atb_common_.h create mode 100644 src/ascend/custom/CMakeLists.txt create mode 100644 src/ascend/custom/add_rms_norm/CMakeLists.txt create mode 100644 src/ascend/custom/add_rms_norm/op_host/add_rms_norm.cpp create mode 100644 src/ascend/custom/add_rms_norm/op_kernel/add_rms_norm.cpp create mode 100755 src/ascend/custom/build.sh create mode 100644 src/ascend/custom/cmake/config_ascend.cmake create mode 100644 src/ascend/custom/cmake/config_envs.cmake create mode 100644 src/ascend/custom/ops.h create mode 100644 src/ascend/custom/rms_norm/CMakeLists.txt create mode 100644 src/ascend/custom/rms_norm/op_host/rms_norm.cpp create mode 100644 src/ascend/custom/rms_norm/op_kernel/rms_norm.cpp create mode 100644 src/ascend/custom/torch_binding.cpp create mode 100644 src/ascend/custom/utils/torch_kernel_helper.h diff --git a/.ci/images/ascend/Dockerfile b/.ci/images/ascend/Dockerfile index 3ff79e1c8..a542b99e0 100644 --- a/.ci/images/ascend/Dockerfile +++ b/.ci/images/ascend/Dockerfile @@ -18,4 +18,12 @@ RUN pip install --no-cache-dir --progress off \ pytest-xdist \ ruff +ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +ENV LD_LIBRARY_PATH=${ASCEND_TOOLKIT_HOME}/lib64:${ASCEND_TOOLKIT_HOME}/lib64/plugin/opskernel:${ASCEND_TOOLKIT_HOME}/lib64/plugin/nnengine:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe/op_tiling/lib/linux/aarch64:${LD_LIBRARY_PATH} +ENV PYTHONPATH=${ASCEND_TOOLKIT_HOME}/python/site-packages:${ASCEND_TOOLKIT_HOME}/opp/built-in/op_impl/ai_core/tbe:${PYTHONPATH} +ENV PATH=${ASCEND_TOOLKIT_HOME}/bin:${ASCEND_TOOLKIT_HOME}/compiler/ccec_compiler/bin:${PATH} +ENV ASCEND_AICPU_PATH=${ASCEND_TOOLKIT_HOME} +ENV ASCEND_OPP_PATH=${ASCEND_TOOLKIT_HOME}/opp +ENV TOOLCHAIN_HOME=${ASCEND_TOOLKIT_HOME}/toolkit + WORKDIR /workspace diff --git a/.ci/run.py b/.ci/run.py index 7330d9694..e293b4a28 100644 --- a/.ci/run.py +++ b/.ci/run.py @@ -8,6 +8,7 @@ import subprocess import sys import uuid +import xml.etree.ElementTree as ET from datetime import datetime from pathlib import Path @@ -24,6 +25,42 @@ _PYTEST_VALUE_FLAGS = {"-n", "-k", "-m", "-p", "--tb", "--junitxml", "--rootdir"} +def _junit_xml_indicates_pass(results_dir): + """Return True if `pytest` junit XML under `results_dir` reports no failures/errors. + + Used to distinguish a real CI failure from the docker 18.09 + container-teardown `SIGKILL` (exit code 137) that occurs on this host + after a child process exits successfully — bash returns 0 from inside + the container, but the docker daemon reports 137 due to a race in its + `--rm` cleanup path. The junit XML is written by pytest before that + teardown and reliably captures the real outcome of the test stage. + """ + for junit in Path(results_dir).rglob("test-results.xml"): + try: + root = ET.parse(junit).getroot() + except ET.ParseError: + continue + + suites = root.findall("testsuite") if root.tag == "testsuites" else [root] + + if not suites: + continue + + for suite in suites: + try: + if int(suite.get("failures", 0)) > 0: + return False + + if int(suite.get("errors", 0)) > 0: + return False + except ValueError: + return False + + return True + + return False + + def apply_test_override(run_cmd, test_path): """Replace positional test path(s) in a pytest stage command. @@ -437,8 +474,23 @@ def main(): pool.release(allocated_ids) if returncode != 0: - print(f"job {job_name} failed (exit code {returncode})", file=sys.stderr) - failed += 1 + # Docker 18.09 on this host occasionally SIGKILLs containers + # during `--rm` cleanup after the inner process already exited + # cleanly, producing exit code 137. Fall back to the pytest + # junit XML to recover the real outcome in that case. + if returncode == 137 and _junit_xml_indicates_pass(results_dir): + print( + f"[warn] job {job_name}: container exited with 137 " + f"(likely docker teardown SIGKILL after clean pytest); " + f"junit XML reports no failures — treating as success", + file=sys.stderr, + ) + else: + print( + f"job {job_name} failed (exit code {returncode})", + file=sys.stderr, + ) + failed += 1 sys.exit(1 if failed else 0) diff --git a/CMakeLists.txt b/CMakeLists.txt index e88cc20c0..b72c34eaf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,7 +4,7 @@ project(InfiniOps LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) -# Internal variable to control pybind11's automatic optimization flags (like `-flto`). +# Internal variable to control `pybind11`'s automatic optimization flags (like `-flto`). set(PYBIND11_ENABLE_EXTRAS ON) # Options for backends. @@ -18,6 +18,13 @@ option(WITH_ASCEND "Enable Ascend backend" OFF) option(WITH_TORCH "Enable PyTorch C++ backend" OFF) +# Default OFF until CANN's `extract_host_stub.py` path handling is fixed for +# `scikit-build-core` temp-dir builds (triggers `KeyError` on the preprocessed +# object path). Enable explicitly with `-DBUILD_CUSTOM_KERNEL=ON` when the +# toolchain is compatible or when building via the standalone +# `src/ascend/custom/build.sh` script. +option(BUILD_CUSTOM_KERNEL "Build custom AscendC kernel PyTorch extension (requires `torch_npu`)" OFF) + option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF) option(AUTO_DETECT_BACKENDS "Automatically detect available backends" OFF) option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF) @@ -130,6 +137,28 @@ if(WITH_TORCH) find_library(C10_LIB c10 HINTS ${_torch_lib_dirs} REQUIRED) set(TORCH_LIBRARIES ${TORCH_LIB} ${TORCH_CPU_LIB} ${C10_LIB}) + # `auditwheel`-repaired `torch` wheels bundle transitive dependencies + # (e.g. `libgfortran-.so`, `libopenblasp-.so`) in a sibling + # `torch.libs/` directory that `library_paths()` does not return. When + # building against such a wheel, the linker needs this path to resolve + # the bundled `NEEDED` entries (otherwise: `undefined reference to + # _gfortran_etime@GFORTRAN_8` etc.). + execute_process( + COMMAND ${Python_EXECUTABLE} -c "import os, torch; d = os.path.dirname(torch.__file__); p = os.path.join(os.path.dirname(d), 'torch.libs'); print(p if os.path.isdir(p) else '')" + OUTPUT_VARIABLE TORCH_BUNDLED_LIBS_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + + if(TORCH_BUNDLED_LIBS_DIR) + list(APPEND CMAKE_BUILD_RPATH "${TORCH_BUNDLED_LIBS_DIR}") + list(APPEND CMAKE_INSTALL_RPATH "${TORCH_BUNDLED_LIBS_DIR}") + # `rpath-link` is linker-only: lets `ld` resolve the bundled + # transitive `NEEDED` entries at link time without adding them to our + # own binary's direct `NEEDED` list. + add_link_options("-Wl,-rpath-link,${TORCH_BUNDLED_LIBS_DIR}") + message(STATUS "PyTorch bundled libs: ${TORCH_BUNDLED_LIBS_DIR}") + endif() + # Query the `CXX11` ABI setting that `torch` was compiled with. # A mismatch causes linker errors (e.g. undefined reference to # `c10::Device::Device(std::string const&)`). diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index c050b31c0..49b6c199f 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -103,8 +103,18 @@ def _find_optional_tensor_params(op_name): return set(re.findall(r"std::optional\s+(\w+)", source)) +def _find_vector_tensor_params(op_name): + """Return a set of parameter names declared as `std::vector` in + the base header. + """ + source = (_BASE_DIR / f"{op_name}.h").read_text() + + return set(re.findall(r"std::vector\s+(\w+)", source)) + + def _generate_pybind11(operator): optional_tensor_params = _find_optional_tensor_params(operator.name) + vector_tensor_params = _find_vector_tensor_params(operator.name) def _is_optional_tensor(arg): if arg.spelling in optional_tensor_params: @@ -112,6 +122,15 @@ def _is_optional_tensor(arg): return "std::optional" in arg.type.spelling and "Tensor" in arg.type.spelling + def _is_optional(arg): + return "std::optional" in arg.type.spelling + + def _is_vector_tensor(arg): + if arg.spelling in vector_tensor_params: + return True + + return "std::vector" in arg.type.spelling and "Tensor" in arg.type.spelling + def _generate_params(node): parts = [] @@ -121,6 +140,8 @@ def _generate_params(node): if _is_optional_tensor(arg): parts.append(f"std::optional {arg.spelling}") + elif _is_vector_tensor(arg): + parts.append(f"std::vector {arg.spelling}") else: param = arg.type.spelling.replace("const Tensor", "py::object").replace( "Tensor", "py::object" @@ -138,6 +159,8 @@ def _generate_arguments(node): if _is_optional_tensor(arg): args.append(f"OptionalTensorFromPybind11Handle({arg.spelling})") + elif _is_vector_tensor(arg): + args.append(f"VectorTensorFromPybind11Handle({arg.spelling})") elif "Tensor" in arg.type.spelling: args.append(f"TensorFromPybind11Handle({arg.spelling})") else: @@ -155,11 +178,18 @@ def _generate_init(constructor): }}))""" def _generate_py_args(node): - return ", ".join( - f'py::arg("{arg.spelling}")' - for arg in node.get_arguments() - if arg.spelling != "stream" - ) + parts = [] + + for arg in node.get_arguments(): + if arg.spelling == "stream": + continue + + if _is_optional(arg): + parts.append(f'py::arg("{arg.spelling}") = py::none()') + else: + parts.append(f'py::arg("{arg.spelling}")') + + return ", ".join(parts) def _generate_call(op_name, call, method=True): call_params = _generate_params(call) @@ -224,7 +254,8 @@ def _generate_call(op_name, call, method=True): {calls} .def_static("active_implementation_indices", [](const std::string& device) {{ return Self::active_implementation_indices(DeviceTypeFromString(device)); - }}); + }}) + .def_static("clear_cache", &Self::clear_cache); {callers} }} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cbdae6745..32c92949d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -178,8 +178,10 @@ if(WITH_ASCEND) "ascend/*.cc" "ascend/*.cpp" ) - # Exclude `kernel_impl.cpp` — AscendC device code, not compiled by the host C++ compiler. + # Exclude `kernel_impl.cpp` — `AscendC` device code, not compiled by the host C++ compiler. list(FILTER ASCEND_SOURCES EXCLUDE REGEX ".*kernel_impl\\.cpp$") + # Exclude custom/ — standalone PyTorch extension, built separately. + list(FILTER ASCEND_SOURCES EXCLUDE REGEX ".*/custom/.*") target_compile_definitions(infiniops PUBLIC WITH_ASCEND=1) target_sources(infiniops PRIVATE ${ASCEND_SOURCES}) @@ -215,7 +217,38 @@ if(WITH_ASCEND) "${ASCEND_HOME}/lib64/libopapi.so" "${ASCEND_HAL_LIB}") + # ATB (Ascend Transformer Boost) — provides fused operators like + # `PagedAttention` and `ReshapeAndCache` that are graph-capture safe. + set(ATB_HOME_DIR "$ENV{ATB_HOME_PATH}") + if(NOT ATB_HOME_DIR) + # Default search path under CANN nnal directory. + file(GLOB ATB_SEARCH_DIRS "/usr/local/Ascend/nnal/atb/*/atb/cxx_abi_1") + if(ATB_SEARCH_DIRS) + list(SORT ATB_SEARCH_DIRS ORDER DESCENDING) + list(GET ATB_SEARCH_DIRS 0 ATB_HOME_DIR) + endif() + endif() + + if(ATB_HOME_DIR AND EXISTS "${ATB_HOME_DIR}/include/atb/operation.h") + message(STATUS "ATB found: ${ATB_HOME_DIR}") + target_compile_definitions(infiniops PUBLIC INFINI_HAS_ATB=1) + target_include_directories(infiniops PUBLIC "${ATB_HOME_DIR}/include") + target_link_libraries(infiniops PUBLIC "${ATB_HOME_DIR}/lib/libatb.so") + else() + message(STATUS "ATB not found — ATB-based operators disabled") + endif() + list(APPEND DEVICE_LIST "ascend") + + # Custom `AscendC` kernels (PyTorch extension, requires `torch_npu`). + if(BUILD_CUSTOM_KERNEL) + add_subdirectory(ascend/custom) + + # Link the compiled `AscendC` kernel objects into `infiniops` so that + # custom kernel implementations (e.g. `RmsNorm` index 1) can call + # them via the generated launch functions. + target_compile_definitions(infiniops PUBLIC INFINI_HAS_CUSTOM_KERNELS=1) + endif() endif() if(WITH_TORCH) @@ -340,6 +373,17 @@ if(GENERATE_PYTHON_BINDINGS) target_include_directories(ops PRIVATE ${PROJECT_SOURCE_DIR}) target_link_libraries(ops PRIVATE infiniops) + # Custom `AscendC` kernel objects must be linked directly into ops + # because the `AscendC` toolchain compiles host stubs with hidden + # visibility — `libinfiniops.so` cannot re-export those symbols. + # The `Operator<..., 1>` template instantiations that call + # `aclrtlaunch_*` live in `ops.cc`, so link here with + # `--whole-archive` to ensure all launch functions are available. + if(BUILD_CUSTOM_KERNEL) + target_link_libraries(ops PRIVATE + -Wl,--whole-archive no_workspace_kernel -Wl,--no-whole-archive) + endif() + set_target_properties(infiniops PROPERTIES INSTALL_RPATH "$ORIGIN") set_target_properties(ops PROPERTIES INSTALL_RPATH "$ORIGIN") diff --git a/src/ascend/atb_common_.h b/src/ascend/atb_common_.h new file mode 100644 index 000000000..dabda7e98 --- /dev/null +++ b/src/ascend/atb_common_.h @@ -0,0 +1,95 @@ +#ifndef INFINI_OPS_ASCEND_ATB_COMMON__H_ +#define INFINI_OPS_ASCEND_ATB_COMMON__H_ + +#ifdef INFINI_HAS_ATB + +#include +#include +#include +#include + +#include "acl/acl.h" +#include "ascend/data_type_.h" +#include "atb/context.h" +#include "atb/operation.h" +#include "atb/types.h" +#include "tensor.h" + +namespace infini::ops::ascend { + +// Thread-local ATB context. +// +// ATB requires a `Context` for Setup/Execute. Creating one per call is +// expensive (internal tiling buffer allocation), so we cache one per thread. +// `SetExecuteStream` is called before every `Execute` to match the caller's +// stream. +inline atb::Context*& ThreadLocalAtbContext() { + thread_local atb::Context* ctx = nullptr; + + return ctx; +} + +inline atb::Context* GetAtbContext(aclrtStream stream) { + auto*& ctx = ThreadLocalAtbContext(); + + if (!ctx) { + atb::Status s = atb::CreateContext(&ctx); + assert(s == atb::NO_ERROR && "atb::CreateContext failed"); + } + + atb::Status s = ctx->SetExecuteStream(stream); + assert(s == atb::NO_ERROR && "atb::Context::SetExecuteStream failed"); + + return ctx; +} + +// Build an `atb::Tensor` from an InfiniOps Tensor. +// +// Sets dtype, ND format, shape dimensions, and the device data pointer. +// The caller must keep the InfiniOps Tensor alive for the duration of the +// ATB operation. +inline atb::Tensor ToAtbTensor(const Tensor& t) { + atb::Tensor out; + out.desc.dtype = ToAclDtype(t.dtype()); + out.desc.format = ACL_FORMAT_ND; + out.desc.shape.dimNum = t.ndim(); + assert(t.ndim() <= atb::MAX_DIM); + + for (uint64_t i = 0; i < t.ndim(); ++i) { + out.desc.shape.dims[i] = static_cast(t.size(i)); + } + + out.deviceData = const_cast(t.data()); + out.dataSize = static_cast(t.numel()) * t.element_size(); + + return out; +} + +// Build an `atb::Tensor` from explicit shape, dtype, and data pointer. +// +// Useful for sub-views of a larger buffer (e.g. K-cache and V-cache halves +// of a fused KV cache tensor). +inline atb::Tensor ToAtbTensor(const std::vector& shape, + aclDataType dtype, void* data, + uint64_t data_size) { + atb::Tensor out; + out.desc.dtype = dtype; + out.desc.format = ACL_FORMAT_ND; + out.desc.shape.dimNum = shape.size(); + assert(shape.size() <= atb::MAX_DIM); + + for (size_t i = 0; i < shape.size(); ++i) { + out.desc.shape.dims[i] = shape[i]; + } + + out.deviceData = data; + out.dataSize = data_size; + + return out; +} + +} // namespace infini::ops::ascend + +#endif // INFINI_HAS_ATB + +#endif // INFINI_OPS_ASCEND_ATB_COMMON__H_ diff --git a/src/ascend/common.h b/src/ascend/common.h index fba4766b4..63f9690e9 100644 --- a/src/ascend/common.h +++ b/src/ascend/common.h @@ -11,12 +11,24 @@ namespace infini::ops::ascend { +// Check whether the ACL runtime is still usable. +// +// During process shutdown the CANN runtime may be torn down before C++ +// static destructors run. Calling `aclrtGetDevice` is the cheapest +// probe — it fails once the runtime is gone. Destructors that call +// ACL/ATB APIs must guard with this to avoid use-after-finalize crashes. +inline bool IsAclRuntimeAlive() { + int32_t dev_id = -1; + + return aclrtGetDevice(&dev_id) == ACL_SUCCESS; +} + // Build an `aclTensor` descriptor from an InfiniOps `Tensor`. // // When `transpose_last2` is true the last two dimensions are swapped in the // descriptor (shape and strides) without copying data. This is used by `Gemm` // and `MatMul` to express a transpose via the view. -inline aclTensor* buildAclTensor(const Tensor& t, +inline aclTensor* BuildAclTensor(const Tensor& t, bool transpose_last2 = false) { std::vector shape(t.shape().begin(), t.shape().end()); std::vector strides(t.strides().begin(), t.strides().end()); @@ -51,6 +63,146 @@ inline aclTensor* buildAclTensor(const Tensor& t, static_cast(storage_shape.size()), const_cast(t.data())); } +// Pre-computed tensor metadata for descriptor reuse. +// +// Stores `shape`, `strides`, `storage_shape`, and `dtype` once (avoiding +// per-call heap allocations). The `aclTensor` descriptor is created on the +// first `get()` call and its data pointer is updated in-place via +// `aclSetRawTensorAddr` on subsequent calls. +class AclTensorCache { + public: + AclTensorCache() = default; + + // Construct from explicit metadata (for device buffers not wrapped in + // Tensor). Computes contiguous strides from shape. + AclTensorCache(std::vector shape, aclDataType dtype, void* data) + : shape_(std::move(shape)), dtype_(dtype) { + strides_.resize(shape_.size()); + int64_t stride = 1; + for (int i = static_cast(shape_.size()) - 1; i >= 0; --i) { + strides_[i] = stride; + stride *= shape_[i]; + } + storage_shape_ = {stride}; + + if (data) { + tensor_ = aclCreateTensor( + shape_.data(), static_cast(shape_.size()), dtype_, + strides_.data(), + /*storageOffset=*/0, ACL_FORMAT_ND, storage_shape_.data(), + static_cast(storage_shape_.size()), data); + } + } + + explicit AclTensorCache(const Tensor& t, bool transpose_last2 = false) + : dtype_{ToAclDtype(t.dtype())} { + shape_.assign(t.shape().begin(), t.shape().end()); + strides_.assign(t.strides().begin(), t.strides().end()); + + if (transpose_last2 && shape_.size() >= 2) { + auto n = shape_.size(); + std::swap(shape_[n - 2], shape_[n - 1]); + std::swap(strides_[n - 2], strides_[n - 1]); + } + + int64_t storage_elems = 1; + for (size_t i = 0; i < shape_.size(); ++i) { + if (shape_[i] == 0) { + storage_elems = 0; + break; + } + if (strides_[i] > 0 && shape_[i] > 1) { + storage_elems += static_cast(shape_[i] - 1) * strides_[i]; + } + } + storage_shape_ = {storage_elems}; + } + + ~AclTensorCache() { + if (tensor_ && IsAclRuntimeAlive()) { + aclDestroyTensor(tensor_); + } + } + + AclTensorCache(const AclTensorCache&) = delete; + + AclTensorCache& operator=(const AclTensorCache&) = delete; + + AclTensorCache(AclTensorCache&& o) noexcept + : shape_(std::move(o.shape_)), + strides_(std::move(o.strides_)), + storage_shape_(std::move(o.storage_shape_)), + dtype_(o.dtype_), + tensor_(o.tensor_) { + o.tensor_ = nullptr; + } + + AclTensorCache& operator=(AclTensorCache&& o) noexcept { + if (this != &o) { + if (tensor_) { + aclDestroyTensor(tensor_); + } + shape_ = std::move(o.shape_); + strides_ = std::move(o.strides_); + storage_shape_ = std::move(o.storage_shape_); + dtype_ = o.dtype_; + tensor_ = o.tensor_; + o.tensor_ = nullptr; + } + + return *this; + } + + // Null the cached descriptor pointer without calling `aclDestroyTensor`. + // Call from the owning operator's destructor: the descriptor is still + // referenced by a Repeatable `aclOpExecutor` which would be destroyed + // alongside the tensor, and per CANN 8.5 docs that destruction is our + // responsibility. In practice `aclDestroyAclOpExecutor` during process + // shutdown crashes even with `IsAclRuntimeAlive()` true — see `64c367c` — + // so operators leak the executor at shutdown; skipping `aclDestroyTensor` + // here keeps `~AclTensorCache` from double-freeing a descriptor the + // executor still holds. + void release() { tensor_ = nullptr; } + + // Explicitly destroy the cached tensor and clear the pointer. + // Use only when the descriptor is owned outside any executor (e.g. an + // intermediate tensor not passed to `aclnn*GetWorkspaceSize`). + void destroy() { + if (tensor_) { + aclDestroyTensor(tensor_); + tensor_ = nullptr; + } + } + + // Update the data pointer and return the cached descriptor. + aclTensor* get(void* data) const { + if (tensor_) { + aclSetRawTensorAddr(tensor_, data); + + return tensor_; + } + + tensor_ = aclCreateTensor( + shape_.data(), static_cast(shape_.size()), dtype_, + strides_.data(), + /*storageOffset=*/0, ACL_FORMAT_ND, storage_shape_.data(), + static_cast(storage_shape_.size()), data); + + return tensor_; + } + + private: + std::vector shape_; + + std::vector strides_; + + std::vector storage_shape_; + + aclDataType dtype_{ACL_DT_UNDEFINED}; + + mutable aclTensor* tensor_ = nullptr; +}; + } // namespace infini::ops::ascend #endif diff --git a/src/ascend/custom/CMakeLists.txt b/src/ascend/custom/CMakeLists.txt new file mode 100644 index 000000000..ca6e6883f --- /dev/null +++ b/src/ascend/custom/CMakeLists.txt @@ -0,0 +1,97 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd. +# Copyright (c) 2025 InfiniTensor. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Adapted from https://github.com/vllm-project/vllm-ascend/blob/main/csrc/CMakeLists.txt + +cmake_minimum_required(VERSION 3.20 FATAL_ERROR) +project(ascend-kernel LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 17) + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE RELEASE) +endif() + +add_compile_options(-Wunused-value -Wcast-align -Wcast-qual -Wwrite-strings + -Wsign-compare -Wextra) + +if(${CMAKE_BUILD_TYPE} MATCHES "RELEASE") + add_compile_options(-O3 -fvisibility=hidden -fvisibility-inlines-hidden + -fstack-protector-strong -fPIE -fPIC) + message(STATUS "Build type set to `RELEASE`.") +else() + add_compile_options(-g -rdynamic) +endif() + +set(PROJECT_OP_SRC_BASE ${PROJECT_SOURCE_DIR}) +set(PROJECT_BUILD_PATH ${PROJECT_SOURCE_DIR}/build) +set(PROJECT_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/output) + +include(cmake/config_envs.cmake) +include(cmake/config_ascend.cmake) + +find_program(CCACHE_PROGRAM ccache) +if(CCACHE_PROGRAM) + message(STATUS "Found `ccache`: ${CCACHE_PROGRAM}.") + set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") + set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") +endif() + +# Shared library output location. +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_OUTPUT_PATH}) + +# Host-side files. +file(GLOB OP_SRCS + ${PROJECT_OP_SRC_BASE}/torch_binding.cpp + ${PROJECT_OP_SRC_BASE}/rms_norm/op_host/rms_norm.cpp +) + +# Shared library name — consumed by `kernel_custom.h` variants and by the +# Python side via `import ascend_kernel`. +set(OP_PLUGIN_NAME ascend_kernel) + +# Kernel-side files (device code compiled by the `AscendC` toolchain). +ascendc_library(no_workspace_kernel STATIC + ${PROJECT_OP_SRC_BASE}/rms_norm/op_kernel/rms_norm.cpp +) + +# Create the shared library `libascend_kernel.so`. +add_library(${OP_PLUGIN_NAME} SHARED ${OP_SRCS}) + +target_link_libraries(${OP_PLUGIN_NAME} PRIVATE + no_workspace_kernel + torch_npu + ascendcl + tiling_api + nnopbase + opapi + register + platform + ascendalog + dl +) + +target_link_directories(${OP_PLUGIN_NAME} PRIVATE + ${TORCH_DIR}/lib + ${TORCH_NPU_DIR}/lib +) + +target_include_directories(${OP_PLUGIN_NAME} PRIVATE + ${PROJECT_OP_SRC_BASE}/utils + ${PROJECT_SOURCE_DIR}/include + ${TORCH_DIR}/include + ${TORCH_DIR}/include/torch/csrc/api/include + ${TORCH_NPU_DIR}/include/third_party/acl/inc + ${TORCH_NPU_DIR}/include/third_party/hccl/inc + ${TORCH_NPU_DIR}/include + ${PYTHON_INCLUDE_DIR} + ${ASCEND_INCLUDE_DIR}/external + ${ASCEND_INCLUDE_DIR}/experiment/platform + ${ASCEND_INCLUDE_DIR}/experiment/runtime +) diff --git a/src/ascend/custom/add_rms_norm/CMakeLists.txt b/src/ascend/custom/add_rms_norm/CMakeLists.txt new file mode 100644 index 000000000..1748afc06 --- /dev/null +++ b/src/ascend/custom/add_rms_norm/CMakeLists.txt @@ -0,0 +1 @@ +ascendc_add_operator(OP_NAME add_rms_norm) diff --git a/src/ascend/custom/add_rms_norm/op_host/add_rms_norm.cpp b/src/ascend/custom/add_rms_norm/op_host/add_rms_norm.cpp new file mode 100644 index 000000000..b8e0d504b --- /dev/null +++ b/src/ascend/custom/add_rms_norm/op_host/add_rms_norm.cpp @@ -0,0 +1,134 @@ +#include "aclrtlaunch_add_rms_norm.h" +#include "tiling/platform/platform_ascendc.h" +#include "torch_kernel_helper.h" + +namespace ascend::detail { + +std::vector AddRmsNorm(const at::Tensor& x1, const at::Tensor& x2, + const at::Tensor& weight, double eps) { + // Input validation. + TORCH_CHECK(x1.dim() > 0, + "`AddRmsNorm`: `x1` must have at least 1 dimension."); + TORCH_CHECK(x1.sizes() == x2.sizes(), + "`AddRmsNorm`: `x1` and `x2` must have the same shape."); + TORCH_CHECK(x1.scalar_type() == x2.scalar_type(), + "`AddRmsNorm`: `x1` and `x2` must have the same dtype."); + TORCH_CHECK(x1.scalar_type() == at::kHalf || x1.scalar_type() == at::kFloat, + "`AddRmsNorm`: only `float16` and `float32` are supported; got ", + x1.scalar_type(), "."); + TORCH_CHECK(weight.dim() == 1, + "`AddRmsNorm`: `weight` must be 1-dimensional."); + TORCH_CHECK(weight.size(0) == x1.size(-1), "`AddRmsNorm`: `weight` size (", + weight.size(0), ") must match input last dim (", x1.size(-1), + ")."); + + int64_t dim_length = x1.size(-1); + int64_t total_rows = x1.numel() / dim_length; + + if (total_rows == 0 || dim_length == 0) { + return {at::empty_like(x1), at::empty_like(x1)}; + } + + at::Tensor inp1 = x1.contiguous(); + at::Tensor inp2 = x2.contiguous(); + int64_t dtype_size = inp1.element_size(); + + // Hardware parameters. + auto ascendc_platform = + platform_ascendc::PlatformAscendCManager::GetInstance(); + int64_t core_num = static_cast(ascendc_platform->GetCoreNumAiv()); + uint64_t ub_size; + ascendc_platform->GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub_size); + int64_t ub_size_limit = static_cast(ub_size); + + // Alignment (32-byte boundary). + int64_t align_elements = 32 / dtype_size; + int64_t dim_length_align = + ((dim_length + align_elements - 1) / align_elements) * align_elements; + + // UB capacity check. + // + // - `fp16`: `inQ_x1` (*2*2) + `inQ_x2` (*2*2) + `outQ_y` (*2*2) + + // `outQ_xout` (*2*2) + `fp32Buf1` (*4) + `fp32Buf2` (*4) + + // `weight` (*4) = 16 + 12 = 28 + // - `fp32`: `inQ_x1` (*2*4) + `inQ_x2` (*2*4) + `outQ_y` (*2*4) + + // `outQ_xout` (*2*4) + `weight` (*4) = 32 + 4 = 36 + int64_t buffer_coefficient = (dtype_size == 2) ? 28 : 36; + int64_t max_dim_length = (ub_size_limit - 1024) / buffer_coefficient; + int64_t fp_align_elements = 32 / 4; + max_dim_length = (max_dim_length / fp_align_elements) * fp_align_elements; + TORCH_CHECK(dim_length_align <= max_dim_length, + "`AddRmsNorm`: `dim_length` ", dim_length, " (aligned ", + dim_length_align, ") exceeds UB capacity (max ", max_dim_length, + ")."); + + // Padding. + at::Tensor kernel_input1; + at::Tensor kernel_input2; + + if (dim_length != dim_length_align) { + kernel_input1 = inp1.reshape({total_rows, dim_length}); + kernel_input1 = at::constant_pad_nd( + kernel_input1, {0, dim_length_align - dim_length}, 0.0); + kernel_input1 = kernel_input1.contiguous(); + + kernel_input2 = inp2.reshape({total_rows, dim_length}); + kernel_input2 = at::constant_pad_nd( + kernel_input2, {0, dim_length_align - dim_length}, 0.0); + kernel_input2 = kernel_input2.contiguous(); + } else { + kernel_input1 = inp1.reshape({total_rows, dim_length_align}).contiguous(); + kernel_input2 = inp2.reshape({total_rows, dim_length_align}).contiguous(); + } + + at::Tensor kernel_output_y = at::empty_like(kernel_input1); + at::Tensor kernel_output_x_out = at::empty_like(kernel_input1); + + // Weight: always pass as fp32, padded to `dim_length_align`. + at::Tensor weight_float = weight.contiguous().to(at::kFloat); + + if (dim_length != dim_length_align) { + weight_float = at::constant_pad_nd( + weight_float, {0, dim_length_align - dim_length}, 0.0); + } + + weight_float = weight_float.contiguous(); + + // Block-level tiling (distribute rows across cores). + int64_t used_core_num = std::min(total_rows, core_num); + int64_t former_length = (total_rows + used_core_num - 1) / used_core_num; + int64_t tail_length = former_length - 1; + int64_t former_num = total_rows - tail_length * used_core_num; + uint32_t block_dim = static_cast(used_core_num); + + // All `EXEC_KERNEL_CMD` args must be lvalues. + float eps_float = static_cast(eps); + int64_t dtype_size_val = dtype_size; + + // The first arg `add_rms_norm` is the AscendC kernel entry-point name — it + // must match `ascendc_add_operator(OP_NAME add_rms_norm)` in `CMakeLists.txt`, + // the `__global__ __aicore__ void add_rms_norm(...)` definition in + // `op_kernel/`, and the generated `aclrtlaunch_add_rms_norm.h` header. + // Google C++ Style's PascalCase rule does NOT apply: this identifier is + // dictated by the AscendC toolchain's symbol convention. + EXEC_KERNEL_CMD(add_rms_norm, block_dim, kernel_input1, kernel_input2, + weight_float, kernel_output_y, kernel_output_x_out, + total_rows, dim_length, dim_length_align, former_num, + former_length, tail_length, eps_float, dtype_size_val); + + // Remove padding and reshape back to original shape. + at::Tensor output_y = kernel_output_y; + at::Tensor output_x_out = kernel_output_x_out; + + if (dim_length != dim_length_align) { + output_y = output_y.narrow(-1, 0, dim_length).contiguous(); + output_x_out = output_x_out.narrow(-1, 0, dim_length).contiguous(); + } + + output_y = output_y.reshape(x1.sizes()); + output_x_out = output_x_out.reshape(x1.sizes()); + + return {output_y, output_x_out}; +} + +} // namespace ascend::detail diff --git a/src/ascend/custom/add_rms_norm/op_kernel/add_rms_norm.cpp b/src/ascend/custom/add_rms_norm/op_kernel/add_rms_norm.cpp new file mode 100644 index 000000000..e2a08e555 --- /dev/null +++ b/src/ascend/custom/add_rms_norm/op_kernel/add_rms_norm.cpp @@ -0,0 +1,249 @@ +#include "kernel_operator.h" + +constexpr int32_t BUFFER_NUM = 2; + +template +class KernelAddRmsNorm { + public: + __aicore__ inline KernelAddRmsNorm() {} + + __aicore__ inline void Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR weight, GM_ADDR y, + GM_ADDR x_out, int64_t totalRows, + int64_t dimLength, int64_t dimLengthAlign, + int64_t formerNum, int64_t formerLength, + int64_t tailLength, float eps) { + this->dimLength = dimLength; + this->dimLengthAlign = dimLengthAlign; + this->eps = eps; + + // Block-level tiling: determine row range for this core. + int64_t blockIdx = AscendC::GetBlockIdx(); + int64_t rowOffset; + + if (blockIdx < formerNum) { + this->blockRows = formerLength; + rowOffset = formerLength * blockIdx; + } else { + this->blockRows = tailLength; + int64_t tailIdx = blockIdx - formerNum; + rowOffset = formerLength * formerNum + tailLength * tailIdx; + } + + // Global memory pointers. + x1Gm.SetGlobalBuffer((__gm__ T*)x1 + rowOffset * dimLengthAlign, + this->blockRows * dimLengthAlign); + x2Gm.SetGlobalBuffer((__gm__ T*)x2 + rowOffset * dimLengthAlign, + this->blockRows * dimLengthAlign); + yGm.SetGlobalBuffer((__gm__ T*)y + rowOffset * dimLengthAlign, + this->blockRows * dimLengthAlign); + xOutGm.SetGlobalBuffer((__gm__ T*)x_out + rowOffset * dimLengthAlign, + this->blockRows * dimLengthAlign); + weightGm.SetGlobalBuffer((__gm__ float*)weight, dimLengthAlign); + + int32_t dimLenAlign = static_cast(this->dimLengthAlign); + + // I/O queues (double-buffered). + pipe.InitBuffer(inQueueX1, BUFFER_NUM, + dimLenAlign * static_cast(sizeof(T))); + pipe.InitBuffer(inQueueX2, BUFFER_NUM, + dimLenAlign * static_cast(sizeof(T))); + pipe.InitBuffer(outQueueY, BUFFER_NUM, + dimLenAlign * static_cast(sizeof(T))); + pipe.InitBuffer(outQueueXOut, BUFFER_NUM, + dimLenAlign * static_cast(sizeof(T))); + + // Weight buffer (fp32, loaded once, reused for all rows). + pipe.InitBuffer(weightBuf, + dimLenAlign * static_cast(sizeof(float))); + + // FP16 path needs extra fp32 compute buffers. + // buf1: holds x_out in fp32 (reused from x1_fp32 after Add). + // buf2: holds x2_fp32 initially, then x_out^2, then final result. + if constexpr (sizeof(T) == 2) { + pipe.InitBuffer(fp32Buf1, + dimLenAlign * static_cast(sizeof(float))); + pipe.InitBuffer(fp32Buf2, + dimLenAlign * static_cast(sizeof(float))); + } + + // ReduceSum temporary buffer (size per API formula). + constexpr int32_t ELEMS_PER_REPEAT = 256 / sizeof(float); + constexpr int32_t ELEMS_PER_BLOCK = 32 / sizeof(float); + int32_t firstMaxRepeat = + (dimLenAlign + ELEMS_PER_REPEAT - 1) / ELEMS_PER_REPEAT; + int32_t reduceTmpSize = + ((firstMaxRepeat + ELEMS_PER_BLOCK - 1) / ELEMS_PER_BLOCK) * + ELEMS_PER_BLOCK; + pipe.InitBuffer(reduceTmpBuf, + reduceTmpSize * static_cast(sizeof(float))); + + // Scalar buffer for reduction result (8 floats = 32 bytes). + pipe.InitBuffer(sumBuf, 32); + + // Load weight (fp32) from GM into `weightBuf`. + AscendC::LocalTensor wLocal = weightBuf.Get(); + AscendC::DataCopyExtParams wParams{ + 1, static_cast(dimLenAlign * sizeof(float)), 0, 0, 0}; + AscendC::DataCopyPadExtParams wPad{false, 0, 0, 0.0f}; + AscendC::DataCopyPad(wLocal, weightGm, wParams, wPad); + + // Ensure weight DMA completes before compute. + AscendC::PipeBarrier(); + } + + __aicore__ inline void Process() { + for (int64_t row = 0; row < this->blockRows; ++row) { + CopyIn(row); + Compute(row); + CopyOut(row); + } + } + + private: + __aicore__ inline void CopyIn(int64_t row) { + AscendC::LocalTensor x1Local = inQueueX1.AllocTensor(); + AscendC::LocalTensor x2Local = inQueueX2.AllocTensor(); + AscendC::DataCopyExtParams params{ + 1, static_cast(this->dimLengthAlign * sizeof(T)), 0, 0, 0}; + AscendC::DataCopyPadExtParams pad{false, 0, 0, static_cast(0)}; + AscendC::DataCopyPad(x1Local, x1Gm[row * this->dimLengthAlign], params, + pad); + AscendC::DataCopyPad(x2Local, x2Gm[row * this->dimLengthAlign], params, + pad); + inQueueX1.EnQue(x1Local); + inQueueX2.EnQue(x2Local); + } + + __aicore__ inline void Compute(int64_t row) { + AscendC::LocalTensor x1Local = inQueueX1.DeQue(); + AscendC::LocalTensor x2Local = inQueueX2.DeQue(); + AscendC::LocalTensor yLocal = outQueueY.AllocTensor(); + AscendC::LocalTensor xOutLocal = outQueueXOut.AllocTensor(); + + AscendC::LocalTensor wLocal = weightBuf.Get(); + AscendC::LocalTensor rTmp = reduceTmpBuf.Get(); + AscendC::LocalTensor sLocal = sumBuf.Get(); + + int32_t dimLen = static_cast(this->dimLength); + int32_t dimLenAlign = static_cast(this->dimLengthAlign); + + if constexpr (sizeof(T) == 4) { + // ---- FP32 path: compute directly. ---- + + // Step 1: x_out = x1 + x2. + AscendC::Add(xOutLocal, x1Local, x2Local, dimLenAlign); + + // Step 2: x_out^2 into yLocal (reuse output buffer temporarily). + AscendC::Mul(yLocal, xOutLocal, xOutLocal, dimLenAlign); + + // Step 3: ReduceSum(x_out^2) -> sLocal[0]. + // ReduceSum may modify yLocal, but we overwrite it below. + AscendC::ReduceSum(sLocal, yLocal, rTmp, dimLenAlign); + + // Step 4-5: scale = 1 / sqrt(mean(x_out^2) + eps). + float sumVal = sLocal.GetValue(0); + float meanVal = sumVal / static_cast(dimLen) + this->eps; + sLocal.SetValue(0, meanVal); + AscendC::Sqrt(sLocal, sLocal, 8); + float scale = 1.0f / sLocal.GetValue(0); + + // Step 6: y = x_out * scale. + AscendC::Muls(yLocal, xOutLocal, scale, dimLenAlign); + + // Step 7: y = y * weight. + AscendC::Mul(yLocal, yLocal, wLocal, dimLenAlign); + + } else { + // ---- FP16 path: cast → fp32 compute → cast back. ---- + AscendC::LocalTensor b1 = fp32Buf1.Get(); + AscendC::LocalTensor b2 = fp32Buf2.Get(); + + // Cast inputs fp16 → fp32. + AscendC::Cast(b1, x1Local, AscendC::RoundMode::CAST_NONE, dimLenAlign); + AscendC::Cast(b2, x2Local, AscendC::RoundMode::CAST_NONE, dimLenAlign); + + // Step 1: x_out = x1 + x2 (fp32), stored in b1. + AscendC::Add(b1, b1, b2, dimLenAlign); + + // Cast x_out fp32 → fp16 for the x_out output. + AscendC::Cast(xOutLocal, b1, AscendC::RoundMode::CAST_ROUND, dimLenAlign); + + // Step 2: x_out^2 in fp32, stored in b2. + AscendC::Mul(b2, b1, b1, dimLenAlign); + + // Step 3: ReduceSum(x_out^2) -> sLocal[0]. + AscendC::ReduceSum(sLocal, b2, rTmp, dimLenAlign); + + // Step 4-5: scale = 1 / sqrt(mean(x_out^2) + eps). + float sumVal = sLocal.GetValue(0); + float meanVal = sumVal / static_cast(dimLen) + this->eps; + sLocal.SetValue(0, meanVal); + AscendC::Sqrt(sLocal, sLocal, 8); + float scale = 1.0f / sLocal.GetValue(0); + + // Step 6: y = x_out * scale (fp32), reuse b2. + AscendC::Muls(b2, b1, scale, dimLenAlign); + + // Step 7: y = y * weight (fp32). + AscendC::Mul(b2, b2, wLocal, dimLenAlign); + + // Cast result fp32 → fp16. + AscendC::Cast(yLocal, b2, AscendC::RoundMode::CAST_ROUND, dimLenAlign); + } + + inQueueX1.FreeTensor(x1Local); + inQueueX2.FreeTensor(x2Local); + outQueueY.EnQue(yLocal); + outQueueXOut.EnQue(xOutLocal); + } + + __aicore__ inline void CopyOut(int64_t row) { + AscendC::LocalTensor yLocal = outQueueY.DeQue(); + AscendC::LocalTensor xOutLocal = outQueueXOut.DeQue(); + AscendC::DataCopyExtParams params{ + 1, static_cast(this->dimLengthAlign * sizeof(T)), 0, 0, 0}; + AscendC::DataCopyPad(yGm[row * this->dimLengthAlign], yLocal, params); + AscendC::DataCopyPad(xOutGm[row * this->dimLengthAlign], xOutLocal, params); + outQueueY.FreeTensor(yLocal); + outQueueXOut.FreeTensor(xOutLocal); + } + + private: + AscendC::TPipe pipe; + AscendC::TQue inQueueX1; + AscendC::TQue inQueueX2; + AscendC::TQue outQueueY; + AscendC::TQue outQueueXOut; + + AscendC::TBuf weightBuf; + AscendC::TBuf fp32Buf1; + AscendC::TBuf fp32Buf2; + AscendC::TBuf reduceTmpBuf; + AscendC::TBuf sumBuf; + + AscendC::GlobalTensor x1Gm, x2Gm, yGm, xOutGm; + AscendC::GlobalTensor weightGm; + + int64_t blockRows; + int64_t dimLength; + int64_t dimLengthAlign; + float eps; +}; + +extern "C" __global__ __aicore__ void add_rms_norm( + GM_ADDR x1, GM_ADDR x2, GM_ADDR weight, GM_ADDR y, GM_ADDR x_out, + int64_t totalRows, int64_t dimLength, int64_t dimLengthAlign, + int64_t formerNum, int64_t formerLength, int64_t tailLength, float eps, + int64_t dtypeSize) { + if (dtypeSize == 2) { + KernelAddRmsNorm op; + op.Init(x1, x2, weight, y, x_out, totalRows, dimLength, dimLengthAlign, + formerNum, formerLength, tailLength, eps); + op.Process(); + } else { + KernelAddRmsNorm op; + op.Init(x1, x2, weight, y, x_out, totalRows, dimLength, dimLengthAlign, + formerNum, formerLength, tailLength, eps); + op.Process(); + } +} diff --git a/src/ascend/custom/build.sh b/src/ascend/custom/build.sh new file mode 100755 index 000000000..258a88e4b --- /dev/null +++ b/src/ascend/custom/build.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Build custom `AscendC` kernels into `libascend_kernel.so`. +set -e + +SOC_VERSION="${1:-Ascend910_9382}" + +# Detect CANN toolkit path. +_CANN_TOOLKIT_INSTALL_PATH=$(grep "Toolkit_InstallPath" /etc/Ascend/ascend_cann_install.info | awk -F'=' '{print $2}') +source "${_CANN_TOOLKIT_INSTALL_PATH}/set_env.sh" +echo "CANN: ${ASCEND_TOOLKIT_HOME}" + +ASCEND_INCLUDE_DIR=${ASCEND_TOOLKIT_HOME}/$(arch)-linux/include +CURRENT_DIR=$(pwd) +OUTPUT_DIR=${CURRENT_DIR}/output +mkdir -p "${OUTPUT_DIR}" + +BUILD_DIR=build +rm -rf "${BUILD_DIR}" +mkdir -p "${BUILD_DIR}" + +cmake \ + -DASCEND_HOME_PATH="${ASCEND_HOME_PATH}" \ + -DASCEND_INCLUDE_DIR="${ASCEND_INCLUDE_DIR}" \ + -DSOC_VERSION="${SOC_VERSION}" \ + -B "${BUILD_DIR}" \ + -S . + +cmake --build "${BUILD_DIR}" -j 16 + +echo "Build complete. Output: ${OUTPUT_DIR}" diff --git a/src/ascend/custom/cmake/config_ascend.cmake b/src/ascend/custom/cmake/config_ascend.cmake new file mode 100644 index 000000000..1772e9e70 --- /dev/null +++ b/src/ascend/custom/cmake/config_ascend.cmake @@ -0,0 +1,40 @@ + +if(DEFINED ASCEND_HOME_PATH) +elseif(DEFINED ENV{ASCEND_HOME_PATH}) + set(ASCEND_HOME_PATH "$ENV{ASCEND_HOME_PATH}" CACHE PATH "ASCEND CANN package installation directory" FORCE) +endif() + +set(ASCEND_CANN_PACKAGE_PATH ${ASCEND_HOME_PATH}) + +# Auto-detect `SOC_VERSION` from `npu-smi info` if not set externally. +# Required by `CANN`'s `ascendc.cmake` for `AscendC` kernel compilation. +if(NOT DEFINED SOC_VERSION OR "${SOC_VERSION}" STREQUAL "") + execute_process( + COMMAND bash -c "npu-smi info 2>/dev/null | awk '/910B|910A|310/ {for (i=1;i<=NF;i++) if ($i ~ /^(910|310)/) {print \"Ascend\" $i; exit}}'" + OUTPUT_VARIABLE _DETECTED_SOC + OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(_DETECTED_SOC) + set(SOC_VERSION "${_DETECTED_SOC}" CACHE STRING "Ascend SOC version" FORCE) + else() + set(SOC_VERSION "Ascend910B4" CACHE STRING "Ascend SOC version" FORCE) + endif() + + message(STATUS "SOC_VERSION auto-set to ${SOC_VERSION}") +endif() + +if(EXISTS ${ASCEND_HOME_PATH}/tools/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/tools/tikcpp/ascendc_kernel_cmake) +elseif(EXISTS ${ASCEND_HOME_PATH}/compiler/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/compiler/tikcpp/ascendc_kernel_cmake) +elseif(EXISTS ${ASCEND_HOME_PATH}/ascendc_devkit/tikcpp/samples/cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/ascendc_devkit/tikcpp/samples/cmake) +else() + message(FATAL_ERROR "`ascendc_kernel_cmake` does not exist; please check whether the `CANN` package is installed.") +endif() + +include(${ASCENDC_CMAKE_DIR}/ascendc.cmake) + + +message(STATUS "ASCEND_CANN_PACKAGE_PATH = ${ASCEND_CANN_PACKAGE_PATH}") +message(STATUS "ASCEND_HOME_PATH = ${ASCEND_HOME_PATH}") diff --git a/src/ascend/custom/cmake/config_envs.cmake b/src/ascend/custom/cmake/config_envs.cmake new file mode 100644 index 000000000..e715bfc7e --- /dev/null +++ b/src/ascend/custom/cmake/config_envs.cmake @@ -0,0 +1,83 @@ +# Find the Python binary. +find_program(PYTHON_EXECUTABLE NAMES python3) + +if (NOT EXISTS ${PYTHON_EXECUTABLE}) + message(FATAL_ERROR "`python3` is not found; install Python first.") +endif () + +# Get `torch`, `torch_npu`, and `pybind11` paths via a Python helper. +execute_process( + COMMAND ${PYTHON_EXECUTABLE} "-c" + "import torch; import torch_npu; import os; import pybind11; import sysconfig; +torch_dir = os.path.realpath(os.path.dirname(torch.__file__)); +torch_npu_dir = os.path.realpath(os.path.dirname(torch_npu.__file__)); +pybind11_dir = os.path.realpath(os.path.dirname(pybind11.__file__)); +abi_enabled=torch.compiled_with_cxx11_abi(); +python_include_dir = sysconfig.get_path('include'); +print(torch_dir, torch_npu_dir, pybind11_dir, abi_enabled, python_include_dir, end=''); +quit(0) + " + RESULT_VARIABLE EXEC_RESULT + OUTPUT_VARIABLE OUTPUT_ENV_DEFINES) + +# Abort if the Python helper failed. +if (NOT ${EXEC_RESULT} EQUAL 0) + message(FATAL_ERROR "Failed to run Python script to probe env vars like `TORCH_DIR`.") +else () + message(STATUS "Python probe succeeded; output string is [${OUTPUT_ENV_DEFINES}].") +endif () + +# Extract `TORCH_DIR`. +execute_process( + COMMAND sh -c "echo \"${OUTPUT_ENV_DEFINES}\" | awk '{print $1}'" + OUTPUT_VARIABLE TORCH_DIR + RESULT_VARIABLE EXEC_RESULT + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +# Extract `TORCH_NPU_DIR`. +execute_process( + COMMAND sh -c "echo \"${OUTPUT_ENV_DEFINES}\" | awk '{print $2}'" + OUTPUT_VARIABLE TORCH_NPU_DIR + RESULT_VARIABLE EXEC_RESULT + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +# Extract `PYBIND11_DIR`. +execute_process( + COMMAND sh -c "echo \"${OUTPUT_ENV_DEFINES}\" | awk '{print $3}'" + OUTPUT_VARIABLE PYBIND11_DIR + RESULT_VARIABLE EXEC_RESULT + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +# Extract the PyTorch C++11 ABI flag. +execute_process( + COMMAND sh -c "echo \"${OUTPUT_ENV_DEFINES}\" | awk '{print $4}'" + OUTPUT_VARIABLE TORCH_API_ENABLED + RESULT_VARIABLE EXEC_RESULT + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +# Extract `PYTHON_INCLUDE_DIR`. +execute_process( + COMMAND sh -c "echo \"${OUTPUT_ENV_DEFINES}\" | awk '{print $5}'" + OUTPUT_VARIABLE PYTHON_INCLUDE_DIR + RESULT_VARIABLE EXEC_RESULT + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +message(STATUS "SOC_VERSION=${SOC_VERSION}") +message(STATUS "TORCH_DIR=${TORCH_DIR}") +message(STATUS "TORCH_NPU_DIR=${TORCH_NPU_DIR}") +message(STATUS "PYBIND11_DIR=${PYBIND11_DIR}") +message(STATUS "PYTHON_INCLUDE_DIR=${PYTHON_INCLUDE_DIR}") + +# Set `_GLIBCXX_USE_CXX11_ABI` to match the PyTorch build. +if (${TORCH_API_ENABLED} STREQUAL "True") + add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=1) + message(STATUS "_GLIBCXX_USE_CXX11_ABI=1") +else () + add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=0) + message(STATUS "_GLIBCXX_USE_CXX11_ABI=0") +endif () diff --git a/src/ascend/custom/ops.h b/src/ascend/custom/ops.h new file mode 100644 index 000000000..e9e6ad9d4 --- /dev/null +++ b/src/ascend/custom/ops.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * Copyright (c) 2025 InfiniTensor. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * + * Adapted from https://github.com/vllm-project/vllm-ascend/blob/main/csrc/ops.h + */ + +#ifndef OPS_H +#define OPS_H + +namespace ascend::detail { + +at::Tensor RmsNorm(const at::Tensor& input, const at::Tensor& weight, + double eps); + +} // namespace ascend::detail + +#endif // OPS_H diff --git a/src/ascend/custom/rms_norm/CMakeLists.txt b/src/ascend/custom/rms_norm/CMakeLists.txt new file mode 100644 index 000000000..94ceabaa6 --- /dev/null +++ b/src/ascend/custom/rms_norm/CMakeLists.txt @@ -0,0 +1 @@ +ascendc_add_operator(OP_NAME rms_norm) diff --git a/src/ascend/custom/rms_norm/op_host/rms_norm.cpp b/src/ascend/custom/rms_norm/op_host/rms_norm.cpp new file mode 100644 index 000000000..eb521c7b5 --- /dev/null +++ b/src/ascend/custom/rms_norm/op_host/rms_norm.cpp @@ -0,0 +1,120 @@ +#include "aclrtlaunch_rms_norm.h" +#include "tiling/platform/platform_ascendc.h" +#include "torch_kernel_helper.h" + +namespace ascend::detail { + +at::Tensor RmsNorm(const at::Tensor& input, const at::Tensor& weight, + double eps) { + // Input validation. + TORCH_CHECK(input.dim() > 0, + "`RmsNorm`: `input` must have at least 1 dimension."); + TORCH_CHECK( + input.scalar_type() == at::kHalf || input.scalar_type() == at::kFloat, + "`RmsNorm`: only `float16` and `float32` are supported; got ", + input.scalar_type(), "."); + TORCH_CHECK(weight.dim() == 1, "`RmsNorm`: `weight` must be 1-dimensional."); + TORCH_CHECK(weight.size(0) == input.size(-1), "`RmsNorm`: `weight` size (", + weight.size(0), ") must match input last dim (", input.size(-1), + ")."); + + int64_t dim_length = input.size(-1); + int64_t total_rows = input.numel() / dim_length; + + if (total_rows == 0 || dim_length == 0) { + return at::empty_like(input); + } + + at::Tensor x = input.contiguous(); + int64_t dtype_size = x.element_size(); + + // Hardware parameters. + auto ascendc_platform = + platform_ascendc::PlatformAscendCManager::GetInstance(); + int64_t core_num = static_cast(ascendc_platform->GetCoreNumAiv()); + uint64_t ub_size; + ascendc_platform->GetCoreMemSize(platform_ascendc::CoreMemType::UB, ub_size); + int64_t ub_size_limit = static_cast(ub_size); + + // Alignment (32-byte boundary). + int64_t align_elements = 32 / dtype_size; + int64_t dim_length_align = + ((dim_length + align_elements - 1) / align_elements) * align_elements; + + // UB capacity check. + // + // - `fp32`: `inQ` (*2) + `outQ` (*2) + `weight` = 5 * `dim_length_align` + // * 4 = coeff 20. + // - `fp16`: `inQ` (*2) + `outQ` (*2) + `x_fp32` + `tmp_fp32` + `weight` + // = 2 * `dim_length_align` * 2 * 2 + 3 * `dim_length_align` * 4 = + // 8 + 12 = coeff 20. + int64_t buffer_coefficient = 20; + // Reserve 1024 bytes for reduce buffers. + int64_t max_dim_length = (ub_size_limit - 1024) / buffer_coefficient; + // `fp32` alignment. + int64_t fp_align_elements = 32 / 4; + max_dim_length = (max_dim_length / fp_align_elements) * fp_align_elements; + TORCH_CHECK(dim_length_align <= max_dim_length, + "`RmsNorm`: `dim_length` ", dim_length, " (aligned ", + dim_length_align, ") exceeds UB capacity (max ", max_dim_length, + ")."); + + // Padding. + at::Tensor kernel_input; + + if (dim_length != dim_length_align) { + kernel_input = x.reshape({total_rows, dim_length}); + kernel_input = at::constant_pad_nd( + kernel_input, {0, dim_length_align - dim_length}, 0.0); + kernel_input = kernel_input.contiguous(); + } else { + kernel_input = x.reshape({total_rows, dim_length_align}).contiguous(); + } + + at::Tensor kernel_output = at::empty_like(kernel_input); + + // Weight: always pass as fp32, padded to `dim_length_align`. + at::Tensor weight_float = weight.contiguous().to(at::kFloat); + + if (dim_length != dim_length_align) { + weight_float = at::constant_pad_nd( + weight_float, {0, dim_length_align - dim_length}, 0.0); + } + + weight_float = weight_float.contiguous(); + + // Block-level tiling (distribute rows across cores). + int64_t used_core_num = std::min(total_rows, core_num); + int64_t former_length = (total_rows + used_core_num - 1) / used_core_num; + int64_t tail_length = former_length - 1; + int64_t former_num = total_rows - tail_length * used_core_num; + uint32_t block_dim = static_cast(used_core_num); + + // All `EXEC_KERNEL_CMD` args must be lvalues. + float eps_float = static_cast(eps); + int64_t dtype_size_val = dtype_size; + + // The first arg `rms_norm` is the AscendC kernel entry-point name — it + // must match `ascendc_add_operator(OP_NAME rms_norm)` in `CMakeLists.txt`, + // the `__global__ __aicore__ void rms_norm(...)` definition in `op_kernel/`, + // and the generated `aclrtlaunch_rms_norm.h` header. Google C++ Style's + // PascalCase rule does NOT apply: this identifier is dictated by the + // AscendC toolchain's symbol convention. + EXEC_KERNEL_CMD(rms_norm, block_dim, kernel_input, weight_float, + kernel_output, total_rows, dim_length, dim_length_align, + former_num, former_length, tail_length, eps_float, + dtype_size_val); + + // Remove padding and reshape back to original shape. + at::Tensor output = kernel_output; + + if (dim_length != dim_length_align) { + output = output.narrow(-1, 0, dim_length).contiguous(); + } + + output = output.reshape(input.sizes()); + + return output; +} + +} // namespace ascend::detail diff --git a/src/ascend/custom/rms_norm/op_kernel/rms_norm.cpp b/src/ascend/custom/rms_norm/op_kernel/rms_norm.cpp new file mode 100644 index 000000000..5c8f4fc67 --- /dev/null +++ b/src/ascend/custom/rms_norm/op_kernel/rms_norm.cpp @@ -0,0 +1,215 @@ +#include "kernel_operator.h" + +constexpr int32_t BUFFER_NUM = 2; + +template +class KernelRmsNorm { + public: + __aicore__ inline KernelRmsNorm() {} + + __aicore__ inline void Init(GM_ADDR x, GM_ADDR weight, GM_ADDR y, + int64_t totalRows, int64_t dimLength, + int64_t dimLengthAlign, int64_t formerNum, + int64_t formerLength, int64_t tailLength, + float eps) { + this->dimLength = dimLength; + this->dimLengthAlign = dimLengthAlign; + this->eps = eps; + + // Block-level tiling: determine row range for this core. + int64_t blockIdx = AscendC::GetBlockIdx(); + int64_t rowOffset; + + if (blockIdx < formerNum) { + this->blockRows = formerLength; + rowOffset = formerLength * blockIdx; + } else { + this->blockRows = tailLength; + int64_t tailIdx = blockIdx - formerNum; + rowOffset = formerLength * formerNum + tailLength * tailIdx; + } + + // Global memory pointers. + xGm.SetGlobalBuffer((__gm__ T*)x + rowOffset * dimLengthAlign, + this->blockRows * dimLengthAlign); + yGm.SetGlobalBuffer((__gm__ T*)y + rowOffset * dimLengthAlign, + this->blockRows * dimLengthAlign); + weightGm.SetGlobalBuffer((__gm__ float*)weight, dimLengthAlign); + + int32_t dimLenAlign = static_cast(this->dimLengthAlign); + + // I/O queues (double-buffered). + pipe.InitBuffer(inQueueX, BUFFER_NUM, + dimLenAlign * static_cast(sizeof(T))); + pipe.InitBuffer(outQueueY, BUFFER_NUM, + dimLenAlign * static_cast(sizeof(T))); + + // Weight buffer (fp32, loaded once, reused for all rows). + pipe.InitBuffer(weightBuf, + dimLenAlign * static_cast(sizeof(float))); + + // FP16 path needs extra fp32 compute buffers. + if constexpr (sizeof(T) == 2) { + pipe.InitBuffer(xFp32Buf, + dimLenAlign * static_cast(sizeof(float))); + pipe.InitBuffer(tmpFp32Buf, + dimLenAlign * static_cast(sizeof(float))); + } + + // ReduceSum temporary buffer (size per API formula). + constexpr int32_t ELEMS_PER_REPEAT = 256 / sizeof(float); + constexpr int32_t ELEMS_PER_BLOCK = 32 / sizeof(float); + int32_t firstMaxRepeat = + (dimLenAlign + ELEMS_PER_REPEAT - 1) / ELEMS_PER_REPEAT; + int32_t reduceTmpSize = + ((firstMaxRepeat + ELEMS_PER_BLOCK - 1) / ELEMS_PER_BLOCK) * + ELEMS_PER_BLOCK; + pipe.InitBuffer(reduceTmpBuf, + reduceTmpSize * static_cast(sizeof(float))); + + // Scalar buffer for reduction result (8 floats = 32 bytes). + pipe.InitBuffer(sumBuf, 32); + + // Load weight (fp32) from GM into `weightBuf`. + AscendC::LocalTensor wLocal = weightBuf.Get(); + AscendC::DataCopyExtParams wParams{ + 1, static_cast(dimLenAlign * sizeof(float)), 0, 0, 0}; + AscendC::DataCopyPadExtParams wPad{false, 0, 0, 0.0f}; + AscendC::DataCopyPad(wLocal, weightGm, wParams, wPad); + + // Ensure weight DMA completes before compute. + AscendC::PipeBarrier(); + } + + __aicore__ inline void Process() { + for (int64_t row = 0; row < this->blockRows; ++row) { + CopyIn(row); + Compute(row); + CopyOut(row); + } + } + + private: + __aicore__ inline void CopyIn(int64_t row) { + AscendC::LocalTensor xLocal = inQueueX.AllocTensor(); + AscendC::DataCopyExtParams params{ + 1, static_cast(this->dimLengthAlign * sizeof(T)), 0, 0, 0}; + AscendC::DataCopyPadExtParams pad{false, 0, 0, static_cast(0)}; + AscendC::DataCopyPad(xLocal, xGm[row * this->dimLengthAlign], params, pad); + inQueueX.EnQue(xLocal); + } + + __aicore__ inline void Compute(int64_t row) { + AscendC::LocalTensor xLocal = inQueueX.DeQue(); + AscendC::LocalTensor yLocal = outQueueY.AllocTensor(); + + AscendC::LocalTensor wLocal = weightBuf.Get(); + AscendC::LocalTensor rTmp = reduceTmpBuf.Get(); + AscendC::LocalTensor sLocal = sumBuf.Get(); + + int32_t dimLen = static_cast(this->dimLength); + int32_t dimLenAlign = static_cast(this->dimLengthAlign); + + if constexpr (sizeof(T) == 4) { + // ---- FP32 path: compute directly. ---- + + // Step 1: x^2 into yLocal (reuse output buffer temporarily). + AscendC::Mul(yLocal, xLocal, xLocal, dimLenAlign); + + // Step 2: ReduceSum(x^2) -> sLocal[0]. + // ReduceSum may modify src (yLocal), but we overwrite it later. + AscendC::ReduceSum(sLocal, yLocal, rTmp, dimLenAlign); + + // Step 3-5: scale = 1 / sqrt(mean(x^2) + eps). + float sumVal = sLocal.GetValue(0); + float meanVal = sumVal / static_cast(dimLen) + this->eps; + sLocal.SetValue(0, meanVal); + AscendC::Sqrt(sLocal, sLocal, 8); + float scale = 1.0f / sLocal.GetValue(0); + + // Step 6: y = x * scale. + AscendC::Muls(yLocal, xLocal, scale, dimLenAlign); + + // Step 7: y = y * weight. + AscendC::Mul(yLocal, yLocal, wLocal, dimLenAlign); + + } else { + // ---- FP16 path: cast → fp32 compute → cast back. ---- + AscendC::LocalTensor xF32 = xFp32Buf.Get(); + AscendC::LocalTensor tmpF32 = tmpFp32Buf.Get(); + + // Cast input fp16 → fp32. + AscendC::Cast(xF32, xLocal, AscendC::RoundMode::CAST_NONE, dimLenAlign); + + // Step 1: x^2 in fp32. + AscendC::Mul(tmpF32, xF32, xF32, dimLenAlign); + + // Step 2: ReduceSum(x^2) -> sLocal[0]. + AscendC::ReduceSum(sLocal, tmpF32, rTmp, dimLenAlign); + + // Step 3-5: scale = 1 / sqrt(mean(x^2) + eps). + float sumVal = sLocal.GetValue(0); + float meanVal = sumVal / static_cast(dimLen) + this->eps; + sLocal.SetValue(0, meanVal); + AscendC::Sqrt(sLocal, sLocal, 8); + float scale = 1.0f / sLocal.GetValue(0); + + // Step 6: y = x * scale (fp32). + AscendC::Muls(tmpF32, xF32, scale, dimLenAlign); + + // Step 7: y = y * weight (fp32). + AscendC::Mul(tmpF32, tmpF32, wLocal, dimLenAlign); + + // Cast result fp32 → fp16. + AscendC::Cast(yLocal, tmpF32, AscendC::RoundMode::CAST_ROUND, + dimLenAlign); + } + + inQueueX.FreeTensor(xLocal); + outQueueY.EnQue(yLocal); + } + + __aicore__ inline void CopyOut(int64_t row) { + AscendC::LocalTensor yLocal = outQueueY.DeQue(); + AscendC::DataCopyExtParams params{ + 1, static_cast(this->dimLengthAlign * sizeof(T)), 0, 0, 0}; + AscendC::DataCopyPad(yGm[row * this->dimLengthAlign], yLocal, params); + outQueueY.FreeTensor(yLocal); + } + + private: + AscendC::TPipe pipe; + AscendC::TQue inQueueX; + AscendC::TQue outQueueY; + + AscendC::TBuf weightBuf; + AscendC::TBuf xFp32Buf; + AscendC::TBuf tmpFp32Buf; + AscendC::TBuf reduceTmpBuf; + AscendC::TBuf sumBuf; + + AscendC::GlobalTensor xGm, yGm; + AscendC::GlobalTensor weightGm; + + int64_t blockRows; + int64_t dimLength; + int64_t dimLengthAlign; + float eps; +}; + +extern "C" __global__ __aicore__ void rms_norm( + GM_ADDR x, GM_ADDR weight, GM_ADDR y, int64_t totalRows, int64_t dimLength, + int64_t dimLengthAlign, int64_t formerNum, int64_t formerLength, + int64_t tailLength, float eps, int64_t dtypeSize) { + if (dtypeSize == 2) { + KernelRmsNorm op; + op.Init(x, weight, y, totalRows, dimLength, dimLengthAlign, formerNum, + formerLength, tailLength, eps); + op.Process(); + } else { + KernelRmsNorm op; + op.Init(x, weight, y, totalRows, dimLength, dimLengthAlign, formerNum, + formerLength, tailLength, eps); + op.Process(); + } +} diff --git a/src/ascend/custom/torch_binding.cpp b/src/ascend/custom/torch_binding.cpp new file mode 100644 index 000000000..1d343c064 --- /dev/null +++ b/src/ascend/custom/torch_binding.cpp @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * Copyright (c) 2025 InfiniTensor. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * + * Adapted from https://github.com/vllm-project/vllm-ascend/blob/main/csrc/torch_binding.cpp + */ + +#include +#include + +#include "ops.h" + +namespace { +TORCH_LIBRARY_FRAGMENT(npu, m) { + m.def("rms_norm(Tensor input, Tensor weight, float eps=1e-6) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(npu, PrivateUse1, m) { + m.impl("rms_norm", TORCH_FN(ascend::detail::RmsNorm)); +} +} // namespace diff --git a/src/ascend/custom/utils/torch_kernel_helper.h b/src/ascend/custom/utils/torch_kernel_helper.h new file mode 100644 index 000000000..c4679d1dc --- /dev/null +++ b/src/ascend/custom/utils/torch_kernel_helper.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * Copyright (c) 2025 InfiniTensor. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * + * Adapted from https://github.com/vllm-project/vllm-ascend/tree/main/csrc + */ + +#ifndef TORCH_KERNEL_HELPER_H +#define TORCH_KERNEL_HELPER_H + +#include +#include + +#include "torch_npu/csrc/core/npu/NPUStream.h" +#include "torch_npu/csrc/framework/OpCommand.h" + +namespace ascend::detail { + +#define DEVICE_TYPE c10::DeviceType::PrivateUse1 + +class TorchNpuHelper { + public: + inline static at::Tensor CopyTensorHostToDevice( + const at::Tensor& cpu_tensor) { + at::Tensor cpuPinMemTensor = cpu_tensor.pin_memory(); + int deviceIndex = 0; + c10_npu::GetDevice(&deviceIndex); + return cpuPinMemTensor.to(c10::Device(DEVICE_TYPE, deviceIndex), + cpuPinMemTensor.scalar_type(), true, true); + } + + inline static at::Tensor CopyScalarToDevice(const c10::Scalar& cpu_scalar, + at::ScalarType scalar_data_type) { + return CopyTensorHostToDevice( + scalar_to_tensor(cpu_scalar).to(scalar_data_type)); + } + + inline static void* ConvertType(const at::Tensor& at_tensor) { + return const_cast(at_tensor.data_ptr()); + } + + template + inline static T ConvertType(T value) { + return value; + } + + template + inline static constexpr auto ConvertTypes(Ts&... args) { + return std::make_tuple(ConvertType(args)...); + } +}; +} // namespace ascend::detail + +/** + * @brief Launch real kernel function on NPU + * + * @param kernel_name [in] name of kernel + * @param blockdim [in] dim size of block + */ +#define EXEC_KERNEL_CMD(kernel_name, blockdim, ...) \ + do { \ + auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \ + auto converted_params = \ + ascend::detail::TorchNpuHelper::ConvertTypes(__VA_ARGS__); \ + auto acl_call = [acl_stream, blockdim, converted_params]() -> int { \ + std::apply( \ + [&](auto&&... params) { \ + ACLRT_LAUNCH_KERNEL(kernel_name) \ + (blockdim, acl_stream, params...); \ + }, \ + converted_params); \ + return 0; \ + }; \ + at_npu::native::OpCommand::RunOpApi(#kernel_name, acl_call); \ + } while (false) + +#endif // TORCH_KERNEL_HELPER_H diff --git a/src/ascend/data_type_.h b/src/ascend/data_type_.h index 9026f515f..4855713d2 100644 --- a/src/ascend/data_type_.h +++ b/src/ascend/data_type_.h @@ -11,6 +11,12 @@ namespace infini::ops::ascend { inline aclDataType ToAclDtype(DataType dt) { switch (dt) { + case DataType::kFloat16: + return ACL_FLOAT16; + case DataType::kBFloat16: + return ACL_BF16; + case DataType::kFloat32: + return ACL_FLOAT; case DataType::kInt8: return ACL_INT8; case DataType::kInt16: @@ -27,19 +33,13 @@ inline aclDataType ToAclDtype(DataType dt) { return ACL_UINT32; case DataType::kUInt64: return ACL_UINT64; - case DataType::kFloat16: - return ACL_FLOAT16; - case DataType::kBFloat16: - return ACL_BF16; - case DataType::kFloat32: - return ACL_FLOAT; default: - assert(false && "Unsupported dtype for Ascend backend."); + assert(false && "unsupported dtype for Ascend backend"); return ACL_DT_UNDEFINED; } } -// Returns true for integer (signed or unsigned) `DataType` values. +// Returns true for integer (signed or unsigned) DataType values. inline bool IsIntegerDtype(DataType dt) { switch (dt) { case DataType::kInt8: diff --git a/src/ascend/gemm/kernel.h b/src/ascend/gemm/kernel.h index 3360e7930..16f8c50ff 100644 --- a/src/ascend/gemm/kernel.h +++ b/src/ascend/gemm/kernel.h @@ -36,10 +36,10 @@ class Operator : public Gemm { std::optional trans_b, Tensor c) const override { auto stream = static_cast(stream_); - auto t_self = ascend::buildAclTensor(c); - auto t_a = ascend::buildAclTensor(a, trans_a_); - auto t_b = ascend::buildAclTensor(b, trans_b_); - auto t_out = ascend::buildAclTensor(c); + auto t_self = ascend::BuildAclTensor(c); + auto t_a = ascend::BuildAclTensor(a, trans_a_); + auto t_b = ascend::BuildAclTensor(b, trans_b_); + auto t_out = ascend::BuildAclTensor(c); uint64_t ws_needed = 0; aclOpExecutor* executor = nullptr; diff --git a/src/ascend/workspace_pool_.h b/src/ascend/workspace_pool_.h index 8c4c91961..bd3774fab 100644 --- a/src/ascend/workspace_pool_.h +++ b/src/ascend/workspace_pool_.h @@ -2,7 +2,11 @@ #define INFINI_OPS_ASCEND_WORKSPACE_POOL__H_ #include +#include #include +#include +#include +#include #include #include @@ -18,36 +22,132 @@ struct WorkspaceArena { class WorkspacePool { public: - WorkspaceArena& Ensure(aclrtStream stream, uint64_t needed) { + // Ensure the arena for `(stream, slot)` has at least `needed` bytes. + // + // The `slot` parameter defaults to `"workspace"` for backward + // compatibility. Operators needing a separate temp arena pass + // `"temp"`. + WorkspaceArena& Ensure(aclrtStream stream, uint64_t needed, + const char* slot = "workspace") { + // Thread-local fast path: a small flat array of recently used + // `(stream, slot, arena*)` triples. In practice operators use at + // most 2-3 slots, so linear scan is sufficient — no heap + // allocation on the hot path. + struct CacheEntry { + aclrtStream stream = nullptr; + const char* slot = nullptr; + WorkspaceArena* arena = nullptr; + }; + static constexpr int kCacheSize = 4; + thread_local CacheEntry cache[kCacheSize] = {}; + + for (int i = 0; i < kCacheSize; ++i) { + auto& e = cache[i]; + + if (e.stream == stream && e.slot != nullptr && + std::strcmp(e.slot, slot) == 0 && e.arena != nullptr && + needed <= e.arena->capacity) { + return *e.arena; + } + } + + // Slow path: look up arena in the map under lock. + assert(!capturing_ && + "`WorkspacePool`: `aclrtMalloc` on slow path during graph " + "capture. Ensure all operators run at least once during " + "eager warmup."); + std::lock_guard lock(mutex_); - auto& arena = arenas_[stream]; - if (needed <= arena.capacity) return arena; - if (arena.capacity > 0) { - aclrtSynchronizeStream(stream); - aclrtFree(arena.buf); + + SlotKey key{stream, slot}; + auto& owned = arenas_[key]; + + if (!owned) { + owned = std::make_unique(); + } + + auto* arena = owned.get(); + + if (needed > arena->capacity) { + if (arena->capacity > 0) { + aclrtSynchronizeStream(stream); + aclrtFree(arena->buf); + } + + if (needed > 0) { + auto ret = aclrtMalloc(&arena->buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY); + assert(ret == ACL_SUCCESS && "`WorkspacePool`: `aclrtMalloc` failed"); + } + + arena->capacity = needed; } - if (needed > 0) { - auto ret = aclrtMalloc(&arena.buf, needed, ACL_MEM_MALLOC_NORMAL_ONLY); - assert(ret == ACL_SUCCESS && "`WorkspacePool`: `aclrtMalloc` failed"); + + // Insert into the thread-local cache (evict oldest). + for (int i = kCacheSize - 1; i > 0; --i) { + cache[i] = cache[i - 1]; } - arena.capacity = needed; - return arena; + cache[0] = {stream, slot, arena}; + + return *arena; } + // Set to true before NPUGraph capture, false after. When true, + // the slow path (which calls `aclrtMalloc`) triggers an assert + // failure — a safety net against accidental device allocations + // being recorded into the graph. + void set_capture_mode(bool capturing) { capturing_ = capturing; } + ~WorkspacePool() { - for (auto& [stream, arena] : arenas_) { - if (arena.capacity > 0) aclrtFree(arena.buf); + for (auto& [key, arena] : arenas_) { + if (arena && arena->capacity > 0) { + // The CANN runtime may already be torn down when this static + // destructor runs. `aclrtGetDevice` fails in that case — + // skip the free to avoid glibc "double free" abort. + int32_t dev_id = -1; + + if (aclrtGetDevice(&dev_id) == ACL_SUCCESS) { + aclrtFree(arena->buf); + } else { + fprintf(stderr, + "[InfiniOps] `WorkspacePool`: CANN runtime already " + "finalized, skipping `aclrtFree` (%" PRIu64 + " bytes leaked).\n", + arena->capacity); + } + } } } private: - std::unordered_map arenas_; + struct SlotKey { + aclrtStream stream; + std::string slot; + + bool operator==(const SlotKey& o) const { + return stream == o.stream && slot == o.slot; + } + }; + + struct SlotKeyHash { + size_t operator()(const SlotKey& k) const { + auto h1 = std::hash{}(static_cast(k.stream)); + auto h2 = std::hash{}(k.slot); + + return h1 ^ (h2 << 1); + } + }; + + std::unordered_map, SlotKeyHash> + arenas_; std::mutex mutex_; + + bool capturing_ = false; }; inline WorkspacePool& GetWorkspacePool() { static WorkspacePool pool; + return pool; } diff --git a/src/hash.h b/src/hash.h index efb34f751..4721f33f3 100644 --- a/src/hash.h +++ b/src/hash.h @@ -2,6 +2,7 @@ #define INFINI_OPS_HASH_H_ #include +#include template inline void HashCombine(std::size_t& seed, const T& v) { @@ -9,4 +10,12 @@ inline void HashCombine(std::size_t& seed, const T& v) { seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } +template +inline void HashCombine(std::size_t& seed, const std::vector& v) { + HashCombine(seed, v.size()); + for (const auto& elem : v) { + HashCombine(seed, elem); + } +} + #endif diff --git a/src/operator.h b/src/operator.h index d4609e055..83fc4ec2c 100644 --- a/src/operator.h +++ b/src/operator.h @@ -37,6 +37,14 @@ struct CacheKey { tensors.push_back(t); } + void Absorb(const std::vector& ts) { + HashCombine(hash, ts.size()); + for (const auto& t : ts) { + HashCombine(hash, t); + tensors.push_back(t); + } + } + template void Absorb(const T& v) { HashCombine(hash, v); @@ -121,7 +129,16 @@ class OperatorBase { template class Operator : public OperatorBase { + // Generation counter for lazy cache invalidation. Bumped by + // `clear_cache()`; the next `call()` detects the mismatch and + // destroys all cached operator instances. + static inline std::size_t cache_generation_{0}; + public: + // Invalidate the operator cache. Cached operators are destroyed on the + // next `call()` invocation. Intended for test isolation — production + // code should never call this. + static void clear_cache() { ++cache_generation_; } template static auto Make(const Config& config, const Tensor tensor, Args&&... args) { std::unique_ptr op_ptr; @@ -166,6 +183,12 @@ class Operator : public OperatorBase { static auto Call(const Handle& handle, const Config& config, Args&&... args) { static std::unordered_map> cache; + static std::size_t generation{0}; + + if (generation != cache_generation_) { + cache.clear(); + generation = cache_generation_; + } auto key = detail::CacheKey::Build(config.implementation_index(), args...); @@ -174,7 +197,7 @@ class Operator : public OperatorBase { if (it == cache.end()) { // Pass args as lvalue refs so they remain valid for the `operator()` call // below. Forwarding rvalue temporaries into `Make()` would leave the args - // in a moved-from (empty) state before operator() can use them. + // in a moved-from (empty) state before `operator()` can use them. it = cache.emplace(std::move(key), Make(config, args...)).first; } diff --git a/src/pybind11_utils.h b/src/pybind11_utils.h index b595836cc..f13d3116c 100644 --- a/src/pybind11_utils.h +++ b/src/pybind11_utils.h @@ -66,10 +66,20 @@ inline Tensor TensorFromPybind11Handle(py::handle obj) { inline std::optional OptionalTensorFromPybind11Handle( const std::optional& obj) { - if (!obj.has_value()) return std::nullopt; + if (!obj.has_value() || obj->is_none()) return std::nullopt; return TensorFromPybind11Handle(*obj); } +inline std::vector VectorTensorFromPybind11Handle( + const std::vector& objs) { + std::vector result; + result.reserve(objs.size()); + for (const auto& obj : objs) { + result.push_back(TensorFromPybind11Handle(obj)); + } + return result; +} + } // namespace infini::ops #endif diff --git a/tests/conftest.py b/tests/conftest.py index 8a72355e5..7b39007f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,6 +44,31 @@ def set_seed_per_test(request): _set_random_seed(_hash(_test_case_path_from_request(request))) +@pytest.fixture(scope="module", autouse=True) +def _clear_operator_caches(): + """Clear the C++ operator cache between test modules. + + The `Operator::call()` cache keys on tensor geometry (shape, strides, + dtype) but not data pointers. When different test modules create tensors + with identical geometry but different data content (e.g., random + `cos_sin_cache` tables), a stale cached operator from a prior module + silently returns wrong results. Clearing the cache at module boundaries + ensures each module starts with a cold cache. + """ + yield + + try: + import infini.ops as ops + + for name in dir(ops): + cls = getattr(ops, name) + + if hasattr(cls, "clear_cache"): + cls.clear_cache() + except ImportError: + pass + + _NPU_UNSUPPORTED_DTYPES = {torch.float64} # `torch_npu` does not implement random number generation for @@ -65,6 +90,63 @@ def skip_unsupported_dtypes(request): pytest.skip(f"{params['dtype']} not supported on Ascend 910B") +# PyTorch device type → InfiniOps platform names. A single torch device type +# can map to several platforms (e.g., `cuda` is shared by `nvidia`, `metax`, +# and `iluvatar`); at most one is actually available at runtime. +_TORCH_DEVICE_TO_PLATFORMS = { + "cuda": ("nvidia", "metax", "iluvatar"), + "mlu": ("cambricon",), + "musa": ("moore",), + "npu": ("ascend",), +} + + +@pytest.fixture(autouse=True) +def skip_op_without_platform_impl(request): + """Skip `device=` parametrizations when the op has no + implementation on any of the corresponding platforms. + + Derives the InfiniOps class name from the test module filename + (`tests/test_.py` → ``) and checks + `infini.ops..active_implementation_indices()` for every + platform that maps to the test's torch device type. Skips only when + every mapped platform reports no active implementation — avoids + `Fatal Python error: Aborted` from dispatching through a base class + that has no backend specialization on the current branch. + """ + if not hasattr(request.node, "callspec"): + return + + device = request.node.callspec.params.get("device") + platforms = _TORCH_DEVICE_TO_PLATFORMS.get(device) + + if not platforms: + return + + module_name = request.node.module.__name__.rsplit(".", 1)[-1] + + if not module_name.startswith("test_"): + return + + op_snake = module_name[len("test_") :] + op_pascal = "".join(part.capitalize() for part in op_snake.split("_")) + + try: + import infini.ops as _ops + except ImportError: + return + + op_cls = getattr(_ops, op_pascal, None) + + if op_cls is None or not hasattr(op_cls, "active_implementation_indices"): + return + + if not any(op_cls.active_implementation_indices(p) for p in platforms): + pytest.skip( + f"{op_pascal} has no implementation on any `{device}`-mapped platform" + ) + + def _set_random_seed(seed): random.seed(seed) torch.manual_seed(seed) diff --git a/tests/test_add.py b/tests/test_add.py index 825fc932c..12c4b9b59 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -2,7 +2,13 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randint_strided, randn_strided +from tests.utils import ( + Payload, + empty_strided, + get_stream, + randint_strided, + randn_strided, +) _INT_DTYPES = (torch.int16, torch.int32, torch.int64) @@ -89,7 +95,13 @@ def test_add( def _add(input, other, out, implementation_index=0): - infini.ops.add(input, other, out, implementation_index=implementation_index) + infini.ops.add( + input, + other, + out, + stream=get_stream(input.device), + implementation_index=implementation_index, + ) return out diff --git a/tests/test_causal_softmax.py b/tests/test_causal_softmax.py index 8b35457a9..79e6be8e0 100644 --- a/tests/test_causal_softmax.py +++ b/tests/test_causal_softmax.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randn_strided +from tests.utils import Payload, empty_strided, get_stream, randn_strided @pytest.mark.auto_act_and_assert @@ -40,7 +40,7 @@ def test_causal_softmax(shape, input_strides, out_strides, dtype, device, rtol, def _causal_softmax(input, out): - infini.ops.causal_softmax(input, out) + infini.ops.causal_softmax(input, out, stream=get_stream(input.device)) return out @@ -48,7 +48,7 @@ def _causal_softmax(input, out): def _torch_causal_softmax(input, out): mask = torch.tril(torch.ones_like(input), diagonal=-1).flip(dims=[-2, -1]) masked = torch.where(mask == 1, -torch.inf, input.to(torch.float32)) - result = torch.nn.functional.softmax(masked, dim=-1, dtype=input.dtype) + result = torch.nn.functional.softmax(masked, dim=-1) out.copy_(result) return out diff --git a/tests/test_gemm.py b/tests/test_gemm.py index 26e102d25..97d060696 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -71,6 +71,18 @@ def test_gemm( ): pytest.skip("ATen CPU `addmm`/`baddbmm` does not support half-precision") + if implementation_index == 2 and device == "npu": + # `src/torch/gemm/gemm.h` partial-specializes `Operator` + # for every `kDev` including `kAscend`, so the SFINAE-based + # `active_implementation_indices` reports `2` as active even though + # `torch/gemm/gemm.cc` only instantiates it for CPU/NVIDIA. + # Dispatching through the unused Ascend specialization reads from an + # uninitialized vtable and crashes. See PR #64 discussion. + pytest.skip( + "Gemm impl=2 on Ascend is a torch-fallback stub without an " + "instantiated specialization" + ) + a = randn_strided(a_shape, a_strides, dtype=dtype, device=device) b = randn_strided(b_shape, b_strides, dtype=dtype, device=device) diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index d6d4dff17..52fd1ae4f 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, empty_strided, randn_strided +from tests.utils import Payload, empty_strided, get_stream, randn_strided @pytest.mark.auto_act_and_assert @@ -18,6 +18,7 @@ ), ) @pytest.mark.parametrize("eps", (1e-6, 1e-5)) +@pytest.mark.parametrize("implementation_index", (0, 1)) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( @@ -33,17 +34,25 @@ def test_rms_norm( weight_strides, out_strides, eps, + implementation_index, dtype, device, rtol, atol, ): + active_indices = infini.ops.RmsNorm.active_implementation_indices(device) + + if implementation_index not in active_indices: + pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") + input = randn_strided(input_shape, input_strides, dtype=dtype, device=device) weight = randn_strided(weight_shape, weight_strides, dtype=dtype, device=device) out = empty_strided(input_shape, out_strides, dtype=dtype, device=device) return Payload( - _rms_norm, + lambda *args, **kwargs: _rms_norm( + *args, **kwargs, implementation_index=implementation_index + ), _torch_rms_norm, (input, weight), {"eps": eps, "out": out}, @@ -52,8 +61,15 @@ def test_rms_norm( ) -def _rms_norm(input, weight, *, eps=1e-6, out=None): - infini.ops.rms_norm(input, weight, eps, out) +def _rms_norm(input, weight, *, eps=1e-6, out=None, implementation_index=0): + infini.ops.rms_norm( + input, + weight, + eps, + out, + implementation_index=implementation_index, + stream=get_stream(input.device), + ) return out diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py index 89c95f77f..23c299438 100644 --- a/tests/test_swiglu.py +++ b/tests/test_swiglu.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, empty_strided, rand_strided +from tests.utils import Payload, empty_strided, get_stream, rand_strided @pytest.mark.auto_act_and_assert @@ -19,6 +19,7 @@ ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), ), ) +@pytest.mark.parametrize("implementation_index", (0, 1)) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( @@ -28,17 +29,45 @@ ), ) def test_swiglu( - shape, input_strides, gate_strides, out_strides, dtype, device, rtol, atol + shape, + input_strides, + gate_strides, + out_strides, + implementation_index, + dtype, + device, + rtol, + atol, ): + active_indices = infini.ops.Swiglu.active_implementation_indices(device) + + if implementation_index not in active_indices: + pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") + input = rand_strided(shape, input_strides, dtype=dtype, device=device) gate = rand_strided(shape, gate_strides, dtype=dtype, device=device) out = empty_strided(shape, out_strides, dtype=dtype, device=device) - return Payload(_swiglu, _torch_swiglu, (input, gate, out), {}, rtol=rtol, atol=atol) + return Payload( + lambda *args, **kwargs: _swiglu( + *args, **kwargs, implementation_index=implementation_index + ), + _torch_swiglu, + (input, gate, out), + {}, + rtol=rtol, + atol=atol, + ) -def _swiglu(input, gate, out): - infini.ops.swiglu(input, gate, out) +def _swiglu(input, gate, out, implementation_index=0): + infini.ops.swiglu( + input, + gate, + out, + implementation_index=implementation_index, + stream=get_stream(input.device), + ) return out diff --git a/tests/utils.py b/tests/utils.py index 8f9532aa0..982d05aec 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -82,11 +82,21 @@ def randint_strided(low, high, shape, strides, *, dtype=None, device=None): return output +_STREAM_ACCESSORS = { + "npu": ("npu", "npu_stream"), + "cuda": ("cuda", "cuda_stream"), + "mlu": ("mlu", "mlu_stream"), + "musa": ("musa", "musa_stream"), +} + + def get_stream(device): """Return the raw stream handle for `device`, or 0 for CPU. - Uses `torch.accelerator.current_stream` when available, falling back to - device-specific APIs for older PyTorch versions. + Uses the device-specific `torch..current_stream()` API rather than + `torch.accelerator.current_stream()` — the latter returns a different + stream object on torch 2.9 + vllm-ascend, producing cross-stream data + hazards on cached-executor ops. """ if isinstance(device, torch.device): device = device.type @@ -97,32 +107,19 @@ def get_stream(device): if device == "cpu": return 0 - if hasattr(torch, "accelerator") and hasattr(torch.accelerator, "current_stream"): - stream = torch.accelerator.current_stream() - - # Each backend exposes the raw handle under a different attribute name. - for attr in ("npu_stream", "cuda_stream", "mlu_stream", "musa_stream"): - if hasattr(stream, attr): - return getattr(stream, attr) + mod_name, attr = _STREAM_ACCESSORS.get(device, (None, None)) + if mod_name is None: return 0 - # Fallback for older PyTorch builds without `torch.accelerator`. - _STREAM_ACCESSORS = { - "npu": ("npu", "npu_stream"), - "cuda": ("cuda", "cuda_stream"), - "mlu": ("mlu", "mlu_stream"), - "musa": ("musa", "musa_stream"), - } + mod = getattr(torch, mod_name, None) - if device in _STREAM_ACCESSORS: - mod_name, attr = _STREAM_ACCESSORS[device] - mod = getattr(torch, mod_name, None) + if mod is None: + return 0 - if mod is not None and hasattr(mod, "current_stream"): - return getattr(mod.current_stream(), attr) + stream = mod.current_stream() - return 0 + return getattr(stream, attr, 0) def clone_strided(input): From 13cf84a36548ad0ebeb249f76985421d80d9f8ce Mon Sep 17 00:00:00 2001 From: zhangyue <138768300+zhangyue207@users.noreply.github.com> Date: Wed, 22 Apr 2026 14:31:13 +0800 Subject: [PATCH 02/11] =?UTF-8?q?feat(ascend):=20op-simple=20group=20?= =?UTF-8?q?=E2=80=94=20Add,=20Mul,=20Cast,=20Cat,=20Matmul,=20Gemm,=20Line?= =?UTF-8?q?ar=20(#65)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(ascend): op-simple group — Add, Mul, Cast, Cat, Matmul, Gemm, Linear Seven foundational Ascend operators: | op | impl | |---|---| | Add | aclnnAdd | | Mul | aclnnMul | | Cast | aclnnCast | | Cat | aclnnCat | | Matmul | aclnnMatmul | | Gemm | aclnnMm (also carries the cached-executor / workspace-pool rework) | | Linear | aclnnMatmul + optional bias | Also ships: - `src/base/.h` for the 5 new ops (cast/cat/linear/matmul/mul); `add.h` and `gemm.h` existed on master and are updated in-place - `src/cpu//.h` reference impls for cast/cat/linear/mul (add/gemm/matmul had CPU refs on master already) - `tests/test_.py` for each operator (add and gemm have MODIFY diffs; others are new) * fix(ascend): Add/Cat destructor — use `release()` for executor-owned caches - `add/kernel.h`: swap destroy() → release() on in_cache_/oth_cache_/out_cache_ and drop aclDestroyAclOpExecutor (both are referenced by the Repeatable executor; destroying them causes double-free at shutdown per the pattern documented in common.h and commit 64c367c). - `cat/kernel.h`: release all in_caches_[i] in the destructor; without it, ~AclTensorCache() on vector teardown double-frees descriptors held by tensor_list_ / executor_. - Also group the alpha_* storage members with blank lines to match file convention. * test: generate `implementation_index` dynamically from `active_implementation_indices` Replaces hardcoded `(0, 1)` / `(0, 1, 2)` tuples in test_add, test_gemm, test_rms_norm, test_swiglu with a union over the locally-available devices' active implementation indices. New helper `tests.utils.all_active_implementation_indices(op_cls)` only iterates `get_available_devices()` to avoid `DispatchFunc::std::abort` on device types outside the build's `ActiveDevices` set. Effect on Ascend CI: skipped-test count drops from 3246 to 1686 — impl=1 (`cuBLASLt`) no longer parametrized when no CUDA device is visible, and RmsNorm/Swiglu's custom-kernel slot drops out of the matrix on op-simple where the framework layer hasn't merged the AscendC impl yet. * test(conftest): joint `(device, implementation_index)` parametrize Replaces the per-test `@pytest.mark.parametrize("implementation_index", ...)` + runtime `if impl not in active_indices: skip` pattern with a single hook in `conftest.pytest_generate_tests` that emits only the (device, impl) pairs actually active on each device. Rationale: kernel dispatch is per-device, so cross-device union (previous `all_active_implementation_indices` helper) polluted the matrix with impls that the selected device can't run — runtime-skipped noise. Joint generation keeps the matrix to its semantic cell: "this device has this impl, so run it". - `tests/conftest.py`: when both `device` and `implementation_index` are in fixturenames, emit pairs via `op_cls.active_implementation_indices(dev)`; fall back to a skipped placeholder (`id="skip"`) when no device has an active impl, avoiding `[NOTSET-...]` test IDs. - `tests/{test_add,test_gemm,test_rms_norm,test_swiglu}.py`: drop the hardcoded `implementation_index` parametrize decorator and the runtime `active_indices` guard — conftest now handles both. - `tests/utils.py`: remove the `all_active_implementation_indices` helper (superseded by per-device generation in conftest). Same test outcome on Ascend CI (1935 passed / 1686 skipped) but the remaining skips are now either semantically mandatory (uint dtypes unsupported by `torch_npu`, Gemm impl=2 SFINAE-only workaround, op missing ascend impl on op-simple pending PR #66) rather than mechanism artifacts. * refactor(conftest): dedupe `_op_class_from_module`, short-circuit redundant fixture Post-review cleanup of the joint-parametrize refactor (1dd288f): - Extract `_op_class_from_module` as a shared helper; `skip_op_without_platform_impl` fixture now calls it instead of re-deriving the snake→pascal class name inline. - Short-circuit the fixture when `implementation_index` is already in callspec — `pytest_generate_tests` has already pruned empty-impl pairs, so per-case `active_implementation_indices` calls are wasted work. - Drop `try/except ImportError` inside the helper — collection has already imported `infini.ops` via test modules; masking a real import failure only turns it into a cryptic NOTSET fixture. - Drop the `devices[0] if devices else "cpu"` fallback — `get_available_devices()` always includes `"cpu"`, making the `else` arm unreachable. * refactor(cpu): flatten nested `DispatchFunc` in Cast; snake_case variables in Linear Per PR #65 review: - `src/cpu/cast/cast.h`: replace nested `DispatchFunc(in_dtype, ...)` inside `DispatchFunc(out_dtype, ...)` with a single multi-dispatch call `DispatchFunc({in, out}, [](in_tag, out_tag) {...})` per the multi-dispatch idiom documented in `CONTRIBUTING.md`. - `src/cpu/linear/linear.h`: rename PascalCase locals to snake_case: `A/B/Out/Bias` → `a_ptr/b_ptr/out_ptr/bias_ptr`, `A_batch/B_batch/Out_batch` → `a_batch/b_batch/out_batch`, `M/N/K` → `m/n/k` (matching master's `src/cpu/gemm/gemm.h` which already uses lowercase dim names `m_/n_/k_`). * refactor(cpu/linear): drop redundant `&& bias` guard + narrating comment - `if (bias_ptr && bias)` → `if (bias_ptr)` (line 75). `bias_ptr` is `nullptr` iff `!bias` by construction at line 38, so `&& bias` is dead. - Remove `// Determine `m`, `n`, `k` from shapes and transpose flags.` — the three lines below literally do exactly that; self-describing now that names are snake_case. --------- Co-authored-by: zhangyue --- src/ascend/add/kernel.h | 93 +++++++++++++++++++++++++++ src/ascend/cast/kernel.h | 64 +++++++++++++++++++ src/ascend/cat/kernel.h | 105 +++++++++++++++++++++++++++++++ src/ascend/gemm/kernel.h | 75 ++++++++++++++-------- src/ascend/linear/kernel.h | 125 +++++++++++++++++++++++++++++++++++++ src/ascend/matmul/kernel.h | 68 ++++++++++++++++++++ src/ascend/mul/kernel.h | 68 ++++++++++++++++++++ src/base/cast.h | 52 +++++++++++++++ src/base/cat.h | 35 +++++++++++ src/base/linear.h | 65 +++++++++++++++++++ src/base/mat_mul.h | 31 --------- src/base/matmul.h | 41 ++++++++++++ src/base/mul.h | 67 ++++++++++++++++++++ src/cpu/cast/cast.h | 52 +++++++++++++++ src/cpu/cat/cat.h | 71 +++++++++++++++++++++ src/cpu/linear/linear.h | 107 +++++++++++++++++++++++++++++++ src/cpu/mul/mul.h | 63 +++++++++++++++++++ tests/conftest.py | 97 +++++++++++++++++++++------- tests/test_add.py | 8 --- tests/test_cast.py | 62 ++++++++++++++++++ tests/test_cat.py | 69 ++++++++++++++++++++ tests/test_gemm.py | 8 --- tests/test_linear.py | 90 ++++++++++++++++++++++++++ tests/test_matmul.py | 76 ++++++++++++++++++++++ tests/test_mul.py | 87 ++++++++++++++++++++++++++ tests/test_rms_norm.py | 6 -- tests/test_swiglu.py | 6 -- 27 files changed, 1584 insertions(+), 107 deletions(-) create mode 100644 src/ascend/add/kernel.h create mode 100644 src/ascend/cast/kernel.h create mode 100644 src/ascend/cat/kernel.h create mode 100644 src/ascend/linear/kernel.h create mode 100644 src/ascend/matmul/kernel.h create mode 100644 src/ascend/mul/kernel.h create mode 100644 src/base/cast.h create mode 100644 src/base/cat.h create mode 100644 src/base/linear.h delete mode 100644 src/base/mat_mul.h create mode 100644 src/base/matmul.h create mode 100644 src/base/mul.h create mode 100644 src/cpu/cast/cast.h create mode 100644 src/cpu/cat/cat.h create mode 100644 src/cpu/linear/linear.h create mode 100644 src/cpu/mul/mul.h create mode 100644 tests/test_cast.py create mode 100644 tests/test_cat.py create mode 100644 tests/test_linear.py create mode 100644 tests/test_matmul.py create mode 100644 tests/test_mul.py diff --git a/src/ascend/add/kernel.h b/src/ascend/add/kernel.h new file mode 100644 index 000000000..251c31364 --- /dev/null +++ b/src/ascend/add/kernel.h @@ -0,0 +1,93 @@ +#ifndef INFINI_OPS_ASCEND_ADD_KERNEL_H_ +#define INFINI_OPS_ASCEND_ADD_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_add.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/add.h" +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Add { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Add(input, other, out), + in_cache_(input), + oth_cache_(other), + out_cache_(out) { + // `aclCreateScalar` stores the pointer rather than copying the value, so + // `alpha_storage_*` must remain alive for the lifetime of `alpha_`. + // The alpha scalar type must match the tensor dtype: use int64 for integer + // dtypes and float for floating-point dtypes. + if (ascend::IsIntegerDtype(input.dtype())) { + alpha_ = aclCreateScalar(&alpha_int_storage_, ACL_INT64); + } else { + alpha_ = aclCreateScalar(&alpha_float_storage_, ACL_FLOAT); + } + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. The + // descriptors are still referenced by the Repeatable `executor_`, so + // skipping `aclDestroyTensor` (and leaking the executor at shutdown) + // avoids a double-free; see `64c367c`. + in_cache_.release(); + oth_cache_.release(); + out_cache_.release(); + + if (alpha_) aclDestroyScalar(alpha_); + } + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + auto stream = static_cast(stream_); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_oth = oth_cache_.get(const_cast(other.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnAddGetWorkspaceSize(t_in, t_oth, alpha_, t_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(executor_, 1, t_oth, + const_cast(other.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + aclnnAdd(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache oth_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; + + // Stable address for `aclCreateScalar` (float). + float alpha_float_storage_ = 1.0f; + + // Stable address for `aclCreateScalar` (int). + int64_t alpha_int_storage_ = 1; + + aclScalar* alpha_ = nullptr; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/cast/kernel.h b/src/ascend/cast/kernel.h new file mode 100644 index 000000000..d918aa843 --- /dev/null +++ b/src/ascend/cast/kernel.h @@ -0,0 +1,64 @@ +#ifndef INFINI_OPS_ASCEND_CAST_KERNEL_H_ +#define INFINI_OPS_ASCEND_CAST_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_cast.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/cast.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Cast { + public: + Operator(const Tensor input, Tensor out) + : Cast(input, out), + in_cache_(input), + out_cache_(out), + acl_out_dtype_(ascend::ToAclDtype(out.dtype())) {} + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + in_cache_.release(); + out_cache_.release(); + } + + void operator()(const Tensor input, Tensor out) const override { + auto stream = static_cast(stream_); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnCastGetWorkspaceSize(t_in, acl_out_dtype_, t_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + aclnnCast(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache out_cache_; + + aclDataType acl_out_dtype_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/cat/kernel.h b/src/ascend/cat/kernel.h new file mode 100644 index 000000000..018f966a2 --- /dev/null +++ b/src/ascend/cat/kernel.h @@ -0,0 +1,105 @@ +#ifndef INFINI_OPS_ASCEND_CAT_KERNEL_H_ +#define INFINI_OPS_ASCEND_CAT_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/acl_meta.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_cat.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/cat.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Cat { + public: + Operator(const Tensor first_input, std::vector rest_inputs, + int64_t dim, Tensor out) + : Cat(first_input, rest_inputs, dim, out), out_cache_(out) { + // Build `AclTensorCache` for each input tensor. + in_caches_.reserve(input_count_); + in_caches_.emplace_back(first_input); + for (const auto& t : rest_inputs) { + in_caches_.emplace_back(t); + } + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. The input + // descriptors are referenced by the Repeatable `executor_` via + // `tensor_list_`, so every `in_caches_[i]` must be released alongside + // `out_cache_`; otherwise `~AclTensorCache()` double-frees them when the + // vector destructs. + for (auto& c : in_caches_) { + c.release(); + } + out_cache_.release(); + + if (tensor_list_) aclDestroyTensorList(tensor_list_); + } + + void operator()(const Tensor first_input, std::vector rest_inputs, + int64_t /*dim*/, Tensor out) const override { + auto stream = static_cast(stream_); + + // Collect all input tensors in order. + std::vector inputs; + inputs.reserve(input_count_); + inputs.push_back(&first_input); + for (const auto& t : rest_inputs) { + inputs.push_back(&t); + } + + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + // First call: create descriptors, tensor list, and executor. + std::vector acl_tensors(input_count_); + for (size_t i = 0; i < input_count_; ++i) { + acl_tensors[i] = + in_caches_[i].get(const_cast(inputs[i]->data())); + } + + tensor_list_ = + aclCreateTensorList(const_cast(acl_tensors.data()), + static_cast(input_count_)); + + aclnnCatGetWorkspaceSize(tensor_list_, dim_, t_out, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + // Subsequent calls: update data pointers on cached descriptors via + // `aclSetRawTensorAddr`. The executor holds references to the same + // `aclTensor*` objects inside `tensor_list_`, so updating their data + // pointers is sufficient — no `aclSetInputTensorAddr` needed. + for (size_t i = 0; i < input_count_; ++i) { + in_caches_[i].get(const_cast(inputs[i]->data())); + } + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + aclnnCat(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable std::vector in_caches_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclTensorList* tensor_list_ = nullptr; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/gemm/kernel.h b/src/ascend/gemm/kernel.h index 16f8c50ff..1795baf2d 100644 --- a/src/ascend/gemm/kernel.h +++ b/src/ascend/gemm/kernel.h @@ -21,14 +21,26 @@ class Operator : public Gemm { : Gemm(a, b, alpha, beta, trans_a, trans_b, c), batched_{batch_count_ > 1}, alpha_val_{alpha.value_or(1.0f)}, - beta_val_{beta.value_or(1.0f)} { + beta_val_{beta.value_or(1.0f)}, + self_cache_(c), + a_cache_(a, trans_a_), + b_cache_(b, trans_b_), + out_cache_(c) { alpha_scalar_ = aclCreateScalar(&alpha_val_, ACL_FLOAT); beta_scalar_ = aclCreateScalar(&beta_val_, ACL_FLOAT); } ~Operator() { - aclDestroyScalar(alpha_scalar_); - aclDestroyScalar(beta_scalar_); + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + self_cache_.release(); + a_cache_.release(); + b_cache_.release(); + out_cache_.release(); + + if (alpha_scalar_) aclDestroyScalar(alpha_scalar_); + if (beta_scalar_) aclDestroyScalar(beta_scalar_); } void operator()(const Tensor a, const Tensor b, std::optional alpha, @@ -36,35 +48,36 @@ class Operator : public Gemm { std::optional trans_b, Tensor c) const override { auto stream = static_cast(stream_); - auto t_self = ascend::BuildAclTensor(c); - auto t_a = ascend::BuildAclTensor(a, trans_a_); - auto t_b = ascend::BuildAclTensor(b, trans_b_); - auto t_out = ascend::BuildAclTensor(c); - - uint64_t ws_needed = 0; - aclOpExecutor* executor = nullptr; - - if (batched_) { - aclnnBaddbmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, - alpha_scalar_, t_out, 0, &ws_needed, - &executor); + auto t_self = self_cache_.get(c.data()); + auto t_a = a_cache_.get(const_cast(a.data())); + auto t_b = b_cache_.get(const_cast(b.data())); + auto t_out = out_cache_.get(c.data()); + + if (!executor_) { + if (batched_) { + aclnnBaddbmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } else { + aclnnAddmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } + aclSetAclOpExecutorRepeatable(executor_); } else { - aclnnAddmmGetWorkspaceSize(t_self, t_a, t_b, beta_scalar_, alpha_scalar_, - t_out, 0, &ws_needed, &executor); + aclSetInputTensorAddr(executor_, 0, t_self, c.data()); + aclSetInputTensorAddr(executor_, 1, t_a, const_cast(a.data())); + aclSetInputTensorAddr(executor_, 2, t_b, const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, c.data()); } - auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_needed); + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); if (batched_) { - aclnnBaddbmm(arena.buf, ws_needed, executor, stream); + aclnnBaddbmm(arena.buf, ws_size_, executor_, stream); } else { - aclnnAddmm(arena.buf, ws_needed, executor, stream); + aclnnAddmm(arena.buf, ws_size_, executor_, stream); } - - aclDestroyTensor(t_self); - aclDestroyTensor(t_a); - aclDestroyTensor(t_b); - aclDestroyTensor(t_out); } private: @@ -77,6 +90,18 @@ class Operator : public Gemm { aclScalar* alpha_scalar_ = nullptr; aclScalar* beta_scalar_ = nullptr; + + mutable ascend::AclTensorCache self_cache_; + + mutable ascend::AclTensorCache a_cache_; + + mutable ascend::AclTensorCache b_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; }; } // namespace infini::ops diff --git a/src/ascend/linear/kernel.h b/src/ascend/linear/kernel.h new file mode 100644 index 000000000..497dd8065 --- /dev/null +++ b/src/ascend/linear/kernel.h @@ -0,0 +1,125 @@ +#ifndef INFINI_OPS_ASCEND_LINEAR_KERNEL_H_ +#define INFINI_OPS_ASCEND_LINEAR_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_addmm.h" +#include "aclnnop/aclnn_baddbmm.h" +#include "aclnnop/aclnn_matmul.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/linear.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Linear { + public: + Operator(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) + : Linear(a, b, bias, trans_a, trans_b, out), + batched_{out.ndim() > 2}, + a_cache_(a, trans_a), + b_cache_(b, trans_b), + out_cache_(out) { + if (has_bias_) { + bias_cache_ = ascend::AclTensorCache(*bias); + alpha_scalar_ = aclCreateScalar(&alpha_storage_, ACL_FLOAT); + beta_scalar_ = aclCreateScalar(&beta_storage_, ACL_FLOAT); + } + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + bias_cache_.release(); + a_cache_.release(); + b_cache_.release(); + out_cache_.release(); + + if (alpha_scalar_) aclDestroyScalar(alpha_scalar_); + if (beta_scalar_) aclDestroyScalar(beta_scalar_); + } + + void operator()(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) const override { + auto stream = static_cast(stream_); + auto t_a = a_cache_.get(const_cast(a.data())); + auto t_b = b_cache_.get(const_cast(b.data())); + auto t_out = out_cache_.get(out.data()); + + if (has_bias_) { + auto t_bias = bias_cache_.get(const_cast(bias->data())); + + if (!executor_) { + if (batched_) { + aclnnBaddbmmGetWorkspaceSize(t_bias, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } else { + aclnnAddmmGetWorkspaceSize(t_bias, t_a, t_b, beta_scalar_, + alpha_scalar_, t_out, 0, &ws_size_, + &executor_); + } + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_bias, + const_cast(bias->data())); + aclSetInputTensorAddr(executor_, 1, t_a, const_cast(a.data())); + aclSetInputTensorAddr(executor_, 2, t_b, const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + + if (batched_) { + aclnnBaddbmm(arena.buf, ws_size_, executor_, stream); + } else { + aclnnAddmm(arena.buf, ws_size_, executor_, stream); + } + } else { + if (!executor_) { + int8_t cube_math_type = 1; + aclnnMatmulGetWorkspaceSize(t_a, t_b, t_out, cube_math_type, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_a, const_cast(a.data())); + aclSetInputTensorAddr(executor_, 1, t_b, const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + aclnnMatmul(arena.buf, ws_size_, executor_, stream); + } + } + + private: + bool batched_; + + mutable ascend::AclTensorCache bias_cache_; + + mutable ascend::AclTensorCache a_cache_; + + mutable ascend::AclTensorCache b_cache_; + + mutable ascend::AclTensorCache out_cache_; + + float alpha_storage_ = 1.0f; + + float beta_storage_ = 1.0f; + + aclScalar* alpha_scalar_ = nullptr; + + aclScalar* beta_scalar_ = nullptr; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/matmul/kernel.h b/src/ascend/matmul/kernel.h new file mode 100644 index 000000000..df05677f7 --- /dev/null +++ b/src/ascend/matmul/kernel.h @@ -0,0 +1,68 @@ +#ifndef INFINI_OPS_ASCEND_MATMUL_KERNEL_H_ +#define INFINI_OPS_ASCEND_MATMUL_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_matmul.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/matmul.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Matmul { + public: + Operator(const Tensor a, const Tensor b, Tensor c, bool trans_a, bool trans_b) + : Matmul(a, b, c, trans_a, trans_b), + a_cache_(a, trans_a), + b_cache_(b, trans_b), + out_cache_(c) {} + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + a_cache_.release(); + b_cache_.release(); + out_cache_.release(); + } + + void operator()(const Tensor a, const Tensor b, Tensor c, bool trans_a, + bool trans_b) const override { + auto stream = static_cast(stream_); + auto t_a = a_cache_.get(const_cast(a.data())); + auto t_b = b_cache_.get(const_cast(b.data())); + auto t_out = out_cache_.get(c.data()); + + if (!executor_) { + int8_t cube_math_type = 1; + aclnnMatmulGetWorkspaceSize(t_a, t_b, t_out, cube_math_type, &ws_size_, + &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_a, const_cast(a.data())); + aclSetInputTensorAddr(executor_, 1, t_b, const_cast(b.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, c.data()); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + aclnnMatmul(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache a_cache_; + + mutable ascend::AclTensorCache b_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/mul/kernel.h b/src/ascend/mul/kernel.h new file mode 100644 index 000000000..f1cfed673 --- /dev/null +++ b/src/ascend/mul/kernel.h @@ -0,0 +1,68 @@ +#ifndef INFINI_OPS_ASCEND_MUL_KERNEL_H_ +#define INFINI_OPS_ASCEND_MUL_KERNEL_H_ + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_mul.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/mul.h" +#include "operator.h" + +namespace infini::ops { + +template <> +class Operator : public Mul { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Mul(input, other, out), + in_cache_(input), + oth_cache_(other), + out_cache_(out) {} + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + in_cache_.release(); + oth_cache_.release(); + out_cache_.release(); + } + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + auto stream = static_cast(stream_); + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_oth = oth_cache_.get(const_cast(other.data())); + auto t_out = out_cache_.get(out.data()); + + if (!executor_) { + aclnnMulGetWorkspaceSize(t_in, t_oth, t_out, &ws_size_, &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_in, + const_cast(input.data())); + aclSetInputTensorAddr(executor_, 1, t_oth, + const_cast(other.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + aclnnMul(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache oth_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/cast.h b/src/base/cast.h new file mode 100644 index 000000000..29f1f40cf --- /dev/null +++ b/src/base/cast.h @@ -0,0 +1,52 @@ +#ifndef INFINI_OPS_BASE_CAST_H_ +#define INFINI_OPS_BASE_CAST_H_ + +#include "operator.h" + +namespace infini::ops { + +class Cast : public Operator { + public: + Cast(const Tensor input, Tensor out) + : ndim_{out.ndim()}, + output_size_{out.numel()}, + input_dtype_{input.dtype()}, + out_dtype_{out.dtype()}, + input_shape_{input.shape()}, + out_shape_{out.shape()}, + input_strides_{input.strides()}, + out_strides_{out.strides()}, + is_input_contiguous_{input.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()} { + assert(input.numel() == out.numel() && + "the input and output of `Cast` must have the same number of " + "elements"); + } + + virtual void operator()(const Tensor input, Tensor out) const = 0; + + protected: + Tensor::Size ndim_{0}; + + Tensor::Size output_size_{0}; + + const DataType input_dtype_; + + const DataType out_dtype_; + + Tensor::Shape input_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides input_strides_; + + Tensor::Strides out_strides_; + + bool is_input_contiguous_{false}; + + bool is_out_contiguous_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/cat.h b/src/base/cat.h new file mode 100644 index 000000000..dcb0ba587 --- /dev/null +++ b/src/base/cat.h @@ -0,0 +1,35 @@ +#ifndef INFINI_OPS_BASE_CAT_H_ +#define INFINI_OPS_BASE_CAT_H_ + +#include + +#include "operator.h" + +namespace infini::ops { + +class Cat : public Operator { + public: + Cat(const Tensor first_input, std::vector rest_inputs, int64_t dim, + Tensor out) + : input_count_{1 + rest_inputs.size()} { + assert(input_count_ >= 2 && "`Cat` requires at least 2 input tensors"); + + auto ndim = static_cast(out.ndim()); + // Normalize negative dim (e.g. -1 means last dimension). + dim_ = dim < 0 ? dim + ndim : dim; + assert(dim_ >= 0 && dim_ < ndim && "`Cat` dim out of range"); + } + + virtual void operator()(const Tensor first_input, + std::vector rest_inputs, int64_t dim, + Tensor out) const = 0; + + protected: + int64_t dim_; + + size_t input_count_; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/linear.h b/src/base/linear.h new file mode 100644 index 000000000..a5276e612 --- /dev/null +++ b/src/base/linear.h @@ -0,0 +1,65 @@ +#ifndef INFINI_OPS_BASE_LINEAR_H_ +#define INFINI_OPS_BASE_LINEAR_H_ + +#include + +#include "operator.h" + +namespace infini::ops { + +// Fused linear projection: out = a @ b (+ bias). +// +// When bias is present, computes out = a @ b + bias in a single dispatch. +// When bias is absent, computes out = a @ b (equivalent to Matmul). +// `trans_a` / `trans_b`: If true, transpose the last two dims before +// multiplying. +class Linear : public Operator { + public: + Linear(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) + : a_shape_{a.shape()}, + b_shape_{b.shape()}, + out_shape_{out.shape()}, + a_strides_{a.strides()}, + b_strides_{b.strides()}, + out_strides_{out.strides()}, + trans_a_{trans_a}, + trans_b_{trans_b}, + has_bias_{bias.has_value()} { + assert(a.dtype() == b.dtype() && + "operator `Linear` requires a and b to have the same dtype"); + assert(a.dtype() == out.dtype() && + "operator `Linear` requires a and out to have the same dtype"); + if (has_bias_) { + assert(bias->dtype() == out.dtype() && + "operator `Linear` requires bias and out to have the same dtype"); + } + } + + virtual void operator()(const Tensor a, const Tensor b, + std::optional bias, bool trans_a, + bool trans_b, Tensor out) const = 0; + + protected: + Tensor::Shape a_shape_; + + Tensor::Shape b_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides a_strides_; + + Tensor::Strides b_strides_; + + Tensor::Strides out_strides_; + + bool trans_a_{false}; + + bool trans_b_{false}; + + bool has_bias_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/mat_mul.h b/src/base/mat_mul.h deleted file mode 100644 index 6180c8bf2..000000000 --- a/src/base/mat_mul.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef INFINI_OPS_BASE_MAT_MUL_H_ -#define INFINI_OPS_BASE_MAT_MUL_H_ - -#include "operator.h" -#include "tensor.h" - -namespace infini::ops { - -class MatMul : public Operator { - public: - MatMul(const Tensor input, const Tensor other, Tensor out) - : input_shape_{input.shape()}, - other_shape_{other.shape()}, - out_shape_{out.shape()} { - assert(input.dtype() == other.dtype()); - } - - virtual void operator()(const Tensor input, const Tensor other, - Tensor out) const = 0; - - protected: - Tensor::Shape input_shape_; - - Tensor::Shape other_shape_; - - Tensor::Shape out_shape_; -}; - -} // namespace infini::ops - -#endif diff --git a/src/base/matmul.h b/src/base/matmul.h new file mode 100644 index 000000000..071feaeaa --- /dev/null +++ b/src/base/matmul.h @@ -0,0 +1,41 @@ +#ifndef INFINI_OPS_BASE_MATMUL_H_ +#define INFINI_OPS_BASE_MATMUL_H_ + +#include "operator.h" +#include "tensor.h" + +namespace infini::ops { + +class Matmul : public Operator { + public: + // `trans_a` / `trans_b`: If true, transpose the last two dims of `a` / `b` + // before multiplying. These are constructor parameters so the `CacheKey` + // encodes the transposition and distinct descriptors are cached for each + // combination. + Matmul(const Tensor a, const Tensor b, Tensor c, bool trans_a, bool trans_b) + : a_shape_{a.shape()}, + b_shape_{b.shape()}, + c_shape_{c.shape()}, + trans_a_{trans_a}, + trans_b_{trans_b} { + assert(a.dtype() == b.dtype()); + } + + virtual void operator()(const Tensor a, const Tensor b, Tensor c, + bool trans_a, bool trans_b) const = 0; + + protected: + Tensor::Shape a_shape_; + + Tensor::Shape b_shape_; + + Tensor::Shape c_shape_; + + bool trans_a_{false}; + + bool trans_b_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/base/mul.h b/src/base/mul.h new file mode 100644 index 000000000..9e7be2239 --- /dev/null +++ b/src/base/mul.h @@ -0,0 +1,67 @@ +#ifndef INFINI_OPS_BASE_MUL_H_ +#define INFINI_OPS_BASE_MUL_H_ + +#include "operator.h" + +namespace infini::ops { + +class Mul : public Operator { + public: + Mul(const Tensor input, const Tensor other, Tensor out) + : ndim_{out.ndim()}, + output_size_{out.numel()}, + input_type_{input.dtype()}, + other_type_{other.dtype()}, + out_type_{out.dtype()}, + input_shape_{input.shape()}, + other_shape_{other.shape()}, + out_shape_{out.shape()}, + input_strides_{input.strides()}, + other_strides_{other.strides()}, + out_strides_{out.strides()}, + is_input_contiguous_{input.IsContiguous()}, + is_other_contiguous_{other.IsContiguous()}, + is_out_contiguous_{out.IsContiguous()} { + assert(!out.HasBroadcastDim() && + "the output of `Mul` should NOT have broadcasted dim!"); + assert(input_type_ == other_type_ && other_type_ == out_type_ && + "operator `Mul` requires all input and output tensors to have the " + "same dtype"); + } + + virtual void operator()(const Tensor input, const Tensor other, + Tensor out) const = 0; + + protected: + Tensor::Size ndim_{0}; + + Tensor::Size output_size_{0}; + + const DataType input_type_; + + const DataType other_type_; + + const DataType out_type_; + + Tensor::Shape input_shape_; + + Tensor::Shape other_shape_; + + Tensor::Shape out_shape_; + + Tensor::Strides input_strides_; + + Tensor::Strides other_strides_; + + Tensor::Strides out_strides_; + + bool is_input_contiguous_{false}; + + bool is_other_contiguous_{false}; + + bool is_out_contiguous_{false}; +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/cast/cast.h b/src/cpu/cast/cast.h new file mode 100644 index 000000000..ef89b8ac3 --- /dev/null +++ b/src/cpu/cast/cast.h @@ -0,0 +1,52 @@ +#ifndef INFINI_OPS_CPU_CAST_CAST_H_ +#define INFINI_OPS_CPU_CAST_CAST_H_ + +#include "base/cast.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Cast { + public: + Operator(const Tensor input, Tensor out) : Cast{input, out} {} + + void operator()(const Tensor input, Tensor out) const override { + DispatchFunc( + {input_dtype_, out_dtype_}, + [&](auto in_tag, auto out_tag) { + using InT = typename decltype(in_tag)::type; + using OutT = typename decltype(out_tag)::type; + Compute(input, out); + }, + "`Operator::operator()`"); + } + + private: + template + void Compute(const Tensor input, Tensor out) const { + const auto* in_ptr = static_cast(input.data()); + auto* out_ptr = static_cast(out.data()); + + auto get_idx = [&](Tensor::Size i, bool is_contig, const auto* shape, + const auto* strides) { + return is_contig ? i : utils::IndexToOffset(i, ndim_, shape, strides); + }; + +#pragma omp parallel for + for (Tensor::Size i = 0; i < output_size_; ++i) { + auto in_idx = get_idx(i, is_input_contiguous_, input_shape_.data(), + input_strides_.data()); + auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), + out_strides_.data()); + + out_ptr[out_idx] = + Caster::template Cast(in_ptr[in_idx]); + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/cat/cat.h b/src/cpu/cat/cat.h new file mode 100644 index 000000000..18b45247a --- /dev/null +++ b/src/cpu/cat/cat.h @@ -0,0 +1,71 @@ +#ifndef INFINI_OPS_CPU_CAT_CAT_H_ +#define INFINI_OPS_CPU_CAT_CAT_H_ + +#include +#include + +#include "base/cat.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Cat { + public: + Operator(const Tensor first_input, std::vector rest_inputs, + int64_t dim, Tensor out) + : Cat{first_input, rest_inputs, dim, out} {} + + void operator()(const Tensor first_input, std::vector rest_inputs, + int64_t /*dim*/, Tensor out) const override { + // Collect all input tensors. + std::vector inputs; + inputs.reserve(input_count_); + inputs.push_back(&first_input); + for (const auto& t : rest_inputs) { + inputs.push_back(&t); + } + + // Use normalized `dim_` from base class (handles negative dim). + auto dim = dim_; + auto elem_size = kDataTypeToSize.at(out.dtype()); + auto ndim = out.ndim(); + auto out_shape = out.shape(); + + // Compute outer and inner sizes relative to the cat dimension. + Tensor::Size outer = 1; + for (int64_t i = 0; i < dim; ++i) { + outer *= out_shape[i]; + } + + Tensor::Size inner = 1; + for (size_t i = static_cast(dim) + 1; i < ndim; ++i) { + inner *= out_shape[i]; + } + + auto* out_ptr = static_cast(out.data()); + Tensor::Size out_dim_size = out_shape[dim]; + + // For each outer index, copy slices from each input along the cat dim. + for (Tensor::Size o = 0; o < outer; ++o) { + Tensor::Size offset_in_dim = 0; + + for (size_t t = 0; t < input_count_; ++t) { + auto in_dim = inputs[t]->shape()[dim]; + auto in_ptr = static_cast(inputs[t]->data()); + + auto src_offset = (o * in_dim) * inner * elem_size; + auto dst_offset = + (o * out_dim_size + offset_in_dim) * inner * elem_size; + auto copy_size = in_dim * inner * elem_size; + + std::memcpy(out_ptr + dst_offset, in_ptr + src_offset, copy_size); + offset_in_dim += in_dim; + } + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/linear/linear.h b/src/cpu/linear/linear.h new file mode 100644 index 000000000..21e1bb265 --- /dev/null +++ b/src/cpu/linear/linear.h @@ -0,0 +1,107 @@ +#ifndef INFINI_OPS_CPU_LINEAR_LINEAR_H_ +#define INFINI_OPS_CPU_LINEAR_LINEAR_H_ + +#include + +#include "base/linear.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Linear, + Caster { + public: + Operator(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) + : Linear{a, b, bias, trans_a, trans_b, out} {} + + void operator()(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) const override { + DispatchFunc( + out.dtype(), + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(a, b, bias, trans_a, trans_b, out); + }, + "`Operator::operator()`"); + } + + private: + template + void Compute(const Tensor a, const Tensor b, std::optional bias, + bool trans_a, bool trans_b, Tensor out) const { + const auto* a_ptr = static_cast(a.data()); + const auto* b_ptr = static_cast(b.data()); + auto* out_ptr = static_cast(out.data()); + const T* bias_ptr = bias ? static_cast(bias->data()) : nullptr; + + auto ndim_a = a_shape_.size(); + auto ndim_b = b_shape_.size(); + auto ndim_out = out_shape_.size(); + + Tensor::Size m = out_shape_[ndim_out - 2]; + Tensor::Size n = out_shape_[ndim_out - 1]; + Tensor::Size k = trans_a ? a_shape_[ndim_a - 2] : a_shape_[ndim_a - 1]; + + // Compute strides for the inner matrix dimensions after transpose. + Tensor::Stride stride_a_m = + trans_a ? a_strides_[ndim_a - 1] : a_strides_[ndim_a - 2]; + Tensor::Stride stride_a_k = + trans_a ? a_strides_[ndim_a - 2] : a_strides_[ndim_a - 1]; + Tensor::Stride stride_b_k = + trans_b ? b_strides_[ndim_b - 1] : b_strides_[ndim_b - 2]; + Tensor::Stride stride_b_n = + trans_b ? b_strides_[ndim_b - 2] : b_strides_[ndim_b - 1]; + Tensor::Stride stride_out_m = out_strides_[ndim_out - 2]; + Tensor::Stride stride_out_n = out_strides_[ndim_out - 1]; + + // Batch dimensions. + Tensor::Size batch_count = 1; + for (size_t i = 0; i + 2 < ndim_out; ++i) { + batch_count *= out_shape_[i]; + } + + Tensor::Stride batch_stride_a = ndim_a > 2 ? a_strides_[ndim_a - 3] : 0; + Tensor::Stride batch_stride_b = ndim_b > 2 ? b_strides_[ndim_b - 3] : 0; + Tensor::Stride batch_stride_out = + ndim_out > 2 ? out_strides_[ndim_out - 3] : 0; + + // Bias stride: for 1D bias `[n]`, stride is 1. For batched bias, use last + // stride. + Tensor::Stride bias_stride = 0; + if (bias_ptr) { + auto ndim_bias = bias->shape().size(); + bias_stride = bias->strides()[ndim_bias - 1]; + } + + for (Tensor::Size batch = 0; batch < batch_count; ++batch) { + const auto* a_batch = a_ptr + batch * batch_stride_a; + const auto* b_batch = b_ptr + batch * batch_stride_b; + auto* out_batch = out_ptr + batch * batch_stride_out; + + for (Tensor::Size i = 0; i < m; ++i) { + for (Tensor::Size j = 0; j < n; ++j) { + float sum = 0.0f; + + for (Tensor::Size l = 0; l < k; ++l) { + float a_val = Cast(a_batch[i * stride_a_m + l * stride_a_k]); + float b_val = Cast(b_batch[l * stride_b_k + j * stride_b_n]); + sum += a_val * b_val; + } + + if (bias_ptr) { + sum += Cast(bias_ptr[j * bias_stride]); + } + + out_batch[i * stride_out_m + j * stride_out_n] = Cast(sum); + } + } + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/src/cpu/mul/mul.h b/src/cpu/mul/mul.h new file mode 100644 index 000000000..0bdefb96b --- /dev/null +++ b/src/cpu/mul/mul.h @@ -0,0 +1,63 @@ +#ifndef INFINI_OPS_CPU_MUL_MUL_H_ +#define INFINI_OPS_CPU_MUL_MUL_H_ + +#include + +#include "base/mul.h" +#include "common/generic_utils.h" +#include "cpu/caster_.h" + +namespace infini::ops { + +template <> +class Operator : public Mul, + Caster { + public: + Operator(const Tensor input, const Tensor other, Tensor out) + : Mul{input, other, out} {} + + void operator()(const Tensor input, const Tensor other, + Tensor out) const override { + DispatchFunc( + out_type_, + [&](auto tag) { + using T = typename decltype(tag)::type; + Compute(input, other, out); + }, + "`Operator::operator()`"); + } + + private: + template + void Compute(const Tensor input, const Tensor other, Tensor out) const { + using ComputeType = std::conditional_t || + IsFP16, + float, T>; + + const auto* input_ptr = static_cast(input.data()); + const auto* other_ptr = static_cast(other.data()); + auto* out_ptr = static_cast(out.data()); + + auto get_idx = [&](Tensor::Size i, bool is_contig, const auto* shape, + const auto* strides) { + return is_contig ? i : utils::IndexToOffset(i, ndim_, shape, strides); + }; + +#pragma omp parallel for + for (Tensor::Size i = 0; i < output_size_; ++i) { + auto input_idx = get_idx(i, is_input_contiguous_, input_shape_.data(), + input_strides_.data()); + auto other_idx = get_idx(i, is_other_contiguous_, other_shape_.data(), + other_strides_.data()); + auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), + out_strides_.data()); + + out_ptr[out_idx] = Cast(Cast(input_ptr[input_idx]) * + Cast(other_ptr[other_idx])); + } + } +}; + +} // namespace infini::ops + +#endif diff --git a/tests/conftest.py b/tests/conftest.py index 7b39007f9..d995459fc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -106,44 +106,33 @@ def skip_op_without_platform_impl(request): """Skip `device=` parametrizations when the op has no implementation on any of the corresponding platforms. - Derives the InfiniOps class name from the test module filename - (`tests/test_.py` → ``) and checks - `infini.ops..active_implementation_indices()` for every - platform that maps to the test's torch device type. Skips only when - every mapped platform reports no active implementation — avoids - `Fatal Python error: Aborted` from dispatching through a base class - that has no backend specialization on the current branch. + Only runs for tests that parametrize `device` but not + `implementation_index` — joint `(device, impl_idx)` parametrize in + `pytest_generate_tests` already prunes empty-impl pairs at collection + time, making this check redundant (and wasteful) on those tests. """ if not hasattr(request.node, "callspec"): return - device = request.node.callspec.params.get("device") - platforms = _TORCH_DEVICE_TO_PLATFORMS.get(device) - - if not platforms: - return - - module_name = request.node.module.__name__.rsplit(".", 1)[-1] + params = request.node.callspec.params - if not module_name.startswith("test_"): + if "implementation_index" in params: return - op_snake = module_name[len("test_") :] - op_pascal = "".join(part.capitalize() for part in op_snake.split("_")) + platforms = _TORCH_DEVICE_TO_PLATFORMS.get(params.get("device")) - try: - import infini.ops as _ops - except ImportError: + if not platforms: return - op_cls = getattr(_ops, op_pascal, None) + op_cls = _op_class_from_module(request.node.module) if op_cls is None or not hasattr(op_cls, "active_implementation_indices"): return if not any(op_cls.active_implementation_indices(p) for p in platforms): pytest.skip( - f"{op_pascal} has no implementation on any `{device}`-mapped platform" + f"{op_cls.__name__} has no implementation on any " + f"`{params.get('device')}`-mapped platform" ) @@ -191,7 +180,69 @@ def pytest_generate_tests(metafunc): else: devices = () - metafunc.parametrize("device", devices or available) + devices = devices or available + + # Joint `(device, implementation_index)` parametrize: generate only + # pairs where the op has an active implementation on that device. + # Avoids cross-device pollution — an impl active on `cpu` but not on + # `npu` no longer appears as a runtime skip in the npu column. + if ( + "implementation_index" in metafunc.fixturenames + and "implementation_index" not in already_parametrized + ): + op_cls = _op_class_from_module(metafunc.module) + + if op_cls is not None and hasattr(op_cls, "active_implementation_indices"): + pairs = [ + (dev, idx) + for dev in devices + for idx in op_cls.active_implementation_indices(dev) + ] + + if not pairs: + # Emit one skipped placeholder so test IDs read + # `[skip-dtype0-...]` instead of `[NOTSET-...]`. + # `get_available_devices()` always includes `"cpu"`, so + # `devices[0]` is safe. + pairs = [ + pytest.param( + devices[0], + 0, + marks=pytest.mark.skip( + reason=( + f"{op_cls.__name__} has no active " + "implementation on any available device" + ) + ), + id="skip", + ) + ] + + metafunc.parametrize("device, implementation_index", pairs) + + return + + metafunc.parametrize("device", devices) + + +def _op_class_from_module(module): + """Derive the `infini.ops.` class from a `tests/test_.py` module. + + Test modules have already imported `infini.ops` by the time this runs, so + no `try/except ImportError` is needed — a real import failure would have + aborted collection long before. + """ + module_name = module.__name__.rsplit(".", 1)[-1] + + if not module_name.startswith("test_"): + return None + + op_snake = module_name[len("test_") :] + op_pascal = "".join(part.capitalize() for part in op_snake.split("_")) + + import infini.ops as _ops + + return getattr(_ops, op_pascal, None) @pytest.hookimpl(tryfirst=True) diff --git a/tests/test_add.py b/tests/test_add.py index 12c4b9b59..e2266c30d 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -35,9 +35,6 @@ ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), ), ) -# TODO: Generate implementation indices dynamically from -# `Add.active_implementation_indices` instead of hardcoding. -@pytest.mark.parametrize("implementation_index", (0, 1)) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( @@ -63,11 +60,6 @@ def test_add( "The `torch.musa` test cloning path does not support `uint16`, `uint32`, or `uint64`." ) - active_indices = infini.ops.Add.active_implementation_indices(device) - - if implementation_index not in active_indices: - pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") - if implementation_index == 1 and dtype in _UINT_DTYPES: pytest.skip("ATen `add` does not support unsigned integer types") diff --git a/tests/test_cast.py b/tests/test_cast.py new file mode 100644 index 000000000..bd19d934b --- /dev/null +++ b/tests/test_cast.py @@ -0,0 +1,62 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, out_strides", + ( + ((13, 4), None, None), + ((13, 4), (10, 1), (10, 1)), + ((13, 4, 4), None, None), + ((16, 5632), None, None), + ((4, 4, 5632), None, None), + ), +) +@pytest.mark.parametrize( + ("input_dtype", "out_dtype", "rtol", "atol"), + ( + (torch.float16, torch.float32, 1e-3, 1e-3), + (torch.float32, torch.float16, 1e-3, 1e-3), + (torch.bfloat16, torch.float32, 1e-2, 5e-3), + (torch.float32, torch.bfloat16, 1e-2, 5e-3), + (torch.float16, torch.bfloat16, 1e-2, 5e-3), + (torch.bfloat16, torch.float16, 1e-2, 5e-3), + ), +) +def test_cast( + shape, + input_strides, + out_strides, + input_dtype, + out_dtype, + device, + rtol, + atol, +): + input = randn_strided(shape, input_strides, dtype=input_dtype, device=device) + out = empty_strided(shape, out_strides, dtype=out_dtype, device=device) + + return Payload( + _cast, + _torch_cast, + (input, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _cast(input, out): + infini.ops.cast(input, out, stream=get_stream(input.device)) + + return out + + +def _torch_cast(input, out): + out.copy_(input.to(out.dtype)) + + return out diff --git a/tests/test_cat.py b/tests/test_cat.py new file mode 100644 index 000000000..85428b53f --- /dev/null +++ b/tests/test_cat.py @@ -0,0 +1,69 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shapes, dim, out_shape", + ( + # 2 inputs, dim=0 + (((4, 64), (4, 64)), 0, (8, 64)), + # 2 inputs, dim=1 + (((4, 32), (4, 64)), 1, (4, 96)), + # 2 inputs, dim=-1 (negative dim) + (((4, 32), (4, 64)), -1, (4, 96)), + # 3 inputs, dim=1 + (((4, 16), (4, 32), (4, 16)), 1, (4, 64)), + # 2 inputs, dim=0, 3D + (((2, 4, 64), (2, 4, 64)), 0, (4, 4, 64)), + # 2 inputs, dim=2, 3D + (((2, 4, 32), (2, 4, 64)), 2, (2, 4, 96)), + # 4 inputs, dim=1 + (((1, 1024), (1, 1024), (1, 1024), (1, 1024)), 1, (1, 4096)), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ), +) +def test_cat(shapes, dim, out_shape, dtype, device, rtol, atol): + inputs = [randn_strided(s, None, dtype=dtype, device=device) for s in shapes] + out = empty_strided(out_shape, None, dtype=dtype, device=device) + + return Payload( + lambda *args: _cat(*args, dim=dim), + lambda *args: _torch_cat(*args, dim=dim), + (*inputs, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _cat(*args, dim): + inputs = list(args[:-1]) + out = args[-1] + + first = inputs[0] + rest = inputs[1:] + + infini.ops.cat(first, rest, dim, out, stream=get_stream(first.device)) + + return out + + +def _torch_cat(*args, dim): + inputs = list(args[:-1]) + out = args[-1] + + result = torch.cat(inputs, dim=dim) + out.copy_(result) + + return out diff --git a/tests/test_gemm.py b/tests/test_gemm.py index 97d060696..71e0e8fde 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -20,9 +20,6 @@ @pytest.mark.parametrize("beta", (-1, -0.5, 0, 0.5, 1)) @pytest.mark.parametrize("trans_a", (False, True)) @pytest.mark.parametrize("trans_b", (False, True)) -# TODO: Generate implementation indices dynamically from -# `Gemm.active_implementation_indices` instead of hardcoding. -@pytest.mark.parametrize("implementation_index", (0, 1, 2)) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( @@ -56,11 +53,6 @@ def test_gemm( if device == "mlu" and dtype == torch.bfloat16: pytest.skip("`bfloat16` is not supported by `cnnlBatchMatMulEx`") - active_indices = infini.ops.Gemm.active_implementation_indices(device) - - if implementation_index not in active_indices: - pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") - if implementation_index == 1 and dtype in (torch.float16, torch.bfloat16): pytest.skip("cuBLASLt half-precision exceeds current tolerances") diff --git a/tests/test_linear.py b/tests/test_linear.py new file mode 100644 index 000000000..364ba5fcf --- /dev/null +++ b/tests/test_linear.py @@ -0,0 +1,90 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "a_shape, b_shape, out_shape", + ( + ((4, 64), (64, 32), (4, 32)), + ((2, 128), (128, 256), (2, 256)), + ((1, 4096), (4096, 4096), (1, 4096)), + ((2, 4, 64), (2, 64, 32), (2, 4, 32)), + ((4, 8, 128), (4, 128, 64), (4, 8, 64)), + ), +) +@pytest.mark.parametrize("trans_a", (False, True)) +@pytest.mark.parametrize("trans_b", (False, True)) +@pytest.mark.parametrize("has_bias", (False, True)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-2, 5e-2), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 1e-2, 1e-2), + ), +) +def test_linear( + a_shape, + b_shape, + out_shape, + trans_a, + trans_b, + has_bias, + dtype, + device, + rtol, + atol, +): + a = randn_strided(a_shape, None, dtype=dtype, device=device) + b = randn_strided(b_shape, None, dtype=dtype, device=device) + + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + # Bias shape is [N], the last dim of the output. + bias = None + + if has_bias: + N = out_shape[-1] + bias = randn_strided((N,), None, dtype=dtype, device=device) + + out = empty_strided(out_shape, None, dtype=dtype, device=device) + + return Payload( + lambda *args: _linear(*args, trans_a=trans_a, trans_b=trans_b), + lambda *args: _torch_linear(*args, trans_a=trans_a, trans_b=trans_b), + (a, b, bias, out), + {}, + rtol=rtol, + atol=atol, + ) + + +def _linear(a, b, bias, out, trans_a=False, trans_b=False): + infini.ops.linear(a, b, bias, trans_a, trans_b, out, stream=get_stream(a.device)) + + return out + + +def _torch_linear(a, b, bias, out, trans_a=False, trans_b=False): + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + result = torch.matmul(a.float(), b.float()) + + if bias is not None: + result = result + bias.float() + + out.copy_(result.to(out.dtype)) + + return out diff --git a/tests/test_matmul.py b/tests/test_matmul.py new file mode 100644 index 000000000..fea3822a8 --- /dev/null +++ b/tests/test_matmul.py @@ -0,0 +1,76 @@ +import infini.ops +import pytest +import torch + +from tests.utils import Payload, empty_strided, get_stream, randn_strided + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "a_shape, b_shape, c_shape", + ( + ((4, 64), (64, 32), (4, 32)), + ((2, 128), (128, 256), (2, 256)), + ((2, 4, 64), (2, 64, 32), (2, 4, 32)), + ((4, 8, 128), (4, 128, 64), (4, 8, 64)), + ), +) +@pytest.mark.parametrize("trans_a", (False, True)) +@pytest.mark.parametrize("trans_b", (False, True)) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-2, 1e-2), + (torch.float16, 1e-2, 1e-2), + (torch.bfloat16, 1e-2, 1e-2), + ), +) +def test_matmul( + a_shape, + b_shape, + c_shape, + trans_a, + trans_b, + dtype, + device, + rtol, + atol, +): + a = randn_strided(a_shape, None, dtype=dtype, device=device) + b = randn_strided(b_shape, None, dtype=dtype, device=device) + + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + c = empty_strided(c_shape, None, dtype=dtype, device=device) + + return Payload( + lambda *args: _matmul(*args, trans_a=trans_a, trans_b=trans_b), + lambda *args: _torch_matmul(*args, trans_a=trans_a, trans_b=trans_b), + (a, b, c), + {}, + rtol=rtol, + atol=atol, + ) + + +def _matmul(a, b, c, trans_a=False, trans_b=False): + infini.ops.matmul(a, b, c, trans_a, trans_b, stream=get_stream(a.device)) + + return c + + +def _torch_matmul(a, b, c, trans_a=False, trans_b=False): + if trans_a: + a = a.transpose(-2, -1) + + if trans_b: + b = b.transpose(-2, -1) + + result = torch.matmul(a.float(), b.float()).to(c.dtype) + c.copy_(result) + + return c diff --git a/tests/test_mul.py b/tests/test_mul.py new file mode 100644 index 000000000..e368f96df --- /dev/null +++ b/tests/test_mul.py @@ -0,0 +1,87 @@ +import infini.ops +import pytest +import torch + +from tests.utils import ( + Payload, + empty_strided, + get_stream, + randint_strided, + randn_strided, +) + +_INT_DTYPES = (torch.int16, torch.int32, torch.int64) + +_UINT_DTYPES = tuple( + filter(None, (getattr(torch, f"uint{bits}", None) for bits in (16, 32, 64))) +) + + +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, other_strides, out_strides", + ( + ((13, 4), None, None, None), + ((13, 4), (10, 1), (10, 1), (10, 1)), + ((13, 4), (0, 1), None, None), + ((13, 4, 4), None, None, None), + ((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)), + ((13, 4, 4), (4, 0, 1), (0, 4, 1), None), + ((16, 5632), None, None, None), + ((16, 5632), (13312, 1), (13312, 1), (13312, 1)), + ((13, 16, 2), (128, 4, 1), (0, 2, 1), (64, 4, 1)), + ((13, 16, 2), (128, 4, 1), (2, 0, 1), (64, 4, 1)), + ((4, 4, 5632), None, None, None), + ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + (torch.bfloat16, 1e-2, 5e-3), + ) + + tuple((dtype, 0, 0) for dtype in _INT_DTYPES + _UINT_DTYPES), +) +def test_mul( + shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol +): + if device == "musa" and dtype in _UINT_DTYPES: + pytest.skip( + "The `torch.musa` test cloning path does not support `uint16`, `uint32`, or `uint64`." + ) + + if dtype in _INT_DTYPES or dtype in _UINT_DTYPES: + input = randint_strided( + 0, 100, shape, input_strides, dtype=dtype, device=device + ) + other = randint_strided( + 0, 100, shape, other_strides, dtype=dtype, device=device + ) + else: + input = randn_strided(shape, input_strides, dtype=dtype, device=device) + other = randn_strided(shape, other_strides, dtype=dtype, device=device) + + out = empty_strided(shape, out_strides, dtype=dtype, device=device) + + return Payload(_mul, _torch_mul, (input, other, out), {}, rtol=rtol, atol=atol) + + +def _mul(input, other, out): + infini.ops.mul(input, other, out, stream=get_stream(input.device)) + + return out + + +def _torch_mul(input, other, out): + if input.dtype in _UINT_DTYPES: + input = input.to(torch.int64) + + if other.dtype in _UINT_DTYPES: + other = other.to(torch.int64) + + res = torch.mul(input, other) + out.copy_(res.to(out.dtype)) + + return out diff --git a/tests/test_rms_norm.py b/tests/test_rms_norm.py index 52fd1ae4f..45f9199ba 100644 --- a/tests/test_rms_norm.py +++ b/tests/test_rms_norm.py @@ -18,7 +18,6 @@ ), ) @pytest.mark.parametrize("eps", (1e-6, 1e-5)) -@pytest.mark.parametrize("implementation_index", (0, 1)) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( @@ -40,11 +39,6 @@ def test_rms_norm( rtol, atol, ): - active_indices = infini.ops.RmsNorm.active_implementation_indices(device) - - if implementation_index not in active_indices: - pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") - input = randn_strided(input_shape, input_strides, dtype=dtype, device=device) weight = randn_strided(weight_shape, weight_strides, dtype=dtype, device=device) out = empty_strided(input_shape, out_strides, dtype=dtype, device=device) diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py index 23c299438..f159742ca 100644 --- a/tests/test_swiglu.py +++ b/tests/test_swiglu.py @@ -19,7 +19,6 @@ ((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)), ), ) -@pytest.mark.parametrize("implementation_index", (0, 1)) @pytest.mark.parametrize( ("dtype", "rtol", "atol"), ( @@ -39,11 +38,6 @@ def test_swiglu( rtol, atol, ): - active_indices = infini.ops.Swiglu.active_implementation_indices(device) - - if implementation_index not in active_indices: - pytest.skip(f"implementation `{implementation_index}` not active on `{device}`") - input = rand_strided(shape, input_strides, dtype=dtype, device=device) gate = rand_strided(shape, gate_strides, dtype=dtype, device=device) out = empty_strided(shape, out_strides, dtype=dtype, device=device) From 123fddca9a38b931845f0102611118b6e67123cc Mon Sep 17 00:00:00 2001 From: gongchensu Date: Wed, 22 Apr 2026 19:49:03 +0800 Subject: [PATCH 03/11] fix: add `-std=c++17` to Iluvatar CUDA flags - pass -std=c++17 through CMAKE_CUDA_FLAGS for Iluvatar clang builds Co-authored-by: zhuyue --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b72c34eaf..91c2b0154 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -211,7 +211,7 @@ if(WITH_ILUVATAR) # `-x ivcore` must not be in `CMAKE_CUDA_FLAGS` — CMake passes those flags # to both compile and link steps. During linking, `-x ivcore` causes # `clang++` to re-parse `.o` files as source code. - set(CMAKE_CUDA_FLAGS "--cuda-gpu-arch=${ILUVATAR_ARCH} -fPIC -Wno-error=unused-variable -Wno-error=unused-private-field -Wno-unused-variable" CACHE STRING "Iluvatar CUDA flags") + set(CMAKE_CUDA_FLAGS "--cuda-gpu-arch=${ILUVATAR_ARCH} -fPIC -Wno-error=unused-variable -Wno-error=unused-private-field -Wno-unused-variable -std=c++17" CACHE STRING "Iluvatar CUDA flags") set(CMAKE_CUDA_SEPARABLE_COMPILATION OFF CACHE BOOL "Disable RDC for Iluvatar") set(CMAKE_CUDA_ARCHITECTURES OFF CACHE STRING "Iluvatar CUDA architectures (passed via CMAKE_CUDA_FLAGS)") # Iluvatar does not ship `libcudadevrt`, which CMake's compiler test From 5636cec470ab6141960a522350662f91a416fafc Mon Sep 17 00:00:00 2001 From: zhangyue <138768300+zhangyue207@users.noreply.github.com> Date: Thu, 23 Apr 2026 15:44:42 +0800 Subject: [PATCH 04/11] chore(lint): add `.clang-tidy` for Google-style naming enforcement (#70) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore(lint): add .clang-tidy for Google-style naming enforcement `clang-format` only enforces whitespace/braces/include order — naming violations (`BUFFER_NUM`, `dimLength`, `inQueueX1`, missing private-member trailing `_`, etc.) pass silently. This PR adds `clang-tidy` with `readability-identifier-naming.*` wired to the Google C++ Style Guide so the `code-lint` skill can catch them. - `.clang-tidy` at repo root: types `PascalCase`, functions `PascalCase`, variables / parameters `snake_case`, private members `snake_case_`, constants `kPascalCase`, macros `UPPER_CASE`, namespaces `lower_case`. Only `readability-identifier-naming.*` is `WarningsAsErrors`; the `google-*` / `modernize-*` checks are advisory. - `src/ascend/custom/.clang-tidy`: relaxes `FunctionCase` to `lower_case` because `ascendc_add_operator(OP_NAME …)` dictates snake_case kernel entry symbol names that cannot be `PascalCase`d. - `src/ascend/custom/rms_norm/op_kernel/.clang-tidy`: disables all checks for device code compiled by `ccec` (absent from `compile_commands.json`, `__aicore__` macro parses incorrectly without `kernel_operator.h`). - `pyproject.toml`: turns on `CMAKE_EXPORT_COMPILE_COMMANDS` so every editable `pip install` emits `compile_commands.json` for `clang-tidy`. - `src/device.h`: adds missing `` / `` includes — pre-existing transitive-include bug surfaced by `clang-tidy`'s stricter parsing. * chore(pr70-review): address review comments - `pyproject.toml`: wrap `scikit-build` in backticks; insert blank line between build-related defines and tool-related defines. - `.clang-tidy`: rewrite section divider comments as complete sentences ending in a period, per project convention. --------- Co-authored-by: zhangyue --- .clang-tidy | 81 +++++++++++++++++++ pyproject.toml | 5 ++ src/ascend/custom/.clang-tidy | 12 +++ .../custom/rms_norm/op_kernel/.clang-tidy | 9 +++ src/device.h | 3 + 5 files changed, 110 insertions(+) create mode 100644 .clang-tidy create mode 100644 src/ascend/custom/.clang-tidy create mode 100644 src/ascend/custom/rms_norm/op_kernel/.clang-tidy diff --git a/.clang-tidy b/.clang-tidy new file mode 100644 index 000000000..63af38e3b --- /dev/null +++ b/.clang-tidy @@ -0,0 +1,81 @@ +--- +# Google C++ Style Guide enforcement. +# +# `clang-format` (via `.clang-format`) handles formatting; this file handles +# everything `clang-format` cannot see — primarily identifier naming. +# +# Usage: +# clang-tidy -p build/ --quiet (needs `compile_commands.json`) +# +# The project's scikit-build config sets `CMAKE_EXPORT_COMPILE_COMMANDS=ON`, +# so `pip install -e .[dev]` produces `build*/compile_commands.json`. +# +# Subdirectories may ship their own `.clang-tidy` with +# `InheritParentConfig: true` to override specific rules (e.g. +# `src/ascend/custom/.clang-tidy` relaxes `FunctionCase` for `extern "C"` +# `AscendC` kernel entries whose symbol names are dictated by +# `ascendc_add_operator(OP_NAME …)`). + +Checks: > + readability-identifier-naming, + google-explicit-constructor, + google-readability-braces-around-statements, + google-readability-casting, + modernize-use-nullptr, + modernize-use-override + +# Only naming violations fail the lint — the rest are advisory while the +# project ramps up on `clang-tidy`. +WarningsAsErrors: 'readability-identifier-naming.*' + +HeaderFilterRegex: '^(src|tests)/.*\.(h|hpp|cuh)$' + +CheckOptions: + # Types. + - {key: readability-identifier-naming.ClassCase, value: CamelCase} + - {key: readability-identifier-naming.StructCase, value: CamelCase} + - {key: readability-identifier-naming.UnionCase, value: CamelCase} + - {key: readability-identifier-naming.EnumCase, value: CamelCase} + - {key: readability-identifier-naming.TypeAliasCase, value: CamelCase} + - {key: readability-identifier-naming.TypedefCase, value: CamelCase} + - {key: readability-identifier-naming.TypeTemplateParameterCase, value: CamelCase} + + # Enumerators and constants use `k` + PascalCase. + - {key: readability-identifier-naming.EnumConstantCase, value: CamelCase} + - {key: readability-identifier-naming.EnumConstantPrefix, value: 'k'} + - {key: readability-identifier-naming.ConstexprVariableCase, value: CamelCase} + - {key: readability-identifier-naming.ConstexprVariablePrefix, value: 'k'} + - {key: readability-identifier-naming.GlobalConstantCase, value: CamelCase} + - {key: readability-identifier-naming.GlobalConstantPrefix, value: 'k'} + - {key: readability-identifier-naming.StaticConstantCase, value: CamelCase} + - {key: readability-identifier-naming.StaticConstantPrefix, value: 'k'} + + # Functions and methods use PascalCase. + - {key: readability-identifier-naming.FunctionCase, value: CamelCase} + - {key: readability-identifier-naming.MethodCase, value: CamelCase} + + # Variables and parameters use snake_case. + - {key: readability-identifier-naming.VariableCase, value: lower_case} + - {key: readability-identifier-naming.ParameterCase, value: lower_case} + - {key: readability-identifier-naming.LocalVariableCase, value: lower_case} + + # Class data members use snake_case with a trailing `_`. + - {key: readability-identifier-naming.PrivateMemberCase, value: lower_case} + - {key: readability-identifier-naming.PrivateMemberSuffix, value: '_'} + - {key: readability-identifier-naming.ProtectedMemberCase, value: lower_case} + - {key: readability-identifier-naming.ProtectedMemberSuffix, value: '_'} + - {key: readability-identifier-naming.PublicMemberCase, value: lower_case} + + # Macros use UPPER_CASE. + - {key: readability-identifier-naming.MacroDefinitionCase, value: UPPER_CASE} + # Include guards end in a trailing `_` (e.g. `INFINI_OPS_FOO_H_`), which + # the default `UPPER_CASE` style rejects — skip them by regex. + - {key: readability-identifier-naming.MacroDefinitionIgnoredRegexp, value: '^INFINI_OPS_.*_H_$'} + + # Namespaces use lower_case. + - {key: readability-identifier-naming.NamespaceCase, value: lower_case} + + # Exemptions. + # Type-trait variables (`IsFP16`, `IsBFloat16`, `HasKey`, …) mirror + # `std::is_*_v<>` naming — `PascalCase` without `k` prefix. + - {key: readability-identifier-naming.ConstexprVariableIgnoredRegexp, value: '^(Is|Has|Can)[A-Z].*'} diff --git a/pyproject.toml b/pyproject.toml index 58740166e..959699f90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,5 +17,10 @@ AUTO_DETECT_DEVICES = "ON" AUTO_DETECT_BACKENDS = "ON" GENERATE_PYTHON_BINDINGS = "ON" +# Enables `compile_commands.json` under `SKBUILD_BUILD_DIR` (or the default +# `scikit-build` build dir) for `clang-tidy -p ` used by +# the `code-lint` skill. +CMAKE_EXPORT_COMPILE_COMMANDS = "ON" + [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/src/ascend/custom/.clang-tidy b/src/ascend/custom/.clang-tidy new file mode 100644 index 000000000..d1f1296a6 --- /dev/null +++ b/src/ascend/custom/.clang-tidy @@ -0,0 +1,12 @@ +--- +# `CANN`'s `ascendc_add_operator(OP_NAME )` dictates the +# symbol name of the `extern "C" __global__ __aicore__ void (…)` +# kernel entry. The `aclrtlaunch_` wrapper and `op_host/` call site +# must match, so these entries cannot follow Google's `PascalCase` +# convention. Relax `FunctionCase` to `lower_case` within this subtree; +# class and method names still inherit `CamelCase` from the root config. + +InheritParentConfig: true + +CheckOptions: + - {key: readability-identifier-naming.FunctionCase, value: lower_case} diff --git a/src/ascend/custom/rms_norm/op_kernel/.clang-tidy b/src/ascend/custom/rms_norm/op_kernel/.clang-tidy new file mode 100644 index 000000000..ccf13972c --- /dev/null +++ b/src/ascend/custom/rms_norm/op_kernel/.clang-tidy @@ -0,0 +1,9 @@ +--- +# `op_kernel/*.cpp` is `AscendC` device code compiled by `ccec`, not by +# the host toolchain, so it has no entry in `compile_commands.json` and +# `clang-tidy` cannot parse it correctly (the `__aicore__` macro expands +# unexpectedly when `kernel_operator.h` is absent). Disable all checks +# here — the `op_host/` side and the `kernel_custom.h` launcher still +# enforce the full ruleset. + +Checks: '-*' diff --git a/src/device.h b/src/device.h index 38e4bce11..688cd0dc2 100644 --- a/src/device.h +++ b/src/device.h @@ -1,6 +1,9 @@ #ifndef INFINI_OPS_DEVICE_H_ #define INFINI_OPS_DEVICE_H_ +#include +#include + #include "common/constexpr_map.h" #include "common/traits.h" #include "hash.h" From 38a23cfd0adb803107672a040462349d0295af34 Mon Sep 17 00:00:00 2001 From: zhangyue <138768300+zhangyue207@users.noreply.github.com> Date: Fri, 24 Apr 2026 14:12:33 +0800 Subject: [PATCH 05/11] =?UTF-8?q?feat(ascend):=20op-norm-rope=20group=20?= =?UTF-8?q?=E2=80=94=20Swiglu,=20SiluAndMul,=20CausalSoftmax,=20RmsNorm,?= =?UTF-8?q?=20AddRmsNorm,=20ApplyRotaryPosEmb,=20RotaryEmbedding=20(#66)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(ascend): op-norm-rope group — Swiglu, SiluAndMul, CausalSoftmax, RmsNorm, AddRmsNorm, ApplyRotaryPosEmb, RotaryEmbedding Seven layer-level Ascend operators: | op | impl | |---|---| | Swiglu | aclnnSilu + aclnnMul (decomposed); `kernel_fused.h` wraps fused swiglu where available | | SiluAndMul | custom AscendC kernel | | CausalSoftmax | aclnnSoftmax + pre-computed mask | | RmsNorm | aclnnRmsNorm (kernel.h); custom AscendC variant (kernel_custom.h) | | AddRmsNorm | 3 impls: decomposed aclnnAdd+aclnnRmsNorm (kernel.h); fused aclnnAddRmsNorm (kernel_fused.h); custom AscendC (kernel_custom.h) | | ApplyRotaryPosEmb | aclnnApplyRotaryPosEmbV2 (kernel.h); ATB RopeParam (kernel_atb.h) | | RotaryEmbedding | **3 impls**: aclnnApplyRotaryPosEmbV2 (kernel.h); ATB RopeParam with both neox/interleave (kernel_atb.h); aclnnRopeWithSinCosCache for partial rotary (kernel_sincos_cache.h) | Bundles the RotaryEmbedding API alignment: `query_out` / `key_out` are now `std::optional` — omitted → inplace on `query` / `key` (matches vLLM `RotaryEmbedding.forward(positions, query, key)`). New `src/base/.h`: apply_rotary_pos_emb, silu_and_mul. Modified: add_rms_norm (constructor signature alignment), rotary_embedding (optional query_out/key_out). * fix(ascend): norm/swiglu destructors + missing add_rms_norm custom kernel registration - swiglu/kernel_fused.h: release() cat_out_cache_ and out_staging_cache_ to avoid double-free; drop aclDestroyTensorList per 64c367c convention. - silu_and_mul/kernel.h: release() out_staging_cache_ to avoid double-free. - custom/CMakeLists.txt: add add_rms_norm sources to OP_SRCS and register its op_kernel via ascendc_library(no_workspace_kernel ...); without this, aclrtlaunch_add_rms_norm has no backing implementation. * style(ascend): rename `AddRmsNorm` parameters to PyTorch-aligned names - `x1/x2/gamma/y_out/x_out` -> `input/other/weight/out/rstd_out`. - Propagate through base header, all three Ascend kernel variants (`kernel.h`, `kernel_fused.h`, `kernel_custom.h`), and test file. - Remove stale `rstd_shape_` field from base (unused; `kernel.h` holds its own copy). - Upgrade assertion messages to complete sentences with backticked identifiers. * style(ascend): comment + assert message audit for norm/swiglu/softmax kernels - Wrap `aclnn*` / `aclrt*` identifiers in backticks and ensure complete-sentence, period-terminated comments per CONTRIBUTING.md. - `silu_and_mul` base header: upgrade assertion message to a complete sentence with backticked identifiers. - Files touched: causal_softmax/kernel.h, rms_norm/kernel.h, swiglu/kernel.h, swiglu/kernel_fused.h, base/silu_and_mul.h. * test(silu_and_mul): add `implementation_index` parametrize and strided coverage - Wire `implementation_index` into joint `(device, implementation_index)` parametrize via conftest; enforces fixture symmetry with `test_swiglu.py`. - Add two non-contiguous shape cases to exercise the staging-buffer copy path in `src/ascend/silu_and_mul/kernel.h`. * refactor(ascend/rotary_embedding): unify RotaryEmbedding and ApplyRotaryPosEmb base ops Merge the two rope base headers into one vLLM-compatible op matching `RotaryEmbedding.forward(positions, query, key=None) -> (query, key|None)`. `key` becomes `std::optional` (MLA), `query_out` / `key_out` remain optional for the vLLM-native inplace path, and a new `bool pre_gathered` constructor flag folds the old `ApplyRotaryPosEmb` fast path into the unified op. Kernel updates across all three Ascend impls: - impl 0 (`aclnnApplyRotaryPosEmbV2`) and impl 1 (ATB `RopeParam`) accept the optional `key` / out tensors and honor `pre_gathered` (skipping internal `aclnnIndexSelect` when the caller has pre-gathered). - impl 0 and impl 1 re-upload the expanded cos/sin tables on cache-pointer change (reviewer-flagged stale-pointer bug). - impl 2 (`aclnnRopeWithSinCosCache`) destroys its per-call `aclOpExecutor` instead of leaking it (reviewer-flagged leak). - Uppercase locals (`D`, `T`, `Nq`, `Nkv`, `half_D`, `hiddenQ`, `hiddenK`) renamed to snake_case, and `uploadCosSinCache` renamed to `UploadCosSinCache` per Google C++ style. * feat(scripts/generate_wrappers): emit `apply_rotary_pos_emb` Python shim After the `ApplyRotaryPosEmb` base class was folded into the unified `RotaryEmbedding` op, vllm-infini still calls `infini.ops.apply_rotary_pos_emb(...)` — preserve that symbol as a pybind11 Python-level shim bound alongside the generated `rotary_embedding` binding. The shim un-expands the caller's neox-duplicated `[T, head_size]` cos / sin halves, concats into a `[T, head_size*2]` pre-gathered cache, synthesizes `positions = arange(T)`, and forwards to the unified op with `pre_gathered=True`. No vllm-infini changes are needed. * test(rotary_embedding): merge apply_rotary_pos_emb cases + cover MLA/3D/partial Consolidate `test_apply_rotary_pos_emb.py` (deleted separately) into `test_rotary_embedding.py`: - `test_apply_rotary_pos_emb` — pre-gathered fast path through the new Python shim; asserts bit-exact parity against `infini.ops.rotary_embedding` on the same data. - `test_apply_rotary_pos_emb_3d` — 3D `[T, Nq, D]` / `[T, Nkv, D]` layout through the shim (reviewer gap). - `test_rotary_embedding_partial` — extend to cover `is_neox_style=False` on impl 2 (`aclnnRopeWithSinCosCache`), matching the reviewer's partial-rotary gap on the non-neox path. - `_ref_rotary_embedding` now tolerates `key=None` (MLA). * fix(generate_wrappers): propagate scalar param defaults to pybind signature Without this, the unified `RotaryEmbedding`'s new `bool pre_gathered` parameter became a required positional kwarg on the Python side, breaking every existing `infini.ops.rotary_embedding(...)` caller that did not pass it. Regex-scan the base header for ` name = ` patterns and emit `py::arg(name) = ` in `_generate_py_args`. Also restore the default on the virtual `operator()` override in `src/base/rotary_embedding.h` so the regex picks it up. * fix(ascend/rotary_embedding): correct pre-gathered layout + revert sincos executor destroy Two in-flight regressions from the previous commit: 1. The `pre_gathered=true` path in kernel.h / kernel_atb.h assumed the caller's `cos_sin_cache` is `[T, head_size*2]` (dim-1 concat), but that layout can't be split with a flat byte offset because row-major contiguous layout interleaves cos and sin per row. Change the wire format to `[2T, head_size]` (dim-0 concat) so the first `T * head_size * elem_sz` bytes are contiguous cos and the next are contiguous sin; update both kernels and the `apply_rotary_pos_emb` Python shim to match. Also set the initial `sin_v2_cache_` base pointer to the sin offset so the V2 executor captures distinct cos/sin addresses on first call. 2. `kernel_sincos_cache.h` (impl 2) SIGABRTs when the per-call `aclOpExecutor*` is destroyed right after `aclnnRopeWithSinCosCache` — the kernel is async on the stream and the executor backs the enqueued launch. Revert the `aclDestroyAclOpExecutor` call (still leaks, but matches the prior behavior that passed all partial-rotary tests) and leave a TODO for proper Repeatable-executor caching once the input-address index layout for this kernel is confirmed. * test(rotary_embedding): fix GPT-J reference for partial rotary The GPT-J-style branch in `_ref_rotary_embedding` indexed `x[t, :, 0::2]` and `x[t, :, 1::2]` across the full `head_size` — correct only when `rotary_dim == head_size`. For partial rotary, only the first `rotary_dim` features rotate; restrict slices to `0:R:2` and `1:R:2`. * refactor(pr66-simplify): correct `rstd_out` semantic name + clarity fixes Post-merge /simplify review findings applied: - **`AddRmsNorm` param rename** (`src/base/add_rms_norm.h` + 3 Ascend kernels + test): `rstd_out` → `residual_out`. The slot actually holds `xOut` (the `input + other` residual sum) per `aclnnAddRmsNorm`'s API — the internal `rstd_tensor_` reciprocal-std buffer is private. Prior name was misleading. - **Generator shim for `apply_rotary_pos_emb`** (`scripts/generate_wrappers.py`): rename the `head_size`-as-`rotary_dim` positional forward to a named local `rotary_dim_shim` + comment noting the legacy shim assumes full rotary (`rotary_dim == head_size`). - **`kernel_sincos_cache.h` leak comment**: TODO → FIXME with persistent-worker impact call-out. Actual fix still blocked on undocumented input-address index layout for `aclnnRopeWithSinCosCache`. Skipped findings: reviewer false positives on `src/base/rotary_embedding.h` members (all consumed by kernels) and `max_seq_len_` (used in constructor body). Larger refactors (UploadCosSinCache + IndexSelect helpers, ~100 lines copy-paste) deferred to a follow-up PR. * style(tests): ruff format `test_add_rms_norm.py` after `residual_out` rename * build(ascend-custom): drive `build.sh` from `pip install` with proper dep tracking In-tree `ascendc_library()` trips a `CANN` `extract_host_stub.py` path bug (`KeyError` on `/./workspace/...` paths in `$`) whenever it runs under `scikit-build-core`'s temp-dir builds. Standalone `src/ascend/custom/build.sh` avoids the bug by invoking a separate `cmake` with `src/ascend/custom/` as its `SOURCE_DIR`. This commit drives `build.sh` from the main build so devs / CI get a working install from a single `pip install` call. - `option(BUILD_ASCEND_CUSTOM …)` replaces the old `BUILD_CUSTOM_KERNEL` (name is Ascend-specific now that the driver is CMake-native) and **defaults to ON**. Non-Ascend builds ignore it (gated by `WITH_ASCEND` in `src/CMakeLists.txt`); users who don't want the `ccec` build on Ascend pass `-DBUILD_ASCEND_CUSTOM=OFF`. - `src/CMakeLists.txt` registers `build.sh` as a build-phase `add_custom_command(OUTPUT …/libno_workspace_kernel.a)` with explicit dependencies on every `src/ascend/custom/**/*.{cpp,h}` file (via `file(GLOB_RECURSE … CONFIGURE_DEPENDS)`) — edits to any `op_host/` or `op_kernel/` source now re-trigger the build, instead of silently reusing a stale `.a`. The outer `scikit-build-core` env (`CMAKE_GENERATOR`, `CMAKE_EXPORT_COMPILE_COMMANDS`, …) is scrubbed via `cmake -E env --unset=…` before invoking `build.sh` — leaving them set makes the nested `cmake`'s `ninja` generator emit the bug-triggering `/./workspace/...` paths even though the outer configure dir is clean. - `src/ascend/custom/cmake/detect_soc.cmake` holds `infiniops_detect_soc()`, which parses `npu-smi info` for the first `910*` / `310*` entry and falls back to `Ascend910B4`. Both `src/CMakeLists.txt` (outer build) and `src/ascend/custom/cmake/config_ascend.cmake` (sub-build driven by `build.sh`) `include()` this file — SOC detection lives in one place. - `src/ascend/custom/CMakeLists.txt` pushes the main `src/` dir onto the interface target's `INCLUDES` property so the kernel TU can `#include "data_type.h"`. - `src/ascend/custom/add_rms_norm/op_kernel/.clang-tidy`: disables all `clang-tidy` checks on `ccec`-compiled device code (absent from `compile_commands.json`, `__aicore__` macro parses incorrectly without `kernel_operator.h`). Dev workflow: `pip install -e .[dev]` gives a fully working install on Ascend; editing any custom-kernel source and re-running `pip install` re-triggers the `ccec` build automatically. * refactor(data_type): pin `DataType` enum values explicitly The `AscendC` custom kernels forward `static_cast(input.dtype())` to their `aclrtlaunch_*` entry points and dispatch on the same enum — making `DataType`'s integer values part of a host↔device ABI. Assign explicit values (`kInt8 = 0, …, kFloat64 = 11`) to pin that ABI: reordering or inserting entries above existing ones would silently change the integers seen by device code. No behaviour change at call sites (the enum is still accessed by symbolic name everywhere except the `int64_t` cast). * feat(ascend-custom): add bf16 support + Google-style identifier renames bf16 was silently producing garbage / NaN on impl 1 (`rms_norm`) and impl 2 (`add_rms_norm`): the kernels only instantiated `` and ``, and the launchers mapped bf16 to the fp32 byte-size path, so bf16 weight was read as if it were fp32 and the fp16 output cast used `CAST_ROUND` (fp16-only alias). Kernel dispatch: - `op_kernel/rms_norm.cpp` / `op_kernel/add_rms_norm.cpp`: add a `KernelXxx` instantiation; dispatch in the `extern "C"` entry is now `switch (static_cast(dtypeCode))` (shared enum forwarded from the launcher via `int64_t`). The fp16/bf16 branch uses `CAST_RINT` for the fp32 → T writeback — defined for both `half` and `bfloat16_t` destinations, whereas `CAST_ROUND` is a `half`-specific alias. Launchers (`kernel_custom.h`): - Store `DataType dtype_` (replaces the old `int64_t dtype_size_` which collapsed fp16 and bf16 onto the same code). - Use `ascend::ToAclDtype(dtype_)` and `kDataTypeToSize.at(dtype_)` instead of hand-rolled ternaries (consistent with the rest of the Ascend backend). - Forward `static_cast(dtype_)` as the kernel's `dtypeCode`. - `extern "C" aclrtlaunch_*` forward-decl parameters renamed to `snake_case`; the function name itself is generated by `ascendc_add_operator(OP_NAME …)` and carries `// NOLINTNEXTLINE(readability-identifier-naming)` so `clang-tidy` accepts it. Identifier naming (Google C++ Style): - `op_kernel/*.cpp` members `snake_case_`, params / locals `snake_case`, constants `kPascalCase` (was `BUFFER_NUM` / `dimLength` / `inQueueX1` / `blockRows`, etc. — inherited from the `vllm-ascend` sample style). Verified: `pytest tests/test_rms_norm.py tests/test_add_rms_norm.py --devices ascend` → 144 passed / 0 failed (fp32 / fp16 / bf16 × both ops × full shape + stride matrix). * refactor(base): align Linear/SiluAndMul/AddRmsNorm/RotaryEmbedding with vLLM Bring `src/base/*.h` interfaces and tensor conventions into strict alignment with vLLM's public kernel contracts. Derived Ascend kernels and tests follow. `generated/bindings/` will regenerate on next build. - **`SiluAndMul`**: rename `x` → `input` (matches `F.glu(input, dim)`); add `(input, out)` overload with `dim = -1` default to match vLLM's hardcoded last-dim behavior. - **`Linear`**: add vLLM-aligned `(input, weight, bias?, out)` overload with weight stored as `[out_features, in_features]` (identical to `F.linear(input, weight, bias)`). Deprecated 6-arg `(a, b, bias, trans_a, trans_b, out)` form retained. CPU and Ascend subclasses gain matching 4-arg ctors that delegate to the 6-arg form with `trans_a = false, trans_b = true`. - **`AddRmsNorm`**: rename `other` → `residual` (matches vLLM's `fused_add_rms_norm(input, residual, weight, eps)` schema); add inplace `(input, residual, weight, eps)` overload that forwards to the out-of-place primary form with aliased buffers. - **`RotaryEmbedding`**: reorder first six parameters to match vLLM's `rotary_embedding(positions, query, key?, head_size, cos_sin_cache, is_neox)` schema verbatim; `rotary_dim` / `query_out?` / `key_out?` / `pre_gathered` remain as InfiniOps extensions at the tail. Added `positions.dtype() == int64` assert per vLLM convention. Verified on NPU: `pytest tests/test_{silu_and_mul,add_rms_norm,rotary_embedding,linear}.py --devices ascend` → 295 passed, 4 skipped, 0 failed. * refactor(base): trim narrative comments and collapse CPU Linear ctors Follow-up to `c23901a`. Per CLAUDE.md "default to writing no comments", strip doc-comments that narrate the change or restate well-named identifiers from the four refactored base headers. Keep only the one WHY comment in `rotary_embedding.h` explaining `pre_gathered`'s index_select+neox precondition (the name alone doesn't carry it). Also replace the two delegating ctors in `src/cpu/linear/linear.h` with `using Linear::Linear;` — matches the pattern already used in `src/cpu/{rms_norm,swiglu}/*.h`, `src/cuda/{rms_norm,causal_softmax}/*.h`. Verified: `pytest tests/test_{silu_and_mul,add_rms_norm,rotary_embedding,linear}.py --devices ascend` → 295 passed, 4 skipped. * fix(pr66-review): address review findings 1-3 - `tests/test_add_rms_norm.py`: extend `implementation_index` parametrize to `(0, 1, 2)`; add `_clear_add_rms_norm_cache` autouse fixture to avoid cross-test state pollution in the custom AscendC kernel (impl 2) whose cached fp32 weight buffer collides across tests with matching shape/dtype keys. Coverage: +54 test cases (108 total, all green). - `src/base/rotary_embedding.h`: assert `key.has_value()` with a TODO noting MLA is not yet implemented on any Ascend backend. All three impls already assert `has_key_` individually; hoisting the check to base turns a silent crash (if a caller passes `key=None`) into a clean assert. Keeps `std::optional key` in the signature for future MLA support without breaking vLLM API compatibility. - `src/ascend/causal_softmax/kernel.h`: add justification for the 3-primitive decomposition (no single CANN 8.5 API covers causal-mask + softmax; `aclnnSoftmaxV2` lacks the mask argument, and `aclnnScaledMaskedSoftmax` requires a pre-scaled attention score), per CLAUDE.md Ascend rule "never decompose when a fused API exists". Verified: `pytest tests/test_{silu_and_mul,add_rms_norm,rotary_embedding,linear,causal_softmax}.py --devices ascend` → 349 passed, 4 skipped. * refactor(pr66): drop `apply_rotary_pos_emb` wrapper + tests The legacy `apply_rotary_pos_emb` shim existed only as a vllm-infini compat alias after the `ApplyRotaryPosEmb` base op was folded into the unified `RotaryEmbedding`. vllm-infini is out of scope for this PR, so drop the shim entirely: - `scripts/generate_wrappers.py`: remove `_generate_apply_rotary_pos_emb_shim` and the `extra_shim` emission hook — the Python-level wrapper was ~60 lines of pybind C++ that concatenated cos/sin, synthesized `positions = arange(T)`, and forwarded to `rotary_embedding` with `pre_gathered=True`. Callers that need the pre-gather fast path can invoke `infini.ops.rotary_embedding(..., pre_gathered=True)` directly. - `tests/test_rotary_embedding.py`: remove `test_apply_rotary_pos_emb` / `test_apply_rotary_pos_emb_3d` and the `_expand_cos_sin` helper that only those tests used. `pre_gathered=True` remains exercised indirectly via `test_rotary_embedding_full` when impl 2 requires the caller to pre-gather (handled internally by the kernel). - Touch up two stale `apply_rotary_pos_emb shim` comments in `kernel{,_atb}.h` that no longer point anywhere. Verified: `pytest tests/ --devices ascend` → 2278 passed, 1612 skipped (was 2306 / 1612 — delta is the 28 removed `apply_rotary_pos_emb` cases). * test(rotary_embedding): add `pre_gathered=True` coverage Fold the deleted `test_apply_rotary_pos_emb` / `_3d` cases into a single `test_rotary_embedding_pre_gathered` that exercises the `pre_gathered` fast path directly on the `rotary_embedding` overload (no shim). Parametrize over 2D / 3D query-key layouts, impls 0 and 1 (impl 2 asserts `!pre_gathered_`), neox / GPT-J styles, fp16 / bf16. The new `_build_pre_gathered_cache` helper constructs the `[2*T, head_size]` wire format that `src/ascend/rotary_embedding/kernel.h` expects — cos rows 0..T-1, sin rows T..2T-1, both neox-expanded per token. Coverage: 12 new cases pass (4 skip for `impl=0 + not-neox`, same as the `test_rotary_embedding_full` skip — V2 only supports `rotaryMode="half"`). Full rotary suite: 88 passed, 8 skipped (was 80 passed, 4 skipped before this test was added). * chore(pr66): drop unused headers - `src/base/add_rms_norm.h`: `#include ` — no `size_t` usage. - `src/base/rotary_embedding.h`: same. - `src/ascend/add_rms_norm/kernel_custom.h`: `#include ` — no `std::vector` / `std::array` usage. Build + 355 passed / 8 skipped on Ascend unchanged. * style(pr66): sweep assert-message periods + comment backticks Addresses inline review comments on #66 (reviewer: Ziminli) across all PR-touched files: - C4: strip trailing periods from assert messages; lowercase the sentence-starting word when it is bare English (e.g. "Ascend ..." → "ascend ..."), leave backticked identifiers untouched. - G4: backtick `RmsNorm` in kernel_custom.h header comment; backtick `aclnn` / `cos_sin_cache` / `infini.ops.add_rms_norm(...)` in kernel comments that were still running raw text. - C2: rename `aclrtlaunch_add_rms_norm` / `aclrtlaunch_rms_norm` forward-decl parameter names from AscendC internals (`x1, x2, y, x_out`) to the base-header semantic names (`input, residual, weight, out, residual_out`). The extern "C" symbol is name-blind so the AscendC kernel .cpp can keep its local names — the wrapper .h just presents the public contract. - Pre-gathered rotary test: drop the hardcoded `implementation_index=(0, 1)` parametrize, let conftest auto-inject and skip impl 2 inline (the impl 2 kernel asserts `!pre_gathered_`). Verified locally (`--gpu-id 3/4/5 --local`): test_add_rms_norm.py: 108 passed test_rms_norm.py: 72 passed test_rotary_embedding.py: 88 passed, 16 skipped (expected: impl 2 + pre_gathered, impl 0 + non-neox) * refactor(pr66): rename AscendC custom kernels to PascalCase + C2 param order Addresses Ziminli's comment on `aclrtlaunch_add_rms_norm` forward-decl (#66 discussion 3115868675 / 3129096521): - **函数名格式:** the AscendC kernel entry-points `add_rms_norm` / `rms_norm` are renamed to `AddRmsNorm` / `RmsNorm`. The AscendC toolchain prepends `aclrtlaunch_` on the symbol regardless of case, so the exported names become `aclrtlaunch_AddRmsNorm` / `aclrtlaunch_RmsNorm` — matching the base-class names and `readability-identifier-naming.FunctionCase = CamelCase`. The `NOLINTNEXTLINE(readability-identifier-naming)` shim and the "PascalCase rule does not apply" apology comments go away. - **参数列表顺序 (C2):** reorder parameters to `inputs, attributes, outputs`. Both `.cpp` kernel entry, `KernelAddRmsNorm::Init` / `KernelRmsNorm::Init`, and the `extern "C"` forward-decl in `kernel_custom.h` are updated together, along with the call sites in `operator()`. - **Variable naming (`.cpp` internals):** `x1/x2/y/x_out` → `input/residual/out/residual_out`; `x/y` → `input/out`. Cascaded through member names (`*_gm_`, `*_queue_*`, `*_local`) for consistency — internal to each kernel class, no ABI impact. - **`op_host/*.cpp`:** updated to include the PascalCase generated header `aclrtlaunch_AddRmsNorm.h` / `aclrtlaunch_RmsNorm.h` and to match the reordered `EXEC_KERNEL_CMD` argument list. Verified locally with `.ci/run.py --local`: test_add_rms_norm.py: 108 passed test_rms_norm.py: 72 passed The AscendC toolchain successfully compiles the PascalCase kernel entries and generates matching launch headers — the `aclrtlaunch_` macro concatenates regardless of case. * refactor(pr66): trim commit-narration comments /simplify found 4 comment blocks that narrate the rename rationale rather than encode load-bearing contracts: - `kernel_custom.h` forward-decl — compress build-system detail (`no_workspace_kernel`, `ascendc_library()`) to one line, keep only the ABI contract (`aclrtlaunch_` is generated by AscendC from `op_kernel/`). - `op_host/.cpp` `EXEC_KERNEL_CMD` — drop "Parameter order follows the base class: inputs, attributes, outputs."; the signature itself is self-evident. - `op_kernel/.cpp` kernel entry — drop "Parameters follow the C2 convention ..." and "`aclrtlaunch_AddRmsNorm` matches the base `AddRmsNorm` class name"; these are commit-message material, not comments. --------- Co-authored-by: zhangyue --- CMakeLists.txt | 21 +- pyproject.toml | 9 + scripts/generate_wrappers.py | 27 +- src/CMakeLists.txt | 70 +- src/ascend/add_rms_norm/kernel.h | 144 ++++ src/ascend/add_rms_norm/kernel_custom.h | 171 +++++ src/ascend/add_rms_norm/kernel_fused.h | 132 ++++ src/ascend/causal_softmax/kernel.h | 173 +++++ src/ascend/custom/CMakeLists.txt | 18 +- .../add_rms_norm/op_host/add_rms_norm.cpp | 19 +- .../custom/add_rms_norm/op_kernel/.clang-tidy | 9 + .../add_rms_norm/op_kernel/add_rms_norm.cpp | 350 +++++---- src/ascend/custom/build.sh | 33 +- src/ascend/custom/cmake/config_ascend.cmake | 14 +- src/ascend/custom/cmake/detect_soc.cmake | 24 + .../custom/rms_norm/op_host/rms_norm.cpp | 18 +- .../custom/rms_norm/op_kernel/rms_norm.cpp | 281 +++---- src/ascend/linear/kernel.h | 6 + src/ascend/rms_norm/kernel.h | 100 +++ src/ascend/rms_norm/kernel_custom.h | 155 ++++ src/ascend/rotary_embedding/kernel.h | 373 +++++++++ src/ascend/rotary_embedding/kernel_atb.h | 449 +++++++++++ .../rotary_embedding/kernel_sincos_cache.h | 177 +++++ src/ascend/silu_and_mul/kernel.h | 127 +++ src/ascend/swiglu/kernel.h | 109 +++ src/ascend/swiglu/kernel_fused.h | 202 +++++ src/base/add_rms_norm.h | 41 +- src/base/linear.h | 33 +- src/base/rotary_embedding.h | 107 +-- src/base/silu_and_mul.h | 62 ++ src/cpu/linear/linear.h | 4 +- src/data_type.h | 31 +- tests/test_add_rms_norm.py | 113 +++ tests/test_rotary_embedding.py | 723 ++++++++++++++++++ tests/test_silu_and_mul.py | 76 ++ 35 files changed, 3963 insertions(+), 438 deletions(-) create mode 100644 src/ascend/add_rms_norm/kernel.h create mode 100644 src/ascend/add_rms_norm/kernel_custom.h create mode 100644 src/ascend/add_rms_norm/kernel_fused.h create mode 100644 src/ascend/causal_softmax/kernel.h create mode 100644 src/ascend/custom/add_rms_norm/op_kernel/.clang-tidy create mode 100644 src/ascend/custom/cmake/detect_soc.cmake create mode 100644 src/ascend/rms_norm/kernel.h create mode 100644 src/ascend/rms_norm/kernel_custom.h create mode 100644 src/ascend/rotary_embedding/kernel.h create mode 100644 src/ascend/rotary_embedding/kernel_atb.h create mode 100644 src/ascend/rotary_embedding/kernel_sincos_cache.h create mode 100644 src/ascend/silu_and_mul/kernel.h create mode 100644 src/ascend/swiglu/kernel.h create mode 100644 src/ascend/swiglu/kernel_fused.h create mode 100644 src/base/silu_and_mul.h create mode 100644 tests/test_add_rms_norm.py create mode 100644 tests/test_rotary_embedding.py create mode 100644 tests/test_silu_and_mul.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 91c2b0154..2e10db2e3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,12 +18,21 @@ option(WITH_ASCEND "Enable Ascend backend" OFF) option(WITH_TORCH "Enable PyTorch C++ backend" OFF) -# Default OFF until CANN's `extract_host_stub.py` path handling is fixed for -# `scikit-build-core` temp-dir builds (triggers `KeyError` on the preprocessed -# object path). Enable explicitly with `-DBUILD_CUSTOM_KERNEL=ON` when the -# toolchain is compatible or when building via the standalone -# `src/ascend/custom/build.sh` script. -option(BUILD_CUSTOM_KERNEL "Build custom AscendC kernel PyTorch extension (requires `torch_npu`)" OFF) +# Custom `AscendC` kernels under `src/ascend/custom/`. `ON` by default +# so CI and routine dev builds always exercise `implementation_index=1/2` +# for `RmsNorm` / `AddRmsNorm`. Gated by `WITH_ASCEND` in +# `src/CMakeLists.txt` — non-Ascend builds ignore it. Pass +# `-DBUILD_ASCEND_CUSTOM=OFF` to skip the `ccec` build on Ascend +# machines where the custom kernels aren't needed. +# +# When `ON`, `src/CMakeLists.txt` drives the standalone +# `src/ascend/custom/build.sh` via `execute_process` at configure time +# (sidesteps a `CANN` `extract_host_stub.py` path bug that breaks +# in-tree `ascendc_library()` under `scikit-build-core` temp-dir builds) +# and links the produced `libno_workspace_kernel.a` into the `ops` +# module with `--whole-archive`. Requires `torch_npu` and the +# `AscendC` toolchain (`ccec`). +option(BUILD_ASCEND_CUSTOM "Build custom AscendC kernels" ON) option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF) option(AUTO_DETECT_BACKENDS "Automatically detect available backends" OFF) diff --git a/pyproject.toml b/pyproject.toml index 959699f90..6b5170266 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,15 @@ name = "InfiniOps" version = "0.1.0" [project.optional-dependencies] +# TODO: `torch` here is unconstrained. On Ascend hosts, the working +# torch is the Ascend-matched `torch 2.9.0+cpu` paired with +# `torch_npu 2.9.0.post1+…`. A `pip install -e .[dev] --force-reinstall` +# will re-resolve `torch` to the latest PyPI version (currently +# `torch 2.11.0`), which now declares `cuda-toolkit` / `nvidia-cublas` / +# `nvidia-cudnn` / … as hard deps — downloads GBs of CUDA wheels and +# kills the `torch_npu` / `vllm-ascend` pairing. Needs a platform-aware +# split (e.g. `torch; platform_machine != 'aarch64'`, or move `torch` +# out of `dev` and require it pre-installed in the container image). dev = ["pytest", "pytest-cov", "pytest-xdist", "ruff", "torch", "pyyaml"] [tool.scikit-build.wheel] diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 49b6c199f..9810404d2 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -112,9 +112,29 @@ def _find_vector_tensor_params(op_name): return set(re.findall(r"std::vector\s+(\w+)", source)) +def _find_params_with_defaults(op_name): + """Return ``{param_name: default_literal}`` for base-header params that + carry a `= ` default value. `libclang`'s cursor API does not + expose defaults reliably, so we regex-scan the source. Only used for + plain scalar defaults such as ``bool pre_gathered = false``. + """ + source = (_BASE_DIR / f"{op_name}.h").read_text() + + mapping = {} + + for name, default in re.findall( + r"\b(?:bool|int(?:64_t|32_t|8_t|16_t)?|std::size_t|std::uint\w+_t|float|double)\s+(\w+)\s*=\s*([^,\)]+?)\s*(?:,|\))", + source, + ): + mapping[name] = default.strip() + + return mapping + + def _generate_pybind11(operator): optional_tensor_params = _find_optional_tensor_params(operator.name) vector_tensor_params = _find_vector_tensor_params(operator.name) + params_with_defaults = _find_params_with_defaults(operator.name) def _is_optional_tensor(arg): if arg.spelling in optional_tensor_params: @@ -186,6 +206,10 @@ def _generate_py_args(node): if _is_optional(arg): parts.append(f'py::arg("{arg.spelling}") = py::none()') + elif arg.spelling in params_with_defaults: + parts.append( + f'py::arg("{arg.spelling}") = {params_with_defaults[arg.spelling]}' + ) else: parts.append(f'py::arg("{arg.spelling}")') @@ -257,8 +281,7 @@ def _generate_call(op_name, call, method=True): }}) .def_static("clear_cache", &Self::clear_cache); -{callers} -}} +{callers}}} }} // namespace infini::ops diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 32c92949d..443ac0e2b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -241,8 +241,66 @@ if(WITH_ASCEND) list(APPEND DEVICE_LIST "ascend") # Custom `AscendC` kernels (PyTorch extension, requires `torch_npu`). - if(BUILD_CUSTOM_KERNEL) - add_subdirectory(ascend/custom) + if(BUILD_ASCEND_CUSTOM) + # In-tree `ascendc_library()` trips the `CANN` `extract_host_stub.py` + # path-handling bug under `scikit-build-core`'s temp-dir builds + # (`KeyError` on `/./workspace/...` paths in `$`). + # Work around it by driving the standalone `src/ascend/custom/build.sh` + # — that script invokes a separate `cmake` with + # `src/ascend/custom/` as its `SOURCE_DIR`, avoiding the buggy + # path shape. The produced `.a` is imported and linked into + # `ops` with `--whole-archive`. + set(_custom_build_dir "${CMAKE_SOURCE_DIR}/build/build_ascend_custom") + set(_custom_lib "${_custom_build_dir}/lib/libno_workspace_kernel.a") + + if(NOT DEFINED SOC_VERSION OR "${SOC_VERSION}" STREQUAL "") + include(${CMAKE_CURRENT_SOURCE_DIR}/ascend/custom/cmake/detect_soc.cmake) + infiniops_detect_soc(SOC_VERSION) + endif() + + # Drive `build.sh` as a build-phase target with explicit source + # dependencies so that editing any `op_host/` or `op_kernel/` + # source re-triggers the build (plain `execute_process` at + # configure time would only gate on file existence and leave + # stale `.a` files in place). + file(GLOB_RECURSE _custom_srcs CONFIGURE_DEPENDS + "${CMAKE_CURRENT_SOURCE_DIR}/ascend/custom/*.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/ascend/custom/*.h" + "${CMAKE_CURRENT_SOURCE_DIR}/ascend/custom/build.sh") + + # Scrub env inherited from the outer `scikit-build-core` invocation + # before handing control to `build.sh`: + # * `CMAKE_GENERATOR` / `CMAKE_EXPORT_COMPILE_COMMANDS` leaking + # into the inner `cmake` change the path format passed to + # `ninja`'s `_host_cpp` rule and re-trigger the `CANN` + # `extract_host_stub.py` `KeyError` (`/./workspace/...`) that + # standalone `build.sh` avoids. + # * `PYTHONPATH` from `pip`'s build-isolation overlay makes the + # child `python3` skip the system `site-packages` — child + # `cmake` modules that `import torch` (`config_envs.cmake`) + # then fail with `ModuleNotFoundError` even though `torch` is + # installed. + add_custom_command( + OUTPUT ${_custom_lib} + COMMAND ${CMAKE_COMMAND} -E env + --unset=CMAKE_GENERATOR + --unset=CMAKE_EXPORT_COMPILE_COMMANDS + --unset=CMAKE_BUILD_PARALLEL_LEVEL + --unset=PYTHONPATH + "BUILD_DIR=${_custom_build_dir}" + "CMAKE_EXE=${CMAKE_COMMAND}" + bash ${CMAKE_CURRENT_SOURCE_DIR}/ascend/custom/build.sh ${SOC_VERSION} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/ascend/custom + DEPENDS ${_custom_srcs} + COMMENT "Building custom AscendC kernels (SOC_VERSION=${SOC_VERSION})" + VERBATIM) + + add_custom_target(no_workspace_kernel_build ALL DEPENDS ${_custom_lib}) + + add_library(no_workspace_kernel STATIC IMPORTED GLOBAL) + set_target_properties(no_workspace_kernel PROPERTIES + IMPORTED_LOCATION "${_custom_lib}") + add_dependencies(no_workspace_kernel no_workspace_kernel_build) # Link the compiled `AscendC` kernel objects into `infiniops` so that # custom kernel implementations (e.g. `RmsNorm` index 1) can call @@ -379,9 +437,13 @@ if(GENERATE_PYTHON_BINDINGS) # The `Operator<..., 1>` template instantiations that call # `aclrtlaunch_*` live in `ops.cc`, so link here with # `--whole-archive` to ensure all launch functions are available. - if(BUILD_CUSTOM_KERNEL) + # `$` works for both real `ascendc_library()` targets and + # `IMPORTED` targets pointing at a pre-built `.a`. + if(BUILD_ASCEND_CUSTOM) target_link_libraries(ops PRIVATE - -Wl,--whole-archive no_workspace_kernel -Wl,--no-whole-archive) + -Wl,--whole-archive $ -Wl,--no-whole-archive) + # `ops` link step must wait for `build.sh` to produce the `.a`. + add_dependencies(ops no_workspace_kernel_build) endif() set_target_properties(infiniops PROPERTIES INSTALL_RPATH "$ORIGIN") diff --git a/src/ascend/add_rms_norm/kernel.h b/src/ascend/add_rms_norm/kernel.h new file mode 100644 index 000000000..38b0a5ab5 --- /dev/null +++ b/src/ascend/add_rms_norm/kernel.h @@ -0,0 +1,144 @@ +#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_H_ +#define INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_add.h" +#include "aclnn_rms_norm.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/add_rms_norm.h" +#include "operator.h" + +namespace infini::ops { + +// Decomposed implementation: `aclnnAdd` + `aclnnRmsNorm`. +// +// The fused `aclnnAddRmsNorm` API has ~200 us host-side launch overhead that +// dominates small-tensor dispatch. Decomposing into two fast ACLNN calls +// reduces host dispatch from ~224 us to ~56 us (4x faster) with negligible +// NPU-side impact for inference tensor sizes. +template <> +class Operator : public AddRmsNorm { + public: + Operator(const Tensor input, const Tensor residual, const Tensor weight, + float eps, Tensor out, Tensor residual_out) + : AddRmsNorm(input, residual, weight, eps, out, residual_out), + input_cache_(input), + residual_cache_(residual), + weight_cache_(weight), + out_cache_(out), + residual_out_cache_(residual_out) { + // Alpha scalar for `aclnnAdd` (`residual_out = input + 1.0 * residual`). + alpha_ = aclCreateScalar(&alpha_storage_, ACL_FLOAT); + + // `aclnnRmsNorm` writes `rstd` as a required side output. Size is + // computed here; the buffer is obtained from the pool in `operator()`. + rstd_shape_ = {static_cast(batch_size_), + static_cast(nhead_)}; + rstd_size_ = batch_size_ * nhead_ * sizeof(float); + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + input_cache_.release(); + residual_cache_.release(); + weight_cache_.release(); + out_cache_.release(); + residual_out_cache_.release(); + + // `rstd_tensor_` leaks with `norm_exec_` at shutdown (see `64c367c`). + if (alpha_) aclDestroyScalar(alpha_); + } + + void operator()(const Tensor input, const Tensor residual, + const Tensor weight, float eps, Tensor out, + Tensor residual_out) const override { + auto t_input = input_cache_.get(const_cast(input.data())); + auto t_residual = residual_cache_.get(const_cast(residual.data())); + auto t_weight = weight_cache_.get(const_cast(weight.data())); + auto t_out = out_cache_.get(out.data()); + auto t_residual_out = residual_out_cache_.get(residual_out.data()); + auto stream = static_cast(stream_); + + // Step 1: `residual_out = input + residual`. + if (!add_exec_) { + aclnnAddGetWorkspaceSize(t_input, t_residual, alpha_, t_residual_out, + &add_ws_, &add_exec_); + aclSetAclOpExecutorRepeatable(add_exec_); + } else { + aclSetInputTensorAddr(add_exec_, 0, t_input, + const_cast(input.data())); + aclSetInputTensorAddr(add_exec_, 1, t_residual, + const_cast(residual.data())); + aclSetOutputTensorAddr(add_exec_, 0, t_residual_out, residual_out.data()); + } + auto& add_arena = ascend::GetWorkspacePool().Ensure(stream, add_ws_); + aclnnAdd(add_arena.buf, add_ws_, add_exec_, stream); + + // Obtain shared `rstd` buffer from pool. + auto& rstd_arena = + ascend::GetWorkspacePool().Ensure(stream, rstd_size_, "temp"); + + // Lazily create the `rstd` tensor descriptor on first call. + if (!rstd_tensor_) { + rstd_tensor_ = aclCreateTensor(rstd_shape_.data(), 2, ACL_FLOAT, + /*strides=*/nullptr, 0, ACL_FORMAT_ND, + rstd_shape_.data(), 2, rstd_arena.buf); + } else { + aclSetRawTensorAddr(rstd_tensor_, rstd_arena.buf); + } + + // Step 2: `out = rms_norm(residual_out, weight, eps)`. + if (!norm_exec_) { + aclnnRmsNormGetWorkspaceSize(t_residual_out, t_weight, eps, t_out, + rstd_tensor_, &norm_ws_, &norm_exec_); + aclSetAclOpExecutorRepeatable(norm_exec_); + } else { + aclSetInputTensorAddr(norm_exec_, 0, t_residual_out, residual_out.data()); + aclSetInputTensorAddr(norm_exec_, 1, t_weight, + const_cast(weight.data())); + aclSetOutputTensorAddr(norm_exec_, 0, t_out, out.data()); + aclSetOutputTensorAddr(norm_exec_, 1, rstd_tensor_, rstd_arena.buf); + } + auto& norm_arena = ascend::GetWorkspacePool().Ensure(stream, norm_ws_); + aclnnRmsNorm(norm_arena.buf, norm_ws_, norm_exec_, stream); + } + + private: + mutable ascend::AclTensorCache input_cache_; + + mutable ascend::AclTensorCache residual_cache_; + + mutable ascend::AclTensorCache weight_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable ascend::AclTensorCache residual_out_cache_; + + float alpha_storage_ = 1.0f; + + aclScalar* alpha_ = nullptr; + + std::vector rstd_shape_; + + uint64_t rstd_size_ = 0; + + mutable aclTensor* rstd_tensor_ = nullptr; + + mutable aclOpExecutor* add_exec_ = nullptr; + + mutable uint64_t add_ws_ = 0; + + mutable aclOpExecutor* norm_exec_ = nullptr; + + mutable uint64_t norm_ws_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/add_rms_norm/kernel_custom.h b/src/ascend/add_rms_norm/kernel_custom.h new file mode 100644 index 000000000..daaa8c394 --- /dev/null +++ b/src/ascend/add_rms_norm/kernel_custom.h @@ -0,0 +1,171 @@ +#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_CUSTOM_H_ +#define INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_CUSTOM_H_ + +#ifdef INFINI_HAS_CUSTOM_KERNELS + +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_cast.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/add_rms_norm.h" +#include "operator.h" + +// Forward-declare the `aclrtlaunch_AddRmsNorm` launch symbol defined +// by the AscendC toolchain from `custom/add_rms_norm/op_kernel/`. +extern "C" uint32_t aclrtlaunch_AddRmsNorm( + uint32_t block_dim, void* stream, void* input, void* residual, void* weight, + int64_t total_rows, int64_t dim_length, int64_t dim_length_align, + int64_t former_num, int64_t former_length, int64_t tail_length, float eps, + int64_t dtype_code, void* out, void* residual_out); + +namespace infini::ops { + +// Custom AscendC fused `AddRmsNorm` kernel (implementation index 2). +// +// A single-kernel implementation that computes `residual_out = input + +// residual` followed by `out = rms_norm(residual_out, weight, eps)` in one +// launch, avoiding the decomposed `aclnnAdd` + `aclnnRmsNorm` calls (index 0) +// or the fused `aclnnAddRmsNorm` call (index 1). Migrated from the custom +// `RmsNorm` kernel (index 1 of `RmsNorm`). +// +// Select via `implementation_index=2` in Python: +// `infini.ops.add_rms_norm(input, residual, weight, eps, out, residual_out, +// implementation_index=2, stream=s)`. +// +// Requirements: +// - Input last dimension must be 32-byte aligned (divisible by 16 for +// `float16` or 8 for `float32`). All standard LLM hidden dimensions +// satisfy this. +// - `weight` must have the same dtype as `input`. +// - The custom kernel binary must be linked (`BUILD_ASCEND_CUSTOM=ON`). +template <> +class Operator : public AddRmsNorm { + public: + Operator(const Tensor input, const Tensor residual, const Tensor weight, + float eps, Tensor out, Tensor residual_out) + : AddRmsNorm(input, residual, weight, eps, out, residual_out), + dtype_{input.dtype()} { + assert((dtype_ == DataType::kFloat16 || dtype_ == DataType::kBFloat16 || + dtype_ == DataType::kFloat32) && + "`AddRmsNorm` custom kernel: `input` must be `fp16`, `bf16`, or " + "`fp32`"); + + // 32-byte alignment on the last dimension — kernel relies on aligned + // `DataCopyPad` loads/stores. + int64_t align_elems = 32 / static_cast(kDataTypeToSize.at(dtype_)); + dim_length_align_ = + ((static_cast(dim_) + align_elems - 1) / align_elems) * + align_elems; + assert(static_cast(dim_) == dim_length_align_ && + "`AddRmsNorm` custom kernel: last dimension must be 32-byte " + "aligned"); + + total_rows_ = + static_cast(batch_size_) * static_cast(nhead_); + + // The custom kernel always reads `weight` as fp32. fp16 / bf16 inputs + // trigger a lazy cast in `operator()` (guarded by `last_weight_ptr_` + // so that the cast runs only when the weight pointer changes — model + // weights are typically fixed after loading). + if (dtype_ != DataType::kFloat32) { + size_t fp32_bytes = static_cast(dim_) * sizeof(float); + aclrtMalloc(&weight_fp32_data_, fp32_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + weight_src_cache_ = ascend::AclTensorCache( + {static_cast(dim_)}, ascend::ToAclDtype(dtype_), nullptr); + weight_dst_cache_ = ascend::AclTensorCache({static_cast(dim_)}, + ACL_FLOAT, weight_fp32_data_); + } + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + weight_src_cache_.release(); + weight_dst_cache_.release(); + + if (weight_fp32_data_) aclrtFree(weight_fp32_data_); + } + + void operator()(const Tensor input, const Tensor residual, + const Tensor weight, float eps, Tensor out, + Tensor residual_out) const override { + auto stream = static_cast(stream_); + + void* weight_fp32; + + if (dtype_ != DataType::kFloat32) { + const void* cur_weight = weight.data(); + + // Model weights are fixed after loading, so the cast typically runs + // once on the first call and is skipped on all subsequent calls. + if (cur_weight != last_weight_ptr_) { + auto t_src = weight_src_cache_.get(const_cast(cur_weight)); + auto t_dst = weight_dst_cache_.get(weight_fp32_data_); + + if (!cast_exec_) { + aclnnCastGetWorkspaceSize(t_src, ACL_FLOAT, t_dst, &cast_ws_, + &cast_exec_); + aclSetAclOpExecutorRepeatable(cast_exec_); + } else { + aclSetInputTensorAddr(cast_exec_, 0, t_src, + const_cast(cur_weight)); + aclSetOutputTensorAddr(cast_exec_, 0, t_dst, weight_fp32_data_); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, cast_ws_); + aclnnCast(arena.buf, cast_ws_, cast_exec_, stream); + last_weight_ptr_ = cur_weight; + } + + weight_fp32 = weight_fp32_data_; + } else { + weight_fp32 = const_cast(weight.data()); + } + + // Block-level tiling. Ascend 910B has 20–40 AIV cores; over-subscribing + // is safe (runtime multiplexes) but wastes one weight load per block. + static constexpr int64_t kMaxBlockDim = 40; + int64_t used_cores = std::min(total_rows_, kMaxBlockDim); + int64_t former_length = (total_rows_ + used_cores - 1) / used_cores; + int64_t tail_length = former_length - 1; + int64_t former_num = total_rows_ - tail_length * used_cores; + uint32_t block_dim = static_cast(used_cores); + + aclrtlaunch_AddRmsNorm(block_dim, stream, const_cast(input.data()), + const_cast(residual.data()), weight_fp32, + total_rows_, static_cast(dim_), + dim_length_align_, former_num, former_length, + tail_length, eps, static_cast(dtype_), + out.data(), residual_out.data()); + } + + private: + DataType dtype_; + + int64_t dim_length_align_; + + int64_t total_rows_; + + void* weight_fp32_data_ = nullptr; + + mutable ascend::AclTensorCache weight_src_cache_; + + mutable ascend::AclTensorCache weight_dst_cache_; + + mutable const void* last_weight_ptr_ = nullptr; + + mutable aclOpExecutor* cast_exec_ = nullptr; + + mutable uint64_t cast_ws_ = 0; +}; + +} // namespace infini::ops + +#endif // INFINI_HAS_CUSTOM_KERNELS +#endif // INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_CUSTOM_H_ diff --git a/src/ascend/add_rms_norm/kernel_fused.h b/src/ascend/add_rms_norm/kernel_fused.h new file mode 100644 index 000000000..e28d7c287 --- /dev/null +++ b/src/ascend/add_rms_norm/kernel_fused.h @@ -0,0 +1,132 @@ +#ifndef INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_FUSED_H_ +#define INFINI_OPS_ASCEND_ADD_RMS_NORM_KERNEL_FUSED_H_ + +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnnop/aclnn_add_rms_norm.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/add_rms_norm.h" +#include "operator.h" + +namespace infini::ops { + +// Fused implementation via `aclnnAddRmsNorm` (implementation index 1). +// +// Computes `residual_out = input + residual` and `out = rms_norm(residual_out, +// weight, eps)` in a single CANN launch. The fused API has higher host-side +// launch overhead (~200 us) compared to the decomposed `aclnnAdd` + +// `aclnnRmsNorm` path (~39 us), but may offer better NPU-side efficiency for +// large tensors where kernel fusion reduces memory traffic. +// +// Select via `implementation_index=1` in Python: +// `infini.ops.add_rms_norm(..., implementation_index=1, stream=s)`. +template <> +class Operator : public AddRmsNorm { + public: + Operator(const Tensor input, const Tensor residual, const Tensor weight, + float eps, Tensor out, Tensor residual_out) + : AddRmsNorm(input, residual, weight, eps, out, residual_out), + input_cache_(input), + residual_cache_(residual), + weight_cache_(weight), + out_cache_(out), + residual_out_cache_(residual_out) { + // `aclnnAddRmsNorm` requires `rstdOut` to have the same ndim as `input`, + // with the last `weight.ndim()` dimensions set to 1. For example: + // `input` (2, 32, 128), `weight` (128) -> `rstdOut` (2, 32, 1). + // `input` (64, 128), `weight` (128) -> `rstdOut` (64, 1). + fused_rstd_shape_.reserve(ndim_); + for (size_t i = 0; i < ndim_ - weight.ndim(); ++i) { + fused_rstd_shape_.push_back(static_cast(input.size(i))); + } + for (size_t i = 0; i < weight.ndim(); ++i) { + fused_rstd_shape_.push_back(1); + } + + size_t rstd_elems = 1; + for (auto d : fused_rstd_shape_) { + rstd_elems *= static_cast(d); + } + size_t rstd_bytes = rstd_elems * sizeof(float); + aclrtMalloc(&rstd_data_, rstd_bytes, ACL_MEM_MALLOC_NORMAL_ONLY); + + rstd_tensor_ = aclCreateTensor( + fused_rstd_shape_.data(), + static_cast(fused_rstd_shape_.size()), ACL_FLOAT, + /*strides=*/nullptr, 0, ACL_FORMAT_ND, fused_rstd_shape_.data(), + static_cast(fused_rstd_shape_.size()), rstd_data_); + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + input_cache_.release(); + residual_cache_.release(); + weight_cache_.release(); + out_cache_.release(); + residual_out_cache_.release(); + + // `rstd_tensor_` leaks with the executor at shutdown (see `64c367c`). + if (rstd_data_) aclrtFree(rstd_data_); + } + + void operator()(const Tensor input, const Tensor residual, + const Tensor weight, float eps, Tensor out, + Tensor residual_out) const override { + auto t_input = input_cache_.get(const_cast(input.data())); + auto t_residual = residual_cache_.get(const_cast(residual.data())); + auto t_weight = weight_cache_.get(const_cast(weight.data())); + auto t_out = out_cache_.get(out.data()); + auto t_residual_out = residual_out_cache_.get(residual_out.data()); + auto stream = static_cast(stream_); + + if (!executor_) { + aclnnAddRmsNormGetWorkspaceSize( + t_input, t_residual, t_weight, static_cast(eps), t_out, + rstd_tensor_, t_residual_out, &ws_size_, &executor_); + aclSetAclOpExecutorRepeatable(executor_); + } else { + aclSetInputTensorAddr(executor_, 0, t_input, + const_cast(input.data())); + aclSetInputTensorAddr(executor_, 1, t_residual, + const_cast(residual.data())); + aclSetInputTensorAddr(executor_, 2, t_weight, + const_cast(weight.data())); + aclSetOutputTensorAddr(executor_, 0, t_out, out.data()); + // `rstd` at output index 1 has a stable address — no update needed. + aclSetOutputTensorAddr(executor_, 2, t_residual_out, residual_out.data()); + } + + auto& arena = ascend::GetWorkspacePool().Ensure(stream, ws_size_); + aclnnAddRmsNorm(arena.buf, ws_size_, executor_, stream); + } + + private: + mutable ascend::AclTensorCache input_cache_; + + mutable ascend::AclTensorCache residual_cache_; + + mutable ascend::AclTensorCache weight_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable ascend::AclTensorCache residual_out_cache_; + + std::vector fused_rstd_shape_; + + void* rstd_data_ = nullptr; + + aclTensor* rstd_tensor_ = nullptr; + + mutable aclOpExecutor* executor_ = nullptr; + + mutable uint64_t ws_size_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/causal_softmax/kernel.h b/src/ascend/causal_softmax/kernel.h new file mode 100644 index 000000000..975a03463 --- /dev/null +++ b/src/ascend/causal_softmax/kernel.h @@ -0,0 +1,173 @@ +#ifndef INFINI_OPS_ASCEND_CAUSAL_SOFTMAX_KERNEL_H_ +#define INFINI_OPS_ASCEND_CAUSAL_SOFTMAX_KERNEL_H_ + +#include +#include + +#include "acl/acl.h" +#include "aclnn/aclnn_base.h" +#include "aclnn_copy.h" +#include "aclnn_masked_fill_scalar.h" +#include "aclnn_softmax.h" +#include "ascend/common.h" +#include "ascend/workspace_pool_.h" +#include "base/causal_softmax.h" +#include "data_type.h" +#include "operator.h" + +namespace infini::ops { + +// CANN 8.5 has no single API covering causal-mask-then-softmax: the nearest +// candidates (`aclnnSoftmaxV2`, `aclnnScaledSoftmaxGrad`) do not accept a +// boolean mask argument, and `aclnnScaledMaskedSoftmax` requires a +// pre-scaled attention-score tensor produced inside flash-attention, not a +// standalone softmax input. Decomposing into three ACLNN calls is therefore +// unavoidable until a `aclnnCausalSoftmax` ships: +// 1. `aclnnInplaceCopy(temp, input)` — stride-aware copy to a contiguous +// `temp` buffer. +// 2. `aclnnInplaceMaskedFillScalar(temp, mask, -inf)` — apply the +// upper-triangle mask. +// 3. `aclnnSoftmax(temp, dim=-1, out)` — softmax over the last dimension. +// +// The boolean causal mask is pre-computed and uploaded to device once in the +// constructor. Its shape `(seq_len, total_seq_len)` broadcasts over the +// batch dimension. +template <> +class Operator : public CausalSoftmax { + public: + Operator(const Tensor input, Tensor out) + : CausalSoftmax(input, out), in_cache_(input), out_cache_(out) { + // Compute `temp` buffer size — allocated lazily from the pool in + // `operator()`. + size_t n_elems = input.numel(); + size_t elem_bytes = kDataTypeToSize.at(dtype_); + temp_size_ = n_elems * elem_bytes; + + // Build a contiguous `Tensor` descriptor — data pointer set on first use. + Tensor temp_t{nullptr, input.shape(), input.dtype(), input.device()}; + temp_cache_ = ascend::AclTensorCache(temp_t); + + // Causal mask: `mask[i][j] = 1` when position `j` must be masked for + // query `i`. Shape `(seq_len, total_seq_len)` broadcasts over the batch + // dimension. + size_t mask_elems = seq_len_ * total_seq_len_; + std::vector mask_host(mask_elems, 0); + + for (size_t i = 0; i < seq_len_; ++i) { + auto vis_end = static_cast(total_seq_len_ - seq_len_ + i); + + for (auto j = vis_end + 1; j < static_cast(total_seq_len_); + ++j) { + mask_host[i * total_seq_len_ + j] = 1; + } + } + + aclrtMalloc(&mask_buf_, mask_elems, ACL_MEM_MALLOC_NORMAL_ONLY); + aclrtMemcpy(mask_buf_, mask_elems, mask_host.data(), mask_elems, + ACL_MEMCPY_HOST_TO_DEVICE); + + std::vector mshape = {static_cast(seq_len_), + static_cast(total_seq_len_)}; + std::vector mstrides = {static_cast(total_seq_len_), 1}; + mask_tensor_ = aclCreateTensor(mshape.data(), mshape.size(), ACL_BOOL, + mstrides.data(), 0, ACL_FORMAT_ND, + mshape.data(), mshape.size(), mask_buf_); + + // Scalar `-inf` for the masked-fill step. `aclCreateScalar` stores the + // pointer rather than copying, so `neg_inf_storage_` must stay alive + // with the object. + neg_inf_ = aclCreateScalar(&neg_inf_storage_, ACL_FLOAT); + // Workspaces are allocated lazily on the first `operator()` call. + } + + ~Operator() { + if (!ascend::IsAclRuntimeAlive()) return; + + // Null cached descriptors — see `AclTensorCache::release()`. + in_cache_.release(); + out_cache_.release(); + temp_cache_.release(); + + // `mask_tensor_` leaks with `fill_exec_` at shutdown (see `64c367c`). + if (mask_buf_) aclrtFree(mask_buf_); + if (neg_inf_) aclDestroyScalar(neg_inf_); + } + + void operator()(const Tensor input, Tensor out) const override { + auto t_in = in_cache_.get(const_cast(input.data())); + auto t_out = out_cache_.get(out.data()); + auto stream = static_cast(stream_); + + // Obtain shared `temp` buffer from the pool. + auto& temp = ascend::GetWorkspacePool().Ensure(stream, temp_size_, "temp"); + auto t_temp = temp_cache_.get(temp.buf); + + // Step 1: copy `input` (possibly non-contiguous) into a contiguous `temp`. + if (!copy_exec_) { + aclnnInplaceCopyGetWorkspaceSize(t_temp, t_in, ©_ws_, ©_exec_); + aclSetAclOpExecutorRepeatable(copy_exec_); + } else { + aclSetInputTensorAddr(copy_exec_, 0, t_temp, temp.buf); + aclSetInputTensorAddr(copy_exec_, 1, t_in, + const_cast(input.data())); + } + auto& copy_arena = ascend::GetWorkspacePool().Ensure(stream, copy_ws_); + aclnnInplaceCopy(copy_arena.buf, copy_ws_, copy_exec_, stream); + + // Step 2: mask upper-triangle positions with `-inf` in-place. + // `mask_tensor_` and `neg_inf_` have stable addresses — first-call only. + if (!fill_exec_) { + aclnnInplaceMaskedFillScalarGetWorkspaceSize( + t_temp, mask_tensor_, neg_inf_, &fill_ws_, &fill_exec_); + aclSetAclOpExecutorRepeatable(fill_exec_); + } + auto& fill_arena = ascend::GetWorkspacePool().Ensure(stream, fill_ws_); + aclnnInplaceMaskedFillScalar(fill_arena.buf, fill_ws_, fill_exec_, stream); + + // Step 3: softmax over the last dimension -> `out`. + if (!softmax_exec_) { + constexpr int64_t kLastDim = -1; + aclnnSoftmaxGetWorkspaceSize(t_temp, kLastDim, t_out, &softmax_ws_, + &softmax_exec_); + aclSetAclOpExecutorRepeatable(softmax_exec_); + } else { + aclSetOutputTensorAddr(softmax_exec_, 0, t_out, out.data()); + } + auto& softmax_arena = + ascend::GetWorkspacePool().Ensure(stream, softmax_ws_); + aclnnSoftmax(softmax_arena.buf, softmax_ws_, softmax_exec_, stream); + } + + private: + mutable ascend::AclTensorCache in_cache_; + + mutable ascend::AclTensorCache out_cache_; + + mutable ascend::AclTensorCache temp_cache_; + + float neg_inf_storage_ = -std::numeric_limits::infinity(); + + uint64_t temp_size_ = 0; + + void* mask_buf_ = nullptr; + + aclTensor* mask_tensor_ = nullptr; + + aclScalar* neg_inf_ = nullptr; + + mutable aclOpExecutor* copy_exec_ = nullptr; + + mutable uint64_t copy_ws_ = 0; + + mutable aclOpExecutor* fill_exec_ = nullptr; + + mutable uint64_t fill_ws_ = 0; + + mutable aclOpExecutor* softmax_exec_ = nullptr; + + mutable uint64_t softmax_ws_ = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/ascend/custom/CMakeLists.txt b/src/ascend/custom/CMakeLists.txt index ca6e6883f..fb9004199 100644 --- a/src/ascend/custom/CMakeLists.txt +++ b/src/ascend/custom/CMakeLists.txt @@ -30,8 +30,6 @@ else() endif() set(PROJECT_OP_SRC_BASE ${PROJECT_SOURCE_DIR}) -set(PROJECT_BUILD_PATH ${PROJECT_SOURCE_DIR}/build) -set(PROJECT_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/output) include(cmake/config_envs.cmake) include(cmake/config_ascend.cmake) @@ -43,13 +41,15 @@ if(CCACHE_PROGRAM) set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") endif() -# Shared library output location. -set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_OUTPUT_PATH}) +# `CMAKE_LIBRARY_OUTPUT_DIRECTORY` is set by `build.sh` so that the +# standalone `libascend_kernel.so` lands next to `libno_workspace_kernel.a` +# under `/build/build_ascend_custom/output/`. # Host-side files. file(GLOB OP_SRCS ${PROJECT_OP_SRC_BASE}/torch_binding.cpp ${PROJECT_OP_SRC_BASE}/rms_norm/op_host/rms_norm.cpp + ${PROJECT_OP_SRC_BASE}/add_rms_norm/op_host/add_rms_norm.cpp ) # Shared library name — consumed by `kernel_custom.h` variants and by the @@ -59,8 +59,18 @@ set(OP_PLUGIN_NAME ascend_kernel) # Kernel-side files (device code compiled by the `AscendC` toolchain). ascendc_library(no_workspace_kernel STATIC ${PROJECT_OP_SRC_BASE}/rms_norm/op_kernel/rms_norm.cpp + ${PROJECT_OP_SRC_BASE}/add_rms_norm/op_kernel/add_rms_norm.cpp ) +# The kernel translation units include `"data_type_enum.h"` from the main +# project's `src/` so that launcher and device code share one `DataType` +# enum. `ascendc_library` forwards the interface target's `INCLUDES` +# property to the nested `ExternalProject_Add` (see +# `${ASCEND_HOME_PATH}/tools/tikcpp/ascendc_kernel_cmake/legacy_modules/function.cmake`), +# so append the main `src/` dir here. +set_property(TARGET no_workspace_kernel_interface APPEND PROPERTY + INCLUDES ${PROJECT_OP_SRC_BASE}/../..) + # Create the shared library `libascend_kernel.so`. add_library(${OP_PLUGIN_NAME} SHARED ${OP_SRCS}) diff --git a/src/ascend/custom/add_rms_norm/op_host/add_rms_norm.cpp b/src/ascend/custom/add_rms_norm/op_host/add_rms_norm.cpp index b8e0d504b..b561eaaa7 100644 --- a/src/ascend/custom/add_rms_norm/op_host/add_rms_norm.cpp +++ b/src/ascend/custom/add_rms_norm/op_host/add_rms_norm.cpp @@ -1,4 +1,4 @@ -#include "aclrtlaunch_add_rms_norm.h" +#include "aclrtlaunch_AddRmsNorm.h" #include "tiling/platform/platform_ascendc.h" #include "torch_kernel_helper.h" @@ -105,16 +105,13 @@ std::vector AddRmsNorm(const at::Tensor& x1, const at::Tensor& x2, float eps_float = static_cast(eps); int64_t dtype_size_val = dtype_size; - // The first arg `add_rms_norm` is the AscendC kernel entry-point name — it - // must match `ascendc_add_operator(OP_NAME add_rms_norm)` in `CMakeLists.txt`, - // the `__global__ __aicore__ void add_rms_norm(...)` definition in - // `op_kernel/`, and the generated `aclrtlaunch_add_rms_norm.h` header. - // Google C++ Style's PascalCase rule does NOT apply: this identifier is - // dictated by the AscendC toolchain's symbol convention. - EXEC_KERNEL_CMD(add_rms_norm, block_dim, kernel_input1, kernel_input2, - weight_float, kernel_output_y, kernel_output_x_out, - total_rows, dim_length, dim_length_align, former_num, - former_length, tail_length, eps_float, dtype_size_val); + // The first arg `AddRmsNorm` is the AscendC kernel entry-point name — it + // must match the `__global__ __aicore__ void AddRmsNorm(...)` definition + // in `op_kernel/` and the generated `aclrtlaunch_AddRmsNorm.h` header. + EXEC_KERNEL_CMD(AddRmsNorm, block_dim, kernel_input1, kernel_input2, + weight_float, total_rows, dim_length, dim_length_align, + former_num, former_length, tail_length, eps_float, + dtype_size_val, kernel_output_y, kernel_output_x_out); // Remove padding and reshape back to original shape. at::Tensor output_y = kernel_output_y; diff --git a/src/ascend/custom/add_rms_norm/op_kernel/.clang-tidy b/src/ascend/custom/add_rms_norm/op_kernel/.clang-tidy new file mode 100644 index 000000000..ccf13972c --- /dev/null +++ b/src/ascend/custom/add_rms_norm/op_kernel/.clang-tidy @@ -0,0 +1,9 @@ +--- +# `op_kernel/*.cpp` is `AscendC` device code compiled by `ccec`, not by +# the host toolchain, so it has no entry in `compile_commands.json` and +# `clang-tidy` cannot parse it correctly (the `__aicore__` macro expands +# unexpectedly when `kernel_operator.h` is absent). Disable all checks +# here — the `op_host/` side and the `kernel_custom.h` launcher still +# enforce the full ruleset. + +Checks: '-*' diff --git a/src/ascend/custom/add_rms_norm/op_kernel/add_rms_norm.cpp b/src/ascend/custom/add_rms_norm/op_kernel/add_rms_norm.cpp index e2a08e555..4b677d357 100644 --- a/src/ascend/custom/add_rms_norm/op_kernel/add_rms_norm.cpp +++ b/src/ascend/custom/add_rms_norm/op_kernel/add_rms_norm.cpp @@ -1,98 +1,102 @@ +#include "data_type.h" #include "kernel_operator.h" -constexpr int32_t BUFFER_NUM = 2; +constexpr int32_t kBufferNum = 2; template class KernelAddRmsNorm { public: __aicore__ inline KernelAddRmsNorm() {} - __aicore__ inline void Init(GM_ADDR x1, GM_ADDR x2, GM_ADDR weight, GM_ADDR y, - GM_ADDR x_out, int64_t totalRows, - int64_t dimLength, int64_t dimLengthAlign, - int64_t formerNum, int64_t formerLength, - int64_t tailLength, float eps) { - this->dimLength = dimLength; - this->dimLengthAlign = dimLengthAlign; - this->eps = eps; + __aicore__ inline void Init(GM_ADDR input, GM_ADDR residual, GM_ADDR weight, + int64_t total_rows, int64_t dim_length, + int64_t dim_length_align, int64_t former_num, + int64_t former_length, int64_t tail_length, + float eps, GM_ADDR out, GM_ADDR residual_out) { + dim_length_ = dim_length; + dim_length_align_ = dim_length_align; + eps_ = eps; // Block-level tiling: determine row range for this core. - int64_t blockIdx = AscendC::GetBlockIdx(); - int64_t rowOffset; + int64_t block_idx = AscendC::GetBlockIdx(); + int64_t row_offset; - if (blockIdx < formerNum) { - this->blockRows = formerLength; - rowOffset = formerLength * blockIdx; + if (block_idx < former_num) { + block_rows_ = former_length; + row_offset = former_length * block_idx; } else { - this->blockRows = tailLength; - int64_t tailIdx = blockIdx - formerNum; - rowOffset = formerLength * formerNum + tailLength * tailIdx; + block_rows_ = tail_length; + int64_t tail_idx = block_idx - former_num; + row_offset = former_length * former_num + tail_length * tail_idx; } // Global memory pointers. - x1Gm.SetGlobalBuffer((__gm__ T*)x1 + rowOffset * dimLengthAlign, - this->blockRows * dimLengthAlign); - x2Gm.SetGlobalBuffer((__gm__ T*)x2 + rowOffset * dimLengthAlign, - this->blockRows * dimLengthAlign); - yGm.SetGlobalBuffer((__gm__ T*)y + rowOffset * dimLengthAlign, - this->blockRows * dimLengthAlign); - xOutGm.SetGlobalBuffer((__gm__ T*)x_out + rowOffset * dimLengthAlign, - this->blockRows * dimLengthAlign); - weightGm.SetGlobalBuffer((__gm__ float*)weight, dimLengthAlign); - - int32_t dimLenAlign = static_cast(this->dimLengthAlign); + input_gm_.SetGlobalBuffer((__gm__ T*)input + row_offset * dim_length_align, + block_rows_ * dim_length_align); + residual_gm_.SetGlobalBuffer( + (__gm__ T*)residual + row_offset * dim_length_align, + block_rows_ * dim_length_align); + out_gm_.SetGlobalBuffer((__gm__ T*)out + row_offset * dim_length_align, + block_rows_ * dim_length_align); + residual_out_gm_.SetGlobalBuffer( + (__gm__ T*)residual_out + row_offset * dim_length_align, + block_rows_ * dim_length_align); + weight_gm_.SetGlobalBuffer((__gm__ float*)weight, dim_length_align); + + int32_t dim_len_align = static_cast(dim_length_align_); // I/O queues (double-buffered). - pipe.InitBuffer(inQueueX1, BUFFER_NUM, - dimLenAlign * static_cast(sizeof(T))); - pipe.InitBuffer(inQueueX2, BUFFER_NUM, - dimLenAlign * static_cast(sizeof(T))); - pipe.InitBuffer(outQueueY, BUFFER_NUM, - dimLenAlign * static_cast(sizeof(T))); - pipe.InitBuffer(outQueueXOut, BUFFER_NUM, - dimLenAlign * static_cast(sizeof(T))); + pipe_.InitBuffer(in_queue_input_, kBufferNum, + dim_len_align * static_cast(sizeof(T))); + pipe_.InitBuffer(in_queue_residual_, kBufferNum, + dim_len_align * static_cast(sizeof(T))); + pipe_.InitBuffer(out_queue_out_, kBufferNum, + dim_len_align * static_cast(sizeof(T))); + pipe_.InitBuffer(out_queue_residual_out_, kBufferNum, + dim_len_align * static_cast(sizeof(T))); // Weight buffer (fp32, loaded once, reused for all rows). - pipe.InitBuffer(weightBuf, - dimLenAlign * static_cast(sizeof(float))); + pipe_.InitBuffer(weight_buf_, + dim_len_align * static_cast(sizeof(float))); - // FP16 path needs extra fp32 compute buffers. - // buf1: holds x_out in fp32 (reused from x1_fp32 after Add). - // buf2: holds x2_fp32 initially, then x_out^2, then final result. + // FP16/BF16 path needs extra fp32 compute buffers. + // `fp32_buf1_`: holds `x_out` in fp32 (reused from `x1_fp32` after Add). + // `fp32_buf2_`: holds `x2_fp32` initially, then `x_out^2`, then final + // result. if constexpr (sizeof(T) == 2) { - pipe.InitBuffer(fp32Buf1, - dimLenAlign * static_cast(sizeof(float))); - pipe.InitBuffer(fp32Buf2, - dimLenAlign * static_cast(sizeof(float))); + pipe_.InitBuffer(fp32_buf1_, + dim_len_align * static_cast(sizeof(float))); + pipe_.InitBuffer(fp32_buf2_, + dim_len_align * static_cast(sizeof(float))); } - // ReduceSum temporary buffer (size per API formula). - constexpr int32_t ELEMS_PER_REPEAT = 256 / sizeof(float); - constexpr int32_t ELEMS_PER_BLOCK = 32 / sizeof(float); - int32_t firstMaxRepeat = - (dimLenAlign + ELEMS_PER_REPEAT - 1) / ELEMS_PER_REPEAT; - int32_t reduceTmpSize = - ((firstMaxRepeat + ELEMS_PER_BLOCK - 1) / ELEMS_PER_BLOCK) * - ELEMS_PER_BLOCK; - pipe.InitBuffer(reduceTmpBuf, - reduceTmpSize * static_cast(sizeof(float))); + // `ReduceSum` temporary buffer (size per API formula). + constexpr int32_t kElemsPerRepeat = 256 / sizeof(float); + constexpr int32_t kElemsPerBlock = 32 / sizeof(float); + int32_t first_max_repeat = + (dim_len_align + kElemsPerRepeat - 1) / kElemsPerRepeat; + int32_t reduce_tmp_size = + ((first_max_repeat + kElemsPerBlock - 1) / kElemsPerBlock) * + kElemsPerBlock; + pipe_.InitBuffer(reduce_tmp_buf_, + reduce_tmp_size * static_cast(sizeof(float))); // Scalar buffer for reduction result (8 floats = 32 bytes). - pipe.InitBuffer(sumBuf, 32); + pipe_.InitBuffer(sum_buf_, 32); - // Load weight (fp32) from GM into `weightBuf`. - AscendC::LocalTensor wLocal = weightBuf.Get(); - AscendC::DataCopyExtParams wParams{ - 1, static_cast(dimLenAlign * sizeof(float)), 0, 0, 0}; - AscendC::DataCopyPadExtParams wPad{false, 0, 0, 0.0f}; - AscendC::DataCopyPad(wLocal, weightGm, wParams, wPad); + // Load weight (fp32) from GM into `weight_buf_`. + AscendC::LocalTensor w_local = weight_buf_.Get(); + AscendC::DataCopyExtParams w_params{ + 1, static_cast(dim_len_align * sizeof(float)), 0, 0, 0}; + AscendC::DataCopyPadExtParams w_pad{false, 0, 0, 0.0f}; + AscendC::DataCopyPad(w_local, weight_gm_, w_params, w_pad); // Ensure weight DMA completes before compute. AscendC::PipeBarrier(); } __aicore__ inline void Process() { - for (int64_t row = 0; row < this->blockRows; ++row) { + for (int64_t row = 0; row < block_rows_; ++row) { CopyIn(row); Compute(row); CopyOut(row); @@ -101,149 +105,175 @@ class KernelAddRmsNorm { private: __aicore__ inline void CopyIn(int64_t row) { - AscendC::LocalTensor x1Local = inQueueX1.AllocTensor(); - AscendC::LocalTensor x2Local = inQueueX2.AllocTensor(); + AscendC::LocalTensor input_local = in_queue_input_.AllocTensor(); + AscendC::LocalTensor residual_local = + in_queue_residual_.AllocTensor(); AscendC::DataCopyExtParams params{ - 1, static_cast(this->dimLengthAlign * sizeof(T)), 0, 0, 0}; + 1, static_cast(dim_length_align_ * sizeof(T)), 0, 0, 0}; AscendC::DataCopyPadExtParams pad{false, 0, 0, static_cast(0)}; - AscendC::DataCopyPad(x1Local, x1Gm[row * this->dimLengthAlign], params, - pad); - AscendC::DataCopyPad(x2Local, x2Gm[row * this->dimLengthAlign], params, - pad); - inQueueX1.EnQue(x1Local); - inQueueX2.EnQue(x2Local); + AscendC::DataCopyPad(input_local, input_gm_[row * dim_length_align_], + params, pad); + AscendC::DataCopyPad(residual_local, residual_gm_[row * dim_length_align_], + params, pad); + in_queue_input_.EnQue(input_local); + in_queue_residual_.EnQue(residual_local); } __aicore__ inline void Compute(int64_t row) { - AscendC::LocalTensor x1Local = inQueueX1.DeQue(); - AscendC::LocalTensor x2Local = inQueueX2.DeQue(); - AscendC::LocalTensor yLocal = outQueueY.AllocTensor(); - AscendC::LocalTensor xOutLocal = outQueueXOut.AllocTensor(); + AscendC::LocalTensor input_local = in_queue_input_.DeQue(); + AscendC::LocalTensor residual_local = in_queue_residual_.DeQue(); + AscendC::LocalTensor out_local = out_queue_out_.AllocTensor(); + AscendC::LocalTensor residual_out_local = + out_queue_residual_out_.AllocTensor(); - AscendC::LocalTensor wLocal = weightBuf.Get(); - AscendC::LocalTensor rTmp = reduceTmpBuf.Get(); - AscendC::LocalTensor sLocal = sumBuf.Get(); + AscendC::LocalTensor w_local = weight_buf_.Get(); + AscendC::LocalTensor r_tmp = reduce_tmp_buf_.Get(); + AscendC::LocalTensor s_local = sum_buf_.Get(); - int32_t dimLen = static_cast(this->dimLength); - int32_t dimLenAlign = static_cast(this->dimLengthAlign); + int32_t dim_len = static_cast(dim_length_); + int32_t dim_len_align = static_cast(dim_length_align_); if constexpr (sizeof(T) == 4) { // ---- FP32 path: compute directly. ---- // Step 1: x_out = x1 + x2. - AscendC::Add(xOutLocal, x1Local, x2Local, dimLenAlign); + AscendC::Add(residual_out_local, input_local, residual_local, + dim_len_align); - // Step 2: x_out^2 into yLocal (reuse output buffer temporarily). - AscendC::Mul(yLocal, xOutLocal, xOutLocal, dimLenAlign); + // Step 2: x_out^2 into out_local (reuse output buffer temporarily). + AscendC::Mul(out_local, residual_out_local, residual_out_local, + dim_len_align); - // Step 3: ReduceSum(x_out^2) -> sLocal[0]. - // ReduceSum may modify yLocal, but we overwrite it below. - AscendC::ReduceSum(sLocal, yLocal, rTmp, dimLenAlign); + // Step 3: ReduceSum(x_out^2) -> s_local[0]. + // `ReduceSum` may modify `out_local`, but we overwrite it below. + AscendC::ReduceSum(s_local, out_local, r_tmp, dim_len_align); // Step 4-5: scale = 1 / sqrt(mean(x_out^2) + eps). - float sumVal = sLocal.GetValue(0); - float meanVal = sumVal / static_cast(dimLen) + this->eps; - sLocal.SetValue(0, meanVal); - AscendC::Sqrt(sLocal, sLocal, 8); - float scale = 1.0f / sLocal.GetValue(0); + float sum_val = s_local.GetValue(0); + float mean_val = sum_val / static_cast(dim_len) + eps_; + s_local.SetValue(0, mean_val); + AscendC::Sqrt(s_local, s_local, 8); + float scale = 1.0f / s_local.GetValue(0); // Step 6: y = x_out * scale. - AscendC::Muls(yLocal, xOutLocal, scale, dimLenAlign); + AscendC::Muls(out_local, residual_out_local, scale, dim_len_align); // Step 7: y = y * weight. - AscendC::Mul(yLocal, yLocal, wLocal, dimLenAlign); + AscendC::Mul(out_local, out_local, w_local, dim_len_align); } else { - // ---- FP16 path: cast → fp32 compute → cast back. ---- - AscendC::LocalTensor b1 = fp32Buf1.Get(); - AscendC::LocalTensor b2 = fp32Buf2.Get(); + // ---- FP16/BF16 path: cast → fp32 compute → cast back. ---- + AscendC::LocalTensor b1 = fp32_buf1_.Get(); + AscendC::LocalTensor b2 = fp32_buf2_.Get(); - // Cast inputs fp16 → fp32. - AscendC::Cast(b1, x1Local, AscendC::RoundMode::CAST_NONE, dimLenAlign); - AscendC::Cast(b2, x2Local, AscendC::RoundMode::CAST_NONE, dimLenAlign); + // Cast inputs fp16/bf16 → fp32. + AscendC::Cast(b1, input_local, AscendC::RoundMode::CAST_NONE, + dim_len_align); + AscendC::Cast(b2, residual_local, AscendC::RoundMode::CAST_NONE, + dim_len_align); // Step 1: x_out = x1 + x2 (fp32), stored in b1. - AscendC::Add(b1, b1, b2, dimLenAlign); + AscendC::Add(b1, b1, b2, dim_len_align); - // Cast x_out fp32 → fp16 for the x_out output. - AscendC::Cast(xOutLocal, b1, AscendC::RoundMode::CAST_ROUND, dimLenAlign); + // Cast `x_out` fp32 → fp16/bf16 for the residual output. + AscendC::Cast(residual_out_local, b1, AscendC::RoundMode::CAST_RINT, + dim_len_align); // Step 2: x_out^2 in fp32, stored in b2. - AscendC::Mul(b2, b1, b1, dimLenAlign); + AscendC::Mul(b2, b1, b1, dim_len_align); - // Step 3: ReduceSum(x_out^2) -> sLocal[0]. - AscendC::ReduceSum(sLocal, b2, rTmp, dimLenAlign); + // Step 3: ReduceSum(x_out^2) -> s_local[0]. + AscendC::ReduceSum(s_local, b2, r_tmp, dim_len_align); // Step 4-5: scale = 1 / sqrt(mean(x_out^2) + eps). - float sumVal = sLocal.GetValue(0); - float meanVal = sumVal / static_cast(dimLen) + this->eps; - sLocal.SetValue(0, meanVal); - AscendC::Sqrt(sLocal, sLocal, 8); - float scale = 1.0f / sLocal.GetValue(0); + float sum_val = s_local.GetValue(0); + float mean_val = sum_val / static_cast(dim_len) + eps_; + s_local.SetValue(0, mean_val); + AscendC::Sqrt(s_local, s_local, 8); + float scale = 1.0f / s_local.GetValue(0); // Step 6: y = x_out * scale (fp32), reuse b2. - AscendC::Muls(b2, b1, scale, dimLenAlign); + AscendC::Muls(b2, b1, scale, dim_len_align); // Step 7: y = y * weight (fp32). - AscendC::Mul(b2, b2, wLocal, dimLenAlign); + AscendC::Mul(b2, b2, w_local, dim_len_align); - // Cast result fp32 → fp16. - AscendC::Cast(yLocal, b2, AscendC::RoundMode::CAST_ROUND, dimLenAlign); + AscendC::Cast(out_local, b2, AscendC::RoundMode::CAST_RINT, + dim_len_align); } - inQueueX1.FreeTensor(x1Local); - inQueueX2.FreeTensor(x2Local); - outQueueY.EnQue(yLocal); - outQueueXOut.EnQue(xOutLocal); + in_queue_input_.FreeTensor(input_local); + in_queue_residual_.FreeTensor(residual_local); + out_queue_out_.EnQue(out_local); + out_queue_residual_out_.EnQue(residual_out_local); } __aicore__ inline void CopyOut(int64_t row) { - AscendC::LocalTensor yLocal = outQueueY.DeQue(); - AscendC::LocalTensor xOutLocal = outQueueXOut.DeQue(); + AscendC::LocalTensor out_local = out_queue_out_.DeQue(); + AscendC::LocalTensor residual_out_local = + out_queue_residual_out_.DeQue(); AscendC::DataCopyExtParams params{ - 1, static_cast(this->dimLengthAlign * sizeof(T)), 0, 0, 0}; - AscendC::DataCopyPad(yGm[row * this->dimLengthAlign], yLocal, params); - AscendC::DataCopyPad(xOutGm[row * this->dimLengthAlign], xOutLocal, params); - outQueueY.FreeTensor(yLocal); - outQueueXOut.FreeTensor(xOutLocal); + 1, static_cast(dim_length_align_ * sizeof(T)), 0, 0, 0}; + AscendC::DataCopyPad(out_gm_[row * dim_length_align_], out_local, params); + AscendC::DataCopyPad(residual_out_gm_[row * dim_length_align_], + residual_out_local, params); + out_queue_out_.FreeTensor(out_local); + out_queue_residual_out_.FreeTensor(residual_out_local); } private: - AscendC::TPipe pipe; - AscendC::TQue inQueueX1; - AscendC::TQue inQueueX2; - AscendC::TQue outQueueY; - AscendC::TQue outQueueXOut; - - AscendC::TBuf weightBuf; - AscendC::TBuf fp32Buf1; - AscendC::TBuf fp32Buf2; - AscendC::TBuf reduceTmpBuf; - AscendC::TBuf sumBuf; - - AscendC::GlobalTensor x1Gm, x2Gm, yGm, xOutGm; - AscendC::GlobalTensor weightGm; - - int64_t blockRows; - int64_t dimLength; - int64_t dimLengthAlign; - float eps; + AscendC::TPipe pipe_; + AscendC::TQue in_queue_input_; + AscendC::TQue in_queue_residual_; + AscendC::TQue out_queue_out_; + AscendC::TQue out_queue_residual_out_; + + AscendC::TBuf weight_buf_; + AscendC::TBuf fp32_buf1_; + AscendC::TBuf fp32_buf2_; + AscendC::TBuf reduce_tmp_buf_; + AscendC::TBuf sum_buf_; + + AscendC::GlobalTensor input_gm_, residual_gm_, out_gm_, residual_out_gm_; + AscendC::GlobalTensor weight_gm_; + + int64_t block_rows_; + int64_t dim_length_; + int64_t dim_length_align_; + float eps_; }; -extern "C" __global__ __aicore__ void add_rms_norm( - GM_ADDR x1, GM_ADDR x2, GM_ADDR weight, GM_ADDR y, GM_ADDR x_out, - int64_t totalRows, int64_t dimLength, int64_t dimLengthAlign, - int64_t formerNum, int64_t formerLength, int64_t tailLength, float eps, - int64_t dtypeSize) { - if (dtypeSize == 2) { - KernelAddRmsNorm op; - op.Init(x1, x2, weight, y, x_out, totalRows, dimLength, dimLengthAlign, - formerNum, formerLength, tailLength, eps); - op.Process(); - } else { - KernelAddRmsNorm op; - op.Init(x1, x2, weight, y, x_out, totalRows, dimLength, dimLengthAlign, - formerNum, formerLength, tailLength, eps); - op.Process(); +// `dtype_code` is `static_cast(infini::ops::DataType)` forwarded +// by the host launcher. fp16 and bf16 both have `sizeof == 2` but need +// distinct numeric paths, so dispatch is on the `DataType` tag rather +// than the byte size. +extern "C" __global__ __aicore__ void AddRmsNorm( + GM_ADDR input, GM_ADDR residual, GM_ADDR weight, int64_t total_rows, + int64_t dim_length, int64_t dim_length_align, int64_t former_num, + int64_t former_length, int64_t tail_length, float eps, int64_t dtype_code, + GM_ADDR out, GM_ADDR residual_out) { + switch (static_cast(dtype_code)) { + case infini::ops::DataType::kFloat16: { + KernelAddRmsNorm op; + op.Init(input, residual, weight, total_rows, dim_length, dim_length_align, + former_num, former_length, tail_length, eps, out, residual_out); + op.Process(); + break; + } + case infini::ops::DataType::kBFloat16: { + KernelAddRmsNorm op; + op.Init(input, residual, weight, total_rows, dim_length, dim_length_align, + former_num, former_length, tail_length, eps, out, residual_out); + op.Process(); + break; + } + case infini::ops::DataType::kFloat32: + default: { + KernelAddRmsNorm op; + op.Init(input, residual, weight, total_rows, dim_length, dim_length_align, + former_num, former_length, tail_length, eps, out, residual_out); + op.Process(); + break; + } } } diff --git a/src/ascend/custom/build.sh b/src/ascend/custom/build.sh index 258a88e4b..837408816 100755 --- a/src/ascend/custom/build.sh +++ b/src/ascend/custom/build.sh @@ -1,30 +1,45 @@ #!/bin/bash -# Build custom `AscendC` kernels into `libascend_kernel.so`. +# Build custom `AscendC` kernels into `libno_workspace_kernel.a` (+ the +# standalone `libascend_kernel.so`). +# +# Intermediate artefacts default to `/build/build_ascend_custom/` +# so the source tree under `src/` stays free of build output. Override +# via `BUILD_DIR= bash build.sh …` if needed. set -e SOC_VERSION="${1:-Ascend910_9382}" +# Use the same `cmake` the caller resolved (default: first `cmake` on +# PATH). The outer `src/CMakeLists.txt` forwards `${CMAKE_COMMAND}` +# via `CMAKE_EXE` so the child build doesn't accidentally pick up the +# PyPI `cmake` shim whose Python package only exists in `pip`'s +# build-isolation overlay. +CMAKE_EXE="${CMAKE_EXE:-cmake}" + # Detect CANN toolkit path. _CANN_TOOLKIT_INSTALL_PATH=$(grep "Toolkit_InstallPath" /etc/Ascend/ascend_cann_install.info | awk -F'=' '{print $2}') source "${_CANN_TOOLKIT_INSTALL_PATH}/set_env.sh" echo "CANN: ${ASCEND_TOOLKIT_HOME}" ASCEND_INCLUDE_DIR=${ASCEND_TOOLKIT_HOME}/$(arch)-linux/include -CURRENT_DIR=$(pwd) -OUTPUT_DIR=${CURRENT_DIR}/output -mkdir -p "${OUTPUT_DIR}" -BUILD_DIR=build +# Resolve build directory. `