Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ exports_files([
".github/workflows/build.yml",
])

# To enable OneDNN BRGeMM support, build with:
# bazel build --define gemma_onednn_brgemm=1 ...
config_setting(
name = "gemma_onednn_brgemm",
define_values = {"gemma_onednn_brgemm": "1"},
)

cc_library(
name = "basics",
srcs = ["util/basics.cc"],
Expand Down Expand Up @@ -313,7 +320,14 @@ test_suite(
cc_library(
name = "matmul_env",
srcs = ["ops/matmul.cc"],
hdrs = ["ops/matmul.h"],
hdrs = [
"ops/brgemm.h",
"ops/matmul.h",
],
defines = select({
":gemma_onednn_brgemm": ["GEMMA_ONEDNN_BRGEMM=1", "DNNL_EXPERIMENTAL_UKERNEL"],
"//conditions:default": [],
}),
deps = [
":allocator",
":basics",
Expand All @@ -324,14 +338,20 @@ cc_library(
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:profiler",
],
] + select({
":gemma_onednn_brgemm": ["@onednn//:onednn"],
"//conditions:default": [],
}),
)

cc_library(
name = "matmul",
# allow depending only on this target, without also matmul_env.
hdrs = ["ops/matmul.h"],
textual_hdrs = ["ops/matmul-inl.h"],
textual_hdrs = [
"ops/brgemm-inl.h",
"ops/matmul-inl.h",
],
deps = [
":allocator",
":basics",
Expand All @@ -345,7 +365,10 @@ cc_library(
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:profiler",
],
] + select({
":gemma_onednn_brgemm": ["@onednn//:onednn"],
"//conditions:default": [],
}),
)

cc_library(
Expand All @@ -362,6 +385,7 @@ cc_library(
"ops/matmul_static.h",
],
textual_hdrs = [
"ops/brgemm-inl.h",
"ops/matmul_static-inl.h",
"ops/matmul-inl.h",
],
Expand All @@ -378,7 +402,10 @@ cc_library(
"@highway//:hwy",
"@highway//:profiler",
"@highway//:timer",
],
] + select({
":gemma_onednn_brgemm": ["@onednn//:onednn"],
"//conditions:default": [],
}),
)

cc_library(
Expand Down
31 changes: 31 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

# Optional: OneDNN BRGeMM micro-kernel support (x86-64 only).
# Enable with: cmake -DGEMMA_ONEDNN_BRGEMM=ON ...
option(GEMMA_ONEDNN_BRGEMM "Enable OneDNN BRGeMM micro-kernel for MatMul (x86-64)" OFF)

if(EMSCRIPTEN)
add_compile_options("-sMEMORY64")
add_compile_options("-msimd128")
Expand Down Expand Up @@ -85,6 +89,23 @@ if(EMSCRIPTEN)
target_compile_options(benchmark PRIVATE -Wno-c2y-extensions)
endif()

# OneDNN BRGeMM micro-kernel support (optional, x86-64 only).
if(GEMMA_ONEDNN_BRGEMM)
set(DNNL_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(DNNL_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
set(DNNL_CPU_RUNTIME "SEQ" CACHE STRING "" FORCE)
set(DNNL_GPU_RUNTIME "NONE" CACHE STRING "" FORCE)
set(DNNL_LIBRARY_TYPE "STATIC" CACHE STRING "" FORCE)
set(DNNL_EXPERIMENTAL_UKERNEL ON CACHE BOOL "" FORCE)
FetchContent_Declare(onednn
GIT_REPOSITORY https://github.com/uxlfoundation/oneDNN.git
GIT_TAG v3.11
EXCLUDE_FROM_ALL
)
FetchContent_MakeAvailable(onednn)
message(STATUS "OneDNN BRGeMM micro-kernel support enabled")
endif()

# Base source files
set(SOURCES
compression/compress-inl.h
Expand Down Expand Up @@ -141,6 +162,8 @@ set(SOURCES
ops/matmul-inl.h
ops/matmul.cc
ops/matmul.h
ops/brgemm.h
ops/brgemm-inl.h
ops/ops-inl.h
ops/ops.h
ops/sum-inl.h
Expand Down Expand Up @@ -191,6 +214,10 @@ target_link_libraries(libgemma hwy hwy_contrib sentencepiece-static)
target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR})
target_compile_definitions(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>)
target_compile_options(libgemma PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
if(GEMMA_ONEDNN_BRGEMM)
target_compile_definitions(libgemma PUBLIC GEMMA_ONEDNN_BRGEMM=1 DNNL_EXPERIMENTAL_UKERNEL)
target_link_libraries(libgemma dnnl)
endif()
install(TARGETS libgemma DESTINATION lib)

# Shared library target for C# interop
Expand All @@ -215,6 +242,10 @@ target_compile_definitions(gemma_shared
$<$<PLATFORM_ID:Windows>:_CRT_SECURE_NO_WARNINGS NOMINMAX>
)
target_compile_options(gemma_shared PRIVATE $<$<PLATFORM_ID:Windows>:-Wno-deprecated-declarations>)
if(GEMMA_ONEDNN_BRGEMM)
target_compile_definitions(gemma_shared PUBLIC GEMMA_ONEDNN_BRGEMM=1 DNNL_EXPERIMENTAL_UKERNEL)
target_link_libraries(gemma_shared PRIVATE dnnl)
endif()
install(TARGETS gemma_shared DESTINATION lib)
install(FILES gemma/c_api.h DESTINATION include/gemma)
install(FILES gemma/GemmaInterop.cs DESTINATION include/gemma)
Expand Down
11 changes: 11 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ git_override(

http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

# OneDNN v3.11 for BRGeMM micro-kernel support (optional, x86-64 only).
http_archive(
name = "onednn",
build_file = "@//bazel:onednn.BUILD",
sha256 = "04df98b18300daf6c3aa7cc2d5e7ce8a8f430fed1787151daed0254d8dd4e64e",
strip_prefix = "oneDNN-3.11",
urls = [
"https://github.com/uxlfoundation/oneDNN/archive/refs/tags/v3.11.tar.gz",
],
)

http_archive(
name = "com_google_absl_py",
sha256 = "8a3d0830e4eb4f66c4fa907c06edf6ce1c719ced811a12e26d9d3162f8471758",
Expand Down
227 changes: 227 additions & 0 deletions bazel/onednn.BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")

exports_files(["LICENSE"])

expand_template(
name = "dnnl_config_h",
out = "include/oneapi/dnnl/dnnl_config.h",
substitutions = {
"#cmakedefine DNNL_EXPERIMENTAL_UKERNEL": "#define DNNL_EXPERIMENTAL_UKERNEL 1",
"#cmakedefine DNNL_SAFE_RBP": "#undef DNNL_SAFE_RBP",
"#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_SEQ",
"#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_SEQ",
"#cmakedefine DNNL_DISABLE_GPU_REF_KERNELS": "#define DNNL_DISABLE_GPU_REF_KERNELS",
"#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE",
"#cmakedefine DNNL_GPU_VENDOR DNNL_VENDOR_${DNNL_GPU_VENDOR}": "#define DNNL_GPU_VENDOR DNNL_VENDOR_NONE",
"#cmakedefine DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE": "#undef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE",
"#cmakedefine DNNL_WITH_SYCL": "#undef DNNL_WITH_SYCL",
"#cmakedefine DNNL_WITH_LEVEL_ZERO": "#undef DNNL_WITH_LEVEL_ZERO",
"#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA",
"#cmakedefine DNNL_SYCL_GENERIC": "#undef DNNL_SYCL_GENERIC",
"#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP",
"#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER",
"#cmakedefine ONEDNN_BUILD_GRAPH": "#define ONEDNN_BUILD_GRAPH",
"#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#undef DNNL_EXPERIMENTAL_SPARSE",
"#cmakedefine DNNL_EXPERIMENTAL_LOGGING": "#undef DNNL_EXPERIMENTAL_LOGGING",
"#cmakedefine DNNL_EXPERIMENTAL_PROFILING": "#undef DNNL_EXPERIMENTAL_PROFILING",
"#cmakedefine DNNL_EXPERIMENTAL_SYCL_KERNEL_COMPILER": "#undef DNNL_EXPERIMENTAL_SYCL_KERNEL_COMPILER",
"#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL",
"#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1",
"#cmakedefine01 BUILD_INFERENCE": "#define BUILD_INFERENCE 0",
"#cmakedefine01 BUILD_PRIMITIVE_ALL": "#define BUILD_PRIMITIVE_ALL 1",
"#cmakedefine01 BUILD_BATCH_NORMALIZATION": "#define BUILD_BATCH_NORMALIZATION 0",
"#cmakedefine01 BUILD_BINARY": "#define BUILD_BINARY 0",
"#cmakedefine01 BUILD_CONCAT": "#define BUILD_CONCAT 0",
"#cmakedefine01 BUILD_CONVOLUTION": "#define BUILD_CONVOLUTION 0",
"#cmakedefine01 BUILD_DECONVOLUTION": "#define BUILD_DECONVOLUTION 0",
"#cmakedefine01 BUILD_ELTWISE": "#define BUILD_ELTWISE 0",
"#cmakedefine01 BUILD_GEMM_KERNELS_ALL": "#define BUILD_GEMM_KERNELS_ALL 1",
"#cmakedefine01 BUILD_GEMM_KERNELS_NONE": "#define BUILD_GEMM_KERNELS_NONE 0",
"#cmakedefine01 BUILD_GEMM_SSE41": "#define BUILD_GEMM_SSE41 1",
"#cmakedefine01 BUILD_GEMM_AVX2": "#define BUILD_GEMM_AVX2 1",
"#cmakedefine01 BUILD_GEMM_AVX512": "#define BUILD_GEMM_AVX512 1",
"#cmakedefine01 BUILD_GROUP_NORMALIZATION": "#define BUILD_GROUP_NORMALIZATION 1",
"#cmakedefine01 BUILD_INNER_PRODUCT": "#define BUILD_INNER_PRODUCT 0",
"#cmakedefine01 BUILD_LAYER_NORMALIZATION": "#define BUILD_LAYER_NORMALIZATION 0",
"#cmakedefine01 BUILD_LRN": "#define BUILD_LRN 0",
"#cmakedefine01 BUILD_MATMUL": "#define BUILD_MATMUL 0",
"#cmakedefine01 BUILD_POOLING": "#define BUILD_POOLING 0",
"#cmakedefine01 BUILD_PRELU": "#define BUILD_PRELU 0",
"#cmakedefine01 BUILD_REDUCTION": "#define BUILD_REDUCTION 0",
"#cmakedefine01 BUILD_REORDER": "#define BUILD_REORDER 0",
"#cmakedefine01 BUILD_RESAMPLING": "#define BUILD_RESAMPLING 0",
"#cmakedefine01 BUILD_RNN": "#define BUILD_RNN 0",
"#cmakedefine01 BUILD_SHUFFLE": "#define BUILD_SHUFFLE 0",
"#cmakedefine01 BUILD_SOFTMAX": "#define BUILD_SOFTMAX 0",
"#cmakedefine01 BUILD_SUM": "#define BUILD_SUM 0",
"#cmakedefine01 BUILD_PRIMITIVE_CPU_ISA_ALL": "#define BUILD_PRIMITIVE_CPU_ISA_ALL 1",
"#cmakedefine01 BUILD_SSE41": "#define BUILD_SSE41 0",
"#cmakedefine01 BUILD_AVX2": "#define BUILD_AVX2 0",
"#cmakedefine01 BUILD_AVX512": "#define BUILD_AVX512 0",
"#cmakedefine01 BUILD_AMX": "#define BUILD_AMX 0",
"#cmakedefine01 BUILD_PRIMITIVE_GPU_ISA_ALL": "#define BUILD_PRIMITIVE_GPU_ISA_ALL 0",
"#cmakedefine01 BUILD_XE2": "#define BUILD_XE2 0",
"#cmakedefine01 BUILD_XELP": "#define BUILD_XELP 0",
"#cmakedefine01 BUILD_XEHPG": "#define BUILD_XEHPG 0",
"#cmakedefine01 BUILD_XEHPC": "#define BUILD_XEHPC 0",
"#cmakedefine01 BUILD_XEHP": "#define BUILD_XEHP 0",
"#cmakedefine01 BUILD_SDPA": "#define BUILD_SDPA 1",
"#cmakedefine01 BUILD_XE3": "#define BUILD_XE3 0",
},
template = "include/oneapi/dnnl/dnnl_config.h.in",
)

expand_template(
name = "dnnl_version_h",
out = "include/oneapi/dnnl/dnnl_version.h",
substitutions = {
"@DNNL_VERSION_MAJOR@": "3",
"@DNNL_VERSION_MINOR@": "11",
"@DNNL_VERSION_PATCH@": "0",
},
template = "include/oneapi/dnnl/dnnl_version.h.in",
)

expand_template(
name = "dnnl_version_hash_h",
out = "include/oneapi/dnnl/dnnl_version_hash.h",
substitutions = {
"@DNNL_VERSION_HASH@": "fc6151651a4577beae5ffac5a4132e75d39e1409",
},
template = "include/oneapi/dnnl/dnnl_version_hash.h.in",
)

cc_library(
name = "onednn_autogen",
srcs = glob(["src/cpu/x64/gemm/**/*_kern_autogen*.cpp"]),
copts = [
"-O1",
"-U_FORTIFY_SOURCE",
"-fexceptions",
"-UUSE_MKL",
"-UUSE_CBLAS",
"-DDNNL_ENABLE_MAX_CPU_ISA",
"-DDNNL_ENABLE_ITT_TASKS",
"-DDNNL_ENABLE_GRAPH_DUMP",
"-DDNNL_EXPERIMENTAL_UKERNEL",
],
includes = [
"include",
"src",
"src/common",
"src/cpu",
"src/cpu/gemm",
"src/graph",
"third_party",
"third_party/ittnotify",
"third_party/xbyak",
],
textual_hdrs = glob([
"include/**/*",
"src/common/*.hpp",
"src/cpu/*.hpp",
"src/cpu/**/*.hpp",
"src/cpu/jit_utils/**/*.hpp",
"src/graph/interface/*.hpp",
"src/graph/backend/*.hpp",
"src/graph/backend/dnnl/*.hpp",
"src/graph/backend/dnnl/executables/*.hpp",
"src/graph/backend/fake/*.hpp",
"src/graph/backend/dnnl/passes/*.hpp",
"src/graph/backend/dnnl/patterns/*.hpp",
"src/graph/backend/dnnl/kernels/*.hpp",
"src/graph/utils/*.hpp",
"src/graph/utils/pm/*.hpp",
"third_party/ittnotify/**/*.h",
"third_party/spdlog/**/*.h",
"third_party/xbyak/*.h",
]) + [
":dnnl_config_h",
":dnnl_version_h",
":dnnl_version_hash_h",
],
visibility = ["//visibility:public"],
)

cc_library(
name = "onednn",
srcs = glob(
[
"src/common/*.cpp",
"src/cpu/*.cpp",
"src/cpu/**/*.cpp",
"src/cpu/jit_utils/**/*.cpp",
"src/cpu/x64/**/*.cpp",
"src/graph/interface/*.cpp",
"src/graph/backend/*.cpp",
"src/graph/backend/dnnl/*.cpp",
"src/graph/backend/dnnl/executables/*.cpp",
"src/graph/backend/fake/*.cpp",
"src/graph/backend/dnnl/passes/*.cpp",
"src/graph/backend/dnnl/patterns/*.cpp",
"src/graph/backend/dnnl/kernels/*.cpp",
"src/graph/utils/*.cpp",
"src/graph/utils/pm/*.cpp",
"third_party/ittnotify/*.c",
],
exclude = [
"src/cpu/aarch64/**",
"src/cpu/rv64/**",
"src/cpu/ppc64/**",
"src/cpu/s390x/**",
"src/cpu/x64/gemm/**/*_kern_autogen.cpp",
"src/cpu/sycl/**",
],
),
copts = [
"-fexceptions",
"-UUSE_MKL",
"-UUSE_CBLAS",
"-DDNNL_ENABLE_MAX_CPU_ISA",
"-DDNNL_ENABLE_ITT_TASKS",
"-DDNNL_ENABLE_GRAPH_DUMP",
"-DDNNL_EXPERIMENTAL_UKERNEL",
],
includes = [
"include",
"src",
"src/common",
"src/cpu",
"src/cpu/gemm",
"src/graph",
"third_party",
"third_party/ittnotify",
"third_party/xbyak",
],
linkopts = [
"-lrt",
"-Wl,--allow-multiple-definition",
],
textual_hdrs = glob([
"include/**/*",
"src/common/*.hpp",
"src/cpu/*.hpp",
"src/cpu/**/*.hpp",
"src/cpu/jit_utils/**/*.hpp",
"src/graph/interface/*.hpp",
"src/graph/backend/*.hpp",
"src/graph/backend/dnnl/*.hpp",
"src/graph/backend/fake/*.hpp",
"src/graph/backend/dnnl/passes/*.hpp",
"src/graph/backend/dnnl/patterns/*.hpp",
"src/graph/backend/dnnl/kernels/*.hpp",
"src/graph/utils/*.hpp",
"src/graph/utils/pm/*.hpp",
"third_party/ittnotify/**/*.h",
"third_party/spdlog/**/*.h",
"third_party/xbyak/*.h",
]) + [
":dnnl_config_h",
":dnnl_version_h",
":dnnl_version_hash_h",
],
visibility = ["//visibility:public"],
deps = [
":onednn_autogen",
],
)
Loading
Loading