diff --git a/BUILD.bazel b/BUILD.bazel index deb376bf..d61492b0 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -27,6 +27,13 @@ exports_files([ ".github/workflows/build.yml", ]) +# To enable OneDNN BRGeMM support, build with: +# bazel build --define gemma_onednn_brgemm=1 ... +config_setting( + name = "gemma_onednn_brgemm", + define_values = {"gemma_onednn_brgemm": "1"}, +) + cc_library( name = "basics", srcs = ["util/basics.cc"], @@ -313,7 +320,14 @@ test_suite( cc_library( name = "matmul_env", srcs = ["ops/matmul.cc"], - hdrs = ["ops/matmul.h"], + hdrs = [ + "ops/brgemm.h", + "ops/matmul.h", + ], + defines = select({ + ":gemma_onednn_brgemm": ["GEMMA_ONEDNN_BRGEMM=1", "DNNL_EXPERIMENTAL_UKERNEL"], + "//conditions:default": [], + }), deps = [ ":allocator", ":basics", @@ -324,14 +338,20 @@ cc_library( "@highway//:hwy", "@highway//:nanobenchmark", "@highway//:profiler", - ], + ] + select({ + ":gemma_onednn_brgemm": ["@onednn//:onednn"], + "//conditions:default": [], + }), ) cc_library( name = "matmul", # allow depending only on this target, without also matmul_env. hdrs = ["ops/matmul.h"], - textual_hdrs = ["ops/matmul-inl.h"], + textual_hdrs = [ + "ops/brgemm-inl.h", + "ops/matmul-inl.h", + ], deps = [ ":allocator", ":basics", @@ -345,7 +365,10 @@ cc_library( "@highway//:hwy", "@highway//:nanobenchmark", "@highway//:profiler", - ], + ] + select({ + ":gemma_onednn_brgemm": ["@onednn//:onednn"], + "//conditions:default": [], + }), ) cc_library( @@ -362,6 +385,7 @@ cc_library( "ops/matmul_static.h", ], textual_hdrs = [ + "ops/brgemm-inl.h", "ops/matmul_static-inl.h", "ops/matmul-inl.h", ], @@ -378,7 +402,10 @@ cc_library( "@highway//:hwy", "@highway//:profiler", "@highway//:timer", - ], + ] + select({ + ":gemma_onednn_brgemm": ["@onednn//:onednn"], + "//conditions:default": [], + }), ) cc_library( diff --git a/CMakeLists.txt b/CMakeLists.txt index 52dc7ca7..2ca0c433 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,6 +22,10 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +# Optional: OneDNN BRGeMM micro-kernel support (x86-64 only). +# Enable with: cmake -DGEMMA_ONEDNN_BRGEMM=ON ... +option(GEMMA_ONEDNN_BRGEMM "Enable OneDNN BRGeMM micro-kernel for MatMul (x86-64)" OFF) + if(EMSCRIPTEN) add_compile_options("-sMEMORY64") add_compile_options("-msimd128") @@ -85,6 +89,23 @@ if(EMSCRIPTEN) target_compile_options(benchmark PRIVATE -Wno-c2y-extensions) endif() +# OneDNN BRGeMM micro-kernel support (optional, x86-64 only). +if(GEMMA_ONEDNN_BRGEMM) + set(DNNL_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(DNNL_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) + set(DNNL_CPU_RUNTIME "SEQ" CACHE STRING "" FORCE) + set(DNNL_GPU_RUNTIME "NONE" CACHE STRING "" FORCE) + set(DNNL_LIBRARY_TYPE "STATIC" CACHE STRING "" FORCE) + set(DNNL_EXPERIMENTAL_UKERNEL ON CACHE BOOL "" FORCE) + FetchContent_Declare(onednn + GIT_REPOSITORY https://github.com/uxlfoundation/oneDNN.git + GIT_TAG v3.11 + EXCLUDE_FROM_ALL + ) + FetchContent_MakeAvailable(onednn) + message(STATUS "OneDNN BRGeMM micro-kernel support enabled") +endif() + # Base source files set(SOURCES compression/compress-inl.h @@ -141,6 +162,8 @@ set(SOURCES ops/matmul-inl.h ops/matmul.cc ops/matmul.h + ops/brgemm.h + ops/brgemm-inl.h ops/ops-inl.h ops/ops.h ops/sum-inl.h @@ -191,6 +214,10 @@ target_link_libraries(libgemma hwy hwy_contrib sentencepiece-static) target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR}) target_compile_definitions(libgemma PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) target_compile_options(libgemma PRIVATE $<$:-Wno-deprecated-declarations>) +if(GEMMA_ONEDNN_BRGEMM) + target_compile_definitions(libgemma PUBLIC GEMMA_ONEDNN_BRGEMM=1 DNNL_EXPERIMENTAL_UKERNEL) + target_link_libraries(libgemma dnnl) +endif() install(TARGETS libgemma DESTINATION lib) # Shared library target for C# interop @@ -215,6 +242,10 @@ target_compile_definitions(gemma_shared $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX> ) target_compile_options(gemma_shared PRIVATE $<$:-Wno-deprecated-declarations>) +if(GEMMA_ONEDNN_BRGEMM) + target_compile_definitions(gemma_shared PUBLIC GEMMA_ONEDNN_BRGEMM=1 DNNL_EXPERIMENTAL_UKERNEL) + target_link_libraries(gemma_shared PRIVATE dnnl) +endif() install(TARGETS gemma_shared DESTINATION lib) install(FILES gemma/c_api.h DESTINATION include/gemma) install(FILES gemma/GemmaInterop.cs DESTINATION include/gemma) diff --git a/MODULE.bazel b/MODULE.bazel index 0dea7752..e44c370b 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -25,6 +25,17 @@ git_override( http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") +# OneDNN v3.11 for BRGeMM micro-kernel support (optional, x86-64 only). +http_archive( + name = "onednn", + build_file = "@//bazel:onednn.BUILD", + sha256 = "04df98b18300daf6c3aa7cc2d5e7ce8a8f430fed1787151daed0254d8dd4e64e", + strip_prefix = "oneDNN-3.11", + urls = [ + "https://github.com/uxlfoundation/oneDNN/archive/refs/tags/v3.11.tar.gz", + ], +) + http_archive( name = "com_google_absl_py", sha256 = "8a3d0830e4eb4f66c4fa907c06edf6ce1c719ced811a12e26d9d3162f8471758", diff --git a/bazel/onednn.BUILD b/bazel/onednn.BUILD new file mode 100644 index 00000000..0cbd436d --- /dev/null +++ b/bazel/onednn.BUILD @@ -0,0 +1,227 @@ +load("@bazel_skylib//rules:expand_template.bzl", "expand_template") + +exports_files(["LICENSE"]) + +expand_template( + name = "dnnl_config_h", + out = "include/oneapi/dnnl/dnnl_config.h", + substitutions = { + "#cmakedefine DNNL_EXPERIMENTAL_UKERNEL": "#define DNNL_EXPERIMENTAL_UKERNEL 1", + "#cmakedefine DNNL_SAFE_RBP": "#undef DNNL_SAFE_RBP", + "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_SEQ", + "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_SEQ", + "#cmakedefine DNNL_DISABLE_GPU_REF_KERNELS": "#define DNNL_DISABLE_GPU_REF_KERNELS", + "#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE", + "#cmakedefine DNNL_GPU_VENDOR DNNL_VENDOR_${DNNL_GPU_VENDOR}": "#define DNNL_GPU_VENDOR DNNL_VENDOR_NONE", + "#cmakedefine DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE": "#undef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE", + "#cmakedefine DNNL_WITH_SYCL": "#undef DNNL_WITH_SYCL", + "#cmakedefine DNNL_WITH_LEVEL_ZERO": "#undef DNNL_WITH_LEVEL_ZERO", + "#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA", + "#cmakedefine DNNL_SYCL_GENERIC": "#undef DNNL_SYCL_GENERIC", + "#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP", + "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER", + "#cmakedefine ONEDNN_BUILD_GRAPH": "#define ONEDNN_BUILD_GRAPH", + "#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#undef DNNL_EXPERIMENTAL_SPARSE", + "#cmakedefine DNNL_EXPERIMENTAL_LOGGING": "#undef DNNL_EXPERIMENTAL_LOGGING", + "#cmakedefine DNNL_EXPERIMENTAL_PROFILING": "#undef DNNL_EXPERIMENTAL_PROFILING", + "#cmakedefine DNNL_EXPERIMENTAL_SYCL_KERNEL_COMPILER": "#undef DNNL_EXPERIMENTAL_SYCL_KERNEL_COMPILER", + "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL", + "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1", + "#cmakedefine01 BUILD_INFERENCE": "#define BUILD_INFERENCE 0", + "#cmakedefine01 BUILD_PRIMITIVE_ALL": "#define BUILD_PRIMITIVE_ALL 1", + "#cmakedefine01 BUILD_BATCH_NORMALIZATION": "#define BUILD_BATCH_NORMALIZATION 0", + "#cmakedefine01 BUILD_BINARY": "#define BUILD_BINARY 0", + "#cmakedefine01 BUILD_CONCAT": "#define BUILD_CONCAT 0", + "#cmakedefine01 BUILD_CONVOLUTION": "#define BUILD_CONVOLUTION 0", + "#cmakedefine01 BUILD_DECONVOLUTION": "#define BUILD_DECONVOLUTION 0", + "#cmakedefine01 BUILD_ELTWISE": "#define BUILD_ELTWISE 0", + "#cmakedefine01 BUILD_GEMM_KERNELS_ALL": "#define BUILD_GEMM_KERNELS_ALL 1", + "#cmakedefine01 BUILD_GEMM_KERNELS_NONE": "#define BUILD_GEMM_KERNELS_NONE 0", + "#cmakedefine01 BUILD_GEMM_SSE41": "#define BUILD_GEMM_SSE41 1", + "#cmakedefine01 BUILD_GEMM_AVX2": "#define BUILD_GEMM_AVX2 1", + "#cmakedefine01 BUILD_GEMM_AVX512": "#define BUILD_GEMM_AVX512 1", + "#cmakedefine01 BUILD_GROUP_NORMALIZATION": "#define BUILD_GROUP_NORMALIZATION 1", + "#cmakedefine01 BUILD_INNER_PRODUCT": "#define BUILD_INNER_PRODUCT 0", + "#cmakedefine01 BUILD_LAYER_NORMALIZATION": "#define BUILD_LAYER_NORMALIZATION 0", + "#cmakedefine01 BUILD_LRN": "#define BUILD_LRN 0", + "#cmakedefine01 BUILD_MATMUL": "#define BUILD_MATMUL 0", + "#cmakedefine01 BUILD_POOLING": "#define BUILD_POOLING 0", + "#cmakedefine01 BUILD_PRELU": "#define BUILD_PRELU 0", + "#cmakedefine01 BUILD_REDUCTION": "#define BUILD_REDUCTION 0", + "#cmakedefine01 BUILD_REORDER": "#define BUILD_REORDER 0", + "#cmakedefine01 BUILD_RESAMPLING": "#define BUILD_RESAMPLING 0", + "#cmakedefine01 BUILD_RNN": "#define BUILD_RNN 0", + "#cmakedefine01 BUILD_SHUFFLE": "#define BUILD_SHUFFLE 0", + "#cmakedefine01 BUILD_SOFTMAX": "#define BUILD_SOFTMAX 0", + "#cmakedefine01 BUILD_SUM": "#define BUILD_SUM 0", + "#cmakedefine01 BUILD_PRIMITIVE_CPU_ISA_ALL": "#define BUILD_PRIMITIVE_CPU_ISA_ALL 1", + "#cmakedefine01 BUILD_SSE41": "#define BUILD_SSE41 0", + "#cmakedefine01 BUILD_AVX2": "#define BUILD_AVX2 0", + "#cmakedefine01 BUILD_AVX512": "#define BUILD_AVX512 0", + "#cmakedefine01 BUILD_AMX": "#define BUILD_AMX 0", + "#cmakedefine01 BUILD_PRIMITIVE_GPU_ISA_ALL": "#define BUILD_PRIMITIVE_GPU_ISA_ALL 0", + "#cmakedefine01 BUILD_XE2": "#define BUILD_XE2 0", + "#cmakedefine01 BUILD_XELP": "#define BUILD_XELP 0", + "#cmakedefine01 BUILD_XEHPG": "#define BUILD_XEHPG 0", + "#cmakedefine01 BUILD_XEHPC": "#define BUILD_XEHPC 0", + "#cmakedefine01 BUILD_XEHP": "#define BUILD_XEHP 0", + "#cmakedefine01 BUILD_SDPA": "#define BUILD_SDPA 1", + "#cmakedefine01 BUILD_XE3": "#define BUILD_XE3 0", + }, + template = "include/oneapi/dnnl/dnnl_config.h.in", +) + +expand_template( + name = "dnnl_version_h", + out = "include/oneapi/dnnl/dnnl_version.h", + substitutions = { + "@DNNL_VERSION_MAJOR@": "3", + "@DNNL_VERSION_MINOR@": "11", + "@DNNL_VERSION_PATCH@": "0", + }, + template = "include/oneapi/dnnl/dnnl_version.h.in", +) + +expand_template( + name = "dnnl_version_hash_h", + out = "include/oneapi/dnnl/dnnl_version_hash.h", + substitutions = { + "@DNNL_VERSION_HASH@": "fc6151651a4577beae5ffac5a4132e75d39e1409", + }, + template = "include/oneapi/dnnl/dnnl_version_hash.h.in", +) + +cc_library( + name = "onednn_autogen", + srcs = glob(["src/cpu/x64/gemm/**/*_kern_autogen*.cpp"]), + copts = [ + "-O1", + "-U_FORTIFY_SOURCE", + "-fexceptions", + "-UUSE_MKL", + "-UUSE_CBLAS", + "-DDNNL_ENABLE_MAX_CPU_ISA", + "-DDNNL_ENABLE_ITT_TASKS", + "-DDNNL_ENABLE_GRAPH_DUMP", + "-DDNNL_EXPERIMENTAL_UKERNEL", + ], + includes = [ + "include", + "src", + "src/common", + "src/cpu", + "src/cpu/gemm", + "src/graph", + "third_party", + "third_party/ittnotify", + "third_party/xbyak", + ], + textual_hdrs = glob([ + "include/**/*", + "src/common/*.hpp", + "src/cpu/*.hpp", + "src/cpu/**/*.hpp", + "src/cpu/jit_utils/**/*.hpp", + "src/graph/interface/*.hpp", + "src/graph/backend/*.hpp", + "src/graph/backend/dnnl/*.hpp", + "src/graph/backend/dnnl/executables/*.hpp", + "src/graph/backend/fake/*.hpp", + "src/graph/backend/dnnl/passes/*.hpp", + "src/graph/backend/dnnl/patterns/*.hpp", + "src/graph/backend/dnnl/kernels/*.hpp", + "src/graph/utils/*.hpp", + "src/graph/utils/pm/*.hpp", + "third_party/ittnotify/**/*.h", + "third_party/spdlog/**/*.h", + "third_party/xbyak/*.h", + ]) + [ + ":dnnl_config_h", + ":dnnl_version_h", + ":dnnl_version_hash_h", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "onednn", + srcs = glob( + [ + "src/common/*.cpp", + "src/cpu/*.cpp", + "src/cpu/**/*.cpp", + "src/cpu/jit_utils/**/*.cpp", + "src/cpu/x64/**/*.cpp", + "src/graph/interface/*.cpp", + "src/graph/backend/*.cpp", + "src/graph/backend/dnnl/*.cpp", + "src/graph/backend/dnnl/executables/*.cpp", + "src/graph/backend/fake/*.cpp", + "src/graph/backend/dnnl/passes/*.cpp", + "src/graph/backend/dnnl/patterns/*.cpp", + "src/graph/backend/dnnl/kernels/*.cpp", + "src/graph/utils/*.cpp", + "src/graph/utils/pm/*.cpp", + "third_party/ittnotify/*.c", + ], + exclude = [ + "src/cpu/aarch64/**", + "src/cpu/rv64/**", + "src/cpu/ppc64/**", + "src/cpu/s390x/**", + "src/cpu/x64/gemm/**/*_kern_autogen.cpp", + "src/cpu/sycl/**", + ], + ), + copts = [ + "-fexceptions", + "-UUSE_MKL", + "-UUSE_CBLAS", + "-DDNNL_ENABLE_MAX_CPU_ISA", + "-DDNNL_ENABLE_ITT_TASKS", + "-DDNNL_ENABLE_GRAPH_DUMP", + "-DDNNL_EXPERIMENTAL_UKERNEL", + ], + includes = [ + "include", + "src", + "src/common", + "src/cpu", + "src/cpu/gemm", + "src/graph", + "third_party", + "third_party/ittnotify", + "third_party/xbyak", + ], + linkopts = [ + "-lrt", + "-Wl,--allow-multiple-definition", + ], + textual_hdrs = glob([ + "include/**/*", + "src/common/*.hpp", + "src/cpu/*.hpp", + "src/cpu/**/*.hpp", + "src/cpu/jit_utils/**/*.hpp", + "src/graph/interface/*.hpp", + "src/graph/backend/*.hpp", + "src/graph/backend/dnnl/*.hpp", + "src/graph/backend/fake/*.hpp", + "src/graph/backend/dnnl/passes/*.hpp", + "src/graph/backend/dnnl/patterns/*.hpp", + "src/graph/backend/dnnl/kernels/*.hpp", + "src/graph/utils/*.hpp", + "src/graph/utils/pm/*.hpp", + "third_party/ittnotify/**/*.h", + "third_party/spdlog/**/*.h", + "third_party/xbyak/*.h", + ]) + [ + ":dnnl_config_h", + ":dnnl_version_h", + ":dnnl_version_hash_h", + ], + visibility = ["//visibility:public"], + deps = [ + ":onednn_autogen", + ], +) diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index 67c702f5..e9432276 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -130,7 +130,11 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { keep += hwy::ConvertScalarTo(C.Row(0)[hwy::Unpredictable1()]); // Only record times after autotuning finished. - if (per_key->autotune.Best()) times.push_back(elapsed); + bool done = per_key->autotune.Best(); +#if GEMMA_ONEDNN_BRGEMM + done = done || per_key->brgemm_autotune.Best(); +#endif + if (done) times.push_back(elapsed); } hwy::PreventElision(keep); env.ctx.pools.MaybeStopSpinning(use_spinning); diff --git a/ops/brgemm-inl.h b/ops/brgemm-inl.h new file mode 100644 index 00000000..78137a3f --- /dev/null +++ b/ops/brgemm-inl.h @@ -0,0 +1,492 @@ +// Copyright 2026 DeepMind Technologies Limited. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// BRGeMM dispatch. Included from matmul-inl.h inside gcpp::HWY_NAMESPACE. + +#if GEMMA_ONEDNN_BRGEMM + +static bool MakeBrgemm(dnnl::ukernel::brgemm& brg, int64_t m, int64_t n, + int64_t k, int64_t batch, int64_t lda, int64_t ldb, + int64_t ldc, dnnl::memory::data_type a_dt, + dnnl::memory::data_type b_dt, + dnnl::memory::data_type c_dt, bool add_C) { + try { + brg = dnnl::ukernel::brgemm(m, n, k, batch, lda, ldb, ldc, a_dt, b_dt, + c_dt, true); + if (!brg) return false; + brg.set_add_C(add_C); + if (!brg.finalize()) return false; + brg.generate(); + return true; + } catch (...) { + return false; + } +} + +template +static HWY_NOINLINE void DoMatMul_BRGeMM( + const MatPtrT& A, const MatPtrT& B, RowPtrs C, size_t M, + size_t K, size_t N, float scale, const float* HWY_RESTRICT add, + const BRGeMMConfig& cfg, ThreadingContext& ctx, size_t cluster_idx) { + using dnnl::ukernel::brgemm; + using dnnl::ukernel::pack_type; + using dnnl::ukernel::transform; + + // Level-1 cache: kernels keyed on (M, K, N, config). + const BRGeMMKernelKey kern_key{M, K, N, cfg.M_blk, cfg.N_blk, cfg.K_blk, + cfg.batch_size}; + auto& kern_cache = GetBRGeMMKernelCache(); + auto kern_it = kern_cache.find(kern_key); + + if (kern_it == kern_cache.end()) { + BRGeMMKernelEntry ke; + + ke.K_blk = cfg.K_blk; + ke.N_blk = cfg.N_blk; + ke.M_blk = + static_cast(std::min(static_cast(cfg.M_blk), M)); + + ke.M_tail = M % ke.M_blk; + ke.N_tail = N % ke.N_blk; + ke.K_tail = K % ke.K_blk; + + ke.K_chunks = K / ke.K_blk; + ke.N_full_tiles = N / ke.N_blk; + ke.M_full_tiles = M / ke.M_blk; + ke.N_total_tiles = ke.N_full_tiles + (ke.N_tail ? 1 : 0); + ke.M_total_tiles = ke.M_full_tiles + (ke.M_tail ? 1 : 0); + ke.N_padded = ke.N_total_tiles * ke.N_blk; + + if (ke.M_total_tiles == 0 || ke.N_total_tiles == 0 || + (ke.K_chunks == 0 && ke.K_tail == 0)) { + return; + } + + ke.K_super_size = std::min(cfg.batch_size, ke.K_chunks); + ke.K_super_blocks = (ke.K_chunks > 0) ? ke.K_chunks / ke.K_super_size : 0; + ke.K_super_rem = (ke.K_chunks > 0) ? ke.K_chunks % ke.K_super_size : 0; + ke.batch_full = ke.K_super_size; + ke.batch_rem = ke.K_super_rem; + + const auto a_dt = dnnl::memory::data_type::bf16; + const auto b_dt = dnnl::memory::data_type::bf16; + const auto c_dt = dnnl::memory::data_type::f32; + ke.a_dt_size = dnnl::memory::data_type_size(a_dt); + ke.b_dt_size = dnnl::memory::data_type_size(b_dt); + + const auto pack = brgemm::get_B_pack_type(a_dt, b_dt); + if (pack == pack_type::undef) return; + ke.need_pack = (pack != pack_type::no_trans); + + ke.lda = A.Stride(); + ke.ldb_orig = B.Stride(); + + ke.m_sizes[0] = ke.M_blk; + ke.m_sizes[1] = ke.M_tail ? ke.M_tail : ke.M_blk; + ke.n_sizes[0] = ke.N_blk; + ke.n_sizes[1] = ke.N_tail ? ke.N_tail : ke.N_blk; + const int64_t ldb_for[2] = {ke.N_blk, ke.N_tail ? ke.N_tail : ke.N_blk}; + const int64_t ldc_for[2] = {ke.N_blk, ke.N_tail ? ke.N_tail : ke.N_blk}; + + // Create brgemm kernels for each (M-tile, N-tile) variant. + size_t max_sp = 0; + for (int mi = 0; mi < 2; ++mi) { + for (int ni = 0; ni < 2; ++ni) { + if (mi == 1 && ke.M_tail == 0) continue; + if (ni == 1 && ke.N_tail == 0) continue; + if (mi == 0 && ke.M_full_tiles == 0) continue; + if (ni == 0 && ke.N_full_tiles == 0) continue; + + const int64_t ms = ke.m_sizes[mi]; + const int64_t ns = ke.n_sizes[ni]; + + if (ke.K_chunks > 0) { + if (!MakeBrgemm(ke.brg_first_all[mi][ni], ms, ns, ke.K_blk, + ke.K_super_size, ke.lda, ldb_for[ni], ldc_for[ni], + a_dt, b_dt, c_dt, false)) { + return; + } + max_sp = std::max(max_sp, + ke.brg_first_all[mi][ni].get_scratchpad_size()); + } + if (ke.K_super_blocks > 1) { + if (!MakeBrgemm(ke.brg_full[mi][ni], ms, ns, ke.K_blk, + ke.batch_full, ke.lda, ldb_for[ni], ldc_for[ni], + a_dt, b_dt, c_dt, true)) { + return; + } + max_sp = + std::max(max_sp, ke.brg_full[mi][ni].get_scratchpad_size()); + } + if (ke.K_super_rem > 0) { + const bool rem_is_first = (ke.K_super_blocks == 0); + auto& target = rem_is_first ? ke.brg_first_rem[mi][ni] + : ke.brg_rem[mi][ni]; + if (!MakeBrgemm(target, ms, ns, ke.K_blk, ke.batch_rem, ke.lda, + ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt, + !rem_is_first)) { + return; + } + max_sp = std::max(max_sp, target.get_scratchpad_size()); + } + if (ke.K_tail > 0) { + const bool add_c = (ke.K_chunks > 0); + if (!MakeBrgemm(ke.brg_ktail[mi][ni], ms, ns, ke.K_tail, 1, ke.lda, + ldb_for[ni], ldc_for[ni], a_dt, b_dt, c_dt, + add_c)) { + return; + } + max_sp = + std::max(max_sp, ke.brg_ktail[mi][ni].get_scratchpad_size()); + } + } + } + ke.scratchpad_size = max_sp + 64; + + // Create B-packing transforms. + if (ke.need_pack) { + for (int ni = 0; ni < 2; ++ni) { + if (ni == 1 && ke.N_tail == 0) continue; + if (ni == 0 && ke.N_full_tiles == 0) continue; + + const int64_t ns = ke.n_sizes[ni]; + if (ke.K_chunks > 0) { + const int64_t K_full = ke.K_chunks * ke.K_blk; + try { + ke.pack_B[ni] = transform(K_full, ns, pack_type::trans, + ke.ldb_orig, ldb_for[ni], b_dt, b_dt); + if (!ke.pack_B[ni]) return; + ke.pack_B[ni].generate(); + ke.blocked_B_size[ni] = ldb_for[ni] * ke.K_blk * ke.b_dt_size; + } catch (...) { + return; + } + } + if (ke.K_tail > 0) { + try { + ke.pack_B_ktail[ni] = transform( + ke.K_tail, ns, pack_type::trans, ke.ldb_orig, ldb_for[ni], + b_dt, b_dt); + if (!ke.pack_B_ktail[ni]) return; + ke.pack_B_ktail[ni].generate(); + ke.blocked_B_ktail_size[ni] = + ldb_for[ni] * ke.K_tail * ke.b_dt_size; + } catch (...) { + return; + } + } + } + } + + // Precompute A/B offset tables for each K-super-block. + for (int ni = 0; ni < 2; ++ni) { + if (ni == 1 && ke.N_tail == 0) continue; + if (ni == 0 && ke.N_full_tiles == 0) continue; + const int64_t cur_n = ke.n_sizes[ni]; + + if (ke.K_chunks > 0) { + ke.offsets_first_all[ni].resize(ke.K_super_size); + for (int64_t i = 0; i < ke.K_super_size; ++i) { + const int64_t a_off = + i * ke.K_blk * static_cast(ke.a_dt_size); + const int64_t b_off = + ke.need_pack + ? i * static_cast(ke.blocked_B_size[ni]) + : i * cur_n * ke.K_blk * static_cast(ke.b_dt_size); + ke.offsets_first_all[ni][i] = {a_off, b_off}; + } + } + + if (ke.K_super_blocks > 1) { + ke.offsets_full[ni].resize(ke.K_super_blocks - 1); + for (int64_t ks = 1; ks < ke.K_super_blocks; ++ks) { + auto& tbl = ke.offsets_full[ni][ks - 1]; + tbl.resize(ke.batch_full); + const int64_t k_start = ks * ke.K_super_size; + for (int64_t i = 0; i < ke.batch_full; ++i) { + const int64_t k_idx = k_start + i; + const int64_t a_off = + k_idx * ke.K_blk * static_cast(ke.a_dt_size); + const int64_t b_off = + ke.need_pack + ? k_idx * static_cast(ke.blocked_B_size[ni]) + : k_idx * cur_n * ke.K_blk * + static_cast(ke.b_dt_size); + tbl[i] = {a_off, b_off}; + } + } + } + + if (ke.K_super_rem > 0) { + const int64_t k_base = ke.K_super_blocks * ke.K_super_size; + auto& rem_tbl = (ke.K_super_blocks == 0) ? ke.offsets_first_rem[ni] + : ke.offsets_rem[ni]; + rem_tbl.resize(ke.K_super_rem); + for (int64_t i = 0; i < ke.K_super_rem; ++i) { + const int64_t k_idx = k_base + i; + const int64_t a_off = + k_idx * ke.K_blk * static_cast(ke.a_dt_size); + const int64_t b_off = + ke.need_pack + ? k_idx * static_cast(ke.blocked_B_size[ni]) + : k_idx * cur_n * ke.K_blk * + static_cast(ke.b_dt_size); + rem_tbl[i] = {a_off, b_off}; + } + } + } + + kern_it = kern_cache.emplace(kern_key, std::move(ke)).first; + } + + BRGeMMKernelEntry& ke = kern_it->second; + if (ke.M_total_tiles == 0 || ke.N_total_tiles == 0) return; + + // Level-2 cache: packed B keyed on (B_ptr, K, N, config). + const uint8_t* A_base = reinterpret_cast(A.Row(0)); + const uint8_t* B_base = reinterpret_cast(B.Row(0)); + + const BRGeMMPackedBKey pb_key{reinterpret_cast(B_base), K, N, + ke.K_blk, ke.N_blk}; + auto& pb_cache = GetBRGeMMPackedBCache(); + auto pb_it = pb_cache.find(pb_key); + + if (pb_it == pb_cache.end()) { + BRGeMMPackedBEntry pe; + pe.B_tile_offset.resize(ke.N_total_tiles, 0); + pe.B_ktail_offset.resize(ke.N_total_tiles, 0); + + if (ke.need_pack) { + size_t total_packed = 0; + for (int64_t nt = 0; nt < ke.N_total_tiles; ++nt) { + const int ni = (nt < ke.N_full_tiles) ? 0 : 1; + pe.B_tile_offset[nt] = total_packed; + if (ke.K_chunks > 0) + total_packed += ke.blocked_B_size[ni] * ke.K_chunks; + pe.B_ktail_offset[nt] = total_packed; + if (ke.K_tail > 0) total_packed += ke.blocked_B_ktail_size[ni]; + } + + pe.B_packed_buf.Resize(total_packed); + uint8_t* B_packed = pe.B_packed_buf.data(); + if (!B_packed) return; + + for (int64_t nt = 0; nt < ke.N_total_tiles; ++nt) { + const int ni = (nt < ke.N_full_tiles) ? 0 : 1; + const int64_t b_row = (nt < ke.N_full_tiles) + ? nt * ke.N_blk + : ke.N_full_tiles * ke.N_blk; + const uint8_t* B_in = + B_base + b_row * ke.ldb_orig * ke.b_dt_size; + + try { + if (ke.K_chunks > 0) { + ke.pack_B[ni].execute(const_cast(B_in), + B_packed + pe.B_tile_offset[nt]); + } + if (ke.K_tail > 0) { + const uint8_t* B_in_ktail = + B_in + ke.K_chunks * ke.K_blk * ke.b_dt_size; + ke.pack_B_ktail[ni].execute(const_cast(B_in_ktail), + B_packed + pe.B_ktail_offset[nt]); + } + } catch (...) { + return; + } + } + } + + pb_it = pb_cache.emplace(pb_key, std::move(pe)).first; + } + + const BRGeMMPackedBEntry& pe = pb_it->second; + const uint8_t* B_packed = + ke.need_pack ? pe.B_packed_buf.data() : nullptr; + + std::vector> offsets_ktail(1); + if (ke.K_tail > 0) offsets_ktail[0] = {0, 0}; + + // Execute one (m, n) tile for a given K-super-block. + const auto execute_tile = [&](size_t m_start, size_t n_start, + int64_t k_super, float* temp_C, + uint8_t* scratch) HWY_ATTR { + const int64_t m_tile_idx = m_start / ke.M_blk; + const int64_t n_tile_idx = n_start / ke.N_blk; + const int mi = (m_tile_idx < ke.M_full_tiles) ? 0 : 1; + const int ni = (n_tile_idx < ke.N_full_tiles) ? 0 : 1; + const int64_t cur_m = ke.m_sizes[mi]; + const int64_t cur_n = ke.n_sizes[ni]; + + const size_t real_m = (m_tile_idx < ke.M_full_tiles) + ? m_tile_idx * ke.M_blk + : ke.M_full_tiles * ke.M_blk; + const size_t real_n = (n_tile_idx < ke.N_full_tiles) + ? n_tile_idx * ke.N_blk + : ke.N_full_tiles * ke.N_blk; + + const uint8_t* A_tile = A_base + real_m * ke.lda * ke.a_dt_size; + const void* B_tile = + ke.need_pack + ? static_cast(B_packed + + pe.B_tile_offset[n_tile_idx]) + : static_cast(B_base + + real_n * ke.ldb_orig * ke.b_dt_size); + + float* C_tile_ptr = temp_C; + const int64_t k_total = + ke.K_super_blocks + (ke.K_super_rem > 0 ? 1 : 0); + + if (k_super < ke.K_super_blocks) { + if (k_super == 0) { + ke.brg_first_all[mi][ni].execute(A_tile, const_cast(B_tile), + ke.offsets_first_all[ni], C_tile_ptr, + scratch); + } else { + ke.brg_full[mi][ni].execute(A_tile, const_cast(B_tile), + ke.offsets_full[ni][k_super - 1], + C_tile_ptr, scratch); + } + } else if (ke.K_super_rem > 0 && k_super == ke.K_super_blocks) { + if (ke.K_super_blocks == 0) { + ke.brg_first_rem[mi][ni].execute(A_tile, const_cast(B_tile), + ke.offsets_first_rem[ni], C_tile_ptr, + scratch); + } else { + ke.brg_rem[mi][ni].execute(A_tile, const_cast(B_tile), + ke.offsets_rem[ni], C_tile_ptr, scratch); + } + } + + const bool is_last = (k_total > 0) ? (k_super == k_total - 1) : true; + if (is_last) { + if (ke.K_tail > 0) { + const uint8_t* A_ktail = + A_tile + ke.K_chunks * ke.K_blk * ke.a_dt_size; + const void* B_ktail = + ke.need_pack + ? static_cast(B_packed + + pe.B_ktail_offset[n_tile_idx]) + : static_cast( + B_base + (real_n * ke.ldb_orig + + ke.K_chunks * ke.K_blk) * + ke.b_dt_size); + ke.brg_ktail[mi][ni].execute(A_ktail, const_cast(B_ktail), + offsets_ktail, C_tile_ptr, scratch); + } + + // Scale and copy temp_C to output. + const hn::ScalableTag df; + const auto vscale = hn::Set(df, scale); + const size_t lanes = hn::Lanes(df); + for (int64_t m = 0; m < cur_m; ++m) { + TC* C_row = C.Row(real_m + m) + real_n; + const float* t_row = C_tile_ptr + m * cur_n; + const float* add_row = add ? add + real_n : nullptr; + int64_t n = 0; + if (add_row) { + for (; n + static_cast(lanes) <= cur_n; + n += static_cast(lanes)) { + const auto v = hn::Load(df, t_row + n); + const auto va = hn::Load(df, add_row + n); + const auto result = hn::MulAdd(v, vscale, va); + if constexpr (hwy::IsSame()) { + hn::Store(result, df, reinterpret_cast(C_row) + n); + } else { + const hn::Rebind dc; + hn::Store(hn::DemoteTo(dc, result), dc, C_row + n); + } + } + for (; n < cur_n; ++n) { + float val = t_row[n] * scale + add_row[n]; + C_row[n] = hwy::ConvertScalarTo(val); + } + } else { + for (; n + static_cast(lanes) <= cur_n; + n += static_cast(lanes)) { + const auto v = hn::Load(df, t_row + n); + const auto result = hn::Mul(v, vscale); + if constexpr (hwy::IsSame()) { + hn::Store(result, df, reinterpret_cast(C_row) + n); + } else { + const hn::Rebind dc; + hn::Store(hn::DemoteTo(dc, result), dc, C_row + n); + } + } + for (; n < cur_n; ++n) { + float val = t_row[n] * scale; + C_row[n] = hwy::ConvertScalarTo(val); + } + } + } + } + }; + + // Parallel dispatch: K-super outer, N middle, M inner (keeps B in L2). + const int64_t k_total_supers = + ke.K_super_blocks + (ke.K_super_rem > 0 ? 1 : 0); + const int64_t k_iters = (k_total_supers > 0) ? k_total_supers : 1; + + const size_t num_threads = ctx.pools.MaxWorkersPerCluster(); + const size_t total_n_tiles = ke.N_total_tiles; + const size_t total_m_tiles = ke.M_total_tiles; + const size_t n_tasks = + std::max(size_t{1}, std::min(total_n_tiles, num_threads)); + + const hwy::pool::Caller caller = + ctx.pool_callers.Get(Callers::kBRGeMM); + + ParallelForWithinCluster( + n_tasks, ctx, cluster_idx, caller, + [&](uint64_t task_idx, size_t /*worker*/) HWY_ATTR { + const size_t tiles_per_task = total_n_tiles / n_tasks; + const size_t extra = total_n_tiles % n_tasks; + const size_t n_begin = + task_idx * tiles_per_task + + std::min(static_cast(task_idx), extra); + const size_t n_end = + n_begin + tiles_per_task + (task_idx < extra ? 1 : 0); + + auto& tbufs = GetBRGeMMThreadBufs(); + tbufs.MaybeSetHwContext(ke.brg_first_all[0][0]); + uint8_t* sp = tbufs.EnsureScratch(ke.scratchpad_size); + + const size_t n_tiles_in_range = n_end - n_begin; + const size_t total_tc = total_m_tiles * n_tiles_in_range; + float* tc_base = tbufs.EnsureTempC(total_tc); + + for (int64_t ks = 0; ks < k_iters; ++ks) { + size_t n_idx = 0; + for (size_t nt = n_begin; nt < n_end; ++nt) { + const size_t n = nt * ke.N_blk; + for (int64_t mt = 0; mt < static_cast(total_m_tiles); + ++mt) { + const size_t m = mt * ke.M_blk; + float* temp_C = + tc_base + (mt * n_tiles_in_range + n_idx) * + BRGeMMThreadBufs::kMaxTempCSize; + execute_tile(m, n, ks, temp_C, sp); + } + ++n_idx; + } + } + }); + + dnnl::ukernel::brgemm::release_hw_context(); + auto& main_bufs = GetBRGeMMThreadBufs(); + main_bufs.hw_ctx_set = false; + main_bufs.hw_ctx_kernel = nullptr; +} + +#endif // GEMMA_ONEDNN_BRGEMM diff --git a/ops/brgemm.h b/ops/brgemm.h new file mode 100644 index 00000000..38e05509 --- /dev/null +++ b/ops/brgemm.h @@ -0,0 +1,288 @@ +// Copyright 2026 DeepMind Technologies Limited. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// OneDNN BRGeMM micro-kernel integration for MatMul on Intel AMX/AVX-512. +// Enabled at compile time via GEMMA_ONEDNN_BRGEMM=1 (Bazel: --define gemma_onednn_brgemm=1). + +#ifndef THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_ +#define THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_ + +#include +#include + +#include +#include +#include +#include + +#include "hwy/base.h" + +#if GEMMA_ONEDNN_BRGEMM +#include + +#include "oneapi/dnnl/dnnl.hpp" +#include "oneapi/dnnl/dnnl_ukernel.hpp" +#endif // GEMMA_ONEDNN_BRGEMM + +namespace gcpp { + +struct BRGeMMConfig { + int64_t M_blk; + int64_t N_blk; + int64_t K_blk; + int64_t batch_size; + int64_t par_m; +}; + +#if GEMMA_ONEDNN_BRGEMM + +// Generates autotuning candidates. Fixed: N_blk=32, K_blk=32 (AMX BF16). +// Tunable: M_blk in {32,64}, batch_size in {16,32,64,128,256}. +inline std::vector BRGeMMCandidates(size_t M, size_t K, + size_t N) { + std::vector out; + static constexpr int64_t kNBlk = 32; + static constexpr int64_t kKBlk = 32; + static constexpr int64_t kMBlkValues[] = {32, 64}; + static constexpr int64_t kBatchValues[] = {16, 32, 64, 128, 256}; + + const int64_t k_chunks = static_cast(K) / kKBlk; + for (int64_t mb : kMBlkValues) { + if (mb > static_cast(M)) continue; + if (kNBlk > static_cast(N)) continue; + for (int64_t bs : kBatchValues) { + const int64_t eff_bs = + (k_chunks > 0) ? std::min(bs, k_chunks) : int64_t{1}; + bool dup = false; + for (const auto& c : out) { + if (c.M_blk == mb && c.batch_size == eff_bs) { + dup = true; + break; + } + } + if (dup) continue; + out.push_back({mb, kNBlk, kKBlk, eff_bs, /*par_m=*/1}); + } + } + if (out.empty()) { + out.push_back({static_cast(std::min(M, size_t{32})), + static_cast(std::min(N, size_t{32})), 32, 1, 1}); + } + return out; +} + +// Hugepage-backed buffer via mmap with MADV_HUGEPAGE for packed-B matrices. +class HugePageBuffer { + public: + HugePageBuffer() = default; + ~HugePageBuffer() { + if (ptr_ && size_) munmap(ptr_, size_); + } + + HugePageBuffer(HugePageBuffer&& o) noexcept + : ptr_(o.ptr_), size_(o.size_) { + o.ptr_ = nullptr; + o.size_ = 0; + } + HugePageBuffer& operator=(HugePageBuffer&& o) noexcept { + if (this != &o) { + if (ptr_ && size_) munmap(ptr_, size_); + ptr_ = o.ptr_; + size_ = o.size_; + o.ptr_ = nullptr; + o.size_ = 0; + } + return *this; + } + + HugePageBuffer(const HugePageBuffer&) = delete; + HugePageBuffer& operator=(const HugePageBuffer&) = delete; + + void Resize(size_t n) { + if (ptr_ && size_) munmap(ptr_, size_); + static constexpr size_t kHugePageSize = 2u << 20; + size_ = (n + kHugePageSize - 1) & ~(kHugePageSize - 1); + ptr_ = static_cast(mmap(nullptr, size_, PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0)); + if (ptr_ == MAP_FAILED) { + ptr_ = nullptr; + size_ = 0; + return; + } + madvise(ptr_, size_, MADV_HUGEPAGE); + for (size_t off = 0; off < size_; off += kHugePageSize) { + static_cast(ptr_)[off] = 0; + } + } + + uint8_t* data() { return ptr_; } + const uint8_t* data() const { return ptr_; } + size_t size() const { return size_; } + + private: + uint8_t* ptr_ = nullptr; + size_t size_ = 0; +}; + +// Kernel cache key: identifies a JIT-compiled kernel set. +struct BRGeMMKernelKey { + size_t M, K, N; + int64_t M_blk, N_blk, K_blk, batch_size; + bool operator==(const BRGeMMKernelKey& o) const { + return M == o.M && K == o.K && N == o.N && M_blk == o.M_blk && + N_blk == o.N_blk && K_blk == o.K_blk && batch_size == o.batch_size; + } +}; + +struct BRGeMMKernelKeyHash { + size_t operator()(const BRGeMMKernelKey& k) const { + size_t h = 14695981039346656037ULL; + h = (h ^ k.M) * 1099511628211ULL; + h = (h ^ k.K) * 1099511628211ULL; + h = (h ^ k.N) * 1099511628211ULL; + h = (h ^ static_cast(k.M_blk)) * 1099511628211ULL; + h = (h ^ static_cast(k.N_blk)) * 1099511628211ULL; + h = (h ^ static_cast(k.K_blk)) * 1099511628211ULL; + h = (h ^ static_cast(k.batch_size)) * 1099511628211ULL; + return h; + } +}; + +// Cached JIT-compiled kernels with precomputed tile parameters and offsets. +struct BRGeMMKernelEntry { + int64_t M_blk, N_blk, K_blk; + int64_t M_tail, N_tail, K_tail; + int64_t K_chunks; + int64_t M_full_tiles, N_full_tiles; + int64_t M_total_tiles, N_total_tiles; + int64_t K_super_size, K_super_blocks; + int64_t K_super_rem; + int64_t batch_full, batch_rem; + int64_t m_sizes[2], n_sizes[2]; + int64_t lda; + int64_t ldb_orig; + bool need_pack; + size_t a_dt_size, b_dt_size; + size_t N_padded; + + // Kernels indexed by [m_tail_flag][n_tail_flag]. + dnnl::ukernel::brgemm brg_first_all[2][2]; + dnnl::ukernel::brgemm brg_full[2][2]; + dnnl::ukernel::brgemm brg_ktail[2][2]; + dnnl::ukernel::brgemm brg_first_rem[2][2]; + dnnl::ukernel::brgemm brg_rem[2][2]; + + // B-packing transforms indexed by n_tail_flag. + dnnl::ukernel::transform pack_B[2], pack_B_ktail[2]; + size_t blocked_B_size[2] = {0, 0}; + size_t blocked_B_ktail_size[2] = {0, 0}; + + size_t scratchpad_size = 0; + + using OffsetVec = + std::vector>; + OffsetVec offsets_first_all[2]; + std::vector offsets_full[2]; + OffsetVec offsets_first_rem[2]; + OffsetVec offsets_rem[2]; +}; + +// Packed-B cache key. +struct BRGeMMPackedBKey { + uintptr_t B_ptr; + size_t K, N; + int64_t K_blk, N_blk; + bool operator==(const BRGeMMPackedBKey& o) const { + return B_ptr == o.B_ptr && K == o.K && N == o.N && K_blk == o.K_blk && + N_blk == o.N_blk; + } +}; + +struct BRGeMMPackedBKeyHash { + size_t operator()(const BRGeMMPackedBKey& k) const { + size_t h = 14695981039346656037ULL; + h = (h ^ k.B_ptr) * 1099511628211ULL; + h = (h ^ k.K) * 1099511628211ULL; + h = (h ^ k.N) * 1099511628211ULL; + h = (h ^ static_cast(k.K_blk)) * 1099511628211ULL; + h = (h ^ static_cast(k.N_blk)) * 1099511628211ULL; + return h; + } +}; + +struct BRGeMMPackedBEntry { + HugePageBuffer B_packed_buf; + std::vector B_tile_offset; + std::vector B_ktail_offset; +}; + +// Thread-local buffers for BRGeMM parallel dispatch. +struct BRGeMMThreadBufs { + static constexpr size_t kMaxTempCSize = 64 * 64; + + std::vector scratch; + std::vector tc_storage; + bool hw_ctx_set = false; + const void* hw_ctx_kernel = nullptr; + + uint8_t* EnsureScratch(size_t size) { + if (scratch.size() < size + 64) scratch.resize(size + 64); + return scratch.data() + + (64 - (reinterpret_cast(scratch.data()) % 64)); + } + + float* EnsureTempC(size_t n_tiles) { + const size_t need = n_tiles * kMaxTempCSize * sizeof(float) + 64; + if (tc_storage.size() < need) tc_storage.resize(need); + return reinterpret_cast( + (reinterpret_cast(tc_storage.data()) + 63) & + ~uintptr_t{63}); + } + + void MaybeSetHwContext(const dnnl::ukernel::brgemm& brg) { + const void* brg_ptr = &brg; + if (!hw_ctx_set || hw_ctx_kernel != brg_ptr) { + brg.set_hw_context(); + hw_ctx_set = true; + hw_ctx_kernel = brg_ptr; + } + } +}; + +inline BRGeMMThreadBufs& GetBRGeMMThreadBufs() { + static thread_local BRGeMMThreadBufs bufs; + return bufs; +} + +// Singleton caches. Thread-safety: MatMul is not called concurrently per env. +inline auto& GetBRGeMMKernelCache() { + static std::unordered_map + cache; + return cache; +} + +inline auto& GetBRGeMMPackedBCache() { + static std::unordered_map + cache; + return cache; +} + +#endif // GEMMA_ONEDNN_BRGEMM + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_OPS_BRGEMM_H_ diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 4b217a15..fda9d821 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -47,6 +47,10 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; +#if GEMMA_ONEDNN_BRGEMM +#include "ops/brgemm-inl.h" // DoMatMul_BRGeMM +#endif // GEMMA_ONEDNN_BRGEMM + // Like hn::PromoteOddTo, but uses assembly to avoid an extra vector register. template > static hn::VFromD FastPromoteOddTo(DF df, hn::VFromD vbf) { @@ -1077,6 +1081,46 @@ HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, MMPerKey& per_key = MMImpl::FindOrAddPerKey( M, K, N, num_B, cache.VectorBytes(), env.per_cluster[cluster_idx]); +#if GEMMA_ONEDNN_BRGEMM + // BRGeMM path for BF16×BF16 on Intel AMX/AVX-512. + // Requires M,N,K >= 32 and K % 32 == 0 (AMX tile constraint). + if constexpr (IsBF16() && IsBF16()) { + if (M >= 32 && N >= 32 && K >= 32 && (K % 32) == 0) { + const float scale = A.Scale() * B.Scale(); + MMAutoTune& brg_tuner = per_key.brgemm_autotune; + + if (HWY_LIKELY(brg_tuner.Best())) { + DoMatMul_BRGeMM(A, B, C_rows, M, K, N, scale, add, *brg_tuner.Best(), + env.ctx, cluster_idx); + return &per_key; + } + + if (HWY_UNLIKELY(!brg_tuner.HasCandidates())) { + brg_tuner.SetCandidates(BRGeMMCandidates(M, K, N)); + } + + const BRGeMMConfig& cfg = brg_tuner.NextConfig(); + const uint64_t t0 = hwy::timer::Start(); + DoMatMul_BRGeMM(A, B, C_rows, M, K, N, scale, add, cfg, env.ctx, + cluster_idx); + const uint64_t t1 = + env.have_timer_stop ? hwy::timer::Stop() : hwy::timer::Start(); + brg_tuner.NotifyTicks(t1 - t0); + + if (HWY_UNLIKELY(env.print_best && brg_tuner.Best())) { + const BRGeMMConfig& best = *brg_tuner.Best(); + fprintf(stderr, + "BRGeMM best: %zux%zux%zu M_blk=%ld N_blk=%ld K_blk=%ld " + "batch=%ld\n", + M, K, N, static_cast(best.M_blk), + static_cast(best.N_blk), static_cast(best.K_blk), + static_cast(best.batch_size)); + } + return &per_key; + } + } // if constexpr BF16/float +#endif // GEMMA_ONEDNN_BRGEMM + // (Also auto-tunes, hence outside the timed section to prevent interference.) const StridedViewBF A_view = MMDecompress::MaybeDecompressA(A, per_key.autotune_par_a, env, options); diff --git a/ops/matmul.h b/ops/matmul.h index 0f3d2866..4724bad9 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -32,6 +32,7 @@ #include "hwy/base.h" #include "hwy/bit_set.h" #include "hwy/profiler.h" +#include "ops/brgemm.h" // BRGeMMConfig, GEMMA_ONEDNN_BRGEMM // IWYU pragma: end_exports namespace gcpp { @@ -639,6 +640,9 @@ class MMKeys { struct MMPerKey { MMAutoTune autotune; MMAutoTune autotune_par_a; +#if GEMMA_ONEDNN_BRGEMM + MMAutoTune brgemm_autotune; +#endif // GEMMA_ONEDNN_BRGEMM }; // Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive diff --git a/util/zones.cc b/util/zones.cc index aec4bbd0..b552bb17 100644 --- a/util/zones.cc +++ b/util/zones.cc @@ -135,6 +135,8 @@ const char* CallerName(Callers caller) { return "Att.DotSoftmaxWeightedSum"; case Callers::kBlobWriter: return "BlobWriter"; + case Callers::kBRGeMM: + return "BRGeMM"; case Callers::kCompress: return "Compress"; case Callers::kFixupWeights: diff --git a/util/zones.h b/util/zones.h index 64b859d2..ba3d5a9b 100644 --- a/util/zones.h +++ b/util/zones.h @@ -81,6 +81,7 @@ enum class Callers { // Keep sorted kAttComputeQKV, kAttDotSoftmaxWeightedSum, kBlobWriter, + kBRGeMM, kCompress, kFixupWeights, kFlashAttention,