From 0dff925df5c867032f53c1c6bd8780be44da93f7 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Tue, 31 Mar 2026 14:18:45 +0800 Subject: [PATCH 01/10] issue/1083: gptq_marlin_gemm --- include/infiniop/ops/gptq_marlin_gemm.h | 42 + .../ops/gptq_marlin_gemm/gptq_marlin_gemm.h | 66 + src/infiniop/ops/gptq_marlin_gemm/info.h | 59 + .../marlin/awq_marlin_repack.cuh | 281 +++ .../ops/gptq_marlin_gemm/marlin/dequant.h | 504 +++++ .../gptq_marlin_gemm/marlin/gptq_marlin.cuh | 1085 ++++++++++ .../marlin/gptq_marlin_repack.cuh | 398 ++++ .../ops/gptq_marlin_gemm/marlin/kernel.h | 34 + .../ops/gptq_marlin_gemm/marlin/marlin.cuh | 92 + .../gptq_marlin_gemm/marlin/marlin_dtypes.cuh | 78 + .../gptq_marlin_gemm/marlin/marlin_template.h | 1917 +++++++++++++++++ .../nvidia/gptq_marlin_gemm_nvidia.cu | 1141 ++++++++++ .../nvidia/gptq_marlin_gemm_nvidia.cuh | 8 + src/infiniop/ops/gptq_marlin_gemm/operator.cc | 120 ++ .../sgl_kernel/scalar_type.hpp | 335 +++ .../sgl_kernel/source_location.h | 41 + .../ops/gptq_marlin_gemm/sgl_kernel/tensor.h | 621 ++++++ .../ops/gptq_marlin_gemm/sgl_kernel/utils.cuh | 310 +++ .../ops/gptq_marlin_gemm/sgl_kernel/utils.h | 241 +++ test/infiniop/gptq_marlin_gemm.py | 623 ++++++ xmake/nvidia.lua | 20 + 21 files changed, 8016 insertions(+) create mode 100644 include/infiniop/ops/gptq_marlin_gemm.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/gptq_marlin_gemm.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/info.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/marlin/awq_marlin_repack.cuh create mode 100644 src/infiniop/ops/gptq_marlin_gemm/marlin/dequant.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin.cuh create mode 100644 src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin_repack.cuh create mode 100644 src/infiniop/ops/gptq_marlin_gemm/marlin/kernel.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/marlin/marlin.cuh create mode 100644 src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_dtypes.cuh create mode 100644 src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_template.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu create mode 100644 src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cuh create mode 100644 src/infiniop/ops/gptq_marlin_gemm/operator.cc create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/scalar_type.hpp create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/source_location.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h create mode 100644 test/infiniop/gptq_marlin_gemm.py diff --git a/include/infiniop/ops/gptq_marlin_gemm.h b/include/infiniop/ops/gptq_marlin_gemm.h new file mode 100644 index 000000000..37e22baec --- /dev/null +++ b/include/infiniop/ops/gptq_marlin_gemm.h @@ -0,0 +1,42 @@ +#ifndef __INFINIOP_GPTQ_MARLIN_GEMM_API_H__ +#define __INFINIOP_GPTQ_MARLIN_GEMM_API_H__ + +#include "../operator_descriptor.h" +#include + +typedef struct InfiniopDescriptor *infiniopGptqMarlinGemmDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateGptqMarlinGemmDescriptor(infiniopHandle_t handle, + infiniopGptqMarlinGemmDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t global_scale_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t g_idx_desc, + infiniopTensorDescriptor_t perm_desc); + +__INFINI_C __export infiniStatus_t infiniopGetGptqMarlinGemmWorkspaceSize(infiniopGptqMarlinGemmDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopGptqMarlinGemm(infiniopGptqMarlinGemmDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + void *b_scales, + void *global_scale, + void *b_zeros, + void *g_idx, + void *perm, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyGptqMarlinGemmDescriptor(infiniopGptqMarlinGemmDescriptor_t desc); + +#endif diff --git a/src/infiniop/ops/gptq_marlin_gemm/gptq_marlin_gemm.h b/src/infiniop/ops/gptq_marlin_gemm/gptq_marlin_gemm.h new file mode 100644 index 000000000..4b02f5e1b --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/gptq_marlin_gemm.h @@ -0,0 +1,66 @@ +#ifndef __GPTQ_MARLIN_GEMM_H__ +#define __GPTQ_MARLIN_GEMM_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::gptq_marlin_gemm::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + GptqMarlinGemmInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + size_t workspace_size_, \ + Opaque *opaque, \ + GptqMarlinGemmInfo info, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t a_desc, \ + infiniopTensorDescriptor_t b_desc, \ + infiniopTensorDescriptor_t b_scales_desc, \ + infiniopTensorDescriptor_t global_scale_desc, \ + infiniopTensorDescriptor_t b_zeros_desc, \ + infiniopTensorDescriptor_t g_idx_desc, \ + infiniopTensorDescriptor_t perm_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + void *out, \ + const void *a, \ + const void *b, \ + void *b_scales, \ + void *global_scale, \ + void *b_zeros, \ + void *g_idx, \ + void *perm, \ + int64_t b_q_type_id, \ + bool is_k_full, \ + bool use_atomic_add, \ + bool use_fp32_reduce, \ + bool is_zp_float, \ + void *stream) const; \ + }; \ + } + +#endif //__GPTQ_MARLIN_GEMM_H__ diff --git a/src/infiniop/ops/gptq_marlin_gemm/info.h b/src/infiniop/ops/gptq_marlin_gemm/info.h new file mode 100644 index 000000000..422a53e3e --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/info.h @@ -0,0 +1,59 @@ +#ifndef __GPTQ_MARLIN_GEMM_INFO_H__ +#define __GPTQ_MARLIN_GEMM_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include + +#include + +namespace op::gptq_marlin_gemm { + +class GptqMarlinGemmInfo { + GptqMarlinGemmInfo() = default; + +public: + infiniDtype_t dtype; + size_t M, K, N, b_q_size_1; + int num_groups; + ptrdiff_t a_stride_0; + + static utils::Result create( + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t global_scale_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t g_idx_desc, + infiniopTensorDescriptor_t perm_desc) { + CHECK_OR_RETURN( + out_desc != nullptr && a_desc != nullptr && b_desc != nullptr && b_scales_desc != nullptr, + INFINI_STATUS_NULL_POINTER); + const infiniDtype_t dtype = a_desc->dtype(); + size_t M = out_desc->dim(0); + size_t N = out_desc->dim(1); + size_t K = a_desc->dim(1); + size_t b_q_size_1 = b_desc->dim(1); + int num_groups = static_cast(b_scales_desc->dim(0)); + ptrdiff_t a_stride_0 = a_desc->strides()[0]; + + auto ndim = out_desc->ndim(); + CHECK_OR_RETURN(ndim == 2 + && a_desc->ndim() == ndim + && b_desc->ndim() == ndim + && b_scales_desc->ndim() == ndim, + INFINI_STATUS_BAD_TENSOR_SHAPE); + + CHECK_OR_RETURN(b_scales_desc->shape()[1] == N + && a_stride_0 % 8 == 0, + INFINI_STATUS_BAD_TENSOR_SHAPE); + + return utils::Result( + GptqMarlinGemmInfo{dtype, M, K, N, b_q_size_1, num_groups, a_stride_0}); + } +}; + +} // namespace op::gptq_marlin_gemm + +#endif // __GPTQ_MARLIN_GEMM_INFO_H__ diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/awq_marlin_repack.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/awq_marlin_repack.cuh new file mode 100644 index 000000000..2aea26529 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/awq_marlin_repack.cuh @@ -0,0 +1,281 @@ +#pragma once + +#include "../sgl_kernel/tensor.h" + +#include "../sgl_kernel/utils.cuh" + +#include "marlin.cuh" + +namespace device::marlin +{ + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + template + __global__ void awq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, uint32_t *__restrict__ out_ptr, int size_k, int size_n) + { + return; + } +#else + + template + __global__ void awq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, uint32_t *__restrict__ out_ptr, int size_k, int size_n) + { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, (int)gridDim.x); + + auto start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) + { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() + { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int tile_n_ints = tile_n_size / pack_factor; + + constexpr int stage_n_threads = tile_n_ints / 4; + constexpr int stage_k_threads = tile_k_size; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) + { + if (n_tile_id >= n_tiles) + { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + int first_n_packed = first_n / pack_factor; + + int4 *sh_ptr = sh + stage_size * pipe; + + if (threadIdx.x < stage_size) + { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + first_n_packed + (n_id * 4)]))); + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) + { + if (n_tile_id >= n_tiles) + { + return; + } + + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; + + if (warp_id >= 4) + { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + int cur_n_packed = cur_n / pack_factor; + int cur_n_pos = cur_n % pack_factor; + + constexpr int sh_stride = tile_n_ints; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4 *sh_stage_ptr = sh + stage_size * pipe; + uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + // Undo interleaving + int cur_n_pos_unpacked; + if constexpr (num_bits == 4) + { + constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } + else + { + constexpr int undo_pack[4] = {0, 2, 1, 3}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } + + uint32_t vals[8]; +#pragma unroll + for (int i = 0; i < 4; i++) + { + int cur_elem = tc_row + tc_offsets[i]; + + int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stride * cur_elem]; + + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } + + constexpr int tile_size_val = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size_val; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) + { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) + { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + } + else + { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; +#pragma unroll + for (int i = 0; i < 4; i++) + { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) + { +#pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) + { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; +#pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) + { + int n_tile_id = 0; + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) + { +#pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) + { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } + } +#endif + +} // namespace device::marlin + +// Host wrapper +void awq_marlin_repack( + tvm::ffi::TensorView out, tvm::ffi::TensorView b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) +{ + using namespace host; + using namespace device::marlin; + + // Validate alignment + RuntimeCheck(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size); + RuntimeCheck(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size); + RuntimeCheck(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); + + int const pack_factor = 32 / num_bits; + + // Validate tensors + SymbolicDevice cuda_device; + cuda_device.set_options(); + + TensorMatcher({size_k, size_n / pack_factor}).with_dtype().with_device(cuda_device).verify(b_q_weight); + + TensorMatcher({size_k / tile_size, size_n * tile_size / pack_factor}) + .with_dtype() + .with_device(cuda_device) + .verify(out); + + // Get device and stream + auto device = cuda_device.unwrap(); + auto stream = LaunchKernel::resolve_device(device); + + // Get pointers + auto *b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); + auto *out_ptr = reinterpret_cast(out.data_ptr()); + + // Get device attributes + int blocks = 0; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, device.device_id); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device.device_id); + RuntimeCheck(max_shared_mem > 0, "max_shared_mem must be > 0"); + + // Dispatch based on num_bits + if (num_bits == 4) + { + cudaFuncSetAttribute( + awq_marlin_repack_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); + LaunchKernel(blocks, repack_threads, stream, max_shared_mem)( + awq_marlin_repack_kernel, + b_q_weight_ptr, + out_ptr, + static_cast(size_k), + static_cast(size_n)); + } + else if (num_bits == 8) + { + cudaFuncSetAttribute( + awq_marlin_repack_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); + LaunchKernel(blocks, repack_threads, stream, max_shared_mem)( + awq_marlin_repack_kernel, + b_q_weight_ptr, + out_ptr, + static_cast(size_k), + static_cast(size_n)); + } + else + { + RuntimeCheck(false, "Unsupported repack config: num_bits = ", num_bits); + } +} diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/dequant.h b/src/infiniop/ops/gptq_marlin_gemm/marlin/dequant.h new file mode 100644 index 000000000..764375f62 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/dequant.h @@ -0,0 +1,504 @@ +/* +Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16) + +The process of fast dequantization can be summarized as a combination +of bitwise operations and floating-point computations: + +weight =>(bit_op / bitwise operations)=> +f16_value =>(flop / floating-point computation)=> +dequantized_weight + +Since the dequantized weights typically require subtracting the zero point and +applying a scale factor, the floating-point computation step can be fused with +the zero-point subtraction and scaling operations. + +The following are the parts that need to be modified for the fused operation +of zero-point subtraction and scaling. + +## INT4 => FP16/BF16 or INT8 => FP16 + +The floating-point computation is `__hsub2` + +If has zero points: + + flop(bit_op(weight)) - flop(bit_op(zp)) + = sub(bit_op(weight), bias) - sub(bit_op(zp), bias) + = bit_op(weight) - bit_op(zp) + +so we don't need additional modification. + +If has float zero points: + + flop(bit_op(weight)) - fzp + = sub(bit_op(weight), bias) - fzp + = bit_op(weight) - (fzp + bias) + +where the `fzp + bias` can be computed at weight loading. But this +may have accuracy issue, so we should not use this in most cases. + +If has not zero points: + + scale(flop(bit_op(weight))) + = scale(sub(bit_op(weight), bias)) + = scale(bit_op(weight)) - scale(bias) + = fma(bit_op(weight), scale_factor, scale(bias)) + +where the `scale(bias)` can be cached. But this may have accuracy issue, +so we should not use this in most cases. + + +## INT8 => BF16 + +INT8 => BF16 is a special case, it use byte_perm instead of flop. +We cannot fused byte_perm with scaling. + + +## FP4/FP8 => FP16/BF16 + + scale(flop(bit_op(weight))) + = scale(mul(bit_op(weight), multiplier)) + = mul(bit_op(weight), scale_factor * multiplier) + +where `scale_factor * multiplier` can be computed at weight loading. + +*/ + +#include "marlin_dtypes.cuh" + +namespace device::marlin { + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline void dequant(int q, scalar_t2* frag_b); + +// +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// +template <> +__device__ inline void dequant(int q, half2* frag_b) { + const int MASK = 0x000f000f; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + // clang-format on + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43084308; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43004300; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); +} + +// +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +// +template <> +__device__ inline void dequant(int q, half2* frag_b) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to bfloat162 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and BF16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template +__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); + +template <> +__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +template <> +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +// New version with s_type_id parameter for marlin_moe_wna16_v2 +template +__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); + +template <> +__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +template <> +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { + // In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16, + // but we assume that such a extreme value would not occur in real models. + int Out1 = (q & 0xFF00FF00) >> 1; + q <<= 7; + int Out2 = q & 0x7F807F80; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +#endif + +} // namespace device::marlin diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin.cuh new file mode 100644 index 000000000..653501357 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin.cuh @@ -0,0 +1,1085 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#pragma once + +#include "../sgl_kernel/tensor.h" + +#include "../sgl_kernel/scalar_type.hpp" + +#include "kernel.h" +#include "marlin_template.h" + +namespace device::marlin +{ + + __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; + + using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + + __global__ void permute_cols_kernel( + int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) {} + +#else + + // For a given "a" of size [M,K] performs a permutation of the K columns based + // on the given "perm" indices. + __global__ void permute_cols_kernel( + int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) + { + auto start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) + { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int input_row_stride = lda * sizeof(half) / 16; + int output_row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) + { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int input_offset = row * input_row_stride; + int output_offset = row * output_row_stride; + + half const *a_row_half = reinterpret_cast(a_int4_ptr + input_offset); + half *out_half = reinterpret_cast(out_int4_ptr + output_offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) + { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) + { + if (threadIdx.x < rest) + { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) + { + int cur_row = start_row + i; + if (cur_row < size_m) + { + permute_row(cur_row); + } + } + } + + typedef struct + { + int thread_k; + int thread_n; + int num_threads; + } thread_config_t; + + thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}}; + + thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}}; + + typedef struct + { + int blocks_per_sm; + thread_config_t tb_cfg; + } exec_config_t; + + int get_scales_cache_size( + thread_config_t const &th_config, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full) + { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) + { + tb_groups = 1; + } + else if (group_size == 0) + { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } + else + { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) + { + int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + } + else + { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } + } + + int get_kernel_cache_size( + thread_config_t const &th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float) + { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + int tb_m = thread_m_blocks * 16; + int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_red_size = tb_m * (tb_n + 8); + int sh_s_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); + int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + int sh_zp_size = 0; + if (has_zp) + { + if (is_zp_float) + sh_zp_size = sh_s_size; + else if (num_bits == 4) + sh_zp_size = sh_s_size / 4; + else if (num_bits == 8) + sh_zp_size = sh_s_size / 2; + } + + int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size; + + return total_size; + } + + bool is_valid_config( + thread_config_t const &th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float, + int max_shared_mem) + { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) + { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) + { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) + { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) + { + return false; + } + + // Check that pipeline fits into cache + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + return cache_size <= max_shared_mem; + } + +#define _GET_IF( \ + W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if ( \ + q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) \ + { \ + kernel = Marlin< \ + scalar_t, \ + W_TYPE.id(), \ + NUM_THREADS, \ + THREAD_M_BLOCKS, \ + THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, \ + pipe_stages, \ + GROUP_BLOCKS, \ + IS_ZP_FLOAT>; \ + } + +// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) +// this is the most common cases +// BIGGROUP: cases for big group size (group_blocks in [-1, 8]) +// FZP: cases for float-zero-point (is_zp_float = true) +// ACT: cases for act order case (group_blocks == 0) +// FP4: cases for nvfp4(e2m1) (group_blocks == 1) +#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF(W_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ + COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF(W_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 4, 8, 128) + +// We currently have 4-bit models only with group_blocks == 4 +#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF(W_TYPE) \ + FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, 4, 8, 128) + +// We currently have 4-bit models only with group_blocks == 4 +#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF(W_TYPE) \ + ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ + ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, 4, 8, 128) + + template + MarlinFuncPtr get_marlin_kernel( + const host::ScalarType q_type, + int thread_m_blocks, + int thread_n_blocks, + int thread_k_blocks, + bool m_block_size_8, + bool has_act_order, + bool has_zp, + int group_blocks, + int num_threads, + bool is_zp_float) + { + int num_bits = q_type.size_bits(); + auto kernel = MarlinDefault; + if (false) + { + } + + COMMON_GET_IF(host::kU4) + COMMON_GET_IF(host::kU4B8) + COMMON_GET_IF(host::kU8B128) + + FP4_GET_IF(host::kFE2M1f) + + BIGGROUP_GET_IF(host::kFE4M3fn) + + ACT_GET_IF(host::kU4B8) + ACT_GET_IF(host::kU8B128) + + if (std::is_same::value) + { + if (false) + { + } + FZP_GET_IF(host::kU4) + } + + return kernel; + } + + template + exec_config_t determine_exec_config( + const host::ScalarType &q_type, + int prob_m, + int prob_n, + int prob_k, + int thread_m_blocks, + bool m_block_size_8, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + bool has_zp, + bool is_zp_float, + int max_shared_mem, + int sms) + { + exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; + thread_config_t *thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs : small_batch_thread_configs; + int thread_configs_size = thread_m_blocks > 1 ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) + : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + + for (int i = 0; i < thread_configs_size; i++) + { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem)) + { + continue; + } + + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + + int group_blocks = 0; + if (!has_act_order) + { + group_blocks = group_size == -1 ? -1 : group_size / 16; + } + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + th_config.thread_n / 16, + th_config.thread_k / 16, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + th_config.num_threads, + is_zp_float); + + if (kernel == MarlinDefault) + continue; + + // int m_tiles = div_ceil(prob_m, thread_m_blocks * 16); + // int n_tiles = prob_n / th_config.thread_n; + // int k_tiles = prob_k / th_config.thread_k; + + return {1, th_config}; + } + + return exec_cfg; + } + + template + void marlin_mm( + const void *A, + const void *B, + void *C, + void *C_tmp, + void *s, + void *s2, + void *zp, + void *g_idx, + void *perm, + void *a_tmp, + int prob_m, + int prob_n, + int prob_k, + int lda, + void *workspace, + host::ScalarType const &q_type, + bool has_act_order, + bool is_k_full, + bool has_zp, + int num_groups, + int group_size, + int dev, + cudaStream_t stream, + int thread_k_init, + int thread_n_init, + int sms, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) + { + if (has_zp) + { + host::RuntimeCheck( + q_type == host::kU4 || q_type == host::kU8, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } + else + { + host::RuntimeCheck( + q_type == host::kU4B8 || q_type == host::kU8B128 || q_type == host::kFE4M3fn || q_type == host::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); + } + + host::RuntimeCheck( + prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); + + int group_blocks = 0; + if (has_act_order) + { + if (is_k_full) + { + host::RuntimeCheck(group_size != -1); + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } + else + { + host::RuntimeCheck(group_size == 0); + group_blocks = 0; + } + } + else + { + if (group_size == -1) + { + group_blocks = -1; + } + else + { + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } + } + + int num_bits = q_type.size_bits(); + const int4 *A_ptr = (const int4 *)A; + const int4 *B_ptr = (const int4 *)B; + int4 *C_ptr = (int4 *)C; + int4 *C_tmp_ptr = (int4 *)C_tmp; + const int4 *s_ptr = (const int4 *)s; + const uint16_t *s2_ptr = (const uint16_t *)s2; + const int4 *zp_ptr = (const int4 *)zp; + const int *g_idx_ptr = (const int *)g_idx; + const int *perm_ptr = (const int *)perm; + int4 *a_tmp_ptr = (int4 *)a_tmp; + + int *locks = (int *)workspace; + + if (has_act_order) + { + // Permute A columns + int block_rows = div_ceil(prob_m, sms); + host::LaunchKernel(sms, default_threads, stream)( + permute_cols_kernel, A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows); + A_ptr = a_tmp_ptr; + lda = prob_k; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) + has_act_order = false; + } + + int max_shared_mem = 0; + host::RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); + host::RuntimeCheck(max_shared_mem > 0); + + int max_par = 16; + if (prob_n <= 4096) + max_par = 16 * 8; + int max_shared_mem_new = max_shared_mem; + int rest_m = prob_m; + int max_thread_m_blocks = 4; + while (rest_m) + { + int par_count = rest_m / (max_thread_m_blocks * 16); + if (par_count > max_par) + par_count = max_par; + int prob_m_split = par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m; + + int thread_k = thread_k_init; + int thread_n = thread_n_init; + + int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); + int m_block_size_8 = prob_m_split <= 8; + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) + { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + host::RuntimeCheck(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); + host::RuntimeCheck(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); + } + else + { + // Auto config + exec_cfg = determine_exec_config( + q_type, + prob_m_split, + prob_n, + prob_k, + thread_m_blocks, + m_block_size_8, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem, + sms); + thread_tfg = exec_cfg.tb_cfg; + if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) + { + max_thread_m_blocks--; + continue; + } + } + + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) + max_shared_mem_new = max_shared_mem / exec_cfg.blocks_per_sm - 1024; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + host::RuntimeCheck( + is_valid_config( + thread_tfg, + thread_m_blocks, + prob_m_split, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem_new), + "Invalid thread config: thread_m_blocks = ", + thread_m_blocks, + ", thread_k = ", + thread_tfg.thread_k, + ", thread_n = ", + thread_tfg.thread_n, + ", num_threads = ", + thread_tfg.num_threads, + " for MKN = [", + prob_m, + ", ", + prob_k, + ", ", + prob_n, + "] and num_bits = ", + num_bits, + ", prob_m_split = ", + prob_m_split, + ", group_size = ", + group_size, + ", has_act_order = ", + has_act_order, + ", is_k_full = ", + is_k_full, + ", has_zp = ", + has_zp, + ", is_zp_float = ", + is_zp_float, + ", max_shared_mem_new = ", + max_shared_mem_new); + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + thread_n_blocks, + thread_k_blocks, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + num_threads, + is_zp_float); + + if (kernel == MarlinDefault) + { + host::Panic( + "Unsupported shapes: MNK = [", + prob_m, + ", ", + prob_n, + ", ", + prob_k, + "]", + ", has_act_order = ", + has_act_order, + ", num_groups = ", + num_groups, + ", group_size = ", + group_size, + ", prob_m_split = ", + prob_m_split, + ", thread_m_blocks = ", + thread_m_blocks, + ", thread_n_blocks = ", + thread_n_blocks, + ", thread_k_blocks = ", + thread_k_blocks, + ", num_threads = ", + num_threads, + ", num_bits = ", + num_bits); + } + + host::RuntimeDeviceCheck( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem_new)); + + bool part_use_atomic_add = use_atomic_add && div_ceil(prob_m_split, 64) * prob_n <= 2048; + + host::LaunchKernel(blocks, num_threads, stream, max_shared_mem_new)( + kernel, + A_ptr, + B_ptr, + C_ptr, + C_tmp_ptr, + s_ptr, + s2_ptr, + zp_ptr, + g_idx_ptr, + num_groups, + prob_m_split, + prob_n, + prob_k, + lda, + locks, + part_use_atomic_add, + use_fp32_reduce, + max_shared_mem_new); + + A_ptr += prob_m_split * (lda / 8); + C_ptr += prob_m_split * (prob_n / 8); + rest_m -= prob_m_split; + } + } + +#endif + +} // namespace device::marlin + +template +void gptq_marlin_gemm( + tvm::ffi::TensorView a, + tvm::ffi::TensorView b_q_weight, + tvm::ffi::TensorView b_scales, + tvm::ffi::TensorView global_scale, + tvm::ffi::TensorView b_zeros, + tvm::ffi::TensorView g_idx, + tvm::ffi::TensorView perm, + tvm::ffi::TensorView c, + tvm::ffi::TensorView c_tmp, + tvm::ffi::TensorView a_tmp, + tvm::ffi::TensorView workspace, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) +{ + using namespace host; + + ScalarType const b_q_type = ScalarType::from_id(b_q_type_id); + int pack_factor = 32 / b_q_type.size_bits(); + + // Bind symbolic sizes + auto M = SymbolicSize{"M"}; + auto K = SymbolicSize{"K"}; + auto N = SymbolicSize{"N"}; + auto device = SymbolicDevice{}; + device.set_options(); + + // Verify a: [M, K] + auto lda = SymbolicSize{"lda"}; + TensorMatcher({M, K}).with_strides({lda, 1}).with_dtype().with_device(device).verify(a); + + int64_t size_m = M.unwrap(); + int64_t size_k = K.unwrap(); + + // Verify b_q_weight: [K/tile_size, packed_N] + RuntimeCheck( + size_k % device::marlin::tile_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_size = ", + device::marlin::tile_size); + int64_t expected_bqw_dim0 = size_k / device::marlin::tile_size; + auto bqw_dim0 = SymbolicSize{"bqw_dim0"}; + auto bqw_dim1 = SymbolicSize{"bqw_dim1"}; + bqw_dim0.set_value(expected_bqw_dim0); + TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype().with_device(device).verify(b_q_weight); + + RuntimeCheck( + b_q_weight.size(1) % device::marlin::tile_size == 0, + "b_q_weight.size(1) = ", + b_q_weight.size(1), + " is not divisible by tile_size = ", + device::marlin::tile_size); + int64_t actual_size_n = (b_q_weight.size(1) / device::marlin::tile_size) * pack_factor; + N.set_value(actual_size_n); + int64_t size_n = N.unwrap(); + + // Verify stride alignment + int64_t a_stride0 = a.stride(0); + RuntimeCheck(a_stride0 % 8 == 0, "a.stride(0) must be divisible by 8"); + + // Verify b_scales: [num_groups, N] + auto num_groups_sym = SymbolicSize{"num_groups"}; + TensorMatcher({num_groups_sym, N}).with_device(device).verify(b_scales); + int num_groups = static_cast(num_groups_sym.unwrap()); + + // Verify c: [M, N] + TensorMatcher({M, N}).with_dtype().with_device(device).verify(c); + + // Early return for zero-size M + if (size_m == 0) + return; + + // Determine has_act_order from g_idx/perm sizes + int64_t g_idx_size = g_idx.size(0); + int64_t perm_size = perm.size(0); + bool has_act_order = g_idx_size > 0 && perm_size > 0; + + if (has_act_order) + { + RuntimeCheck( + (g_idx_size == size_k && perm_size == size_k), + "Unexpected g_idx.size(0) = ", + g_idx_size, + " and perm.size(0) = ", + perm_size, + ", where size_k = ", + size_k); + } + + // Determine has_zp from b_zeros size + int64_t b_zeros_size = b_zeros.size(0); + bool has_zp = b_zeros_size > 0; + + if (has_zp) + { + RuntimeCheck( + b_q_type == kU4 || b_q_type == kU8, "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + } + else + { + RuntimeCheck( + b_q_type == kU4B8 || b_q_type == kU8B128 || b_q_type == kFE4M3fn || b_q_type == kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + b_q_type.str()); + } + + if (has_zp && is_zp_float) + { + RuntimeCheck( + std::is_same::value, "Computation type must be float16 (half) when using float zero points."); + } + + // Verify b_zeros shape + if (has_zp) + { + RuntimeCheck(b_zeros.dim() == 2, "b_zeros rank = ", b_zeros.dim(), " is not 2"); + if (is_zp_float) + { + RuntimeCheck(b_zeros.size(1) == size_n, "b_zeros dim 1 = ", b_zeros.size(1), " is not size_n = ", size_n); + RuntimeCheck( + num_groups == b_zeros.size(0), "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); + RuntimeCheck(num_groups != -1, "num_groups must be != -1"); + } + else + { + RuntimeCheck( + b_zeros.size(0) == num_groups, "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); + RuntimeCheck( + b_zeros.size(1) == size_n / pack_factor, + "b_zeros dim 1 = ", + b_zeros.size(1), + " is not size_n / pack_factor = ", + size_n / pack_factor); + } + } + + // Verify global_scale + int64_t global_scale_size = global_scale.size(0); + if (global_scale_size > 0) + { + RuntimeCheck(b_q_type == kFE2M1f, "global_scale can only be used for float4_e2m1f."); + } + else + { + RuntimeCheck(!(b_q_type == kFE2M1f), "the global_scale parameter must be passed for float4_e2m1f."); + } + + // Derive group_size + int group_size = -1; + if (has_act_order) + { + if (is_k_full) + { + RuntimeCheck(num_groups > 1, "For act_order, num_groups must be > 1"); + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = static_cast(size_k / num_groups); + } + else + { + group_size = 0; + } + } + else + { + if (num_groups > 1) + { + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = static_cast(size_k / num_groups); + } + else + { + group_size = -1; + } + } + + // Verify workspace and get device info + RuntimeCheck( + size_n % device::marlin::min_thread_n == 0, + "size_n = ", + size_n, + ", is not divisible by min_thread_n = ", + device::marlin::min_thread_n); + + DLDevice dl_device = device.unwrap(); + int dev = dl_device.device_id; + cudaStream_t stream = LaunchKernel::resolve_device(dl_device); + + int sms = -1; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev)); + + RuntimeCheck( + workspace.size(0) >= sms, "workspace.size(0) = ", workspace.size(0), " is below min_workspace_size = ", sms); + + // Hardcoded defaults (auto config) + int thread_k_init = -1; + int thread_n_init = -1; + + // Compute c_tmp and a_tmp pointers + // c_tmp and a_tmp are pre-allocated by caller + + device::marlin::marlin_mm( + a.data_ptr(), + b_q_weight.data_ptr(), + c.data_ptr(), + c_tmp.data_ptr(), + b_scales.data_ptr(), + global_scale.data_ptr(), + b_zeros.data_ptr(), + g_idx.data_ptr(), + perm.data_ptr(), + a_tmp.data_ptr(), + static_cast(size_m), + static_cast(size_n), + static_cast(size_k), + static_cast(a_stride0), + workspace.data_ptr(), + b_q_type, + has_act_order, + is_k_full, + has_zp, + num_groups, + group_size, + dev, + stream, + thread_k_init, + thread_n_init, + sms, + use_atomic_add, + use_fp32_reduce, + is_zp_float); +} diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin_repack.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin_repack.cuh new file mode 100644 index 000000000..d0c2d5414 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin_repack.cuh @@ -0,0 +1,398 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#pragma once + +#include "../sgl_kernel/tensor.h" + +#include "../sgl_kernel/utils.cuh" + +#include "marlin.cuh" + +namespace device::marlin +{ + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + template + __global__ void gptq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, + uint32_t const *__restrict__ perm_ptr, + uint32_t *__restrict__ out_ptr, + int size_k, + int size_n) + { + return; + } +#else + template + __global__ void gptq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, + uint32_t const *__restrict__ perm_ptr, + uint32_t *__restrict__ out_ptr, + int size_k, + int size_n) + { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + auto start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) + { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() + { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int perm_size = tile_k_size / 4; + + int4 *sh_perm_ptr = sh; + int4 *sh_pipe_ptr = sh_perm_ptr; + if constexpr (has_perm) + { + sh_pipe_ptr += perm_size; + } + + constexpr int tile_ints = tile_k_size / pack_factor; + + constexpr int stage_n_threads = tile_n_size / 4; + constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto load_perm_to_shared = [&](int k_tile_id) + { + int first_k_int4 = (k_tile_id * tile_k_size) / 4; + + int4 const *perm_int4_ptr = reinterpret_cast(perm_ptr); + + if (threadIdx.x < perm_size) + { + sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; + } + __syncthreads(); + }; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) + { + if (n_tile_id >= n_tiles) + { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + + int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe; + + if constexpr (has_perm) + { + if (threadIdx.x < stage_size) + { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + uint32_t const *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + int src_k = sh_perm_int_ptr[k_id]; + int src_k_packed = src_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&(b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); + } + } + else + { + if (threadIdx.x < stage_size) + { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + int first_k_packed = first_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)]))); + } + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) + { + if (n_tile_id >= n_tiles) + { + return; + } + + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; + + if (warp_id >= 4) + { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + + constexpr int sh_stride = 64; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + uint32_t vals[8]; + + if constexpr (has_perm) + { + for (int i = 0; i < 4; i++) + { + int k_idx = tc_row + tc_offsets[i]; + + uint32_t src_k = sh_perm_int_ptr[k_idx]; + uint32_t src_k_pos = src_k % pack_factor; + + uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; + + uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; + + vals[i] = b1_cur_val; + vals[4 + i] = b2_cur_val; + } + } + else + { + uint32_t b1_vals[tile_ints]; + uint32_t b2_vals[tile_ints]; + +#pragma unroll + for (int i = 0; i < tile_ints; i++) + { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + } + +#pragma unroll + for (int i = 0; i < 4; i++) + { + int cur_elem = tc_row + tc_offsets[i]; + int cur_int = cur_elem / pack_factor; + int cur_pos = cur_elem % pack_factor; + + vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + } + } + + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) + { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) + { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + } + else + { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; +#pragma unroll + for (int i = 0; i < 4; i++) + { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) + { +#pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) + { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; +#pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) + { + int n_tile_id = 0; + + if constexpr (has_perm) + { + load_perm_to_shared(k_tile_id); + } + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) + { +#pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) + { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } + } +#endif + +} // namespace device::marlin + +#define CALL_IF_REPACK(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) \ + { \ + host::RuntimeDeviceCheck(cudaFuncSetAttribute( \ + device::marlin::gptq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem)); \ + host::LaunchKernel(blocks, device::marlin::repack_threads, stream, static_cast(max_shared_mem))( \ + device::marlin::gptq_marlin_repack_kernel, \ + b_q_weight_ptr, \ + perm_ptr, \ + out_ptr, \ + size_k, \ + size_n); \ + } + +void gptq_marlin_repack( + tvm::ffi::TensorView b_q_weight, + tvm::ffi::TensorView perm, + tvm::ffi::TensorView out, + int64_t size_k, + int64_t size_n, + int64_t num_bits) +{ + using namespace host; + + // Validate num_bits + RuntimeCheck(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / static_cast(num_bits); + + // Validate size alignment + RuntimeCheck( + size_k % device::marlin::tile_k_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_k_size = ", + device::marlin::tile_k_size); + RuntimeCheck( + size_n % device::marlin::tile_n_size == 0, + "size_n = ", + size_n, + " is not divisible by tile_n_size = ", + device::marlin::tile_n_size); + + // Validate b_q_weight + auto bqw_dim0 = SymbolicSize{"bqw_dim0"}; + auto bqw_dim1 = SymbolicSize{"bqw_dim1"}; + bqw_dim0.set_value(size_k / pack_factor); + bqw_dim1.set_value(size_n); + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype().with_device(device_).verify(b_q_weight); + + // Validate out + auto out_dim0 = SymbolicSize{"out_dim0"}; + auto out_dim1 = SymbolicSize{"out_dim1"}; + out_dim0.set_value(size_k / device::marlin::tile_size); + out_dim1.set_value(size_n * device::marlin::tile_size / pack_factor); + TensorMatcher({out_dim0, out_dim1}).with_dtype().with_device(device_).verify(out); + + // Detect if there is act_order + bool has_perm = perm.size(0) != 0; + + // Get ptrs + uint32_t const *b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); + uint32_t const *perm_ptr = reinterpret_cast(perm.data_ptr()); + uint32_t *out_ptr = reinterpret_cast(out.data_ptr()); + + // Get dev info + DLDevice dl_device = device_.unwrap(); + int dev = dl_device.device_id; + cudaStream_t stream = LaunchKernel::resolve_device(dl_device); + int blocks; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev)); + + int max_shared_mem = 0; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); + RuntimeCheck(max_shared_mem > 0, "max_shared_mem must be > 0"); + + if (false) + { + } + CALL_IF_REPACK(4, false) + CALL_IF_REPACK(4, true) + CALL_IF_REPACK(8, false) + CALL_IF_REPACK(8, true) + else + { + Panic("Unsupported repack config: num_bits = ", num_bits, ", has_perm = ", has_perm); + } +} + +#undef CALL_IF_REPACK diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/kernel.h b/src/infiniop/ops/gptq_marlin_gemm/marlin/kernel.h new file mode 100644 index 000000000..e0e36cdd4 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/kernel.h @@ -0,0 +1,34 @@ + +#include "../sgl_kernel/scalar_type.hpp" + +#include "marlin.cuh" +#include "marlin_dtypes.cuh" + +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, const uint16_t *__restrict__ scale2_ptr, const int4 *__restrict__ zp_ptr, \ + const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ + bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem + +namespace device::marlin +{ + template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > + __global__ void Marlin(MARLIN_KERNEL_PARAMS); + +} // namespace device::marlin diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin.cuh new file mode 100644 index 000000000..9e99d0f4d --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin.cuh @@ -0,0 +1,92 @@ +#pragma once + +#include "../sgl_kernel/utils.cuh" + +#include + +namespace device::marlin +{ + // Marlin params + + // 8 warps are a good choice since every SM has 4 schedulers and having more + // than 1 warp per schedule allows some more latency hiding. At the same time, + // we want relatively few warps to have many registers per warp and small tiles. + static constexpr int default_threads = 256; + + static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory + + static constexpr int min_thread_n = 64; + static constexpr int min_thread_k = 64; + static constexpr int max_thread_n = 256; + + static constexpr int tile_size = 16; + static constexpr int max_par = 16; + + // Repack params + static constexpr int repack_stages = 8; + + static constexpr int repack_threads = 256; + + static constexpr int tile_k_size = tile_size; + static constexpr int tile_n_size = tile_k_size * 4; + + // Helpers + template + struct Vec + { + T elems[n]; + __device__ T &operator[](int i) + { + return elems[i]; + } + }; + + using I4 = Vec; + + using host::div_ceil; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +// No support for async +#else + + __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, bool pred = true) + { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), + "l"(glob_ptr), + "n"(BYTES)); + } + + __device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) + { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), + "n"(BYTES)); + } + + __device__ inline void cp_async_fence() + { + asm volatile("cp.async.commit_group;\n" ::); + } + + template + __device__ inline void cp_async_wait() + { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); + } + +#endif + +} // namespace device::marlin diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_dtypes.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_dtypes.cuh new file mode 100644 index 000000000..783374ff2 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_dtypes.cuh @@ -0,0 +1,78 @@ +#ifndef _data_types_cuh +#define _data_types_cuh +#include "../sgl_kernel/utils.cuh" + +#include "marlin.cuh" + +namespace device::marlin { + +template +class ScalarType { +}; + +template <> +class ScalarType<__half> { +public: + using scalar_t = __half; + using scalar_t2 = fp16x2_t; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + using FragZP = Vec; + + static __device__ float inline num2float(const __half x) { + return __half2float(x); + } + + static __device__ fp16x2_t inline num2num2(const __half x) { + return __half2half2(x); + } + + static __device__ fp16x2_t inline nums2num2(const __half x1, const __half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ __half inline float2num(const float x) { + return __float2half(x); + } +}; + +template <> +class ScalarType<__nv_bfloat16> { +public: + using scalar_t = __nv_bfloat16; + using scalar_t2 = bf16x2_t; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + using FragZP = Vec; + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const __nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ bf16x2_t inline num2num2(const __nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ bf16x2_t inline nums2num2(const __nv_bfloat16 x1, const __nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ __nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } +#endif +}; + +} // namespace device::marlin + +#endif diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_template.h b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_template.h new file mode 100644 index 000000000..8f35f227d --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_template.h @@ -0,0 +1,1917 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ +#include "../sgl_kernel/scalar_type.hpp" + +#include "dequant.h" +#include "marlin.cuh" +#include "marlin_dtypes.cuh" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert( \ + std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace device::marlin +{ + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + + template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > + __global__ void Marlin( + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks, // extra global storage for barrier synchronization + bool use_fp32_reduce // whether to use fp32 global reduce + ) + { + } + +} // namespace device::marlin + +#else + + // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 + // output/accumulation. + template + __device__ inline void + mma(const typename ScalarType::FragA &a_frag, + const typename ScalarType::FragB &frag_b, + typename ScalarType::FragC &frag_c) + { + const uint32_t *a = reinterpret_cast(&a_frag); + const uint32_t *b = reinterpret_cast(&frag_b); + float *c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } + else if constexpr (std::is_same::value) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } + else + { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } + } + + template + __device__ inline void mma_trans( + const typename ScalarType::FragA &a_frag, + const typename ScalarType::FragB &frag_b, + const typename ScalarType::FragB &frag_b2, + typename ScalarType::FragC &frag_c) + { + const uint32_t *a = reinterpret_cast(&a_frag); + const uint32_t *b = reinterpret_cast(&frag_b); + const uint32_t *b2 = reinterpret_cast(&frag_b2); + float *c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), + "r"(b2[0]), + "r"(b[1]), + "r"(b2[1]), + "r"(a[0]), + "r"(a[1]), + "f"(c[0]), + "f"(c[1]), + "f"(c[2]), + "f"(c[3])); + } + else if constexpr (std::is_same::value) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), + "r"(b2[0]), + "r"(b[1]), + "r"(b2[1]), + "r"(a[0]), + "r"(a[1]), + "f"(c[0]), + "f"(c[1]), + "f"(c[2]), + "f"(c[3])); + } + else + { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } + } + + // Instruction for loading a full 16x16 matrix fragment of operand A from shared + // memory, directly in tensor core layout. + template + __device__ inline void ldsm(typename ScalarType::FragA &frag_a, const void *smem_ptr) + { + uint32_t *a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (count == 4) + { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } + else if constexpr (count == 2) + { + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) : "r"(smem)); + } + else if constexpr (count == 1) + { + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(a[0]) : "r"(smem)); + } + else + { + static_assert(count == 1 || count == 2 || count == 4, "invalid count"); + } + } + + // Multiply dequantized values by the corresponding quantization scale; used + // only for grouped quantization. + template + __device__ inline void + scale(typename ScalarType::FragB &frag_b, typename ScalarType::FragS &frag_s, int i) + { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); + } + + template + __device__ inline void scale_and_sub(typename ScalarType::FragB &frag_b, scalar_t s, scalar_t zp) + { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s2 = ScalarType::num2num2(s); + scalar_t2 zp2 = ScalarType::num2num2(zp); + frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); + frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); + } + + template + __device__ inline void + sub_zp(typename ScalarType::FragB &frag_b, typename ScalarType::scalar_t2 &frag_zp, int i) + { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); + } + + // Same as above, but for act_order (each K is multiplied individually) + template + __device__ inline void scale4( + typename ScalarType::FragB &frag_b, + typename ScalarType::FragS &frag_s_1, + typename ScalarType::FragS &frag_s_2, + typename ScalarType::FragS &frag_s_3, + typename ScalarType::FragS &frag_s_4, + int i) + { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); + } + + // Given 2 floats multiply by 2 scales (halves) + template + __device__ inline void scale_float(float *c, typename ScalarType::FragS &s) + { + scalar_t *s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); + } + + // Wait until barrier reaches `count`, then lock for current threadblock. + __device__ inline void barrier_acquire(int *lock, int count) + { + if (threadIdx.x == 0) + { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state != count); + } + __syncthreads(); + } + + // Release barrier and increment visitation count. + __device__ inline void barrier_release(int *lock, bool reset = false) + { + __syncthreads(); + if (threadIdx.x == 0) + { + if (reset) + { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } + } + + // Wait until value of lock to be negative, and then add 1 + __device__ inline void wait_negative_and_add(int *lock) + { + if (threadIdx.x == 0) + { + int state = 0; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state >= 0); + atomicAdd(lock, 1); + } + __syncthreads(); + } + + template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > + __global__ void Marlin( + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const uint16_t *__restrict__ scale2_ptr, // fp16 global scale (for nvfp4 + // only) + const int4 *__restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int lda, // A.stride(0), equal to prob_k is A is contiguous + int *locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) + { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + using FragZP = typename ScalarType::FragZP; + + static constexpr auto w_type = host::ScalarType::from_id(w_type_id); + constexpr bool has_zp = w_type == host::kU4 || w_type == host::kU8; + constexpr bool is_int_type = + w_type == host::kU4 || w_type == host::kU8 || w_type == host::kU4B8 || w_type == host::kU8B128; + // see comments of dequant.h for more details + constexpr bool dequant_skip_flop = !is_int_type || + has_zp && !is_zp_float && !std::is_same::value || + has_zp && !is_zp_float && !(w_type == host::kU8); + + scalar_t2 global_scale; + + if constexpr (w_type == host::kFE2M1f) + { + uint16_t val = scale2_ptr[0]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + } + + constexpr bool has_act_order = group_blocks == 0; + constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + + constexpr int pack_factor = 32 / w_type.size_bits(); + static_assert(thread_m_blocks == 1 || !m_block_size_8); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > m_block_size) + { + parallel = prob_m / m_block_size; + prob_m = m_block_size; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) + { + if (group_blocks >= thread_k_blocks) + { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + int par_id = 0; + int locks_off = 0; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) + { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; + } + if (parallel * n_tiles >= gridDim.x) + { + // when parallel * n_tiles >= sms + // then there are at most $sms$ conflict tile blocks + locks_off = blockIdx.x; + } + else + { + locks_off = (iters * blockIdx.x) / k_tiles - 1; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&](bool first_init = false) + { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) + { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) + slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else + { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) + slice_idx--; + } + } + if (parallel * n_tiles >= gridDim.x) + { + if (slice_count > 1 && slice_idx == slice_count - 1) + { + locks_off++; + } + } + else + { + locks_off++; + } + + if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) + { + constexpr int threads_per_m = 16 * thread_n_blocks / 8; + int m_per_thread = div_ceil(thread_m_blocks * 16, threads / threads_per_m); + if (m_block_size_8) + m_per_thread = div_ceil(8, threads / threads_per_m); + for (int i = 0; i < m_per_thread; i++) + { + int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; + if (row < prob_m) + { + int col = slice_col * 16 * thread_n_blocks / 8 + threadIdx.x % threads_per_m; + C[row * prob_n / 8 + col] = {0, 0, 0, 0}; + } + } + // After write zero to output, write a negative value to lock. + // Every SM that processes the same slice would wait for + // the negative value, and then atomicAdd 1 to it. + // After all SMs are processed, the lock value would back to 0 again. + __syncthreads(); + if (threadIdx.x == 0) + locks[locks_off] = 1 - slice_count; + } + + if (slice_col == n_tiles) + { + A += 16 * thread_m_blocks * lda / 8; + C += 16 * thread_m_blocks * prob_n / 8; + slice_col = 0; + par_id++; + } + }; + init_slice(true); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = lda / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * m_block_size; + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1) + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + constexpr int act_s_max_num_groups = 32; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_zp_float ? 16 * thread_n_blocks / 8 : ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + auto b_sh_wr = threadIdx.x * b_thread_vecs; + auto b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) + { + if constexpr (group_blocks == -1) + { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } + else + { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) + + s_sh_stride * slice_col + threadIdx.x; + } + } + auto s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) + { + if constexpr (group_blocks == -1) + { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + else + { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } + } + auto zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) + { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + } + else if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) + { + if constexpr (is_zp_float) + { + if constexpr (group_blocks != -1) + { + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + } + } + else + { + zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) + { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8); + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4 *B_ptr[b_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; + constexpr int sh_b_size = stages * b_sh_stage; + int4 *sh_b = sh; + int4 *sh_red = sh; + int4 *sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + int4 *sh_zp = sh_g_idx + (stages * g_idx_stage); + constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); + int4 *sh_s = sh_zp + (stages * zp_sh_stage); + // shared memory reused by reduction should be smaller than + // shared memory used by weight. + static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage); + int4 *sh_a = sh_s + sh_s_size; + // constexpr int shm_size_used = + // stages * (g_idx_stage + zp_sh_stage) + sh_s_size + + // (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + + // Zero accumulators. + auto zero_accums = [&]() + { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + + auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) + { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups > act_s_max_num_groups) + { + sh_num_groups = act_s_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) + { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) + { + for (int i = 0; i < sh_num_groups; i++) + { + if (threadIdx.x < s_sh_stride) + { + cp_async4_pred( + &sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); + } + } + } + else + { + for (int i = 0; i < sh_num_groups; i++) + { + if (threadIdx.x < s_sh_stride) + { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) + { + if (pred) + { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + { +#pragma unroll + for (int j = 0; j < b_thread_vecs; j++) + { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) + { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) + { + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const *cur_g_idx_stage_ptr = reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) + { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } + else + { + if constexpr (group_blocks != -1) + { + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) + { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) + { + if (s_sh_wr_pred) + { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + else + { + for (int i = 0; i < s_tb_groups; i++) + { + if (s_sh_wr_pred) + { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) + { + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) + { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) + { + if (zp_sh_wr_pred) + { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + else + { + for (int i = 0; i < zp_tb_groups; i++) + { + if (zp_sh_wr_pred) + { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_col_zp_to_shared = [&]() + { + if (zp_sh_wr_pred) + { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + auto fetch_col_scale_to_shared = [&]() + { + if (s_sh_wr_pred) + { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() + { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) + { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + +#pragma unroll + for (int i = 0; i < b_thread_vecs; i++) + { + frag_b_quant[k % 2][i] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) + { + if constexpr (!has_act_order) + { + return; + } + + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) + { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) + { + // No act-order case + if constexpr (group_blocks == -1) + { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) + { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + else if constexpr (group_blocks != -1) + { + if constexpr (group_blocks >= thread_k_blocks) + { + if (k % b_sh_wr_iters == 0) + { + int4 *sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } + else + { + reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + } + } + else + { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1)); + + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (w_type_id != host::kFE2M1f.id()) + { + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + else + { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) + { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + auto th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) + { + if (k % 2 == 0) + { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; + } + else + { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) + { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, 9}; // Tensor core offsets per thread + +#pragma unroll + for (int i = 0; i < 4; i++) + { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) + { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp && !is_zp_float) + { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) + { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) + { +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) + { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + } + } + else if constexpr (group_blocks >= thread_k_blocks) + { + if (k % b_sh_wr_iters == 0) + { + int4 *sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) + { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + else + { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning +#pragma nv_diagnostic push +#pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; +#pragma nv_diagnostic pop + + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) + { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + + else if constexpr (has_zp && is_zp_float) + { + int pipe = full_pipe % stages; + + if constexpr (group_blocks != -1) + { + if constexpr (group_blocks >= thread_k_blocks) + { + if (k % b_sh_wr_iters == 0) + { + int4 *sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + } + } + else + { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + // Suppress bogus and persistent divide-by-zero warning +#pragma nv_diagnostic push +#pragma nv_diag_suppress divide_by_zero + int cur_group_id = k_blocks / group_blocks; +#pragma nv_diagnostic pop + + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } + } + }; + + auto dequant_data = [&](int q, scalar_t2 *frag_b_ptr) + { + dequant(q, frag_b_ptr); + }; + + // Execute the actual tensor core matmul of a sub-tile. + bool is_first_matmul_in_slice = true; + auto matmul = [&](int k) + { + int k2 = k % 2; + const bool is_new_zp = ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == -1 && is_first_matmul_in_slice); + if constexpr (has_zp && !is_zp_float) + { + if (is_new_zp) + { + if constexpr (group_blocks == -1) + is_first_matmul_in_slice = false; + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) + { + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } + else + { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = frag_qzp[k2][1]; + } + + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + } + } + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) + { + if (is_new_zp) + { + reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; + } + } + + if constexpr (w_type == host::kFE2M1f) + { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales(s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } + +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) + { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (w_type_id == host::kFE2M1f.id()) + { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } + else if constexpr (w_type.size_bits() == 4) + { + b_quant_0 = frag_b_quant[k2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } + else + { + static_assert(w_type.size_bits() == 8); + int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) + { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } + + // Apply scale to frag_b0 + if constexpr (has_act_order) + { + static_assert(group_blocks != -1); + scale4( + frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4( + frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + } + else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && group_blocks == -1) + { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 s2 = Dtype::nums2num2( + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); + if (is_new_zp) + frag_zp[j] = __hmul2(frag_zp[j], s2); + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } + else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) + { + if (is_new_zp) + frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } + else if constexpr (group_blocks != -1) + { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + { + if constexpr (m_block_size_8) + { + mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + } + else + { + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() + { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) + { + auto red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) + { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) + { + if (i <= red_idx && red_idx < 2 * i) + { +#pragma unroll + for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) + { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) + { + float *c_rd = reinterpret_cast(&sh_red[red_sh_delta * j + red_sh_rd]); + float *c_wr = reinterpret_cast(&sh_red[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) + { +#pragma unroll + for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) + { + float *c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce_fp16 = [&](bool first = false, bool last = false) + { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) + { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr; + if constexpr (m_block_size_8) + { + c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + else + { + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + constexpr int c_sh_wr_delta = active_threads; + auto c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) + { +// Interestingly, doing direct global accesses here really seems to mess up +// the compiler and lead to slowdowns, hence we also use async-copies even +// though these fetches are not actually asynchronous. +#pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) + { + if constexpr (m_block_size_8) + { + cp_async4_pred( + &sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i], + (threadIdx.x % 4) * 2 + i < prob_m); + } + else + { + cp_async4_pred( + &sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + } + cp_async_fence(); + cp_async_wait<0>(); + } + +#pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) + { + bool mask = (!m_block_size_8) && (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) || + (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m); + if (mask) + { + if (!first) + { + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) + { + int delta = 0; + if constexpr (m_block_size_8) + { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) + { + int4 c; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) + { + int delta = 0; + if constexpr (m_block_size_8) + { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + } + if constexpr (m_block_size_8) + C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c; + else + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; + } + } + } + } + }; + + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) + { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = locks_off * c_size; + + if (!is_th_active) + { + return; + } + + if (!first) + { + float *frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) + { + sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float *sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); +#pragma unroll + for (int f = 0; f < 4; f++) + { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } + } + + if (!last) + { + int4 *frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) + { + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() + { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr; + if constexpr (m_block_size_8) + { + c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + (threadIdx.x % 32) / 4; + c_sh_wr += 64 * (threadIdx.x / 32); + } + else + { + c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + } + + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS &s) + { + scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr ( + !has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) + { + res = __hmul2(res, s[0]); + } + + if constexpr (w_type == host::kFE2M1f) + { + res = __hmul2(res, global_scale); + } + + if constexpr (m_block_size_8) + { + ((scalar_t *)sh_red)[idx] = res.x; + ((scalar_t *)sh_red)[idx + 8 * c_sh_stride] = res.y; + } + else + { + ((scalar_t2 *)sh_red)[idx] = res; + } + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) + { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + { +#pragma unroll + for (int j = 0; j < 4; j++) + { + if constexpr (m_block_size_8) + { + int wr = c_sh_wr + 16 * j; + write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + else + { + int wr = c_sh_wr + 8 * j; + write( + wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write( + wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write( + wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write( + wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) + { + if (c_gl_wr < c_gl_wr_end) + { + if (use_atomic_add && slice_count > 1) + { + scalar_t2 *C_half2 = reinterpret_cast(&C[c_gl_wr]); + scalar_t2 *sh_red_half2 = reinterpret_cast(&sh_red[c_sh_rd]); +#pragma unroll + for (int a = 0; a < 4; a++) + { + atomicAdd(&C_half2[a], sh_red_half2[a]); + } + } + else + { + C[c_gl_wr] = sh_red[c_sh_rd]; + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + __syncthreads(); + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() + { + +#pragma unroll + for (int i = 0; i < stages - 1; i++) + { + if (has_act_order && i == 0) + { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) + { + last_g_idx = prob_k - 1; + } + fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + + if constexpr (has_zp && !is_zp_float && group_blocks == -1) + { + if (i == 0) + { + fetch_col_zp_to_shared(); + if constexpr (!dequant_skip_flop) + { + fetch_col_scale_to_shared(); + } + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + if constexpr (has_act_order) + { + slice_k_start_shared_fetch += tb_k * (stages - 1); + } + }; + if (slice_iters) + { + start_pipes(); + } + + // Main loop. + while (slice_iters) + { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + +#pragma unroll + for (int pipe = 0; pipe < stages;) + { +#pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) + { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) + { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) + { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + + if constexpr (has_act_order) + { + slice_k_start += tb_k * stages; + + if (slice_k_start < prob_k) + { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) + { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) + { + fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) + { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) + { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) + { + if (s_sh_wr_pred) + { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) + { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) + { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) + { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + if constexpr (m_block_size_8) + { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 *frag_s_half2 = reinterpret_cast(frag_s); +#pragma unroll + for (int i = 0; i < 8; i++) + { + frag_s_half2[i] = Dtype::num2num2(reinterpret_cast(&frag_s_half2[i])[idx]); + } + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr ( + !has_act_order && group_blocks == -1 && w_type.size_bits() == 8 && (has_zp && dequant_skip_flop || !has_zp)) + { + if (threadIdx.x / 32 < thread_n_blocks / 4) + { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + { +#pragma unroll + for (int j = 0; j < 4; j++) + { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); + + if constexpr (!m_block_size_8) + { + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + } + + if (slice_count > 1 && !use_atomic_add) + { + // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[locks_off], slice_idx); + if (use_fp32_reduce) + { + global_reduce_fp32(slice_idx == 0, last); + } + else + { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[locks_off], last); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) + wait_negative_and_add(&locks[locks_off]); + if (last || use_atomic_add) + // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + is_first_matmul_in_slice = true; + init_slice(); + + if (slice_iters) + { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) + { +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) + { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + } + else + { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } + } + +} // namespace device::marlin + +#endif diff --git a/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu b/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu new file mode 100644 index 000000000..3e424ac4f --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu @@ -0,0 +1,1141 @@ +#if defined ENABLE_NVIDIA_API + +#include "../../../devices/nvidia/nvidia_handle.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../gptq_marlin_gemm.h" +#include "../sgl_kernel/tensor.h" +#include "gptq_marlin_gemm_nvidia.cuh" + +#include "../sgl_kernel/scalar_type.hpp" + +#include "../marlin/kernel.h" +#include "../marlin/marlin_template.h" + +namespace device::marlin { + +__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; + +using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +__global__ void permute_cols_kernel( + int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) {} + +#else + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel( + int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) { + auto start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int input_row_stride = lda * sizeof(half) / 16; + int output_row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int input_offset = row * input_row_stride; + int output_offset = row * output_row_stride; + + half const *a_row_half = reinterpret_cast(a_int4_ptr + input_offset); + half *out_half = reinterpret_cast(out_int4_ptr + output_offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +typedef struct +{ + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}}; + +typedef struct +{ + int blocks_per_sm; + thread_config_t tb_cfg; +} exec_config_t; + +int get_scales_cache_size( + thread_config_t const &th_config, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +int get_kernel_cache_size( + thread_config_t const &th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + int tb_m = thread_m_blocks * 16; + int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_red_size = tb_m * (tb_n + 8); + int sh_s_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); + int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + int sh_zp_size = 0; + if (has_zp) { + if (is_zp_float) { + sh_zp_size = sh_s_size; + } else if (num_bits == 4) { + sh_zp_size = sh_s_size / 4; + } else if (num_bits == 8) { + sh_zp_size = sh_s_size / 2; + } + } + + int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size; + + return total_size; +} + +bool is_valid_config( + thread_config_t const &th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float, + int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Check that pipeline fits into cache + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + return cache_size <= max_shared_mem; +} + +#define _GET_IF( \ + W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if ( \ + q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin< \ + scalar_t, \ + W_TYPE.id(), \ + NUM_THREADS, \ + THREAD_M_BLOCKS, \ + THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, \ + pipe_stages, \ + GROUP_BLOCKS, \ + IS_ZP_FLOAT>; \ + } + +// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) +// this is the most common cases +// BIGGROUP: cases for big group size (group_blocks in [-1, 8]) +// FZP: cases for float-zero-point (is_zp_float = true) +// ACT: cases for act order case (group_blocks == 0) +// FP4: cases for nvfp4(e2m1) (group_blocks == 1) +#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF(W_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ + COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF(W_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 4, 8, 128) + +// We currently have 4-bit models only with group_blocks == 4 +#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF(W_TYPE) \ + FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, 4, 8, 128) + +// We currently have 4-bit models only with group_blocks == 4 +#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF(W_TYPE) \ + ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ + ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, 4, 8, 128) + +template +MarlinFuncPtr get_marlin_kernel( + const host::ScalarType q_type, + int thread_m_blocks, + int thread_n_blocks, + int thread_k_blocks, + bool m_block_size_8, + bool has_act_order, + bool has_zp, + int group_blocks, + int num_threads, + bool is_zp_float) { + int num_bits = q_type.size_bits(); + auto kernel = MarlinDefault; + if (false) { + } + + COMMON_GET_IF(host::kU4) + COMMON_GET_IF(host::kU4B8) + COMMON_GET_IF(host::kU8B128) + + FP4_GET_IF(host::kFE2M1f) + + BIGGROUP_GET_IF(host::kFE4M3fn) + + ACT_GET_IF(host::kU4B8) + ACT_GET_IF(host::kU8B128) + + if (std::is_same::value) { + if (false) { + } + FZP_GET_IF(host::kU4) + } + + return kernel; +} + +template +exec_config_t determine_exec_config( + const host::ScalarType &q_type, + int prob_m, + int prob_n, + int prob_k, + int thread_m_blocks, + bool m_block_size_8, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + bool has_zp, + bool is_zp_float, + int max_shared_mem, + int sms) { + exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; + thread_config_t *thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs : small_batch_thread_configs; + int thread_configs_size = thread_m_blocks > 1 ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) + : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + + for (int i = 0; i < thread_configs_size; i++) { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem)) { + continue; + } + + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + + int group_blocks = 0; + if (!has_act_order) { + group_blocks = group_size == -1 ? -1 : group_size / 16; + } + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + th_config.thread_n / 16, + th_config.thread_k / 16, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + th_config.num_threads, + is_zp_float); + + if (kernel == MarlinDefault) { + continue; + } + + // int m_tiles = div_ceil(prob_m, thread_m_blocks * 16); + // int n_tiles = prob_n / th_config.thread_n; + // int k_tiles = prob_k / th_config.thread_k; + + return {1, th_config}; + } + + return exec_cfg; +} + +template +void marlin_mm( + const void *A, + const void *B, + void *C, + void *C_tmp, + void *s, + void *s2, + void *zp, + void *g_idx, + void *perm, + void *a_tmp, + int prob_m, + int prob_n, + int prob_k, + int lda, + void *workspace, + host::ScalarType const &q_type, + bool has_act_order, + bool is_k_full, + bool has_zp, + int num_groups, + int group_size, + int dev, + cudaStream_t stream, + int thread_k_init, + int thread_n_init, + int sms, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) { + if (has_zp) { + host::RuntimeCheck( + q_type == host::kU4 || q_type == host::kU8, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } else { + host::RuntimeCheck( + q_type == host::kU4B8 || q_type == host::kU8B128 || q_type == host::kFE4M3fn || q_type == host::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); + } + + host::RuntimeCheck( + prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + host::RuntimeCheck(group_size != -1); + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } else { + host::RuntimeCheck(group_size == 0); + group_blocks = 0; + } + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } + } + + int num_bits = q_type.size_bits(); + const int4 *A_ptr = (const int4 *)A; + const int4 *B_ptr = (const int4 *)B; + int4 *C_ptr = (int4 *)C; + int4 *C_tmp_ptr = (int4 *)C_tmp; + const int4 *s_ptr = (const int4 *)s; + const uint16_t *s2_ptr = (const uint16_t *)s2; + const int4 *zp_ptr = (const int4 *)zp; + const int *g_idx_ptr = (const int *)g_idx; + const int *perm_ptr = (const int *)perm; + int4 *a_tmp_ptr = (int4 *)a_tmp; + + int *locks = (int *)workspace; + + if (has_act_order) { + // Permute A columns + int block_rows = div_ceil(prob_m, sms); + host::LaunchKernel(sms, default_threads, stream)( + permute_cols_kernel, A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows); + A_ptr = a_tmp_ptr; + lda = prob_k; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } + } + + int max_shared_mem = 0; + host::RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); + host::RuntimeCheck(max_shared_mem > 0); + + int max_par = 16; + if (prob_n <= 4096) { + max_par = 16 * 8; + } + int max_shared_mem_new = max_shared_mem; + int rest_m = prob_m; + int max_thread_m_blocks = 4; + while (rest_m) { + int par_count = rest_m / (max_thread_m_blocks * 16); + if (par_count > max_par) { + par_count = max_par; + } + int prob_m_split = par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m; + + int thread_k = thread_k_init; + int thread_n = thread_n_init; + + int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); + int m_block_size_8 = prob_m_split <= 8; + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + host::RuntimeCheck(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); + host::RuntimeCheck(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); + } else { + // Auto config + exec_cfg = determine_exec_config( + q_type, + prob_m_split, + prob_n, + prob_k, + thread_m_blocks, + m_block_size_8, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem, + sms); + thread_tfg = exec_cfg.tb_cfg; + if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { + max_thread_m_blocks--; + continue; + } + } + + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) { + max_shared_mem_new = max_shared_mem / exec_cfg.blocks_per_sm - 1024; + } + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + host::RuntimeCheck( + is_valid_config( + thread_tfg, + thread_m_blocks, + prob_m_split, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem_new), + "Invalid thread config: thread_m_blocks = ", + thread_m_blocks, + ", thread_k = ", + thread_tfg.thread_k, + ", thread_n = ", + thread_tfg.thread_n, + ", num_threads = ", + thread_tfg.num_threads, + " for MKN = [", + prob_m, + ", ", + prob_k, + ", ", + prob_n, + "] and num_bits = ", + num_bits, + ", prob_m_split = ", + prob_m_split, + ", group_size = ", + group_size, + ", has_act_order = ", + has_act_order, + ", is_k_full = ", + is_k_full, + ", has_zp = ", + has_zp, + ", is_zp_float = ", + is_zp_float, + ", max_shared_mem_new = ", + max_shared_mem_new); + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + thread_n_blocks, + thread_k_blocks, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + num_threads, + is_zp_float); + + if (kernel == MarlinDefault) { + host::Panic( + "Unsupported shapes: MNK = [", + prob_m, + ", ", + prob_n, + ", ", + prob_k, + "]", + ", has_act_order = ", + has_act_order, + ", num_groups = ", + num_groups, + ", group_size = ", + group_size, + ", prob_m_split = ", + prob_m_split, + ", thread_m_blocks = ", + thread_m_blocks, + ", thread_n_blocks = ", + thread_n_blocks, + ", thread_k_blocks = ", + thread_k_blocks, + ", num_threads = ", + num_threads, + ", num_bits = ", + num_bits); + } + + host::RuntimeDeviceCheck( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem_new)); + + bool part_use_atomic_add = use_atomic_add && div_ceil(prob_m_split, 64) * prob_n <= 2048; + + host::LaunchKernel(blocks, num_threads, stream, max_shared_mem_new)( + kernel, + A_ptr, + B_ptr, + C_ptr, + C_tmp_ptr, + s_ptr, + s2_ptr, + zp_ptr, + g_idx_ptr, + num_groups, + prob_m_split, + prob_n, + prob_k, + lda, + locks, + part_use_atomic_add, + use_fp32_reduce, + max_shared_mem_new); + + A_ptr += prob_m_split * (lda / 8); + C_ptr += prob_m_split * (prob_n / 8); + rest_m -= prob_m_split; + } +} + +#endif + +} // namespace device::marlin + +template +void gptq_marlin_gemm(const void *a, + const void *b_q_weight, + void *b_scales, + void *global_scale, + void *b_zeros, + void *g_idx, + void *perm, + void *c, + void *c_tmp, + void *a_tmp, + void *workspace, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float, + int64_t size_m, + int64_t size_k, + int64_t size_n, + int64_t b_q_size_1, + int64_t a_stride0, + int num_groups, cudaStream_t stream) + +{ + using namespace host; + + // Verify a: [M, K] + ScalarType const b_q_type = ScalarType::from_id(b_q_type_id); + int pack_factor = 32 / b_q_type.size_bits(); + // Verify b_q_weight: [K/tile_size, packed_N] + RuntimeCheck( + size_k % device::marlin::tile_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_size = ", + device::marlin::tile_size); + + int64_t expected_bqw_dim0 = size_k / device::marlin::tile_size; + RuntimeCheck( + b_q_size_1 % device::marlin::tile_size == 0, + "b_q_weight.size(1) = ", + b_q_size_1, + " is not divisible by tile_size = ", + device::marlin::tile_size); + int64_t actual_size_n = (b_q_size_1 / device::marlin::tile_size) * pack_factor; + RuntimeCheck(actual_size_n == size_n, "actual_size_n must = size_n"); + // size_n = actual_size_n + RuntimeCheck(a_stride0 % 8 == 0, "a.stride(0) must be divisible by 8"); + // Verify b_scales: [num_groups, N] + // Early return for zero-size M + if (size_m == 0) { + return; + } + + // int64_t g_idx_size = g_idx.size(0);// g_idx_size == size_k + // int64_t perm_size = perm.size(0);// perm_size == size_k + bool has_act_order = (g_idx != nullptr && perm != nullptr); + + // Determine has_zp from b_zeros size + // int64_t b_zeros_size = b_zeros.size(0); + bool has_zp = (b_zeros != nullptr); + + if (has_zp) { + RuntimeCheck( + b_q_type == kU4 || b_q_type == kU8, "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + } else { + RuntimeCheck( + b_q_type == kU4B8 || b_q_type == kU8B128 || b_q_type == kFE4M3fn || b_q_type == kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + b_q_type.str()); + } + + if (has_zp && is_zp_float) { + RuntimeCheck( + std::is_same::value, "Computation type must be float16 (half) when using float zero points."); + } + + // int64_t global_scale_size = global_scale.size(0); + if (global_scale != nullptr) { + RuntimeCheck(b_q_type == kFE2M1f, "global_scale can only be used for float4_e2m1f."); + } else { + RuntimeCheck(!(b_q_type == kFE2M1f), "the global_scale parameter must be passed for float4_e2m1f."); + } + // Derive group_size + int group_size = -1; + if (has_act_order) { + if (is_k_full) { + RuntimeCheck(num_groups > 1, "For act_order, num_groups must be > 1"); + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = static_cast(size_k / num_groups); + } else { + group_size = 0; + } + } else { + if (num_groups > 1) { + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = static_cast(size_k / num_groups); + } else { + group_size = -1; + } + } + + // Verify workspace and get device info + RuntimeCheck( + size_n % device::marlin::min_thread_n == 0, + "size_n = ", + size_n, + ", is not divisible by min_thread_n = ", + device::marlin::min_thread_n); + + int device_id = 0; + int sms = -1; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, device_id)); + // RuntimeCheck(workspace.size(0) >= sms, "workspace.size(0) = ", workspace.size(0), " is below min_workspace_size = ", sms); + // Hardcoded defaults (auto config) + int thread_k_init = -1; + int thread_n_init = -1; + + // Compute c_tmp and a_tmp pointers + // c_tmp and a_tmp are pre-allocated by caller + + device::marlin::marlin_mm( + a, + b_q_weight, + c, + c_tmp, + b_scales, + global_scale, + b_zeros, + g_idx, + perm, + a_tmp, + static_cast(size_m), + static_cast(size_n), + static_cast(size_k), + static_cast(a_stride0), + workspace, + b_q_type, + has_act_order, + is_k_full, + has_zp, + num_groups, + group_size, + device_id, + stream, + thread_k_init, + thread_n_init, + sms, + use_atomic_add, + use_fp32_reduce, + is_zp_float); +} + +template +infiniStatus_t gptq_marlin_gemm_kernel(void *c, + const void *a, + const void *b_q_weight, + void *b_scales, + void *global_scale, + void *b_zeros, + void *g_idx, + void *perm, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float, + int64_t size_m, + int64_t size_k, + int64_t size_n, + int64_t b_q_size_1, + int64_t a_stride0, + int num_groups, void *total_buffer, cudaStream_t stream) { + int _MAX_THREAD_N = 256; + int max_blocks_per_sm = 1; + float *c_tmp = nullptr; + void *a_tmp = nullptr; + void *workspace = nullptr; + + // 获取设备 SM 数量(只查询 1 次!) + int dev; + cudaGetDevice(&dev); // 获取当前设备号 + + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, dev); + const int sms = prop.multiProcessorCount; + + // ===================== 1. 计算每块内存大小 ===================== + size_t c_tmp_bytes = 0; + if (use_fp32_reduce) { + int max_m_block = ((size_m + 15) / 16) * 16; + max_m_block = min(max_m_block, 64); + const size_t c_elems = (size_t)sms * max_m_block * _MAX_THREAD_N; + c_tmp_bytes = c_elems * sizeof(float); + } + + size_t a_tmp_bytes = 0; + bool has_act_order = false; + if (g_idx != nullptr && perm != nullptr) { + has_act_order = true; + } + if (has_act_order) { + a_tmp_bytes = (size_t)size_m * size_k * sizeof(scalar_t); + } + + // workspace 大小(int 类型,必须分配) + const size_t workspace_elems = (size_t)sms * max_blocks_per_sm; + const size_t workspace_bytes = workspace_elems * sizeof(int); + + // ===================== 2. 计算总内存大小 ===================== + const size_t total_bytes = c_tmp_bytes + a_tmp_bytes + workspace_bytes; + + // ===================== 3. 单次 cudaMalloc 分配 ===================== + if (total_bytes > 0) { + cudaMemset(total_buffer, 0, total_bytes); + } + + // ===================== 4. 手动切分指针(核心!) ===================== + uint8_t *ptr = reinterpret_cast(total_buffer); + + // 分配 c_tmp + if (use_fp32_reduce && c_tmp_bytes > 0) { + c_tmp = reinterpret_cast(ptr); + ptr += c_tmp_bytes; + } + + // 分配 a_tmp + if (has_act_order && a_tmp_bytes > 0) { + a_tmp = ptr; + ptr += a_tmp_bytes; + } + + // 分配 workspace + if (workspace_bytes > 0) { + workspace = ptr; + ptr += workspace_bytes; + } + + gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + global_scale, + b_zeros, + g_idx, + perm, + c, + c_tmp, + a_tmp, + workspace, + b_q_type_id, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, + size_m, + size_k, + size_n, + b_q_size_1, + a_stride0, + num_groups, + stream); + return INFINI_STATUS_SUCCESS; +} + +int getCudaDeviceSMCount() { + int dev; + cudaGetDevice(&dev); + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, dev); + + return prop.multiProcessorCount; +} + +namespace op::gptq_marlin_gemm::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { delete _opaque; } + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t global_scale_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t g_idx_desc, + infiniopTensorDescriptor_t perm_desc) { + + auto handle = reinterpret_cast(handle_); + auto result = GptqMarlinGemmInfo::create(out_desc, a_desc, b_desc, b_scales_desc, global_scale_desc, b_zeros_desc, g_idx_desc, perm_desc); + + int sms = getCudaDeviceSMCount(); + int _MAX_THREAD_N = 256; + int max_blocks_per_sm = 1; + int max_m_block = ((out_desc->dim(0) + 15) / 16) * 16; + max_m_block = min(max_m_block, 64); + const size_t c_elems = (size_t)sms * max_m_block * _MAX_THREAD_N; + size_t c_tmp_bytes = c_elems * sizeof(float); + size_t a_tmp_bytes = (size_t)a_desc->dim(0) * a_desc->dim(1) * infiniSizeOf(a_desc->dtype()); + const size_t workspace_elems = (size_t)sms * max_blocks_per_sm; + const size_t workspace_bytes = workspace_elems * sizeof(int); + size_t workspace_size = c_tmp_bytes + a_tmp_bytes + workspace_bytes; + + *desc_ptr = new Descriptor( + workspace_size, + new Opaque{handle->internal()}, + result.take(), + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t +Descriptor::calculate( + void *workspace, size_t workspace_size, + void *out, + const void *a, + const void *b, + void *b_scales, + void *global_scale, + void *b_zeros, + void *g_idx, + void *perm, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float, + void *stream_) const { + + cudaStream_t stream = (cudaStream_t)stream_; + int64_t M = static_cast(_info.M); + int64_t K = static_cast(_info.K); + int64_t N = static_cast(_info.N); + int64_t b_q_size_1 = static_cast(_info.b_q_size_1); + int64_t a_stride_0 = static_cast(_info.a_stride_0); + int num_groups = _info.num_groups; + +#define MARLIN(TDATA) \ + gptq_marlin_gemm_kernel(out, a, b, b_scales, global_scale, b_zeros, g_idx, perm, b_q_type_id, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float, M, K, N, b_q_size_1, a_stride_0, num_groups, workspace, stream) + + if (_info.dtype == INFINI_DTYPE_F16) { + return MARLIN(half); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + return MARLIN(__nv_bfloat16); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::gptq_marlin_gemm::nvidia + +#endif diff --git a/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cuh b/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cuh new file mode 100644 index 000000000..f9c7eb6e9 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __GPTQ_MARLIN_GEMM_CUDA_CUH__ +#define __GPTQ_MARLIN_GEMM_CUDA_CUH__ + +#include "../gptq_marlin_gemm.h" + +DESCRIPTOR(nvidia) + +#endif // __GPTQ_MARLIN_GEMM_CUDA_CUH__ diff --git a/src/infiniop/ops/gptq_marlin_gemm/operator.cc b/src/infiniop/ops/gptq_marlin_gemm/operator.cc new file mode 100644 index 000000000..04d9e43f6 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/operator.cc @@ -0,0 +1,120 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/gptq_marlin_gemm.h" + +#if defined ENABLE_NVIDIA_API +#include "nvidia/gptq_marlin_gemm_nvidia.cuh" +#endif + +__INFINI_C infiniStatus_t infiniopCreateGptqMarlinGemmDescriptor( + infiniopHandle_t handle, + infiniopGptqMarlinGemmDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t global_scale_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t g_idx_desc, + infiniopTensorDescriptor_t perm_desc) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::gptq_marlin_gemm::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, \ + a_desc, \ + b_desc, \ + b_scales_desc, \ + global_scale_desc, \ + b_zeros_desc, \ + g_idx_desc, \ + perm_desc) + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetGptqMarlinGemmWorkspaceSize(infiniopGptqMarlinGemmDescriptor_t desc, + size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET +} + +__INFINI_C infiniStatus_t infiniopGptqMarlinGemm( + infiniopGptqMarlinGemmDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + void *b_scales, + void *global_scale, + void *b_zeros, + void *g_idx, + void *perm, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, out, a, b, b_scales, global_scale, b_zeros, g_idx, perm, b_q_type_id, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float, stream) + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t +infiniopDestroyGptqMarlinGemmDescriptor(infiniopGptqMarlinGemmDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} + +// #endif diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/scalar_type.hpp b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/scalar_type.hpp new file mode 100644 index 000000000..15f46457f --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/scalar_type.hpp @@ -0,0 +1,335 @@ +#pragma once + +#include +#include +#ifndef __CUDACC__ +#include +#endif + +namespace host { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// The type definitions on the Python side can be found in: vllm/scalar_type.py +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : uint8_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType( + uint8_t exponent, + uint8_t mantissa, + bool signed_, + int32_t bias, + bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr) {}; + + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); + } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) { + assert(mantissa > 0 && exponent > 0); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) { + assert(nan_repr < NAN_REPR_ID_MAX); + assert(mantissa > 0 && exponent > 0); + assert(nan_repr != NAN_IEEE_754); + return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr); + } + + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + using Id = int64_t; + + private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, finite_values_only, nan_repr); + }; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); + }; + + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { return acc + member_id_field_width(); }, 0); + } + + public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << bit_offset, bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } + + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, tuple_args); + } + + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { + return signed_; + } + constexpr bool is_integer() const { + return exponent == 0; + } + constexpr bool is_floating_point() const { + return exponent > 0; + } + constexpr bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754; + } + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + constexpr bool has_bias() const { + return bias != 0; + } + +#ifndef __CUDACC__ + private: + double _floating_point_max() const { + assert(mantissa <= 52 && exponent <= 11); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + assert(exponent < 11); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); + } + + constexpr std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + assert(size_bits() < 64 || (size_bits() == 64 && is_signed())); + return {(int64_t(1) << mantissa) - 1}; + } + } + + constexpr std::variant _raw_min() const { + if (is_floating_point()) { + assert(is_signed()); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + assert(!is_signed() || size_bits() <= 64); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant max() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant min() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_min()); + } +#endif // __CUDACC__ + + public: + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = + "float" + std::to_string(size_bits()) + "_e" + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + constexpr bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && nan_repr == other.nan_repr; + } +}; + +using ScalarTypeId = ScalarType::Id; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE2M1f = ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE8M0fnu = ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names, generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +static inline constexpr auto kFloat16Id = kFloat16.id(); +} // namespace host + diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/source_location.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/source_location.h new file mode 100644 index 000000000..57573171a --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/source_location.h @@ -0,0 +1,41 @@ +/// \file source_location.h +/// \brief Portable `source_location` wrapper. +/// +/// Uses `std::source_location` when available (C++20), otherwise falls +/// back to a minimal stub that returns empty/zero values. + +#pragma once +#include + +/// NOTE: fallback to a minimal source_location implementation +#if defined(__cpp_lib_source_location) +#include + +using source_location_t = std::source_location; + +#else + +struct source_location_fallback { + public: + static constexpr source_location_fallback current() noexcept { + return source_location_fallback{}; + } + constexpr source_location_fallback() noexcept = default; + constexpr unsigned line() const noexcept { + return 0; + } + constexpr unsigned column() const noexcept { + return 0; + } + constexpr const char* file_name() const noexcept { + return ""; + } + constexpr const char* function_name() const noexcept { + return ""; + } +}; + +using source_location_t = source_location_fallback; + +#endif + diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h new file mode 100644 index 000000000..9f48edd96 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h @@ -0,0 +1,621 @@ +/// \file tensor.h +/// \brief Tensor validation and symbolic matching utilities. +#pragma once +#include "utils.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __CUDACC__ +#include "utils.cuh" +#endif + +namespace host +{ + struct SymbolicSize; + struct SymbolicDType; + struct SymbolicDevice; + + namespace details + { + inline constexpr auto kAnyDeviceID = -1; + inline constexpr auto kAnySize = static_cast(-1); + inline constexpr auto kNullSize = static_cast(-1); + inline constexpr auto kNullDType = static_cast(18u); + inline constexpr auto kNullDevice = static_cast(-1); + + template + struct ArrayView + { + const T *data; + size_t size; + + __host__ __device__ ArrayView() : data(nullptr), size(0) {} + __host__ __device__ ArrayView(const T *d, size_t s) : data(d), size(s) {} + + template + __host__ __device__ ArrayView(const std::array &arr) + : data(arr.data()), size(arr.size()) {} + + __host__ __device__ const T &operator[](size_t i) const { return data[i]; } + __host__ __device__ bool empty() const { return size == 0; } + }; + + template + struct PrintAbleSpan + { + const T *data; + size_t length; + + PrintAbleSpan(const T *p, size_t l) : data(p), length(l) {} + size_t size() const { return length; } + const T &operator[](size_t i) const { return data[i]; } + }; + + inline constexpr const char *kDeviceStringMap[] = { + "", // 0 + "cpu", // 1 + "cuda", // 2 + "cuda_host", // 3 + "opencl", // 4 + "vulkan", // 5 + "metal", // 6 + "vpi", // 7 + "rocm", // 8 + "rocm_host", // 9 + "ext_dev", // 10 + "cuda_managed", // 11 + "oneapi", // 12 + "webgpu", // 13 + "hexagon", // 14 + "maia", // 15 + "trn", // 16 + }; + + constexpr int kMaxDeviceType = 16; + + struct PrintableDevice + { + DLDevice device; + }; + + template + struct _dtype_trait; + + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLInt, 8, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLInt, 16, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLInt, 32, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLInt, 64, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLUInt, 8, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLUInt, 16, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLUInt, 32, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLUInt, 64, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLFloat, 32, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLFloat, 64, 1}; + }; + +#ifdef __CUDACC__ + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLFloat, 16, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLBfloat, 16, 1}; + }; + template <> + struct _dtype_trait + { + static constexpr DLDataType value = {kDLFloat8_e4m3fn, 8, 1}; + }; +#endif + + template + struct _device_trait + { + static constexpr DLDevice value = {Code, kAnyDeviceID}; + }; + + template + inline constexpr std::array kDTypeList = { + _dtype_trait::value...}; + + template + inline constexpr std::array kDeviceList = { + _device_trait::value...}; + + } // namespace details + + inline std::ostream &operator<<(std::ostream &os, DLDevice device) + { + int code = static_cast(device.device_type); + if (code < 1 || code > details::kMaxDeviceType) + RuntimeCheck(false, "Unknown device: ", code); + os << details::kDeviceStringMap[code]; + if (device.device_id != details::kAnyDeviceID && device.device_type != kDLCPU) + os << ":" << device.device_id; + return os; + } + + inline std::ostream &operator<<(std::ostream &os, details::PrintableDevice pd) + { + return os << pd.device; + } + + template + inline std::ostream &operator<<(std::ostream &os, const details::PrintAbleSpan &span) + { + os << "["; + for (size_t i = 0; i < span.size(); ++i) + { + if (i > 0) + os << ", "; + os << span[i]; + } + os << "]"; + return os; + } + + // ============================================== + // SymbolicSize 完整定义 + // ============================================== + struct SymbolicSize + { + public: + explicit SymbolicSize(std::string_view ann = {}) + : m_value(details::kNullSize), m_ann(ann) {} + + SymbolicSize(const SymbolicSize &) = delete; + SymbolicSize &operator=(const SymbolicSize &) = delete; + + std::string_view get_name() const { return m_ann; } + bool has_value() const { return m_value != details::kNullSize; } + + void set_value(int64_t v) + { + RuntimeCheck(!has_value(), "Size already set"); + m_value = v; + } + + std::optional get_value() const + { + return has_value() ? std::optional(m_value) : std::nullopt; + } + + int64_t unwrap(DebugInfo info = {}) const + { + RuntimeCheck(info, has_value(), "Size not set"); + return m_value; + } + + void verify(int64_t v, const char *prefix, int64_t dim) + { + if (has_value()) + { + if (m_value != v) [[unlikely]] + { + Panic("Size mismatch for ", m_name_str(prefix, dim), ": expected ", m_value, " got ", v); + } + } + else + { + set_value(v); + } + } + + std::string value_or_name(const char *prefix, int64_t dim) const + { + if (auto v = get_value()) + return std::to_string(*v); + return m_name_str(prefix, dim); + } + + private: + std::string m_name_str(const char *prefix, int64_t dim) const + { + std::ostringstream os; + os << prefix << '#' << dim; + if (!m_ann.empty()) + os << "('" << m_ann << "')"; + return os.str(); + } + + int64_t m_value; + std::string_view m_ann; + }; + + inline bool operator==(DLDevice a, DLDevice b) + { + return a.device_type == b.device_type && a.device_id == b.device_id; + } + + // ============================================== + // SymbolicDType 完整定义 + // ============================================== + struct SymbolicDType + { + public: + SymbolicDType() : m_value({details::kNullDType, 0, 0}) {} + SymbolicDType(const SymbolicDType &) = delete; + SymbolicDType &operator=(const SymbolicDType &) = delete; + + bool has_value() const { return m_value.code != details::kNullDType; } + + void set_value(DLDataType v) + { + RuntimeCheck(!has_value(), "DType already set"); + RuntimeCheck(m_check(v), "DType not allowed: ", v); + m_value = v; + } + + std::optional get_value() const + { + return has_value() ? std::optional(m_value) : std::nullopt; + } + + DLDataType unwrap(DebugInfo info = {}) const + { + RuntimeCheck(info, has_value(), "DType not set"); + return m_value; + } + + void set_options(details::ArrayView opts) { m_opts = opts; } + + template + void set_options() + { + m_opts = details::ArrayView(details::kDTypeList.data(), details::kDTypeList.size()); + } + + void verify(DLDataType dtype) + { + if (has_value()) + { + RuntimeCheck(m_value == dtype, "DType mismatch: expected ", m_value, " got ", dtype); + } + else + { + set_value(dtype); + } + } + + template + bool is_type() const + { + return m_value == details::_dtype_trait::value; + } + + private: + bool m_check(DLDataType v) const + { + if (m_opts.empty()) + return true; + for (size_t i = 0; i < m_opts.size; ++i) + if (m_opts[i] == v) + return true; + return false; + } + + details::ArrayView m_opts; + DLDataType m_value; + }; + + // ============================================== + // SymbolicDevice 完整定义 + // ============================================== + struct SymbolicDevice + { + public: + SymbolicDevice() : m_value({details::kNullDevice, details::kAnyDeviceID}) {} + SymbolicDevice(const SymbolicDevice &) = delete; + SymbolicDevice &operator=(const SymbolicDevice &) = delete; + + bool has_value() const { return m_value.device_type != details::kNullDevice; } + + void set_value(DLDevice v) + { + RuntimeCheck(!has_value(), "Device already set"); + RuntimeCheck(m_check(v), "Device not allowed: ", details::PrintableDevice{v}); + m_value = v; + } + + std::optional get_value() const + { + return has_value() ? std::optional(m_value) : std::nullopt; + } + + DLDevice unwrap(DebugInfo info = {}) const + { + RuntimeCheck(info, has_value(), "Device not set"); + return m_value; + } + + void set_options(details::ArrayView opts) { m_opts = opts; } + + template + void set_options() + { + m_opts = details::ArrayView(details::kDeviceList.data(), details::kDeviceList.size()); + } + + void verify(DLDevice dev) + { + if (has_value()) + { + RuntimeCheck(m_value == dev, "Device mismatch: expected ", + details::PrintableDevice{m_value}, " got ", details::PrintableDevice{dev}); + } + else + { + set_value(dev); + } + } + + private: + bool m_check(DLDevice v) const + { + if (m_opts.empty()) + return true; + for (size_t i = 0; i < m_opts.size; ++i) + { + auto o = m_opts[i]; + if (o.device_type != v.device_type) + continue; + if (o.device_id == details::kAnyDeviceID || o.device_id == v.device_id) + return true; + } + return false; + } + + details::ArrayView m_opts; + DLDevice m_value; + }; + + // ============================================== + // BaseRef / Ref 类型(现在类型已完整定义) + // ============================================== + namespace details + { + template + struct BaseRef + { + BaseRef() : m_ref(&m_cache) {} + explicit BaseRef(T &r) : m_ref(&r) {} + BaseRef(const BaseRef &) = delete; + BaseRef &operator=(const BaseRef &) = delete; + + T *operator->() const { return m_ref; } + T &operator*() const { return *m_ref; } + void rebind(T &r) { m_ref = &r; } + + private: + T *m_ref; + T m_cache; + }; + + struct SizeRef : public BaseRef + { + using BaseRef::BaseRef; + SizeRef(int64_t v); + }; + + struct DTypeRef : public BaseRef + { + using BaseRef::BaseRef; + DTypeRef(DLDataType); + DTypeRef(std::initializer_list); + DTypeRef(ArrayView); + }; + + struct DeviceRef : public BaseRef + { + using BaseRef::BaseRef; + DeviceRef(DLDevice); + DeviceRef(std::initializer_list); + DeviceRef(ArrayView); + }; + + inline SizeRef::SizeRef(int64_t v) + { + if (v != kAnySize) + (**this).set_value(v); + } + inline DTypeRef::DTypeRef(DLDataType v) { (**this).set_value(v); } + inline DTypeRef::DTypeRef(std::initializer_list l) : DTypeRef(ArrayView(l.begin(), l.size())) {} + inline DTypeRef::DTypeRef(ArrayView v) { (**this).set_options(v); } + inline DeviceRef::DeviceRef(DLDevice v) { (**this).set_value(v); } + inline DeviceRef::DeviceRef(std::initializer_list l) : DeviceRef(ArrayView(l.begin(), l.size())) {} + inline DeviceRef::DeviceRef(ArrayView v) { (**this).set_options(v); } + + } // namespace details + + template + inline bool is_type(DLDataType dtype) + { + return dtype == details::_dtype_trait::value; + } + + // ============================================== + // TensorMatcher + // ============================================== + struct TensorMatcher + { + using SizeRef = details::SizeRef; + using DTypeRef = details::DTypeRef; + using DeviceRef = details::DeviceRef; + + TensorMatcher(const TensorMatcher &) = delete; + TensorMatcher &operator=(const TensorMatcher &) = delete; + + explicit TensorMatcher(std::initializer_list s) + : m_shape(s.begin(), s.size()), m_strides(nullptr, 0) {} + + TensorMatcher &&with_strides(std::initializer_list s) && + { + RuntimeCheck(m_strides.empty(), "Strides already set"); + RuntimeCheck(m_shape.size == s.size(), "Stride/shape size mismatch"); + m_strides = details::ArrayView(s.begin(), s.size()); + return std::move(*this); + } + + template + TensorMatcher &&with_dtype(DTypeRef &&d) && + { + m_dtype.rebind(*d); + m_dtype->template set_options(); + return std::move(*this); + } + + template + TensorMatcher &&with_dtype() && + { + m_dtype->template set_options(); + return std::move(*this); + } + + template + TensorMatcher &&with_device(DeviceRef &&d) && + { + m_device.rebind(*d); + m_device->template set_options(); + return std::move(*this); + } + + template + TensorMatcher &&with_device() && + { + m_device->template set_options(); + return std::move(*this); + } + + const TensorMatcher &&verify(tvm::ffi::TensorView, DebugInfo = {}) const &&; + + private: + static void s_print_tensor(std::ostringstream &, tvm::ffi::TensorView); + void m_verify_impl(tvm::ffi::TensorView) const; + + details::ArrayView m_shape; + details::ArrayView m_strides; + DTypeRef m_dtype; + DeviceRef m_device; + }; + + inline void TensorMatcher::s_print_tensor(std::ostringstream &os, tvm::ffi::TensorView v) + { + os << "Tensor<"; + size_t d = 0; + for (int64_t s : v.shape()) + { + if (d++) + os << ", "; + os << s; + } + os << ">[strides=<"; + d = 0; + for (int64_t s : v.strides()) + { + if (d++) + os << ", "; + os << s; + } + os << ">, dtype=" << v.dtype(); + os << ", device=" << details::PrintableDevice{v.device()} << "]"; + } + + inline const TensorMatcher &&TensorMatcher::verify(tvm::ffi::TensorView v, DebugInfo info) const && + { + try + { + m_verify_impl(v); + } + catch (PanicError &e) + { + std::ostringstream os; + os << "Tensor match failed: "; + s_print_tensor(os, v); + os << " @ " << info.file_name() << ":" << info.line() << "\n- cause: " << e.root_cause(); + throw PanicError(os.str()); + } + return std::move(*this); + } + + inline void TensorMatcher::m_verify_impl(tvm::ffi::TensorView v) const + { + size_t dim = static_cast(v.dim()); + RuntimeCheck(dim == m_shape.size, "Dim mismatch: expected ", m_shape.size, " got ", dim); + + for (size_t i = 0; i < dim; ++i) + m_shape[i]->verify(v.size(i), "shape", (int64_t)i); + + if (!m_strides.empty()) + { + for (size_t i = 0; i < dim; ++i) + { + if (v.size(i) != 1 || !m_strides[i]->has_value()) + m_strides[i]->verify(v.stride(i), "stride", (int64_t)i); + } + } + else + { + RuntimeCheck(v.is_contiguous(), "Tensor not contiguous"); + } + + m_dtype->verify(v.dtype()); + m_device->verify(v.device()); + } + +} // namespace host diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh new file mode 100644 index 000000000..d73c2ac04 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh @@ -0,0 +1,310 @@ +/// \file utils.cuh +/// \brief Core CUDA/device utilities: type aliases, PDL helpers, +/// typed pointer access, kernel launch wrapper, and error checking. +/// +/// This header is included (directly or transitively) by nearly every +/// JIT kernel. It provides: +/// - Scalar/packed type aliases (`fp16_t`, `bf16_t`, `fp8_e4m3_t`, ...). +/// - `SGL_DEVICE` macro (forced-inline device function qualifier). +/// - `kWarpThreads` constant (32). +/// - PDL (Programmatic Dependent Launch) helpers for Hopper (sm_90+). +/// - Typed `load_as` / `store_as` for void-pointer access. +/// - `pointer::offset` for safe void-pointer arithmetic. +/// - `host::LaunchKernel` - kernel launcher with optional PDL. +/// - `host::RuntimeDeviceCheck` - CUDA error checking. + +#pragma once + +#include "utils.h" + +#include +#include + +#include +#include +#include +#ifndef USE_ROCM +#include +#include +#include +#include +#else +#include +#include +#include +#ifndef __grid_constant__ +#define __grid_constant__ +#endif +using cudaError_t = hipError_t; +using cudaStream_t = hipStream_t; +using cudaLaunchConfig_t = hipLaunchConfig_t; +using cudaLaunchAttribute = hipLaunchAttribute; +inline constexpr auto cudaSuccess = hipSuccess; +#define cudaStreamPerThread hipStreamPerThread +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaLaunchKernel hipLaunchKernel +#endif + +#ifndef USE_ROCM +using fp32_t = float; +// using fp16_t = __half; +// using bf16_t = __nv_bfloat16; +using fp8_e4m3_t = __nv_fp8_e4m3; +using fp8_e5m2_t = __nv_fp8_e5m2; + +using fp32x2_t = float2; +using fp16x2_t = __half2; +using bf16x2_t = __nv_bfloat162; +using fp8x2_e4m3_t = __nv_fp8x2_e4m3; +using fp8x2_e5m2_t = __nv_fp8x2_e5m2; + +using fp32x4_t = float4; +#else +using fp32_t = float; +using fp16_t = __half; +using bf16_t = __hip_bfloat16; +using fp8_e4m3_t = uint8_t; +using fp8_e5m2_t = uint8_t; +using fp32x2_t = float2; +using fp16x2_t = half2; +using bf16x2_t = __hip_bfloat162; +using fp8x2_e4m3_t = uint16_t; +using fp8x2_e5m2_t = uint16_t; +using fp32x4_t = float4; +#endif + +/* + * LDG Support + */ +#ifndef USE_ROCM +#define SGLANG_LDG(arg) __ldg(arg) +#else +#define SGLANG_LDG(arg) *(arg) +#endif + +namespace device { + +/// \brief Macro: forced-inline device function qualifier. +#define SGL_DEVICE __forceinline__ __device__ + +// Architecture detection: SGL_CUDA_ARCH is injected by load_jit() and is +// available in both host and device compilation passes, whereas __CUDA_ARCH__ +// is only defined by nvcc during the device pass. +#if !defined(USE_ROCM) +#if !defined(SGL_CUDA_ARCH) +#error "SGL_CUDA_ARCH is not defined. JIT compilation must inject -DSGL_CUDA_ARCH via load_jit()." +#endif +#if defined(__CUDA_ARCH__) +static_assert( + __CUDA_ARCH__ == SGL_CUDA_ARCH, "SGL_CUDA_ARCH mismatch: injected arch flag does not match device target"); +#endif +#define SGL_ARCH_HOPPER_OR_GREATER (SGL_CUDA_ARCH >= 900) +#define SGL_ARCH_BLACKWELL_OR_GREATER ((SGL_CUDA_ARCH >= 1000) && (CUDA_VERSION >= 12090)) +#else // USE_ROCM +#define SGL_ARCH_HOPPER_OR_GREATER 0 +#define SGL_ARCH_BLACKWELL_OR_GREATER 0 +#endif + +// Maximum vector size in bytes supported by current architecture. +// Pre-Blackwell / AMD: 128-bit (16 bytes) +// Blackwell or greater: 256-bit (32 bytes) +inline constexpr std::size_t kMaxVecBytes = SGL_ARCH_BLACKWELL_OR_GREATER ? 32 : 16; + +/// \brief Number of threads per warp (always 32 on NVIDIA/AMD GPUs). +inline constexpr auto kWarpThreads = 32u; +/// \brief Full warp active mask (all 32 lanes). +inline constexpr auto kFullMask = 0xffffffffu; + +/** + * \brief PDL (Programmatic Dependent Launch): wait for the primary kernel. + * + * On Hopper (sm_90+), inserts a `griddepcontrol.wait` instruction to + * synchronize with a preceding kernel in the same stream. On older + * architectures or ROCm this is a no-op. + */ +template +SGL_DEVICE void PDLWaitPrimary() { +#if SGL_ARCH_HOPPER_OR_GREATER + if constexpr (kUsePDL) { + asm volatile("griddepcontrol.wait;" ::: "memory"); + } +#endif +} + +/** + * \brief PDL: trigger dependent (secondary) kernel launch. + * + * On Hopper (sm_90+), inserts a `griddepcontrol.launch_dependents` + * instruction. On older architectures or ROCm this is a no-op. + */ +template +SGL_DEVICE void PDLTriggerSecondary() { +#if SGL_ARCH_HOPPER_OR_GREATER + if constexpr (kUsePDL) { + asm volatile("griddepcontrol.launch_dependents;" :::); + } +#endif +} + +template +SGL_DEVICE constexpr auto div_ceil(T a, U b) { + static_assert(std::is_integral::value && std::is_integral::value, + "div_ceil requires integer types"); + return (a + b - 1) / b; +} + +/** + * \brief Load data with the specified type and offset from a void pointer. + * \tparam T The type to load. + * \param ptr The base pointer. + * \param offset The offset in number of elements of type T. + */ +template +SGL_DEVICE T load_as(const void *ptr, int64_t offset = 0) { + return static_cast(ptr)[offset]; +} + +/** + * \brief Store data with the specified type and offset to a void pointer. + * \tparam T The type to store. + * \param ptr The base pointer. + * \param val The value to store. + * \param offset The offset in number of elements of type T. + * \note we use type_identity_t to force the caller to explicitly specify + * the template parameter `T`, which can avoid accidentally using the wrong type. + */ +template +SGL_DEVICE void store_as(void *ptr, T val, int64_t offset = 0) { + static_cast(ptr)[offset] = val; +} + +/// \brief Safe void-pointer arithmetic (byte-level by default). +namespace pointer { +// we only allow void * pointer arithmetic for safety + +template +SGL_DEVICE auto offset(void *ptr, U... offset) -> void * { + return static_cast(ptr) + (offset + ...); +} + +template +SGL_DEVICE auto offset(const void *ptr, U... offset) -> const void * { + return static_cast(ptr) + (offset + ...); +} + +} // namespace pointer + +} // namespace device + +namespace host { + +/** + * \brief Check the CUDA error code and panic with location info on failure. + */ +inline void RuntimeDeviceCheck(::cudaError_t error, DebugInfo location = {}) { + if (error != ::cudaSuccess) { + [[unlikely]]; + ::host::panic(location, "CUDA error: ", ::cudaGetErrorString(error)); + } +} + +/// \brief Check the last CUDA error (calls `cudaGetLastError`). +inline void RuntimeDeviceCheck(DebugInfo location = {}) { + return RuntimeDeviceCheck(::cudaGetLastError(), location); +} + +/** + * \brief Kernel launcher with automatic stream resolution and PDL support. + * + * Usage: + * \code + * host::LaunchKernel(grid, block, device) + * .enable_pdl(true) + * (my_kernel, arg1, arg2); + * \endcode + * + * The constructor resolves the CUDA stream from a `DLDevice` (via + * `TVMFFIEnvGetStream`) or accepts a raw `cudaStream_t`. The call + * operator launches the kernel and checks for errors. + */ +struct LaunchKernel { +public: + explicit LaunchKernel( + dim3 grid_dim, + dim3 block_dim, + DLDevice device, + std::size_t dynamic_shared_mem_bytes = 0, + DebugInfo location = {}) noexcept + : m_config(s_make_config(grid_dim, block_dim, resolve_device(device), dynamic_shared_mem_bytes)), + m_location(location) {} + + explicit LaunchKernel( + dim3 grid_dim, + dim3 block_dim, + cudaStream_t stream, + std::size_t dynamic_shared_mem_bytes = 0, + DebugInfo location = {}) noexcept + : m_config(s_make_config(grid_dim, block_dim, stream, dynamic_shared_mem_bytes)), m_location(location) {} + + LaunchKernel(const LaunchKernel &) = delete; + LaunchKernel &operator=(const LaunchKernel &) = delete; + + static auto resolve_device(DLDevice device) -> cudaStream_t { + return static_cast(::TVMFFIEnvGetStream(device.device_type, device.device_id)); + } + + auto enable_pdl(bool enabled = true) -> LaunchKernel & { +#ifdef USE_ROCM + (void)enabled; + m_config.numAttrs = 0; +#else + if (enabled) { + m_attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + m_attrs[0].val.programmaticStreamSerializationAllowed = true; + m_config.numAttrs = 1; + m_config.attrs = m_attrs; + } else { + m_config.numAttrs = 0; + } +#endif + return *this; + } + + template + auto operator()(T &&kernel, Args &&...args) const -> void { +#ifdef USE_ROCM + hipLaunchKernelGGL( + std::forward(kernel), + m_config.gridDim, + m_config.blockDim, + m_config.dynamicSmemBytes, + m_config.stream, + std::forward(args)...); + RuntimeDeviceCheck(m_location); +#else + RuntimeDeviceCheck(::cudaLaunchKernelEx(&m_config, kernel, std::forward(args)...), m_location); +#endif + } + +private: + static auto s_make_config( // Make a config for kernel launch + dim3 grid_dim, + dim3 block_dim, + cudaStream_t stream, + std::size_t smem) -> cudaLaunchConfig_t { + auto config = ::cudaLaunchConfig_t{}; + config.gridDim = grid_dim; + config.blockDim = block_dim; + config.dynamicSmemBytes = smem; + config.stream = stream; + config.numAttrs = 0; + return config; + } + + cudaLaunchConfig_t m_config; + const DebugInfo m_location; + cudaLaunchAttribute m_attrs[1]; +}; + +} // namespace host diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h new file mode 100644 index 000000000..bf7a5ce40 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h @@ -0,0 +1,241 @@ +/// \file utils.h +/// \brief Host-side C++ utilities used by JIT kernel wrappers. +/// +/// Provides: +/// - `DebugInfo` - wraps `std::source_location` for error reporting. +/// - `RuntimeCheck` - runtime assertion with formatted error messages. +/// - `Panic` - unconditional abort with formatted error messages. +/// - `pointer::offset` - safe void-pointer arithmetic (host side). +/// - `div_ceil` - integer ceiling division. +/// - `dtype_bytes` - byte width of a `DLDataType`. +/// - `irange` - Python-style integer range for range-for loops. + +#pragma once + +// ref: https://forums.developer.nvidia.com/t/c-20s-source-location-compilation-error-when-using-nvcc-12-1/258026/3 +#ifdef __CUDACC__ +#include +#if CUDA_VERSION <= 12010 + +#pragma push_macro("__cpp_consteval") +#pragma push_macro("_NODISCARD") +#pragma push_macro("__builtin_LINE") + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wbuiltin-macro-redefined" +#define __cpp_consteval 201811L +#pragma clang diagnostic pop + +#ifdef _NODISCARD +#undef _NODISCARD +#define _NODISCARD +#endif + +#define consteval constexpr + +#include "source_location.h" + +#undef consteval +#pragma pop_macro("__cpp_consteval") +#pragma pop_macro("_NODISCARD") +#else // __CUDACC__ && CUDA_VERSION > 12010 +#include "source_location.h" +#endif +#else // no __CUDACC__ +#include "source_location.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace host +{ + + template + inline constexpr bool dependent_false_v = false; + + /// \brief Source-location wrapper for debug/error messages. + struct DebugInfo : public source_location_t + { + DebugInfo(source_location_t loc = source_location_t::current()) : source_location_t(loc) {} + }; + + /// \brief Exception type thrown by `RuntimeCheck` and `Panic`. + struct PanicError : public std::runtime_error + { + public: + explicit PanicError(std::string msg) : runtime_error(msg), m_message(std::move(msg)) {} + auto root_cause() const -> std::string_view + { + const auto str = std::string_view{m_message}; + const auto pos = str.find(": "); + return pos == std::string_view::npos ? str : str.substr(pos + 2); + } + + private: + std::string m_message; + }; + + /// \brief Unconditionally abort with a formatted error message. + template + [[noreturn]] + inline auto panic(DebugInfo location, Args &&...args) -> void + { + std::ostringstream os; + os << "Runtime check failed at " << location.file_name() << ":" << location.line(); + if constexpr (sizeof...(args) > 0) + { + os << ": "; + (os << ... << std::forward(args)); + } + else + { + os << " in " << location.function_name(); + } + throw PanicError(std::move(os).str()); + } + + /** + * \brief Runtime assertion: panics with a formatted message when `condition` + * is false. Extra `args` are streamed to the error message. + * + * Example: + * \code + * RuntimeCheck(n > 0, "n must be positive, got ", n); + * \endcode + */ + template + struct RuntimeCheck + { + template + explicit RuntimeCheck(Cond &&condition, Args &&...args, DebugInfo location = {}) + { + if (condition) + return; + [[unlikely]] ::host::panic(location, std::forward(args)...); + } + template + explicit RuntimeCheck(DebugInfo location, Cond &&condition, Args &&...args) + { + if (condition) + return; + [[unlikely]] ::host::panic(location, std::forward(args)...); + } + }; + + template + struct Panic + { + explicit Panic(Args &&...args, DebugInfo location = {}) + { + ::host::panic(location, std::forward(args)...); + } + explicit Panic(DebugInfo location, Args &&...args) + { + ::host::panic(location, std::forward(args)...); + } + [[noreturn]] ~Panic() + { + std::terminate(); + } + }; + + template + explicit RuntimeCheck(Cond &&, Args &&...) -> RuntimeCheck; + + template + explicit RuntimeCheck(DebugInfo, Cond &&, Args &&...) -> RuntimeCheck; + + template + explicit Panic(Args &&...) -> Panic; + + template + explicit Panic(DebugInfo, Args &&...) -> Panic; + + namespace pointer + { + + // we only allow void * pointer arithmetic for safety + + template ::value && ...)>> + inline auto offset(void *ptr, U... offset) -> void * + { + return static_cast(ptr) + (... + offset); + } + + template ::value && ...)>> + inline auto offset(const void *ptr, U... offset) -> const void * + { + return static_cast(ptr) + (... + offset); + } + + } // namespace pointer + + /// \brief Integer ceiling division: ceil(a / b). + template + inline constexpr auto div_ceil(T a, U b) + { + static_assert(std::is_integral::value, "T must be integral"); + static_assert(std::is_integral::value, "U must be integral"); + return (a + b - 1) / b; + } + + /// \brief Returns the byte width of a DLPack data type. + inline auto dtype_bytes(DLDataType dtype) -> std::size_t + { + return static_cast(dtype.bits / 8); + } + + // ====================== 修复开始:纯 C++11 兼容版 irange ====================== + // 移除所有 std::ranges / std::integral,完全兼容旧版 CUDA 编译器 + + template + struct IntegerRange + { + T start_; + T end_; + + struct Iterator + { + T value; + + T operator*() const { return value; } + Iterator &operator++() + { + ++value; + return *this; + } + bool operator!=(const Iterator &other) const + { + return value != other.value; + } + }; + + Iterator begin() const { return {start_}; } + Iterator end() const { return {end_}; } + }; + + /// Python-style integer range: irange(n) -> [0, n) + template + IntegerRange irange(T end) + { + return {0, end}; + } + + /// Python-style integer range: irange(start, end) -> [start, end) + template + IntegerRange irange(T start, T end) + { + return {start, end}; + } + // ====================== 修复结束 ====================== + +} // namespace host diff --git a/test/infiniop/gptq_marlin_gemm.py b/test/infiniop/gptq_marlin_gemm.py new file mode 100644 index 000000000..9ba296d18 --- /dev/null +++ b/test/infiniop/gptq_marlin_gemm.py @@ -0,0 +1,623 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + TestWorkspace, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) +import itertools +from libinfiniop.scalar_type import scalar_types, ScalarType +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union +import numpy as np + + +# ============================================================================== +# Configuration +# ============================================================================== + +MNK_FACTORS = [ + (1, 1, 1), + (1, 4, 8), + (13, 17, 67), + (257, 13, 11), +] + +k_chunk = 128 +n_chunk = [64, 256] +quant_type = [scalar_types.uint4, scalar_types.uint4b8] +group_size = [-1, 128] +mnk_factors = MNK_FACTORS +act_order = [False, True] + +def to_iter(x): + return x if isinstance(x, (list, tuple)) else (x,) + +_TEST_CASES = list(itertools.product( + to_iter(k_chunk), + to_iter(n_chunk), + to_iter(quant_type), + to_iter(group_size), + to_iter(mnk_factors), + to_iter(act_order), +)) + +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16] + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +# ============================================================================== +# Reference Implementation (matches CUDA kernel) +# ============================================================================== + + +GPTQ_MARLIN_TILE = 16 +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: list[int] = [] + for i in range(32): + perm1: list[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + +def get_scale_perms(): + scale_perm: list[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: list[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_res = np.zeros((size_k, size_n // pack_factor), dtype=np.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(np.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + +def marlin_make_workspace( + device: torch.device, max_blocks_per_sm: int = 1 +) -> torch.Tensor: + # In the new marlin kernel, we use the num of threadblocks as workspace + # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros( + sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False + ) + + +# ============================================================================== +# Test Entrypoint +# ============================================================================== + + +def test( + handle, + device, + k_chunk, + n_chunk, + quant_type, + group_size, + mnk_factors, + act_order, + dtype=None, + sync=None, +): + m_factor, n_factor, k_factor = mnk_factors + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + if act_order: + if group_size == -1: + return + if group_size == size_k: + return + if has_zp: + return + + if size_k % group_size != 0: + return + + print( + f"Testing Gptq Marlin Gemm on {InfiniDeviceNames[device]} with M-K-N:({size_m, size_k, size_n}), group_size:{group_size}, dtype:{InfiniDtypeNames[dtype]}" + ) + + a_input = TestTensor((size_m, size_k), None, dtype, device) + b_weight = TestTensor((size_k, size_n), None, dtype, device) + if has_zp: + w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( + b_weight.torch_tensor(), quant_type, group_size + ) + g_idx = None + sort_indices = None + marlin_s2 = None + else: + w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + b_weight.torch_tensor(), quant_type, group_size, act_order + ) + marlin_zp = None + marlin_s2 = None + output_ref = torch.matmul(a_input.torch_tensor(), w_ref) + b = TestTensor(marlin_q_w.shape, marlin_q_w.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=marlin_q_w) + c = TestTensor(output_ref.shape, None, dtype, device) + b_scales = TestTensor(marlin_s.shape, marlin_s.stride(), dtype, device, mode="manual", set_tensor=marlin_s) + global_scale = None + if marlin_zp is not None: + b_zeros = TestTensor(marlin_zp.shape, marlin_zp.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=marlin_zp) + else: + b_zeros = None + if g_idx is not None: + b_g_idx = TestTensor(g_idx.shape, g_idx.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=g_idx) + else: + b_g_idx = None + if sort_indices is not None: + perm = TestTensor(sort_indices.shape, sort_indices.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=sort_indices) + else: + perm = None + + is_k_full=True + use_atomic_add=False + use_fp32_reduce=False + is_zp_float=False + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateGptqMarlinGemmDescriptor( + handle, + ctypes.byref(descriptor), + c.descriptor, + a_input.descriptor, + b.descriptor, + b_scales.descriptor, + global_scale.descriptor if global_scale is not None else None, + b_zeros.descriptor if b_zeros is not None else None, + b_g_idx.descriptor if b_g_idx is not None else None, + perm.descriptor if perm is not None else None, + ) + ) + + # Invalidate descriptors (same pattern as other tests) + for tensor in [c, a_input, b, b_scales, global_scale, b_zeros, b_g_idx, perm]: + if tensor is not None: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetGptqMarlinGemmWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + def lib_gptq_marlin_gemm(): + check_error( + LIBINFINIOP.infiniopGptqMarlinGemm( + descriptor, + workspace.data(), + workspace_size.value, + c.data(), + a_input.data(), + b.data(), + b_scales.data(), + global_scale.data() if global_scale is not None else None, + b_zeros.data() if b_zeros is not None else None, + b_g_idx.data() if b_g_idx is not None else None, + perm.data() if perm is not None else None, + quant_type.id, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, + None, + ) + ) + + lib_gptq_marlin_gemm() + + + max_diff = torch.mean(torch.abs(c.actual_tensor() - output_ref)) / torch.mean( + torch.abs(output_ref) + ) + assert max_diff < 0.04 + + if PROFILE: + profile_operation( + "PyTorch", + lambda: torch.matmul(a_input.torch_tensor(), w_ref), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", lambda: lib_gptq_marlin_gemm(), device, NUM_PRERUN, NUM_ITERATIONS + ) + + check_error(LIBINFINIOP.infiniopDestroyGptqMarlinGemmDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index f0d273d77..086e6a924 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -9,6 +9,17 @@ local FLASH_ATTN_ROOT = get_config("flash-attn") local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini") + +function parse_sgl_cuda_arch(arch) + + local num = arch:match("sm_(%d+)") + if not num then + return nil + end + + return tonumber(num) * 10 +end + target("infiniop-nvidia") set_kind("static") add_deps("infini-utils") @@ -100,6 +111,15 @@ target("infiniop-nvidia") end local arch_opt = get_config("cuda_arch") + if arch_opt then + local sgl_arch = parse_sgl_cuda_arch(arch_opt) + if sgl_arch then + add_defines("SGL_CUDA_ARCH=" .. sgl_arch) + print("SGL_CUDA_ARCH =", sgl_arch) + else + print("Invalid cuda_arch:", arch_opt) + end + end if arch_opt and type(arch_opt) == "string" then for _, arch in ipairs(arch_opt:split(",")) do arch = arch:trim() From 5d2890367ee2cc6d4ac5b451d38cc858049719cc Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Tue, 31 Mar 2026 14:22:48 +0800 Subject: [PATCH 02/10] issue/1083: modified format --- .../marlin/awq_marlin_repack.cuh | 369 ++- .../ops/gptq_marlin_gemm/marlin/dequant.h | 580 ++-- .../gptq_marlin_gemm/marlin/gptq_marlin.cuh | 1630 ++++++----- .../marlin/gptq_marlin_repack.cuh | 501 ++-- .../ops/gptq_marlin_gemm/marlin/kernel.h | 39 +- .../ops/gptq_marlin_gemm/marlin/marlin.cuh | 75 +- .../gptq_marlin_gemm/marlin/marlin_template.h | 2470 ++++++++--------- .../sgl_kernel/scalar_type.hpp | 509 ++-- .../sgl_kernel/source_location.h | 35 +- .../ops/gptq_marlin_gemm/sgl_kernel/tensor.h | 826 +++--- .../ops/gptq_marlin_gemm/sgl_kernel/utils.cuh | 6 +- .../ops/gptq_marlin_gemm/sgl_kernel/utils.h | 249 +- test/infiniop/gptq_marlin_gemm.py | 110 +- 13 files changed, 3471 insertions(+), 3928 deletions(-) diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/awq_marlin_repack.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/awq_marlin_repack.cuh index 2aea26529..2963dbf6b 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/marlin/awq_marlin_repack.cuh +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/awq_marlin_repack.cuh @@ -6,22 +6,19 @@ #include "marlin.cuh" -namespace device::marlin -{ +namespace device::marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - template - __global__ void awq_marlin_repack_kernel( - uint32_t const *__restrict__ b_q_weight_ptr, uint32_t *__restrict__ out_ptr, int size_k, int size_n) - { +template +__global__ void awq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, uint32_t *__restrict__ out_ptr, int size_k, int size_n) { return; - } +} #else - template - __global__ void awq_marlin_repack_kernel( - uint32_t const *__restrict__ b_q_weight_ptr, uint32_t *__restrict__ out_ptr, int size_k, int size_n) - { +template +__global__ void awq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, uint32_t *__restrict__ out_ptr, int size_k, int size_n) { constexpr int pack_factor = 32 / num_bits; int k_tiles = size_k / tile_k_size; @@ -29,22 +26,20 @@ namespace device::marlin int block_k_tiles = div_ceil(k_tiles, (int)gridDim.x); auto start_k_tile = blockIdx.x * block_k_tiles; - if (start_k_tile >= k_tiles) - { - return; + if (start_k_tile >= k_tiles) { + return; } int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() - { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); }; extern __shared__ int4 sh[]; @@ -55,227 +50,201 @@ namespace device::marlin constexpr int stage_k_threads = tile_k_size; constexpr int stage_size = stage_k_threads * stage_n_threads; - auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) - { - if (n_tile_id >= n_tiles) - { - cp_async_fence(); - return; - } + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } - int first_n = n_tile_id * tile_n_size; - int first_n_packed = first_n / pack_factor; + int first_n = n_tile_id * tile_n_size; + int first_n_packed = first_n / pack_factor; - int4 *sh_ptr = sh + stage_size * pipe; + int4 *sh_ptr = sh + stage_size * pipe; - if (threadIdx.x < stage_size) - { - auto k_id = threadIdx.x / stage_n_threads; - auto n_id = threadIdx.x % stage_n_threads; + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; - int first_k = k_tile_id * tile_k_size; + int first_k = k_tile_id * tile_k_size; - cp_async4( - &sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast( - &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + first_n_packed + (n_id * 4)]))); - } + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + first_n_packed + (n_id * 4)]))); + } - cp_async_fence(); + cp_async_fence(); }; - auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) - { - if (n_tile_id >= n_tiles) - { - return; - } + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } - auto warp_id = threadIdx.x / 32; - auto th_id = threadIdx.x % 32; + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; - if (warp_id >= 4) - { - return; - } + if (warp_id >= 4) { + return; + } - int tc_col = th_id / 4; - int tc_row = (th_id % 4) * 2; + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; - constexpr int tc_offsets[4] = {0, 1, 8, 9}; + constexpr int tc_offsets[4] = {0, 1, 8, 9}; - int cur_n = warp_id * 16 + tc_col; - int cur_n_packed = cur_n / pack_factor; - int cur_n_pos = cur_n % pack_factor; + int cur_n = warp_id * 16 + tc_col; + int cur_n_packed = cur_n / pack_factor; + int cur_n_pos = cur_n % pack_factor; - constexpr int sh_stride = tile_n_ints; - constexpr uint32_t mask = (1 << num_bits) - 1; + constexpr int sh_stride = tile_n_ints; + constexpr uint32_t mask = (1 << num_bits) - 1; - int4 *sh_stage_ptr = sh + stage_size * pipe; - uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + int4 *sh_stage_ptr = sh + stage_size * pipe; + uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); - // Undo interleaving - int cur_n_pos_unpacked; - if constexpr (num_bits == 4) - { - constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; - cur_n_pos_unpacked = undo_pack[cur_n_pos]; - } - else - { - constexpr int undo_pack[4] = {0, 2, 1, 3}; - cur_n_pos_unpacked = undo_pack[cur_n_pos]; - } + // Undo interleaving + int cur_n_pos_unpacked; + if constexpr (num_bits == 4) { + constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } else { + constexpr int undo_pack[4] = {0, 2, 1, 3}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } - uint32_t vals[8]; + uint32_t vals[8]; #pragma unroll - for (int i = 0; i < 4; i++) - { - int cur_elem = tc_row + tc_offsets[i]; + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; - int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; - int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stride * cur_elem]; + int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stride * cur_elem]; - vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; - vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; - } + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } - constexpr int tile_size_val = tile_k_size * tile_n_size / pack_factor; - int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size_val; + constexpr int tile_size_val = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size_val; - // Result of: - // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - if constexpr (num_bits == 4) - { - constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - uint32_t res = 0; + uint32_t res = 0; #pragma unroll - for (int i = 0; i < 8; i++) - { - res |= vals[pack_idx[i]] << (i * 4); - } + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } - out_ptr[out_offset + th_id * 4 + warp_id] = res; - } - else - { - constexpr int pack_idx[4] = {0, 2, 1, 3}; + out_ptr[out_offset + th_id * 4 + warp_id] = res; + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; - uint32_t res1 = 0; - uint32_t res2 = 0; + uint32_t res1 = 0; + uint32_t res2 = 0; #pragma unroll - for (int i = 0; i < 4; i++) - { - res1 |= vals[pack_idx[i]] << (i * 8); - res2 |= vals[4 + pack_idx[i]] << (i * 8); - } + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; - } + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } }; - auto start_pipes = [&](int k_tile_id, int n_tile_id) - { + auto start_pipes = [&](int k_tile_id, int n_tile_id) { #pragma unroll - for (int pipe = 0; pipe < repack_stages - 1; pipe++) - { - fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); - } + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } - wait_for_stage(); + wait_for_stage(); }; #pragma unroll - for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) - { - int n_tile_id = 0; + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; - start_pipes(k_tile_id, n_tile_id); + start_pipes(k_tile_id, n_tile_id); - while (n_tile_id < n_tiles) - { + while (n_tile_id < n_tiles) { #pragma unroll - for (int pipe = 0; pipe < repack_stages; pipe++) - { - fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); - repack_tile(pipe, k_tile_id, n_tile_id + pipe); - wait_for_stage(); + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; } - n_tile_id += repack_stages; - } } - } +} #endif } // namespace device::marlin // Host wrapper void awq_marlin_repack( - tvm::ffi::TensorView out, tvm::ffi::TensorView b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) -{ - using namespace host; - using namespace device::marlin; - - // Validate alignment - RuntimeCheck(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size); - RuntimeCheck(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size); - RuntimeCheck(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); - - int const pack_factor = 32 / num_bits; - - // Validate tensors - SymbolicDevice cuda_device; - cuda_device.set_options(); - - TensorMatcher({size_k, size_n / pack_factor}).with_dtype().with_device(cuda_device).verify(b_q_weight); - - TensorMatcher({size_k / tile_size, size_n * tile_size / pack_factor}) - .with_dtype() - .with_device(cuda_device) - .verify(out); - - // Get device and stream - auto device = cuda_device.unwrap(); - auto stream = LaunchKernel::resolve_device(device); - - // Get pointers - auto *b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); - auto *out_ptr = reinterpret_cast(out.data_ptr()); - - // Get device attributes - int blocks = 0; - cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, device.device_id); - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device.device_id); - RuntimeCheck(max_shared_mem > 0, "max_shared_mem must be > 0"); - - // Dispatch based on num_bits - if (num_bits == 4) - { - cudaFuncSetAttribute( - awq_marlin_repack_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); - LaunchKernel(blocks, repack_threads, stream, max_shared_mem)( - awq_marlin_repack_kernel, - b_q_weight_ptr, - out_ptr, - static_cast(size_k), - static_cast(size_n)); - } - else if (num_bits == 8) - { - cudaFuncSetAttribute( - awq_marlin_repack_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); - LaunchKernel(blocks, repack_threads, stream, max_shared_mem)( - awq_marlin_repack_kernel, - b_q_weight_ptr, - out_ptr, - static_cast(size_k), - static_cast(size_n)); - } - else - { - RuntimeCheck(false, "Unsupported repack config: num_bits = ", num_bits); - } + tvm::ffi::TensorView out, tvm::ffi::TensorView b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) { + using namespace host; + using namespace device::marlin; + + // Validate alignment + RuntimeCheck(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size); + RuntimeCheck(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size); + RuntimeCheck(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); + + int const pack_factor = 32 / num_bits; + + // Validate tensors + SymbolicDevice cuda_device; + cuda_device.set_options(); + + TensorMatcher({size_k, size_n / pack_factor}).with_dtype().with_device(cuda_device).verify(b_q_weight); + + TensorMatcher({size_k / tile_size, size_n * tile_size / pack_factor}) + .with_dtype() + .with_device(cuda_device) + .verify(out); + + // Get device and stream + auto device = cuda_device.unwrap(); + auto stream = LaunchKernel::resolve_device(device); + + // Get pointers + auto *b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); + auto *out_ptr = reinterpret_cast(out.data_ptr()); + + // Get device attributes + int blocks = 0; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, device.device_id); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device.device_id); + RuntimeCheck(max_shared_mem > 0, "max_shared_mem must be > 0"); + + // Dispatch based on num_bits + if (num_bits == 4) { + cudaFuncSetAttribute( + awq_marlin_repack_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); + LaunchKernel(blocks, repack_threads, stream, max_shared_mem)( + awq_marlin_repack_kernel, + b_q_weight_ptr, + out_ptr, + static_cast(size_k), + static_cast(size_n)); + } else if (num_bits == 8) { + cudaFuncSetAttribute( + awq_marlin_repack_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); + LaunchKernel(blocks, repack_threads, stream, max_shared_mem)( + awq_marlin_repack_kernel, + b_q_weight_ptr, + out_ptr, + static_cast(size_k), + static_cast(size_n)); + } else { + RuntimeCheck(false, "Unsupported repack config: num_bits = ", num_bits); + } } diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/dequant.h b/src/infiniop/ops/gptq_marlin_gemm/marlin/dequant.h index 764375f62..6a0d90e5d 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/marlin/dequant.h +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/dequant.h @@ -73,22 +73,26 @@ namespace device::marlin { // all cases. template __device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; } // Constructs destination register by taking bytes from 2 sources (based on // mask) template __device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask)); - return res; + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; } template -__device__ inline void dequant(int q, scalar_t2* frag_b); +__device__ inline void dequant(int q, scalar_t2 *frag_b); // // Efficiently dequantize 4bit values packed in an int32 value into a full @@ -100,102 +104,102 @@ __device__ inline void dequant(int q, scalar_t2* frag_b); // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 // template <> -__device__ inline void dequant(int q, half2* frag_b) { - const int MASK = 0x000f000f; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - q >>= 4; - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - - frag_b[0] = *reinterpret_cast(&lo); - frag_b[1] = *reinterpret_cast(&hi); +__device__ inline void dequant(int q, half2 *frag_b) { + const int MASK = 0x000f000f; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - // clang-format off +__device__ inline void dequant(int q, half2 *frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // clang-format on - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2( - *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, half2 *frag_b) { + dequant(q, frag_b); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - // clang-format off +__device__ inline void dequant(int q, half2 *frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // clang-format on - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64006400; - const int MUL = 0x2c002c00; - const int ADD = 0xd400d400; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2( - *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2( + *reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; - // Guarantee that the `(a & b) | c` operations are LOP3s. - // clang-format off + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); q >>= 4; int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); - // clang-format on + // clang-format on - frag_b[0] = *reinterpret_cast(&lo); - frag_b[1] = *reinterpret_cast(&hi); + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + dequant(q, frag_b); - static constexpr uint32_t SUB = 0x43084308; + static constexpr uint32_t SUB = 0x43084308; - frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); - frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + dequant(q, frag_b); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + dequant(q, frag_b); - static constexpr uint32_t SUB = 0x43004300; + static constexpr uint32_t SUB = 0x43004300; - frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); - frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); } // @@ -207,298 +211,298 @@ __device__ inline void dequant(int q, nv_bf // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 // template <> -__device__ inline void dequant(int q, half2* frag_b) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; +__device__ inline void dequant(int q, half2 *frag_b) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); - frag_b[0] = *reinterpret_cast(&lo); - frag_b[1] = *reinterpret_cast(&hi); + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, half2 *frag_b) { + dequant(q, frag_b); - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, half2 *frag_b) { + dequant(q, frag_b); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, half2 *frag_b) { + dequant(q, frag_b); - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; - frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); - - static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); - - fp32_intermediates[0] -= 8388736.f; - fp32_intermediates[1] -= 8388736.f; - fp32_intermediates[2] -= 8388736.f; - fp32_intermediates[3] -= 8388736.f; - - uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + float fp32_intermediates[4]; + uint32_t *fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t *bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); - - static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); - - fp32_intermediates[0] -= 8388608.f; - fp32_intermediates[1] -= 8388608.f; - fp32_intermediates[2] -= 8388608.f; - fp32_intermediates[3] -= 8388608.f; - - uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + float fp32_intermediates[4]; + uint32_t *fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t *bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - // Constants for FP8 (E4M3) and FP16 formats - constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; - constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; - constexpr int MASK = 0x7F007F00; - - // Extract and shift FP8 values to FP16 format - int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - q <<= 8; - int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); +__device__ inline void dequant(int q, half2 *frag_b) { + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, half2 *frag_b) { + dequant(q, frag_b); - // Constants for FP8 (E4M3) and FP16 formats - constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; - // Construct and apply exponent bias - constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); - const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); - // Convert to half2 and apply bias - frag_b[1] = __hmul2(frag_b[1], bias_reg); - frag_b[0] = __hmul2(frag_b[0], bias_reg); + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - // Constants for FP8 (E4M3) and BF16 formats - constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; - constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; - constexpr int MASK = 0x7F007F00; + constexpr int MASK = 0x7F007F00; - // Extract and shift FP8 values to BF16 format - int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - q <<= 8; - int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + // Extract and shift FP8 values to BF16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - dequant(q, frag_b); - - // Constants for FP8 (E4M3) and BF16 formats - constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; - - // Construct and apply exponent bias - constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); - // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent - // position - constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; - const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); - - // Convert to bfloat162 and apply bias - frag_b[1] = __hmul2(frag_b[1], bias_reg); - frag_b[0] = __hmul2(frag_b[0], bias_reg); +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to bfloat162 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - // Constants for FP4 (E2M1) and FP16 formats - constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; - constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; - constexpr int MASK = 0x70007000; - - // Extract and shift FP4 values to FP16 format - int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - q <<= 4; - int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); +__device__ inline void dequant(int q, half2 *frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); } template <> -__device__ inline void dequant(int q, half2* frag_b) { - dequant(q, frag_b); +__device__ inline void dequant(int q, half2 *frag_b) { + dequant(q, frag_b); - // Constants for FP4 (E2M1) and FP16 formats - constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; - // Construct and apply exponent bias - constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); - const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); - // Convert to half2 and apply bias - frag_b[1] = __hmul2(frag_b[1], bias_reg); - frag_b[0] = __hmul2(frag_b[0], bias_reg); + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - // Constants for FP4 (E2M1) and FP16 formats - constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; - constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT; - constexpr int MASK = 0x70007000; - - // Extract and shift FP4 values to FP16 format - int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - q <<= 4; - int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); } template <> -__device__ inline void dequant(int q, nv_bfloat162* frag_b) { - dequant(q, frag_b); - - // Constants for FP4 (E2M1) and BF16 formats - constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; - - // Construct and apply exponent bias - constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); - // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent - // position - constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; - const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); - - // Convert to half2 and apply bias - frag_b[1] = __hmul2(frag_b[1], bias_reg); - frag_b[0] = __hmul2(frag_b[0], bias_reg); +__device__ inline void dequant(int q, nv_bfloat162 *frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and BF16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); } template -__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); +__device__ inline void dequant_fp8_scales(int q, scalar_t2 *frag_b); template <> -__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { - int Out1 = (q & 0xFF00FF00) >> 1; - ; - q <<= 8; - int Out2 = (q & 0xFF00FF00) >> 1; - - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); +__device__ inline void dequant_fp8_scales(int q, half2 *frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); }; template <> -__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { - constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; - constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; - constexpr int MASK = 0x7F007F00; - - // Extract and shift FP8 values to BF16 format - int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); - q <<= 8; - int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); - - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162 *frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); }; // New version with s_type_id parameter for marlin_moe_wna16_v2 template -__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); +__device__ inline void dequant_fp8_scales(int q, scalar_t2 *frag_b); template <> -__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { - int Out1 = (q & 0xFF00FF00) >> 1; - ; - q <<= 8; - int Out2 = (q & 0xFF00FF00) >> 1; - - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); +__device__ inline void dequant_fp8_scales(int q, half2 *frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); }; template <> -__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { - constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; - constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; - constexpr int MASK = 0x7F007F00; - - // Extract and shift FP8 values to BF16 format - int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); - q <<= 8; - int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); - - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162 *frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); } template <> -__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { - // In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16, - // but we assume that such a extreme value would not occur in real models. - int Out1 = (q & 0xFF00FF00) >> 1; - q <<= 7; - int Out2 = q & 0x7F807F80; - - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = *reinterpret_cast(&Out1); - frag_b[0] = *reinterpret_cast(&Out2); +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162 *frag_b) { + // In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16, + // but we assume that such a extreme value would not occur in real models. + int Out1 = (q & 0xFF00FF00) >> 1; + q <<= 7; + int Out2 = q & 0x7F807F80; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); } #endif -} // namespace device::marlin +} // namespace device::marlin diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin.cuh index 653501357..ca85889f5 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin.cuh +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin.cuh @@ -28,132 +28,122 @@ #include "kernel.h" #include "marlin_template.h" -namespace device::marlin -{ +namespace device::marlin { - __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; +__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; - using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); +using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - __global__ void permute_cols_kernel( - int4 const *__restrict__ a_int4_ptr, - int const *__restrict__ perm_int_ptr, - int4 *__restrict__ out_int4_ptr, - int size_m, - int size_k, - int lda, - int block_rows) {} +__global__ void permute_cols_kernel( + int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) {} #else - // For a given "a" of size [M,K] performs a permutation of the K columns based - // on the given "perm" indices. - __global__ void permute_cols_kernel( - int4 const *__restrict__ a_int4_ptr, - int const *__restrict__ perm_int_ptr, - int4 *__restrict__ out_int4_ptr, - int size_m, - int size_k, - int lda, - int block_rows) - { +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel( + int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) { auto start_row = block_rows * blockIdx.x; int finish_row = start_row + block_rows; - if (finish_row > size_m) - { - finish_row = size_m; + if (finish_row > size_m) { + finish_row = size_m; } int cur_block_rows = finish_row - start_row; int input_row_stride = lda * sizeof(half) / 16; int output_row_stride = size_k * sizeof(half) / 16; - auto permute_row = [&](int row) - { - int iters = size_k / default_threads; - int rest = size_k % default_threads; + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; - int input_offset = row * input_row_stride; - int output_offset = row * output_row_stride; + int input_offset = row * input_row_stride; + int output_offset = row * output_row_stride; - half const *a_row_half = reinterpret_cast(a_int4_ptr + input_offset); - half *out_half = reinterpret_cast(out_int4_ptr + output_offset); + half const *a_row_half = reinterpret_cast(a_int4_ptr + input_offset); + half *out_half = reinterpret_cast(out_int4_ptr + output_offset); - int base_k = 0; + int base_k = 0; - for (int i = 0; i < iters; i++) - { - auto cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; + for (int i = 0; i < iters; i++) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; - out_half[cur_k] = a_row_half[src_pos]; + out_half[cur_k] = a_row_half[src_pos]; - base_k += default_threads; - } + base_k += default_threads; + } - if (rest) - { - if (threadIdx.x < rest) - { - auto cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; + if (rest) { + if (threadIdx.x < rest) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; - out_half[cur_k] = a_row_half[src_pos]; + out_half[cur_k] = a_row_half[src_pos]; + } } - } }; - for (int i = 0; i < cur_block_rows; i++) - { - int cur_row = start_row + i; - if (cur_row < size_m) - { - permute_row(cur_row); - } + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } } - } +} - typedef struct - { +typedef struct +{ int thread_k; int thread_n; int num_threads; - } thread_config_t; +} thread_config_t; - thread_config_t small_batch_thread_configs[] = { - // Ordered by priority +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority - // thread_k, thread_n, num_threads - {128, 128, 256}, - {64, 128, 128}, - {128, 64, 128}}; + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}}; - thread_config_t large_batch_thread_configs[] = { - // Ordered by priority +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority - // thread_k, thread_n, num_threads - {64, 256, 256}, - {64, 128, 128}, - {128, 64, 128}}; + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}}; - typedef struct - { +typedef struct +{ int blocks_per_sm; thread_config_t tb_cfg; - } exec_config_t; - - int get_scales_cache_size( - thread_config_t const &th_config, - int prob_m, - int prob_n, - int prob_k, - int num_bits, - int group_size, - bool has_act_order, - bool is_k_full) - { +} exec_config_t; + +int get_scales_cache_size( + thread_config_t const &th_config, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full) { bool cache_scales_chunk = has_act_order && !is_k_full; int tb_n = th_config.thread_n; @@ -161,46 +151,37 @@ namespace device::marlin // Get max scale groups per thread-block int tb_groups; - if (group_size == -1) - { - tb_groups = 1; - } - else if (group_size == 0) - { - tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size - } - else - { - tb_groups = div_ceil(tb_k, group_size); + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); } - if (cache_scales_chunk) - { - int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K - load_groups = max(load_groups, 32); // We load at least 32 scale groups - return load_groups * tb_n * 2; - } - else - { - int tb_scales = tb_groups * tb_n * 2; + if (cache_scales_chunk) { + int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + } else { + int tb_scales = tb_groups * tb_n * 2; - return tb_scales * pipe_stages; + return tb_scales * pipe_stages; } - } - - int get_kernel_cache_size( - thread_config_t const &th_config, - int thread_m_blocks, - int prob_m, - int prob_n, - int prob_k, - int num_bits, - int group_size, - bool has_act_order, - bool is_k_full, - int has_zp, - int is_zp_float) - { +} + +int get_kernel_cache_size( + thread_config_t const &th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float) { int pack_factor = 32 / num_bits; // Get B size @@ -210,61 +191,55 @@ namespace device::marlin int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; int sh_red_size = tb_m * (tb_n + 8); - int sh_s_size = - get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); + int sh_s_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; int sh_zp_size = 0; - if (has_zp) - { - if (is_zp_float) - sh_zp_size = sh_s_size; - else if (num_bits == 4) - sh_zp_size = sh_s_size / 4; - else if (num_bits == 8) - sh_zp_size = sh_s_size / 2; + if (has_zp) { + if (is_zp_float) { + sh_zp_size = sh_s_size; + } else if (num_bits == 4) { + sh_zp_size = sh_s_size / 4; + } else if (num_bits == 8) { + sh_zp_size = sh_s_size / 2; + } } int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size; return total_size; - } - - bool is_valid_config( - thread_config_t const &th_config, - int thread_m_blocks, - int prob_m, - int prob_n, - int prob_k, - int num_bits, - int group_size, - bool has_act_order, - bool is_k_full, - int has_zp, - int is_zp_float, - int max_shared_mem) - { +} + +bool is_valid_config( + thread_config_t const &th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float, + int max_shared_mem) { // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) - { - return false; + if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { + return false; } // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) - { - return false; + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; } // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) - { - return false; + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; } // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) - { - return false; + if (th_config.num_threads < 128) { + return false; } // Check that pipeline fits into cache @@ -281,27 +256,24 @@ namespace device::marlin has_zp, is_zp_float); return cache_size <= max_shared_mem; - } - -#define _GET_IF( \ - W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ - else if ( \ - q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) \ - { \ - kernel = Marlin< \ - scalar_t, \ - W_TYPE.id(), \ - NUM_THREADS, \ - THREAD_M_BLOCKS, \ - THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, \ - pipe_stages, \ - GROUP_BLOCKS, \ - IS_ZP_FLOAT>; \ - } +} + +#define _GET_IF( \ + W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if ( \ + q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin< \ + scalar_t, \ + W_TYPE.id(), \ + NUM_THREADS, \ + THREAD_M_BLOCKS, \ + THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, \ + pipe_stages, \ + GROUP_BLOCKS, \ + IS_ZP_FLOAT>; \ + } // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) // this is the most common cases @@ -309,132 +281,130 @@ namespace device::marlin // FZP: cases for float-zero-point (is_zp_float = true) // ACT: cases for act order case (group_blocks == 0) // FP4: cases for nvfp4(e2m1) (group_blocks == 1) -#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - -#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - -#define COMMON_GET_IF(W_TYPE) \ - COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ - COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ - COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ - COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) - -#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - -#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - -#define BIGGROUP_GET_IF(W_TYPE) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) - -#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - -#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - -#define FP4_GET_IF(W_TYPE) \ - FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ - FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ - FP4_GET_IF_M234(W_TYPE, 4, 8, 128) +#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF(W_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ + COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF(W_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 4, 8, 128) // We currently have 4-bit models only with group_blocks == 4 -#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - -#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - -#define FZP_GET_IF(W_TYPE) \ - FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ - FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M234(W_TYPE, 4, 8, 128) +#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF(W_TYPE) \ + FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, 4, 8, 128) // We currently have 4-bit models only with group_blocks == 4 -#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - -#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - -#define ACT_GET_IF(W_TYPE) \ - ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ - ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ - ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ - ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M234(W_TYPE, 4, 8, 128) - - template - MarlinFuncPtr get_marlin_kernel( - const host::ScalarType q_type, - int thread_m_blocks, - int thread_n_blocks, - int thread_k_blocks, - bool m_block_size_8, - bool has_act_order, - bool has_zp, - int group_blocks, - int num_threads, - bool is_zp_float) - { +#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF(W_TYPE) \ + ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ + ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, 4, 8, 128) + +template +MarlinFuncPtr get_marlin_kernel( + const host::ScalarType q_type, + int thread_m_blocks, + int thread_n_blocks, + int thread_k_blocks, + bool m_block_size_8, + bool has_act_order, + bool has_zp, + int group_blocks, + int num_threads, + bool is_zp_float) { int num_bits = q_type.size_bits(); auto kernel = MarlinDefault; - if (false) - { + if (false) { } COMMON_GET_IF(host::kU4) @@ -448,181 +418,163 @@ namespace device::marlin ACT_GET_IF(host::kU4B8) ACT_GET_IF(host::kU8B128) - if (std::is_same::value) - { - if (false) - { - } - FZP_GET_IF(host::kU4) + if (std::is_same::value) { + if (false) { + } + FZP_GET_IF(host::kU4) } return kernel; - } - - template - exec_config_t determine_exec_config( - const host::ScalarType &q_type, - int prob_m, - int prob_n, - int prob_k, - int thread_m_blocks, - bool m_block_size_8, - int num_bits, - int group_size, - bool has_act_order, - bool is_k_full, - bool has_zp, - bool is_zp_float, - int max_shared_mem, - int sms) - { +} + +template +exec_config_t determine_exec_config( + const host::ScalarType &q_type, + int prob_m, + int prob_n, + int prob_k, + int thread_m_blocks, + bool m_block_size_8, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + bool has_zp, + bool is_zp_float, + int max_shared_mem, + int sms) { exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; thread_config_t *thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs : small_batch_thread_configs; int thread_configs_size = thread_m_blocks > 1 ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); - for (int i = 0; i < thread_configs_size; i++) - { - thread_config_t th_config = thread_configs[i]; - - if (!is_valid_config( - th_config, - thread_m_blocks, - prob_m, - prob_n, - prob_k, - num_bits, - group_size, - has_act_order, - is_k_full, - has_zp, - is_zp_float, - max_shared_mem)) - { - continue; - } - - int cache_size = get_kernel_cache_size( - th_config, - thread_m_blocks, - prob_m, - prob_n, - prob_k, - num_bits, - group_size, - has_act_order, - is_k_full, - has_zp, - is_zp_float); - - int group_blocks = 0; - if (!has_act_order) - { - group_blocks = group_size == -1 ? -1 : group_size / 16; - } - - auto kernel = get_marlin_kernel( - q_type, - thread_m_blocks, - th_config.thread_n / 16, - th_config.thread_k / 16, - m_block_size_8, - has_act_order, - has_zp, - group_blocks, - th_config.num_threads, - is_zp_float); - - if (kernel == MarlinDefault) - continue; - - // int m_tiles = div_ceil(prob_m, thread_m_blocks * 16); - // int n_tiles = prob_n / th_config.thread_n; - // int k_tiles = prob_k / th_config.thread_k; - - return {1, th_config}; + for (int i = 0; i < thread_configs_size; i++) { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem)) { + continue; + } + + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + + int group_blocks = 0; + if (!has_act_order) { + group_blocks = group_size == -1 ? -1 : group_size / 16; + } + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + th_config.thread_n / 16, + th_config.thread_k / 16, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + th_config.num_threads, + is_zp_float); + + if (kernel == MarlinDefault) { + continue; + } + + // int m_tiles = div_ceil(prob_m, thread_m_blocks * 16); + // int n_tiles = prob_n / th_config.thread_n; + // int k_tiles = prob_k / th_config.thread_k; + + return {1, th_config}; } return exec_cfg; - } - - template - void marlin_mm( - const void *A, - const void *B, - void *C, - void *C_tmp, - void *s, - void *s2, - void *zp, - void *g_idx, - void *perm, - void *a_tmp, - int prob_m, - int prob_n, - int prob_k, - int lda, - void *workspace, - host::ScalarType const &q_type, - bool has_act_order, - bool is_k_full, - bool has_zp, - int num_groups, - int group_size, - int dev, - cudaStream_t stream, - int thread_k_init, - int thread_n_init, - int sms, - bool use_atomic_add, - bool use_fp32_reduce, - bool is_zp_float) - { - if (has_zp) - { - host::RuntimeCheck( - q_type == host::kU4 || q_type == host::kU8, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); - } - else - { - host::RuntimeCheck( - q_type == host::kU4B8 || q_type == host::kU8B128 || q_type == host::kFE4M3fn || q_type == host::kFE2M1f, - "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " - "has_zp = False. Got = ", - q_type.str()); +} + +template +void marlin_mm( + const void *A, + const void *B, + void *C, + void *C_tmp, + void *s, + void *s2, + void *zp, + void *g_idx, + void *perm, + void *a_tmp, + int prob_m, + int prob_n, + int prob_k, + int lda, + void *workspace, + host::ScalarType const &q_type, + bool has_act_order, + bool is_k_full, + bool has_zp, + int num_groups, + int group_size, + int dev, + cudaStream_t stream, + int thread_k_init, + int thread_n_init, + int sms, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) { + if (has_zp) { + host::RuntimeCheck( + q_type == host::kU4 || q_type == host::kU8, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } else { + host::RuntimeCheck( + q_type == host::kU4B8 || q_type == host::kU8B128 || q_type == host::kFE4M3fn || q_type == host::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); } host::RuntimeCheck( prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); int group_blocks = 0; - if (has_act_order) - { - if (is_k_full) - { - host::RuntimeCheck(group_size != -1); - group_blocks = group_size / 16; - host::RuntimeCheck( - prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); - } - else - { - host::RuntimeCheck(group_size == 0); - group_blocks = 0; - } - } - else - { - if (group_size == -1) - { - group_blocks = -1; - } - else - { - group_blocks = group_size / 16; - host::RuntimeCheck( - prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); - } + if (has_act_order) { + if (is_k_full) { + host::RuntimeCheck(group_size != -1); + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } else { + host::RuntimeCheck(group_size == 0); + group_blocks = 0; + } + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } } int num_bits = q_type.size_bits(); @@ -639,20 +591,20 @@ namespace device::marlin int *locks = (int *)workspace; - if (has_act_order) - { - // Permute A columns - int block_rows = div_ceil(prob_m, sms); - host::LaunchKernel(sms, default_threads, stream)( - permute_cols_kernel, A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows); - A_ptr = a_tmp_ptr; - lda = prob_k; - - // If we have a full K, then we can run the non-act-order version of Marlin - // (since the weight rows are reordered by increasing group ids, and by - // having a full K, we have full original groups) - if (is_k_full) - has_act_order = false; + if (has_act_order) { + // Permute A columns + int block_rows = div_ceil(prob_m, sms); + host::LaunchKernel(sms, default_threads, stream)( + permute_cols_kernel, A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows); + A_ptr = a_tmp_ptr; + lda = prob_k; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } } int max_shared_mem = 0; @@ -660,187 +612,184 @@ namespace device::marlin host::RuntimeCheck(max_shared_mem > 0); int max_par = 16; - if (prob_n <= 4096) - max_par = 16 * 8; + if (prob_n <= 4096) { + max_par = 16 * 8; + } int max_shared_mem_new = max_shared_mem; int rest_m = prob_m; int max_thread_m_blocks = 4; - while (rest_m) - { - int par_count = rest_m / (max_thread_m_blocks * 16); - if (par_count > max_par) - par_count = max_par; - int prob_m_split = par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m; - - int thread_k = thread_k_init; - int thread_n = thread_n_init; - - int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); - int m_block_size_8 = prob_m_split <= 8; - - // Set thread config - exec_config_t exec_cfg; - thread_config_t thread_tfg; - if (thread_k != -1 && thread_n != -1) - { - thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; - exec_cfg = exec_config_t{1, thread_tfg}; - host::RuntimeCheck(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); - host::RuntimeCheck(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); - } - else - { - // Auto config - exec_cfg = determine_exec_config( - q_type, - prob_m_split, - prob_n, - prob_k, + while (rest_m) { + int par_count = rest_m / (max_thread_m_blocks * 16); + if (par_count > max_par) { + par_count = max_par; + } + int prob_m_split = par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m; + + int thread_k = thread_k_init; + int thread_n = thread_n_init; + + int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); + int m_block_size_8 = prob_m_split <= 8; + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + host::RuntimeCheck(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); + host::RuntimeCheck(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); + } else { + // Auto config + exec_cfg = determine_exec_config( + q_type, + prob_m_split, + prob_n, + prob_k, + thread_m_blocks, + m_block_size_8, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem, + sms); + thread_tfg = exec_cfg.tb_cfg; + if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { + max_thread_m_blocks--; + continue; + } + } + + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) { + max_shared_mem_new = max_shared_mem / exec_cfg.blocks_per_sm - 1024; + } + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + host::RuntimeCheck( + is_valid_config( + thread_tfg, + thread_m_blocks, + prob_m_split, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem_new), + "Invalid thread config: thread_m_blocks = ", thread_m_blocks, - m_block_size_8, + ", thread_k = ", + thread_tfg.thread_k, + ", thread_n = ", + thread_tfg.thread_n, + ", num_threads = ", + thread_tfg.num_threads, + " for MKN = [", + prob_m, + ", ", + prob_k, + ", ", + prob_n, + "] and num_bits = ", num_bits, + ", prob_m_split = ", + prob_m_split, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, + ", is_k_full = ", is_k_full, + ", has_zp = ", has_zp, + ", is_zp_float = ", is_zp_float, - max_shared_mem, - sms); - thread_tfg = exec_cfg.tb_cfg; - if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) - { - max_thread_m_blocks--; - continue; - } - } - - int num_threads = thread_tfg.num_threads; - thread_k = thread_tfg.thread_k; - thread_n = thread_tfg.thread_n; - int blocks = sms * exec_cfg.blocks_per_sm; - if (exec_cfg.blocks_per_sm > 1) - max_shared_mem_new = max_shared_mem / exec_cfg.blocks_per_sm - 1024; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - - host::RuntimeCheck( - is_valid_config( - thread_tfg, - thread_m_blocks, - prob_m_split, - prob_n, - prob_k, - num_bits, - group_size, - has_act_order, - is_k_full, - has_zp, - is_zp_float, - max_shared_mem_new), - "Invalid thread config: thread_m_blocks = ", - thread_m_blocks, - ", thread_k = ", - thread_tfg.thread_k, - ", thread_n = ", - thread_tfg.thread_n, - ", num_threads = ", - thread_tfg.num_threads, - " for MKN = [", - prob_m, - ", ", - prob_k, - ", ", - prob_n, - "] and num_bits = ", - num_bits, - ", prob_m_split = ", - prob_m_split, - ", group_size = ", - group_size, - ", has_act_order = ", - has_act_order, - ", is_k_full = ", - is_k_full, - ", has_zp = ", - has_zp, - ", is_zp_float = ", - is_zp_float, - ", max_shared_mem_new = ", - max_shared_mem_new); - - auto kernel = get_marlin_kernel( - q_type, - thread_m_blocks, - thread_n_blocks, - thread_k_blocks, - m_block_size_8, - has_act_order, - has_zp, - group_blocks, - num_threads, - is_zp_float); - - if (kernel == MarlinDefault) - { - host::Panic( - "Unsupported shapes: MNK = [", - prob_m, - ", ", - prob_n, - ", ", - prob_k, - "]", - ", has_act_order = ", - has_act_order, - ", num_groups = ", - num_groups, - ", group_size = ", - group_size, - ", prob_m_split = ", - prob_m_split, - ", thread_m_blocks = ", + ", max_shared_mem_new = ", + max_shared_mem_new); + + auto kernel = get_marlin_kernel( + q_type, thread_m_blocks, - ", thread_n_blocks = ", thread_n_blocks, - ", thread_k_blocks = ", thread_k_blocks, - ", num_threads = ", + m_block_size_8, + has_act_order, + has_zp, + group_blocks, num_threads, - ", num_bits = ", - num_bits); - } - - host::RuntimeDeviceCheck( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem_new)); - - bool part_use_atomic_add = use_atomic_add && div_ceil(prob_m_split, 64) * prob_n <= 2048; - - host::LaunchKernel(blocks, num_threads, stream, max_shared_mem_new)( - kernel, - A_ptr, - B_ptr, - C_ptr, - C_tmp_ptr, - s_ptr, - s2_ptr, - zp_ptr, - g_idx_ptr, - num_groups, - prob_m_split, - prob_n, - prob_k, - lda, - locks, - part_use_atomic_add, - use_fp32_reduce, - max_shared_mem_new); - - A_ptr += prob_m_split * (lda / 8); - C_ptr += prob_m_split * (prob_n / 8); - rest_m -= prob_m_split; + is_zp_float); + + if (kernel == MarlinDefault) { + host::Panic( + "Unsupported shapes: MNK = [", + prob_m, + ", ", + prob_n, + ", ", + prob_k, + "]", + ", has_act_order = ", + has_act_order, + ", num_groups = ", + num_groups, + ", group_size = ", + group_size, + ", prob_m_split = ", + prob_m_split, + ", thread_m_blocks = ", + thread_m_blocks, + ", thread_n_blocks = ", + thread_n_blocks, + ", thread_k_blocks = ", + thread_k_blocks, + ", num_threads = ", + num_threads, + ", num_bits = ", + num_bits); + } + + host::RuntimeDeviceCheck( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem_new)); + + bool part_use_atomic_add = use_atomic_add && div_ceil(prob_m_split, 64) * prob_n <= 2048; + + host::LaunchKernel(blocks, num_threads, stream, max_shared_mem_new)( + kernel, + A_ptr, + B_ptr, + C_ptr, + C_tmp_ptr, + s_ptr, + s2_ptr, + zp_ptr, + g_idx_ptr, + num_groups, + prob_m_split, + prob_n, + prob_k, + lda, + locks, + part_use_atomic_add, + use_fp32_reduce, + max_shared_mem_new); + + A_ptr += prob_m_split * (lda / 8); + C_ptr += prob_m_split * (prob_n / 8); + rest_m -= prob_m_split; } - } +} #endif @@ -863,223 +812,202 @@ void gptq_marlin_gemm( bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, - bool is_zp_float) -{ - using namespace host; - - ScalarType const b_q_type = ScalarType::from_id(b_q_type_id); - int pack_factor = 32 / b_q_type.size_bits(); - - // Bind symbolic sizes - auto M = SymbolicSize{"M"}; - auto K = SymbolicSize{"K"}; - auto N = SymbolicSize{"N"}; - auto device = SymbolicDevice{}; - device.set_options(); - - // Verify a: [M, K] - auto lda = SymbolicSize{"lda"}; - TensorMatcher({M, K}).with_strides({lda, 1}).with_dtype().with_device(device).verify(a); - - int64_t size_m = M.unwrap(); - int64_t size_k = K.unwrap(); - - // Verify b_q_weight: [K/tile_size, packed_N] - RuntimeCheck( - size_k % device::marlin::tile_size == 0, - "size_k = ", - size_k, - " is not divisible by tile_size = ", - device::marlin::tile_size); - int64_t expected_bqw_dim0 = size_k / device::marlin::tile_size; - auto bqw_dim0 = SymbolicSize{"bqw_dim0"}; - auto bqw_dim1 = SymbolicSize{"bqw_dim1"}; - bqw_dim0.set_value(expected_bqw_dim0); - TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype().with_device(device).verify(b_q_weight); - - RuntimeCheck( - b_q_weight.size(1) % device::marlin::tile_size == 0, - "b_q_weight.size(1) = ", - b_q_weight.size(1), - " is not divisible by tile_size = ", - device::marlin::tile_size); - int64_t actual_size_n = (b_q_weight.size(1) / device::marlin::tile_size) * pack_factor; - N.set_value(actual_size_n); - int64_t size_n = N.unwrap(); - - // Verify stride alignment - int64_t a_stride0 = a.stride(0); - RuntimeCheck(a_stride0 % 8 == 0, "a.stride(0) must be divisible by 8"); - - // Verify b_scales: [num_groups, N] - auto num_groups_sym = SymbolicSize{"num_groups"}; - TensorMatcher({num_groups_sym, N}).with_device(device).verify(b_scales); - int num_groups = static_cast(num_groups_sym.unwrap()); - - // Verify c: [M, N] - TensorMatcher({M, N}).with_dtype().with_device(device).verify(c); - - // Early return for zero-size M - if (size_m == 0) - return; - - // Determine has_act_order from g_idx/perm sizes - int64_t g_idx_size = g_idx.size(0); - int64_t perm_size = perm.size(0); - bool has_act_order = g_idx_size > 0 && perm_size > 0; - - if (has_act_order) - { - RuntimeCheck( - (g_idx_size == size_k && perm_size == size_k), - "Unexpected g_idx.size(0) = ", - g_idx_size, - " and perm.size(0) = ", - perm_size, - ", where size_k = ", - size_k); - } - - // Determine has_zp from b_zeros size - int64_t b_zeros_size = b_zeros.size(0); - bool has_zp = b_zeros_size > 0; - - if (has_zp) - { - RuntimeCheck( - b_q_type == kU4 || b_q_type == kU8, "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); - } - else - { + bool is_zp_float) { + using namespace host; + + ScalarType const b_q_type = ScalarType::from_id(b_q_type_id); + int pack_factor = 32 / b_q_type.size_bits(); + + // Bind symbolic sizes + auto M = SymbolicSize{"M"}; + auto K = SymbolicSize{"K"}; + auto N = SymbolicSize{"N"}; + auto device = SymbolicDevice{}; + device.set_options(); + + // Verify a: [M, K] + auto lda = SymbolicSize{"lda"}; + TensorMatcher({M, K}).with_strides({lda, 1}).with_dtype().with_device(device).verify(a); + + int64_t size_m = M.unwrap(); + int64_t size_k = K.unwrap(); + + // Verify b_q_weight: [K/tile_size, packed_N] RuntimeCheck( - b_q_type == kU4B8 || b_q_type == kU8B128 || b_q_type == kFE4M3fn || b_q_type == kFE2M1f, - "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " - "has_zp = False. Got = ", - b_q_type.str()); - } - - if (has_zp && is_zp_float) - { + size_k % device::marlin::tile_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_size = ", + device::marlin::tile_size); + int64_t expected_bqw_dim0 = size_k / device::marlin::tile_size; + auto bqw_dim0 = SymbolicSize{"bqw_dim0"}; + auto bqw_dim1 = SymbolicSize{"bqw_dim1"}; + bqw_dim0.set_value(expected_bqw_dim0); + TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype().with_device(device).verify(b_q_weight); + RuntimeCheck( - std::is_same::value, "Computation type must be float16 (half) when using float zero points."); - } - - // Verify b_zeros shape - if (has_zp) - { - RuntimeCheck(b_zeros.dim() == 2, "b_zeros rank = ", b_zeros.dim(), " is not 2"); - if (is_zp_float) - { - RuntimeCheck(b_zeros.size(1) == size_n, "b_zeros dim 1 = ", b_zeros.size(1), " is not size_n = ", size_n); - RuntimeCheck( - num_groups == b_zeros.size(0), "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); - RuntimeCheck(num_groups != -1, "num_groups must be != -1"); + b_q_weight.size(1) % device::marlin::tile_size == 0, + "b_q_weight.size(1) = ", + b_q_weight.size(1), + " is not divisible by tile_size = ", + device::marlin::tile_size); + int64_t actual_size_n = (b_q_weight.size(1) / device::marlin::tile_size) * pack_factor; + N.set_value(actual_size_n); + int64_t size_n = N.unwrap(); + + // Verify stride alignment + int64_t a_stride0 = a.stride(0); + RuntimeCheck(a_stride0 % 8 == 0, "a.stride(0) must be divisible by 8"); + + // Verify b_scales: [num_groups, N] + auto num_groups_sym = SymbolicSize{"num_groups"}; + TensorMatcher({num_groups_sym, N}).with_device(device).verify(b_scales); + int num_groups = static_cast(num_groups_sym.unwrap()); + + // Verify c: [M, N] + TensorMatcher({M, N}).with_dtype().with_device(device).verify(c); + + // Early return for zero-size M + if (size_m == 0) { + return; } - else - { - RuntimeCheck( - b_zeros.size(0) == num_groups, "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); - RuntimeCheck( - b_zeros.size(1) == size_n / pack_factor, - "b_zeros dim 1 = ", - b_zeros.size(1), - " is not size_n / pack_factor = ", - size_n / pack_factor); + + // Determine has_act_order from g_idx/perm sizes + int64_t g_idx_size = g_idx.size(0); + int64_t perm_size = perm.size(0); + bool has_act_order = g_idx_size > 0 && perm_size > 0; + + if (has_act_order) { + RuntimeCheck( + (g_idx_size == size_k && perm_size == size_k), + "Unexpected g_idx.size(0) = ", + g_idx_size, + " and perm.size(0) = ", + perm_size, + ", where size_k = ", + size_k); } - } - - // Verify global_scale - int64_t global_scale_size = global_scale.size(0); - if (global_scale_size > 0) - { - RuntimeCheck(b_q_type == kFE2M1f, "global_scale can only be used for float4_e2m1f."); - } - else - { - RuntimeCheck(!(b_q_type == kFE2M1f), "the global_scale parameter must be passed for float4_e2m1f."); - } - - // Derive group_size - int group_size = -1; - if (has_act_order) - { - if (is_k_full) - { - RuntimeCheck(num_groups > 1, "For act_order, num_groups must be > 1"); - RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); - group_size = static_cast(size_k / num_groups); + + // Determine has_zp from b_zeros size + int64_t b_zeros_size = b_zeros.size(0); + bool has_zp = b_zeros_size > 0; + + if (has_zp) { + RuntimeCheck( + b_q_type == kU4 || b_q_type == kU8, "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + } else { + RuntimeCheck( + b_q_type == kU4B8 || b_q_type == kU8B128 || b_q_type == kFE4M3fn || b_q_type == kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + b_q_type.str()); } - else - { - group_size = 0; + + if (has_zp && is_zp_float) { + RuntimeCheck( + std::is_same::value, "Computation type must be float16 (half) when using float zero points."); } - } - else - { - if (num_groups > 1) - { - RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); - group_size = static_cast(size_k / num_groups); + + // Verify b_zeros shape + if (has_zp) { + RuntimeCheck(b_zeros.dim() == 2, "b_zeros rank = ", b_zeros.dim(), " is not 2"); + if (is_zp_float) { + RuntimeCheck(b_zeros.size(1) == size_n, "b_zeros dim 1 = ", b_zeros.size(1), " is not size_n = ", size_n); + RuntimeCheck( + num_groups == b_zeros.size(0), "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); + RuntimeCheck(num_groups != -1, "num_groups must be != -1"); + } else { + RuntimeCheck( + b_zeros.size(0) == num_groups, "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); + RuntimeCheck( + b_zeros.size(1) == size_n / pack_factor, + "b_zeros dim 1 = ", + b_zeros.size(1), + " is not size_n / pack_factor = ", + size_n / pack_factor); + } } - else - { - group_size = -1; + + // Verify global_scale + int64_t global_scale_size = global_scale.size(0); + if (global_scale_size > 0) { + RuntimeCheck(b_q_type == kFE2M1f, "global_scale can only be used for float4_e2m1f."); + } else { + RuntimeCheck(!(b_q_type == kFE2M1f), "the global_scale parameter must be passed for float4_e2m1f."); } - } - - // Verify workspace and get device info - RuntimeCheck( - size_n % device::marlin::min_thread_n == 0, - "size_n = ", - size_n, - ", is not divisible by min_thread_n = ", - device::marlin::min_thread_n); - - DLDevice dl_device = device.unwrap(); - int dev = dl_device.device_id; - cudaStream_t stream = LaunchKernel::resolve_device(dl_device); - - int sms = -1; - RuntimeDeviceCheck(cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev)); - - RuntimeCheck( - workspace.size(0) >= sms, "workspace.size(0) = ", workspace.size(0), " is below min_workspace_size = ", sms); - - // Hardcoded defaults (auto config) - int thread_k_init = -1; - int thread_n_init = -1; - - // Compute c_tmp and a_tmp pointers - // c_tmp and a_tmp are pre-allocated by caller - - device::marlin::marlin_mm( - a.data_ptr(), - b_q_weight.data_ptr(), - c.data_ptr(), - c_tmp.data_ptr(), - b_scales.data_ptr(), - global_scale.data_ptr(), - b_zeros.data_ptr(), - g_idx.data_ptr(), - perm.data_ptr(), - a_tmp.data_ptr(), - static_cast(size_m), - static_cast(size_n), - static_cast(size_k), - static_cast(a_stride0), - workspace.data_ptr(), - b_q_type, - has_act_order, - is_k_full, - has_zp, - num_groups, - group_size, - dev, - stream, - thread_k_init, - thread_n_init, - sms, - use_atomic_add, - use_fp32_reduce, - is_zp_float); + + // Derive group_size + int group_size = -1; + if (has_act_order) { + if (is_k_full) { + RuntimeCheck(num_groups > 1, "For act_order, num_groups must be > 1"); + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = static_cast(size_k / num_groups); + } else { + group_size = 0; + } + } else { + if (num_groups > 1) { + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = static_cast(size_k / num_groups); + } else { + group_size = -1; + } + } + + // Verify workspace and get device info + RuntimeCheck( + size_n % device::marlin::min_thread_n == 0, + "size_n = ", + size_n, + ", is not divisible by min_thread_n = ", + device::marlin::min_thread_n); + + DLDevice dl_device = device.unwrap(); + int dev = dl_device.device_id; + cudaStream_t stream = LaunchKernel::resolve_device(dl_device); + + int sms = -1; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev)); + + RuntimeCheck( + workspace.size(0) >= sms, "workspace.size(0) = ", workspace.size(0), " is below min_workspace_size = ", sms); + + // Hardcoded defaults (auto config) + int thread_k_init = -1; + int thread_n_init = -1; + + // Compute c_tmp and a_tmp pointers + // c_tmp and a_tmp are pre-allocated by caller + + device::marlin::marlin_mm( + a.data_ptr(), + b_q_weight.data_ptr(), + c.data_ptr(), + c_tmp.data_ptr(), + b_scales.data_ptr(), + global_scale.data_ptr(), + b_zeros.data_ptr(), + g_idx.data_ptr(), + perm.data_ptr(), + a_tmp.data_ptr(), + static_cast(size_m), + static_cast(size_n), + static_cast(size_k), + static_cast(a_stride0), + workspace.data_ptr(), + b_q_type, + has_act_order, + is_k_full, + has_zp, + num_groups, + group_size, + dev, + stream, + thread_k_init, + thread_n_init, + sms, + use_atomic_add, + use_fp32_reduce, + is_zp_float); } diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin_repack.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin_repack.cuh index d0c2d5414..f23f73cbf 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin_repack.cuh +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/gptq_marlin_repack.cuh @@ -27,29 +27,26 @@ #include "marlin.cuh" -namespace device::marlin -{ +namespace device::marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - template - __global__ void gptq_marlin_repack_kernel( - uint32_t const *__restrict__ b_q_weight_ptr, - uint32_t const *__restrict__ perm_ptr, - uint32_t *__restrict__ out_ptr, - int size_k, - int size_n) - { +template +__global__ void gptq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, + uint32_t const *__restrict__ perm_ptr, + uint32_t *__restrict__ out_ptr, + int size_k, + int size_n) { return; - } +} #else - template - __global__ void gptq_marlin_repack_kernel( - uint32_t const *__restrict__ b_q_weight_ptr, - uint32_t const *__restrict__ perm_ptr, - uint32_t *__restrict__ out_ptr, - int size_k, - int size_n) - { +template +__global__ void gptq_marlin_repack_kernel( + uint32_t const *__restrict__ b_q_weight_ptr, + uint32_t const *__restrict__ perm_ptr, + uint32_t *__restrict__ out_ptr, + int size_k, + int size_n) { constexpr int pack_factor = 32 / num_bits; int k_tiles = size_k / tile_k_size; @@ -57,22 +54,20 @@ namespace device::marlin int block_k_tiles = div_ceil(k_tiles, gridDim.x); auto start_k_tile = blockIdx.x * block_k_tiles; - if (start_k_tile >= k_tiles) - { - return; + if (start_k_tile >= k_tiles) { + return; } int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() - { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); }; extern __shared__ int4 sh[]; @@ -81,9 +76,8 @@ namespace device::marlin int4 *sh_perm_ptr = sh; int4 *sh_pipe_ptr = sh_perm_ptr; - if constexpr (has_perm) - { - sh_pipe_ptr += perm_size; + if constexpr (has_perm) { + sh_pipe_ptr += perm_size; } constexpr int tile_ints = tile_k_size / pack_factor; @@ -92,232 +86,202 @@ namespace device::marlin constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; constexpr int stage_size = stage_k_threads * stage_n_threads; - auto load_perm_to_shared = [&](int k_tile_id) - { - int first_k_int4 = (k_tile_id * tile_k_size) / 4; + auto load_perm_to_shared = [&](int k_tile_id) { + int first_k_int4 = (k_tile_id * tile_k_size) / 4; - int4 const *perm_int4_ptr = reinterpret_cast(perm_ptr); + int4 const *perm_int4_ptr = reinterpret_cast(perm_ptr); - if (threadIdx.x < perm_size) - { - sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; - } - __syncthreads(); + if (threadIdx.x < perm_size) { + sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; + } + __syncthreads(); }; - auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) - { - if (n_tile_id >= n_tiles) - { - cp_async_fence(); - return; - } + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } - int first_n = n_tile_id * tile_n_size; + int first_n = n_tile_id * tile_n_size; - int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe; + int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe; - if constexpr (has_perm) - { - if (threadIdx.x < stage_size) - { - auto k_id = threadIdx.x / stage_n_threads; - auto n_id = threadIdx.x % stage_n_threads; + if constexpr (has_perm) { + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; - uint32_t const *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + uint32_t const *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); - int src_k = sh_perm_int_ptr[k_id]; - int src_k_packed = src_k / pack_factor; + int src_k = sh_perm_int_ptr[k_id]; + int src_k_packed = src_k / pack_factor; - cp_async4( - &sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast(&(b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); - } - } - else - { - if (threadIdx.x < stage_size) - { - auto k_id = threadIdx.x / stage_n_threads; - auto n_id = threadIdx.x % stage_n_threads; - - int first_k = k_tile_id * tile_k_size; - int first_k_packed = first_k / pack_factor; - - cp_async4( - &sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast(&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)]))); + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&(b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); + } + } else { + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + int first_k_packed = first_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)]))); + } } - } - cp_async_fence(); + cp_async_fence(); }; - auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) - { - if (n_tile_id >= n_tiles) - { - return; - } + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } - auto warp_id = threadIdx.x / 32; - auto th_id = threadIdx.x % 32; + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; - if (warp_id >= 4) - { - return; - } + if (warp_id >= 4) { + return; + } - int tc_col = th_id / 4; - int tc_row = (th_id % 4) * 2; + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; - constexpr int tc_offsets[4] = {0, 1, 8, 9}; + constexpr int tc_offsets[4] = {0, 1, 8, 9}; - int cur_n = warp_id * 16 + tc_col; + int cur_n = warp_id * 16 + tc_col; - constexpr int sh_stride = 64; - constexpr uint32_t mask = (1 << num_bits) - 1; + constexpr int sh_stride = 64; + constexpr uint32_t mask = (1 << num_bits) - 1; - int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; - uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); - uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); - uint32_t vals[8]; + uint32_t vals[8]; - if constexpr (has_perm) - { - for (int i = 0; i < 4; i++) - { - int k_idx = tc_row + tc_offsets[i]; + if constexpr (has_perm) { + for (int i = 0; i < 4; i++) { + int k_idx = tc_row + tc_offsets[i]; - uint32_t src_k = sh_perm_int_ptr[k_idx]; - uint32_t src_k_pos = src_k % pack_factor; + uint32_t src_k = sh_perm_int_ptr[k_idx]; + uint32_t src_k_pos = src_k % pack_factor; - uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; - uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; + uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; - uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; - uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; + uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; - vals[i] = b1_cur_val; - vals[4 + i] = b2_cur_val; - } - } - else - { - uint32_t b1_vals[tile_ints]; - uint32_t b2_vals[tile_ints]; + vals[i] = b1_cur_val; + vals[4 + i] = b2_cur_val; + } + } else { + uint32_t b1_vals[tile_ints]; + uint32_t b2_vals[tile_ints]; #pragma unroll - for (int i = 0; i < tile_ints; i++) - { - b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; - b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; - } + for (int i = 0; i < tile_ints; i++) { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + } #pragma unroll - for (int i = 0; i < 4; i++) - { - int cur_elem = tc_row + tc_offsets[i]; - int cur_int = cur_elem / pack_factor; - int cur_pos = cur_elem % pack_factor; - - vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; - vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + int cur_int = cur_elem / pack_factor; + int cur_pos = cur_elem % pack_factor; + + vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + } } - } - constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; - int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; - // Result of: - // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - if constexpr (num_bits == 4) - { - constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - uint32_t res = 0; + uint32_t res = 0; #pragma unroll - for (int i = 0; i < 8; i++) - { - res |= vals[pack_idx[i]] << (i * 4); - } + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } - out_ptr[out_offset + th_id * 4 + warp_id] = res; - } - else - { - constexpr int pack_idx[4] = {0, 2, 1, 3}; + out_ptr[out_offset + th_id * 4 + warp_id] = res; + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; - uint32_t res1 = 0; - uint32_t res2 = 0; + uint32_t res1 = 0; + uint32_t res2 = 0; #pragma unroll - for (int i = 0; i < 4; i++) - { - res1 |= vals[pack_idx[i]] << (i * 8); - res2 |= vals[4 + pack_idx[i]] << (i * 8); - } + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; - out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; - } + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } }; - auto start_pipes = [&](int k_tile_id, int n_tile_id) - { + auto start_pipes = [&](int k_tile_id, int n_tile_id) { #pragma unroll - for (int pipe = 0; pipe < repack_stages - 1; pipe++) - { - fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); - } + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } - wait_for_stage(); + wait_for_stage(); }; #pragma unroll - for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) - { - int n_tile_id = 0; + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; - if constexpr (has_perm) - { - load_perm_to_shared(k_tile_id); - } + if constexpr (has_perm) { + load_perm_to_shared(k_tile_id); + } - start_pipes(k_tile_id, n_tile_id); + start_pipes(k_tile_id, n_tile_id); - while (n_tile_id < n_tiles) - { + while (n_tile_id < n_tiles) { #pragma unroll - for (int pipe = 0; pipe < repack_stages; pipe++) - { - fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); - repack_tile(pipe, k_tile_id, n_tile_id + pipe); - wait_for_stage(); + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; } - n_tile_id += repack_stages; - } } - } +} #endif } // namespace device::marlin -#define CALL_IF_REPACK(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) \ - { \ - host::RuntimeDeviceCheck(cudaFuncSetAttribute( \ - device::marlin::gptq_marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - max_shared_mem)); \ - host::LaunchKernel(blocks, device::marlin::repack_threads, stream, static_cast(max_shared_mem))( \ - device::marlin::gptq_marlin_repack_kernel, \ - b_q_weight_ptr, \ - perm_ptr, \ - out_ptr, \ - size_k, \ - size_n); \ - } +#define CALL_IF_REPACK(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + host::RuntimeDeviceCheck(cudaFuncSetAttribute( \ + device::marlin::gptq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem)); \ + host::LaunchKernel(blocks, device::marlin::repack_threads, stream, static_cast(max_shared_mem))( \ + device::marlin::gptq_marlin_repack_kernel, \ + b_q_weight_ptr, \ + perm_ptr, \ + out_ptr, \ + size_k, \ + size_n); \ + } void gptq_marlin_repack( tvm::ffi::TensorView b_q_weight, @@ -325,74 +289,71 @@ void gptq_marlin_repack( tvm::ffi::TensorView out, int64_t size_k, int64_t size_n, - int64_t num_bits) -{ - using namespace host; - - // Validate num_bits - RuntimeCheck(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); - int const pack_factor = 32 / static_cast(num_bits); - - // Validate size alignment - RuntimeCheck( - size_k % device::marlin::tile_k_size == 0, - "size_k = ", - size_k, - " is not divisible by tile_k_size = ", - device::marlin::tile_k_size); - RuntimeCheck( - size_n % device::marlin::tile_n_size == 0, - "size_n = ", - size_n, - " is not divisible by tile_n_size = ", - device::marlin::tile_n_size); - - // Validate b_q_weight - auto bqw_dim0 = SymbolicSize{"bqw_dim0"}; - auto bqw_dim1 = SymbolicSize{"bqw_dim1"}; - bqw_dim0.set_value(size_k / pack_factor); - bqw_dim1.set_value(size_n); - auto device_ = SymbolicDevice{}; - device_.set_options(); - TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype().with_device(device_).verify(b_q_weight); - - // Validate out - auto out_dim0 = SymbolicSize{"out_dim0"}; - auto out_dim1 = SymbolicSize{"out_dim1"}; - out_dim0.set_value(size_k / device::marlin::tile_size); - out_dim1.set_value(size_n * device::marlin::tile_size / pack_factor); - TensorMatcher({out_dim0, out_dim1}).with_dtype().with_device(device_).verify(out); - - // Detect if there is act_order - bool has_perm = perm.size(0) != 0; - - // Get ptrs - uint32_t const *b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); - uint32_t const *perm_ptr = reinterpret_cast(perm.data_ptr()); - uint32_t *out_ptr = reinterpret_cast(out.data_ptr()); - - // Get dev info - DLDevice dl_device = device_.unwrap(); - int dev = dl_device.device_id; - cudaStream_t stream = LaunchKernel::resolve_device(dl_device); - int blocks; - RuntimeDeviceCheck(cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev)); - - int max_shared_mem = 0; - RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); - RuntimeCheck(max_shared_mem > 0, "max_shared_mem must be > 0"); - - if (false) - { - } - CALL_IF_REPACK(4, false) - CALL_IF_REPACK(4, true) - CALL_IF_REPACK(8, false) - CALL_IF_REPACK(8, true) - else - { - Panic("Unsupported repack config: num_bits = ", num_bits, ", has_perm = ", has_perm); - } + int64_t num_bits) { + using namespace host; + + // Validate num_bits + RuntimeCheck(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / static_cast(num_bits); + + // Validate size alignment + RuntimeCheck( + size_k % device::marlin::tile_k_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_k_size = ", + device::marlin::tile_k_size); + RuntimeCheck( + size_n % device::marlin::tile_n_size == 0, + "size_n = ", + size_n, + " is not divisible by tile_n_size = ", + device::marlin::tile_n_size); + + // Validate b_q_weight + auto bqw_dim0 = SymbolicSize{"bqw_dim0"}; + auto bqw_dim1 = SymbolicSize{"bqw_dim1"}; + bqw_dim0.set_value(size_k / pack_factor); + bqw_dim1.set_value(size_n); + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype().with_device(device_).verify(b_q_weight); + + // Validate out + auto out_dim0 = SymbolicSize{"out_dim0"}; + auto out_dim1 = SymbolicSize{"out_dim1"}; + out_dim0.set_value(size_k / device::marlin::tile_size); + out_dim1.set_value(size_n * device::marlin::tile_size / pack_factor); + TensorMatcher({out_dim0, out_dim1}).with_dtype().with_device(device_).verify(out); + + // Detect if there is act_order + bool has_perm = perm.size(0) != 0; + + // Get ptrs + uint32_t const *b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); + uint32_t const *perm_ptr = reinterpret_cast(perm.data_ptr()); + uint32_t *out_ptr = reinterpret_cast(out.data_ptr()); + + // Get dev info + DLDevice dl_device = device_.unwrap(); + int dev = dl_device.device_id; + cudaStream_t stream = LaunchKernel::resolve_device(dl_device); + int blocks; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev)); + + int max_shared_mem = 0; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); + RuntimeCheck(max_shared_mem > 0, "max_shared_mem must be > 0"); + + if (false) { + } + CALL_IF_REPACK(4, false) + CALL_IF_REPACK(4, true) + CALL_IF_REPACK(8, false) + CALL_IF_REPACK(8, true) + else { + Panic("Unsupported repack config: num_bits = ", num_bits, ", has_perm = ", has_perm); + } } #undef CALL_IF_REPACK diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/kernel.h b/src/infiniop/ops/gptq_marlin_gemm/marlin/kernel.h index e0e36cdd4..785d5e9b9 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/marlin/kernel.h +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/kernel.h @@ -10,25 +10,24 @@ const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem -namespace device::marlin -{ - template < - typename scalar_t, // compute dtype, half or nv_float16 - const host::ScalarTypeId w_type_id, // weight ScalarType id - const int threads, // number of threads in a threadblock - const int thread_m_blocks, // number of 16x16 blocks in the m - // dimension (batchsize) of the - // threadblock - const int thread_n_blocks, // same for n dimension (output) - const int thread_k_blocks, // same for k dimension (reduction) - const bool m_block_size_8, // whether m_block_size == 8 - // only works when thread_m_blocks == 1 - const int stages, // number of stages for the async global->shared - // fetch pipeline - const int group_blocks, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? - > - __global__ void Marlin(MARLIN_KERNEL_PARAMS); +namespace device::marlin { +template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin(MARLIN_KERNEL_PARAMS); } // namespace device::marlin diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin.cuh b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin.cuh index 9e99d0f4d..944ca3522 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin.cuh +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin.cuh @@ -4,53 +4,49 @@ #include -namespace device::marlin -{ - // Marlin params +namespace device::marlin { +// Marlin params - // 8 warps are a good choice since every SM has 4 schedulers and having more - // than 1 warp per schedule allows some more latency hiding. At the same time, - // we want relatively few warps to have many registers per warp and small tiles. - static constexpr int default_threads = 256; +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int default_threads = 256; - static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory +static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory - static constexpr int min_thread_n = 64; - static constexpr int min_thread_k = 64; - static constexpr int max_thread_n = 256; +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; +static constexpr int max_thread_n = 256; - static constexpr int tile_size = 16; - static constexpr int max_par = 16; +static constexpr int tile_size = 16; +static constexpr int max_par = 16; - // Repack params - static constexpr int repack_stages = 8; +// Repack params +static constexpr int repack_stages = 8; - static constexpr int repack_threads = 256; +static constexpr int repack_threads = 256; - static constexpr int tile_k_size = tile_size; - static constexpr int tile_n_size = tile_k_size * 4; +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; - // Helpers - template - struct Vec - { +// Helpers +template +struct Vec { T elems[n]; - __device__ T &operator[](int i) - { - return elems[i]; + __device__ T &operator[](int i) { + return elems[i]; } - }; +}; - using I4 = Vec; +using I4 = Vec; - using host::div_ceil; +using host::div_ceil; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 // No support for async #else - __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, bool pred = true) - { +__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, bool pred = true) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( @@ -62,10 +58,9 @@ namespace device::marlin "r"(smem), "l"(glob_ptr), "n"(BYTES)); - } +} - __device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) - { +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile( @@ -74,18 +69,16 @@ namespace device::marlin "}\n" ::"r"(smem), "l"(glob_ptr), "n"(BYTES)); - } +} - __device__ inline void cp_async_fence() - { +__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); - } +} - template - __device__ inline void cp_async_wait() - { +template +__device__ inline void cp_async_wait() { asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); - } +} #endif diff --git a/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_template.h b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_template.h index 8f35f227d..4ea220265 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_template.h +++ b/src/infiniop/ops/gptq_marlin_gemm/marlin/marlin_template.h @@ -24,209 +24,187 @@ #include "marlin.cuh" #include "marlin_dtypes.cuh" -#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ - static_assert( \ - std::is_same::value || std::is_same::value, \ - "only float16 and bfloat16 is supported"); +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert( \ + std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); -namespace device::marlin -{ +namespace device::marlin { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - template < - typename scalar_t, // compute dtype, half or nv_float16 - const host::ScalarTypeId w_type_id, // weight ScalarType id - const int threads, // number of threads in a threadblock - const int thread_m_blocks, // number of 16x16 blocks in the m - // dimension (batchsize) of the - // threadblock - const int thread_n_blocks, // same for n dimension (output) - const int thread_k_blocks, // same for k dimension (reduction) - const bool m_block_size_8, // whether m_block_size == 8 - // only works when thread_m_blocks == 1 - const int stages, // number of stages for the async global->shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? - > - __global__ void Marlin( - const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce) - const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int *__restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int *locks, // extra global storage for barrier synchronization - bool use_fp32_reduce // whether to use fp32 global reduce - ) - { - } +template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks, // extra global storage for barrier synchronization + bool use_fp32_reduce // whether to use fp32 global reduce +) { +} } // namespace device::marlin #else - // m16n8k16 tensor core mma instruction with fp16 inputs and fp32 - // output/accumulation. - template - __device__ inline void - mma(const typename ScalarType::FragA &a_frag, - const typename ScalarType::FragB &frag_b, - typename ScalarType::FragC &frag_c) - { +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void +mma(const typename ScalarType::FragA &a_frag, + const typename ScalarType::FragB &frag_b, + typename ScalarType::FragC &frag_c) { const uint32_t *a = reinterpret_cast(&a_frag); const uint32_t *b = reinterpret_cast(&frag_b); float *c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); } - else if constexpr (std::is_same::value) - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } - else - { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); - } - } - - template - __device__ inline void mma_trans( - const typename ScalarType::FragA &a_frag, - const typename ScalarType::FragB &frag_b, - const typename ScalarType::FragB &frag_b2, - typename ScalarType::FragC &frag_c) - { +} + +template +__device__ inline void mma_trans( + const typename ScalarType::FragA &a_frag, + const typename ScalarType::FragB &frag_b, + const typename ScalarType::FragB &frag_b2, + typename ScalarType::FragC &frag_c) { const uint32_t *a = reinterpret_cast(&a_frag); const uint32_t *b = reinterpret_cast(&frag_b); const uint32_t *b2 = reinterpret_cast(&frag_b2); float *c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), - "r"(b2[0]), - "r"(b[1]), - "r"(b2[1]), - "r"(a[0]), - "r"(a[1]), - "f"(c[0]), - "f"(c[1]), - "f"(c[2]), - "f"(c[3])); - } - else if constexpr (std::is_same::value) - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), - "r"(b2[0]), - "r"(b[1]), - "r"(b2[1]), - "r"(a[0]), - "r"(a[1]), - "f"(c[0]), - "f"(c[1]), - "f"(c[2]), - "f"(c[3])); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), + "r"(b2[0]), + "r"(b[1]), + "r"(b2[1]), + "r"(a[0]), + "r"(a[1]), + "f"(c[0]), + "f"(c[1]), + "f"(c[2]), + "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), + "r"(b2[0]), + "r"(b[1]), + "r"(b2[1]), + "r"(a[0]), + "r"(a[1]), + "f"(c[0]), + "f"(c[1]), + "f"(c[2]), + "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); } - else - { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); - } - } +} - // Instruction for loading a full 16x16 matrix fragment of operand A from shared - // memory, directly in tensor core layout. - template - __device__ inline void ldsm(typename ScalarType::FragA &frag_a, const void *smem_ptr) - { +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm(typename ScalarType::FragA &frag_a, const void *smem_ptr) { uint32_t *a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - if constexpr (count == 4) - { - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); - } - else if constexpr (count == 2) - { - asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) : "r"(smem)); - } - else if constexpr (count == 1) - { - asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(a[0]) : "r"(smem)); + if constexpr (count == 4) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } else if constexpr (count == 2) { + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) + : "r"(smem)); + } else if constexpr (count == 1) { + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" + : "=r"(a[0]) + : "r"(smem)); + } else { + static_assert(count == 1 || count == 2 || count == 4, "invalid count"); } - else - { - static_assert(count == 1 || count == 2 || count == 4, "invalid count"); - } - } - - // Multiply dequantized values by the corresponding quantization scale; used - // only for grouped quantization. - template - __device__ inline void - scale(typename ScalarType::FragB &frag_b, typename ScalarType::FragS &frag_s, int i) - { +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void +scale(typename ScalarType::FragB &frag_b, typename ScalarType::FragS &frag_s, int i) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); - } +} - template - __device__ inline void scale_and_sub(typename ScalarType::FragB &frag_b, scalar_t s, scalar_t zp) - { +template +__device__ inline void scale_and_sub(typename ScalarType::FragB &frag_b, scalar_t s, scalar_t zp) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s2 = ScalarType::num2num2(s); scalar_t2 zp2 = ScalarType::num2num2(zp); frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); - } +} - template - __device__ inline void - sub_zp(typename ScalarType::FragB &frag_b, typename ScalarType::scalar_t2 &frag_zp, int i) - { +template +__device__ inline void +sub_zp(typename ScalarType::FragB &frag_b, typename ScalarType::scalar_t2 &frag_zp, int i) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 zp = ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); frag_b[0] = __hsub2(frag_b[0], zp); frag_b[1] = __hsub2(frag_b[1], zp); - } - - // Same as above, but for act_order (each K is multiplied individually) - template - __device__ inline void scale4( - typename ScalarType::FragB &frag_b, - typename ScalarType::FragS &frag_s_1, - typename ScalarType::FragS &frag_s_2, - typename ScalarType::FragS &frag_s_3, - typename ScalarType::FragS &frag_s_4, - int i) - { +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4( + typename ScalarType::FragB &frag_b, + typename ScalarType::FragS &frag_s_1, + typename ScalarType::FragS &frag_s_2, + typename ScalarType::FragS &frag_s_3, + typename ScalarType::FragS &frag_s_4, + int i) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s_val_1_2; s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; @@ -238,106 +216,103 @@ namespace device::marlin frag_b[0] = __hmul2(frag_b[0], s_val_1_2); frag_b[1] = __hmul2(frag_b[1], s_val_3_4); - } +} - // Given 2 floats multiply by 2 scales (halves) - template - __device__ inline void scale_float(float *c, typename ScalarType::FragS &s) - { +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float *c, typename ScalarType::FragS &s) { scalar_t *s_ptr = reinterpret_cast(&s); c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); - } - - // Wait until barrier reaches `count`, then lock for current threadblock. - __device__ inline void barrier_acquire(int *lock, int count) - { - if (threadIdx.x == 0) - { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); - while (state != count); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int *lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do { + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + } while (state != count); } __syncthreads(); - } +} - // Release barrier and increment visitation count. - __device__ inline void barrier_release(int *lock, bool reset = false) - { +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int *lock, bool reset = false) { __syncthreads(); - if (threadIdx.x == 0) - { - if (reset) - { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); } - } - - // Wait until value of lock to be negative, and then add 1 - __device__ inline void wait_negative_and_add(int *lock) - { - if (threadIdx.x == 0) - { - int state = 0; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); - while (state >= 0); - atomicAdd(lock, 1); +} + +// Wait until value of lock to be negative, and then add 1 +__device__ inline void wait_negative_and_add(int *lock) { + if (threadIdx.x == 0) { + int state = 0; + do { + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + } while (state >= 0); + atomicAdd(lock, 1); } __syncthreads(); - } - - template < - typename scalar_t, // compute dtype, half or nv_float16 - const host::ScalarTypeId w_type_id, // weight ScalarType id - const int threads, // number of threads in a threadblock - const int thread_m_blocks, // number of 16x16 blocks in the m - // dimension (batchsize) of the - // threadblock - const int thread_n_blocks, // same for n dimension (output) - const int thread_k_blocks, // same for k dimension (reduction) - const bool m_block_size_8, // whether m_block_size == 8 - // only works when thread_m_blocks == 1 - const int stages, // number of stages for the async global->shared - // fetch pipeline - const int group_blocks, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? - > - __global__ void Marlin( - const int4 *__restrict__ A, // fp16 input matrix of shape mxk - const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn - int4 *__restrict__ C, // fp16 output buffer of shape mxn - int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce) - const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const uint16_t *__restrict__ scale2_ptr, // fp16 global scale (for nvfp4 - // only) - const int4 *__restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int *__restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int lda, // A.stride(0), equal to prob_k is A is contiguous - int *locks, // extra global storage for barrier synchronization - bool use_atomic_add, // whether to use atomic add to reduce - bool use_fp32_reduce, // whether to use fp32 global reduce - int max_shared_mem) - { +} + +template < + typename scalar_t, // compute dtype, half or nv_float16 + const host::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const bool m_block_size_8, // whether m_block_size == 8 + // only works when thread_m_blocks == 1 + const int stages, // number of stages for the async global->shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const uint16_t *__restrict__ scale2_ptr, // fp16 global scale (for nvfp4 + // only) + const int4 *__restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int lda, // A.stride(0), equal to prob_k is A is contiguous + int *locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM @@ -359,19 +334,15 @@ namespace device::marlin static constexpr auto w_type = host::ScalarType::from_id(w_type_id); constexpr bool has_zp = w_type == host::kU4 || w_type == host::kU8; - constexpr bool is_int_type = - w_type == host::kU4 || w_type == host::kU8 || w_type == host::kU4B8 || w_type == host::kU8B128; + constexpr bool is_int_type = w_type == host::kU4 || w_type == host::kU8 || w_type == host::kU4B8 || w_type == host::kU8B128; // see comments of dequant.h for more details - constexpr bool dequant_skip_flop = !is_int_type || - has_zp && !is_zp_float && !std::is_same::value || - has_zp && !is_zp_float && !(w_type == host::kU8); + constexpr bool dequant_skip_flop = !is_int_type || has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(w_type == host::kU8); scalar_t2 global_scale; - if constexpr (w_type == host::kFE2M1f) - { - uint16_t val = scale2_ptr[0]; - global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + if constexpr (w_type == host::kFE2M1f) { + uint16_t val = scale2_ptr[0]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); } constexpr bool has_act_order = group_blocks == 0; @@ -383,25 +354,22 @@ namespace device::marlin // For larger GEMMs we run multiple batchsize 64 versions in parallel for a // better partitioning with less reductions int parallel = 1; - if (prob_m > m_block_size) - { - parallel = prob_m / m_block_size; - prob_m = m_block_size; + if (prob_m > m_block_size) { + parallel = prob_m / m_block_size; + prob_m = m_block_size; } int k_tiles = prob_k / 16 / thread_k_blocks; int n_tiles = prob_n / 16 / thread_n_blocks; int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); - if constexpr (!has_act_order && group_blocks != -1) - { - if (group_blocks >= thread_k_blocks) - { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks)); - } + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks)); + } } int slice_row = (iters * blockIdx.x) % k_tiles; @@ -417,97 +385,89 @@ namespace device::marlin // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers - if (slice_col_par >= n_tiles) - { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - slice_col = slice_col_par % n_tiles; - par_id = slice_col_par / n_tiles; - } - if (parallel * n_tiles >= gridDim.x) - { - // when parallel * n_tiles >= sms - // then there are at most $sms$ conflict tile blocks - locks_off = blockIdx.x; + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; } - else - { - locks_off = (iters * blockIdx.x) / k_tiles - 1; + if (parallel * n_tiles >= gridDim.x) { + // when parallel * n_tiles >= sms + // then there are at most $sms$ conflict tile blocks + locks_off = blockIdx.x; + } else { + locks_off = (iters * blockIdx.x) / k_tiles - 1; } // Compute all information about the current slice which is required for // synchronization. - auto init_slice = [&](bool first_init = false) - { - slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) - slice_iters = 0; - if (slice_iters == 0) - return; - if (slice_row + slice_iters > k_tiles) - slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) - { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = div_ceil(k_tiles - col_off, iters); - if (col_off > 0) - slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else - { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) - slice_idx--; + auto init_slice = [&](bool first_init = false) { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) { + slice_iters = 0; } - } - if (parallel * n_tiles >= gridDim.x) - { - if (slice_count > 1 && slice_idx == slice_count - 1) - { - locks_off++; + if (slice_iters == 0) { + return; } - } - else - { - locks_off++; - } - - if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) - { - constexpr int threads_per_m = 16 * thread_n_blocks / 8; - int m_per_thread = div_ceil(thread_m_blocks * 16, threads / threads_per_m); - if (m_block_size_8) - m_per_thread = div_ceil(8, threads / threads_per_m); - for (int i = 0; i < m_per_thread; i++) - { - int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; - if (row < prob_m) - { - int col = slice_col * 16 * thread_n_blocks / 8 + threadIdx.x % threads_per_m; - C[row * prob_n / 8 + col] = {0, 0, 0, 0}; - } + if (slice_row + slice_iters > k_tiles) { + slice_iters = k_tiles - slice_row; + } + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) { + slice_count++; + } + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) { + slice_idx = slice_count - 1; + } else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) { + slice_idx--; + } + } + } + if (parallel * n_tiles >= gridDim.x) { + if (slice_count > 1 && slice_idx == slice_count - 1) { + locks_off++; + } + } else { + locks_off++; + } + + if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) { + constexpr int threads_per_m = 16 * thread_n_blocks / 8; + int m_per_thread = div_ceil(thread_m_blocks * 16, threads / threads_per_m); + if (m_block_size_8) { + m_per_thread = div_ceil(8, threads / threads_per_m); + } + for (int i = 0; i < m_per_thread; i++) { + int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; + if (row < prob_m) { + int col = slice_col * 16 * thread_n_blocks / 8 + threadIdx.x % threads_per_m; + C[row * prob_n / 8 + col] = {0, 0, 0, 0}; + } + } + // After write zero to output, write a negative value to lock. + // Every SM that processes the same slice would wait for + // the negative value, and then atomicAdd 1 to it. + // After all SMs are processed, the lock value would back to 0 again. + __syncthreads(); + if (threadIdx.x == 0) { + locks[locks_off] = 1 - slice_count; + } + } + + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * lda / 8; + C += 16 * thread_m_blocks * prob_n / 8; + slice_col = 0; + par_id++; } - // After write zero to output, write a negative value to lock. - // Every SM that processes the same slice would wait for - // the negative value, and then atomicAdd 1 to it. - // After all SMs are processed, the lock value would back to 0 again. - __syncthreads(); - if (threadIdx.x == 0) - locks[locks_off] = 1 - slice_count; - } - - if (slice_col == n_tiles) - { - A += 16 * thread_m_blocks * lda / 8; - C += 16 * thread_m_blocks * prob_n / 8; - slice_col = 0; - par_id++; - } }; init_slice(true); @@ -549,8 +509,8 @@ namespace device::marlin int s_gl_stride = prob_n / 8; constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1) - : 1; + ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1) + : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -579,8 +539,7 @@ namespace device::marlin // Shared write index of current thread. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); // Shared read index. - int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + - (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; @@ -598,33 +557,24 @@ namespace device::marlin // No act_order int s_gl_rd; - if constexpr (!has_act_order) - { - if constexpr (group_blocks == -1) - { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } - else - { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) + - s_sh_stride * slice_col + threadIdx.x; - } + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) + s_sh_stride * slice_col + threadIdx.x; + } } auto s_sh_wr = threadIdx.x; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; // Zero-points int zp_gl_rd; - if constexpr (has_zp) - { - if constexpr (group_blocks == -1) - { - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } - else - { - zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; - } + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } } auto zp_sh_wr = threadIdx.x; bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; @@ -633,21 +583,20 @@ namespace device::marlin // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) - { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + warp_row % 2; + if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + } else if constexpr (group_blocks != -1) { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + } else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; + } else { + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; } - else if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; // Zero-points have the same read layout as the scales // (without column-wise case) @@ -655,20 +604,14 @@ namespace device::marlin constexpr int num_row_threads = 4; constexpr int num_ints_per_thread = 8 / pack_factor; int zp_sh_rd; - if constexpr (has_zp) - { - if constexpr (is_zp_float) - { - if constexpr (group_blocks != -1) - { - zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + if constexpr (has_zp) { + if constexpr (is_zp_float) { + if constexpr (group_blocks != -1) { + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + } + } else { + zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); } - } - else - { - zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); - } } // Precompute which thread should not read memory in which iterations; this is @@ -676,8 +619,9 @@ namespace device::marlin // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } // To ensure that writing and reading A tiles to/from shared memory, the // latter in fragment format, is fully bank conflict free, we need to use a @@ -685,25 +629,25 @@ namespace device::marlin // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the // same shared memory banks. Further, it seems (based on NSight-Compute) that // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) - { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8); + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8); }; // Since the computation of this remapping is non-trivial and, due to our main // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + } int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - { + for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + for (int j = 0; j < thread_m_blocks; j++) { + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } } // Since B-accesses have non-constant stride they have to be computed at @@ -712,8 +656,9 @@ namespace device::marlin // optimization. const int4 *B_ptr[b_sh_wr_iters]; #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + for (int i = 0; i < b_sh_wr_iters; i++) { + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + } extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. @@ -744,1173 +689,956 @@ namespace device::marlin FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ // Zero accumulators. - auto zero_accums = [&]() - { + auto zero_accums = [&]() { #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; }; int sh_first_group_id = -1; int sh_num_groups = -1; - auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) - { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups > act_s_max_num_groups) - { - sh_num_groups = act_s_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) - { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) - { - for (int i = 0; i < sh_num_groups; i++) - { - if (threadIdx.x < s_sh_stride) - { - cp_async4_pred( - &sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); - } + auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups > act_s_max_num_groups) { + sh_num_groups = act_s_max_num_groups; } - } - else - { - for (int i = 0; i < sh_num_groups; i++) - { - if (threadIdx.x < s_sh_stride) - { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; - } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred( + &sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; + } + } } - } }; // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) - { - if (pred) - { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - { + for (int i = 0; i < b_sh_wr_iters; i++) { #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) - { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } - B_ptr[i] += b_gl_rd_delta_o; - } + B_ptr[i] += b_gl_rd_delta_o; + } - if constexpr (has_act_order) - { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) - { - int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int4 const *cur_g_idx_stage_ptr = reinterpret_cast(&g_idx[cur_k]); + int4 const *cur_g_idx_stage_ptr = reinterpret_cast(&g_idx[cur_k]); - if (threadIdx.x < g_idx_stage) - { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } - else - { - if constexpr (group_blocks != -1) - { - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) - { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) - { - if (s_sh_wr_pred) - { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - else - { - for (int i = 0; i < s_tb_groups; i++) - { - if (s_sh_wr_pred) - { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]); + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], &cur_g_idx_stage_ptr[threadIdx.x]); + } } - s_gl_rd += s_gl_rd_delta; - } - } - } - - if constexpr (has_zp && group_blocks != -1) - { - int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) - { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) - { - if (zp_sh_wr_pred) - { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } else { + if constexpr (group_blocks != -1) { + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } } - zp_gl_rd += zp_gl_rd_delta; - } - } - else - { - for (int i = 0; i < zp_tb_groups; i++) - { - if (zp_sh_wr_pred) - { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]); + + if constexpr (has_zp && group_blocks != -1) { + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } } - zp_gl_rd += zp_gl_rd_delta; - } } - } } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); }; - auto fetch_col_zp_to_shared = [&]() - { - if (zp_sh_wr_pred) - { - cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } + auto fetch_col_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } }; - auto fetch_col_scale_to_shared = [&]() - { - if (s_sh_wr_pred) - { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } + auto fetch_col_scale_to_shared = [&]() { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } }; // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() - { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); }; // Load the next sub-tile from the current location in the shared memory pipe // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) - { - int4 *sh_a_stage = sh_a + a_sh_stage * pipe; + auto fetch_to_registers = [&](int k, int pipe) { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + for (int i = 0; i < thread_m_blocks; i++) + ldsm(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) - { - frag_b_quant[k % 2][i] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } }; bool is_same_group[stages]; int same_group_id[stages]; - auto init_same_group = [&](int pipe) - { - if constexpr (!has_act_order) - { - return; - } - - int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) - { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) - { - // No act-order case - if constexpr (group_blocks == -1) - { - // load only when starting a new slice - if (k == 0 && full_pipe == 0) - { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + return; } - else if constexpr (group_blocks != -1) - { - if constexpr (group_blocks >= thread_k_blocks) - { - if (k % b_sh_wr_iters == 0) - { - int4 *sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } - else - { - reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; - } - } - else - { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1)); + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; - int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; - if constexpr (w_type_id != host::kFE2M1f.id()) - { - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - else - { - reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } else if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4 *sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1)); + + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (w_type_id != host::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } + } } - } - } - return; - } + return; + } - // Act-order case + // Act-order case - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) - { - return; - } + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); - // Determine "position" inside the thread-block (based on warp and - // thread-id) - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + // Determine "position" inside the thread-block (based on warp and + // thread-id) + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; - cur_k += warp_row * 16; + cur_k += warp_row * 16; - auto th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + auto th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride; + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride; - if (is_same_group[pipe]) - { - if (k % 2 == 0) - { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; - } - else - { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } - for (int i = 1; i < 4; i++) - { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; } - return; - } - int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - constexpr int k_frag_offsets[4] = {0, 1, 8, 9}; // Tensor core offsets per thread + constexpr int k_frag_offsets[4] = {0, 1, 8, 9}; // Tensor core offsets per thread #pragma unroll - for (int i = 0; i < 4; i++) - { - int actual_k = cur_k + k_frag_offsets[i]; + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } }; - auto fetch_zp_to_registers = [&](int k, int full_pipe) - { - // This code does not handle group_blocks == 0, - // which signifies act_order. - // has_zp implies AWQ, which doesn't have act_order, - static_assert(!has_zp || group_blocks != 0); + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); - if constexpr (has_zp && !is_zp_float) - { - int pipe = full_pipe % stages; + if constexpr (has_zp && !is_zp_float) { + int pipe = full_pipe % stages; - if constexpr (group_blocks == -1) - { - // load only when starting a new slice - if (k == 0 && full_pipe == 0) - { + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { #pragma unroll - for (int i = 0; i < num_ints_per_thread; i++) - { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; - } - } - } - else if constexpr (group_blocks >= thread_k_blocks) - { - if (k % b_sh_wr_iters == 0) - { - int4 *sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + } + } else if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4 *sh_zp_stage = sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); #pragma unroll - for (int i = 0; i < num_ints_per_thread; i++) - { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } - } - else - { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; + int warp_row = warp_id / n_warps; - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); - int k_blocks = cur_k / 16; - int cur_group_id = 0; + int k_blocks = cur_k / 16; + int cur_group_id = 0; - // Suppress bogus and persistent divide-by-zero warning + // Suppress bogus and persistent divide-by-zero warning #pragma nv_diagnostic push #pragma nv_diag_suppress divide_by_zero - cur_group_id = k_blocks / group_blocks; + cur_group_id = k_blocks / group_blocks; #pragma nv_diagnostic pop - int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; - sh_zp_stage += cur_group_id * zp_sh_stride; + sh_zp_stage += cur_group_id * zp_sh_stride; #pragma unroll - for (int i = 0; i < num_ints_per_thread; i++) - { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } } - } - else if constexpr (has_zp && is_zp_float) - { - int pipe = full_pipe % stages; + else if constexpr (has_zp && is_zp_float) { + int pipe = full_pipe % stages; - if constexpr (group_blocks != -1) - { - if constexpr (group_blocks >= thread_k_blocks) - { - if (k % b_sh_wr_iters == 0) - { - int4 *sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; - } - } - else - { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4 *sh_zp_stage = sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; + int warp_row = warp_id / n_warps; - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); - int k_blocks = cur_k / 16; - // Suppress bogus and persistent divide-by-zero warning + int k_blocks = cur_k / 16; + // Suppress bogus and persistent divide-by-zero warning #pragma nv_diagnostic push #pragma nv_diag_suppress divide_by_zero - int cur_group_id = k_blocks / group_blocks; + int cur_group_id = k_blocks / group_blocks; #pragma nv_diagnostic pop - int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; - reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; - } + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } } - } }; - auto dequant_data = [&](int q, scalar_t2 *frag_b_ptr) - { - dequant(q, frag_b_ptr); + auto dequant_data = [&](int q, scalar_t2 *frag_b_ptr) { + dequant(q, frag_b_ptr); }; // Execute the actual tensor core matmul of a sub-tile. bool is_first_matmul_in_slice = true; - auto matmul = [&](int k) - { - int k2 = k % 2; - const bool is_new_zp = ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || - (group_blocks == -1 && is_first_matmul_in_slice); - if constexpr (has_zp && !is_zp_float) - { - if (is_new_zp) - { - if constexpr (group_blocks == -1) - is_first_matmul_in_slice = false; - FragB frag_zp_0; - FragB frag_zp_1; - int zp_quant_0, zp_quant_1; - - if constexpr (w_type.size_bits() == 4) - { - zp_quant_0 = frag_qzp[k2][0]; - zp_quant_1 = zp_quant_0 >> 8; - } - else - { - static_assert(w_type.size_bits() == 8); - zp_quant_0 = frag_qzp[k2][0]; - zp_quant_1 = frag_qzp[k2][1]; - } - - dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); - dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + auto matmul = [&](int k) { + int k2 = k % 2; + const bool is_new_zp = ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || (group_blocks == -1 && is_first_matmul_in_slice); + if constexpr (has_zp && !is_zp_float) { + if (is_new_zp) { + if constexpr (group_blocks == -1) { + is_first_matmul_in_slice = false; + } + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = frag_qzp[k2][1]; + } + + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + } } - } - if constexpr (!dequant_skip_flop && has_zp && is_zp_float) - { - if (is_new_zp) - { - reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { + if (is_new_zp) { + reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; + } } - } - if constexpr (w_type == host::kFE2M1f) - { - int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; - int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + if constexpr (w_type == host::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - dequant_fp8_scales(s_quant_0, reinterpret_cast(&frag_s[k2])); - dequant_fp8_scales(s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); - } + dequant_fp8_scales(s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales(s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll - for (int j = 0; j < 4; j++) - { - FragB frag_b0; - FragB frag_b1; - int b_quant_0, b_quant_1; - - if constexpr (w_type_id == host::kFE2M1f.id()) - { - b_quant_1 = frag_b_quant[k2][0][j]; - b_quant_0 = b_quant_1 << 8; - } - else if constexpr (w_type.size_bits() == 4) - { - b_quant_0 = frag_b_quant[k2][0][j]; - b_quant_1 = b_quant_0 >> 8; - } - else - { - static_assert(w_type.size_bits() == 8); - int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); - b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - } + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (w_type_id == host::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } - dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); - dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); - if constexpr (dequant_skip_flop && has_zp && !is_zp_float) - { - sub_zp(frag_b0, frag_zp[j], 0); - sub_zp(frag_b1, frag_zp[j], 1); - } + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } - // Apply scale to frag_b0 - if constexpr (has_act_order) - { - static_assert(group_blocks != -1); - scale4( - frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); - scale4( - frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); - } - else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && group_blocks == -1) - { - int idx = (threadIdx.x / 4) % 2; - scalar_t2 s2 = Dtype::nums2num2( - reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], - reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); - if (is_new_zp) - frag_zp[j] = __hmul2(frag_zp[j], s2); - scale_and_sub(frag_b0, s2.x, frag_zp[j].x); - scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } - else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) - { - if (is_new_zp) - frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); - scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); - scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); - } - else if constexpr (group_blocks != -1) - { - scale(frag_b0, frag_s[k2][j], 0); - scale(frag_b1, frag_s[k2][j], 1); - } + // Apply scale to frag_b0 + if constexpr (has_act_order) { + static_assert(group_blocks != -1); + scale4( + frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4( + frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && group_blocks == -1) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 s2 = Dtype::nums2num2( + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); + if (is_new_zp) + frag_zp[j] = __hmul2(frag_zp[j], s2); + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { + if (is_new_zp) + frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } else if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); + } #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - { - if constexpr (m_block_size_8) - { - mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); - } - else - { - mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); - } + for (int i = 0; i < thread_m_blocks; i++) { + if constexpr (m_block_size_8) { + mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + } else { + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } } - } }; // Since we slice across the k dimension of a tile in order to increase the // number of warps while keeping the n dimension of a tile reasonable, we have // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() - { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) - { - auto red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + auto red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) - { + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { #pragma unroll - for (int i = red_off; i > 0; i /= 2) - { - if (i <= red_idx && red_idx < 2 * i) - { + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { #pragma unroll - for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) - { - int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) - { - float *c_rd = reinterpret_cast(&sh_red[red_sh_delta * j + red_sh_rd]); - float *c_wr = reinterpret_cast(&sh_red[red_sh_wr]); + for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float *c_rd = reinterpret_cast(&sh_red[red_sh_delta * j + red_sh_rd]); + float *c_wr = reinterpret_cast(&sh_red[red_sh_wr]); #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); } - sh_red[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) - { + if (red_idx == 0) { #pragma unroll - for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) - { - float *c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); + for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + float *c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); } - } - __syncthreads(); } - } }; // Since multiple threadblocks may process parts of the same column slice, we // finally have to globally reduce over the results. As the striped // partitioning minimizes the number of such reductions and our outputs are // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce_fp16 = [&](bool first = false, bool last = false) - { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) - { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr; - if constexpr (m_block_size_8) - { - c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - } - else - { - c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - } - constexpr int c_sh_wr_delta = active_threads; - auto c_sh_wr = threadIdx.x; + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr; + if constexpr (m_block_size_8) { + c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } else { + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + constexpr int c_sh_wr_delta = active_threads; + auto c_sh_wr = threadIdx.x; - int row = (threadIdx.x % 32) / 4; + int row = (threadIdx.x % 32) / 4; - if (!first) - { + if (!first) { // Interestingly, doing direct global accesses here really seems to mess up // the compiler and lead to slowdowns, hence we also use async-copies even // though these fetches are not actually asynchronous. #pragma unroll - for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) - { - if constexpr (m_block_size_8) - { - cp_async4_pred( - &sh_red[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i], - (threadIdx.x % 4) * 2 + i < prob_m); - } - else - { - cp_async4_pred( - &sh_red[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + if constexpr (m_block_size_8) { + cp_async4_pred( + &sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i], + (threadIdx.x % 4) * 2 + i < prob_m); + } else { + cp_async4_pred( + &sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + } + cp_async_fence(); + cp_async_wait<0>(); } - } - cp_async_fence(); - cp_async_wait<0>(); - } #pragma unroll - for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) - { - bool mask = (!m_block_size_8) && (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) || - (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m); - if (mask) - { - if (!first) - { - int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + bool mask = (!m_block_size_8) && (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) || (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m); + if (mask) { + if (!first) { + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; #pragma unroll - for (int j = 0; j < 2 * 4; j++) - { - int delta = 0; - if constexpr (m_block_size_8) - { - delta = j % 2 == 1 ? -2 : 0; - } - reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); - } - } - if (!last) - { - int4 c; + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; #pragma unroll - for (int j = 0; j < 2 * 4; j++) - { - int delta = 0; - if constexpr (m_block_size_8) - { - delta = j % 2 == 1 ? -2 : 0; + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&c)[j] = Dtype::float2num(reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + } + if constexpr (m_block_size_8) + C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c; + else + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; + } } - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); - } - if constexpr (m_block_size_8) - C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c; - else - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; } - } } - } }; // Globally reduce over threadblocks that compute the same column block. // We use a tmp C buffer to reduce in full fp32 precision. - auto global_reduce_fp32 = [&](bool first = false, bool last = false) - { - constexpr int tb_m = thread_m_blocks * 16; - constexpr int tb_n = thread_n_blocks * 16; + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; - constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; - constexpr int active_threads = 32 * thread_n_blocks / 4; - bool is_th_active = threadIdx.x < active_threads; + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; - constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; - constexpr int th_size = num_floats * sizeof(float) / 16; + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; - int c_cur_offset = locks_off * c_size; + int c_cur_offset = locks_off * c_size; - if (!is_th_active) - { - return; - } + if (!is_th_active) { + return; + } - if (!first) - { - float *frag_c_ptr = reinterpret_cast(&frag_c); + if (!first) { + float *frag_c_ptr = reinterpret_cast(&frag_c); #pragma unroll - for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) - { - sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) { + sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; - float *sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); + float *sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); #pragma unroll - for (int f = 0; f < 4; f++) - { - frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; - } + for (int f = 0; f < 4; f++) { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } } - } - if (!last) - { - int4 *frag_c_ptr = reinterpret_cast(&frag_c); + if (!last) { + int4 *frag_c_ptr = reinterpret_cast(&frag_c); #pragma unroll - for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) - { - C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) { + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } } - } }; // Write out the reduce final result in the correct layout. We only actually // reshuffle matrix fragments in this step, the reduction above is performed // in fragment layout. - auto write_result = [&]() - { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr; - if constexpr (m_block_size_8) - { - c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + (threadIdx.x % 32) / 4; - c_sh_wr += 64 * (threadIdx.x / 32); - } - else - { - c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - } - - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS &s) - { - scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); - - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr ( - !has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) - { - res = __hmul2(res, s[0]); + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr; + if constexpr (m_block_size_8) { + c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + (threadIdx.x % 32) / 4; + c_sh_wr += 64 * (threadIdx.x / 32); + } else { + c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); } - if constexpr (w_type == host::kFE2M1f) - { - res = __hmul2(res, global_scale); - } + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); - if constexpr (m_block_size_8) - { - ((scalar_t *)sh_red)[idx] = res.x; - ((scalar_t *)sh_red)[idx + 8 * c_sh_stride] = res.y; - } - else - { - ((scalar_t2 *)sh_red)[idx] = res; - } - }; + int c_gl_wr_end = c_gl_stride * prob_m; + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS &s) { + scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr ( + !has_act_order && group_blocks == -1 && w_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { + res = __hmul2(res, s[0]); + } + + if constexpr (w_type == host::kFE2M1f) { + res = __hmul2(res, global_scale); + } - if (threadIdx.x / 32 < thread_n_blocks / 4) - { + if constexpr (m_block_size_8) { + ((scalar_t *)sh_red)[idx] = res.x; + ((scalar_t *)sh_red)[idx + 8 * c_sh_stride] = res.y; + } else { + ((scalar_t2 *)sh_red)[idx] = res; + } + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - { + for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll - for (int j = 0; j < 4; j++) - { - if constexpr (m_block_size_8) - { - int wr = c_sh_wr + 16 * j; - write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - else - { - int wr = c_sh_wr + 8 * j; - write( - wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write( - wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write( - wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write( - wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + for (int j = 0; j < 4; j++) { + if constexpr (m_block_size_8) { + int wr = c_sh_wr + 16 * j; + write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 1]); + } else { + int wr = c_sh_wr + 8 * j; + write( + wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write( + wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write( + wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write( + wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + } + c_sh_wr += 16 * (4 * c_sh_stride); } - } - c_sh_wr += 16 * (4 * c_sh_stride); } - } - __syncthreads(); + __syncthreads(); #pragma unroll - for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) - { - if (c_gl_wr < c_gl_wr_end) - { - if (use_atomic_add && slice_count > 1) - { - scalar_t2 *C_half2 = reinterpret_cast(&C[c_gl_wr]); - scalar_t2 *sh_red_half2 = reinterpret_cast(&sh_red[c_sh_rd]); + for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { + if (c_gl_wr < c_gl_wr_end) { + if (use_atomic_add && slice_count > 1) { + scalar_t2 *C_half2 = reinterpret_cast(&C[c_gl_wr]); + scalar_t2 *sh_red_half2 = reinterpret_cast(&sh_red[c_sh_rd]); #pragma unroll - for (int a = 0; a < 4; a++) - { - atomicAdd(&C_half2[a], sh_red_half2[a]); + for (int a = 0; a < 4; a++) { + atomicAdd(&C_half2[a], sh_red_half2[a]); + } + } else { + C[c_gl_wr] = sh_red[c_sh_rd]; + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; } - } - else - { - C[c_gl_wr] = sh_red[c_sh_rd]; - } - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; } - } - __syncthreads(); + __syncthreads(); }; // Start global fetch and register load pipelines. - auto start_pipes = [&]() - { + auto start_pipes = [&]() { #pragma unroll - for (int i = 0; i < stages - 1; i++) - { - if (has_act_order && i == 0) - { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) - { - last_g_idx = prob_k - 1; - } - fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } - if constexpr (has_zp && !is_zp_float && group_blocks == -1) - { - if (i == 0) - { - fetch_col_zp_to_shared(); - if constexpr (!dequant_skip_flop) - { - fetch_col_scale_to_shared(); + if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + if (i == 0) { + fetch_col_zp_to_shared(); + if constexpr (!dequant_skip_flop) { + fetch_col_scale_to_shared(); + } + } } - } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + if constexpr (has_act_order) { + slice_k_start_shared_fetch += tb_k * (stages - 1); } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - fetch_zp_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - if constexpr (has_act_order) - { - slice_k_start_shared_fetch += tb_k * (stages - 1); - } }; - if (slice_iters) - { - start_pipes(); + if (slice_iters) { + start_pipes(); } // Main loop. - while (slice_iters) - { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. #pragma unroll - for (int pipe = 0; pipe < stages;) - { + for (int pipe = 0; pipe < stages;) { #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) - { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) - { - fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) - { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - - if constexpr (has_act_order) - { - slice_k_start += tb_k * stages; - - if (slice_k_start < prob_k) - { - slice_k_start_shared_fetch += tb_k * stages; - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) - { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) - { - fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) - { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) - { - if (w_type.size_bits() == 8 || (last || use_atomic_add)) - { - if (s_sh_wr_pred) - { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + + a_gl_rd += a_gl_rd_delta_o * stages; + + if constexpr (has_act_order) { + slice_k_start += tb_k * stages; + + if (slice_k_start < prob_k) { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } } - cp_async_fence(); - } } - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) - { - if (w_type.size_bits() == 8 || (last || use_atomic_add)) - { + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) - { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - if constexpr (m_block_size_8) - { - int idx = (threadIdx.x / 4) % 2; - scalar_t2 *frag_s_half2 = reinterpret_cast(frag_s); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + if constexpr (m_block_size_8) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 *frag_s_half2 = reinterpret_cast(frag_s); #pragma unroll - for (int i = 0; i < 8; i++) - { - frag_s_half2[i] = Dtype::num2num2(reinterpret_cast(&frag_s_half2[i])[idx]); + for (int i = 0; i < 8; i++) { + frag_s_half2[i] = Dtype::num2num2(reinterpret_cast(&frag_s_half2[i])[idx]); + } + } + } } - } } - } - } - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr ( - !has_act_order && group_blocks == -1 && w_type.size_bits() == 8 && (has_zp && dequant_skip_flop || !has_zp)) - { - if (threadIdx.x / 32 < thread_n_blocks / 4) - { + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr ( + !has_act_order && group_blocks == -1 && w_type.size_bits() == 8 && (has_zp && dequant_skip_flop || !has_zp)) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - { + for (int i = 0; i < thread_m_blocks; i++) { #pragma unroll - for (int j = 0; j < 4; j++) - { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); - scale_float( - reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); - - if constexpr (!m_block_size_8) - { - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); + + if constexpr (!m_block_size_8) { + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } } - } } - } - } - if (slice_count > 1 && !use_atomic_add) - { - // only globally reduce if there is more than one block in a slice - barrier_acquire(&locks[locks_off], slice_idx); - if (use_fp32_reduce) - { - global_reduce_fp32(slice_idx == 0, last); - } - else - { - global_reduce_fp16(slice_idx == 0, last); - } - barrier_release(&locks[locks_off], last); - } - if (use_atomic_add && slice_count > 1 && slice_idx != 0) - wait_negative_and_add(&locks[locks_off]); - if (last || use_atomic_add) - // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - is_first_matmul_in_slice = true; - init_slice(); - - if (slice_iters) - { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + if (slice_count > 1 && !use_atomic_add) { + // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[locks_off], slice_idx); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[locks_off], last); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) { + wait_negative_and_add(&locks[locks_off]); + } + if (last || use_atomic_add) { + // only the last block in a slice actually writes the result + write_result(); + } + slice_row = 0; + slice_col_par++; + slice_col++; + is_first_matmul_in_slice = true; + init_slice(); + + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) - { + for (int i = 0; i < b_sh_wr_iters; i++) { + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + } + if (slice_col == 0) { #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) - { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - } - else - { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } + for (int i = 0; i < b_sh_wr_iters; i++) { + B_ptr[i] -= b_gl_stride; + } + } - start_pipes(); + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } } - } } - } +} } // namespace device::marlin diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/scalar_type.hpp b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/scalar_type.hpp index 15f46457f..04ce7a537 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/scalar_type.hpp +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/scalar_type.hpp @@ -18,276 +18,274 @@ namespace host { // here. // class ScalarType { - public: - enum NanRepr : uint8_t { - NAN_NONE = 0, // nans are not supported - NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s - NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s - - NAN_REPR_ID_MAX - }; - - constexpr ScalarType( - uint8_t exponent, - uint8_t mantissa, - bool signed_, - int32_t bias, - bool finite_values_only = false, - NanRepr nan_repr = NAN_IEEE_754) - : exponent(exponent), - mantissa(mantissa), - signed_(signed_), - bias(bias), - finite_values_only(finite_values_only), - nan_repr(nan_repr) {}; - - static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { - return ScalarType(0, size_bits - 1, true, bias); - } - - static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { - return ScalarType(0, size_bits, false, bias); - } - - // IEEE 754 compliant floating point type - static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) { - assert(mantissa > 0 && exponent > 0); - return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); - } - - // IEEE 754 non-compliant floating point type - static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) { - assert(nan_repr < NAN_REPR_ID_MAX); - assert(mantissa > 0 && exponent > 0); - assert(nan_repr != NAN_IEEE_754); - return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr); - } - - uint8_t const exponent; // size of the exponent field (0 for integer types) - uint8_t const mantissa; // size of the mantissa field (size of the integer - // excluding the sign bit for integer types) - bool const signed_; // flag if the type supports negative numbers (i.e. has a - // sign bit) - int32_t const bias; // stored values equal value + bias, - // used for quantized type - - // Extra Floating point info - bool const finite_values_only; // i.e. no +/-inf if true - NanRepr const nan_repr; // how NaNs are represented - // (not applicable for integer types) - - using Id = int64_t; - - private: - // Field size in id - template - static constexpr size_t member_id_field_width() { - using T = std::decay_t; - return std::is_same_v ? 1 : sizeof(T) * 8; - } - - template - static constexpr auto reduce_members_helper(Fn f, Init val, Member member, Rest... rest) { - auto new_val = f(val, member); - if constexpr (sizeof...(rest) > 0) { - return reduce_members_helper(f, new_val, rest...); - } else { - return new_val; +public: + enum NanRepr : uint8_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX }; - } - - template - constexpr auto reduce_members(Fn f, Init init) const { - // Should be in constructor order for `from_id` - return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, finite_values_only, nan_repr); - }; - - template - static constexpr auto reduce_member_types(Fn f, Init init) { - constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); - return dummy_type.reduce_members(f, init); - }; - - static constexpr auto id_size_bits() { - return reduce_member_types( - [](int acc, auto member) -> int { return acc + member_id_field_width(); }, 0); - } - - public: - // unique id for this scalar type that can be computed at compile time for - // c++17 template specialization this is not needed once we migrate to - // c++20 and can pass literal classes as template parameters - constexpr Id id() const { - static_assert(id_size_bits() <= sizeof(Id) * 8, "ScalarType id is too large to be stored"); - - auto or_and_advance = [](std::pair result, auto member) -> std::pair { - auto [id, bit_offset] = result; - auto constexpr bits = member_id_field_width(); - return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << bit_offset, bit_offset + bits}; + + constexpr ScalarType( + uint8_t exponent, + uint8_t mantissa, + bool signed_, + int32_t bias, + bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr){}; + + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); + } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) { + assert(mantissa > 0 && exponent > 0); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) { + assert(nan_repr < NAN_REPR_ID_MAX); + assert(mantissa > 0 && exponent > 0); + assert(nan_repr != NAN_IEEE_754); + return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr); + } + + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + using Id = int64_t; + +private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, finite_values_only, nan_repr); }; - return reduce_members(or_and_advance, std::pair{}).first; - } - - // create a ScalarType from an id, for c++17 template specialization, - // this is not needed once we migrate to c++20 and can pass literal - // classes as template parameters - static constexpr ScalarType from_id(Id id) { - auto extract_and_advance = [id](auto result, auto member) { - using T = decltype(member); - auto [tuple, bit_offset] = result; - auto constexpr bits = member_id_field_width(); - auto extracted_val = static_cast((int64_t(id) >> bit_offset) & ((uint64_t(1) << bits) - 1)); - auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); - return std::pair{new_tuple, bit_offset + bits}; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); }; - auto [tuple_args, _] = reduce_member_types(extract_and_advance, std::pair, int>{}); - return std::apply([](auto... args) { return ScalarType(args...); }, tuple_args); - } - - constexpr int64_t size_bits() const { - return mantissa + exponent + is_signed(); - } - constexpr bool is_signed() const { - return signed_; - } - constexpr bool is_integer() const { - return exponent == 0; - } - constexpr bool is_floating_point() const { - return exponent > 0; - } - constexpr bool is_ieee_754() const { - return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754; - } - constexpr bool has_nans() const { - return is_floating_point() && nan_repr != NAN_NONE; - } - constexpr bool has_infs() const { - return is_floating_point() && finite_values_only == false; - } - constexpr bool has_bias() const { - return bias != 0; - } + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { return acc + member_id_field_width(); }, 0); + } -#ifndef __CUDACC__ - private: - double _floating_point_max() const { - assert(mantissa <= 52 && exponent <= 11); +public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << bit_offset, bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } - uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; - if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { - max_mantissa -= 1; + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, tuple_args); } - uint64_t max_exponent = (uint64_t(1) << exponent) - 2; - if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { - assert(exponent < 11); - max_exponent += 1; + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { + return signed_; + } + constexpr bool is_integer() const { + return exponent == 0; + } + constexpr bool is_floating_point() const { + return exponent > 0; + } + constexpr bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754; } + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + constexpr bool has_bias() const { + return bias != 0; + } + +#ifndef __CUDACC__ +private: + double _floating_point_max() const { + assert(mantissa <= 52 && exponent <= 11); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + assert(exponent < 11); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = max_exponent - exponent_bias + exponent_bias_double; - // adjust the exponent to match that of a double - // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e - // is the exponent bits), there is some precedent for non-standard biases, - // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes - // but to avoid premature over complication we are just assuming the - // standard exponent bias until there is a need to support non-standard - // biases - uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; - uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 - - uint64_t max_exponent_double = max_exponent - exponent_bias + exponent_bias_double; - - // shift the mantissa into the position for a double and - // the exponent - uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); - - return *reinterpret_cast(&double_raw); - } - - constexpr std::variant _raw_max() const { - if (is_floating_point()) { - return {_floating_point_max()}; - } else { - assert(size_bits() < 64 || (size_bits() == 64 && is_signed())); - return {(int64_t(1) << mantissa) - 1}; + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); } - } - - constexpr std::variant _raw_min() const { - if (is_floating_point()) { - assert(is_signed()); - constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); - - double max = _floating_point_max(); - uint64_t max_raw = *reinterpret_cast(&max); - uint64_t min_raw = max_raw | sign_bit_double; - return {*reinterpret_cast(&min_raw)}; - } else { - assert(!is_signed() || size_bits() <= 64); - if (is_signed()) { - // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 - // then perform an arithmetic shift right to set all the bits above - // (size_bits() - 1) to 1 - return {INT64_MIN >> (64 - size_bits())}; - } else { - return {int64_t(0)}; - } + + constexpr std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + assert(size_bits() < 64 || (size_bits() == 64 && is_signed())); + return {(int64_t(1) << mantissa) - 1}; + } } - } - - public: - // Max representable value for this scalar type. - // (accounting for bias if there is one) - constexpr std::variant max() const { - return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_max()); - } - - // Min representable value for this scalar type. - // (accounting for bias if there is one) - constexpr std::variant min() const { - return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_min()); - } -#endif // __CUDACC__ - - public: - std::string str() const { - /* naming generally follows: https://github.com/jax-ml/ml_dtypes - * for floating point types (leading f) the scheme is: - * `float_em[flags]` - * flags: - * - no-flags: means it follows IEEE 754 conventions - * - f: means finite values only (no infinities) - * - n: means nans are supported (non-standard encoding) - * for integer types the scheme is: - * `[u]int[b]` - * - if bias is not present it means its zero - */ - if (is_floating_point()) { - auto ret = - "float" + std::to_string(size_bits()) + "_e" + std::to_string(exponent) + "m" + std::to_string(mantissa); - if (!is_ieee_754()) { - if (finite_values_only) { - ret += "f"; + + constexpr std::variant _raw_min() const { + if (is_floating_point()) { + assert(is_signed()); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + assert(!is_signed() || size_bits() <= 64); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } } - if (nan_repr != NAN_NONE) { - ret += "n"; + } + +public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant max() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant min() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_min()); + } +#endif // __CUDACC__ + +public: + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float_em[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int[b]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = "float" + std::to_string(size_bits()) + "_e" + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; } - } - return ret; - } else { - auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); - if (has_bias()) { - ret += "b" + std::to_string(bias); - } - return ret; } - } - constexpr bool operator==(ScalarType const& other) const { - return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ && - finite_values_only == other.finite_values_only && nan_repr == other.nan_repr; - } + constexpr bool operator==(ScalarType const &other) const { + return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ && finite_values_only == other.finite_values_only && nan_repr == other.nan_repr; + } }; using ScalarTypeId = ScalarType::Id; @@ -331,5 +329,4 @@ static inline constexpr auto kFloat16 = kHalf; static inline constexpr auto kBFloat16 = kFE8M7; static inline constexpr auto kFloat16Id = kFloat16.id(); -} // namespace host - +} // namespace host diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/source_location.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/source_location.h index 57573171a..9a06fb380 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/source_location.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/source_location.h @@ -16,26 +16,25 @@ using source_location_t = std::source_location; #else struct source_location_fallback { - public: - static constexpr source_location_fallback current() noexcept { - return source_location_fallback{}; - } - constexpr source_location_fallback() noexcept = default; - constexpr unsigned line() const noexcept { - return 0; - } - constexpr unsigned column() const noexcept { - return 0; - } - constexpr const char* file_name() const noexcept { - return ""; - } - constexpr const char* function_name() const noexcept { - return ""; - } +public: + static constexpr source_location_fallback current() noexcept { + return source_location_fallback{}; + } + constexpr source_location_fallback() noexcept = default; + constexpr unsigned line() const noexcept { + return 0; + } + constexpr unsigned column() const noexcept { + return 0; + } + constexpr const char *file_name() const noexcept { + return ""; + } + constexpr const char *function_name() const noexcept { + return ""; + } }; using source_location_t = source_location_fallback; #endif - diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h index 9f48edd96..f30492621 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h @@ -23,199 +23,178 @@ #include "utils.cuh" #endif -namespace host -{ - struct SymbolicSize; - struct SymbolicDType; - struct SymbolicDevice; - - namespace details - { - inline constexpr auto kAnyDeviceID = -1; - inline constexpr auto kAnySize = static_cast(-1); - inline constexpr auto kNullSize = static_cast(-1); - inline constexpr auto kNullDType = static_cast(18u); - inline constexpr auto kNullDevice = static_cast(-1); - - template - struct ArrayView - { - const T *data; - size_t size; - - __host__ __device__ ArrayView() : data(nullptr), size(0) {} - __host__ __device__ ArrayView(const T *d, size_t s) : data(d), size(s) {} - - template - __host__ __device__ ArrayView(const std::array &arr) - : data(arr.data()), size(arr.size()) {} - - __host__ __device__ const T &operator[](size_t i) const { return data[i]; } - __host__ __device__ bool empty() const { return size == 0; } - }; - - template - struct PrintAbleSpan - { - const T *data; - size_t length; - - PrintAbleSpan(const T *p, size_t l) : data(p), length(l) {} - size_t size() const { return length; } - const T &operator[](size_t i) const { return data[i]; } - }; - - inline constexpr const char *kDeviceStringMap[] = { - "", // 0 - "cpu", // 1 - "cuda", // 2 - "cuda_host", // 3 - "opencl", // 4 - "vulkan", // 5 - "metal", // 6 - "vpi", // 7 - "rocm", // 8 - "rocm_host", // 9 - "ext_dev", // 10 - "cuda_managed", // 11 - "oneapi", // 12 - "webgpu", // 13 - "hexagon", // 14 - "maia", // 15 - "trn", // 16 - }; - - constexpr int kMaxDeviceType = 16; - - struct PrintableDevice - { - DLDevice device; - }; - - template - struct _dtype_trait; - - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLInt, 8, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLInt, 16, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLInt, 32, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLInt, 64, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLUInt, 8, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLUInt, 16, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLUInt, 32, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLUInt, 64, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLFloat, 32, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLFloat, 64, 1}; - }; +namespace host { +struct SymbolicSize; +struct SymbolicDType; +struct SymbolicDevice; + +namespace details { +inline constexpr auto kAnyDeviceID = -1; +inline constexpr auto kAnySize = static_cast(-1); +inline constexpr auto kNullSize = static_cast(-1); +inline constexpr auto kNullDType = static_cast(18u); +inline constexpr auto kNullDevice = static_cast(-1); + +template +struct ArrayView { + const T *data; + size_t size; + + __host__ __device__ ArrayView() : data(nullptr), size(0) {} + __host__ __device__ ArrayView(const T *d, size_t s) : data(d), size(s) {} + + template + __host__ __device__ ArrayView(const std::array &arr) + : data(arr.data()), size(arr.size()) {} + + __host__ __device__ const T &operator[](size_t i) const { return data[i]; } + __host__ __device__ bool empty() const { return size == 0; } +}; + +template +struct PrintAbleSpan { + const T *data; + size_t length; + + PrintAbleSpan(const T *p, size_t l) : data(p), length(l) {} + size_t size() const { return length; } + const T &operator[](size_t i) const { return data[i]; } +}; + +inline constexpr const char *kDeviceStringMap[] = { + "", // 0 + "cpu", // 1 + "cuda", // 2 + "cuda_host", // 3 + "opencl", // 4 + "vulkan", // 5 + "metal", // 6 + "vpi", // 7 + "rocm", // 8 + "rocm_host", // 9 + "ext_dev", // 10 + "cuda_managed", // 11 + "oneapi", // 12 + "webgpu", // 13 + "hexagon", // 14 + "maia", // 15 + "trn", // 16 +}; + +constexpr int kMaxDeviceType = 16; + +struct PrintableDevice { + DLDevice device; +}; + +template +struct _dtype_trait; + +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLInt, 8, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLInt, 16, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLInt, 32, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLInt, 64, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLUInt, 8, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLUInt, 16, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLUInt, 32, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLUInt, 64, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLFloat, 32, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLFloat, 64, 1}; +}; #ifdef __CUDACC__ - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLFloat, 16, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLBfloat, 16, 1}; - }; - template <> - struct _dtype_trait - { - static constexpr DLDataType value = {kDLFloat8_e4m3fn, 8, 1}; - }; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLFloat, 16, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLBfloat, 16, 1}; +}; +template <> +struct _dtype_trait { + static constexpr DLDataType value = {kDLFloat8_e4m3fn, 8, 1}; +}; #endif - template - struct _device_trait - { - static constexpr DLDevice value = {Code, kAnyDeviceID}; - }; +template +struct _device_trait { + static constexpr DLDevice value = {Code, kAnyDeviceID}; +}; - template - inline constexpr std::array kDTypeList = { - _dtype_trait::value...}; +template +inline constexpr std::array kDTypeList = { + _dtype_trait::value...}; - template - inline constexpr std::array kDeviceList = { - _device_trait::value...}; +template +inline constexpr std::array kDeviceList = { + _device_trait::value...}; - } // namespace details +} // namespace details - inline std::ostream &operator<<(std::ostream &os, DLDevice device) - { +inline std::ostream &operator<<(std::ostream &os, DLDevice device) { int code = static_cast(device.device_type); - if (code < 1 || code > details::kMaxDeviceType) - RuntimeCheck(false, "Unknown device: ", code); + if (code < 1 || code > details::kMaxDeviceType) { + RuntimeCheck(false, "Unknown device: ", code); + } os << details::kDeviceStringMap[code]; - if (device.device_id != details::kAnyDeviceID && device.device_type != kDLCPU) - os << ":" << device.device_id; + if (device.device_id != details::kAnyDeviceID && device.device_type != kDLCPU) { + os << ":" << device.device_id; + } return os; - } +} - inline std::ostream &operator<<(std::ostream &os, details::PrintableDevice pd) - { +inline std::ostream &operator<<(std::ostream &os, details::PrintableDevice pd) { return os << pd.device; - } +} - template - inline std::ostream &operator<<(std::ostream &os, const details::PrintAbleSpan &span) - { +template +inline std::ostream &operator<<(std::ostream &os, const details::PrintAbleSpan &span) { os << "["; - for (size_t i = 0; i < span.size(); ++i) - { - if (i > 0) - os << ", "; - os << span[i]; + for (size_t i = 0; i < span.size(); ++i) { + if (i > 0) { + os << ", "; + } + os << span[i]; } os << "]"; return os; - } - - // ============================================== - // SymbolicSize 完整定义 - // ============================================== - struct SymbolicSize - { - public: +} + +// ============================================== +// SymbolicSize 完整定义 +// ============================================== +struct SymbolicSize { +public: explicit SymbolicSize(std::string_view ann = {}) : m_value(details::kNullSize), m_ann(ann) {} @@ -225,275 +204,243 @@ namespace host std::string_view get_name() const { return m_ann; } bool has_value() const { return m_value != details::kNullSize; } - void set_value(int64_t v) - { - RuntimeCheck(!has_value(), "Size already set"); - m_value = v; + void set_value(int64_t v) { + RuntimeCheck(!has_value(), "Size already set"); + m_value = v; } - std::optional get_value() const - { - return has_value() ? std::optional(m_value) : std::nullopt; + std::optional get_value() const { + return has_value() ? std::optional(m_value) : std::nullopt; } - int64_t unwrap(DebugInfo info = {}) const - { - RuntimeCheck(info, has_value(), "Size not set"); - return m_value; + int64_t unwrap(DebugInfo info = {}) const { + RuntimeCheck(info, has_value(), "Size not set"); + return m_value; } - void verify(int64_t v, const char *prefix, int64_t dim) - { - if (has_value()) - { - if (m_value != v) [[unlikely]] - { - Panic("Size mismatch for ", m_name_str(prefix, dim), ": expected ", m_value, " got ", v); + void verify(int64_t v, const char *prefix, int64_t dim) { + if (has_value()) { + if (m_value != v) [[unlikely]] { + Panic("Size mismatch for ", m_name_str(prefix, dim), ": expected ", m_value, " got ", v); + } + } else { + set_value(v); } - } - else - { - set_value(v); - } } - std::string value_or_name(const char *prefix, int64_t dim) const - { - if (auto v = get_value()) - return std::to_string(*v); - return m_name_str(prefix, dim); + std::string value_or_name(const char *prefix, int64_t dim) const { + if (auto v = get_value()) { + return std::to_string(*v); + } + return m_name_str(prefix, dim); } - private: - std::string m_name_str(const char *prefix, int64_t dim) const - { - std::ostringstream os; - os << prefix << '#' << dim; - if (!m_ann.empty()) - os << "('" << m_ann << "')"; - return os.str(); +private: + std::string m_name_str(const char *prefix, int64_t dim) const { + std::ostringstream os; + os << prefix << '#' << dim; + if (!m_ann.empty()) { + os << "('" << m_ann << "')"; + } + return os.str(); } int64_t m_value; std::string_view m_ann; - }; +}; - inline bool operator==(DLDevice a, DLDevice b) - { +inline bool operator==(DLDevice a, DLDevice b) { return a.device_type == b.device_type && a.device_id == b.device_id; - } - - // ============================================== - // SymbolicDType 完整定义 - // ============================================== - struct SymbolicDType - { - public: +} + +// ============================================== +// SymbolicDType 完整定义 +// ============================================== +struct SymbolicDType { +public: SymbolicDType() : m_value({details::kNullDType, 0, 0}) {} SymbolicDType(const SymbolicDType &) = delete; SymbolicDType &operator=(const SymbolicDType &) = delete; bool has_value() const { return m_value.code != details::kNullDType; } - void set_value(DLDataType v) - { - RuntimeCheck(!has_value(), "DType already set"); - RuntimeCheck(m_check(v), "DType not allowed: ", v); - m_value = v; + void set_value(DLDataType v) { + RuntimeCheck(!has_value(), "DType already set"); + RuntimeCheck(m_check(v), "DType not allowed: ", v); + m_value = v; } - std::optional get_value() const - { - return has_value() ? std::optional(m_value) : std::nullopt; + std::optional get_value() const { + return has_value() ? std::optional(m_value) : std::nullopt; } - DLDataType unwrap(DebugInfo info = {}) const - { - RuntimeCheck(info, has_value(), "DType not set"); - return m_value; + DLDataType unwrap(DebugInfo info = {}) const { + RuntimeCheck(info, has_value(), "DType not set"); + return m_value; } void set_options(details::ArrayView opts) { m_opts = opts; } template - void set_options() - { - m_opts = details::ArrayView(details::kDTypeList.data(), details::kDTypeList.size()); + void set_options() { + m_opts = details::ArrayView(details::kDTypeList.data(), details::kDTypeList.size()); } - void verify(DLDataType dtype) - { - if (has_value()) - { - RuntimeCheck(m_value == dtype, "DType mismatch: expected ", m_value, " got ", dtype); - } - else - { - set_value(dtype); - } + void verify(DLDataType dtype) { + if (has_value()) { + RuntimeCheck(m_value == dtype, "DType mismatch: expected ", m_value, " got ", dtype); + } else { + set_value(dtype); + } } template - bool is_type() const - { - return m_value == details::_dtype_trait::value; + bool is_type() const { + return m_value == details::_dtype_trait::value; } - private: - bool m_check(DLDataType v) const - { - if (m_opts.empty()) - return true; - for (size_t i = 0; i < m_opts.size; ++i) - if (m_opts[i] == v) - return true; - return false; +private: + bool m_check(DLDataType v) const { + if (m_opts.empty()) { + return true; + } + for (size_t i = 0; i < m_opts.size; ++i) { + if (m_opts[i] == v) { + return true; + } + } + return false; } details::ArrayView m_opts; DLDataType m_value; - }; - - // ============================================== - // SymbolicDevice 完整定义 - // ============================================== - struct SymbolicDevice - { - public: +}; + +// ============================================== +// SymbolicDevice 完整定义 +// ============================================== +struct SymbolicDevice { +public: SymbolicDevice() : m_value({details::kNullDevice, details::kAnyDeviceID}) {} SymbolicDevice(const SymbolicDevice &) = delete; SymbolicDevice &operator=(const SymbolicDevice &) = delete; bool has_value() const { return m_value.device_type != details::kNullDevice; } - void set_value(DLDevice v) - { - RuntimeCheck(!has_value(), "Device already set"); - RuntimeCheck(m_check(v), "Device not allowed: ", details::PrintableDevice{v}); - m_value = v; + void set_value(DLDevice v) { + RuntimeCheck(!has_value(), "Device already set"); + RuntimeCheck(m_check(v), "Device not allowed: ", details::PrintableDevice{v}); + m_value = v; } - std::optional get_value() const - { - return has_value() ? std::optional(m_value) : std::nullopt; + std::optional get_value() const { + return has_value() ? std::optional(m_value) : std::nullopt; } - DLDevice unwrap(DebugInfo info = {}) const - { - RuntimeCheck(info, has_value(), "Device not set"); - return m_value; + DLDevice unwrap(DebugInfo info = {}) const { + RuntimeCheck(info, has_value(), "Device not set"); + return m_value; } void set_options(details::ArrayView opts) { m_opts = opts; } template - void set_options() - { - m_opts = details::ArrayView(details::kDeviceList.data(), details::kDeviceList.size()); - } - - void verify(DLDevice dev) - { - if (has_value()) - { - RuntimeCheck(m_value == dev, "Device mismatch: expected ", - details::PrintableDevice{m_value}, " got ", details::PrintableDevice{dev}); - } - else - { - set_value(dev); - } - } - - private: - bool m_check(DLDevice v) const - { - if (m_opts.empty()) - return true; - for (size_t i = 0; i < m_opts.size; ++i) - { - auto o = m_opts[i]; - if (o.device_type != v.device_type) - continue; - if (o.device_id == details::kAnyDeviceID || o.device_id == v.device_id) - return true; - } - return false; + void set_options() { + m_opts = details::ArrayView(details::kDeviceList.data(), details::kDeviceList.size()); + } + + void verify(DLDevice dev) { + if (has_value()) { + RuntimeCheck(m_value == dev, "Device mismatch: expected ", + details::PrintableDevice{m_value}, " got ", details::PrintableDevice{dev}); + } else { + set_value(dev); + } + } + +private: + bool m_check(DLDevice v) const { + if (m_opts.empty()) { + return true; + } + for (size_t i = 0; i < m_opts.size; ++i) { + auto o = m_opts[i]; + if (o.device_type != v.device_type) { + continue; + } + if (o.device_id == details::kAnyDeviceID || o.device_id == v.device_id) { + return true; + } + } + return false; } details::ArrayView m_opts; DLDevice m_value; - }; - - // ============================================== - // BaseRef / Ref 类型(现在类型已完整定义) - // ============================================== - namespace details - { - template - struct BaseRef - { - BaseRef() : m_ref(&m_cache) {} - explicit BaseRef(T &r) : m_ref(&r) {} - BaseRef(const BaseRef &) = delete; - BaseRef &operator=(const BaseRef &) = delete; - - T *operator->() const { return m_ref; } - T &operator*() const { return *m_ref; } - void rebind(T &r) { m_ref = &r; } - - private: - T *m_ref; - T m_cache; - }; - - struct SizeRef : public BaseRef - { - using BaseRef::BaseRef; - SizeRef(int64_t v); - }; - - struct DTypeRef : public BaseRef - { - using BaseRef::BaseRef; - DTypeRef(DLDataType); - DTypeRef(std::initializer_list); - DTypeRef(ArrayView); - }; - - struct DeviceRef : public BaseRef - { - using BaseRef::BaseRef; - DeviceRef(DLDevice); - DeviceRef(std::initializer_list); - DeviceRef(ArrayView); - }; - - inline SizeRef::SizeRef(int64_t v) - { - if (v != kAnySize) +}; + +// ============================================== +// BaseRef / Ref 类型(现在类型已完整定义) +// ============================================== +namespace details { +template +struct BaseRef { + BaseRef() : m_ref(&m_cache) {} + explicit BaseRef(T &r) : m_ref(&r) {} + BaseRef(const BaseRef &) = delete; + BaseRef &operator=(const BaseRef &) = delete; + + T *operator->() const { return m_ref; } + T &operator*() const { return *m_ref; } + void rebind(T &r) { m_ref = &r; } + +private: + T *m_ref; + T m_cache; +}; + +struct SizeRef : public BaseRef { + using BaseRef::BaseRef; + SizeRef(int64_t v); +}; + +struct DTypeRef : public BaseRef { + using BaseRef::BaseRef; + DTypeRef(DLDataType); + DTypeRef(std::initializer_list); + DTypeRef(ArrayView); +}; + +struct DeviceRef : public BaseRef { + using BaseRef::BaseRef; + DeviceRef(DLDevice); + DeviceRef(std::initializer_list); + DeviceRef(ArrayView); +}; + +inline SizeRef::SizeRef(int64_t v) { + if (v != kAnySize) { (**this).set_value(v); } - inline DTypeRef::DTypeRef(DLDataType v) { (**this).set_value(v); } - inline DTypeRef::DTypeRef(std::initializer_list l) : DTypeRef(ArrayView(l.begin(), l.size())) {} - inline DTypeRef::DTypeRef(ArrayView v) { (**this).set_options(v); } - inline DeviceRef::DeviceRef(DLDevice v) { (**this).set_value(v); } - inline DeviceRef::DeviceRef(std::initializer_list l) : DeviceRef(ArrayView(l.begin(), l.size())) {} - inline DeviceRef::DeviceRef(ArrayView v) { (**this).set_options(v); } +} +inline DTypeRef::DTypeRef(DLDataType v) { (**this).set_value(v); } +inline DTypeRef::DTypeRef(std::initializer_list l) : DTypeRef(ArrayView(l.begin(), l.size())) {} +inline DTypeRef::DTypeRef(ArrayView v) { (**this).set_options(v); } +inline DeviceRef::DeviceRef(DLDevice v) { (**this).set_value(v); } +inline DeviceRef::DeviceRef(std::initializer_list l) : DeviceRef(ArrayView(l.begin(), l.size())) {} +inline DeviceRef::DeviceRef(ArrayView v) { (**this).set_options(v); } - } // namespace details +} // namespace details - template - inline bool is_type(DLDataType dtype) - { +template +inline bool is_type(DLDataType dtype) { return dtype == details::_dtype_trait::value; - } +} - // ============================================== - // TensorMatcher - // ============================================== - struct TensorMatcher - { +// ============================================== +// TensorMatcher +// ============================================== +struct TensorMatcher { using SizeRef = details::SizeRef; using DTypeRef = details::DTypeRef; using DeviceRef = details::DeviceRef; @@ -504,47 +451,42 @@ namespace host explicit TensorMatcher(std::initializer_list s) : m_shape(s.begin(), s.size()), m_strides(nullptr, 0) {} - TensorMatcher &&with_strides(std::initializer_list s) && - { - RuntimeCheck(m_strides.empty(), "Strides already set"); - RuntimeCheck(m_shape.size == s.size(), "Stride/shape size mismatch"); - m_strides = details::ArrayView(s.begin(), s.size()); - return std::move(*this); + TensorMatcher &&with_strides(std::initializer_list s) && { + RuntimeCheck(m_strides.empty(), "Strides already set"); + RuntimeCheck(m_shape.size == s.size(), "Stride/shape size mismatch"); + m_strides = details::ArrayView(s.begin(), s.size()); + return std::move(*this); } template - TensorMatcher &&with_dtype(DTypeRef &&d) && - { - m_dtype.rebind(*d); - m_dtype->template set_options(); - return std::move(*this); + TensorMatcher &&with_dtype(DTypeRef &&d) && { + m_dtype.rebind(*d); + m_dtype->template set_options(); + return std::move(*this); } template - TensorMatcher &&with_dtype() && - { - m_dtype->template set_options(); - return std::move(*this); + TensorMatcher &&with_dtype() && { + m_dtype->template set_options(); + return std::move(*this); } template - TensorMatcher &&with_device(DeviceRef &&d) && - { - m_device.rebind(*d); - m_device->template set_options(); - return std::move(*this); + TensorMatcher &&with_device(DeviceRef &&d) && { + m_device.rebind(*d); + m_device->template set_options(); + return std::move(*this); } template - TensorMatcher &&with_device() && - { - m_device->template set_options(); - return std::move(*this); + TensorMatcher &&with_device() && { + m_device->template set_options(); + return std::move(*this); } const TensorMatcher &&verify(tvm::ffi::TensorView, DebugInfo = {}) const &&; - private: +private: static void s_print_tensor(std::ostringstream &, tvm::ffi::TensorView); void m_verify_impl(tvm::ffi::TensorView) const; @@ -552,70 +494,62 @@ namespace host details::ArrayView m_strides; DTypeRef m_dtype; DeviceRef m_device; - }; +}; - inline void TensorMatcher::s_print_tensor(std::ostringstream &os, tvm::ffi::TensorView v) - { +inline void TensorMatcher::s_print_tensor(std::ostringstream &os, tvm::ffi::TensorView v) { os << "Tensor<"; size_t d = 0; - for (int64_t s : v.shape()) - { - if (d++) - os << ", "; - os << s; + for (int64_t s : v.shape()) { + if (d++) { + os << ", "; + } + os << s; } os << ">[strides=<"; d = 0; - for (int64_t s : v.strides()) - { - if (d++) - os << ", "; - os << s; + for (int64_t s : v.strides()) { + if (d++) { + os << ", "; + } + os << s; } os << ">, dtype=" << v.dtype(); os << ", device=" << details::PrintableDevice{v.device()} << "]"; - } - - inline const TensorMatcher &&TensorMatcher::verify(tvm::ffi::TensorView v, DebugInfo info) const && - { - try - { - m_verify_impl(v); - } - catch (PanicError &e) - { - std::ostringstream os; - os << "Tensor match failed: "; - s_print_tensor(os, v); - os << " @ " << info.file_name() << ":" << info.line() << "\n- cause: " << e.root_cause(); - throw PanicError(os.str()); +} + +inline const TensorMatcher &&TensorMatcher::verify(tvm::ffi::TensorView v, DebugInfo info) const && { + try { + m_verify_impl(v); + } catch (PanicError &e) { + std::ostringstream os; + os << "Tensor match failed: "; + s_print_tensor(os, v); + os << " @ " << info.file_name() << ":" << info.line() << "\n- cause: " << e.root_cause(); + throw PanicError(os.str()); } return std::move(*this); - } +} - inline void TensorMatcher::m_verify_impl(tvm::ffi::TensorView v) const - { +inline void TensorMatcher::m_verify_impl(tvm::ffi::TensorView v) const { size_t dim = static_cast(v.dim()); RuntimeCheck(dim == m_shape.size, "Dim mismatch: expected ", m_shape.size, " got ", dim); - for (size_t i = 0; i < dim; ++i) - m_shape[i]->verify(v.size(i), "shape", (int64_t)i); - - if (!m_strides.empty()) - { - for (size_t i = 0; i < dim; ++i) - { - if (v.size(i) != 1 || !m_strides[i]->has_value()) - m_strides[i]->verify(v.stride(i), "stride", (int64_t)i); - } + for (size_t i = 0; i < dim; ++i) { + m_shape[i]->verify(v.size(i), "shape", (int64_t)i); } - else - { - RuntimeCheck(v.is_contiguous(), "Tensor not contiguous"); + + if (!m_strides.empty()) { + for (size_t i = 0; i < dim; ++i) { + if (v.size(i) != 1 || !m_strides[i]->has_value()) { + m_strides[i]->verify(v.stride(i), "stride", (int64_t)i); + } + } + } else { + RuntimeCheck(v.is_contiguous(), "Tensor not contiguous"); } m_dtype->verify(v.dtype()); m_device->verify(v.device()); - } +} } // namespace host diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh index d73c2ac04..18d5da7c3 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh @@ -127,7 +127,8 @@ template SGL_DEVICE void PDLWaitPrimary() { #if SGL_ARCH_HOPPER_OR_GREATER if constexpr (kUsePDL) { - asm volatile("griddepcontrol.wait;" ::: "memory"); + asm volatile("griddepcontrol.wait;" :: + : "memory"); } #endif } @@ -142,7 +143,8 @@ template SGL_DEVICE void PDLTriggerSecondary() { #if SGL_ARCH_HOPPER_OR_GREATER if constexpr (kUsePDL) { - asm volatile("griddepcontrol.launch_dependents;" :::); + asm volatile("griddepcontrol.launch_dependents;" :: + :); } #endif } diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h index bf7a5ce40..d6892d0dd 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h @@ -45,197 +45,172 @@ #include "source_location.h" #endif -#include -#include #include #include +#include #include #include #include +#include #include -namespace host -{ +namespace host { - template - inline constexpr bool dependent_false_v = false; +template +inline constexpr bool dependent_false_v = false; - /// \brief Source-location wrapper for debug/error messages. - struct DebugInfo : public source_location_t - { +/// \brief Source-location wrapper for debug/error messages. +struct DebugInfo : public source_location_t { DebugInfo(source_location_t loc = source_location_t::current()) : source_location_t(loc) {} - }; +}; - /// \brief Exception type thrown by `RuntimeCheck` and `Panic`. - struct PanicError : public std::runtime_error - { - public: +/// \brief Exception type thrown by `RuntimeCheck` and `Panic`. +struct PanicError : public std::runtime_error { +public: explicit PanicError(std::string msg) : runtime_error(msg), m_message(std::move(msg)) {} - auto root_cause() const -> std::string_view - { - const auto str = std::string_view{m_message}; - const auto pos = str.find(": "); - return pos == std::string_view::npos ? str : str.substr(pos + 2); + auto root_cause() const -> std::string_view { + const auto str = std::string_view{m_message}; + const auto pos = str.find(": "); + return pos == std::string_view::npos ? str : str.substr(pos + 2); } - private: +private: std::string m_message; - }; +}; - /// \brief Unconditionally abort with a formatted error message. - template - [[noreturn]] - inline auto panic(DebugInfo location, Args &&...args) -> void - { +/// \brief Unconditionally abort with a formatted error message. +template +[[noreturn]] inline auto panic(DebugInfo location, Args &&...args) -> void { std::ostringstream os; os << "Runtime check failed at " << location.file_name() << ":" << location.line(); - if constexpr (sizeof...(args) > 0) - { - os << ": "; - (os << ... << std::forward(args)); - } - else - { - os << " in " << location.function_name(); + if constexpr (sizeof...(args) > 0) { + os << ": "; + (os << ... << std::forward(args)); + } else { + os << " in " << location.function_name(); } throw PanicError(std::move(os).str()); - } - - /** - * \brief Runtime assertion: panics with a formatted message when `condition` - * is false. Extra `args` are streamed to the error message. - * - * Example: - * \code - * RuntimeCheck(n > 0, "n must be positive, got ", n); - * \endcode - */ - template - struct RuntimeCheck - { +} + +/** + * \brief Runtime assertion: panics with a formatted message when `condition` + * is false. Extra `args` are streamed to the error message. + * + * Example: + * \code + * RuntimeCheck(n > 0, "n must be positive, got ", n); + * \endcode + */ +template +struct RuntimeCheck { template - explicit RuntimeCheck(Cond &&condition, Args &&...args, DebugInfo location = {}) - { - if (condition) - return; - [[unlikely]] ::host::panic(location, std::forward(args)...); + explicit RuntimeCheck(Cond &&condition, Args &&...args, DebugInfo location = {}) { + if (condition) { + return; + } + [[unlikely]] ::host::panic(location, std::forward(args)...); } template - explicit RuntimeCheck(DebugInfo location, Cond &&condition, Args &&...args) - { - if (condition) - return; - [[unlikely]] ::host::panic(location, std::forward(args)...); + explicit RuntimeCheck(DebugInfo location, Cond &&condition, Args &&...args) { + if (condition) { + return; + } + [[unlikely]] ::host::panic(location, std::forward(args)...); } - }; - - template - struct Panic - { - explicit Panic(Args &&...args, DebugInfo location = {}) - { - ::host::panic(location, std::forward(args)...); +}; + +template +struct Panic { + explicit Panic(Args &&...args, DebugInfo location = {}) { + ::host::panic(location, std::forward(args)...); } - explicit Panic(DebugInfo location, Args &&...args) - { - ::host::panic(location, std::forward(args)...); + explicit Panic(DebugInfo location, Args &&...args) { + ::host::panic(location, std::forward(args)...); } - [[noreturn]] ~Panic() - { - std::terminate(); + [[noreturn]] ~Panic() { + std::terminate(); } - }; +}; - template - explicit RuntimeCheck(Cond &&, Args &&...) -> RuntimeCheck; +template +explicit RuntimeCheck(Cond &&, Args &&...) -> RuntimeCheck; - template - explicit RuntimeCheck(DebugInfo, Cond &&, Args &&...) -> RuntimeCheck; +template +explicit RuntimeCheck(DebugInfo, Cond &&, Args &&...) -> RuntimeCheck; - template - explicit Panic(Args &&...) -> Panic; +template +explicit Panic(Args &&...) -> Panic; - template - explicit Panic(DebugInfo, Args &&...) -> Panic; +template +explicit Panic(DebugInfo, Args &&...) -> Panic; - namespace pointer - { +namespace pointer { - // we only allow void * pointer arithmetic for safety +// we only allow void * pointer arithmetic for safety - template ::value && ...)>> - inline auto offset(void *ptr, U... offset) -> void * - { - return static_cast(ptr) + (... + offset); - } +template ::value && ...)>> +inline auto offset(void *ptr, U... offset) -> void * { + return static_cast(ptr) + (... + offset); +} - template ::value && ...)>> - inline auto offset(const void *ptr, U... offset) -> const void * - { - return static_cast(ptr) + (... + offset); - } +template ::value && ...)>> +inline auto offset(const void *ptr, U... offset) -> const void * { + return static_cast(ptr) + (... + offset); +} - } // namespace pointer +} // namespace pointer - /// \brief Integer ceiling division: ceil(a / b). - template - inline constexpr auto div_ceil(T a, U b) - { +/// \brief Integer ceiling division: ceil(a / b). +template +inline constexpr auto div_ceil(T a, U b) { static_assert(std::is_integral::value, "T must be integral"); static_assert(std::is_integral::value, "U must be integral"); return (a + b - 1) / b; - } +} - /// \brief Returns the byte width of a DLPack data type. - inline auto dtype_bytes(DLDataType dtype) -> std::size_t - { +/// \brief Returns the byte width of a DLPack data type. +inline auto dtype_bytes(DLDataType dtype) -> std::size_t { return static_cast(dtype.bits / 8); - } +} - // ====================== 修复开始:纯 C++11 兼容版 irange ====================== - // 移除所有 std::ranges / std::integral,完全兼容旧版 CUDA 编译器 +// ====================== 修复开始:纯 C++11 兼容版 irange ====================== +// 移除所有 std::ranges / std::integral,完全兼容旧版 CUDA 编译器 - template - struct IntegerRange - { +template +struct IntegerRange { T start_; T end_; - struct Iterator - { - T value; - - T operator*() const { return value; } - Iterator &operator++() - { - ++value; - return *this; - } - bool operator!=(const Iterator &other) const - { - return value != other.value; - } + struct Iterator { + T value; + + T operator*() const { return value; } + Iterator &operator++() { + ++value; + return *this; + } + bool operator!=(const Iterator &other) const { + return value != other.value; + } }; Iterator begin() const { return {start_}; } Iterator end() const { return {end_}; } - }; +}; - /// Python-style integer range: irange(n) -> [0, n) - template - IntegerRange irange(T end) - { +/// Python-style integer range: irange(n) -> [0, n) +template +IntegerRange irange(T end) { return {0, end}; - } +} - /// Python-style integer range: irange(start, end) -> [start, end) - template - IntegerRange irange(T start, T end) - { +/// Python-style integer range: irange(start, end) -> [start, end) +template +IntegerRange irange(T start, T end) { return {start, end}; - } - // ====================== 修复结束 ====================== +} +// ====================== 修复结束 ====================== } // namespace host diff --git a/test/infiniop/gptq_marlin_gemm.py b/test/infiniop/gptq_marlin_gemm.py index 9ba296d18..8119fbe8d 100644 --- a/test/infiniop/gptq_marlin_gemm.py +++ b/test/infiniop/gptq_marlin_gemm.py @@ -41,17 +41,21 @@ mnk_factors = MNK_FACTORS act_order = [False, True] + def to_iter(x): return x if isinstance(x, (list, tuple)) else (x,) -_TEST_CASES = list(itertools.product( - to_iter(k_chunk), - to_iter(n_chunk), - to_iter(quant_type), - to_iter(group_size), - to_iter(mnk_factors), - to_iter(act_order), -)) + +_TEST_CASES = list( + itertools.product( + to_iter(k_chunk), + to_iter(n_chunk), + to_iter(quant_type), + to_iter(group_size), + to_iter(mnk_factors), + to_iter(act_order), + ) +) _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16] @@ -70,6 +74,7 @@ def to_iter(x): SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + def quantize_weights( w: torch.Tensor, quant_type: ScalarType, @@ -164,10 +169,12 @@ def reshape_w(w): maybe_w_zp, ) + def get_pack_factor(num_bits): assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" return 32 // num_bits + def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): assert q_w.shape == (size_k, size_n) assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" @@ -182,6 +189,7 @@ def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): return q_w + def marlin_weights(q_w, size_k, size_n, num_bits, perm): # Permute q_w = marlin_permute_weights(q_w, size_k, size_n, perm) @@ -230,6 +238,7 @@ def get_weight_perm(num_bits: int): perm = torch.from_numpy(perm) return perm + def get_scale_perms(): scale_perm: list[int] = [] for i in range(8): @@ -253,6 +262,7 @@ def marlin_permute_scales( return s + def pack_cols( q_w: torch.Tensor, num_bits: int, @@ -278,6 +288,7 @@ def pack_cols( return q_res + def marlin_zero_points( zp: torch.Tensor, size_k: int, size_n: int, num_bits: int ) -> torch.Tensor: @@ -300,6 +311,7 @@ def marlin_zero_points( return zp + def permute_rows( q_w: torch.Tensor, w_ref: torch.Tensor, @@ -329,6 +341,7 @@ def permute_rows( rand_perm.to(device=orig_device), ) + def gptq_quantize_weights( w: torch.Tensor, quant_type: ScalarType, @@ -377,6 +390,7 @@ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): sort_indices.to(device=orig_device), ) + def marlin_quantize( w: torch.Tensor, quant_type: ScalarType, @@ -415,6 +429,7 @@ def marlin_quantize( return res_list + def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): size_k, size_n = w.shape @@ -443,6 +458,7 @@ def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int return res_list + def marlin_make_workspace( device: torch.device, max_blocks_per_sm: int = 1 ) -> torch.Tensor: @@ -488,7 +504,7 @@ def test( if size_k % group_size != 0: return - + print( f"Testing Gptq Marlin Gemm on {InfiniDeviceNames[device]} with M-K-N:({size_m, size_k, size_n}), group_size:{group_size}, dtype:{InfiniDtypeNames[dtype]}" ) @@ -509,28 +525,63 @@ def test( marlin_zp = None marlin_s2 = None output_ref = torch.matmul(a_input.torch_tensor(), w_ref) - b = TestTensor(marlin_q_w.shape, marlin_q_w.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=marlin_q_w) + b = TestTensor( + marlin_q_w.shape, + marlin_q_w.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=marlin_q_w, + ) c = TestTensor(output_ref.shape, None, dtype, device) - b_scales = TestTensor(marlin_s.shape, marlin_s.stride(), dtype, device, mode="manual", set_tensor=marlin_s) + b_scales = TestTensor( + marlin_s.shape, + marlin_s.stride(), + dtype, + device, + mode="manual", + set_tensor=marlin_s, + ) global_scale = None if marlin_zp is not None: - b_zeros = TestTensor(marlin_zp.shape, marlin_zp.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=marlin_zp) + b_zeros = TestTensor( + marlin_zp.shape, + marlin_zp.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=marlin_zp, + ) else: b_zeros = None if g_idx is not None: - b_g_idx = TestTensor(g_idx.shape, g_idx.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=g_idx) + b_g_idx = TestTensor( + g_idx.shape, + g_idx.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=g_idx, + ) else: b_g_idx = None if sort_indices is not None: - perm = TestTensor(sort_indices.shape, sort_indices.stride(), InfiniDtype.I32, device, mode="manual", set_tensor=sort_indices) + perm = TestTensor( + sort_indices.shape, + sort_indices.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=sort_indices, + ) else: perm = None - - is_k_full=True - use_atomic_add=False - use_fp32_reduce=False - is_zp_float=False - + + is_k_full = True + use_atomic_add = False + use_fp32_reduce = False + is_zp_float = False + if sync is not None: sync() @@ -554,7 +605,7 @@ def test( for tensor in [c, a_input, b, b_scales, global_scale, b_zeros, b_g_idx, perm]: if tensor is not None: tensor.destroy_desc() - + workspace_size = c_uint64(0) check_error( LIBINFINIOP.infiniopGetGptqMarlinGemmWorkspaceSize( @@ -577,18 +628,17 @@ def lib_gptq_marlin_gemm(): b_zeros.data() if b_zeros is not None else None, b_g_idx.data() if b_g_idx is not None else None, perm.data() if perm is not None else None, - quant_type.id, - is_k_full, - use_atomic_add, - use_fp32_reduce, - is_zp_float, + quant_type.id, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, None, ) ) lib_gptq_marlin_gemm() - max_diff = torch.mean(torch.abs(c.actual_tensor() - output_ref)) / torch.mean( torch.abs(output_ref) ) @@ -603,7 +653,11 @@ def lib_gptq_marlin_gemm(): NUM_ITERATIONS, ) profile_operation( - " lib", lambda: lib_gptq_marlin_gemm(), device, NUM_PRERUN, NUM_ITERATIONS + " lib", + lambda: lib_gptq_marlin_gemm(), + device, + NUM_PRERUN, + NUM_ITERATIONS, ) check_error(LIBINFINIOP.infiniopDestroyGptqMarlinGemmDescriptor(descriptor)) From d6a778d18dab3a12638806364592ab2b4b720dc3 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Wed, 1 Apr 2026 14:12:32 +0800 Subject: [PATCH 03/10] issue/1083: modified global --- .../ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu b/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu index 3e424ac4f..59271f78b 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu +++ b/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu @@ -13,13 +13,13 @@ namespace device::marlin { -__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; +INFINIOP_CUDA_KERNEL MarlinDefault(MARLIN_KERNEL_PARAMS){}; using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -__global__ void permute_cols_kernel( +INFINIOP_CUDA_KERNEL permute_cols_kernel( int4 const *__restrict__ a_int4_ptr, int const *__restrict__ perm_int_ptr, int4 *__restrict__ out_int4_ptr, @@ -32,7 +32,7 @@ __global__ void permute_cols_kernel( // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. -__global__ void permute_cols_kernel( +INFINIOP_CUDA_KERNEL permute_cols_kernel( int4 const *__restrict__ a_int4_ptr, int const *__restrict__ perm_int_ptr, int4 *__restrict__ out_int4_ptr, From c3e1fba73d8ff9a9672848243a83141f38ba802d Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Wed, 20 May 2026 02:35:58 +0000 Subject: [PATCH 04/10] issue/1083: success gptq_marlin --- .../sgl_kernel/dlpack/dlpack.h | 639 ++++++ .../ops/gptq_marlin_gemm/sgl_kernel/tensor.h | 6 +- .../gptq_marlin_gemm/sgl_kernel/tvm/ffi/any.h | 682 +++++++ .../sgl_kernel/tvm/ffi/base_details.h | 317 +++ .../sgl_kernel/tvm/ffi/c_api.h | 1226 ++++++++++++ .../sgl_kernel/tvm/ffi/cast.h | 79 + .../sgl_kernel/tvm/ffi/container/array.h | 1164 +++++++++++ .../tvm/ffi/container/container_details.h | 360 ++++ .../sgl_kernel/tvm/ffi/container/map.h | 1781 +++++++++++++++++ .../sgl_kernel/tvm/ffi/container/shape.h | 343 ++++ .../sgl_kernel/tvm/ffi/container/tensor.h | 785 ++++++++ .../sgl_kernel/tvm/ffi/container/tuple.h | 395 ++++ .../sgl_kernel/tvm/ffi/container/variant.h | 311 +++ .../sgl_kernel/tvm/ffi/dtype.h | 199 ++ .../sgl_kernel/tvm/ffi/endian.h | 90 + .../sgl_kernel/tvm/ffi/error.h | 398 ++++ .../sgl_kernel/tvm/ffi/extra/base.h | 48 + .../sgl_kernel/tvm/ffi/extra/base64.h | 142 ++ .../sgl_kernel/tvm/ffi/extra/c_env_api.h | 158 ++ .../sgl_kernel/tvm/ffi/extra/cuda/base.h | 54 + .../tvm/ffi/extra/cuda/cubin_launcher.h | 604 ++++++ .../tvm/ffi/extra/cuda/device_guard.h | 74 + .../sgl_kernel/tvm/ffi/extra/json.h | 84 + .../sgl_kernel/tvm/ffi/extra/module.h | 301 +++ .../sgl_kernel/tvm/ffi/extra/serialization.h | 72 + .../tvm/ffi/extra/structural_equal.h | 78 + .../tvm/ffi/extra/structural_hash.h | 57 + .../sgl_kernel/tvm/ffi/function.h | 998 +++++++++ .../sgl_kernel/tvm/ffi/function_details.h | 272 +++ .../sgl_kernel/tvm/ffi/memory.h | 274 +++ .../sgl_kernel/tvm/ffi/object.h | 1207 +++++++++++ .../sgl_kernel/tvm/ffi/optional.h | 428 ++++ .../tvm/ffi/reflection/access_path.h | 444 ++++ .../sgl_kernel/tvm/ffi/reflection/accessor.h | 260 +++ .../sgl_kernel/tvm/ffi/reflection/creator.h | 120 ++ .../sgl_kernel/tvm/ffi/reflection/registry.h | 741 +++++++ .../sgl_kernel/tvm/ffi/rvalue_ref.h | 163 ++ .../sgl_kernel/tvm/ffi/string.h | 1102 ++++++++++ .../sgl_kernel/tvm/ffi/type_traits.h | 828 ++++++++ .../ops/gptq_marlin_gemm/sgl_kernel/utils.cuh | 4 +- .../ops/gptq_marlin_gemm/sgl_kernel/utils.h | 2 +- test/infiniop/libinfiniop/op_register.py | 46 + 42 files changed, 17330 insertions(+), 6 deletions(-) create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/dlpack/dlpack.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/any.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/base_details.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/c_api.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/cast.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/array.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/container_details.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/map.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/shape.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tensor.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tuple.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/variant.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/dtype.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/endian.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/error.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base64.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/c_env_api.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/base.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/cubin_launcher.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/device_guard.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/json.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/module.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/serialization.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_equal.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_hash.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/function.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/function_details.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/memory.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/object.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/optional.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/access_path.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/accessor.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/creator.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/registry.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/rvalue_ref.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/string.h create mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/type_traits.h diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/dlpack/dlpack.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/dlpack/dlpack.h new file mode 100644 index 000000000..9a710ebde --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/dlpack/dlpack.h @@ -0,0 +1,639 @@ +/*! + * Copyright (c) 2017 - by Contributors + * \file dlpack.h + * \brief The common header of DLPack. + */ +#ifndef DLPACK_DLPACK_H_ +#define DLPACK_DLPACK_H_ + +/** + * \brief Compatibility with C++ + */ +#ifdef __cplusplus +#define DLPACK_EXTERN_C extern "C" +#else +#define DLPACK_EXTERN_C +#endif + +/*! \brief The current major version of dlpack */ +#define DLPACK_MAJOR_VERSION 1 + +/*! \brief The current minor version of dlpack */ +#define DLPACK_MINOR_VERSION 2 + +/*! \brief DLPACK_DLL prefix for windows */ +#ifdef _WIN32 +#ifdef DLPACK_EXPORTS +#define DLPACK_DLL __declspec(dllexport) +#else +#define DLPACK_DLL __declspec(dllimport) +#endif +#else +#define DLPACK_DLL +#endif + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/*! + * \brief The DLPack version. + * + * A change in major version indicates that we have changed the + * data layout of the ABI - DLManagedTensorVersioned. + * + * A change in minor version indicates that we have added new + * code, such as a new device type, but the ABI is kept the same. + * + * If an obtained DLPack tensor has a major version that disagrees + * with the version number specified in this header file + * (i.e. major != DLPACK_MAJOR_VERSION), the consumer must call the deleter + * (and it is safe to do so). It is not safe to access any other fields + * as the memory layout will have changed. + * + * In the case of a minor version mismatch, the tensor can be safely used as + * long as the consumer knows how to interpret all fields. Minor version + * updates indicate the addition of enumeration values. + */ +typedef struct { + /*! \brief DLPack major version. */ + uint32_t major; + /*! \brief DLPack minor version. */ + uint32_t minor; +} DLPackVersion; + +/*! + * \brief The device type in DLDevice. + */ +#ifdef __cplusplus +typedef enum : int32_t { +#else +typedef enum { +#endif + /*! \brief CPU device */ + kDLCPU = 1, + /*! \brief CUDA GPU device */ + kDLCUDA = 2, + /*! + * \brief Pinned CUDA CPU memory by cudaMallocHost + */ + kDLCUDAHost = 3, + /*! \brief OpenCL devices. */ + kDLOpenCL = 4, + /*! \brief Vulkan buffer for next generation graphics. */ + kDLVulkan = 7, + /*! \brief Metal for Apple GPU. */ + kDLMetal = 8, + /*! \brief Verilog simulator buffer */ + kDLVPI = 9, + /*! \brief ROCm GPUs for AMD GPUs */ + kDLROCM = 10, + /*! + * \brief Pinned ROCm CPU memory allocated by hipMallocHost + */ + kDLROCMHost = 11, + /*! + * \brief Reserved extension device type, + * used for quickly test extension device + * The semantics can differ depending on the implementation. + */ + kDLExtDev = 12, + /*! + * \brief CUDA managed/unified memory allocated by cudaMallocManaged + */ + kDLCUDAManaged = 13, + /*! + * \brief Unified shared memory allocated on a oneAPI non-partititioned + * device. Call to oneAPI runtime is required to determine the device + * type, the USM allocation type and the sycl context it is bound to. + * + */ + kDLOneAPI = 14, + /*! \brief GPU support for next generation WebGPU standard. */ + kDLWebGPU = 15, + /*! \brief Qualcomm Hexagon DSP */ + kDLHexagon = 16, + /*! \brief Microsoft MAIA devices */ + kDLMAIA = 17, + /*! \brief AWS Trainium */ + kDLTrn = 18, +} DLDeviceType; + +/*! + * \brief A Device for Tensor and operator. + */ +typedef struct { + /*! \brief The device type used in the device. */ + DLDeviceType device_type; + /*! + * \brief The device index. + * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0. + */ + int32_t device_id; +} DLDevice; + +/*! + * \brief The type code options DLDataType. + */ +typedef enum { + /*! \brief signed integer */ + kDLInt = 0U, + /*! \brief unsigned integer */ + kDLUInt = 1U, + /*! \brief IEEE floating point */ + kDLFloat = 2U, + /*! + * \brief Opaque handle type, reserved for testing purposes. + * Frameworks need to agree on the handle data type for the exchange to be well-defined. + */ + kDLOpaqueHandle = 3U, + /*! \brief bfloat16 */ + kDLBfloat = 4U, + /*! + * \brief complex number + * (C/C++/Python layout: compact struct per complex number) + */ + kDLComplex = 5U, + /*! \brief boolean */ + kDLBool = 6U, + /*! \brief FP8 data types */ + kDLFloat8_e3m4 = 7U, + kDLFloat8_e4m3 = 8U, + kDLFloat8_e4m3b11fnuz = 9U, + kDLFloat8_e4m3fn = 10U, + kDLFloat8_e4m3fnuz = 11U, + kDLFloat8_e5m2 = 12U, + kDLFloat8_e5m2fnuz = 13U, + kDLFloat8_e8m0fnu = 14U, + /*! \brief FP6 data types + * Setting bits != 6 is currently unspecified, and the producer must ensure it is set + * while the consumer must stop importing if the value is unexpected. + */ + kDLFloat6_e2m3fn = 15U, + kDLFloat6_e3m2fn = 16U, + /*! \brief FP4 data types + * Setting bits != 4 is currently unspecified, and the producer must ensure it is set + * while the consumer must stop importing if the value is unexpected. + */ + kDLFloat4_e2m1fn = 17U, +} DLDataTypeCode; + +/*! + * \brief The data type the tensor can hold. The data type is assumed to follow the + * native endian-ness. An explicit error message should be raised when attempting to + * export an array with non-native endianness + * + * Examples + * - float: type_code = 2, bits = 32, lanes = 1 + * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes = 4 + * - int8: type_code = 0, bits = 8, lanes = 1 + * - std::complex: type_code = 5, bits = 64, lanes = 1 + * - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits) + * - float8_e4m3: type_code = 8, bits = 8, lanes = 1 (packed in memory) + * - float6_e3m2fn: type_code = 16, bits = 6, lanes = 1 (packed in memory) + * - float4_e2m1fn: type_code = 17, bits = 4, lanes = 1 (packed in memory) + * + * When a sub-byte type is packed, DLPack requires the data to be in little bit-endian, i.e., + * for a packed data set D ((D >> (i * bits)) && bit_mask) stores the i-th element. + */ +typedef struct { + /*! + * \brief Type code of base types. + * We keep it uint8_t instead of DLDataTypeCode for minimal memory + * footprint, but the value should be one of DLDataTypeCode enum values. + * */ + uint8_t code; + /*! + * \brief Number of bits, common choices are 8, 16, 32. + */ + uint8_t bits; + /*! \brief Number of lanes in the type, used for vector types. */ + uint16_t lanes; +} DLDataType; + +/*! + * \brief Plain C Tensor object, does not manage memory. + */ +typedef struct { + /*! + * \brief The data pointer points to the allocated data. This will be CUDA + * device pointer or cl_mem handle in OpenCL. It may be opaque on some device + * types. This pointer is always aligned to 256 bytes as in CUDA. The + * `byte_offset` field should be used to point to the beginning of the data. + * + * Note that as of Nov 2021, multiple libraries (CuPy, PyTorch, TensorFlow, + * TVM, perhaps others) do not adhere to this 256 byte alignment requirement + * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed + * (after which this note will be updated); at the moment it is recommended + * to not rely on the data pointer being correctly aligned. + * + * For given DLTensor, the size of memory required to store the contents of + * data is calculated as follows: + * + * \code{.c} + * static inline size_t GetDataSize(const DLTensor* t) { + * size_t size = 1; + * for (tvm_index_t i = 0; i < t->ndim; ++i) { + * size *= t->shape[i]; + * } + * size *= (t->dtype.bits * t->dtype.lanes + 7) / 8; + * return size; + * } + * \endcode + * + * Note that if the tensor is of size zero, then the data pointer should be + * set to `NULL`. + */ + void* data; + /*! \brief The device of the tensor */ + DLDevice device; + /*! \brief Number of dimensions */ + int32_t ndim; + /*! \brief The data type of the pointer*/ + DLDataType dtype; + /*! + * \brief The shape of the tensor + * + * When ndim == 0, shape can be set to NULL. + */ + int64_t* shape; + /*! + * \brief strides of the tensor (in number of elements, not bytes), + * can not be NULL if ndim != 0, must points to + * an array of ndim elements that specifies the strides, + * so consumer can always rely on strides[dim] being valid for 0 <= dim < ndim. + * + * When ndim == 0, strides can be set to NULL. + * + * \note Before DLPack v1.2, strides can be NULL to indicate contiguous data. + * This is not allowed in DLPack v1.2 and later. The rationale + * is to simplify the consumer handling. + */ + int64_t* strides; + /*! \brief The offset in bytes to the beginning pointer to data */ + uint64_t byte_offset; +} DLTensor; + +/*! + * \brief C Tensor object, manage memory of DLTensor. This data structure is + * intended to facilitate the borrowing of DLTensor by another framework. It is + * not meant to transfer the tensor. When the borrowing framework doesn't need + * the tensor, it should call the deleter to notify the host that the resource + * is no longer needed. + * + * \note This data structure is used as Legacy DLManagedTensor + * in DLPack exchange and is deprecated after DLPack v0.8 + * Use DLManagedTensorVersioned instead. + * This data structure may get renamed or deleted in future versions. + * + * \sa DLManagedTensorVersioned + */ +typedef struct DLManagedTensor { + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; + /*! \brief the context of the original host framework of DLManagedTensor in + * which DLManagedTensor is used in the framework. It can also be NULL. + */ + void * manager_ctx; + /*! + * \brief Destructor - this should be called + * to destruct the manager_ctx which backs the DLManagedTensor. It can be + * NULL if there is no way for the caller to provide a reasonable destructor. + * The destructor deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensor * self); +} DLManagedTensor; + +// bit masks used in the DLManagedTensorVersioned + +/*! \brief bit mask to indicate that the tensor is read only. */ +#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL) + +/*! + * \brief bit mask to indicate that the tensor is a copy made by the producer. + * + * If set, the tensor is considered solely owned throughout its lifetime by the + * consumer, until the producer-provided deleter is invoked. + */ +#define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL) + +/*! + * \brief bit mask to indicate that whether a sub-byte type is packed or padded. + * + * The default for sub-byte types (ex: fp4/fp6) is assumed packed. This flag can + * be set by the producer to signal that a tensor of sub-byte type is padded. + */ +#define DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED (1UL << 2UL) + +/*! + * \brief A versioned and managed C Tensor object, manage memory of DLTensor. + * + * This data structure is intended to facilitate the borrowing of DLTensor by + * another framework. It is not meant to transfer the tensor. When the borrowing + * framework doesn't need the tensor, it should call the deleter to notify the + * host that the resource is no longer needed. + * + * \note This is the current standard DLPack exchange data structure. + */ +typedef struct DLManagedTensorVersioned { + /*! + * \brief The API and ABI version of the current managed Tensor + */ + DLPackVersion version; + /*! + * \brief the context of the original host framework. + * + * Stores DLManagedTensorVersioned is used in the + * framework. It can also be NULL. + */ + void *manager_ctx; + /*! + * \brief Destructor. + * + * This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned. + * It can be NULL if there is no way for the caller to provide a reasonable + * destructor. The destructor deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensorVersioned *self); + /*! + * \brief Additional bitmask flags information about the tensor. + * + * By default the flags should be set to 0. + * + * \note Future ABI changes should keep everything until this field + * stable, to ensure that deleter can be correctly called. + * + * \sa DLPACK_FLAG_BITMASK_READ_ONLY + * \sa DLPACK_FLAG_BITMASK_IS_COPIED + */ + uint64_t flags; + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; +} DLManagedTensorVersioned; + +//---------------------------------------------------------------------- +// DLPack `__c_dlpack_exchange_api__` fast exchange protocol definitions +//---------------------------------------------------------------------- +/*! + * \brief Request a producer library to create a new tensor. + * + * Create a new `DLManagedTensorVersioned` within the context of the producer + * library. The allocation is defined via the prototype DLTensor. + * + * This function is exposed by the framework through the DLPackExchangeAPI. + * + * \param prototype The prototype DLTensor. Only the dtype, ndim, shape, + * and device fields are used. + * \param out The output DLManagedTensorVersioned. + * \param error_ctx Context for `SetError`. + * \param SetError The function to set the error. + * \return The owning DLManagedTensorVersioned* or NULL on failure. + * SetError is called exactly when NULL is returned (the implementor + * must ensure this). + * \note - As a C function, must not thrown C++ exceptions. + * - Error propagation via SetError to avoid any direct need + * of Python API. Due to this `SetError` may have to ensure the GIL is + * held since it will presumably set a Python error. + * + * \sa DLPackExchangeAPI + */ +typedef int (*DLPackManagedTensorAllocator)( // + DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, // + void (*SetError)(void* error_ctx, const char* kind, const char* message) // +); + +/*! + * \brief Exports a PyObject* Tensor/NDArray to a DLManagedTensorVersioned. + * + * This function does not perform any stream synchronization. The consumer should query + * DLPackCurrentWorkStream to get the current work stream and launch kernels on it. + * + * This function is exposed by the framework through the DLPackExchangeAPI. + * + * \param py_object The Python object to convert. Must have the same type + * as the one the `DLPackExchangeAPI` was discovered from. + * \return The owning DLManagedTensorVersioned* or NULL on failure with a + * Python exception set. If the data cannot be described using DLPack + * this should be a BufferError if possible. + * \note - As a C function, must not thrown C++ exceptions. + * + * \sa DLPackExchangeAPI, DLPackCurrentWorkStream + */ +typedef int (*DLPackManagedTensorFromPyObjectNoSync)( // + void* py_object, // + DLManagedTensorVersioned** out // +); + +/*! + * \brief Exports a PyObject* Tensor/NDArray to a provided DLTensor. + * + * This function provides a faster interface for temporary, non-owning, exchange. + * The producer (implementor) still owns the memory of data, strides, shape. + * The liveness of the DLTensor and the data it views is only guaranteed until + * control is returned. + * + * This function currently assumes that the producer (implementor) can fill + * in the DLTensor shape and strides without the need for temporary allocations. + * + * This function does not perform any stream synchronization. The consumer should query + * DLPackCurrentWorkStream to get the current work stream and launch kernels on it. + * + * This function is exposed by the framework through the DLPackExchangeAPI. + * + * \param py_object The Python object to convert. Must have the same type + * as the one the `DLPackExchangeAPI` was discovered from. + * \param out The output DLTensor, whose space is pre-allocated on stack. + * \return 0 on success, -1 on failure with a Python exception set. + * \note - As a C function, must not thrown C++ exceptions. + * + * \sa DLPackExchangeAPI, DLPackCurrentWorkStream + */ +typedef int (*DLPackDLTensorFromPyObjectNoSync)( // + void* py_object, // + DLTensor* out // +); + +/*! + * \brief Obtain the current work stream of a device. + * + * Obtain the current work stream of a device from the producer framework. + * For example, it should map to torch.cuda.current_stream in PyTorch. + * + * When device_type is kDLCPU, the consumer do not have to query the stream + * and the producer can simply return NULL when queried. + * The consumer do not have to do anything on stream sync or setting. + * So CPU only framework can just provide a dummy implementation that + * always set out_current_stream[0] to NULL. + * + * \param device_type The device type. + * \param device_id The device id. + * \param out_current_stream The output current work stream. + * + * \return 0 on success, -1 on failure with a Python exception set. + * \note - As a C function, must not thrown C++ exceptions. + * + * \sa DLPackExchangeAPI + */ +typedef int (*DLPackCurrentWorkStream)( // + DLDeviceType device_type, // + int32_t device_id, // + void** out_current_stream // +); + +/*! + * \brief Imports a DLManagedTensorVersioned to a PyObject* Tensor/NDArray. + * + * Convert an owning DLManagedTensorVersioned* to the Python tensor of the + * producer (implementor) library with the correct type. + * + * This function does not perform any stream synchronization. + * + * This function is exposed by the framework through the DLPackExchangeAPI. + * + * \param tensor The DLManagedTensorVersioned to convert the ownership of the + * tensor is stolen. + * \param out_py_object The output Python object. + * \return 0 on success, -1 on failure with a Python exception set. + * + * \sa DLPackExchangeAPI + */ +typedef int (*DLPackManagedTensorToPyObjectNoSync)( // + DLManagedTensorVersioned* tensor, // + void** out_py_object // +); + +/*! + * \brief DLPackExchangeAPI stable header. + * \sa DLPackExchangeAPI + */ +typedef struct DLPackExchangeAPIHeader { + /*! + * \brief The provided DLPack version the consumer must check major version + * compatibility before using this struct. + */ + DLPackVersion version; + /*! + * \brief Optional pointer to an older DLPackExchangeAPI in the chain. + * + * It must be NULL if the framework does not support older versions. + * If the current major version is larger than the one supported by the + * consumer, the consumer may walk this to find an earlier supported version. + * + * \sa DLPackExchangeAPI + */ + struct DLPackExchangeAPIHeader* prev_api; +} DLPackExchangeAPIHeader; + +/*! + * \brief Framework-specific function pointers table for DLPack exchange. + * + * Additionally to `__dlpack__()` we define a C function table sharable by + * Python implementations via `__c_dlpack_exchange_api__`. + * This attribute must be set on the type as a Python integer compatible + * with `PyLong_FromVoidPtr`/`PyLong_AsVoidPtr`. + * + * A consumer library may use a pattern such as: + * + * \code + * + * PyObject *api_obj = type(tensor_obj).__c_dlpack_exchange_api__; // as C-code + * MyDLPackExchangeAPI *api = PyLong_AsVoidPtr(api_obj); + * if (api == NULL && PyErr_Occurred()) { goto handle_error; } + * + * \endcode + * + * Note that this must be defined on the type. The consumer should look up the + * attribute on the type and may cache the result for each unique type. + * + * The precise API table is given by: + * \code + * struct MyDLPackExchangeAPI : public DLPackExchangeAPI { + * MyDLPackExchangeAPI() { + * header.version.major = DLPACK_MAJOR_VERSION; + * header.version.minor = DLPACK_MINOR_VERSION; + * header.prev_version_api = nullptr; + * + * managed_tensor_allocator = MyDLPackManagedTensorAllocator; + * managed_tensor_from_py_object_no_sync = MyDLPackManagedTensorFromPyObjectNoSync; + * managed_tensor_to_py_object_no_sync = MyDLPackManagedTensorToPyObjectNoSync; + * dltensor_from_py_object_no_sync = MyDLPackDLTensorFromPyObjectNoSync; + * current_work_stream = MyDLPackCurrentWorkStream; + * } + * + * static const DLPackExchangeAPI* Global() { + * static MyDLPackExchangeAPI inst; + * return &inst; + * } + * }; + * \endcode + * + * Guidelines for leveraging DLPackExchangeAPI: + * + * There are generally two kinds of consumer needs for DLPack exchange: + * - N0: library support, where consumer.kernel(x, y, z) would like to run a kernel + * with the data from x, y, z. The consumer is also expected to run the kernel with the same + * stream context as the producer. For example, when x, y, z is torch.Tensor, + * consumer should query exchange_api->current_work_stream to get the + * current stream and launch the kernel with the same stream. + * This setup is necessary for no synchronization in kernel launch and maximum compatibility + * with CUDA graph capture in the producer. + * This is the desirable behavior for library extension support for frameworks like PyTorch. + * - N1: data ingestion and retention + * + * Note that obj.__dlpack__() API should provide useful ways for N1. + * The primary focus of the current DLPackExchangeAPI is to enable faster exchange N0 + * with the support of the function pointer current_work_stream. + * + * Array/Tensor libraries should statically create and initialize this structure + * then return a pointer to DLPackExchangeAPI as an int value in Tensor/Array. + * The DLPackExchangeAPI* must stay alive throughout the lifetime of the process. + * + * One simple way to do so is to create a static instance of DLPackExchangeAPI + * within the framework and return a pointer to it. The following code + * shows an example to do so in C++. It should also be reasonably easy + * to do so in other languages. + */ +typedef struct DLPackExchangeAPI { + /*! + * \brief The header that remains stable across versions. + */ + DLPackExchangeAPIHeader header; + /*! + * \brief Producer function pointer for DLPackManagedTensorAllocator + * This function must not be NULL. + * \sa DLPackManagedTensorAllocator + */ + DLPackManagedTensorAllocator managed_tensor_allocator; + /*! + * \brief Producer function pointer for DLPackManagedTensorFromPyObject + * This function must be not NULL. + * \sa DLPackManagedTensorFromPyObject + */ + DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync; + /*! + * \brief Producer function pointer for DLPackManagedTensorToPyObject + * This function must be not NULL. + * \sa DLPackManagedTensorToPyObject + */ + DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync; + /*! + * \brief Producer function pointer for DLPackDLTensorFromPyObject + * This function can be NULL when the producer does not support this function. + * \sa DLPackDLTensorFromPyObjectNoSync + */ + DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync; + /*! + * \brief Producer function pointer for DLPackCurrentWorkStream + * This function must be not NULL. + * \sa DLPackCurrentWorkStream + */ + DLPackCurrentWorkStream current_work_stream; +} DLPackExchangeAPI; + +#ifdef __cplusplus +} // DLPACK_EXTERN_C +#endif +#endif // DLPACK_DLPACK_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h index f30492621..308d6fac3 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h @@ -3,9 +3,9 @@ #pragma once #include "utils.h" -#include -#include -#include +#include "dlpack/dlpack.h" +#include "tvm/ffi/container/tensor.h" +#include "tvm/ffi/dtype.h" #include #include diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/any.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/any.h new file mode 100644 index 000000000..2c79b383b --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/any.h @@ -0,0 +1,682 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/any.h + * \brief Any value support. + */ +#ifndef TVM_FFI_ANY_H_ +#define TVM_FFI_ANY_H_ + +#include "c_api.h" +#include "string.h" +#include "type_traits.h" + +#include +#include + +namespace tvm { +namespace ffi { + +class Any; + +namespace details { +// Helper to perform +// unsafe operations related to object +struct AnyUnsafe; +} // namespace details + +/*! + * \brief AnyView allows us to take un-managed reference view of any value. + */ +class AnyView { +protected: + /*! \brief The underlying backing data of the any object */ + TVMFFIAny data_; + // Any can see AnyView + friend class Any; + +public: + // NOTE: the following functions use style + // since they are common functions appearing in FFI. + /*! + * \brief Reset any view to None + */ + void reset() { + data_.type_index = TypeIndex::kTVMFFINone; + // invariance: always set the union padding part to 0 + data_.zero_padding = 0; + data_.v_int64 = 0; + } + /*! + * \brief Swap this AnyView with another AnyView + * \param other The other AnyView + */ + TVM_FFI_INLINE void swap(AnyView &other) noexcept { std::swap(data_, other.data_); } + /*! \return the internal type index */ + TVM_FFI_INLINE int32_t type_index() const noexcept { return data_.type_index; } + /*! \brief Default constructor */ + AnyView() { + data_.type_index = TypeIndex::kTVMFFINone; + data_.zero_padding = 0; + data_.v_int64 = 0; + } + ~AnyView() = default; + // constructors from any view + /*! \brief Copy constructor */ + AnyView(const AnyView &) = default; + /*! \brief Copy assignment operator */ + AnyView &operator=(const AnyView &) = default; + /*! \brief Move constructor */ + AnyView(AnyView &&other) noexcept : data_(other.data_) { + other.data_.type_index = TypeIndex::kTVMFFINone; + other.data_.zero_padding = 0; + other.data_.v_int64 = 0; + } + TVM_FFI_INLINE AnyView &operator=(AnyView &&other) noexcept { + // copy-and-swap idiom + AnyView(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief Constructor from a general type. + * \tparam T The type to convert from. + * \param other The value to convert from. + */ + template ::convert_enabled>> + AnyView(const T &other) { // NOLINT(*) + TypeTraits::CopyToAnyView(other, &data_); + } + /*! + * \brief Assign from a general type. + * \tparam T The type to convert from. + * \param other The value to convert from. + */ + template ::convert_enabled>> + TVM_FFI_INLINE AnyView &operator=(const T &other) { // NOLINT(*) + // copy-and-swap idiom + AnyView(other).swap(*this); // NOLINT(*) + return *this; + } + + /*! + * \brief Try to see if we can reinterpret the AnyView to as T object. + * + * \tparam T The type to cast to. + * \return The casted value, or std::nullopt if the cast is not possible. + * \note This function won't try run type conversion (use try_cast for that purpose). + */ + template ::convert_enabled>> + TVM_FFI_INLINE std::optional as() const { + if (TypeTraits::CheckAnyStrict(&data_)) { + return TypeTraits::CopyFromAnyViewAfterCheck(&data_); + } else { + return std::optional(std::nullopt); + } + } + /*! + * \brief Shortcut of as Object to cast to a const pointer when T is an Object. + * + * \tparam T The object type. + * \return The requested pointer, returns nullptr if type mismatches. + */ + template >> + TVM_FFI_INLINE const T *as() const { + return this->as().value_or(nullptr); + } + + /*! + * \brief Cast to a type T. + * + * \tparam T The type to cast to. + * \return The casted value, or throws an exception if the cast is not possible. + */ + template ::convert_enabled>> + TVM_FFI_INLINE T cast() const { + std::optional opt = TypeTraits::TryCastFromAnyView(&data_); + if (!opt.has_value()) { + TVM_FFI_THROW(TypeError) << "Cannot convert from type `" + << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" + << TypeTraits::TypeStr() << "`"; + } + return *std::move(opt); + } + + /*! + * \brief Try to cast to a type T, return std::nullopt if the cast is not possible. + * + * \tparam T The type to cast to. + * \return The casted value, or std::nullopt if the cast is not possible. + */ + template ::convert_enabled>> + TVM_FFI_INLINE std::optional try_cast() const { + return TypeTraits::TryCastFromAnyView(&data_); + } + + // comparison with nullptr + TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { + return data_.type_index == TypeIndex::kTVMFFINone; + } + TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { + return data_.type_index != TypeIndex::kTVMFFINone; + } + /*! + * \brief Get the type key of the Any + * \return The type key of the Any + */ + TVM_FFI_INLINE std::string GetTypeKey() const { return TypeIndexToTypeKey(data_.type_index); } + // The following functions are only used for testing purposes + /*! + * \return The underlying supporting data of any view + * \note This function is used only for testing purposes. + */ + TVM_FFI_INLINE TVMFFIAny CopyToTVMFFIAny() const { return data_; } + /*! + * \return Create an AnyView from TVMFFIAny + * \param data the underlying ffi data. + */ + TVM_FFI_INLINE static AnyView CopyFromTVMFFIAny(TVMFFIAny data) { + AnyView view; + view.data_ = data; + return view; + } +}; + +namespace details { +/*! + * \brief Helper function to inplace convert any view to any. + * \param data The pointer that represents the format as any view. + * \param extra_any_bytes Indicate that the data may contain extra bytes following + * the TVMFFIAny data structure. This is reserved for future possible optimizations + * of small-string and extended any object. + */ +TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny *data, + [[maybe_unused]] size_t extra_any_bytes = 0) { + if (data->type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) { + details::ObjectUnsafe::IncRefObjectHandle(data->v_obj); + } else if (data->type_index >= TypeIndex::kTVMFFIRawStr) { + if (data->type_index == TypeIndex::kTVMFFIRawStr) { + // convert raw string to owned string object + String temp(data->v_c_str); + TypeTraits::MoveToAny(std::move(temp), data); + } else if (data->type_index == TypeIndex::kTVMFFIByteArrayPtr) { + // convert byte array to owned bytes object + Bytes temp(*static_cast(data->v_ptr)); + TypeTraits::MoveToAny(std::move(temp), data); + } else if (data->type_index == TypeIndex::kTVMFFIObjectRValueRef) { + // convert rvalue ref to owned object + Object **obj_addr = static_cast(data->v_ptr); + TVM_FFI_ICHECK(obj_addr[0] != nullptr) << "RValueRef already moved"; + ObjectRef temp(details::ObjectUnsafe::ObjectPtrFromOwned(obj_addr[0])); + // set the rvalue ref to nullptr to avoid double move + obj_addr[0] = nullptr; + TypeTraits::MoveToAny(std::move(temp), data); + } + } +} +} // namespace details + +/*! + * \brief Managed Any that takes strong reference to a value. + * + * \note Develooper invariance: the TVMFFIAny data_ + * in the Any can be safely used in AnyView. + */ +class Any { +protected: + /*! \brief The underlying backing data of the any object */ + TVMFFIAny data_; + +public: + /*! + * \brief Reset any to None + */ + TVM_FFI_INLINE void reset() { + if (data_.type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) { + details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj); + } + data_.type_index = TVMFFITypeIndex::kTVMFFINone; + data_.zero_padding = 0; + data_.v_int64 = 0; + } + /*! + * \brief Swap this Any with another Any + * \param other The other Any + */ + TVM_FFI_INLINE void swap(Any &other) noexcept { std::swap(data_, other.data_); } + /*! \return the internal type index */ + TVM_FFI_INLINE int32_t type_index() const noexcept { return data_.type_index; } + /*! + * \brief Default constructor + */ + Any() { + data_.type_index = TypeIndex::kTVMFFINone; + data_.zero_padding = 0; + data_.v_int64 = 0; + } + /*! + * \brief Destructor + */ + ~Any() { this->reset(); } + /*! + * \brief Constructor from another Any + * \param other The other Any + */ + Any(const Any &other) : data_(other.data_) { + if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { + details::ObjectUnsafe::IncRefObjectHandle(data_.v_obj); + } + } + /*! + * \brief Move constructor from another Any + * \param other The other Any + */ + Any(Any &&other) noexcept : data_(other.data_) { + other.data_.type_index = TypeIndex::kTVMFFINone; + other.data_.zero_padding = 0; + other.data_.v_int64 = 0; + } + /*! + * \brief Assign from another Any + * \param other The other Any + */ + TVM_FFI_INLINE Any &operator=(const Any &other) { + // copy-and-swap idiom + Any(other).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief Move assign from another Any + * \param other The other Any + */ + TVM_FFI_INLINE Any &operator=(Any &&other) noexcept { + // copy-and-swap idiom + Any(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief Constructor from another AnyView + * \param other The other AnyView + */ + Any(const AnyView &other) : data_(other.data_) { // NOLINT(*) + details::InplaceConvertAnyViewToAny(&data_); + } + /*! + * \brief Assign from another AnyView + * \param other The other AnyView + */ + TVM_FFI_INLINE Any &operator=(const AnyView &other) { + // copy-and-swap idiom + Any(other).swap(*this); // NOLINT(*) + return *this; + } + /*! \brief Any can be converted to AnyView in zero cost. */ + operator AnyView() const { // NOLINT(google-explicit-constructor) + return AnyView::CopyFromTVMFFIAny(data_); + } + /*! + * \brief Constructor from a general type + * \tparam T The value type of the other + */ + template ::convert_enabled>> + Any(T other) { // NOLINT(*) + TypeTraits::MoveToAny(std::move(other), &data_); + } + /*! + * \brief Assignment from a general type + * \tparam T The value type of the other + */ + template ::convert_enabled>> + TVM_FFI_INLINE Any &operator=(T other) { // NOLINT(*) + // copy-and-swap idiom + Any(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + + /** + * \brief Try to reinterpret the Any as a type T, return std::nullopt if it is not possible. + * + * \tparam T The type to cast to. + * \return The casted value, or std::nullopt if the cast is not possible. + * \note This function won't try to run type conversion (use try_cast for that purpose). + */ + template ::storage_enabled || std::is_same_v>> + TVM_FFI_INLINE std::optional as() && { + if constexpr (std::is_same_v) { + return std::move(*this); + } else { + if (TypeTraits::CheckAnyStrict(&data_)) { + return TypeTraits::MoveFromAnyAfterCheck(&data_); + } else { + return std::optional(std::nullopt); + } + } + } + + /** + * \brief Try to reinterpret the Any as a type T, return std::nullopt if it is not possible. + * + * \tparam T The type to cast to. + * \return The casted value, or std::nullopt if the cast is not possible. + * \note This function won't try to run type conversion (use try_cast for that purpose). + */ + template ::convert_enabled || std::is_same_v>> + TVM_FFI_INLINE std::optional as() const & { + if constexpr (std::is_same_v) { + return *this; + } else { + if (TypeTraits::CheckAnyStrict(&data_)) { + return TypeTraits::CopyFromAnyViewAfterCheck(&data_); + } else { + return std::optional(std::nullopt); + } + } + } + + /*! + * \brief Shortcut of as Object to cast to a const pointer when T is an Object. + * + * \tparam T The object type. + * \return The requested pointer, returns nullptr if type mismatches. + */ + template >> + TVM_FFI_INLINE const T *as() const & { + return this->as().value_or(nullptr); + } + + /** + * \brief Cast to a type T, throw an exception if the cast is not possible. + * + * \tparam T The type to cast to. + */ + template ::convert_enabled>> + TVM_FFI_INLINE T cast() const & { + std::optional opt = TypeTraits::TryCastFromAnyView(&data_); + if (!opt.has_value()) { + TVM_FFI_THROW(TypeError) << "Cannot convert from type `" + << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" + << TypeTraits::TypeStr() << "`"; + } + return *std::move(opt); + } + + /** + * \brief Cast to a type T, throw an exception if the cast is not possible. + * + * \tparam T The type to cast to. + */ + template ::storage_enabled>> + TVM_FFI_INLINE T cast() && { + if (TypeTraits::CheckAnyStrict(&data_)) { + return TypeTraits::MoveFromAnyAfterCheck(&data_); + } + // slow path, try to do fallback convert + std::optional opt = TypeTraits::TryCastFromAnyView(&data_); + if (!opt.has_value()) { + TVM_FFI_THROW(TypeError) << "Cannot convert from type `" + << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" + << TypeTraits::TypeStr() << "`"; + } + return *std::move(opt); + } + + /** + * \brief Try to cast to a type T. + * + * \tparam T The type to cast to. + * \return The casted value, or std::nullopt if the cast is not possible. + * \note use STL name since it to be more consistent with cast API. + */ + template ::convert_enabled || std::is_same_v>> + TVM_FFI_INLINE std::optional try_cast() const { + if constexpr (std::is_same_v) { + return *this; + } else { + return TypeTraits::TryCastFromAnyView(&data_); + } + } + /*! + * \brief Check if the two Any are same type and value in shallow comparison. + * \param other The other Any + * \return True if the two Any are same type and value, false otherwise. + */ + TVM_FFI_INLINE bool same_as(const Any &other) const noexcept { + return data_.type_index == other.data_.type_index && data_.zero_padding == other.data_.zero_padding && data_.v_int64 == other.data_.v_int64; + } + + /*! + * \brief Check if any and ObjectRef are same type and value in shallow comparison. + * \param other The other ObjectRef + * \return True if the two Any are same type and value, false otherwise. + */ + TVM_FFI_INLINE bool same_as(const ObjectRef &other) const noexcept { + if (other.get() != nullptr) { + return (data_.type_index == other->type_index() && reinterpret_cast(data_.v_obj) == other.get()); + } else { + return data_.type_index == TypeIndex::kTVMFFINone; + } + } + + TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { + return data_.type_index == TypeIndex::kTVMFFINone; + } + TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { + return data_.type_index != TypeIndex::kTVMFFINone; + } + + /*! + * \brief Get the type key of the Any + * \return The type key of the Any + */ + TVM_FFI_INLINE std::string GetTypeKey() const { return TypeIndexToTypeKey(data_.type_index); } + + friend struct details::AnyUnsafe; + friend struct AnyHash; + friend struct AnyEqual; +}; + +// layout assert to ensure we can freely cast between the two types +static_assert(sizeof(AnyView) == sizeof(TVMFFIAny)); +static_assert(sizeof(Any) == sizeof(TVMFFIAny)); + +namespace details { + +template +struct Type2Str { + static std::string v() { return TypeTraitsNoCR::TypeStr(); } +}; + +template <> +struct Type2Str { + static std::string v() { return "Any"; } +}; + +template <> +struct Type2Str { + static std::string v() { return "Any"; } +}; + +template <> +struct Type2Str { + static std::string v() { return "AnyView"; } +}; + +template <> +struct Type2Str { + static std::string v() { return "AnyView"; } +}; + +template <> +struct Type2Str { + static std::string v() { return "void"; } +}; + +// Extra unsafe method to help any manipulation +struct AnyUnsafe : public ObjectUnsafe { + // FFI related operations + TVM_FFI_INLINE static TVMFFIAny MoveAnyToTVMFFIAny(Any &&any) { + TVMFFIAny result = any.data_; + any.data_.type_index = TypeIndex::kTVMFFINone; + any.data_.zero_padding = 0; + any.data_.v_int64 = 0; + return result; + } + + TVM_FFI_INLINE static Any MoveTVMFFIAnyToAny(TVMFFIAny *data) { + Any any; + any.data_ = *data; + data->type_index = TypeIndex::kTVMFFINone; + data->zero_padding = 0; + data->v_int64 = 0; + return any; + } + + template + TVM_FFI_INLINE static bool CheckAnyStrict(const Any &ref) { + return TypeTraits::CheckAnyStrict(&(ref.data_)); + } + + template + TVM_FFI_INLINE static T CopyFromAnyViewAfterCheck(const Any &ref) { + if constexpr (!std::is_same_v) { + return TypeTraits::CopyFromAnyViewAfterCheck(&(ref.data_)); + } else { + return ref; + } + } + + template + TVM_FFI_INLINE static T MoveFromAnyAfterCheck(Any &&ref) { + if constexpr (!std::is_same_v) { + return TypeTraits::MoveFromAnyAfterCheck(&(ref.data_)); + } else { + return std::move(ref); + } + } + + TVM_FFI_INLINE static Object *ObjectPtrFromAnyAfterCheck(const Any &ref) { + return reinterpret_cast(ref.data_.v_obj); + } + + TVM_FFI_INLINE static const TVMFFIAny *TVMFFIAnyPtrFromAny(const Any &ref) { + return &(ref.data_); + } + + template + TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const Any &ref) { + return TypeTraits::GetMismatchTypeInfo(&(ref.data_)); + } +}; +} // namespace details + +/*! \brief String-aware Any equal functor */ +struct AnyHash { + /*! + * \brief Calculate the hash code of an Any + * \param a The given Any + * \return Hash code of a, string hash for strings and pointer address otherwise. + */ + uint64_t operator()(const Any &src) const { + if (src.data_.type_index == TypeIndex::kTVMFFISmallStr) { + // for small string, we use the same type key hash as normal string + // so heap allocated string and on stack string will have the same hash + return details::StableHashCombine(TypeIndex::kTVMFFIStr, + details::StableHashSmallStrBytes(&src.data_)); + } else if (src.data_.type_index == TypeIndex::kTVMFFISmallBytes) { + // use byte the same type key as bytes + return details::StableHashCombine(TypeIndex::kTVMFFIBytes, + details::StableHashSmallStrBytes(&src.data_)); + } else if (src.data_.type_index == TypeIndex::kTVMFFIStr || src.data_.type_index == TypeIndex::kTVMFFIBytes) { + const details::BytesObjBase *src_str = details::AnyUnsafe::CopyFromAnyViewAfterCheck(src); + return details::StableHashCombine(src.data_.type_index, + details::StableHashBytes(src_str->data, src_str->size)); + } else { + return details::StableHashCombine(src.data_.type_index, src.data_.v_uint64); + } + } +}; + +/*! \brief String-aware Any hash functor */ +struct AnyEqual { + /*! + * \brief Check if the two Any are equal + * \param lhs left operand. + * \param rhs right operand + * \return String equality if both are strings, pointer address equality otherwise. + */ + bool operator()(const Any &lhs, const Any &rhs) const { + // header with type index + const int64_t *lhs_as_int64 = reinterpret_cast(&lhs.data_); + const int64_t *rhs_as_int64 = reinterpret_cast(&rhs.data_); + static_assert(sizeof(TVMFFIAny) == 16); + // fast path, check byte equality + if (lhs_as_int64[0] == rhs_as_int64[0] && lhs_as_int64[1] == rhs_as_int64[1]) { + return true; + } + // common false case type index match, in this case we only need to pay attention to string + // equality + if (lhs.data_.type_index == rhs.data_.type_index) { + // specialy handle string hash + if (lhs.data_.type_index == TypeIndex::kTVMFFIStr || lhs.data_.type_index == TypeIndex::kTVMFFIBytes) { + const details::BytesObjBase *lhs_str = details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); + const details::BytesObjBase *rhs_str = details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); + return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size); + } + return false; + } else { + // type_index mismatch, if index is not string, return false + if (lhs.data_.type_index != kTVMFFIStr && lhs.data_.type_index != kTVMFFISmallStr && lhs.data_.type_index != kTVMFFISmallBytes && lhs.data_.type_index != kTVMFFIBytes) { + return false; + } + // small string and normal string comparison + if (lhs.data_.type_index == kTVMFFIStr && rhs.data_.type_index == kTVMFFISmallStr) { + const details::BytesObjBase *lhs_str = details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); + return Bytes::memequal(lhs_str->data, rhs.data_.v_bytes, lhs_str->size, + rhs.data_.small_str_len); + } + if (lhs.data_.type_index == kTVMFFISmallStr && rhs.data_.type_index == kTVMFFIStr) { + const details::BytesObjBase *rhs_str = details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); + return Bytes::memequal(lhs.data_.v_bytes, rhs_str->data, lhs.data_.small_str_len, + rhs_str->size); + } + if (lhs.data_.type_index == kTVMFFIBytes && rhs.data_.type_index == kTVMFFISmallBytes) { + const details::BytesObjBase *lhs_bytes = details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); + return Bytes::memequal(lhs_bytes->data, rhs.data_.v_bytes, lhs_bytes->size, + rhs.data_.small_str_len); + } + if (lhs.data_.type_index == kTVMFFISmallBytes && rhs.data_.type_index == kTVMFFIBytes) { + const details::BytesObjBase *rhs_bytes = details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); + return Bytes::memequal(lhs.data_.v_bytes, rhs_bytes->data, lhs.data_.small_str_len, + rhs_bytes->size); + } + return false; + } + } +}; +} // namespace ffi + +// Expose to the tvm namespace for usability +// Rationale: no ambiguity even in root +using tvm::ffi::Any; +using tvm::ffi::AnyView; + +} // namespace tvm +#endif // TVM_FFI_ANY_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/base_details.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/base_details.h new file mode 100644 index 000000000..147862117 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/base_details.h @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/base_details.h + * \brief Internal detail utils that can be used by files in tvm/ffi. + * \note details headers are for internal use only + * and not to be directly used by user. + */ +#ifndef TVM_FFI_BASE_DETAILS_H_ +#define TVM_FFI_BASE_DETAILS_H_ + +#include "c_api.h" +#include "endian.h" + +#include +#include + +#if defined(_MSC_VER) +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif + +#ifndef NOMINMAX +#define NOMINMAX +#endif + +#include +#if (defined(_M_ARM64) || defined(_ARM64_) || defined(_M_ARM64EC)) && !defined(_InlineInterlockedAdd64) +#define _InlineInterlockedAdd64 InterlockedAdd64 +#endif + +#ifdef ERROR +#undef ERROR +#endif + +#endif +/// \cond Doxygen_Suppress + +#if defined(_MSC_VER) +#define TVM_FFI_INLINE [[msvc::forceinline]] inline +#else +#define TVM_FFI_INLINE [[gnu::always_inline]] inline +#endif + +/*! + * \brief Macro helper to force a function not to be inlined. + * It is only used in places that we know not inlining is good, + * e.g. some logging functions. + */ +#if defined(_MSC_VER) +#define TVM_FFI_NO_INLINE [[msvc::noinline]] +#else +#define TVM_FFI_NO_INLINE [[gnu::noinline]] +#endif + +#if defined(_MSC_VER) +#define TVM_FFI_UNREACHABLE() __assume(false) +#else +#define TVM_FFI_UNREACHABLE() __builtin_unreachable() +#endif + +#define TVM_FFI_STR_CONCAT_(__x, __y) __x##__y +#define TVM_FFI_STR_CONCAT(__x, __y) TVM_FFI_STR_CONCAT_(__x, __y) + +#if defined(__GNUC__) || defined(__clang__) +#define TVM_FFI_FUNC_SIG __PRETTY_FUNCTION__ +#elif defined(_MSC_VER) +#define TVM_FFI_FUNC_SIG __FUNCSIG__ +#else +#define TVM_FFI_FUNC_SIG __func__ +#endif + +#if defined(__GNUC__) +// gcc and clang and attribute constructor +/// \cond Doxygen_Suppress +#define TVM_FFI_STATIC_INIT_BLOCK_DEF_(FnName) __attribute__((constructor)) static void FnName() +/// \endcond +/* + * \brief Macro that defines a block that will be called during static initialization. + * + * \code + * TVM_FFI_STATIC_INIT_BLOCK() { + * RegisterFunctions(); + * } + * \endcode + */ +#define TVM_FFI_STATIC_INIT_BLOCK() \ + TVM_FFI_STATIC_INIT_BLOCK_DEF_(TVM_FFI_STR_CONCAT(__TVMFFIStaticInitFunc, __COUNTER__)) + +#else +/// \cond Doxygen_Suppress +// for other compilers, use the variable trick +#define TVM_FFI_STATIC_INIT_BLOCK_DEF_(FnName, RegVar) \ + static void FnName(); \ + [[maybe_unused]] static inline int RegVar = []() { \ + FnName(); \ + return 0; \ + }(); \ + static void FnName() + +#define TVM_FFI_STATIC_INIT_BLOCK() \ + TVM_FFI_STATIC_INIT_BLOCK_DEF_(TVM_FFI_STR_CONCAT(__TVMFFIStaticInitFunc, __COUNTER__), \ + TVM_FFI_STR_CONCAT(__TVMFFIStaticInitReg, __COUNTER__)) +/// \endcond +#endif + +/* + * \brief Define the default copy/move constructor and assign operator + * \param TypeName The class typename. + */ +#define TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + TypeName(const TypeName &other) = default; /* NOLINT(bugprone-macro-parentheses) */ \ + TypeName(TypeName &&other) noexcept = default; /* NOLINT(bugprone-macro-parentheses) */ \ + TypeName &operator=(const TypeName &other) = default; /* NOLINT(bugprone-macro-parentheses) */ \ + TypeName &operator=(TypeName &&other) noexcept = default; /* NOLINT(bugprone-macro-parentheses)*/ + +/*! + * \brief marks the begining of a C call that logs exception + */ +#define TVM_FFI_LOG_EXCEPTION_CALL_BEGIN() \ + try { \ + (void)0 + +/*! + * \brief Marks the end of a C call that logs exception + */ +#define TVM_FFI_LOG_EXCEPTION_CALL_END(Name) \ + } \ + catch (const std::exception &err) { \ + std::cerr << "Exception caught during " << #Name << ":\n" \ + << err.what() << std::endl; \ + exit(-1); \ + } + +/*! + * \brief Clear the padding parts so we can safely use v_int64 for hash + * and equality check even when the value stored is a pointer. + * + * This macro is used to clear the padding parts for hash and equality check + * in 32bit platform. + */ +#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result) \ + if constexpr (sizeof(void *) != sizeof(int64_t)) { \ + (result)->v_int64 = 0; \ + } + +namespace tvm { +namespace ffi { +namespace details { + +// for each iterator +struct for_each_dispatcher { + template + static void run(std::index_sequence, const F &f, Args &&...args) { // NOLINT(*) + (f(I, std::forward(args)), ...); + } +}; + +template +void for_each(const F &f, Args &&...args) { // NOLINT(*) + for_each_dispatcher::run(std::index_sequence_for{}, f, std::forward(args)...); +} + +/*! + * \brief hash an object and combines uint64_t key with previous keys + * + * This hash function is stable across platforms. + * + * \param key The left operand. + * \param value The right operand. + * \return the combined result. + */ +template , bool> = true> +TVM_FFI_INLINE uint64_t StableHashCombine(uint64_t key, const T &value) { + // XXX: do not use std::hash in this function. This hash must be stable + // across different platforms and std::hash is implementation dependent. + return key ^ (uint64_t(value) + 0x9e3779b9 + (key << 6) + (key >> 2)); +} + +/*! + * \brief Hash the binary bytes + * \param data The data pointer + * \param size The size of the bytes. + * \return the hash value. + */ +TVM_FFI_INLINE uint64_t StableHashBytes(const void *data_ptr, size_t size) { + // NOLINTBEGIN(clang-analyzer-security.ArrayBound) + const char *data = reinterpret_cast(data_ptr); + const constexpr uint64_t kMultiplier = 1099511628211ULL; + const constexpr uint64_t kMod = 2147483647ULL; + union Union { + uint8_t a[8]; + uint64_t b; + } u; + static_assert(sizeof(Union) == sizeof(uint64_t), "sizeof(Union) != sizeof(uint64_t)"); + const char *it = data; + const char *end = it + size; + uint64_t result = 0; + if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) { + // if alignment requirement is met, directly use load + if (reinterpret_cast(it) % 8 == 0) { + for (; it + 8 <= end; it += 8) { + u.b = *reinterpret_cast(it); + result = (result * kMultiplier + u.b) % kMod; + } + } else { + // unaligned version + for (; it + 8 <= end; it += 8) { + u.a[0] = it[0]; + u.a[1] = it[1]; + u.a[2] = it[2]; + u.a[3] = it[3]; + u.a[4] = it[4]; + u.a[5] = it[5]; + u.a[6] = it[6]; + u.a[7] = it[7]; + result = (result * kMultiplier + u.b) % kMod; + } + } + } else { + // need endian swap + for (; it + 8 <= end; it += 8) { + u.a[0] = it[7]; + u.a[1] = it[6]; + u.a[2] = it[5]; + u.a[3] = it[4]; + u.a[4] = it[3]; + u.a[5] = it[2]; + u.a[6] = it[1]; + u.a[7] = it[0]; + result = (result * kMultiplier + u.b) % kMod; + } + } + + if (it < end) { + u.b = 0; + uint8_t *a = u.a; + if (it + 4 <= end) { + a[0] = it[0]; + a[1] = it[1]; + a[2] = it[2]; + a[3] = it[3]; + it += 4; + a += 4; + } + if (it + 2 <= end) { + a[0] = it[0]; + a[1] = it[1]; + it += 2; + a += 2; + } + if (it + 1 <= end) { + a[0] = it[0]; + } + if constexpr (!TVM_FFI_IO_NO_ENDIAN_SWAP) { + std::swap(u.a[0], u.a[7]); + std::swap(u.a[1], u.a[6]); + std::swap(u.a[2], u.a[5]); + std::swap(u.a[3], u.a[4]); + } + result = (result * kMultiplier + u.b) % kMod; + } + // NOLINTEND(clang-analyzer-security.ArrayBound) + return result; +} + +/*! + * \brief Same as StableHashBytes, but for small string data. + * \param data The data pointer + * \return the hash value. + */ +TVM_FFI_INLINE uint64_t StableHashSmallStrBytes(const TVMFFIAny *data) { + if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) { + // fast path, no endian swap, simply hash as uint64_t + const constexpr uint64_t kMod = 2147483647ULL; + return data->v_uint64 % kMod; + } + return StableHashBytes(reinterpret_cast(data), sizeof(data->v_uint64)); +} + +/*! + * \brief Helper to generate a JSON-based type schema for a given type. + * \tparam T The type to generate the schema for. Assuming `T` is not + * const-qualified or reference-qualified. + */ +template +struct TypeSchemaImpl; +/*! + * \brief Helper to generate a JSON-based type schema for a given type. + * \tparam T The type to generate the schema for. + * \note This type removes const and reference qualifiers from `T` before + * passing it to `TypeSchemaImpl`. + */ +template +using TypeSchema = TypeSchemaImpl>>; + +} // namespace details +} // namespace ffi +} // namespace tvm +/// \endcond +#endif // TVM_FFI_BASE_DETAILS_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/c_api.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/c_api.h new file mode 100644 index 000000000..4b721f66d --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/c_api.h @@ -0,0 +1,1226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +// NOLINTBEGIN(modernize-use-using,bugprone-reserved-identifier,modernize-deprecated-headers) +/* + * \file tvm/ffi/c_api.h + * \brief This file defines the C convention of the FFI convention + */ +#ifndef TVM_FFI_C_API_H_ +#define TVM_FFI_C_API_H_ + +#include "../../dlpack/dlpack.h" +#include + +// Macros to do weak linking +#ifdef _MSC_VER +#define TVM_FFI_WEAK __declspec(selectany) +#else +#define TVM_FFI_WEAK __attribute__((weak)) +#endif + +// Defines two macros +// TVM_FFI_DLL: marks the function as a DLL export/import +// depending on whether TVM_FFI_EXPORTS is defined +// TVM_FFI_DLL_EXPORT: always marks the function as a DLL export +#if !defined(TVM_FFI_DLL) && defined(__EMSCRIPTEN__) +#include +#define TVM_FFI_DLL EMSCRIPTEN_KEEPALIVE +#define TVM_FFI_DLL_EXPORT EMSCRIPTEN_KEEPALIVE +#endif +#if !defined(TVM_FFI_DLL) && defined(_MSC_VER) +#ifdef TVM_FFI_EXPORTS +#define TVM_FFI_DLL __declspec(dllexport) +#else +#define TVM_FFI_DLL __declspec(dllimport) +#endif +#define TVM_FFI_DLL_EXPORT __declspec(dllexport) +#endif +#ifndef TVM_FFI_DLL +#define TVM_FFI_DLL __attribute__((visibility("default"))) +#define TVM_FFI_DLL_EXPORT __attribute__((visibility("default"))) +#endif + +// NOLINTBEGIN(modernize-macro-to-enum) +/*! \brief TVM FFI major version. */ +#define TVM_FFI_VERSION_MAJOR 0 +/*! \brief TVM FFI minor version. */ +#define TVM_FFI_VERSION_MINOR 1 +/*! \brief TVM FFI patch version. */ +#define TVM_FFI_VERSION_PATCH 4 +// NOLINTEND(modernize-macro-to-enum) + +#ifdef __cplusplus +extern "C" { +#endif + +/*! + * \brief TVM FFI version. + */ +typedef struct { + /*! \brief TVM FFI major version. */ + uint32_t major; + /*! \brief TVM FFI minor version. */ + uint32_t minor; + /*! \brief TVM FFI patch version. */ + uint32_t patch; +} TVMFFIVersion; + +#ifdef __cplusplus +enum TVMFFITypeIndex : int32_t { +#else +typedef enum { +#endif + /* + * \brief The root type of all FFI objects. + * + * We include it so TypeIndex captures all possible runtime values. + * `kTVMFFIAny` code will never appear in Any::type_index. + * However, it may appear in field annotations during reflection. + */ + kTVMFFIAny = -1, + // [Section] On-stack POD and special types: [0, kTVMFFIStaticObjectBegin) + // N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array, + // which is not owned by TVMFFIAny. It is required that the following + // invariant holds: + // - `Any::type_index` is never `kTVMFFIRawStr` + // - `AnyView::type_index` can be `kTVMFFIRawStr` + // + /*! \brief None/nullptr value */ + kTVMFFINone = 0, + /*! \brief POD int value */ + kTVMFFIInt = 1, + /*! \brief POD bool value */ + kTVMFFIBool = 2, + /*! \brief POD float value */ + kTVMFFIFloat = 3, + /*! \brief Opaque pointer object */ + kTVMFFIOpaquePtr = 4, + /*! \brief DLDataType */ + kTVMFFIDataType = 5, + /*! \brief DLDevice */ + kTVMFFIDevice = 6, + /*! \brief DLTensor* */ + kTVMFFIDLTensorPtr = 7, + /*! \brief const char* */ + kTVMFFIRawStr = 8, + /*! \brief TVMFFIByteArray* */ + kTVMFFIByteArrayPtr = 9, + /*! \brief R-value reference to ObjectRef */ + kTVMFFIObjectRValueRef = 10, + /*! \brief Small string on stack */ + kTVMFFISmallStr = 11, + /*! \brief Small bytes on stack */ + kTVMFFISmallBytes = 12, + /*! \brief Start of statically defined objects. */ + kTVMFFIStaticObjectBegin = 64, + /*! + * \brief Object, all objects starts with TVMFFIObject as its header. + * \note We will also add other fields + */ + kTVMFFIObject = 64, + /*! + * \brief String object, layout = { TVMFFIObject, TVMFFIByteArray, ... } + */ + kTVMFFIStr = 65, + /*! + * \brief Bytes object, layout = { TVMFFIObject, TVMFFIByteArray, ... } + */ + kTVMFFIBytes = 66, + /*! \brief Error object. */ + kTVMFFIError = 67, + /*! \brief Function object. */ + kTVMFFIFunction = 68, + /*! + * \brief Shape object, layout = { TVMFFIObject, { const int64_t*, size_t }, ... } + */ + kTVMFFIShape = 69, + /*! + * \brief Tensor object, layout = { TVMFFIObject, DLTensor, ... } + */ + kTVMFFITensor = 70, + /*! \brief Array object. */ + kTVMFFIArray = 71, + /*! \brief Map object. */ + kTVMFFIMap = 72, + /*! \brief Runtime dynamic loaded module object. */ + kTVMFFIModule = 73, + /*! + * \brief Opaque python object. + * + * This is a special type index to indicate we are storing an opaque PyObject. + * Such object may interact with callback functions that are registered to support + * python-related operations. + * + * We only translate the objects that we do not recognize into this type index. + * + * \sa TVMFFIObjectCreateOpaque + */ + kTVMFFIOpaquePyObject = 74, + //---------------------------------------------------------------- + // more complex objects + //---------------------------------------------------------------- + kTVMFFIStaticObjectEnd, + // [Section] Dynamic Boxed: [kTVMFFIDynObjectBegin, +oo) + /*! \brief Start of type indices that are allocated at runtime. */ + kTVMFFIDynObjectBegin = 128 +#ifdef __cplusplus +}; +#else +} TVMFFITypeIndex; +#endif + +/*! \brief Handle to Object from C API's pov */ +typedef void *TVMFFIObjectHandle; + +/*! + * \brief bitmask of the object deleter flag. + */ +#ifdef __cplusplus +enum TVMFFIObjectDeleterFlagBitMask : int32_t { +#else +typedef enum { +#endif + /*! + * \brief deleter action when strong reference count becomes zero. + * Need to call destructor of the object but not free the memory block. + */ + kTVMFFIObjectDeleterFlagBitMaskStrong = 1 << 0, + /*! + * \brief deleter action when weak reference count becomes zero. + * Need to free the memory block. + */ + kTVMFFIObjectDeleterFlagBitMaskWeak = 1 << 1, + /*! + * \brief deleter action when both strong and weak reference counts become zero. + * \note This is the most common case. + */ + kTVMFFIObjectDeleterFlagBitMaskBoth = (kTVMFFIObjectDeleterFlagBitMaskStrong | kTVMFFIObjectDeleterFlagBitMaskWeak), +#ifdef __cplusplus +}; +#else +} TVMFFIObjectDeleterFlagBitMask; +#endif + +/*! + * \brief C-based type of all FFI object header that allocates on heap. + */ +typedef struct { + /*! + * \brief Combined strong and weak reference counter of the object. + * + * Strong ref counter is packed into the lower 32 bits. + * Weak ref counter is packed into the upper 32 bits. + * + * It is equivalent to { uint32_t strong_ref_count, uint32_t weak_ref_count } + * in little-endian structure: + * + * - strong_ref_count: `combined_ref_count & 0xFFFFFFFF` + * - weak_ref_count: `(combined_ref_count >> 32) & 0xFFFFFFFF` + * + * Rationale: atomic ops on strong ref counter remains the same as +1/-1, + * this combined ref counter allows us to use u64 atomic once + * instead of a separate atomic read of weak counter during deletion. + * + * The ref counter goes first to align ABI with most intrusive ptr designs. + * It is also likely more efficient as rc operations can be quite common. + */ + uint64_t combined_ref_count; + /*! + * \brief type index of the object. + * \note The type index of Object and Any are shared in FFI. + */ + int32_t type_index; + /*! \brief Extra padding to ensure 8 bytes alignment. */ + uint32_t __padding; +#if !defined(TVM_FFI_DOXYGEN_MODE) + union { +#endif + /*! + * \brief Deleter to be invoked when strong reference counter goes to zero. + * \param self The self object handle. + * \param flags The flags to indicate deletion behavior. + * \sa TVMFFIObjectDeleterFlagBitMask + */ + void (*deleter)(void *self, int flags); + /*! + * \brief auxilary field to TVMFFIObject is always 8 bytes aligned. + * \note This helps us to ensure cross platform compatibility. + */ + int64_t __ensure_align; +#if !defined(TVM_FFI_DOXYGEN_MODE) + }; +#endif +} TVMFFIObject; + +/*! + * \brief C-based type of all on stack Any value. + * + * Any value can hold on stack values like int, + * as well as reference counted pointers to object. + */ +typedef struct { + /*! + * \brief type index of the object. + * \note The type index of Object and Any are shared in FFI. + */ + int32_t type_index; +#if !defined(TVM_FFI_DOXYGEN_MODE) + union { // 4 bytes +#endif + /*! \brief padding, must set to zero for values other than small string. */ + uint32_t zero_padding; + /*! + * \brief Length of small string, with a max value of 7. + * + * We keep small str to start at next 4 bytes to ensure alignment + * when accessing the small str content. + */ + uint32_t small_str_len; +#if !defined(TVM_FFI_DOXYGEN_MODE) + }; +#endif +#if !defined(TVM_FFI_DOXYGEN_MODE) + union { // 8 bytes +#endif + /*! \brief integers */ + int64_t v_int64; + /*! \brief floating-point numbers */ + double v_float64; + /*! \brief typeless pointers */ + void *v_ptr; + /*! \brief raw C-string */ + const char *v_c_str; + /*! \brief ref counted objects */ + TVMFFIObject *v_obj; + /*! \brief data type */ + DLDataType v_dtype; + /*! \brief device */ + DLDevice v_device; + /*! \brief small string */ + char v_bytes[8]; + /*! \brief uint64 repr mainly used for hashing */ + uint64_t v_uint64; +#if !defined(TVM_FFI_DOXYGEN_MODE) + }; +#endif +} TVMFFIAny; + +/*! + * \brief Byte array data structure used by String and Bytes. + * + * String and Bytes object layout = { TVMFFIObject, TVMFFIByteArray, ... } + * + * \note This byte array data structure layout differs in 32/64 bit platforms. + * as size_t equals to the size of the pointer, use this convetion to + * be consistent with std::string and also avoid need to calculate padding + * for the size field on 32-bit platforms. + * The FFI binding should be careful when treating this ABI. + */ +typedef struct { + /*! \brief The data pointer. */ + const char *data; + /*! \brief The size of the data. */ + size_t size; +} TVMFFIByteArray; + +/*! + * \brief Shape cell used in shape object following header. + */ +typedef struct { + /*! \brief The data pointer. */ + const int64_t *data; + /*! \brief The size of the data. */ + size_t size; +} TVMFFIShapeCell; + +/*! + * \brief Mode to update the backtrace of the error. + */ +#ifdef __cplusplus +enum TVMFFIBacktraceUpdateMode : int32_t { +#else +typedef enum { +#endif + kTVMFFIBacktraceUpdateModeReplace = 0, + kTVMFFIBacktraceUpdateModeAppend = 1, +#ifdef __cplusplus +}; +#else +} TVMFFIBacktraceUpdateMode; +#endif + +/*! + * \brief Error cell used in error object following header. + */ +typedef struct { + /*! \brief The kind of the error. */ + TVMFFIByteArray kind; + /*! \brief The message of the error. */ + TVMFFIByteArray message; + /*! + * \brief The backtrace of the error. + * + * The backtrace is in the order of recent call first from the top of the stack + * to the bottom of the stack. This order makes it helpful for appending + * the extra backtrace to the end as we go up when error is propagated. + * + * When printing out, we encourage reverse the order of lines to make it + * align with python style. + */ + TVMFFIByteArray backtrace; + /*! + * \brief Function handle to update the backtrace of the error. + * \param self The self object handle. + * \param backtrace The backtrace to update. + * \param update_mode The mode to update the backtrace, + * can be either kTVMFFIBacktraceUpdateModeReplace, kTVMFFIBacktraceUpdateModeAppend. + */ + void (*update_backtrace)(TVMFFIObjectHandle self, const TVMFFIByteArray *backtrace, + int32_t update_mode); +} TVMFFIErrorCell; + +/*! + * \brief Type that defines C-style safe call convention + * + * Safe call explicitly catches exception on function boundary. + * + * \param handle The function handle + * \param num_args Number of input arguments + * \param args The input arguments to the call. + * \param result Store output result. + * + * IMPORTANT: caller must initialize result->type_index to be kTVMFFINone, + * or any other value smaller than kTVMFFIStaticObjectBegin. + * + * \return The call returns 0 if call is successful. + * It returns non-zero value if there is an error. + * + * Possible return error of the API functions: + * * 0: success + * * -1: error happens, can be retrieved by TVMFFIErrorMoveFromRaised + * * -2: a frontend error occurred and recorded in the frontend. + * + * \note We decided to leverage TVMFFIErrorMoveFromRaised and TVMFFIErrorSetRaised + * for C function error propagation. This design choice, while + * introducing a dependency for TLS runtime, simplifies error + * propgation in chains of calls in compiler codegen. + * As we do not need to propagate error through argument but simply + * set them in the runtime environment. + * + * \sa TVMFFIErrorMoveFromRaised + * \sa TVMFFIErrorSetRaised + * \sa TVMFFIErrorSetRaisedFromCStr + * \sa TVMFFIErrorSetRaisedFromCStrParts + */ +typedef int (*TVMFFISafeCallType)(void *handle, const TVMFFIAny *args, int32_t num_args, + TVMFFIAny *result); + +/*! + * \brief Object cell for function object following header. + */ +typedef struct { + /*! \brief A C API compatible call with exception catching. */ + TVMFFISafeCallType safe_call; + /*! + * \brief A function pointer to an underlying cpp call. + * + * The signature is the same as TVMFFISafeCallType except the return type is void, + * and the function throws exception directly instead of returning error code. + * We use void* here to avoid depending on c++ compiler. + * + * This pointer should be set to NULL for functions that are not originally created in cpp. + * + * \note The caller must assume the same cpp exception catching abi when using this pointer. + * When used across FFI boundaries, always use safe_call. + */ + void *cpp_call; +} TVMFFIFunctionCell; + +/*! + * \brief Object cell for opaque object following header. + */ +typedef struct { + /*! \brief The handle of the opaque object, for python it is PyObject* */ + void *handle; +} TVMFFIOpaqueObjectCell; + +//----------------------------------------------------------------------- +// Section: Version API +//----------------------------------------------------------------------- +/*! + * \brief Get the TVM FFI version from the current C ABI. + * + * This function is always stable across all versions of the C ABI. + * + * \param out_version The output version. + */ +TVM_FFI_DLL void TVMFFIGetVersion(TVMFFIVersion *out_version); + +//------------------------------------------------------------ +// Section: Basic object API +//------------------------------------------------------------ +/*! + * \brief Increase the strong reference count of an object handle + * \param obj The object handle. + * \note Internally we increase the reference counter of the object. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIObjectIncRef(TVMFFIObjectHandle obj); + +/*! + * \brief Free an object handle by decreasing strong reference + * \param obj The object handle. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIObjectDecRef(TVMFFIObjectHandle obj); + +/*! + * \brief Create an Opaque object by passing in handle, type_index and deleter. + * + * The opaque object's lifetime is managed as an Object, so it can be retained + * and released like other objects. + * When the opaque object is kTVMFFIOpaquePyObject, it can be converted back to + * the python type when returned or passed as arguments to a python function. + * + * We can support ffi::Function that interacts with these objects, + * most likely callback registered from python. + * + * For language bindings, we only convert types that we do not recognize into this type. + * On the C++ side, the most common way to represent such OpaqueObject is to simply + * use ffi::ObjectRef or ffi::Any. + * + * \param handle The resource handle of the opaque object. + * \param type_index The type index of the object. + * \param deleter deleter to recycle + * \param out The output of the opaque object. + * \return 0 when success, nonzero when failure happens + * + * \note The caller must ensure the type_index is a valid opaque object type index. + * \sa kTVMFFIOpaquePyObject + */ +TVM_FFI_DLL int TVMFFIObjectCreateOpaque(void *handle, int32_t type_index, + void (*deleter)(void *handle), TVMFFIObjectHandle *out); + +/*! + * \brief Convert type key to type index. + * \param type_key The key of the type. + * \param out_tindex the corresponding type index. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFITypeKeyToIndex(const TVMFFIByteArray *type_key, int32_t *out_tindex); + +//----------------------------------------------------------------------- +// Section: Basic function calling API for function implementation +//----------------------------------------------------------------------- +/*! + * \brief Create a FFIFunc by passing in callbacks from a C callback. + * The registered function can then be retrieved by the backend using its name. + * \param self The resource handle of the C callback. + * \param safe_call The C callback implementation. + * \param deleter The deleter to recycle. + * \param out The output of the function. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFIFunctionCreate(void *self, TVMFFISafeCallType safe_call, + void (*deleter)(void *self), TVMFFIObjectHandle *out); + +/*! + * \brief Get a global function registered in the system. + * \param name The name of the function. + * \param out The result function pointer, NULL if it does not exist. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFIFunctionGetGlobal(const TVMFFIByteArray *name, TVMFFIObjectHandle *out); + +/*! + * \brief Convert an AnyView to an owned Any. + * \param any_view The AnyView to convert. + * \param out The output Any, must be an empty object. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFIAnyViewToOwnedAny(const TVMFFIAny *any_view, TVMFFIAny *out); + +/*! + * \brief Call a FFIFunc by passing in arguments. + * \param func The resource handle of the C callback. + * \param args The input arguments to the call. + * \param num_args The number of input arguments. + * \param result The output result, caller must ensure result->type_index is set to kTVMFFINone. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny *args, int32_t num_args, + TVMFFIAny *result); + +/*! + * \brief Move the last error from the environment to the result. + * \param result The result error. + * \note This function clears the error stored in the TLS. + */ +TVM_FFI_DLL void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle *result); + +/*! + * \brief Set a raised error in TLS, which can be fetched by TVMFFIErrorMoveFromRaised. + * \param error The error object handle + */ +TVM_FFI_DLL void TVMFFIErrorSetRaised(TVMFFIObjectHandle error); + +/*! + * \brief Set a raised error in TLS, which can be fetched by TVMFFIErrorMoveFromRaised. + * \param kind The kind of the error. + * \param message The error message. + * \note This is a convenient method for the C API side to set an error directly from a string. + */ +TVM_FFI_DLL void TVMFFIErrorSetRaisedFromCStr(const char *kind, const char *message); + +/*! + * \brief Set a raised error in TLS, which can be fetched by TVMFFIErrorMoveFromRaised. + * + * Rationale: This function can be used by compilers to create error messages by + * concatenating multiple parts of the error message, which can reduce the + * storage size for common parts such as function signatures. + * + * For example, the following are possible error messages from a kernel DSL + * + * - Argument 1 mismatch in `matmul(x: Tensor, y: Tensor, z: Tensor)`, dtype mismatch + * - Argument 2 mismatch in `matmul(x: Tensor, y: Tensor, z: Tensor)`, shape[0] mismatch + * - Argument 2 mismatch in `matmul(x: Tensor, y: Tensor, z: Tensor)`, shape[1] mismatch + * + * Storing each part of the error message as a separate global string can cause quite + * a bit of duplication, especially considering the kinds of error reports we may have. + * Instead, compilers can store error messages in parts, where items like + * `matmul(x: Tensor, y: Tensor, z: Tensor)` can be reused across multiple error messages. + * This API simplifies error reporting for such cases. + * + * \param kind The kind of the error. + * \param message_parts The error message parts, each part can be NULL and will be skipped. + * \param num_parts The number of error message parts. + */ +TVM_FFI_DLL void TVMFFIErrorSetRaisedFromCStrParts(const char *kind, const char **message_parts, + int32_t num_parts); + +/*! + * \brief Create an initial error object. + * \param kind The kind of the error. + * \param message The error message. + * \param backtrace The backtrace of the error. + * \param out The output error object handle. + * \return 0 on success, nonzero on failure(likely MemoryError) + * + * \note This function is different from other functions as it is used in the error handling loop. + * So we do not follow normal error handling patterns. When error happens it will not set + * the error in TLS (since TLS error setting also involves creating an Error object). + * Instead, caller should simply report MemoryError to the logger. + */ +TVM_FFI_DLL int TVMFFIErrorCreate(const TVMFFIByteArray *kind, const TVMFFIByteArray *message, + const TVMFFIByteArray *backtrace, TVMFFIObjectHandle *out); + +//------------------------------------------------------------ +// Section: DLPack support APIs +//------------------------------------------------------------ +/*! + * \brief Produce a managed Tensor from a DLPack tensor. + * \param from The source DLPack tensor. + * \param require_alignment The minimum alignment required of the data + byte_offset. + * \param require_contiguous Boolean flag indicating if we need to check for contiguity. + * \param out The output Tensor handle. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFITensorFromDLPack(DLManagedTensor *from, int32_t require_alignment, + int32_t require_contiguous, TVMFFIObjectHandle *out); + +/*! + * \brief Produce a DLManagedTensor from the array that shares data memory with the array. + * \param from The source array. + * \param out The DLManagedTensor handle. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFITensorToDLPack(TVMFFIObjectHandle from, DLManagedTensor **out); + +/*! + * \brief Produce a managed Tensor from a DLPack tensor. + * \param from The source DLPack tensor. + * \param require_alignment The minimum alignment required of the data + byte_offset. + * \param require_contiguous Boolean flag indicating if we need to check for contiguity. + * \param out The output Tensor handle. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFITensorFromDLPackVersioned(DLManagedTensorVersioned *from, + int32_t require_alignment, + int32_t require_contiguous, + TVMFFIObjectHandle *out); + +/*! + * \brief Produce a DLManagedTensor from the array that shares data memory with the array. + * \param from The source array. + * \param out The DLManagedTensor handle. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFITensorToDLPackVersioned(TVMFFIObjectHandle from, + DLManagedTensorVersioned **out); +//--------------------------------------------------------------- +// Section: string/bytes support APIs. +// These APIs are used to simplify the string/bytes construction +//--------------------------------------------------------------- +/*! + * \brief Reinterpret the content of TVMFFIByteArray to String. + * \param input The TVMFFIByteArray to convert. + * \param out The output String owned by the caller, maybe a SmallStr or a Str object. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFIStringFromByteArray(const TVMFFIByteArray *input, TVMFFIAny *out); + +/*! + * \brief Reinterpret the content of TVMFFIByteArray to Bytes. + * \param input The TVMFFIByteArray to convert. + * \param out The output Bytes owned by the caller, maybe a SmallBytes or a Bytes object. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFIBytesFromByteArray(const TVMFFIByteArray *input, TVMFFIAny *out); + +//--------------------------------------------------------------- +// Section: dtype string support APIs. +// These APIs are used to simplify the dtype printings during FFI +//--------------------------------------------------------------- + +/*! + * \brief Convert a string to a DLDataType. + * \param str The string to convert. + * \param out The output DLDataType. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray *str, DLDataType *out); + +/*! +* \brief Convert a DLDataType to a string. +* \param dtype The DLDataType to convert. +* \param out The output string. +* \return 0 on success, nonzero on failure. +* \note out is a String object that needs to be freed by the caller via TVMFFIObjectDecRef. +The content of string can be accessed via TVMFFIObjectGetByteArrayPtr. + +* \note The input dtype is a pointer to the DLDataType to avoid ABI compatibility issues. +*/ +TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType *dtype, TVMFFIAny *out); + +//------------------------------------------------------------ +// Section: Type reflection support APIs +// +// The reflec +//------------------------------------------------------------ +/*! + * \brief Getter that can take the address of a field and set the result. + * \param field The raw address of the field. + * \param result Stores the result. + * \return 0 on success, nonzero on failure. + */ +typedef int (*TVMFFIFieldGetter)(void *field, TVMFFIAny *result); + +/*! + * \brief Getter that can take the address of a field and set it to a value. + * \param field The raw address of the field. + * \param value The value to set. + * \return 0 on success, nonzero on failure. + */ +typedef int (*TVMFFIFieldSetter)(void *field, const TVMFFIAny *value); + +/*! + * \brief Function that creates a new instance of the type. + * \param result The new object handle + * \return 0 on success, nonzero on failure. + */ +typedef int (*TVMFFIObjectCreator)(TVMFFIObjectHandle *result); + +/*! + * \brief bitmask of the field. + */ +#ifdef __cplusplus +enum TVMFFIFieldFlagBitMask : int32_t { +#else +typedef enum { +#endif + /*! \brief The field is writable. */ + kTVMFFIFieldFlagBitMaskWritable = 1 << 0, + /*! \brief The field has default value. */ + kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1, + /*! \brief The field is a static method. */ + kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2, + /*! + * \brief The field should be ignored when performing structural eq/hash + * + * This is an optional meta-data for structural eq/hash. + */ + kTVMFFIFieldFlagBitMaskSEqHashIgnore = 1 << 3, + /*! + * \brief The field enters a def region where var can be defined/matched. + * + * This is an optional meta-data for structural eq/hash. + */ + kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4, +#ifdef __cplusplus +}; +#else +} TVMFFIFieldFlagBitMask; +#endif + +/*! + * \brief Optional meta-data for structural eq/hash. + * + * This meta-data is only useful when we want to leverage the information + * to perform richer semantics aware structural comparison and hash. + * It can be safely ignored if such information is not needed. + * + * The meta-data record comparison method in tree node and DAG node. + * + * \code + * x = VarNode() + * v0 = AddNode(x, 1) + * v1 = AddNode(x, 1) + * v2 = AddNode(v0, v0) + * v3 = AddNode(v1, v0) + * \endcode + * + * Consider the construct sequence of AddNode below, + * if AddNode is treated as a tree node, then v2 and v3 + * structural equals to each other, but if AddNode is + * treated as a DAG node, then v2 and v3 does not + * structural equals to each other. + */ +#ifdef __cplusplus +enum TVMFFISEqHashKind : int32_t { +#else +typedef enum { +#endif + /*! \brief Do not support structural eq/hash. */ + kTVMFFISEqHashKindUnsupported = 0, + /*! + * \brief The object be compared as a tree node. + */ + kTVMFFISEqHashKindTreeNode = 1, + /*! + * \brief The object is treated as a free variable that can be mapped + * to another free variable in the definition region. + */ + kTVMFFISEqHashKindFreeVar = 2, + /*! + * \brief The field should be compared as a DAG node. + */ + kTVMFFISEqHashKindDAGNode = 3, + /*! + * \brief The object is treated as a constant tree node. + * + * Same as tree node, but the object does not contain free var + * as any of its nested children. + * + * That means we can use pointer equality for equality. + */ + kTVMFFISEqHashKindConstTreeNode = 4, + /*! + * \brief One can simply use pointer equality for equality. + * + * This is useful for "singleton"-style object that can + * is only an unique copy of each value. + */ + kTVMFFISEqHashKindUniqueInstance = 5, +#ifdef __cplusplus +}; +#else +} TVMFFISEqHashKind; +#endif + +/*! + * \brief Information support for optional object reflection. + */ +typedef struct { + /*! \brief The name of the field. */ + TVMFFIByteArray name; + /*! \brief The docstring about the field. */ + TVMFFIByteArray doc; + /*! \brief The structured metadata of the field in JSON string. */ + TVMFFIByteArray metadata; + /*! + * \brief bitmask flags of the field. + */ + int64_t flags; + /*! \brief The size of the field. */ + int64_t size; + /*! \brief The alignment of the field. */ + int64_t alignment; + /*! \brief The offset of the field. */ + int64_t offset; + /*! \brief The getter to access the field. */ + TVMFFIFieldGetter getter; + /*! + * \brief The setter to access the field. + * \note The setter is set even if the field is readonly for serialization. + */ + TVMFFIFieldSetter setter; + /*! + * \brief The default value of the field, this field hold AnyView, + * valid when flags set kTVMFFIFieldFlagBitMaskHasDefault + */ + TVMFFIAny default_value; + /*! + * \brief Records the static type kind of the field. + * + * Possible values: + * + * - TVMFFITypeIndex::kTVMFFIObject for general objects. + * The value is nullable when kTVMFFIObject is chosen. + * - Static object type kinds such as Map, Dict, String + * - POD type index, note it does not give information about storage size of the field. + * - TVMFFITypeIndex::kTVMFFIAny if we don't have specialized info + * about the field. + * + * When the value is a type index of Object type, the field is storaged as an ObjectRef. + * + * \note This information maybe helpful in designing serializer. + * As it helps to narrow down the field type so we don't have to + * print type_key for cases like POD types. + * It also helps to provide opportunities to enable short-cut getter to ObjectRef fields. + */ + int32_t field_static_type_index; +} TVMFFIFieldInfo; + +/*! + * \brief Method information that can appear in reflection table. + */ +typedef struct { + /*! \brief The name of the field. */ + TVMFFIByteArray name; + /*! \brief The docstring about the method. */ + TVMFFIByteArray doc; + // Rationale: We separate the docstring from the metadata since docstrings + // can be unstructured and sometimes large, while metadata can be focused + // on storing structured information. + /*! \brief Optional structured metadata of the method in JSON string. */ + TVMFFIByteArray metadata; + /*! \brief bitmask flags of the method. */ + int64_t flags; + /*! + * \brief The method wrapped as ffi::Function, stored as AnyView. + * \note The first argument to the method is always the self for instance methods. + */ + TVMFFIAny method; +} TVMFFIMethodInfo; + +/*! + * \brief Extra information of object type that can be used for reflection. + * + * \note This information is optional and can be used to enable reflection based + * creation of the object. + */ +typedef struct { + /*! \brief The docstring about the object. */ + TVMFFIByteArray doc; + /*! + * \brief An optional function that can create a new empty instance of the type. + * + * When known_fixed_size is non-zero, creator can be called + * with nullptr passed to optional_bytes. + * + * \note Caller must call setter for each field to initialize the object for + * the final object to be in valid state. + * + * \note This field is optional to enable reflection based creation. + */ + TVMFFIObjectCreator creator; + /*! + * \brief Total size of the object struct, if it is fixed and known. + * + * This field is set optional and set to 0 if not registered. + */ + int32_t total_size; + /*! + * \brief Optional meta-data for structural eq/hash. + */ + TVMFFISEqHashKind structural_eq_hash_kind; +} TVMFFITypeMetadata; + +/*! + * \brief Column array that stores extra attributes about types + * + * The attributes stored in a column array that can be looked up by type index. + * Note that the TypeAttr behaves like type_traits so column[T] so not contain + * attributes from base classes. + * + * \note + * \sa TVMFFIRegisterTypeAttr + */ +typedef struct { + /*! \brief The data of the column. */ + const TVMFFIAny *data; + /*! \brief The size of the column. */ + size_t size; +} TVMFFITypeAttrColumn; + +/*! + * \brief Runtime type information for object type checking. + */ +#ifdef __cplusplus +struct TVMFFITypeInfo { +#else +typedef struct TVMFFITypeInfo { +#endif + /*! + *\brief The runtime type index, + * It can be allocated during runtime if the type is dynamic. + */ + int32_t type_index; + /*! \brief number of parent types in the type hierachy. */ + int32_t type_depth; + /*! \brief the unique type key to identify the type. */ + TVMFFIByteArray type_key; + /*! + * \brief type_ancestors[depth] stores the type_index of the acenstors at depth level + * \note To keep things simple, we do not allow multiple inheritance so the + * hieracy stays as a tree + */ + const struct TVMFFITypeInfo **type_ancestors; + // The following fields are used for reflection + /*! \brief Cached hash value of the type key, used for consistent structural hashing. */ + uint64_t type_key_hash; + /*! \brief number of reflection accessible fields. */ + int32_t num_fields; + /*! \brief number of reflection acccesible methods. */ + int32_t num_methods; + /*! \brief The reflection field information. */ + const TVMFFIFieldInfo *fields; + /*! \brief The reflection method. */ + const TVMFFIMethodInfo *methods; + /*! \brief The extra information of the type. */ + const TVMFFITypeMetadata *metadata; +#ifdef __cplusplus +}; +#else +} TVMFFITypeInfo; +#endif + +/*! + * \brief Register the function to runtime's global table. + * The registered function can then be retrieved by the backend using its name. + * \param name The name of the function. + * \param f The function to be registered. + * \param allow_override Whether to allow overriding an already registered function. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFIFunctionSetGlobal(const TVMFFIByteArray *name, TVMFFIObjectHandle f, + int allow_override); + +/*! + * \brief Register the function to runtime's global table with method info. + * This is the same as TVMFFIFunctionSetGlobal but with method info that can provide extra + * metadata used in the runtime. + * \param method_info The method info to be registered. + * \param allow_override Whether to allow overriding an already registered function. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFIFunctionSetGlobalFromMethodInfo(const TVMFFIMethodInfo *method_info, + int allow_override); + +/*! + * \brief Register type field information for runtime reflection. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFITypeRegisterField(int32_t type_index, const TVMFFIFieldInfo *info); + +/*! + * \brief Register type method information for runtime reflection. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodInfo *info); + +/*! + * \brief Register type creator information for runtime reflection. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFITypeRegisterMetadata(int32_t type_index, const TVMFFITypeMetadata *metadata); + +/*! + * \brief Register extra type attributes that can be looked up during runtime. + * \return 0 on success, nonzero on failure. + */ +TVM_FFI_DLL int TVMFFITypeRegisterAttr(int32_t type_index, const TVMFFIByteArray *attr_name, + const TVMFFIAny *attr_value); + +/*! + * \brief Get the type attribute column by name. + * \return The pointer to the type attribute column. + * \return NULL if the attribute was not registered in the system. + */ +TVM_FFI_DLL const TVMFFITypeAttrColumn *TVMFFIGetTypeAttrColumn(const TVMFFIByteArray *attr_name); + +//------------------------------------------------------------ +// Section: Backend noexcept functions for internal use +// +// These functions are used internally and do not throw error +// instead the error will be logged and abort the process +// These are function are being called in startup or exit time +// so exception handling do not apply +//------------------------------------------------------------ +/*! + * \brief Get stack backtrace in a string. + * + * The backtrace is in the order of recent call first from the top of the stack + * to the bottom of the stack. This order makes it helpful for appending + * the extra backtrace as we unwind the stack. + * + * When printing out, we encourage reverse the order of lines to make it + * align with python style. + * + * \param filename The current file name. + * \param lineno The current line number + * \param func The current function + * \param cross_ffi_boundary Whether the backtrace is crossing the ffi boundary + * or we should stop at the ffi boundary when detected + * \return The backtrace string + * + * \note filename/func can be nullptr, then this info is skipped, they are useful + * for cases when debug symbols are not available. + */ +TVM_FFI_DLL const TVMFFIByteArray *TVMFFIBacktrace(const char *filename, int lineno, + const char *func, int cross_ffi_boundary); + +/*! + * \brief Initialize the type info during runtime. + * + * When the function is first called for a type, + * it will register the type to the type table in the runtime. + * If the static_tindex is non-negative, the function will + * allocate a runtime type index. + * Otherwise, we will populate the type table and return the static index. + * + * \param type_key The type key. + * \param type_depth The type depth. + * \param static_type_index Static type index if any, can be -1, which means this is a dynamic index + * \param num_child_slots Number of slots reserved for its children. + * \param child_slots_can_overflow Whether to allow child to overflow the slots. + * \param parent_type_index Parent type index, pass in -1 if it is root. + * + * \return The allocated type index. + */ +TVM_FFI_DLL int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray *type_key, + int32_t static_type_index, int32_t type_depth, + int32_t num_child_slots, + int32_t child_slots_can_overflow, + int32_t parent_type_index); + +/*! + * \brief Get dynamic type info by type index. + * \return The type info. + */ +TVM_FFI_DLL const TVMFFITypeInfo *TVMFFIGetTypeInfo(int32_t type_index); + +#ifdef __cplusplus +} // TVM_FFI_EXTERN_C +#endif + +//--------------------------------------------------------------- +// The following API defines static object attribute accessors +// for language bindings. +// +// They are defined in C++ inline functions for cleaner code. +// Note that they only have to do with address offset computation. +// So they can always be reimplemented in bindings when c++ is +// not available or when binding only wants to refer to the dll. +//---------------------------------------------------------------- +#ifdef __cplusplus +/*! + * \brief Get the type index of an object. + * \param obj The object handle. + * \return The type index. + */ +inline int32_t TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) { + return static_cast(obj)->type_index; +} + +/*! + * \brief Get the content of a small string in bytearray format. + * \param value The value to get the content of the small string in bytearray format. + * \return The content of the small string in bytearray format. + */ +inline TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny *value) { + return TVMFFIByteArray{value->v_bytes, static_cast(value->small_str_len)}; +} + +/*! + * \brief Get the data pointer of a bytearray from a string or bytes object. + * \param obj The object handle. + * \return The data pointer. + */ +inline TVMFFIByteArray *TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) { + return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); +} + +/*! + * \brief Get the data pointer of a ErrorInfo from an Error object. + * \param obj The object handle. + * \return The cell pointer. + */ +inline TVMFFIErrorCell *TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) { + return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); +} + +/*! + * \brief Get the data pointer of a function cell from a function object. + * \param obj The object handle. + * \return The cell pointer. + */ +inline TVMFFIFunctionCell *TVMFFIFunctionGetCellPtr(TVMFFIObjectHandle obj) { + return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); +} + +/*! + * \brief Get the data pointer of a opaque object cell from a opaque object. + * \param obj The object handle. + * \return The cell pointer. + */ +inline TVMFFIOpaqueObjectCell *TVMFFIOpaqueObjectGetCellPtr(TVMFFIObjectHandle obj) { + return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); +} + +/*! + * \brief Get the data pointer of a shape array from a shape object. + * \param obj The object handle. + * \return The cell pointer. + */ +inline TVMFFIShapeCell *TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) { + return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); +} + +/*! + * \brief Get the DLTensor pointer from an Tensor object. + * \param obj The object handle. + * \return The DLTensor pointer. + */ +inline DLTensor *TVMFFITensorGetDLTensorPtr(TVMFFIObjectHandle obj) { + return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); +} + +/*! + * \brief Create a DLDevice from a device type and device id. + * \param device_type The device type. + * \param device_id The device id. + * \return The DLDevice. + */ +inline DLDevice TVMFFIDLDeviceFromIntPair(int32_t device_type, int32_t device_id) { + return DLDevice{static_cast(device_type), device_id}; +} +#endif // __cplusplus +#endif // TVM_FFI_C_API_H_ +// NOLINTEND(modernize-use-using,bugprone-reserved-identifier,modernize-deprecated-headers) diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/cast.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/cast.h new file mode 100644 index 000000000..66abd6644 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/cast.h @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/cast.h + * \brief Extra value casting helpers + */ +#ifndef TVM_FFI_CAST_H_ +#define TVM_FFI_CAST_H_ + +#include +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief Get a reference type from a raw object ptr type + * + * It is always important to get a reference type + * if we want to return a value as reference or keep + * the object alive beyond the scope of the function. + * + * \param ptr The object pointer + * \tparam RefType The reference type + * \tparam ObjectType The object type + * \return The corresponding RefType + */ +template +inline RefType GetRef(const ObjectType* ptr) { + using ContainerType = typename RefType::ContainerType; + static_assert(std::is_base_of_v, + "Can only cast to the ref of same container type"); + + if constexpr (is_optional_type_v || RefType::_type_is_nullable) { + if (ptr == nullptr) { + return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); + } + } else { + TVM_FFI_ICHECK_NOTNULL(ptr); + } + return details::ObjectUnsafe::ObjectRefFromObjectPtr( + details::ObjectUnsafe::ObjectPtrFromUnowned( + const_cast(static_cast(ptr)))); +} + +/*! + * \brief Get an object ptr type from a raw object ptr. + * + * \param ptr The object pointer + * \tparam BaseType The reference type + * \tparam ObjectType The object type + * \return The corresponding RefType + */ +template +inline ObjectPtr GetObjectPtr(ObjectType* ptr) { + static_assert(std::is_base_of_v, + "Can only cast to the ref of same container type"); + return details::ObjectUnsafe::ObjectPtrFromUnowned(ptr); +} +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CAST_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/array.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/array.h new file mode 100644 index 000000000..9f8674b50 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/array.h @@ -0,0 +1,1164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/array.h + * \brief Array type. + * + * tvm::ffi::Array is an erased type that contains a list of content + */ +#ifndef TVM_FFI_CONTAINER_ARRAY_H_ +#define TVM_FFI_CONTAINER_ARRAY_H_ + +#include "../any.h" +#include "../memory.h" +#include "../object.h" +#include "../optional.h" +#include "container_details.h" + +#include +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +/*! \brief Array node content in array */ +class ArrayObj : public Object, public details::InplaceArrayBase { +public: + ~ArrayObj() { + Any *begin = MutableBegin(); + for (int64_t i = 0; i < size_; ++i) { + (begin + i)->Any::~Any(); + } + if (data_deleter_ != nullptr) { + data_deleter_(data_); + } + } + + /*! \return The size of the array */ + size_t size() const { return this->size_; } + + /*! + * \brief Read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const Any &at(int64_t i) const { return this->operator[](i); } + + /*! + * \brief Read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const Any &operator[](int64_t i) const { + if (i >= size_) { + TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << size_; + } + return static_cast(data_)[i]; + } + + /*! \return begin constant iterator */ + const Any *begin() const { return static_cast(data_); } + + /*! \return end constant iterator */ + const Any *end() const { return begin() + size_; } + + /*! \brief Release reference to all the elements */ + void clear() { ShrinkBy(size_); } + + /*! + * \brief Set i-th element of the array in-place + * \param i The index + * \param item The value to be set + */ + void SetItem(int64_t i, Any item) { + if (i >= size_) { + TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << size_; + } + static_cast(data_)[i] = std::move(item); + } + + /*! + * \brief Constructs a container and copy from another + * \param cap The capacity of the container + * \param from Source of the copy + * \return Ref-counted ArrayObj requested + */ + static ObjectPtr CopyFrom(int64_t cap, ArrayObj *from) { + int64_t size = from->size_; + if (size > cap) { + TVM_FFI_THROW(ValueError) << "Not enough capacity"; + } + ObjectPtr p = ArrayObj::Empty(cap); + Any *write = p->MutableBegin(); + Any *read = from->MutableBegin(); + // To ensure exception safety, size is only incremented after the initialization succeeds + for (int64_t &i = p->size_ = 0; i < size; ++i) { + new (write++) Any(*read++); + } + return p; + } + + /*! + * \brief Constructs a container and move from another + * \param cap The capacity of the container + * \param from Source of the move + * \return Ref-counted ArrayObj requested + */ + static ObjectPtr MoveFrom(int64_t cap, ArrayObj *from) { + int64_t size = from->size_; + if (size > cap) { + TVM_FFI_THROW(RuntimeError) << "Not enough capacity"; + } + ObjectPtr p = ArrayObj::Empty(cap); + Any *write = p->MutableBegin(); + Any *read = from->MutableBegin(); + // To ensure exception safety, size is only incremented after the initialization succeeds + for (int64_t &i = p->size_ = 0; i < size; ++i) { + new (write++) Any(std::move(*read++)); + } + from->size_ = 0; + return p; + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + * \return Ref-counted ArrayObj requested + */ + static ObjectPtr CreateRepeated(int64_t n, const Any &val) { + ObjectPtr p = ArrayObj::Empty(n); + Any *itr = p->MutableBegin(); + for (int64_t &i = p->size_ = 0; i < n; ++i) { + new (itr++) Any(val); + } + return p; + } + + /// \cond Doxygen_Suppress + static constexpr const int32_t _type_index = TypeIndex::kTVMFFIArray; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIArray, ArrayObj, Object); + /// \endcond + +private: + /*! \return Size of initialized memory, used by InplaceArrayBase. */ + size_t GetSize() const { return this->size_; } + + /*! \return begin mutable iterator */ + Any *MutableBegin() const { return static_cast(this->data_); } + + /*! \return end mutable iterator */ + Any *MutableEnd() const { return MutableBegin() + size_; } + + /*! + * \brief Emplace a new element at the back of the array + * \param idx The index of the element. + * \param args The arguments to construct the new element + */ + template + void EmplaceInit(size_t idx, Args &&...args) { + Any *itr = MutableBegin() + idx; + new (itr) Any(std::forward(args)...); + } + + /*! + * \brief Create an ArrayObj with the given capacity. + * \param n Required capacity + * \return Ref-counted ArrayObj requested + */ + static ObjectPtr Empty(int64_t n = kInitSize) { + ObjectPtr p = make_inplace_array_object(n); + p->capacity_ = n; + p->size_ = 0; + p->data_ = p->AddressOf(0); + return p; + } + + /*! + * \brief Inplace-initialize the elements starting idx from [first, last) + * \param idx The starting point + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return Self + */ + template + ArrayObj *InitRange(int64_t idx, IterType first, IterType last) { + Any *itr = MutableBegin() + idx; + for (; first != last; ++first) { + Any ref = *first; + new (itr++) Any(std::move(ref)); + } + return this; + } + + /*! + * \brief Move elements from right to left, requires src_begin > dst + * \param dst Destination + * \param src_begin The start point of copy (inclusive) + * \param src_end The end point of copy (exclusive) + * \return Self + */ + ArrayObj *MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { + Any *from = MutableBegin() + src_begin; + Any *to = MutableBegin() + dst; + while (src_begin++ != src_end) { + *to++ = std::move(*from++); + } + return this; + } + + /*! + * \brief Move elements from left to right, requires src_begin < dst + * \param dst Destination + * \param src_begin The start point of move (inclusive) + * \param src_end The end point of move (exclusive) + * \return Self + */ + ArrayObj *MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { + Any *from = MutableBegin() + src_end; + Any *to = MutableBegin() + (src_end - src_begin + dst); + while (src_begin++ != src_end) { + *--to = std::move(*--from); + } + return this; + } + + /*! + * \brief Enlarges the size of the array + * \param delta Size enlarged, should be positive + * \param val Default value + * \return Self + */ + ArrayObj *EnlargeBy(int64_t delta, const Any &val = Any()) { + Any *itr = MutableEnd(); + while (delta-- > 0) { + new (itr++) Any(val); + ++size_; + } + return this; + } + + /*! + * \brief Shrinks the size of the array + * \param delta Size shrinked, should be positive + * \return Self + */ + ArrayObj *ShrinkBy(int64_t delta) { + Any *itr = MutableEnd(); + while (delta-- > 0) { + (--itr)->Any::~Any(); + --size_; + } + return this; + } + + /*! \brief Data pointer to the first element of the array */ + void *data_; + /*! \brief Number of elements used */ + int64_t size_; + /*! \brief Number of elements allocated */ + int64_t capacity_; + /*! + * \brief Optional data deleter when data is allocated separately + * and its deletion is not managed by ArrayObj::deleter_. + */ + void (*data_deleter_)(void *) = nullptr; + + /*! \brief Initial size of ArrayObj */ + static constexpr int64_t kInitSize = 4; + + /*! \brief Expansion factor of the Array */ + static constexpr int64_t kIncFactor = 2; + + // CRTP parent class + friend InplaceArrayBase; + + // Reference class + template + friend class Array; + + template + friend class Tuple; + + template + friend struct TypeTraits; + + // To specialize make_object + friend ObjectPtr make_object<>(); +}; + +/*! \brief Helper struct for type-checking + * + * is_valid_iterator::value will be true if IterType can + * be dereferenced into a type that can be stored in an Array, and + * false otherwise. + */ +template +struct is_valid_iterator + : std::bool_constant< + std::is_same_v< + T, std::remove_cv_t())>>> + || std::is_base_of_v< + T, std::remove_cv_t())>>>> { +}; + +template +struct is_valid_iterator, IterType> : is_valid_iterator {}; + +template +struct is_valid_iterator : std::true_type {}; + +/*! + * \brief Check whether IterType is valid iterator for T. + * \tparam T The type. + * \tparam IterType The type of iterator. + */ +template +inline constexpr bool is_valid_iterator_v = is_valid_iterator::value; + +/*! + * \brief Array, container representing a contiguous sequence of ObjectRefs. + * + * Array implements in-place copy-on-write semantics. + * + * As in typical copy-on-write, a method which would typically mutate the array + * instead opaquely copies the underlying container, and then acts on its copy. + * + * If the array has reference count equal to one, we directly update the + * container in place without copying. This is optimization is sound because + * when the reference count is equal to one this reference is guranteed to be + * the sole pointer to the container. + * + * + * operator[] only provides const access, use Set to mutate the content. + * \tparam T The content Value type, must be compatible with tvm::ffi::Any + */ +template >> +class Array : public ObjectRef { +public: + /*! \brief The value type of the array */ + using value_type = T; + // constructors + /*! + * \brief Construct an Array with UnsafeInit + */ + explicit Array(UnsafeInit tag) : ObjectRef(tag) {} + /*! + * \brief default constructor + */ + Array() { data_ = ArrayObj::Empty(); } // NOLINT(modernize-use-equals-default) + /*! + * \brief Move constructor + * \param other The other array + */ + Array(Array &&other) // NOLINT(google-explicit-constructor) + : ObjectRef(std::move(other.data_)) {} + /*! + * \brief Copy constructor + * \param other The other array + */ + Array(const Array &other) : ObjectRef(other.data_) {} // NOLINT(google-explicit-constructor) + /*! + * \brief Constructor from another array + * \param other The other array + * \tparam U The value type of the other array + */ + template >> + Array(Array &&other) // NOLINT(google-explicit-constructor) + : ObjectRef(std::move(other.data_)) {} + /*! + * \brief Constructor from another array + * \param other The other array + * \tparam U The value type of the other array + */ + template >> + Array(const Array &other) // NOLINT(google-explicit-constructor) + : ObjectRef(other.data_) {} + + /*! + * \brief Move assignment from another array + * \param other The other array + */ + TVM_FFI_INLINE Array &operator=(Array &&other) { + data_ = std::move(other.data_); + return *this; + } + /*! + * \brief Assignment from another array + * \param other The other array + */ + TVM_FFI_INLINE Array &operator=(const Array &other) { + data_ = other.data_; + return *this; + } + /*! + * \brief Move assignment from another array + * \param other The other array + * \tparam U The value type of the other array + */ + template >> + TVM_FFI_INLINE Array &operator=(Array &&other) { + data_ = std::move(other.data_); + return *this; + } + /*! + * \brief Assignment from another array + * \param other The other array + * \tparam U The value type of the other array + */ + template >> + TVM_FFI_INLINE Array &operator=(const Array &other) { + data_ = other.data_; + return *this; + } + + /*! + * \brief Constructor from pointer + * \param n the container pointer + */ + explicit Array(ObjectPtr n) : ObjectRef(std::move(n)) {} + + /*! + * \brief Constructor from iterator + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template + Array(IterType first, IterType last) { // NOLINT(performance-unnecessary-value-param) + static_assert(is_valid_iterator_v, + "IterType cannot be inserted into a tvm::Array"); + Assign(first, last); + } + + /*! + * \brief constructor from initializer list + * \param init The initializer list + */ + Array(std::initializer_list init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief constructor from vector + * \param init The vector + */ + Array(const std::vector &init) { // NOLINT(*) + Assign(init.begin(), init.end()); + } + + /*! + * \brief Constructs a container with n elements. Each element is a copy of val + * \param n The size of the container + * \param val The init value + */ + explicit Array(const size_t n, const T &val) { data_ = ArrayObj::CreateRepeated(n, val); } + +public: + // iterators + /// \cond Doxygen_Suppress + struct ValueConverter { + using ResultType = T; + /*! + * \brief Convert any to T + * \param n The any value to convert + * \return The converted value + */ + static T convert(const Any &n) { return details::AnyUnsafe::CopyFromAnyViewAfterCheck(n); } + }; + /// \endcond + + /*! \brief The iterator type of the array */ + using iterator = details::IterAdapter; + /*! \brief The reverse iterator type of the array */ + using reverse_iterator = details::ReverseIterAdapter; + + /*! \return begin iterator */ + iterator begin() const { return iterator(GetArrayObj()->begin()); } + + /*! \return end iterator */ + iterator end() const { return iterator(GetArrayObj()->end()); } + + /*! \return rbegin iterator */ + reverse_iterator rbegin() const { + // ArrayObj::end() is never nullptr + return reverse_iterator(GetArrayObj()->end() - 1); + } + + /*! \return rend iterator */ + reverse_iterator rend() const { + // ArrayObj::begin() is never nullptr + return reverse_iterator(GetArrayObj()->begin() - 1); + } + +public: + // const methods in std::vector + /*! + * \brief Immutably read i-th element from array. + * \param i The index + * \return the i-th element. + */ + const T operator[](int64_t i) const { + ArrayObj *p = GetArrayObj(); + if (p == nullptr) { + TVM_FFI_THROW(IndexError) << "cannot index a null array"; + } + if (i < 0 || i >= p->size_) { + TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size " << p->size_; + } + return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*(p->begin() + i)); + } + + /*! \return The size of the array */ + size_t size() const { + ArrayObj *p = GetArrayObj(); + return p == nullptr ? 0 : GetArrayObj()->size_; + } + + /*! \return The capacity of the array */ + size_t capacity() const { + ArrayObj *p = GetArrayObj(); + return p == nullptr ? 0 : GetArrayObj()->capacity_; + } + + /*! \return Whether array is empty */ + bool empty() const { return size() == 0; } + + /*! \return The first element of the array */ + const T front() const { + ArrayObj *p = GetArrayObj(); + if (p == nullptr || p->size_ == 0) { + TVM_FFI_THROW(IndexError) << "cannot index a empty array"; + } + return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*(p->begin())); + } + + /*! \return The last element of the array */ + const T back() const { + ArrayObj *p = GetArrayObj(); + if (p == nullptr || p->size_ == 0) { + TVM_FFI_THROW(IndexError) << "cannot index a empty array"; + } + return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*(p->end() - 1)); + } + +public: + // mutation in std::vector, implements copy-on-write + /*! + * \brief push a new item to the back of the list + * \param item The item to be pushed. + */ + void push_back(const T &item) { + ArrayObj *p = CopyOnWrite(1); + p->EmplaceInit(p->size_++, item); + } + + /*! + * \brief Emplace a new element at the back of the array + * \param args The arguments to construct the new element + */ + template + void emplace_back(Args &&...args) { + ArrayObj *p = CopyOnWrite(1); + p->EmplaceInit(p->size_++, std::forward(args)...); + } + + /*! + * \brief Insert an element into the given position + * \param position An iterator pointing to the insertion point + * \param val The element to insert + */ + void insert(iterator position, const T &val) { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot insert a null array"; + } + int64_t idx = std::distance(begin(), position); + int64_t size = GetArrayObj()->size_; + auto addr = CopyOnWrite(1) // + ->EnlargeBy(1) // + ->MoveElementsRight(idx + 1, idx, size) // + ->MutableBegin(); + new (addr + idx) Any(val); + } + + /*! + * \brief Insert a range of elements into the given position + * \param position An iterator pointing to the insertion point + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + template + void insert(iterator position, IterType first, IterType last) { + static_assert(is_valid_iterator_v, + "IterType cannot be inserted into a tvm::Array"); + + if (first == last) { + return; + } + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot insert a null array"; + } + int64_t idx = std::distance(begin(), position); + int64_t size = GetArrayObj()->size_; + int64_t numel = std::distance(first, last); + CopyOnWrite(numel) + ->EnlargeBy(numel) + ->MoveElementsRight(idx + numel, idx, size) + ->InitRange(idx, first, last); + } + + /*! \brief Remove the last item of the list */ + void pop_back() { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot pop_back a null array"; + } + int64_t size = GetArrayObj()->size_; + if (size == 0) { + TVM_FFI_THROW(RuntimeError) << "cannot pop_back an empty array"; + } + CopyOnWrite()->ShrinkBy(1); + } + + /*! + * \brief Erase an element on the given position + * \param position An iterator pointing to the element to be erased + */ + void erase(iterator position) { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot erase a null array"; + } + int64_t st = std::distance(begin(), position); + int64_t size = GetArrayObj()->size_; + if (st < 0 || st >= size) { + TVM_FFI_THROW(RuntimeError) << "cannot erase at index " << st << ", because Array size is " + << size; + } + CopyOnWrite() // + ->MoveElementsLeft(st, st + 1, size) // + ->ShrinkBy(1); + } + + /*! + * \brief Erase a given range of elements + * \param first The begin iterator of the range + * \param last The end iterator of the range + */ + void erase(iterator first, iterator last) { + if (first == last) { + return; + } + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "cannot erase a null array"; + } + int64_t size = GetArrayObj()->size_; + int64_t st = std::distance(begin(), first); + int64_t ed = std::distance(begin(), last); + if (st >= ed) { + TVM_FFI_THROW(IndexError) << "cannot erase array in range [" << st << ", " << ed << ")"; + } + if (st < 0 || st > size || ed < 0 || ed > size) { + TVM_FFI_THROW(IndexError) << "cannot erase array in range [" << st << ", " << ed << ")" + << ", because array size is " << size; + } + CopyOnWrite() // + ->MoveElementsLeft(st, ed, size) // + ->ShrinkBy(ed - st); + } + + /*! + * \brief Resize the array. + * \param n The new size. + */ + void resize(int64_t n) { + if (n < 0) { + TVM_FFI_THROW(ValueError) << "cannot resize an Array to negative size"; + } + if (data_ == nullptr) { + SwitchContainer(n); + return; + } + int64_t size = GetArrayObj()->size_; + if (size < n) { + CopyOnWrite(n - size)->EnlargeBy(n - size); + } else if (size > n) { + CopyOnWrite()->ShrinkBy(size - n); + } + } + + /*! + * \brief Make sure the list has the capacity of at least n + * \param n lower bound of the capacity + */ + void reserve(int64_t n) { + if (data_ == nullptr || n > GetArrayObj()->capacity_) { + SwitchContainer(n); + } + } + + /*! \brief Release reference to all the elements */ + void clear() { + if (data_ != nullptr) { + ArrayObj *p = CopyOnWrite(); + p->clear(); + } + } + /// \cond Doxygen_Suppress + template + static size_t CalcCapacityImpl() { + return 0; + } + + template + static size_t CalcCapacityImpl(Array value, Args... args) { + return value.size() + CalcCapacityImpl(args...); + } + + template + static size_t CalcCapacityImpl(T value, Args... args) { + return 1 + CalcCapacityImpl(args...); + } + + template + static void AgregateImpl(Array &dest) {} // NOLINT(*) + + template + static void AgregateImpl(Array &dest, Array value, Args... args) { // NOLINT(*) + dest.insert(dest.end(), value.begin(), value.end()); + AgregateImpl(dest, args...); + } + + template + static void AgregateImpl(Array &dest, T value, Args... args) { // NOLINT(*) + dest.push_back(value); + AgregateImpl(dest, args...); + } + /// \endcond + +public: + // Array's own methods + + /*! + * \brief set i-th element of the array. + * \param i The index + * \param value The value to be setted. + */ + void Set(int64_t i, T value) { + ArrayObj *p = this->CopyOnWrite(); + if (i < 0 || i >= p->size_) { + TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size " << p->size_; + } + *(p->MutableBegin() + i) = std::move(value); + } + + /*! \return The underlying ArrayObj */ + ArrayObj *GetArrayObj() const { return static_cast(data_.get()); } + + /*! + * \brief Helper function to apply a map function onto the array. + * + * \param fmap The transformation function T -> U. + * + * \tparam F The type of the mutation function. + * + * \tparam U The type of the returned array, inferred from the + * return type of F. If overridden by the user, must be something + * that is convertible from the return type of F. + * + * \note This function performs copy on write optimization. If + * `fmap` returns an object of type `T`, and all elements of the + * array are mapped to themselves, then the returned array will be + * the same as the original, and reference counts of the elements in + * the array will not be incremented. + * + * \return The transformed array. + */ + template > + Array Map(F fmap) const { + return Array(MapHelper(data_, fmap)); + } + + /*! + * \brief Helper function to apply fmutate to mutate an array. + * \param fmutate The transformation function T -> T. + * \tparam F the type of the mutation function. + * \note This function performs copy on write optimization. + */ + template >>> + void MutateByApply(F fmutate) { + data_ = MapHelper(std::move(data_), fmutate); + } + + /*! + * \brief reset the array to content from iterator. + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + */ + template + void Assign(IterType first, IterType last) { // NOLINT(performance-unnecessary-value-param) + int64_t cap = std::distance(first, last); + if (cap < 0) { + TVM_FFI_THROW(ValueError) << "cannot construct an Array of negative size"; + } + ArrayObj *p = GetArrayObj(); + if (p != nullptr && data_.unique() && p->capacity_ >= cap) { + // do not have to make new space + p->clear(); + } else { + // create new space + data_ = ArrayObj::Empty(cap); + p = GetArrayObj(); + } + // To ensure exception safety, size is only incremented after the initialization succeeds + Any *itr = p->MutableBegin(); + for (int64_t &i = p->size_ = 0; i < cap; ++i, ++first, ++itr) { + new (itr) Any(*first); + } + } + + /*! + * \brief Copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which ganrantees to be unique) + */ + ArrayObj *CopyOnWrite() { + if (data_ == nullptr) { + return SwitchContainer(ArrayObj::kInitSize); + } + if (!data_.unique()) { + return SwitchContainer(capacity()); + } + return static_cast(data_.get()); + } + + /*! \brief specify container node */ + using ContainerType = ArrayObj; + + /*! + * \brief Agregate arguments into a single Array + * \param args sequence of T or Array elements + * \return Agregated Array + */ + template + static Array Agregate(Args... args) { + Array result; + result.reserve(CalcCapacityImpl(args...)); + AgregateImpl(result, args...); + return result; + } + +private: + /*! + * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements. + * \param reserve_extra Number of extra slots needed + * \return ArrayObj pointer to the unique copy + */ + ArrayObj *CopyOnWrite(int64_t reserve_extra) { + ArrayObj *p = GetArrayObj(); + if (p == nullptr) { + // necessary to get around the constexpr address issue before c++17 + const int64_t kInitSize = ArrayObj::kInitSize; + return SwitchContainer(std::max(kInitSize, reserve_extra)); + } + if (p->capacity_ >= p->size_ + reserve_extra) { + return CopyOnWrite(); + } + int64_t cap = p->capacity_ * ArrayObj::kIncFactor; + cap = std::max(cap, p->size_ + reserve_extra); + return SwitchContainer(cap); + } + + /*! + * \brief Move or copy the ArrayObj to new address with the given capacity + * \param capacity The capacity requirement of the new address + */ + ArrayObj *SwitchContainer(int64_t capacity) { + if (data_ == nullptr) { + data_ = ArrayObj::Empty(capacity); + } else if (data_.unique()) { + data_ = ArrayObj::MoveFrom(capacity, GetArrayObj()); + } else { + data_ = ArrayObj::CopyFrom(capacity, GetArrayObj()); + } + return static_cast(data_.get()); + } + + /*! \brief Helper method for mutate/map + * + * A helper function used internally by both `Array::Map` and + * `Array::MutateInPlace`. Given an array of data, apply the + * mapping function to each element, returning the collected array. + * Applies both mutate-in-place and copy-on-write optimizations, if + * possible. + * + * \param data A pointer to the ArrayObj containing input data. + * Passed by value to allow for mutate-in-place optimizations. + * + * \param fmap The mapping function + * + * \tparam F The type of the mutation function. + * + * \tparam U The output type of the mutation function. Inferred + * from the callable type given. Must inherit from ObjectRef. + * + * \return The mapped array. Depending on whether mutate-in-place + * or copy-on-write optimizations were applicable, may be the same + * underlying array as the `data` parameter. + */ + template > + static ObjectPtr MapHelper(ObjectPtr data, F fmap) { + if (data == nullptr) { + return nullptr; + } + + TVM_FFI_ICHECK(data->IsInstance()); + + constexpr bool is_same_output_type = std::is_same_v; + + if constexpr (is_same_output_type) { + if (data.unique()) { + // Mutate-in-place path. Only allowed if the output type U is + // the same as type T, we have a mutable this*, and there are + // no other shared copies of the array. + auto arr = static_cast(data.get()); + for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) { + T value = details::AnyUnsafe::CopyFromAnyViewAfterCheck(*it); + // reset the original value to nullptr, to ensure unique ownership + it->reset(); + T mapped = fmap(std::move(value)); + *it = std::move(mapped); + } + return data; + } + } + + constexpr bool compatible_types = is_valid_iterator_v || is_valid_iterator_v; + + ObjectPtr output = nullptr; + auto arr = static_cast(data.get()); + + auto it = arr->begin(); + if constexpr (compatible_types) { + // Copy-on-write path, if the output Array might be + // represented by the same underlying array as the existing + // Array. Typically, this is for functions that map `T` to + // `T`, but can also apply to functions that map `T` to + // `Optional`, or that map `T` to a subclass or superclass of + // `T`. + bool all_identical = true; + for (; it != arr->end(); it++) { + U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck(*it)); + if (!(*it).same_as(mapped)) { + // At least one mapped element is different than the + // original. Therefore, prepare the output array, + // consisting of any previous elements that had mapped to + // themselves (if any), and the element that didn't map to + // itself. + // + // We cannot use `U()` as the default object, as `U` may be + // a non-nullable type. Since the default `Any()` + // will be overwritten before returning, all objects will be + // of type `U` for the calling scope. + all_identical = false; + output = ArrayObj::CreateRepeated(static_cast(arr->size()), Any()); + output->InitRange(0, arr->begin(), it); + output->SetItem(it - arr->begin(), std::move(mapped)); + it++; + break; + } + } + if (all_identical) { + return data; + } + } else { + // Path for incompatible types. The constexpr check for + // compatible types isn't strictly necessary, as the first + // (*it).same_as(mapped) would return false, but we might as well + // avoid it altogether. + // + // We cannot use `U()` as the default object, as `U` may be a + // non-nullable type. Since the default `Any()` will be + // overwritten before returning, all objects will be of type `U` + // for the calling scope. + output = ArrayObj::CreateRepeated(static_cast(arr->size()), Any()); + } + + // Normal path for incompatible types, or post-copy path for + // copy-on-write instances. + // + // If the types are incompatible, then at this point `output` is + // empty, and `it` points to the first element of the input. + // + // If the types were compatible, then at this point `output` + // contains zero or more elements that mapped to themselves + // followed by the first element that does not map to itself, and + // `it` points to the element just after the first element that + // does not map to itself. Because at least one element has been + // changed, we no longer have the opportunity to avoid a copy, so + // we don't need to check the result. + // + // In both cases, `it` points to the next element to be processed, + // so we can either start or resume the iteration from that point, + // with no further checks on the result. + for (; it != arr->end(); it++) { + U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck(*it)); + output->SetItem(it - arr->begin(), std::move(mapped)); + } + + return output; + } + template + friend class Array; +}; + +/*! + * \brief Concat two Arrays. + * \param lhs first Array to be concatenated. + * \param rhs second Array to be concatenated. + * \return The concatenated Array. Original Arrays are kept unchanged. + */ +template || TypeTraits::convert_enabled>> +inline Array Concat(Array lhs, const Array &rhs) { + for (const auto &x : rhs) { + lhs.push_back(x); + } + return std::move(lhs); +} + +/*! + * \brief Specialize make_object + * \return The empty array object. + */ +template <> +inline ObjectPtr make_object() { + return ArrayObj::Empty(); +} + +// Traits for Array +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> : public ObjectRefTypeTraitsBase> { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIArray; + using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; + + TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny *src) { + if (src->type_index != TypeIndex::kTVMFFIArray) { + return TypeTraitsBase::GetMismatchTypeInfo(src); + } + if constexpr (!std::is_same_v) { + const ArrayObj *n = reinterpret_cast(src->v_obj); + for (size_t i = 0; i < n->size(); i++) { + const Any &any_v = (*n)[static_cast(i)]; + // CheckAnyStrict is cheaper than try_cast + if (details::AnyUnsafe::CheckAnyStrict(any_v)) { + continue; + } + // try see if p is convertible to T + if (any_v.try_cast()) { + continue; + } + // now report the accurate mismatch information + return "Array[index " + std::to_string(i) + ": " + details::AnyUnsafe::GetMismatchTypeInfo(any_v) + "]"; + } + } + TVM_FFI_THROW(InternalError) << "Cannot reach here"; + TVM_FFI_UNREACHABLE(); + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + if (src->type_index != TypeIndex::kTVMFFIArray) { + return false; + } + if constexpr (std::is_same_v) { + return true; + } else { + const ArrayObj *n = reinterpret_cast(src->v_obj); + for (const Any &any_v : *n) { + if (!details::AnyUnsafe::CheckAnyStrict(any_v)) { + return false; + } + } + return true; + } + } + + TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny *src) { + // try to run conversion. + if (src->type_index != TypeIndex::kTVMFFIArray) { + return std::nullopt; + } + if constexpr (!std::is_same_v) { + const ArrayObj *n = reinterpret_cast(src->v_obj); + bool storage_check = [&]() { + for (const Any &any_v : *n) { + if (!details::AnyUnsafe::CheckAnyStrict(any_v)) { + return false; + } + } + return true; + }(); + // fast path, if storage check passes, we can return the array directly. + if (storage_check) { + return CopyFromAnyViewAfterCheck(src); + } + // slow path, try to run a conversion to Array + Array result; + result.reserve(n->size()); + for (const Any &any_v : *n) { + if (auto opt_v = any_v.try_cast()) { + result.push_back(*std::move(opt_v)); + } else { + return std::nullopt; + } + } + return result; + } else { + return CopyFromAnyViewAfterCheck(src); + } + } + + TVM_FFI_INLINE static std::string TypeStr() { return "Array<" + details::Type2Str::v() + ">"; } + TVM_FFI_INLINE static std::string TypeSchema() { + std::ostringstream oss; + oss << R"({"type":")" << StaticTypeKey::kTVMFFIArray << R"(","args":[)"; + oss << details::TypeSchema::v(); + oss << "]}"; + return oss.str(); + } +}; + +namespace details { +template +inline constexpr bool type_contains_v, Array> = type_contains_v; +} // namespace details + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CONTAINER_ARRAY_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/container_details.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/container_details.h new file mode 100644 index 000000000..d6102946a --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/container_details.h @@ -0,0 +1,360 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/container_details.h + * \brief Common utilities for typed container types. + */ +#ifndef TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ +#define TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ + +#include "../any.h" +#include "../memory.h" +#include "../object.h" + +#include +#include +#include +#include + +namespace tvm { +namespace ffi { +namespace details { +/*! + * \brief Base template for classes with array like memory layout. + * + * It provides general methods to access the memory. The memory + * layout is ArrayType + [ElemType]. The alignment of ArrayType + * and ElemType is handled by the memory allocator. + * + * \tparam ArrayType The array header type, contains object specific metadata. + * \tparam ElemType The type of objects stored in the array right after + * ArrayType. + * + * \code + * // Example usage of the template to define a simple array wrapper + * class ArrayObj : public tvm::ffi::details::InplaceArrayBase { + * public: + * // Wrap EmplaceInit to initialize the elements + * template + * void Init(Iterator begin, Iterator end) { + * size_t num_elems = std::distance(begin, end); + * auto it = begin; + * this->size = 0; + * for (size_t i = 0; i < num_elems; ++i) { + * InplaceArrayBase::EmplaceInit(i, *it++); + * this->size++; + * } + * } + * } + * + * void test_function() { + * vector fields; + * auto ptr = make_inplace_array_object(fields.size()); + * ptr->Init(fields.begin(), fields.end()); + * + * // Access the 0th element in the array. + * assert(ptr->operator[](0) == fields[0]); + * } + * + * \endcode + */ +template +class InplaceArrayBase { +public: + /*! + * \brief Access element at index + * \param idx The index of the element. + * \return Const reference to ElemType at the index. + */ + const ElemType &operator[](size_t idx) const { + size_t size = Self()->GetSize(); + if (idx > size) { + TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " << size; + } + return *(reinterpret_cast(AddressOf(idx))); + } + + /*! + * \brief Access element at index + * \param idx The index of the element. + * \return Reference to ElemType at the index. + */ + ElemType &operator[](size_t idx) { + size_t size = Self()->GetSize(); + if (idx > size) { + TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " << size; + } + return *(reinterpret_cast(AddressOf(idx))); + } + + /*! + * \brief Destroy the Inplace Array Base object + */ + ~InplaceArrayBase() { + if constexpr (!(std::is_standard_layout_v && std::is_trivial_v)) { + size_t size = Self()->GetSize(); + for (size_t i = 0; i < size; ++i) { + ElemType *fp = reinterpret_cast(AddressOf(i)); + fp->ElemType::~ElemType(); + } + } + } + +private: + InplaceArrayBase() = default; + friend ArrayType; + +protected: + /*! + * \brief Construct a value in place with the arguments. + * + * \tparam Args Type parameters of the arguments. + * \param idx Index of the element. + * \param args Arguments to construct the new value. + * + * \note Please make sure ArrayType::GetSize returns 0 before first call of + * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds. + */ + template + void EmplaceInit(size_t idx, Args &&...args) { + void *field_ptr = AddressOf(idx); + new (field_ptr) ElemType(std::forward(args)...); + } + + /*! + * \brief Return the self object for the array. + * + * \return Pointer to ArrayType. + */ + inline ArrayType *Self() const { + return static_cast(const_cast(this)); + } + + /*! + * \brief Return the raw pointer to the element at idx. + * + * \param idx The index of the element. + * \return Raw pointer to the element. + */ + void *AddressOf(size_t idx) const { + static_assert( + alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, + "The size and alignment of ArrayType should respect " + "ElemType's alignment."); + + size_t kDataStart = sizeof(ArrayType); + ArrayType *self = Self(); + char *data_start = reinterpret_cast(self) + kDataStart; + return data_start + idx * sizeof(ElemType); + } +}; + +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template +class IterAdapter { +public: + using difference_type = typename std::iterator_traits::difference_type; + using value_type = typename Converter::ResultType; + using pointer = const typename Converter::ResultType *; + using reference = const typename Converter::ResultType; + using iterator_category = typename std::iterator_traits::iterator_category; + + explicit IterAdapter(TIter iter) : iter_(iter) {} + IterAdapter &operator++() { + ++iter_; + return *this; + } + IterAdapter &operator--() { + --iter_; + return *this; + } + IterAdapter operator++(int) { + IterAdapter copy = *this; + ++iter_; + return copy; + } + IterAdapter operator--(int) { + IterAdapter copy = *this; + --iter_; + return copy; + } + + IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } + + IterAdapter operator-(difference_type offset) const { return IterAdapter(iter_ - offset); } + + IterAdapter &operator+=(difference_type offset) { + iter_ += offset; + return *this; + } + + IterAdapter &operator-=(difference_type offset) { + iter_ -= offset; + return *this; + } + + template + inline std::enable_if_t, + typename T::difference_type> + operator-(const IterAdapter &rhs) const { + return iter_ - rhs.iter_; + } + + bool operator==(IterAdapter other) const { return iter_ == other.iter_; } + bool operator!=(IterAdapter other) const { return !(*this == other); } + reference operator*() const { return Converter::convert(*iter_); } + +private: + TIter iter_; +}; + +/*! + * \brief iterator adapter that adapts TIter to return another type. + * \tparam Converter a struct that contains converting function + * \tparam TIter the content iterator type. + */ +template +class ReverseIterAdapter { +public: + using difference_type = typename std::iterator_traits::difference_type; + using value_type = typename Converter::ResultType; + using pointer = const typename Converter::ResultType *; + using reference = const typename Converter::ResultType; + using iterator_category = typename std::iterator_traits::iterator_category; + + explicit ReverseIterAdapter(TIter iter) : iter_(iter) {} + ReverseIterAdapter &operator++() { + --iter_; + return *this; + } + ReverseIterAdapter &operator--() { + ++iter_; + return *this; + } + ReverseIterAdapter operator++(int) { + ReverseIterAdapter copy = *this; + --iter_; + return copy; + } + ReverseIterAdapter operator--(int) { + ReverseIterAdapter copy = *this; + ++iter_; + return copy; + } + ReverseIterAdapter operator+(difference_type offset) const { + return ReverseIterAdapter(iter_ - offset); + } + + template + inline std::enable_if_t, + typename T::difference_type> + operator-(const ReverseIterAdapter &rhs) const { + return rhs.iter_ - iter_; + } + + bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; } + bool operator!=(ReverseIterAdapter other) const { return !(*this == other); } + reference operator*() const { return Converter::convert(*iter_); } + +private: + TIter iter_; +}; + +/*! + * \brief Check if T is compatible with Any. + * + * \tparam T The type to check. + * \return True if T is compatible with Any, false otherwise. + */ +template +inline constexpr bool storage_enabled_v = std::is_same_v || TypeTraits::storage_enabled; + +/*! + * \brief Check if all T are compatible with Any. + * + * \tparam T The type to check. + * \return True if T is compatible with Any, false otherwise. + */ +template +inline constexpr bool all_storage_enabled_v = (storage_enabled_v && ...); + +/*! + * \brief Check if all T are compatible with Any. + * + * \tparam T The type to check. + * \return True if T is compatible with Any, false otherwise. + */ +template +inline constexpr bool all_object_ref_v = (std::is_base_of_v && ...); +/** + * \brief Check if Any storage of Derived can always be directly used as Base. + * + * \tparam Base The base type. + * \tparam Derived The derived type. + * \return True if Derived's storage can be used as Base's storage, false otherwise. + */ +template +inline constexpr bool type_contains_v = std::is_base_of_v || std::is_same_v; +// special case for Any +template +inline constexpr bool type_contains_v = true; + +/*! + * \brief Create a string of the container type. + * \tparam V The types of the elements in the container. + * \param name The name of the container type. + * \return A string of the container type. + */ +template +std::string ContainerTypeStr(const char *name) { + std::stringstream ss; + // helper to construct concated string of TypeStr + class TypeStrHelper { + public: + TypeStrHelper(std::stringstream &stream) : stream_(stream) {} // NOLINT(*) + + TypeStrHelper &operator<<(const std::string &str) { + if (counter_ > 0) { + stream_ << ", "; + } + stream_ << str; + counter_++; + return *this; + } + + private: + std::stringstream &stream_; // NOLINT(*) + int counter_ = 0; + }; + TypeStrHelper helper(ss); + ss << name << '<'; + (helper << ... << Type2Str::v()); + ss << '>'; + return ss.str(); +} + +} // namespace details +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/map.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/map.h new file mode 100644 index 000000000..b948fc8e4 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/map.h @@ -0,0 +1,1781 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/map.h + * \brief Runtime Map container types. + */ +#ifndef TVM_FFI_CONTAINER_MAP_H_ +#define TVM_FFI_CONTAINER_MAP_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +/// \cond Doxygen_Suppress +#if TVM_FFI_DEBUG_WITH_ABI_CHANGE +#define TVM_FFI_MAP_FAIL_IF_CHANGED() \ + TVM_FFI_ICHECK(state_marker == self->state_marker) << "Concurrent modification of the Map"; +#else +#define TVM_FFI_MAP_FAIL_IF_CHANGED() +#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE +/// \endcond + +/*! \brief Shared content of all specializations of hash map */ +class MapObj : public Object { + public: + /*! \brief Type of the keys in the hash map */ + using key_type = Any; + /*! \brief Type of the values in the hash map */ + using mapped_type = Any; + /*! \brief Type of value stored in the hash map */ + using KVType = std::pair; + /// \cond Doxygen_Suppress + /*! \brief Type of raw storage of the key-value pair in the hash map */ + struct KVRawStorageType { + TVMFFIAny first; + TVMFFIAny second; + }; + /// \endcond + /*! \brief Iterator class */ + class iterator; + + static_assert(std::is_standard_layout_v, "KVType is not standard layout"); + static_assert(sizeof(KVType) == 32, "sizeof(KVType) incorrect"); + + /// \cond Doxygen_Suppress + static constexpr const int32_t _type_index = TypeIndex::kTVMFFIMap; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIMap, MapObj, Object); + /// \endcond + + /*! + * \brief Number of elements in the MapObj + * \return The result + */ + size_t size() const { return size_; } + /*! + * \brief Count the number of times a key exists in the hash map + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type& key) const; + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const; + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key); + /*! \return begin iterator */ + iterator begin() const; + /*! \return end iterator */ + iterator end() const; + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const; + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position); + /*! + * \brief Erase the entry associated with the key, do nothing if not exists + * \param key The indexing key + */ + void erase(const key_type& key) { erase(find(key)); } + + /// \cond Doxygen_Suppress + class iterator { + public: + using iterator_category = std::forward_iterator_tag; + using difference_type = int64_t; + using value_type = KVType; + using pointer = KVType*; + using reference = KVType&; +/*! \brief Default constructor */ +#if TVM_FFI_DEBUG_WITH_ABI_CHANGE + iterator() : state_marker(0), index(0), self(nullptr) {} +#else + iterator() : index(0), self(nullptr) {} +#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE + /*! \brief Compare iterators */ + bool operator==(const iterator& other) const { + TVM_FFI_MAP_FAIL_IF_CHANGED() + return index == other.index && self == other.self; + } + /*! \brief Compare iterators */ + bool operator!=(const iterator& other) const { return !(*this == other); } + /*! \brief De-reference iterators */ + pointer operator->() const; + /*! \brief De-reference iterators */ + reference operator*() const { + TVM_FFI_MAP_FAIL_IF_CHANGED() + return *((*this).operator->()); + } + /*! \brief Prefix self increment, e.g. ++iter */ + iterator& operator++(); + /*! \brief Prefix self decrement, e.g. --iter */ + iterator& operator--(); + /*! \brief Suffix self increment */ + iterator operator++(int) { + TVM_FFI_MAP_FAIL_IF_CHANGED() + iterator copy = *this; + ++(*this); + return copy; + } + /*! \brief Suffix self decrement */ + iterator operator--(int) { + TVM_FFI_MAP_FAIL_IF_CHANGED() + iterator copy = *this; + --(*this); + return copy; + } + + protected: +#if TVM_FFI_DEBUG_WITH_ABI_CHANGE + uint64_t state_marker; + /*! \brief Construct by value */ + iterator(uint64_t index, const MapObj* self) + : state_marker(self->state_marker), index(index), self(self) {} + +#else + iterator(uint64_t index, const MapObj* self) : index(index), self(self) {} +#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE + /*! \brief The position on the array */ + uint64_t index; + /*! \brief The container it points to */ + const MapObj* self; + + friend class DenseMapObj; + friend class SmallMapObj; + }; + /// \endcond + /*! + * \brief Create an empty container + * \return The object created + */ + static inline ObjectPtr Empty(); + + protected: +#if TVM_FFI_DEBUG_WITH_ABI_CHANGE + uint64_t state_marker; +#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE + /*! + * \brief Create the map using contents from the given iterators. + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return ObjectPtr to the map created + */ + template + static inline ObjectPtr CreateFromRange(IterType first, IterType last); + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static inline void InsertMaybeReHash(KVType&& kv, ObjectPtr* map); + /*! + * \brief Create an empty container with elements copying from another SmallMapObj + * \param from The source container + * \return The object created + */ + static inline ObjectPtr CopyFrom(MapObj* from); + /*! + * \brief data pointer to the data region of the map. + * \note For immutable inplace small map we do not need data_, + * but we keep it here for future compact with mutable container. + */ + void* data_; + /*! \brief number of entries in the container */ + uint64_t size_; + /*! \brief number of slots */ + uint64_t slots_; + /*! + * \brief Small layout tag mask + * \note The most significant bit is used to indicate the small map layout. + */ + static constexpr uint64_t kSmallTagMask = static_cast(1) << 63; + /*! + * \brief Check if the map is a small map + * \return True if the map is a small map + */ + bool IsSmallMap() const { return (slots_ & kSmallTagMask) != 0ull; } + /*! + * \brief Optional data deleter when data is allocated separately + * and its deletion is not managed by MapObj::deleter_. + */ + void (*data_deleter_)(void*) = nullptr; + // Reference class + template + friend class Map; +}; + +/*! \brief A specialization of small-sized hash map */ +class SmallMapObj : public MapObj, + public details::InplaceArrayBase { + private: + static constexpr uint64_t kInitSize = 2; + static constexpr uint64_t kMaxSize = 4; + + public: + using MapObj::iterator; + using MapObj::KVType; + + // Return the number of usable slots for Small layout (mask off tag). + /*! + * \brief Return the number of usable slots for Small layout (mask off tag). + * \return The number of usable slots + */ + uint64_t NumSlots() const { return slots_ & ~kSmallTagMask; } + + ~SmallMapObj() { + KVType* begin = static_cast(data_); + for (uint64_t index = 0; index < size_; ++index) { + // call destructor to destroy the item in `begin + index` + // Explicit call Any::~Any() to destroy the Any object + // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) + (begin + index)->first.Any::~Any(); + (begin + index)->second.Any::~Any(); + } + if (data_deleter_ != nullptr) { + data_deleter_(data_); + } + } + /*! + * \brief Count the number of times a key exists in the SmallMapObj + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type& key) const { return find(key).index < size_; } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const { + iterator itr = find(key); + if (itr.index >= size_) { + TVM_FFI_THROW(KeyError) << "key is not in Map"; + } + return itr->second; + } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key) { + iterator itr = find(key); + if (itr.index >= size_) { + TVM_FFI_THROW(KeyError) << "key is not in Map"; + } + return itr->second; + } + /*! \return begin iterator */ + iterator begin() const { return iterator(0, this); } + /*! \return end iterator */ + iterator end() const { return iterator(size_, this); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const { + KVType* ptr = static_cast(data_); + for (uint64_t i = 0; i < size_; ++i, ++ptr) { + if (AnyEqual()(ptr->first, key)) { + return iterator(i, this); + } + } + return iterator(size_, this); + } + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position) { Erase(position.index); } + + private: + /*! + * \brief Set the number of slots and attach tags bit. + * \param n The number of slots + */ + void SetSlotsAndSmallLayoutTag(uint64_t n) { slots_ = (n & ~kSmallTagMask) | kSmallTagMask; } + /*! + * \brief Remove a position in SmallMapObj + * \param index The position to be removed + */ + void Erase(const uint64_t index) { + if (index >= size_) { + return; + } + KVType* begin = static_cast(data_); + // call destructor to destroy the item in `begin + index` + // Explicit call Any::~Any() to destroy the Any object + // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) + (begin + index)->first.Any::~Any(); + (begin + index)->second.Any::~Any(); + // IMPORTANT: We do direct raw memmove to bring later items to the current position + // to preserve the order of insertion. + // This works because direct memory copy preserves the Any's move semantics. + if (index + 1 < size_) { + std::memmove(reinterpret_cast(begin + index), + reinterpret_cast(begin + index + 1), + (size_ - index - 1) * sizeof(KVType)); + } + size_ -= 1; + } + /*! + * \brief Create an empty container + * \param n Number of empty slots + * \return The object created + */ + static ObjectPtr Empty(uint64_t n = kInitSize) { + using ::tvm::ffi::make_inplace_array_object; + ObjectPtr p = make_inplace_array_object(n); + p->data_ = p->AddressOf(0); + p->size_ = 0; + p->SetSlotsAndSmallLayoutTag(n); + return p; + } + /*! + * \brief Create an empty container initialized with a given range + * \param n Number of empty slots + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + * \return The object created + */ + template + static ObjectPtr CreateFromRange(uint64_t n, IterType first, IterType last) { + ObjectPtr p = Empty(n); + KVType* ptr = static_cast(p->data_); + for (; first != last; ++first, ++p->size_) { + new (ptr++) KVType(*first); + } + return p; + } + /*! + * \brief Create an empty container with elements copying from another SmallMapObj + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(SmallMapObj* from) { + KVType* first = static_cast(from->data_); + KVType* last = first + from->size_; + return CreateFromRange(from->size_, first, last); + } + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(KVType&& kv, ObjectPtr* map) { + SmallMapObj* map_node = static_cast(map->get()); + iterator itr = map_node->find(kv.first); + if (itr.index < map_node->size_) { + itr->second = kv.second; + return; + } + if (map_node->size_ < map_node->NumSlots()) { + KVType* ptr = static_cast(map_node->data_) + map_node->size_; + new (ptr) KVType(std::move(kv)); + ++map_node->size_; + return; + } + uint64_t next_size = std::max(map_node->NumSlots() * 2, kInitSize); + next_size = std::min(next_size, kMaxSize); + TVM_FFI_ICHECK_GT(next_size, map_node->NumSlots()); + ObjectPtr new_map = CreateFromRange(next_size, map_node->begin(), map_node->end()); + InsertMaybeReHash(std::move(kv), &new_map); + *map = std::move(new_map); + } + /*! + * \brief Increment the pointer + * \param index The pointer to be incremented + * \return The increased pointer + */ + uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; } + /*! + * \brief Decrement the pointer + * \param index The pointer to be decremented + * \return The decreased pointer + */ + uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; } + /*! + * \brief De-reference the pointer + * \param index The pointer to be dereferenced + * \return The result + */ + KVType* DeRefItr(uint64_t index) const { return static_cast(data_) + index; } + /*! \brief A size function used by InplaceArrayBase */ + uint64_t GetSize() const { return size_; } + + protected: + friend class MapObj; + friend class DenseMapObj; + friend class details::InplaceArrayBase; +}; + +/*! \brief A specialization of hash map that implements the idea of array-based hash map. + * Another reference implementation can be found [1]. + * + * A. Overview + * + * DenseMapObj did several improvements over traditional separate chaining hash, + * in terms of cache locality, memory footprints and data organization. + * + * A1. Implicit linked list. For better cache locality, instead of using linked list + * explicitly for each bucket, we store list data into a single array that spans contiguously + * in memory, and then carefully design access patterns to make sure most of them fall into + * a single cache line. + * + * A2. 1-byte metadata. There is only 1 byte overhead for each slot in the array to indexing and + * traversal. This can be divided in 3 parts. + * 1) Reserved code: (0b11111111)_2 indicates a slot is empty; (0b11111110)_2 indicates protected, + * which means the slot is empty but not allowed to be written. + * 2) If not empty or protected, the highest bit is used to indicate whether data in the slot is + * head of a linked list. + * 3) The rest 7 bits are used as the "next pointer" (i.e. pointer to the next element). On 64-bit + * architecture, an ordinary pointer can take up to 8 bytes, which is not acceptable overhead when + * dealing with 16-byte ObjectRef pairs. Based on a commonly noticed fact that the lists are + * relatively short (length <= 3) in hash maps, we follow [1]'s idea that only allows the pointer to + * be one of the 126 possible values, i.e. if the next element of i-th slot is (i + x)-th element, + * then x must be one of the 126 pre-defined values. + * + * A3. Data blocking. We organize the array in the way that every 16 elements forms a data block. + * The 16-byte metadata of those 16 elements are stored together, followed by the real data, i.e. + * 16 key-value pairs. + * + * B. Implementation details + * + * B1. Power-of-2 table size and Fibonacci Hashing. We use power-of-two as table size to avoid + * modulo for more efficient arithmetics. To make the hash-to-slot mapping distribute more evenly, + * we use the Fibonacci Hashing [2] trick. + * + * B2. Traverse a linked list in the array. + * 1) List head. Assume Fibonacci Hashing maps a given key to slot i, if metadata at slot i + * indicates that it is list head, then we found the head; otherwise the list is empty. No probing + * is done in this procedure. 2) Next element. To find the next element of a non-empty slot i, we + * look at the last 7 bits of the metadata at slot i. If they are all zeros, then it is the end of + * list; otherwise, we know that the next element is (i + candidates[the-last-7-bits]). + * + * B3. InsertMaybeReHash an element. Following B2, we first traverse the linked list to see if this + * element is in the linked list, and if not, we put it at the end by probing the next empty + * position in one of the 126 candidate positions. If the linked list does not even exist, but the + * slot for list head has been occupied by another linked list, we should find this intruder another + * place. + * + * B4. Quadratic probing with triangle numbers. In open address hashing, it is provable that probing + * with triangle numbers can traverse power-of-2-sized table [3]. In our algorithm, we follow the + * suggestion in [1] that also use triangle numbers for "next pointer" as well as sparing for list + * head. + * + * [1] https://github.com/skarupke/flat_hash_map + * [2] https://programmingpraxis.com/2018/06/19/fibonacci-hash/ + * [3] https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ + */ +class DenseMapObj : public MapObj { + private: + /*! \brief The number of elements in a memory block */ + static constexpr int kBlockCap = 16; + /*! \brief Maximum load factor of the hash map */ + static constexpr double kMaxLoadFactor = 0.99; + /*! \brief Binary representation of the metadata of an empty slot */ + static constexpr uint8_t kEmptySlot = static_cast(0b11111111); + /*! \brief Binary representation of the metadata of a protected slot */ + static constexpr uint8_t kProtectedSlot = static_cast(0b11111110); + /*! \brief Number of probing choices available */ + static constexpr int kNumJumpDists = 126; + /*! \brief Index indicator to indicate an invalid index */ + static constexpr uint64_t kInvalidIndex = std::numeric_limits::max(); + /*! \brief Head of the implicit linked list */ + struct ListNode; + /*! \brief item type of the dense map, including a kv data and prev/next pointer */ + struct ItemType { + KVType data; + uint64_t prev = kInvalidIndex; + uint64_t next = kInvalidIndex; + + explicit ItemType(KVType&& data) : data(std::move(data)) {} + explicit ItemType(key_type key, mapped_type value) : data(std::move(key), std::move(value)) {} + }; + /*! \brief POD type of a block of memory */ + struct Block { + uint8_t bytes[kBlockCap + kBlockCap * sizeof(ItemType)]; + }; + static_assert(sizeof(Block) == kBlockCap * (sizeof(ItemType) + 1), "sizeof(Block) incorrect"); + static_assert(std::is_standard_layout_v, "Block is not standard layout"); + + /*! + * \brief Deleter for the Block + * \param data The pointer to the Block + */ + static void BlockDeleter(void* data) { delete[] static_cast(data); } + + public: + using MapObj::iterator; + + /*! + * \brief Return the number of usable slots for Dense layout (MSB clear => identity). + * \return The number of usable slots + */ + uint64_t NumSlots() const { return slots_; } + + /*! + * \brief Destroy the DenseMapObj + */ + ~DenseMapObj() { this->Reset(); } + /*! \return The number of elements of the key */ + size_t count(const key_type& key) const { return !Search(key).IsNone(); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type& at(const key_type& key) const { return At(key); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type& at(const key_type& key) { return At(key); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type& key) const { + ListNode node = Search(key); + return node.IsNone() ? end() : iterator(node.index, this); + } + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator& position) { + uint64_t index = position.index; + if (position.self != nullptr && index <= this->NumSlots()) { + Erase(ListNode(index, this)); + } + } + /*! \return begin iterator */ + iterator begin() const { return iterator(iter_list_head_, this); } + /*! \return end iterator */ + iterator end() const { return iterator(kInvalidIndex, this); } + + private: + Block* GetBlock(size_t index) const { return static_cast(data_) + index; } + /*! + * \brief Unlink the entry from iterator list + * \param node The node to be unlinked + * \note This function is usually used before deletion, + * and it does not change data content of the node. + */ + void IterListUnlink(ListNode node) { + // update head and tail of iterator list if needed + if (node.Item().prev == kInvalidIndex) { + iter_list_head_ = node.Item().next; + } else { + ListNode prev_node(node.Item().prev, this); + prev_node.Item().next = node.Item().next; + } + if (node.Item().next == kInvalidIndex) { + iter_list_tail_ = node.Item().prev; + } else { + ListNode next_node(node.Item().next, this); + next_node.Item().prev = node.Item().prev; + } + } + /*! + * \brief Insert the entry into tail of iterator list + * \param node The node to be inserted + * \note this function does not change data content of the node. + */ + void IterListPushBack(ListNode node) { + node.Item().prev = iter_list_tail_; + node.Item().next = kInvalidIndex; + if (iter_list_tail_ != kInvalidIndex) { + ListNode prev_node(iter_list_tail_, this); + prev_node.Item().next = node.index; + } + if (iter_list_head_ == kInvalidIndex) { + iter_list_head_ = node.index; + } + iter_list_tail_ = node.index; + } + /*! + * \brief Replace node src by dst in the iter list + * \param src The source node + * \param dst The destination node, must be empty + * \note This function does not change data content of the nodes, + * which needs to be updated by the caller. + */ + void IterListReplaceNodeBy(ListNode src, ListNode dst) { + // set link correctly on the dst + dst.Item().prev = src.Item().prev; + dst.Item().next = src.Item().next; + // update prev and next of dst + if (dst.Item().prev == kInvalidIndex) { + iter_list_head_ = dst.index; + } else { + ListNode prev_node(dst.Item().prev, this); + prev_node.Item().next = dst.index; + } + if (dst.Item().next == kInvalidIndex) { + iter_list_tail_ = dst.index; + } else { + ListNode next_node(dst.Item().next, this); + next_node.Item().prev = dst.index; + } + } + /*! + * \brief Search for the given key + * \param key The key + * \return ListNode that associated with the key + */ + ListNode Search(const key_type& key) const { + if (this->size_ == 0) { + return ListNode(); + } + for (ListNode iter = GetListHead(AnyHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) { + if (AnyEqual()(key, iter.Key())) { + return iter; + } + } + return ListNode(); + } + /*! + * \brief Search for the given key, throw exception if not exists + * \param key The key + * \return ListNode that associated with the key + */ + mapped_type& At(const key_type& key) const { + ListNode iter = Search(key); + if (iter.IsNone()) { + TVM_FFI_THROW(KeyError) << "key is not in Map"; + } + return iter.Val(); + } + /*! + * \brief Try to insert a key, or do nothing if already exists + * \param key The indexing key + * \param result The linked-list entry found or just constructed + * \return A boolean, indicating if actual insertion happens + */ + bool TryInsert(const key_type& key, ListNode* result) { + if (slots_ == 0) { + return false; + } + // required that `iter` to be the head of a linked list through which we can iterator + ListNode iter = IndexFromHash(AnyHash()(key)); + // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list + // Case 1: empty + if (iter.IsEmpty()) { + iter.NewHead(ItemType(key, Any(nullptr))); + this->size_ += 1; + *result = iter; + return true; + } + // Case 2: body of an irrelevant list + if (!iter.IsHead()) { + // we move the elements around and construct the single-element linked list + return IsFull() ? false : TrySpareListHead(iter, key, result); + } + // Case 3: head of the relevant list + // we iterate through the linked list until the end + // make sure `iter` is the previous element of `next` + ListNode next = iter; + do { + // find equal item, do not insert + if (AnyEqual()(key, next.Key())) { + // we plan to take next, so we need to unlink it from iterator list + IterListUnlink(next); + *result = next; + return true; + } + // make sure `iter` is the previous element of `next` + iter = next; + } while (next.MoveToNext(this)); + // `iter` is the tail of the linked list + // always check capacity before insertion + if (IsFull()) { + return false; + } + // find the next empty slot + uint8_t jump; + if (!iter.GetNextEmpty(this, &jump, result)) { + return false; + } + result->NewTail(ItemType(key, Any(nullptr))); + // link `iter` to `empty`, and move forward + iter.SetJump(jump); + this->size_ += 1; + return true; + } + /*! + * \brief Spare an entry to be the head of a linked list. + * As described in B3, during insertion, it is possible that the entire linked list does not + * exist, but the slot of its head has been occupied by other linked lists. In this case, we need + * to spare the slot by moving away the elements to another valid empty one to make insertion + * possible. + * \param target The given entry to be spared + * \param key The indexing key + * \param result The linked-list entry constructed as the head + * \return A boolean, if actual insertion happens + */ + bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) { + // `target` is not the head of the linked list + // move the original item of `target` (if any) + // and construct new item on the position `target` + // To make `target` empty, we + // 1) find `w` the previous element of `target` in the linked list + // 2) copy the linked list starting from `r = target` + // 3) paste them after `w` + // read from the linked list after `r` + ListNode r = target; + // write to the tail of `w` + ListNode w = target.FindPrev(this); + // after `target` is moved, we disallow writing to the slot + bool is_first = true; + uint8_t r_meta, jump; + ListNode empty; + do { + // `jump` describes how `w` is jumped to `empty` + // rehash if there is no empty space after `w` + if (!w.GetNextEmpty(this, &jump, &empty)) { + return false; + } + // move `r` to `empty` + // first move the data over + empty.NewTail(ItemType(std::move(r.Data()))); + // then move link list chain of r to empty + // this needs to happen after NewTail so empty's prev/next get updated + IterListReplaceNodeBy(r, empty); + // explicit call destructor to destroy the item in `r` + r.DestructData(); + // clear the metadata of `r` + r_meta = r.Meta(); + if (is_first) { + is_first = false; + r.SetProtected(); + } else { + r.SetEmpty(); + } + // link `w` to `empty`, and move forward + w.SetJump(jump); + w = empty; + // move `r` forward as well + } while (r.MoveToNext(this, r_meta)); + // finally we have done moving the linked list + // fill data_ into `target` + target.NewHead(ItemType(key, Any(nullptr))); + this->size_ += 1; + *result = target; + return true; + } + /*! + * \brief Remove a ListNode + * \param iter The node to be removed + */ + void Erase(const ListNode& iter) { + this->size_ -= 1; + if (!iter.HasNext()) { + // `iter` is the last + if (!iter.IsHead()) { + // cut the link if there is any + iter.FindPrev(this).SetJump(0); + } + // unlink the node from iterator list + IterListUnlink(iter); + // IMPORTANT: must explicit call destructor `iter` to avoid memory leak + // This is because we need to recycle iter's data + iter.DestructData(); + // set the meta data to be empty + iter.SetEmpty(); + } else { + ListNode last = iter, prev = iter; + for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) { + } + // needs to first unlink iter from the list + IterListUnlink(iter); + // move data from last to iter + iter.Data() = std::move(last.Data()); + // Move link chain of iter to last as we stores last node to the new iter loc. + IterListReplaceNodeBy(last, iter); + // IMPORTANT: must explicit call destructor `last` to avoid memory leak + // likely we don't need this in this particular case because Any move behavior + // keep it here to be safe so code do not depend on specific move behavior of KVType + last.DestructData(); + // set the meta data to be empty + last.SetEmpty(); + prev.SetJump(0); + } + } + /*! \brief Clear the container to empty, release all entries and memory acquired */ + void Reset() { + uint64_t n_blocks = CalcNumBlocks(this->NumSlots()); + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t* meta_ptr = GetBlock(bi)->bytes; + ItemType* data_ptr = reinterpret_cast(GetBlock(bi)->bytes + kBlockCap); + for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { + uint8_t& meta = *meta_ptr; + if (meta != kProtectedSlot && meta != kEmptySlot) { + meta = kEmptySlot; + data_ptr->ItemType::~ItemType(); + } + } + } + ReleaseMemory(); + } + /*! \brief Release the memory acquired by the container without deleting its entries stored inside + */ + void ReleaseMemory() { + if (data_ != nullptr) { + TVM_FFI_ICHECK(data_deleter_ != nullptr); + data_deleter_(data_); + } + data_ = nullptr; + data_deleter_ = nullptr; + slots_ = 0; + size_ = 0; + fib_shift_ = 63; + } + /*! + * \brief Create an empty container + * \param fib_shift The fib shift provided + * \param n_slots Number of slots required, should be power-of-two + * \return The object created + */ + static ObjectPtr Empty(uint32_t fib_shift, uint64_t n_slots) { + TVM_FFI_ICHECK_GT(n_slots, uint64_t(SmallMapObj::kMaxSize)); + // Ensure even slot count (power-of-two expected by callers; this guard + // makes the method robust if a non-even value slips through). + ObjectPtr p = make_object(); + uint64_t n_blocks = CalcNumBlocks(n_slots); + Block* block = new Block[n_blocks]; + p->data_ = block; + // assign block deleter so even if we take re-alloc data + // in another shared-lib that may have different malloc/free behavior + // it will still be safe. + p->data_deleter_ = BlockDeleter; + p->SetSlotsAndDenseLayoutTag(n_slots); + p->size_ = 0; + p->fib_shift_ = fib_shift; + p->iter_list_head_ = kInvalidIndex; + p->iter_list_tail_ = kInvalidIndex; + for (uint64_t i = 0; i < n_blocks; ++i, ++block) { + std::fill(block->bytes, block->bytes + kBlockCap, kEmptySlot); + } + return p; + } + /*! + * \brief Create an empty container with elements copying from another DenseMapObj + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(DenseMapObj* from) { + ObjectPtr p = make_object(); + uint64_t n_blocks = CalcNumBlocks(from->NumSlots()); + p->data_ = new Block[n_blocks]; + // assign block deleter so even if we take re-alloc data + // in another shared-lib that may have different malloc/free behavior + // it will still be safe. + p->data_deleter_ = BlockDeleter; + p->SetSlotsAndDenseLayoutTag(from->NumSlots()); + p->size_ = from->size_; + p->fib_shift_ = from->fib_shift_; + p->iter_list_head_ = from->iter_list_head_; + p->iter_list_tail_ = from->iter_list_tail_; + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t* meta_ptr_from = from->GetBlock(bi)->bytes; + ItemType* data_ptr_from = reinterpret_cast(from->GetBlock(bi)->bytes + kBlockCap); + uint8_t* meta_ptr_to = p->GetBlock(bi)->bytes; + ItemType* data_ptr_to = reinterpret_cast(p->GetBlock(bi)->bytes + kBlockCap); + for (int j = 0; j < kBlockCap; + ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) { + uint8_t& meta = *meta_ptr_to = *meta_ptr_from; + TVM_FFI_ICHECK(meta != kProtectedSlot); + if (meta != kEmptySlot) { + new (data_ptr_to) ItemType(*data_ptr_from); + } + } + } + return p; + } + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(KVType&& kv, ObjectPtr* map) { + DenseMapObj* map_node = static_cast(map->get()); + ListNode iter; + // Try to insert. If succeed, we simply return + if (map_node->TryInsert(kv.first, &iter)) { + iter.Val() = std::move(kv.second); + // update the iter list relation + map_node->IterListPushBack(iter); + return; + } + TVM_FFI_ICHECK(!map_node->IsSmallMap()); + // Otherwise, start rehash + ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->NumSlots() * 2); + + // need to insert in the same order as the original map + for (uint64_t index = map_node->iter_list_head_; index != kInvalidIndex;) { + ListNode node(index, map_node); + // now try move src_data into the new map, note that src may still not + // be fully consumed into the call, but destructor will be called. + InsertMaybeReHash(std::move(node.Data()), &p); + // Important, needs to explicit call destructor in case move did remove + // node's internal item + index = node.Item().next; + // IMPORTANT: must explicit call destructor `node` to avoid memory leak + // We must call node.DestructData() here. + // This is because std::move() arguments in IterMaybeReHash may or may not + // explicitly move out the node.Data() + // Remove this call will cause memory leak very likely. + node.DestructData(); + } + InsertMaybeReHash(std::move(kv), &p); + map_node->ReleaseMemory(); + *map = p; + } + /*! + * \brief Check whether the hash table is full + * \return A boolean indicating whether hash table is full + */ + bool IsFull() const { // NOLINTNEXTLINE(bugprone-narrowing-conversions) + return (size_ + 1) > static_cast(NumSlots()) * kMaxLoadFactor; + } + /*! + * \brief Increment the pointer + * \param index The pointer to be incremented + * \return The increased pointer + */ + uint64_t IncItr(uint64_t index) const { + // keep at the end of iterator + if (index == kInvalidIndex) { + return index; + } + ListNode node(index, this); + return node.Item().next; + } + /*! + * \brief Decrement the pointer + * \param index The pointer to be decremented + * \return The decreased pointer + */ + uint64_t DecItr(uint64_t index) const { + // this is the end iterator, we need to return tail. + if (index == kInvalidIndex) { + return iter_list_tail_; + } + // circle around the iterator list, which is OK + ListNode node(index, this); + return node.Item().prev; + } + /*! + * \brief De-reference the pointer + * \param index The pointer to be dereferenced + * \return The result + */ + KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); } + /*! \brief Construct from hash code */ + ListNode IndexFromHash(uint64_t hash_value) const { + return ListNode(FibHash(hash_value, fib_shift_), this); + } + /*! \brief Construct from hash code if the position is head of list */ + ListNode GetListHead(uint64_t hash_value) const { + ListNode node = IndexFromHash(hash_value); + return node.IsHead() ? node : ListNode(); + } + /*! \brief Construct the number of blocks in the hash table */ + static uint64_t CalcNumBlocks(uint64_t n_slots) { return (n_slots + kBlockCap - 1) / kBlockCap; } + /*! + * \brief Calculate the power-of-2 table size given the lower-bound of required capacity. + * \param cap The lower-bound of the required capacity + * \param fib_shift The result shift for Fibonacci Hashing + * \param n_slots The result number of slots + */ + static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) { + uint32_t shift = 64; + uint64_t slots = 1; + for (uint64_t c = cap; c; c >>= 1) { + shift -= 1; + slots <<= 1; + } + TVM_FFI_ICHECK_GT(slots, cap); + if (slots < cap * 2) { + *fib_shift = shift - 1; + *n_slots = slots << 1; + } else { + *fib_shift = shift; + *n_slots = slots; + } + } + /*! + * \brief Fibonacci Hashing, maps a hash code to an index in a power-of-2-sized table. + * See also: https://programmingpraxis.com/2018/06/19/fibonacci-hash/. + * \param hash_value The raw hash value + * \param fib_shift The shift in Fibonacci Hashing + * \return An index calculated using Fibonacci Hashing + */ + static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) { + constexpr uint64_t coeff = 11400714819323198485ull; + return (coeff * hash_value) >> fib_shift; + } + /*! \brief The implicit in-place linked list used to index a chain */ + struct ListNode { + /*! \brief Construct None */ + ListNode() : index(0), block(nullptr) {} + /*! \brief Construct from position */ + ListNode(uint64_t index, const DenseMapObj* self) + : index(index), block(self->GetBlock(index / kBlockCap)) {} + /*! \brief Metadata on the entry */ + uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); } + /*! \brief Data on the entry */ + ItemType& Item() const { + return *(reinterpret_cast(block->bytes + kBlockCap + + (index % kBlockCap) * sizeof(ItemType))); + } + /*! \brief Data on the entry */ + KVType& Data() const { return Item().data; } + /*! \brief Key on the entry */ + key_type& Key() const { return Data().first; } + /*! \brief Value on the entry */ + mapped_type& Val() const { return Data().second; } + /*! \brief If the entry is head of linked list */ + bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; } + /*! \brief If the entry is none */ + bool IsNone() const { return block == nullptr; } + /*! \brief If the entry is empty slot */ + bool IsEmpty() const { return Meta() == kEmptySlot; } + /*! \brief If the entry is protected slot */ + bool IsProtected() const { return Meta() == kProtectedSlot; } + /*! \brief Set the entry to be empty */ + void SetEmpty() const { Meta() = kEmptySlot; } + /*! \brief Destruct the item in the entry */ + void DestructData() const { + // explicit call destructor to destroy the item + // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) + (&Data())->first.Any::~Any(); + (&Data())->second.Any::~Any(); + } + /*! \brief Set the entry to be protected */ + void SetProtected() const { Meta() = kProtectedSlot; } + /*! \brief Set the entry's jump to its next entry */ + void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; } + /*! \brief Construct a head of linked list in-place */ + void NewHead(ItemType v) const { + Meta() = 0b00000000; + new (&Item()) ItemType(std::move(v)); + } + /*! \brief Construct a tail of linked list in-place */ + void NewTail(ItemType v) const { + Meta() = 0b10000000; + new (&Item()) ItemType(std::move(v)); + } + + /*! \brief If the entry has next entry on the linked list */ + bool HasNext() const { return NextProbeLocation(Meta() & 0b01111111) != 0; } + /*! \brief Move the entry to the next entry on the linked list */ + bool MoveToNext(const DenseMapObj* self, uint8_t meta) { + uint64_t offset = NextProbeLocation(meta & 0b01111111); + if (offset == 0) { + index = 0; + block = nullptr; + return false; + } + // the probing will go to next position and round back to stay within the + // correct range of the slots + index = (index + offset) % self->NumSlots(); + block = self->GetBlock(index / kBlockCap); + return true; + } + /*! \brief Move the entry to the next entry on the linked list */ + bool MoveToNext(const DenseMapObj* self) { return MoveToNext(self, Meta()); } + /*! \brief Get the previous entry on the linked list */ + ListNode FindPrev(const DenseMapObj* self) const { + // start from the head of the linked list, which must exist + ListNode next = self->IndexFromHash(AnyHash()(Key())); + // `prev` is always the previous item of `next` + ListNode prev = next; + for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) { + } + return prev; + } + /*! \brief Get the next empty jump */ + bool GetNextEmpty(const DenseMapObj* self, uint8_t* jump, ListNode* result) const { + for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) { + // the probing will go to next position and round back to stay within the + // correct range of the slots + ListNode candidate((index + NextProbeLocation(idx)) % self->NumSlots(), self); + if (candidate.IsEmpty()) { + *jump = idx; + *result = candidate; + return true; + } + } + return false; + } + /*! \brief Index on the real array */ + uint64_t index; + /*! \brief Pointer to the actual block */ + Block* block; + }; + + protected: + /*! \brief fib shift in Fibonacci Hashing */ + uint32_t fib_shift_; + /*! \brief the head of iterator list */ + uint64_t iter_list_head_ = kInvalidIndex; + /*! \brief the tail of iterator list */ + uint64_t iter_list_tail_ = kInvalidIndex; + + static uint64_t NextProbeLocation(size_t index) { + /* clang-format off */ + /*! \brief Candidates of probing distance */ + static const uint64_t kNextProbeLocation[kNumJumpDists] { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + // Quadratic probing with triangle numbers. See also: + // 1) https://en.wikipedia.org/wiki/Quadratic_probing + // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ + // 3) https://github.com/skarupke/flat_hash_map + 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, + 136, 153, 171, 190, 210, 231, 253, 276, 300, 325, + 351, 378, 406, 435, 465, 496, 528, 561, 595, 630, + 666, 703, 741, 780, 820, 861, 903, 946, 990, 1035, + 1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540, + 1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145, + 2211, 2278, 2346, 2415, 2485, 2556, 2628, + // larger triangle numbers + 8515, 19110, 42778, 96141, 216153, + 486591, 1092981, 2458653, 5532801, 12442566, + 27993903, 62983476, 141717030, 318844378, 717352503, + 1614057336, 3631522476, 8170957530, 18384510628, 41364789378, + 93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695, + 5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000, + 309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701, + 17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, + 457381325854679626, 1029107982097042876, 2315492959180353330, 5209859154120846435, + }; + /* clang-format on */ + return kNextProbeLocation[index]; + } + friend class MapObj; + + private: + /*! + * \brief Set the number of slots and attach tags bit. + * \param n The number of slots + */ + void SetSlotsAndDenseLayoutTag(uint64_t n) { + TVM_FFI_ICHECK(((n & kSmallTagMask) == 0ull)) << "DenseMap expects MSB clear"; + slots_ = n; + } +}; + +/// \cond +#define TVM_FFI_DISPATCH_MAP(base, var, body) \ + { \ + using TSmall = SmallMapObj*; \ + using TDense = DenseMapObj*; \ + if ((base)->IsSmallMap()) { \ + TSmall var = static_cast((base)); \ + body; \ + } else { \ + TDense var = static_cast((base)); \ + body; \ + } \ + } + +#define TVM_FFI_DISPATCH_MAP_CONST(base, var, body) \ + { \ + using TSmall = const SmallMapObj*; \ + using TDense = const DenseMapObj*; \ + if ((base)->IsSmallMap()) { \ + TSmall var = static_cast((base)); \ + body; \ + } else { \ + TDense var = static_cast((base)); \ + body; \ + } \ + } + +inline MapObj::iterator::pointer MapObj::iterator::operator->() const { + TVM_FFI_MAP_FAIL_IF_CHANGED() + TVM_FFI_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); +} + +inline MapObj::iterator& MapObj::iterator::operator++() { + TVM_FFI_MAP_FAIL_IF_CHANGED() + TVM_FFI_DISPATCH_MAP_CONST(self, p, { + index = p->IncItr(index); + return *this; + }); +} + +inline MapObj::iterator& MapObj::iterator::operator--() { + TVM_FFI_MAP_FAIL_IF_CHANGED() + TVM_FFI_DISPATCH_MAP_CONST(self, p, { + index = p->DecItr(index); + return *this; + }); +} + +inline size_t MapObj::count(const key_type& key) const { + TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->count(key); }); +} + +inline const MapObj::mapped_type& MapObj::at(const MapObj::key_type& key) const { + TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->at(key); }); +} + +inline MapObj::mapped_type& MapObj::at(const MapObj::key_type& key) { + TVM_FFI_DISPATCH_MAP(this, p, { return p->at(key); }); +} + +inline MapObj::iterator MapObj::begin() const { + TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->begin(); }); +} + +inline MapObj::iterator MapObj::end() const { + TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->end(); }); +} + +inline MapObj::iterator MapObj::find(const MapObj::key_type& key) const { + TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->find(key); }); +} + +inline void MapObj::erase(const MapObj::iterator& position) { + TVM_FFI_DISPATCH_MAP(this, p, { return p->erase(position); }); +} +/// \endcond + +#undef TVM_FFI_DISPATCH_MAP +#undef TVM_FFI_DISPATCH_MAP_CONST + +inline ObjectPtr MapObj::Empty() { return SmallMapObj::Empty(); } + +inline ObjectPtr MapObj::CopyFrom(MapObj* from) { + if (from->IsSmallMap()) { + return SmallMapObj::CopyFrom(static_cast(from)); + } else { + return DenseMapObj::CopyFrom(static_cast(from)); + } +} + +template +inline ObjectPtr MapObj::CreateFromRange(IterType first, IterType last) { + int64_t _cap = std::distance(first, last); + if (_cap < 0) { + return SmallMapObj::Empty(); + } + uint64_t cap = static_cast(_cap); + if (cap < SmallMapObj::kMaxSize) { + if (cap < 2) { + return SmallMapObj::CreateFromRange(cap, first, last); + } + // need to insert to avoid duplicate keys + ObjectPtr obj = SmallMapObj::Empty(cap); + for (; first != last; ++first) { + KVType kv(*first); + SmallMapObj::InsertMaybeReHash(std::move(kv), &obj); + } + return obj; + } else { + uint32_t fib_shift; + uint64_t n_slots; + DenseMapObj::CalcTableSize(cap, &fib_shift, &n_slots); + ObjectPtr obj = DenseMapObj::Empty(fib_shift, n_slots); + for (; first != last; ++first) { + KVType kv(*first); + DenseMapObj::InsertMaybeReHash(std::move(kv), &obj); + } + return obj; + } +} + +inline void MapObj::InsertMaybeReHash(KVType&& kv, ObjectPtr* map) { + MapObj* base = static_cast(map->get()); +#if TVM_FFI_DEBUG_WITH_ABI_CHANGE + base->state_marker++; +#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE + if (base->IsSmallMap()) { + SmallMapObj* sm = static_cast(base); + if (sm->NumSlots() < SmallMapObj::kMaxSize) { + SmallMapObj::InsertMaybeReHash(std::move(kv), map); + } else if (sm->NumSlots() == SmallMapObj::kMaxSize) { + if (base->size_ < sm->NumSlots()) { + SmallMapObj::InsertMaybeReHash(std::move(kv), map); + } else { + ObjectPtr new_map = MapObj::CreateFromRange(base->begin(), base->end()); + DenseMapObj::InsertMaybeReHash(std::move(kv), &new_map); + *map = std::move(new_map); + } + } + } else { + DenseMapObj::InsertMaybeReHash(std::move(kv), map); + } +} + +/// \cond Doxygen_Suppress +/*! + * \brief Specialize make_object to be deleted for make_object and + * make_object only. + */ +template <> +inline ObjectPtr make_object<>() = delete; +/// \endcond + +/*! + * \brief Map container of NodeRef->NodeRef in DSL graph. + * Map implements copy on write semantics, which means map is mutable + * but copy will happen when array is referenced in more than two places. + * + * operator[] only provide const acces, use Set to mutate the content. + * \tparam K The key NodeRef type. + * \tparam V The value NodeRef type. + */ +template && + details::storage_enabled_v>> +class Map : public ObjectRef { + public: + /*! \brief The key type of the map */ + using key_type = K; + /*! \brief The mapped type of the map */ + using mapped_type = V; + /*! \brief The iterator type of the map */ + class iterator; + /*! + * \brief Construct an Map with UnsafeInit + */ + explicit Map(UnsafeInit tag) : ObjectRef(tag) {} + /*! + * \brief default constructor + */ + Map() { data_ = MapObj::Empty(); } + /*! + * \brief move constructor + * \param other source + */ + Map(Map&& other) // NOLINT(google-explicit-constructor) + : ObjectRef(std::move(other.data_)) {} + /*! + * \brief copy constructor + * \param other source + */ + Map(const Map& other) // NOLINT(google-explicit-constructor) + : ObjectRef(other.data_) {} + + /*! + * \brief Move constructor + * \param other The other map + * \tparam KU The key type of the other map + * \tparam VU The mapped type of the other map + */ + template && + details::type_contains_v>> + Map(Map&& other) // NOLINT(google-explicit-constructor) + : ObjectRef(std::move(other.data_)) {} + + /*! + * \brief Copy constructor + * \param other The other map + * \tparam KU The key type of the other map + * \tparam VU The mapped type of the other map + */ + template && + details::type_contains_v>> + Map(const Map& other) : ObjectRef(other.data_) {} // NOLINT(google-explicit-constructor) + + /*! + * \brief Move assignment + * \param other The other map + */ + Map& operator=(Map&& other) { + data_ = std::move(other.data_); + return *this; + } + + /*! + * \brief Copy assignment + * \param other The other map + */ + Map& operator=(const Map& other) { + data_ = other.data_; + return *this; + } + + /*! + * \brief Move assignment + * \param other The other map + * \tparam KU The key type of the other map + * \tparam VU The mapped type of the other map + */ + template && + details::type_contains_v>> + Map& operator=(Map&& other) { + data_ = std::move(other.data_); + return *this; + } + + /*! + * \brief Copy assignment + * \param other The other map + * \tparam KU The key type of the other map + * \tparam VU The mapped type of the other map + */ + template && + details::type_contains_v>> + Map& operator=(const Map& other) { + data_ = other.data_; + return *this; + } + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Map(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor from iterator + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template + Map(IterType begin, IterType end) { + data_ = MapObj::CreateFromRange(begin, end); + } + /*! + * \brief constructor from initializer list + * \param init The initalizer list + */ + Map(std::initializer_list> init) { + data_ = MapObj::CreateFromRange(init.begin(), init.end()); + } + /*! + * \brief constructor from unordered_map + * \param init The unordered_map + */ + template + Map(const std::unordered_map& init) { // NOLINT(*) + data_ = MapObj::CreateFromRange(init.begin(), init.end()); + } + /*! + * \brief Read element from map. + * \param key The key + * \return the corresonding element. + */ + const V at(const K& key) const { + return details::AnyUnsafe::CopyFromAnyViewAfterCheck(GetMapObj()->at(key)); + } + /*! + * \brief Read element from map. + * \param key The key + * \return the corresonding element. + */ + const V operator[](const K& key) const { return this->at(key); } + /*! \return The size of the array */ + size_t size() const { + MapObj* n = GetMapObj(); + return n == nullptr ? 0 : n->size(); + } + /*! \return The number of elements of the key */ + size_t count(const K& key) const { + MapObj* n = GetMapObj(); + return n == nullptr ? 0 : GetMapObj()->count(key); + } + /*! \return whether array is empty */ + bool empty() const { return size() == 0; } + /*! \brief Release reference to all the elements */ + void clear() { + MapObj* n = GetMapObj(); + if (n != nullptr) { + data_ = MapObj::Empty(); + } + } + /*! + * \brief set the Map. + * \param key The index key. + * \param value The value to be setted. + */ + void Set(const K& key, const V& value) { + CopyOnWrite(); + MapObj::InsertMaybeReHash(MapObj::KVType(key, value), &data_); + } + /*! \return begin iterator */ + iterator begin() const { return iterator(GetMapObj()->begin()); } + /*! \return end iterator */ + iterator end() const { return iterator(GetMapObj()->end()); } + /*! \return find the key and returns the associated iterator */ + iterator find(const K& key) const { return iterator(GetMapObj()->find(key)); } + /*! \return The value associated with the key, std::nullopt if not found */ + std::optional Get(const K& key) const { + MapObj::iterator iter = GetMapObj()->find(key); + if (iter == GetMapObj()->end()) { + return std::nullopt; + } + return details::AnyUnsafe::CopyFromAnyViewAfterCheck(iter->second); + } + + /*! + * \brief Erase the entry associated with the key + * \param key The key + */ + void erase(const K& key) { CopyOnWrite()->erase(key); } + + /*! + * \brief copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which guarantees to be unique) + */ + MapObj* CopyOnWrite() { + if (data_.get() == nullptr) { + data_ = MapObj::Empty(); + } else if (!data_.unique()) { + data_ = MapObj::CopyFrom(GetMapObj()); + } + return GetMapObj(); + } + /*! \brief specify container node */ + using ContainerType = MapObj; + + /// \cond Doxygen_Suppress + /*! \brief Iterator of the hash map */ + class iterator { + public: + using iterator_category = std::bidirectional_iterator_tag; + using difference_type = int64_t; + using value_type = const std::pair; + using pointer = value_type*; + using reference = value_type; + + iterator() : itr() {} + + /*! \brief Compare iterators */ + bool operator==(const iterator& other) const { return itr == other.itr; } + /*! \brief Compare iterators */ + bool operator!=(const iterator& other) const { return itr != other.itr; } + /*! \brief De-reference iterators is not allowed */ + pointer operator->() const = delete; + /*! \brief De-reference iterators */ + reference operator*() const { + auto& kv = *itr; + return std::make_pair(details::AnyUnsafe::CopyFromAnyViewAfterCheck(kv.first), + details::AnyUnsafe::CopyFromAnyViewAfterCheck(kv.second)); + } + /*! \brief Prefix self increment, e.g. ++iter */ + iterator& operator++() { + ++itr; + return *this; + } + /*! \brief Suffix self increment */ + iterator operator++(int) { + iterator copy = *this; + ++(*this); + return copy; + } + + /*! \brief Prefix self decrement, e.g. --iter */ + iterator& operator--() { + --itr; + return *this; + } + /*! \brief Suffix self decrement */ + iterator operator--(int) { + iterator copy = *this; + --(*this); + return copy; + } + + private: + iterator(const MapObj::iterator& itr) // NOLINT(*) + : itr(itr) {} + + template + friend class Map; + + MapObj::iterator itr; + }; + /// \endcond + + private: + /*! \brief Return data_ as type of pointer of MapObj */ + MapObj* GetMapObj() const { return static_cast(data_.get()); } + + template + friend class Map; +}; + +/*! + * \brief Merge two Maps. + * \param lhs the first Map to merge. + * \param rhs the second Map to merge. + * @return The merged Array. Original Maps are kept unchanged. + */ +template && + details::storage_enabled_v>> +inline Map Merge(Map lhs, const Map& rhs) { + for (const auto& p : rhs) { + lhs.Set(p.first, p.second); + } + return std::move(lhs); +} + +// Traits for Map +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> : public ObjectRefTypeTraitsBase> { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIMap; + using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; + + TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { + if (src->type_index != TypeIndex::kTVMFFIMap) { + return TypeTraitsBase::GetMismatchTypeInfo(src); + } + if constexpr (!std::is_same_v || !std::is_same_v) { + const MapObj* n = reinterpret_cast(src->v_obj); + for (const auto& kv : *n) { + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStrict(kv.first) && + !kv.first.try_cast().has_value()) { + return "Map[some key is " + details::AnyUnsafe::GetMismatchTypeInfo(kv.first) + + ", V]"; + } + } + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStrict(kv.second) && + !kv.second.try_cast().has_value()) { + return "Map[K, some value is " + details::AnyUnsafe::GetMismatchTypeInfo(kv.second) + + "]"; + } + } + } + } + TVM_FFI_THROW(InternalError) << "Cannot reach here"; + TVM_FFI_UNREACHABLE(); + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { + if (src->type_index != TypeIndex::kTVMFFIMap) return false; + if constexpr (std::is_same_v && std::is_same_v) { + return true; + } else { + const MapObj* n = reinterpret_cast(src->v_obj); + for (const auto& kv : *n) { + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStrict(kv.first)) return false; + } + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStrict(kv.second)) return false; + } + } + return true; + } + } + + TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { + if (src->type_index != TypeIndex::kTVMFFIMap) return std::nullopt; + if constexpr (!std::is_same_v || !std::is_same_v) { + const MapObj* n = reinterpret_cast(src->v_obj); + bool storage_check = [&]() { + for (const auto& kv : *n) { + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStrict(kv.first)) return false; + } + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStrict(kv.second)) return false; + } + } + return true; + }(); + // fast path, if storage check passes, we can return the array directly. + if (storage_check) return CopyFromAnyViewAfterCheck(src); + // slow path, we need to create a new map and convert to the target type. + Map ret; + for (const auto& kv : *n) { + auto k = kv.first.try_cast(); + auto v = kv.second.try_cast(); + if (!k.has_value() || !v.has_value()) return std::nullopt; + ret.Set(*std::move(k), *std::move(v)); + } + return ret; + } else { + return CopyFromAnyViewAfterCheck(src); + } + } + + TVM_FFI_INLINE static std::string TypeStr() { + return "Map<" + details::Type2Str::v() + ", " + details::Type2Str::v() + ">"; + } + TVM_FFI_INLINE static std::string TypeSchema() { + std::ostringstream oss; + oss << R"({"type":")" << StaticTypeKey::kTVMFFIMap << R"(","args":[)"; + oss << details::TypeSchema::v() << ","; + oss << details::TypeSchema::v(); + oss << "]}"; + return oss.str(); + } +}; + +namespace details { +template +inline constexpr bool type_contains_v, Map> = + type_contains_v && type_contains_v; +} // namespace details + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CONTAINER_MAP_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/shape.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/shape.h new file mode 100644 index 000000000..074fefbc9 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/shape.h @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/shape.h + * \brief Container to store shape of an Tensor. + */ +#ifndef TVM_FFI_CONTAINER_SHAPE_H_ +#define TVM_FFI_CONTAINER_SHAPE_H_ + +#include "../error.h" +#include "../type_traits.h" +#include "array.h" + +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief Lightweight view non-owning class for shape. + */ +class ShapeView { +public: + /*! \brief Default constructor. */ + ShapeView() : cell_{nullptr, 0} {} + /*! \brief Copy constructor. */ + ShapeView(const ShapeView &other) = default; + /*! \brief Copy assignment operator. */ + ShapeView &operator=(const ShapeView &other) = default; + /*! \brief Move constructor. */ + ShapeView(ShapeView &&other) = default; + /*! \brief Move assignment operator. */ + ShapeView &operator=(ShapeView &&other) = default; + /*! \brief Constructor from data and size. */ + ShapeView(const int64_t *data, size_t size) : cell_{data, size} {} + /*! \brief Constructor from initializer list. */ + ShapeView(const std::initializer_list &other) : cell_{other.begin(), other.size()} {} + /*! \brief Get the data pointer. */ + const int64_t *data() const { return cell_.data; } + /*! \brief Get the size of the shape. */ + size_t size() const { return cell_.size; } + + /*! \brief Get the product of the shape. */ + int64_t Product() const { + int64_t product = 1; + for (size_t i = 0; i < cell_.size; ++i) { + product *= cell_.data[i]; + } + return product; + } + + /*! \brief Get the i-th element of the shape. */ + int64_t operator[](size_t idx) const { return cell_.data[idx]; } + + /*! \return begin iterator */ + const int64_t *begin() const { return cell_.data; } + + /*! \return end iterator */ + const int64_t *end() const { return cell_.data + cell_.size; } + + /*! \return Whether shape tuple is empty */ + bool empty() const { return size() == 0; } + + /*! \return The first element of the shape tuple */ + int64_t front() const { return this->at(0); } + + /*! \return The last element of the shape tuple */ + int64_t back() const { return this->at(this->size() - 1); } + + /*! \brief Get the i-th element of the shape. */ + int64_t at(size_t idx) const { + if (idx >= this->size()) { + TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size " << this->size(); + } + return cell_.data[idx]; + } + +private: + TVMFFIShapeCell cell_; +}; + +/*! \brief An object representing a shape tuple. */ +class ShapeObj : public Object, public TVMFFIShapeCell { +public: + /*! \brief The type of shape index element. */ + using index_type = int64_t; + + /*! \brief Get "numel", meaning the number of elements of an array if the array has this shape */ + int64_t Product() const { + int64_t product = 1; + for (size_t i = 0; i < this->size; ++i) { + product *= this->data[i]; + } + return product; + } + + /// \cond Doxygen_Suppress + static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIShape; + TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIShape, ShapeObj, Object); + /// \endcond +}; + +namespace details { + +class ShapeObjStdImpl : public ShapeObj { +public: + explicit ShapeObjStdImpl(std::vector other) : data_{std::move(other)} { + this->data = data_.data(); + this->size = static_cast(data_.size()); + } + +private: + std::vector data_; +}; + +TVM_FFI_INLINE ObjectPtr MakeEmptyShape(size_t length, int64_t **mutable_data) { + ObjectPtr p = make_inplace_array_object(length); + static_assert(alignof(ShapeObj) % alignof(int64_t) == 0); + static_assert(sizeof(ShapeObj) % alignof(int64_t) == 0); + int64_t *data = reinterpret_cast(reinterpret_cast(p.get()) + sizeof(ShapeObj)); + if (mutable_data) { + *mutable_data = data; + } + p->data = data; + p->size = length; + return p; +} + +// inplace shape allocation +template +TVM_FFI_INLINE ObjectPtr MakeInplaceShape(IterType begin, IterType end) { + size_t length = std::distance(begin, end); + int64_t *mutable_data; + ObjectPtr p = MakeEmptyShape(length, &mutable_data); + std::copy(begin, end, mutable_data); + return p; +} + +/*! + * \brief Get the product of a shape. + * \param shape The input shape. + * \param out_strides The output strides. + * \return The product of the shape. + */ +TVM_FFI_INLINE void FillStridesFromShape(ShapeView shape, int64_t *out_strides) { + int64_t stride = 1; + for (int64_t i = static_cast(shape.size()) - 1; i >= 0; --i) { + out_strides[i] = stride; + stride *= shape[i]; + } +} + +/*! + * \brief Make a strides shape from a shape view. + * \param shape The input shape. + * \return The shape. + */ +TVM_FFI_INLINE ObjectPtr MakeStridesFromShape(ShapeView shape) { + int64_t *strides_data; + ObjectPtr strides = details::MakeEmptyShape(shape.size(), &strides_data); + FillStridesFromShape(shape, strides_data); + return strides; +} + +} // namespace details + +/*! + * \brief Managed reference to shape object. + * + * When possible, use ShapeView instead of Shape to reduce memory allocation. + * Use Shape when you need to have a managed reference to on-heap allocated shape. + * + * \sa ShapeView + */ +class Shape : public ObjectRef { +public: + /*! \brief The type of shape index element. */ + using index_type = ShapeObj::index_type; + + /*! \brief Default constructor */ + Shape() : ObjectRef(details::MakeEmptyShape(0, nullptr)) {} + + /*! + * \brief Constructor from iterator + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template + Shape(IterType begin, IterType end) : Shape(details::MakeInplaceShape(begin, end)) {} + + /** + * \brief Constructor from Array + * \param shape The Array + * + * \note This constructor will copy the data content. + */ + Shape(Array shape) // NOLINT(*) + : Shape(shape.begin(), shape.end()) {} + + /*! + * \brief constructor from initializer list + * \param shape The initializer list + */ + Shape(std::initializer_list shape) : Shape(shape.begin(), shape.end()) {} + + /*! + * \brief constructor from int64_t [N] + * + * \param other a int64_t array. + */ + Shape(std::vector other) // NOLINT(*) + : ObjectRef(make_object(std::move(other))) {} + + /*! + * \brief constructor from shape view. + * \param other The shape view. + */ + Shape(ShapeView other) : Shape(other.begin(), other.end()) {} // NOLINT(*) + + /*! + * \brief Create a strides from a shape. + * \param shape The input shape. + * \return The strides. + */ + static Shape StridesFromShape(ShapeView shape) { + return Shape(details::MakeStridesFromShape(shape)); + } + + /*! + * \brief Convert to shape view. + * \return The shape view. + */ + operator ShapeView() const { return ShapeView(data(), size()); } // NOLINT(*) + + /*! + * \brief Return the data pointer + * + * \return const index_type* data pointer + */ + const int64_t *data() const { return get()->data; } + + /*! + * \brief Return the size of the shape tuple + * + * \return size_t shape tuple size + */ + size_t size() const { return get()->size; } + + /*! + * \brief Immutably read i-th element from the shape tuple. + * \param idx The index + * \return the i-th element. + */ + int64_t operator[](size_t idx) const { return this->data()[idx]; } + + /*! + * \brief Immutably read i-th element from the shape tuple. + * \param idx The index + * \return the i-th element. + */ + int64_t at(size_t idx) const { + if (idx >= this->size()) { + TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size " << this->size(); + } + return this->operator[](idx); + } + + /*! \return Whether shape tuple is empty */ + bool empty() const { return size() == 0; } + + /*! \return The first element of the shape tuple */ + int64_t front() const { return this->at(0); } + + /*! \return The last element of the shape tuple */ + int64_t back() const { return this->at(this->size() - 1); } + + /*! \return begin iterator */ + const int64_t *begin() const { return get()->data; } + + /*! \return end iterator */ + const int64_t *end() const { return (get()->data + size()); } + + /*! \return The product of the shape tuple */ + int64_t Product() const { return get()->Product(); } + + /// \cond Doxygen_Suppress + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Shape, ObjectRef, ShapeObj); + /// \endcond + +private: + explicit Shape(ObjectPtr ptr) : ObjectRef(std::move(ptr)) {} +}; + +inline std::ostream &operator<<(std::ostream &os, const Shape &shape) { + os << '['; + for (size_t i = 0; i < shape.size(); ++i) { + if (i != 0) { + os << ", "; + } + os << shape[i]; + } + os << ']'; + return os; +} + +// Shape +template <> +inline constexpr bool use_default_type_traits_v = false; + +// Allow auto conversion from Array to Shape, but not from Shape to Array +template <> +struct TypeTraits : public ObjectRefWithFallbackTraitsBase> { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIShape; + TVM_FFI_INLINE static Shape ConvertFallbackValue(Array src) { + return Shape(std::move(src)); + } +}; + +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_CONTAINER_SHAPE_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tensor.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tensor.h new file mode 100644 index 000000000..b9533c118 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tensor.h @@ -0,0 +1,785 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/tensor.h + * \brief Container to store a Tensor. + */ +#ifndef TVM_FFI_CONTAINER_TENSOR_H_ +#define TVM_FFI_CONTAINER_TENSOR_H_ + +#include "../dtype.h" +#include "../error.h" +#include "../type_traits.h" +#include "shape.h" + +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +class Tensor; + +/*! + * \brief Check if the device uses direct address, where address of data indicate alignment. + * \param device The input device. + * \return True if the device uses direct address, false otherwise. + */ +inline bool IsDirectAddressDevice(const DLDevice &device) { + return device.device_type <= kDLCUDAHost || device.device_type == kDLCUDAManaged || device.device_type == kDLROCM || device.device_type == kDLROCMHost; +} + +/*! + * \brief check if a DLTensor is contiguous. + * \param arr The input DLTensor. + * \return The check result. + */ +inline bool IsContiguous(const DLTensor &arr) { + if (arr.strides == nullptr) { + return true; + } + int64_t expected_stride = 1; + for (int32_t i = arr.ndim; i != 0; --i) { + int32_t k = i - 1; + if (arr.shape[k] == 1) { + // Skip stride check if shape[k] is 1, where the dimension is contiguous + // regardless of the value of stride. + // + // For example, PyTorch will normalize stride to 1 if shape is 1 when exporting + // to DLPack. + // More context: https://github.com/pytorch/pytorch/pull/83158 + continue; + } + if (arr.strides[k] != expected_stride) { + return false; + } + expected_stride *= arr.shape[k]; + } + return true; +} + +/** + * \brief Check if the data in the DLTensor is aligned to the given alignment. + * \param arr The input DLTensor. + * \param alignment The alignment to check. + * \return True if the data is aligned to the given alignment, false otherwise. + */ +inline bool IsAligned(const DLTensor &arr, size_t alignment) { + if (IsDirectAddressDevice(arr.device)) { + return (reinterpret_cast(static_cast(arr.data) + arr.byte_offset) % alignment == 0); + } else { + return arr.byte_offset % alignment == 0; + } +} + +/*! + * \brief return the total number of bytes needed to store packed data + * + * \param numel the number of elements in the array + * \param dtype the data type of the array + * \return the total number of bytes needed to store packed data + */ +inline size_t GetDataSize(size_t numel, DLDataType dtype) { + // compatible handling sub-byte uint1(bool), which usually stored as uint8_t + // TODO(tqchen): revisit and switch to kDLBool + if (dtype.code == kDLUInt && dtype.bits == 1 && dtype.lanes == 1) { + return numel; + } + // for other sub-byte types, packing is preferred + return (numel * dtype.bits * dtype.lanes + 7) / 8; +} + +/*! + * \brief return the size of data the DLTensor holds, in terms of number of bytes + * + * \param arr the input DLTensor + * \return number of bytes of data in the DLTensor. + */ +inline size_t GetDataSize(const DLTensor &arr) { + size_t size = 1; + for (int i = 0; i < arr.ndim; ++i) { + size *= static_cast(arr.shape[i]); + } + return GetDataSize(size, arr.dtype); +} + +/*! \brief An object representing a Tensor. */ +class TensorObj : public Object, public DLTensor { +public: + /// \cond Doxygen_Suppress + static constexpr const uint32_t _type_index = TypeIndex::kTVMFFITensor; + TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFITensor, TensorObj, Object); + /// \endcond + + /*! + * \brief Move a Tensor to a DLPack managed tensor. + * \return The converted DLPack managed tensor. + */ + DLManagedTensor *ToDLPack() const { + TensorObj *self = const_cast(this); + DLManagedTensor *ret = new DLManagedTensor(); + ret->dl_tensor = *static_cast(self); + ret->manager_ctx = self; + ret->deleter = DLManagedTensorDeleter; + details::ObjectUnsafe::IncRefObjectHandle(self); + return ret; + } + + /*! + * \brief Move a Tensor to a DLPack managed tensor. + * \return The converted DLPack managed tensor. + */ + DLManagedTensorVersioned *ToDLPackVersioned() const { + TensorObj *self = const_cast(this); + DLManagedTensorVersioned *ret = new DLManagedTensorVersioned(); + ret->version.major = DLPACK_MAJOR_VERSION; + ret->version.minor = DLPACK_MINOR_VERSION; + ret->dl_tensor = *static_cast(self); + ret->manager_ctx = self; + ret->deleter = DLManagedTensorDeleter; + details::ObjectUnsafe::IncRefObjectHandle(self); + return ret; + } + +protected: + /*! + * \brief Deleter for DLManagedTensor. + * \param tensor The DLManagedTensor to be deleted. + */ + template + static void DLManagedTensorDeleter(TDLManagedTensor *tensor) { + TensorObj *obj = static_cast(tensor->manager_ctx); + details::ObjectUnsafe::DecRefObjectHandle(obj); + delete tensor; + } + + friend class Tensor; +}; + +namespace details { +/*! + *\brief Helper class to create an TensorObj from an NDAllocator + * + * The underlying allocator needs to be implemented by user. + */ +template +class TensorObjFromNDAlloc : public TensorObj { +public: + using Self = TensorObjFromNDAlloc; + + template + TensorObjFromNDAlloc(TNDAlloc alloc, ffi::ShapeView shape, DLDataType dtype, DLDevice device, + ExtraArgs &&...extra_args) + : alloc_(alloc) { + this->device = device; + this->ndim = static_cast(shape.size()); + this->dtype = dtype; + this->byte_offset = 0; + // inplace alloc shape and strides after data structure + this->shape = reinterpret_cast(reinterpret_cast(this) + sizeof(Self)); + this->strides = this->shape + shape.size(); + std::copy(shape.begin(), shape.end(), this->shape); + details::FillStridesFromShape(shape, this->strides); + // call allocator to alloc data + alloc_.AllocData(static_cast(this), std::forward(extra_args)...); + } + + ~TensorObjFromNDAlloc() { alloc_.FreeData(static_cast(this)); } + +private: + TNDAlloc alloc_; +}; + +/*! \brief helper class to import from DLPack legacy DLManagedTensor */ +template +class TensorObjFromDLPack : public TensorObj { +public: + using Self = TensorObjFromDLPack; + + explicit TensorObjFromDLPack(TDLPackManagedTensor *tensor, bool extra_strides_at_tail) + : tensor_(tensor) { + *static_cast(this) = tensor_->dl_tensor; + if (extra_strides_at_tail) { + this->strides = reinterpret_cast(reinterpret_cast(this) + sizeof(Self)); + details::FillStridesFromShape(ShapeView(tensor_->dl_tensor.shape, tensor_->dl_tensor.ndim), + this->strides); + } + } + + ~TensorObjFromDLPack() { + // run DLPack deleter if needed. + if (tensor_->deleter != nullptr) { + (*tensor_->deleter)(tensor_); + } + } + +private: + TDLPackManagedTensor *tensor_; +}; +} // namespace details + +/*! + * \brief Managed Tensor (n-dimensional array). + * The tensor is backed by reference counted blocks. + * + * \note This class can be subclassed to implement downstream customized + * Tensor types that are backed by the same TensorObj storage type. + */ +class Tensor : public ObjectRef { +public: + /*! + * \brief Default constructor. + */ + Tensor() = default; + /*! + * \brief Constructor from a ObjectPtr. + * \param n The ObjectPtr. + */ + explicit Tensor(::tvm::ffi::ObjectPtr n) : ObjectRef(std::move(n)) {} + /*! + * \brief Constructor from a UnsafeInit tag. + * \param tag The UnsafeInit tag. + */ + explicit Tensor(::tvm::ffi::UnsafeInit tag) : ObjectRef(tag) {} + /// \cond Doxygen_Suppress + TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Tensor) + /// \endcond + /*! + * \brief Get the data pointer of the Tensor. + * \return The data pointer of the Tensor. + */ + void *data_ptr() const { return get()->data; } + + /*! + * \brief Get the device of the Tensor. + * \return The device of the Tensor. + */ + DLDevice device() const { return get()->device; } + + /*! + * \brief Get the number of dimensions in the Tensor. + * \return The number of dimensions in the Tensor. + */ + int32_t ndim() const { return get()->ndim; } + + /*! + * \brief Get the data type of the Tensor. + * \return The data type of the Tensor. + */ + DLDataType dtype() const { return get()->dtype; } + + /*! + * \brief Get the shape of the Tensor. + * \return The shape of the Tensor. + */ + ShapeView shape() const { + const TensorObj *obj = get(); + return tvm::ffi::ShapeView(obj->shape, obj->ndim); + } + + /*! + * \brief Get the strides of the Tensor. + * \return The strides of the Tensor. + */ + ShapeView strides() const { + const TensorObj *obj = get(); + TVM_FFI_ICHECK(obj->strides != nullptr || obj->ndim == 0); + return ShapeView(obj->strides, obj->ndim); + } + + /*! + * \brief Get the size of the idx-th dimension. If the idx is negative, + * it gets the size of last idx-th dimension. + * \param idx The index of the size. + * \return The size of the idx-th dimension. + */ + int64_t size(int64_t idx) const { + const TensorObj *ptr = get(); + return ptr->shape[idx >= 0 ? idx : (ptr->ndim + idx)]; + } + + /*! + * \brief Get the stride of the idx-th dimension. If the idx is negative, + * it gets the stride of last idx-th dimension. + * \param idx The index of the stride. + * \return The stride of the idx-th dimension. + */ + int64_t stride(int64_t idx) const { + const TensorObj *ptr = get(); + return ptr->strides[idx >= 0 ? idx : (ptr->ndim + idx)]; + } + + /*! + * \brief Get the number of elements in the Tensor. + * \return The number of elements in the Tensor. + */ + int64_t numel() const { return this->shape().Product(); } + /*! + * \brief Get the byte offset of the Tensor. + * \return The byte offset of the Tensor. + */ + uint64_t byte_offset() const { return get()->byte_offset; } + /*! + * \brief Check if the Tensor is contiguous. + * \return True if the Tensor is contiguous, false otherwise. + */ + bool IsContiguous() const { return tvm::ffi::IsContiguous(*get()); } + /*! + * \brief Check if the Tensor data is aligned to the given alignment. + * \param alignment The alignment to check. + * \return True if the Tensor data is aligned to the given alignment, false otherwise. + */ + bool IsAligned(size_t alignment) const { return tvm::ffi::IsAligned(*get(), alignment); } + /*! + * \brief Create a Tensor from a NDAllocator. + * + * \note When building a kernel library, we always recommend use FromEnvAlloc when possible to + * allocate intermediate Tensors. When a loaded module returns an allocated tensor to the caller, + * we need to keep the module alive before the returned tensors get freed, because its + * deleter is defined within the module. FromNDAlloc can be used by C++ applications and runtimes + * to create Tensors. + * + * Example usage: + * \code + * // CPU Allocator + * struct CPUNDAlloc { + * void AllocData(DLTensor* tensor) { tensor->data = malloc(ffi::GetDataSize(*tensor)); } + * void FreeData(DLTensor* tensor) { free(tensor->data); } + * }; + * + * // CUDA Allocator + * struct CUDANDAlloc { + * void AllocData(DLTensor* tensor) { + * size_t data_size = ffi::GetDataSize(*tensor); + * void* ptr = nullptr; + * cudaError_t err = cudaMalloc(&ptr, data_size); + * TVM_FFI_ICHECK_EQ(err, cudaSuccess) << "cudaMalloc failed: " << cudaGetErrorString(err); + * tensor->data = ptr; + * } + * void FreeData(DLTensor* tensor) { + * if (tensor->data != nullptr) { + * cudaError_t err = cudaFree(tensor->data); + * TVM_FFI_ICHECK_EQ(err, cudaSuccess) << "cudaFree failed: " << cudaGetErrorString(err); + * tensor->data = nullptr; + * } + * } + * }; + * + * // NVSHMEM Allocator + * struct NVSHMEMNDAlloc { + * void AllocData(DLTensor* tensor) { + * size_t size = tvm::ffi::GetDataSize(*tensor); + * tensor->data = nvshmem_malloc(size); + * TVM_FFI_ICHECK_NE(tensor->data, nullptr) << "nvshmem_malloc failed. size: " << size; + * } + * void FreeData(DLTensor* tensor) { nvshmem_free(tensor->data); } + * }; + * + * // Allocator usage + * ffi::Tensor cpu_tensor = ffi::Tensor::FromNDAlloc(CPUNDAlloc(), ...); + * ffi::Tensor cuda_tensor = ffi::Tensor::FromNDAlloc(CUDANDAlloc(), ...); + * ffi::Tensor nvshmem_tensor = ffi::Tensor::FromNDAlloc(NVSHMEMNDAlloc(), ...); + * \endcode + * + * \param alloc The NDAllocator. + * \param shape The shape of the Tensor. + * \param dtype The data type of the Tensor. + * \param device The device of the Tensor. + * \param extra_args Extra arguments to be forwarded to TNDAlloc. + * \return The created Tensor. + * \tparam TNDAlloc The type of the NDAllocator, impelments Alloc and Free. + * \tparam ExtraArgs Extra arguments to be passed to Alloc. + */ + template + static Tensor FromNDAlloc(TNDAlloc alloc, ffi::ShapeView shape, DLDataType dtype, DLDevice device, + ExtraArgs &&...extra_args) { + // inplace alloc shape and strides after data structure (as a result why multiply 2) + size_t num_extra_i64_at_tail = shape.size() * 2; + return Tensor(make_inplace_array_object, int64_t>( + num_extra_i64_at_tail, alloc, shape, dtype, device, + std::forward(extra_args)...)); + } + /*! + * \brief Create a Tensor from the TVMFFIEnvTensorAlloc API + * + * This function can be used together with TVMFFIEnvSetDLPackManagedTensorAllocator + * in the extra/c_env_api.h to create a Tensor from the thread-local environment allocator. + * We explicitly pass TVMFFIEnvTensorAlloc to maintain explicit dependency on extra/c_env_api.h + * + * \code + * + * ffi::Tensor tensor = ffi::Tensor::FromEnvAlloc( + * TVMFFIEnvTensorAlloc, shape, dtype, device + * ); + * + * \endcode + * + * \param env_alloc TVMFFIEnvTensorAlloc function pointer. + * \param shape The shape of the Tensor. + * \param dtype The data type of the Tensor. + * \param device The device of the Tensor. + * \return The created Tensor. + * + * \sa TVMFFIEnvTensorAlloc + */ + static Tensor FromEnvAlloc(int (*env_alloc)(DLTensor *, TVMFFIObjectHandle *), ffi::ShapeView shape, + DLDataType dtype, DLDevice device) { + TVMFFIObjectHandle out; + DLTensor prototype{}; + prototype.device = device; + prototype.dtype = dtype; + prototype.shape = const_cast(shape.data()); + prototype.ndim = static_cast(shape.size()); + TVM_FFI_CHECK_SAFE_CALL(env_alloc(&prototype, &out)); + return Tensor( + details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(out))); + } + /*! + * \brief Create a Tensor from a DLPack managed tensor, pre v1.0 API. + * \param tensor The input DLPack managed tensor. + * \param require_alignment The minimum alignment requored of the data + byte_offset. + * \param require_contiguous Boolean flag indicating if we need to check for contiguity. + * \note This function will not run any checks on flags. + * \return The created Tensor. + */ + static Tensor FromDLPack(DLManagedTensor *tensor, size_t require_alignment = 0, + bool require_contiguous = false) { + if (require_alignment != 0 && !ffi::IsAligned(tensor->dl_tensor, require_alignment)) { + TVM_FFI_THROW(RuntimeError) << "FromDLPack: Data is not aligned to " << require_alignment + << " bytes."; + } + if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) { + TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous."; + } + if (tensor->dl_tensor.strides != nullptr || tensor->dl_tensor.ndim == 0) { + return Tensor(make_object>( + tensor, /*extra_strides_at_tail=*/false)); + } else { + return Tensor( + make_inplace_array_object, int64_t>( + tensor->dl_tensor.ndim, tensor, /*extra_strides_at_tail=*/true)); + } + } + + /*! + * \brief Create a Tensor from a DLPack managed tensor, post v1.0 API. + * \param tensor The input DLPack managed tensor. + * \param require_alignment The minimum alignment requored of the data + byte_offset. + * \param require_contiguous Boolean flag indicating if we need to check for contiguity. + * \return The created Tensor. + */ + static Tensor FromDLPackVersioned(DLManagedTensorVersioned *tensor, size_t require_alignment = 0, + bool require_contiguous = false) { + if (require_alignment != 0 && !ffi::IsAligned(tensor->dl_tensor, require_alignment)) { + TVM_FFI_THROW(RuntimeError) << "FromDLPack: Data is not aligned to " << require_alignment + << " bytes."; + } + if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) { + TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous."; + } + if (tensor->flags & DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED) { + TVM_FFI_THROW(RuntimeError) << "Subbyte type padded is not yet supported"; + } + if (tensor->dl_tensor.strides != nullptr || tensor->dl_tensor.ndim == 0) { + return Tensor(make_object>( + tensor, /*extra_strides_at_tail=*/false)); + } else { + return Tensor( + make_inplace_array_object, + int64_t>(tensor->dl_tensor.ndim, tensor, + /*extra_strides_at_tail=*/true)); + } + } + + /*! + * \brief Convert the Tensor to a DLPack managed tensor. + * \return The converted DLPack managed tensor. + */ + DLManagedTensor *ToDLPack() const { return get_mutable()->ToDLPack(); } + + /*! + * \brief Convert the Tensor to a DLPack managed tensor. + * \return The converted DLPack managed tensor. + */ + DLManagedTensorVersioned *ToDLPackVersioned() const { return get_mutable()->ToDLPackVersioned(); } + /*! + * \brief Get the underlying DLTensor pointer. + * \return The underlying DLTensor pointer. + */ + const DLTensor *GetDLTensorPtr() const { return get(); } + /// \cond Doxygen_Suppress + [[maybe_unused]] static constexpr bool _type_is_nullable = true; + using ContainerType = TensorObj; + /// \endcond + + // the following code are convenient APIs redirections created to provide aten-style api + /*! + * \brief This functions redirects to ndim(). + * \return The number of dimensions in the Tensor. + */ + inline int32_t dim() { return ndim(); } + /*! + * \brief This functions redirects to shape(). + * \return The shape of the Tensor. + */ + inline ShapeView sizes() const { return shape(); } + /*! + * \brief This functions redirects to IsContiguous(). + * \return True if the Tensor is contiguous, false otherwise. + */ + inline bool is_contiguous() const { return IsContiguous(); } + +protected: + /*! + * \brief Get const internal container pointer. + * \return a const container pointer. + */ + const TensorObj *get() const { return static_cast(ObjectRef::get()); } + /*! + * \brief Get mutable internal container pointer. + * \return a mutable container pointer. + */ + TensorObj *get_mutable() const { return const_cast(get()); } +}; + +/*! + * \brief A non-owning view of a Tensor. + * + * This class stores a light-weight non-owning view of a Tensor. + * This is useful for accessing tensor data without retaining a strong reference to the Tensor. + * Since the caller may not always be able to pass in a strong referenced tensor. + * + * It is the user's responsibility to ensure + * that the underlying tensor data outlives the `TensorView`. + * This responsibility extends to all data pointed to by the underlying DLTensor. + * This includes not only the tensor elements in data but also the memory for shape and strides. + * + * When exposing a function that expects only expects a TensorView, we recommend using + * ffi::TensorView as the argument type instead of ffi::Tensor. + */ +class TensorView { +public: + /*! + * \brief Create a TensorView from a Tensor. + * \param tensor The input Tensor. + */ + TensorView(const Tensor &tensor) { // NOLINT(*) + TVM_FFI_ICHECK(tensor.defined()); + tensor_ = *tensor.GetDLTensorPtr(); + } // NOLINT(*) + /*! + * \brief Create a TensorView from a DLTensor. + * \param tensor The input DLTensor. + */ + TensorView(const DLTensor *tensor) { // NOLINT(*) + TVM_FFI_ICHECK(tensor != nullptr); + tensor_ = *tensor; + } + /*! + * \brief Copy constructor. + * \param tensor The input TensorView. + */ + TensorView(const TensorView &tensor) = default; + /*! + * \brief Move constructor. + * \param tensor The input TensorView. + */ + TensorView(TensorView &&tensor) = default; + /*! + * \brief Copy assignment operator. + * \param tensor The input TensorView. + * \return The created TensorView. + */ + TensorView &operator=(const TensorView &tensor) = default; + /*! + * \brief Move assignment operator. + * \param tensor The input TensorView. + * \return The created TensorView. + */ + TensorView &operator=(TensorView &&tensor) = default; + /*! + * \brief Assignment operator from a Tensor. + * \param tensor The input Tensor. + * \return The created TensorView. + */ + TensorView &operator=(const Tensor &tensor) { + TVM_FFI_ICHECK(tensor.defined()); + tensor_ = *tensor.GetDLTensorPtr(); + return *this; + } + + // explicitly delete move constructor + TensorView(Tensor &&tensor) = delete; // NOLINT(*) + // delete move assignment operator from owned tensor + TensorView &operator=(Tensor &&tensor) = delete; + /*! + * \brief Get the data pointer of the Tensor. + * \return The data pointer of the Tensor. + */ + void *data_ptr() const { return tensor_.data; } + /*! + * \brief Get the device of the Tensor. + * \return The device of the Tensor. + */ + DLDevice device() const { return tensor_.device; } + /*! + * \brief Get the number of dimensions in the Tensor. + * \return The number of dimensions in the Tensor. + */ + int32_t ndim() const { return tensor_.ndim; } + /*! + * \brief Get the data type of the Tensor. + * \return The data type of the Tensor. + */ + DLDataType dtype() const { return tensor_.dtype; } + /*! + * \brief Get the shape of the Tensor. + * \return The shape of the Tensor. + */ + ShapeView shape() const { return ShapeView(tensor_.shape, tensor_.ndim); } + + /*! + * \brief Get the number of elements in the Tensor. + * \return The number of elements in the Tensor. + */ + int64_t numel() const { return this->shape().Product(); } + + /*! + * \brief Get the strides of the Tensor. + * \return The strides of the Tensor. + */ + ShapeView strides() const { + TVM_FFI_ICHECK(tensor_.strides != nullptr || tensor_.ndim == 0); + return ShapeView(tensor_.strides, tensor_.ndim); + } + + /*! + * \brief Get the size of the idx-th dimension. If the idx is negative, + * it gets the size of last idx-th dimension. + * \param idx The index of the size. + * \return The size of the idx-th dimension. + */ + int64_t size(int64_t idx) const { return tensor_.shape[idx >= 0 ? idx : tensor_.ndim + idx]; } + + /*! + * \brief Get the stride of the idx-th dimension. If the idx is negative, + * it gets the stride of last idx-th dimension. + * \param idx The index of the stride. + * \return The stride of the idx-th dimension. + */ + int64_t stride(int64_t idx) const { return tensor_.strides[idx >= 0 ? idx : tensor_.ndim + idx]; } + + /*! + * \brief Get the byte offset of the Tensor. + * \return The byte offset of the Tensor. + */ + uint64_t byte_offset() const { return tensor_.byte_offset; } + + /*! + * \brief Check if the Tensor is contiguous. + * \return True if the Tensor is contiguous, false otherwise. + */ + bool IsContiguous() const { return tvm::ffi::IsContiguous(tensor_); } + + // the following code are convenient APIs redirections created to provide aten-style api + /*! + * \brief This functions redirects to ndim(). + * \return The number of dimensions in the Tensor. + */ + inline int32_t dim() { return ndim(); } + /*! + * \brief This functions redirects to shape(). + * \return The shape of the Tensor. + */ + inline ShapeView sizes() const { return shape(); } + /*! + * \brief This functions redirects to IsContiguous(). + * \return True if the Tensor is contiguous, false otherwise. + */ + inline bool is_contiguous() const { return IsContiguous(); } + +private: + DLTensor tensor_; + template + friend struct TypeTraits; +}; + +/*! + * \brief Get the data size of the Tensor. + * \param tensor The input Tensor. + * \return The data size of the Tensor. + */ +inline size_t GetDataSize(const Tensor &tensor) { + return GetDataSize(tensor.numel(), tensor.dtype()); +} + +/*! + * \brief Get the data size of the TensorView. + * \param tensor The input TensorView. + * \return The data size of the TensorView. + */ +inline size_t GetDataSize(const TensorView &tensor) { + return GetDataSize(tensor.numel(), tensor.dtype()); +} + +// TensorView type, allow implicit casting from DLTensor* +// NOTE: we deliberately do not support MoveToAny and MoveFromAny since it does not retain ownership +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr bool storage_enabled = false; + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDLTensorPtr; + + TVM_FFI_INLINE static void CopyToAnyView(const TensorView &src, TVMFFIAny *result) { + result->type_index = TypeIndex::kTVMFFIDLTensorPtr; + result->zero_padding = 0; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); + result->v_ptr = const_cast(&(src.tensor_)); + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + return src->type_index == TypeIndex::kTVMFFIDLTensorPtr; + } + + TVM_FFI_INLINE static TensorView CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { + return TensorView(static_cast(src->v_ptr)); + } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFIDLTensorPtr) { + return TensorView(static_cast(src->v_ptr)); + } else if (src->type_index == TypeIndex::kTVMFFITensor) { + return TensorView(TVMFFITensorGetDLTensorPtr(src->v_obj)); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIDLTensorPtr; } + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIDLTensorPtr) + R"("})"; + } +}; + +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_CONTAINER_TENSOR_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tuple.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tuple.h new file mode 100644 index 000000000..e5eb3cab6 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tuple.h @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/tuple.h + * \brief Typed tuple like std::tuple backed by ArrayObj container. + */ +#ifndef TVM_FFI_CONTAINER_TUPLE_H_ +#define TVM_FFI_CONTAINER_TUPLE_H_ + +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief Typed tuple like std::tuple backed by ArrayObj container. + * + * Tuple implements in-place copy-on-write semantics. + * + * \tparam Types The types of the tuple elements + */ +template +class Tuple : public ObjectRef { + public: + static_assert(details::all_storage_enabled_v, + "All types used in Tuple<...> must be compatible with Any"); + /*! \brief Default constructor */ + Tuple() : ObjectRef(MakeDefaultTupleNode()) {} + /*! + * \brief Constructor with UnsafeInit + */ + explicit Tuple(UnsafeInit tag) : ObjectRef(tag) {} + /*! \brief Copy constructor */ + Tuple(const Tuple& other) : ObjectRef(other) {} + /*! \brief Move constructor */ + Tuple(Tuple&& other) noexcept : ObjectRef(std::move(other)) {} + /*! + * \brief Constructor from another tuple + * \param other The other tuple + * \tparam UTypes The types of the other tuple + * \tparam The enable_if_t type + */ + template && ...), int>> + Tuple(const Tuple& other) : ObjectRef(other) {} // NOLINT(google-explicit-constructor) + + /*! + * \brief Constructor from another tuple + * \param other The other tuple + * \tparam UTypes The types of the other tuple + * \tparam The enable_if_t type + */ + template && ...), int>> + Tuple(Tuple&& other) // NOLINT(google-explicit-constructor) + : ObjectRef(std::move(other)) {} + + /*! + * \brief Constructor from arguments + * \param args The arguments + * \tparam UTypes The types of the other tuple + */ + template , Tuple> && ...))>> + explicit Tuple(UTypes&&... args) : ObjectRef(MakeTupleNode(std::forward(args)...)) {} + + /*! + * \brief Assignment from another tuple + * \param other The other tuple + * \tparam The enable_if_t type + */ + TVM_FFI_INLINE Tuple& operator=(const Tuple& other) { + data_ = other.data_; + return *this; + } + + /*! + * \brief Assignment from another tuple + * \param other The other tuple + * \tparam The enable_if_t type + */ + TVM_FFI_INLINE Tuple& operator=(Tuple&& other) noexcept { + data_ = std::move(other.data_); + return *this; + } + + /*! + * \brief Assignment from another tuple + * \param other The other tuple + * \tparam UTypes The types of the other tuple + * \tparam The enable_if_t type + */ + template && ...)>> + TVM_FFI_INLINE Tuple& operator=(const Tuple& other) { + data_ = other.data_; + return *this; + } + + /*! + * \brief Assignment from another tuple + * \param other The other tuple + * \tparam UTypes The types of the other tuple + * \tparam The enable_if_t type + */ + template && ...)>> + TVM_FFI_INLINE Tuple& operator=(Tuple&& other) { + data_ = std::move(other.data_); + return *this; + } + + /*! + * \brief Get I-th element of the tuple + * + * \tparam I The index of the element to get + * \return The I-th element of the tuple + * \note We use stl style since get usually is like a getter. + */ + template + auto get() const& { + static_assert(I < sizeof...(Types), "Tuple index out of bounds"); + using ReturnType = std::tuple_element_t>; + const Any* ptr = GetArrayObj()->begin() + I; + return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*ptr); + } + + /*! + * \brief Move out I-th element of the tuple + * + * \tparam I The index of the element to get + * \return The I-th element of the tuple + * \note We use stl style since get usually is like a getter. + */ + template + auto get() && { + if (!this->unique()) { + // fallback to const& version if not unique + return std::as_const(*this).template get(); + } + static_assert(I < sizeof...(Types), "Tuple index out of bounds"); + using ReturnType = std::tuple_element_t>; + Any* ptr = GetArrayObj()->MutableBegin() + I; + return details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(*ptr)); + } + + /*! + * \brief Set I-th element of the tuple + * + * \param item The item to set + * \tparam I The index of the element to set + * \tparam U The type of the item + * + * \note This function will perform copy on write if underlying + * container is not uniquely owned. + * We use CamelCase since Set can cause copy on write + * and is more complicated than simple field setter. + */ + template + void Set(U&& item) { + static_assert(I < sizeof...(Types), "Tuple index out of bounds"); + using T = std::tuple_element_t>; + this->CopyIfNotUnique(); + Any* ptr = GetArrayObj()->MutableBegin() + I; + *ptr = T(std::forward(item)); + } + + /*! \brief specify container node */ + using ContainerType = ArrayObj; + + private: + static ObjectPtr MakeDefaultTupleNode() { + ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); + Any* itr = p->MutableBegin(); + // increase size after each new to ensure exception safety + ((new (itr++) Any(Types()), p->size_++), ...); + return p; + } + + template + static ObjectPtr MakeTupleNode(UTypes&&... args) { + ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); + Any* itr = p->MutableBegin(); + // increase size after each new to ensure exception safety + ((new (itr++) Any(Types(std::forward(args))), p->size_++), ...); + return p; + } + + /*! \brief Copy on write */ + void CopyIfNotUnique() { + if (!data_.unique()) { + ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); + Any* itr = p->MutableBegin(); + const Any* read = GetArrayObj()->begin(); + // increase size after each new to ensure exception safety + for (size_t i = 0; i < sizeof...(Types); ++i) { + new (itr++) Any(*read++); + p->size_++; + } + data_ = std::move(p); + } + } + + /*! \return The underlying ArrayObj */ + ArrayObj* GetArrayObj() const { return static_cast(data_.get()); } + + template + friend class Tuple; +}; + +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> : public ObjectRefTypeTraitsBase> { + using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; + + TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { + if (src->type_index != TypeIndex::kTVMFFIArray) { + return TypeTraitsBase::GetMismatchTypeInfo(src); + } + const ArrayObj* n = reinterpret_cast(src->v_obj); + if (n->size() != sizeof...(Types)) { + return "Array[size=" + std::to_string(n->size()) + "]"; + } + return GetMismatchTypeInfoHelper<0, Types...>(n->begin()); + } + + template + TVM_FFI_INLINE static std::string GetMismatchTypeInfoHelper(const Any* arr) { + if constexpr (!std::is_same_v) { + const Any& any_v = arr[I]; + if (!details::AnyUnsafe::CheckAnyStrict(any_v) && !(any_v.try_cast().has_value())) { + // now report the accurate mismatch information + return "Array[index " + std::to_string(I) + ": " + + details::AnyUnsafe::GetMismatchTypeInfo(any_v) + "]"; + } + } + if constexpr (sizeof...(Rest) > 0) { + return GetMismatchTypeInfoHelper(arr); + } + TVM_FFI_THROW(InternalError) << "Cannot reach here"; + TVM_FFI_UNREACHABLE(); + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { + if (src->type_index != TypeIndex::kTVMFFIArray) return false; + const ArrayObj* n = reinterpret_cast(src->v_obj); + if (n->size() != sizeof...(Types)) return false; + const TVMFFIAny* ffi_any_arr = reinterpret_cast(n->begin()); + return CheckAnyStrictHelper<0, Types...>(ffi_any_arr); + } + + template + TVM_FFI_INLINE static bool CheckAnyStrictHelper(const TVMFFIAny* src_arr) { + if constexpr (!std::is_same_v) { + if (!TypeTraits::CheckAnyStrict(src_arr + I)) { + return false; + } + } + if constexpr (sizeof...(Rest) > 0) { + return CheckAnyStrictHelper(src_arr); + } + return true; + } + + TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { + if (src->type_index != TypeIndex::kTVMFFIArray) return std::nullopt; + const ArrayObj* n = reinterpret_cast(src->v_obj); + if (n->size() != sizeof...(Types)) return std::nullopt; + // fast path, storage is already in the right type + if (CheckAnyStrict(src)) { + return CopyFromAnyViewAfterCheck(src); + } + // slow path, try to convert to each type to match the tuple storage need. + Array arr = TypeTraits>::CopyFromAnyViewAfterCheck(src); + Any* ptr = arr.CopyOnWrite()->MutableBegin(); + if (TryConvertElements<0, Types...>(ptr)) { + return details::ObjectUnsafe::ObjectRefFromObjectPtr>( + details::ObjectUnsafe::ObjectPtrFromObjectRef(arr)); + } + return std::nullopt; + } + + template + TVM_FFI_INLINE static bool TryConvertElements(Any* arr) { + if constexpr (!std::is_same_v) { + if (auto opt_convert = arr[I].try_cast()) { + arr[I] = *std::move(opt_convert); + } else { + return false; + } + } + if constexpr (sizeof...(Rest) > 0) { + return TryConvertElements(std::move(arr)); + } else { + return true; + } + } + + TVM_FFI_INLINE static std::string TypeStr() { + return details::ContainerTypeStr("Tuple"); + } + TVM_FFI_INLINE static std::string TypeSchema() { + std::ostringstream oss; + oss << R"({"type":"Tuple","args":[)"; + const char* sep = ""; + ((oss << sep << details::TypeSchema::v(), sep = ","), ...); + oss << "]}"; + return oss.str(); + } +}; + +namespace details { +template +inline constexpr bool type_contains_v, Tuple> = (type_contains_v && ...); +} // namespace details + +/// \cond Doxygen_Suppress + +/// NOTE: ADL friendly get functions +/// Example usage: { using std::get; get<0>(t); } +/// ADL will find the right get function + +/** + * \brief get I-th element of the tuple + * \tparam I The index of the element to get + * \param t The tuple + * \return The I-th element of the tuple + */ +template +inline constexpr auto get(const Tuple& t) + -> std::tuple_element_t> { + return t.template get(); +} + +/** + * \brief get I-th element of the tuple + * \tparam I The index of the element to get + * \param t The tuple (rvalue) + * \return The I-th element of the tuple + */ +template +inline constexpr auto get(Tuple&& t) -> std::tuple_element_t> { + return std::move(t).template get(); +} + +/// NOTE: C++17 deduction guide +template +Tuple(UTypes&&...) -> Tuple>...>; + +/// \endcond + +} // namespace ffi +} // namespace tvm + +namespace std { + +template +struct tuple_size<::tvm::ffi::Tuple> + : public std::integral_constant {}; + +template +struct tuple_element> { + using type = std::tuple_element_t>; +}; + +} // namespace std + +#endif // TVM_FFI_CONTAINER_TUPLE_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/variant.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/variant.h new file mode 100644 index 000000000..08dc764d5 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/variant.h @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/container/variant.h + * \brief Runtime variant container types. + */ +#ifndef TVM_FFI_CONTAINER_VARIANT_H_ +#define TVM_FFI_CONTAINER_VARIANT_H_ + +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace ffi { +namespace details { +/*! + * \brief Base class for Variant. + * + * \tparam all_storage_object Whether all types are derived from ObjectRef. + */ +template +class VariantBase { + public: + TVM_FFI_INLINE bool same_as(const VariantBase& other) const { + return data_.same_as(other.data_); + } + + protected: + template + explicit VariantBase(T other) : data_(std::move(other)) {} + + TVM_FFI_INLINE void SetData(Any other_data) { data_ = std::move(other_data); } + + TVM_FFI_INLINE Any MoveToAny() && { return std::move(data_); } + + TVM_FFI_INLINE AnyView ToAnyView() const { return data_.operator AnyView(); } + + Any data_; +}; + +// Specialization for all object ref case, backed by ObjectRef. +template <> +class VariantBase : public ObjectRef { + protected: + template + explicit VariantBase(const T& other) : ObjectRef(other) {} + template , VariantBase>>> + explicit VariantBase(T&& other) : ObjectRef(std::forward(other)) {} + explicit VariantBase(UnsafeInit tag) : ObjectRef(tag) {} + explicit VariantBase(Any other) + : ObjectRef(details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(other))) {} + + TVM_FFI_INLINE void SetData(ObjectPtr other) { data_ = std::move(other); } + + TVM_FFI_INLINE Any MoveToAny() && { return Any(ObjectRef(std::move(data_))); } + + TVM_FFI_INLINE AnyView ToAnyView() const { + TVMFFIAny any_data; + if (data_ == nullptr) { + any_data.type_index = TypeIndex::kTVMFFINone; + any_data.zero_padding = 0; + any_data.v_int64 = 0; + } else { + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data); + any_data.type_index = data_->type_index(); + any_data.zero_padding = 0; + any_data.v_obj = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr(data_); + } + return AnyView::CopyFromTVMFFIAny(any_data); + } +}; +} // namespace details + +/*! + * \brief A typed variant container. + * + * When all values are ObjectRef, Variant is backed by ObjectRef, + * otherwise it is backed by Any. + */ +template +class Variant : public details::VariantBase> { + public: + /// \cond Doxygen_Suppress + using TParent = details::VariantBase>; + static_assert(details::all_storage_enabled_v, + "All types used in Variant<...> must be compatible with Any"); + /* + * \brief Helper utility to check if the type can be contained in the variant + */ + template + static constexpr bool variant_contains_v = (details::type_contains_v || ...); + /* \brief Helper utility for SFINAE if the type is part of the variant */ + template + using enable_if_variant_contains_t = std::enable_if_t>; + /// \endcond + /*! + * \brief Constructor from another variant + * \param other The other variant + */ + Variant(const Variant& other) : TParent(other.data_) {} + /*! + * \brief Constructor from another variant + * \param other The other variant + */ + Variant(Variant&& other) noexcept : TParent(std::move(other.data_)) {} + + /*! + * \brief Assignment from another variant + * \param other The other variant + */ + TVM_FFI_INLINE Variant& operator=(const Variant& other) { + this->SetData(other.data_); + return *this; + } + + /*! + * \brief Assignment from another variant + * \param other The other variant + */ + TVM_FFI_INLINE Variant& operator=(Variant&& other) noexcept { + this->SetData(std::move(other.data_)); + return *this; + } + + /*! + * \brief Constructor from another variant + * \param other The other variant + */ + template > + Variant(T other) : TParent(std::move(other)) {} // NOLINT(*) + + /*! + * \brief Assignment from another variant + * \param other The other variant + */ + template > + TVM_FFI_INLINE Variant& operator=(T other) { + return operator=(Variant(std::move(other))); + } + + /*! + * \brief Try to cast to a type T, return std::nullopt if the cast is not possible. + * \return The casted value, or std::nullopt if the cast is not possible. + * \tparam T The type to cast to. + */ + template > + TVM_FFI_INLINE std::optional as() const { + return this->TParent::ToAnyView().template as(); + } + + /*! + * \brief Shortcut of as Object to cast to a const pointer when T is an Object. + * + * \tparam T The object type. + * \return The requested pointer, returns nullptr if type mismatches. + */ + template >> + TVM_FFI_INLINE const T* as() const { + return this->TParent::ToAnyView().template as().value_or(nullptr); + } + + /*! + * \brief Get the value of the variant in type T, throws an exception if cast fails. + * \return The value of the variant + * \tparam T The type to get. + */ + template > + TVM_FFI_INLINE T get() const& { + return this->TParent::ToAnyView().template cast(); + } + + /*! + * \brief Get the value of the variant in type T, throws an exception if cast fails. + * \return The value of the variant + * \tparam T The type to get. + */ + template > + TVM_FFI_INLINE T get() && { + return std::move(*this).TParent::MoveToAny().template cast(); + } + + /*! + * \brief Get the type key of the variant + * \return The type key of the variant + */ + TVM_FFI_INLINE std::string GetTypeKey() const { return this->TParent::ToAnyView().GetTypeKey(); } + + private: + friend struct TypeTraits>; + friend struct ObjectPtrHash; + friend struct ObjectPtrEqual; + // constructor from any + explicit Variant(Any data) : TParent(std::move(data)) {} + /*! + * \brief Get the object pointer from the variant + * \note This function is only available if all types used in Variant<...> are derived from + * ObjectRef + */ + TVM_FFI_INLINE Object* GetObjectPtrForHashEqual() const { + constexpr bool all_object_v = (std::is_base_of_v && ...); + static_assert(all_object_v, + "All types used in Variant<...> must be derived from ObjectRef " + "to enable ObjectPtrHash/ObjectPtrEqual"); + return this->data_.get(); + } + // rexpose to friend class + using TParent::MoveToAny; + using TParent::ToAnyView; +}; + +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> : public TypeTraitsBase { + TVM_FFI_INLINE static void CopyToAnyView(const Variant& src, TVMFFIAny* result) { + *result = src.ToAnyView().CopyToTVMFFIAny(); + } + + TVM_FFI_INLINE static void MoveToAny(Variant src, TVMFFIAny* result) { + *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src).MoveToAny()); + } + + TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { + return TypeTraitsBase::GetMismatchTypeInfo(src); + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { + return (TypeTraits::CheckAnyStrict(src) || ...); + } + + TVM_FFI_INLINE static Variant CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + return Variant(Any(AnyView::CopyFromTVMFFIAny(*src))); + } + + TVM_FFI_INLINE static Variant MoveFromAnyAfterCheck(TVMFFIAny* src) { + return Variant(details::AnyUnsafe::MoveTVMFFIAnyToAny(src)); + } + + TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { + // fast path, storage is already in the right type + if (CheckAnyStrict(src)) { + return CopyFromAnyViewAfterCheck(src); + } + // More expensive path, try to convert to each type, in order of declaration + return TryVariantTypes(src); + } + + template + TVM_FFI_INLINE static std::optional> TryVariantTypes(const TVMFFIAny* src) { + if (auto opt_convert = TypeTraits::TryCastFromAnyView(src)) { + return Variant(*std::move(opt_convert)); + } + if constexpr (sizeof...(Rest) > 0) { + return TryVariantTypes(src); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return details::ContainerTypeStr("Variant"); } + TVM_FFI_INLINE static std::string TypeSchema() { + std::ostringstream oss; + oss << R"({"type":"Variant","args":[)"; + const char* sep = ""; + ((oss << sep << details::TypeSchema::v(), sep = ","), ...); + oss << "]}"; + return oss.str(); + } +}; + +template +TVM_FFI_INLINE size_t ObjectPtrHash::operator()(const Variant& a) const { + return std::hash()(a.GetObjectPtrForHashEqual()); +} + +template +TVM_FFI_INLINE bool ObjectPtrEqual::operator()(const Variant& a, + const Variant& b) const { + return a.GetObjectPtrForHashEqual() == b.GetObjectPtrForHashEqual(); +} + +namespace details { +template +inline constexpr bool type_contains_v, T> = (type_contains_v || ...); +} // namespace details +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CONTAINER_VARIANT_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/dtype.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/dtype.h new file mode 100644 index 000000000..32a0da5f9 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/dtype.h @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/dtype.h + * \brief Data type handling. + */ +#ifndef TVM_FFI_DTYPE_H_ +#define TVM_FFI_DTYPE_H_ + +#include "../../dlpack/dlpack.h" +#include "error.h" +#include "function.h" +#include "string.h" +#include "type_traits.h" + +#include +#include + +namespace tvm { +namespace ffi { +/*! + * \brief Extension code beyond the DLDataType. + * + * This class is always consistent with the DLPack. + */ +enum DLExtDataTypeCode { kDLExtCustomBegin = 129 }; + +namespace details { + +/* + * \brief Convert a DLDataTypeCode to a string. + * \param os The output stream. + * \param type_code The DLDataTypeCode to convert. + */ +inline const char *DLDataTypeCodeAsCStr(DLDataTypeCode type_code) { // NOLINT(*) + switch (static_cast(type_code)) { + case kDLInt: { + return "int"; + } + case kDLUInt: { + return "uint"; + } + case kDLFloat: { + return "float"; + } + case kDLOpaqueHandle: { + return "handle"; + } + case kDLBfloat: { + return "bfloat"; + } + case kDLBool: { + return "bool"; + } + case kDLFloat8_e3m4: { + return "float8_e3m4"; + } + case kDLFloat8_e4m3: { + return "float8_e4m3"; + } + case kDLFloat8_e4m3b11fnuz: { + return "float8_e4m3b11fnuz"; + } + case kDLFloat8_e4m3fn: { + return "float8_e4m3fn"; + } + case kDLFloat8_e4m3fnuz: { + return "float8_e4m3fnuz"; + } + case kDLFloat8_e5m2: { + return "float8_e5m2"; + } + case kDLFloat8_e5m2fnuz: { + return "float8_e5m2fnuz"; + } + case kDLFloat8_e8m0fnu: { + return "float8_e8m0fnu"; + } + case kDLFloat6_e2m3fn: { + return "float6_e2m3fn"; + } + case kDLFloat6_e3m2fn: { + return "float6_e3m2fn"; + } + case kDLFloat4_e2m1fn: { + return "float4_e2m1fn"; + } + default: { + if (static_cast(type_code) >= static_cast(DLExtDataTypeCode::kDLExtCustomBegin)) { + return "custom"; + } else { + TVM_FFI_THROW(ValueError) << "DLDataType contains unknown type_code=" + << static_cast(type_code); + } + TVM_FFI_UNREACHABLE(); + } + } +} +} // namespace details + +/*! + * \brief Convert a string to a DLDataType. + * \param str The string to convert. + * \return The DLDataType. + */ +inline DLDataType StringToDLDataType(const String &str) { + DLDataType out; + TVMFFIByteArray data{str.data(), str.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeFromString(&data, &out)); + return out; +} + +/*! + * \brief Convert a DLDataType to a string. + * \param dtype The DLDataType to convert. + * \return The string. + */ +inline String DLDataTypeToString(DLDataType dtype) { + TVMFFIAny out; + TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(&dtype, &out)); + return TypeTraits::MoveFromAnyAfterCheck(&out); +} + +// DLDataType +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDataType; + + TVM_FFI_INLINE static void CopyToAnyView(const DLDataType &src, TVMFFIAny *result) { + // clear padding part to ensure the equality check can always check the v_uint64 part + result->v_uint64 = 0; + result->type_index = TypeIndex::kTVMFFIDataType; + result->zero_padding = 0; + result->v_dtype = src; + } + + TVM_FFI_INLINE static void MoveToAny(DLDataType src, TVMFFIAny *result) { + // clear padding part to ensure the equality check can always check the v_uint64 part + result->v_uint64 = 0; + result->type_index = TypeIndex::kTVMFFIDataType; + result->zero_padding = 0; + result->v_dtype = src; + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + return src->type_index == TypeIndex::kTVMFFIDataType; + } + + TVM_FFI_INLINE static DLDataType CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { + return src->v_dtype; + } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFIDataType) { + return src->v_dtype; + } + // enable string to dtype auto conversion + if (auto opt_str = TypeTraits::TryCastFromAnyView(src)) { + return StringToDLDataType(*opt_str); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return ffi::StaticTypeKey::kTVMFFIDataType; } + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":")" + std::string(ffi::StaticTypeKey::kTVMFFIDataType) + R"("})"; + } +}; +} // namespace ffi +} // namespace tvm + +// define DLDataType comparison and printing in root namespace +inline std::ostream &operator<<(std::ostream &os, DLDataType dtype) { // NOLINT(*) + return os << tvm::ffi::DLDataTypeToString(dtype); +} + +inline bool operator==(const DLDataType &lhs, const DLDataType &rhs) { + return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes; +} + +inline bool operator!=(const DLDataType &lhs, const DLDataType &rhs) { return !(lhs == rhs); } +#endif // TVM_FFI_DTYPE_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/endian.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/endian.h new file mode 100644 index 000000000..10639bea3 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/endian.h @@ -0,0 +1,90 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file tvm/ffi/endian.h + * \brief Endian detection and handling + */ +#ifndef TVM_FFI_ENDIAN_H_ +#define TVM_FFI_ENDIAN_H_ + +#include +#include + +#ifndef TVM_FFI_IO_USE_LITTLE_ENDIAN +#define TVM_FFI_IO_USE_LITTLE_ENDIAN 1 +#endif + +#ifdef TVM_FFI_CMAKE_LITTLE_ENDIAN +// If compiled with CMake, use CMake's endian detection logic +#define TVM_FFI_LITTLE_ENDIAN TVM_FFI_CMAKE_LITTLE_ENDIAN +#else +#if defined(__APPLE__) || defined(_WIN32) +#define TVM_FFI_LITTLE_ENDIAN 1 +#elif defined(__GLIBC__) || defined(__GNU_LIBRARY__) || defined(__ANDROID__) || \ + defined(__RISCV__) || defined(__MUSL__) +#include +#define TVM_FFI_LITTLE_ENDIAN (__BYTE_ORDER == __LITTLE_ENDIAN) +#elif defined(__FreeBSD__) || defined(__OpenBSD__) +#include +#define TVM_FFI_LITTLE_ENDIAN (_BYTE_ORDER == _LITTLE_ENDIAN) +#elif defined(__QNX__) +#include +#define TVM_FFI_LITTLE_ENDIAN (BYTE_ORDER == LITTLE_ENDIAN) +#elif defined(__EMSCRIPTEN__) || defined(__hexagon__) +#define TVM_FFI_LITTLE_ENDIAN 1 +#elif defined(__sun) || defined(sun) +#include +#if defined(_LITTLE_ENDIAN) +#define TVM_FFI_LITTLE_ENDIAN 1 +#else +#define TVM_FFI_LITTLE_ENDIAN 0 +#endif +#else +#error "Unable to determine endianness of your machine; use CMake to compile" +#endif +#endif + +/*! \brief whether serialize using little endian */ +#define TVM_FFI_IO_NO_ENDIAN_SWAP (TVM_FFI_LITTLE_ENDIAN == TVM_FFI_IO_USE_LITTLE_ENDIAN) + +namespace tvm { +namespace ffi { +/*! + * \brief A generic inplace byte swapping function. + * \param data The data pointer. + * \param elem_bytes The number of bytes of the data elements + * \param num_elems Number of elements in the data. + * \note Always try pass in constant elem_bytes to enable + * compiler optimization + */ +inline void ByteSwap(void* data, size_t elem_bytes, size_t num_elems) { + for (size_t i = 0; i < num_elems; ++i) { + uint8_t* bptr = reinterpret_cast(data) + elem_bytes * i; + for (size_t j = 0; j < elem_bytes / 2; ++j) { + uint8_t v = bptr[elem_bytes - 1 - j]; + bptr[elem_bytes - 1 - j] = bptr[j]; + bptr[j] = v; + } + } +} +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_ENDIAN_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/error.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/error.h new file mode 100644 index 000000000..f310dcbcc --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/error.h @@ -0,0 +1,398 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * \file tvm/ffi/error.h + * \brief Error handling component. + */ +#ifndef TVM_FFI_ERROR_H_ +#define TVM_FFI_ERROR_H_ + +#include "base_details.h" +#include "c_api.h" +#include "memory.h" +#include "object.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +/*! + * \brief Macro defines whether we enable libbacktrace + */ +#ifndef TVM_FFI_USE_LIBBACKTRACE +#define TVM_FFI_USE_LIBBACKTRACE 1 +#endif + +/*! + * \brief Macro defines whether to install signal handler + * and print backtrace during segfault + */ +#ifndef TVM_FFI_BACKTRACE_ON_SEGFAULT +#define TVM_FFI_BACKTRACE_ON_SEGFAULT 1 +#endif + +#ifndef TVM_FFI_ALWAYS_LOG_BEFORE_THROW +#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 0 +#endif + +namespace tvm { +namespace ffi { + +/*! + * \brief Error already set in frontend env. + * + * This error can be thrown by EnvCheckSignals to indicate + * that there is an error set in the frontend environment(e.g. + * python interpreter). The TVM FFI should catch this error + * and return a proper code to tell the frontend caller about + * this fact. + * + * \code + * + * void ExampleLongRunningFunction() { + * if (TVMFFIEnvCheckSignals() != 0) { + * throw ::tvm::ffi::EnvErrorAlreadySet(); + * } + * // do work here + * } + * + * \endcode + */ +struct EnvErrorAlreadySet : public std::exception {}; + +/*! + * \brief Error object class. + */ +class ErrorObj : public Object, public TVMFFIErrorCell { +public: + /// \cond Doxygen_Suppress + static constexpr const int32_t _type_index = TypeIndex::kTVMFFIError; + TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIError, ErrorObj, Object); + /// \endcond +}; + +namespace details { +class ErrorObjFromStd : public ErrorObj { +public: + ErrorObjFromStd(std::string kind, std::string message, std::string backtrace) + : kind_data_(std::move(kind)), + message_data_(std::move(message)), + backtrace_data_(std::move(backtrace)) { + this->kind = TVMFFIByteArray{kind_data_.data(), kind_data_.length()}; + this->message = TVMFFIByteArray{message_data_.data(), message_data_.length()}; + this->backtrace = TVMFFIByteArray{backtrace_data_.data(), backtrace_data_.length()}; + this->update_backtrace = UpdateBacktrace; + } + +private: + /*! + * \brief Update the backtrace of the error object. + * \param backtrace The backtrace to update. + * \param update_mode The mode to update the backtrace, + * can be either kTVMFFIBacktraceUpdateModeReplace, kTVMFFIBacktraceUpdateModeAppend. + */ + static void UpdateBacktrace(TVMFFIObjectHandle self, const TVMFFIByteArray *backtrace_str, + int32_t update_mode) { + ErrorObjFromStd *obj = static_cast(self); + if (update_mode == kTVMFFIBacktraceUpdateModeReplace) { + obj->backtrace_data_.resize(backtrace_str->size); + std::memcpy(obj->backtrace_data_.data(), backtrace_str->data, backtrace_str->size); + obj->backtrace = TVMFFIByteArray{obj->backtrace_data_.data(), obj->backtrace_data_.length()}; + } else { + obj->backtrace_data_.append(backtrace_str->data, backtrace_str->size); + obj->backtrace = TVMFFIByteArray{obj->backtrace_data_.data(), obj->backtrace_data_.length()}; + } + } + + std::string kind_data_; + std::string message_data_; + std::string backtrace_data_; +}; +} // namespace details + +/*! + * \brief Managed reference to ErrorObj + * \sa Error Object + */ +class Error : public ObjectRef, public std::exception { +public: + /*! + * \brief Constructor + * \param kind The kind of the error. + * \param message The message of the error. + * \param backtrace The backtrace of the error. + */ + Error(std::string kind, std::string message, std::string backtrace) { + data_ = make_object(std::move(kind), std::move(message), + std::move(backtrace)); + } + + /*! + * \brief Constructor + * \param kind The kind of the error. + * \param message The message of the error. + * \param backtrace The backtrace of the error. + */ + Error(std::string kind, std::string message, const TVMFFIByteArray *backtrace) + : Error(std::move(kind), std::move(message), std::string(backtrace->data, backtrace->size)) {} + + /*! + * \brief Get the kind of the error object. + * \return The kind of the error object. + */ + std::string kind() const { + ErrorObj *obj = static_cast(data_.get()); + return std::string(obj->kind.data, obj->kind.size); + } + + /*! + * \brief Get the message of the error object. + * \return The message of the error object. + */ + std::string message() const { + ErrorObj *obj = static_cast(data_.get()); + return std::string(obj->message.data, obj->message.size); + } + + /*! + * \brief Get the backtrace of the error object. + * \return The backtrace of the error object. + * \note Consider use TracebackMostRecentCallLast for pythonic style traceback. + * + * \sa TracebackMostRecentCallLast + */ + std::string backtrace() const { + ErrorObj *obj = static_cast(data_.get()); + return std::string(obj->backtrace.data, obj->backtrace.size); + } + + /*! + * \brief Get the traceback in the order of most recent call last. + * + * \return The traceback of the error object. + */ + std::string TracebackMostRecentCallLast() const { + // add placeholder for the first line + std::vector line_breakers = {-1}; + ErrorObj *obj = static_cast(data_.get()); + for (size_t i = 0; i < obj->backtrace.size; i++) { + if (obj->backtrace.data[i] == '\n') { + line_breakers.push_back(static_cast(i)); + } + } + std::string result; + result.reserve(obj->backtrace.size); + for (size_t i = line_breakers.size() - 1; i > 0; --i) { + int64_t line_start = line_breakers[i - 1] + 1; + int64_t line_end = line_breakers[i]; + if (line_start == line_end) { + continue; + } + result.append(obj->backtrace.data + line_start, line_end - line_start); + result.append("\n"); + } + return result; + } + + /*! + * \brief Update the backtrace of the error object. + * \param backtrace_str The backtrace to update. + * \param update_mode The mode to update the backtrace, + * can be either kTVMFFIBacktraceUpdateModeReplace, kTVMFFIBacktraceUpdateModeAppend. + */ + void UpdateBacktrace(const TVMFFIByteArray *backtrace_str, int32_t update_mode) { + ErrorObj *obj = static_cast(data_.get()); + obj->update_backtrace(obj, backtrace_str, update_mode); + } + + /*! + * \brief Get the full message of the error, including kind, message and traceback. + * \return The full message of the error object. + */ + std::string FullMessage() const { + ErrorObj *obj = static_cast(data_.get()); + return (std::string("Traceback (most recent call last):\n") + TracebackMostRecentCallLast() + std::string(obj->kind.data, obj->kind.size) + std::string(": ") + std::string(obj->message.data, obj->message.size) + '\n'); + } + + /*! + * \brief Get the error message + * \return The error message + * \note To get the full message including kind and traceback, use FullMessage() instead. + */ + const char *what() const noexcept(true) override { + ErrorObj *obj = static_cast(data_.get()); + return obj->message.data; + } + + /// \cond Doxygen_Suppress + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Error, ObjectRef, ErrorObj); + /// \endcond +}; + +namespace details { + +class ErrorBuilder { +public: + explicit ErrorBuilder(std::string kind, std::string backtrace, bool log_before_throw) + : kind_(std::move(kind)), + backtrace_(std::move(backtrace)), + log_before_throw_(log_before_throw) {} + + explicit ErrorBuilder(std::string kind, const TVMFFIByteArray *backtrace, bool log_before_throw) + : ErrorBuilder(std::move(kind), std::string(backtrace->data, backtrace->size), + log_before_throw) {} + +// MSVC disable warning in error builder as it is exepected +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4722) +#endif + // avoid inline to reduce binary size, error throw path do not need to be fast + [[noreturn]] ~ErrorBuilder() noexcept(false) { + ::tvm::ffi::Error error(std::move(kind_), stream_.str(), std::move(backtrace_)); + if (log_before_throw_) { + std::cerr << error.FullMessage(); + } + throw error; + } +#ifdef _MSC_VER +#pragma warning(pop) +#endif + + std::ostringstream &stream() { return stream_; } + +protected: + std::string kind_; + std::ostringstream stream_; + std::string backtrace_; + bool log_before_throw_; +}; + +} // namespace details + +/*! + * \brief Helper macro to throw an error with backtrace and message + * + * \code + * + * void ThrowError() { + * TVM_FFI_THROW(RuntimeError) << "error message"; + * } + * + * \endcode + */ +#define TVM_FFI_THROW(ErrorKind) \ + ::tvm::ffi::details::ErrorBuilder(#ErrorKind, \ + TVMFFIBacktrace(__FILE__, __LINE__, TVM_FFI_FUNC_SIG, 0), \ + TVM_FFI_ALWAYS_LOG_BEFORE_THROW) \ + .stream() + +/*! + * \brief Explicitly log error in stderr and then throw the error. + * + * \note This is only necessary on startup functions where we know error + * cannot be caught, and it is better to have a clear log message. + * In most cases, we should use use TVM_FFI_THROW. + */ +#define TVM_FFI_LOG_AND_THROW(ErrorKind) \ + ::tvm::ffi::details::ErrorBuilder( \ + #ErrorKind, TVMFFIBacktrace(__FILE__, __LINE__, TVM_FFI_FUNC_SIG, 0), true) \ + .stream() + +// Glog style checks with TVM_FFI prefix +// NOTE: we explicitly avoid glog style generic macros (LOG/CHECK) in tvm ffi +// to avoid potential conflict of downstream users who might have their own GLOG style macros +namespace details { + +template +TVM_FFI_INLINE std::unique_ptr LogCheckFormat(const X &x, const Y &y) { + std::ostringstream os; + os << " (" << x << " vs. " << y << ") "; // CHECK_XX(x, y) requires x and y can be serialized to + // string. Use CHECK(x OP y) otherwise. + return std::make_unique(os.str()); +} + +#define TVM_FFI_CHECK_FUNC(name, op) \ + template \ + TVM_FFI_INLINE std::unique_ptr LogCheck##name(const X &x, const Y &y) { \ + if (x op y) \ + return nullptr; \ + return LogCheckFormat(x, y); \ + } \ + TVM_FFI_INLINE std::unique_ptr LogCheck##name(int x, int y) { \ + return LogCheck##name(x, y); \ + } + +// Inline _Pragma in macros does not work reliably on old version of MSVC and +// GCC. We wrap all comparisons in a function so that we can use #pragma to +// silence bad comparison warnings. +#if defined(__GNUC__) || defined(__clang__) // GCC and Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wsign-compare" +#elif defined(_MSC_VER) // MSVC +#pragma warning(push) +#pragma warning(disable : 4389) // '==' : signed/unsigned mismatch +#endif + +TVM_FFI_CHECK_FUNC(_LT, <) +TVM_FFI_CHECK_FUNC(_GT, >) +TVM_FFI_CHECK_FUNC(_LE, <=) +TVM_FFI_CHECK_FUNC(_GE, >=) +TVM_FFI_CHECK_FUNC(_EQ, ==) +TVM_FFI_CHECK_FUNC(_NE, !=) + +#if defined(__GNUC__) || defined(__clang__) // GCC and Clang +#pragma GCC diagnostic pop +#elif defined(_MSC_VER) // MSVC +#pragma warning(pop) +#endif +} // namespace details + +#define TVM_FFI_ICHECK_BINARY_OP(name, op, x, y) \ + if (auto __tvm_ffi_log_err = /* NOLINT(bugprone-reserved-identifier) */ \ + ::tvm::ffi::details::LogCheck##name(x, y)) \ + TVM_FFI_THROW(InternalError) << "Check failed: " << #x " " #op " " #y << *__tvm_ffi_log_err \ + << ": " + +#define TVM_FFI_ICHECK(x) \ + if (!(x)) \ + TVM_FFI_THROW(InternalError) << "Check failed: (" #x << ") is false: " + +#define TVM_FFI_CHECK(cond, ErrorKind) \ + if (!(cond)) \ + TVM_FFI_THROW(ErrorKind) << "Check failed: (" #cond << ") is false: " + +#define TVM_FFI_ICHECK_LT(x, y) TVM_FFI_ICHECK_BINARY_OP(_LT, <, x, y) +#define TVM_FFI_ICHECK_GT(x, y) TVM_FFI_ICHECK_BINARY_OP(_GT, >, x, y) +#define TVM_FFI_ICHECK_LE(x, y) TVM_FFI_ICHECK_BINARY_OP(_LE, <=, x, y) +#define TVM_FFI_ICHECK_GE(x, y) TVM_FFI_ICHECK_BINARY_OP(_GE, >=, x, y) +#define TVM_FFI_ICHECK_EQ(x, y) TVM_FFI_ICHECK_BINARY_OP(_EQ, ==, x, y) +#define TVM_FFI_ICHECK_NE(x, y) TVM_FFI_ICHECK_BINARY_OP(_NE, !=, x, y) +#define TVM_FFI_ICHECK_NOTNULL(x) \ + ((x) == nullptr ? TVM_FFI_THROW(InternalError) << "Check not null: " #x << ' ', \ + (x) : (x)) // NOLINT(*) +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_ERROR_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base.h new file mode 100644 index 000000000..b09b3540a --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base.h @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/extra/base.h + * \brief Base header for Extra API. + * + * The extra APIs contains a minmal set of extra APIs that are not + * required to support essential core functionality. + */ +#ifndef TVM_FFI_EXTRA_BASE_H_ +#define TVM_FFI_EXTRA_BASE_H_ + +#include + +/*! + * \brief Marks the API as extra c++ api that is defined in cc files. + * + * They are implemented in cc files to reduce compile-time overhead. + * The input/output only uses POD/Any/ObjectRef for ABI stability. + * However, these extra APIs may have an issue across MSVC/Itanium ABI, + * + * Related features are also available through reflection based function + * that is fully based on C API + * + * The project aims to minimize the number of extra C++ APIs to keep things + * lightweight and restrict the use to non-core functionalities. + */ +#ifndef TVM_FFI_EXTRA_CXX_API +#define TVM_FFI_EXTRA_CXX_API TVM_FFI_DLL +#endif + +#endif // TVM_FFI_EXTRA_BASE_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base64.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base64.h new file mode 100644 index 000000000..ac92e9f84 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base64.h @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * + * \file tvm/ffi/extra/base64.h + * \brief Base64 encoding and decoding utilities + */ +#ifndef TVM_FFI_EXTRA_BASE64_H_ +#define TVM_FFI_EXTRA_BASE64_H_ + +#include + +#include + +namespace tvm { +namespace ffi { +/*! + * \brief Encode a byte array into a base64 string + * \param bytes The byte array to encode + * \return The base64 encoded string + */ +inline String Base64Encode(TVMFFIByteArray bytes) { + // encoding every 3 bytes into 4 characters + constexpr const char kEncodeTable[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::string encoded; + encoded.reserve(4 * (bytes.size + 2) / 3); + + for (size_t i = 0; i < (bytes.size / 3) * 3; i += 3) { + int32_t buf[3]; + buf[0] = static_cast(static_cast(bytes.data[i])); + buf[1] = static_cast(static_cast(bytes.data[i + 1])); + buf[2] = static_cast(static_cast(bytes.data[i + 2])); + encoded.push_back(kEncodeTable[buf[0] >> 2]); + encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]); + encoded.push_back(kEncodeTable[((buf[1] << 2) | (buf[2] >> 6)) & 0x3F]); + encoded.push_back(kEncodeTable[buf[2] & 0x3F]); + } + if (bytes.size % 3 == 1) { + int32_t buf[1] = {static_cast(static_cast(bytes.data[bytes.size - 1]))}; + encoded.push_back(kEncodeTable[buf[0] >> 2]); + encoded.push_back(kEncodeTable[(buf[0] << 4) & 0x3F]); + encoded.push_back('='); + encoded.push_back('='); + } else if (bytes.size % 3 == 2) { + int32_t buf[2] = {static_cast(static_cast(bytes.data[bytes.size - 2])), + static_cast(static_cast(bytes.data[bytes.size - 1]))}; + encoded.push_back(kEncodeTable[buf[0] >> 2]); + encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]); + encoded.push_back(kEncodeTable[(buf[1] << 2) & 0x3F]); + encoded.push_back('='); + } + return String(encoded); +} + +/*! + * \brief Encode a bytes object into a base64 string + * \param data The bytes object to encode + * \return The base64 encoded string + */ +inline String Base64Encode(const Bytes& data) { + return Base64Encode(TVMFFIByteArray{data.data(), data.size()}); +} + +/*! + * \brief Decode a base64 string into a byte array + * \param bytes The bytes to be decoded + * \return The decoded byte array + */ +inline Bytes Base64Decode(TVMFFIByteArray bytes) { + constexpr const char kDecodeTable[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 62, // '+' + 0, 0, 0, + 63, // '/' + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' + 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' + 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, + 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' + }; + std::string decoded; + decoded.reserve(bytes.size * 3 / 4); + if (bytes.size == 0) return Bytes(); + TVM_FFI_ICHECK(bytes.size % 4 == 0) << "invalid base64 encoding"; + // leverage this property to simplify decoding + static_assert('=' < sizeof(kDecodeTable) && kDecodeTable[static_cast('=')] == 0); + // base64 is always multiple of 4 bytes + for (size_t i = 0; i < bytes.size; i += 4) { + // decode every 4 characters into 24bits, each character contains 6 bits + // note that = is also decoded as 0, which is safe to skip + int32_t buf[4] = { + static_cast(static_cast(bytes.data[i])), + static_cast(static_cast(bytes.data[i + 1])), + static_cast(static_cast(bytes.data[i + 2])), + static_cast(static_cast(bytes.data[i + 3])), + }; + int32_t value_i24 = (static_cast(kDecodeTable[buf[0]]) << 18) | + (static_cast(kDecodeTable[buf[1]]) << 12) | + (static_cast(kDecodeTable[buf[2]]) << 6) | + static_cast(kDecodeTable[buf[3]]); + // unpack 24bits into 3 bytes, each contains 8 bits + decoded.push_back(static_cast((value_i24 >> 16) & 0xFF)); + if (buf[2] != '=') { + decoded.push_back(static_cast((value_i24 >> 8) & 0xFF)); + } + if (buf[3] != '=') { + decoded.push_back(static_cast(value_i24 & 0xFF)); + } + } + return Bytes(decoded); +} + +/*! + * \brief Decode a base64 string into a byte array + * \param data The base64 encoded string to decode + * \return The decoded byte array + */ +inline Bytes Base64Decode(const String& data) { + return Base64Decode(TVMFFIByteArray{data.data(), data.size()}); +} + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_BASE64_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/c_env_api.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/c_env_api.h new file mode 100644 index 000000000..9f879705c --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/c_env_api.h @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +// NOLINTBEGIN(modernize-use-using) +/*! + * \file tvm/ffi/extra/c_env_api.h + * \brief Extra environment API. + */ +#ifndef TVM_FFI_EXTRA_C_ENV_API_H_ +#define TVM_FFI_EXTRA_C_ENV_API_H_ + +#include "../c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// ---------------------------------------------------------------------------- +// Stream context +// Focusing on minimalistic thread-local context recording stream being used. +// We explicitly not handle allocation/de-allocation of stream here. +// ---------------------------------------------------------------------------- +/*! + * \brief The type of the stream handle. + */ +typedef void *TVMFFIStreamHandle; + +/*! + * \brief FFI function to set the current stream for a device + * + * \param device_type The type of the device. + * \param device_id The id of the device. + * \param stream The stream to set. + * \param opt_out_original_stream Output original stream if the address is not nullptr. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, + TVMFFIStreamHandle stream, + TVMFFIStreamHandle *opt_out_original_stream); + +/*! + * \brief FFI function to get the current stream for a device + * + * \param device_type The type of the device. + * \param device_id The id of the device. + * \return The current stream of the device. + */ +TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetStream(int32_t device_type, int32_t device_id); + +/*! + * \brief Set the current DLPackManagedTensorAllocator in thread-local(TLS) context + * + * \param allocator The allocator to set. + * \param write_to_global_context Whether to also set the allocator to the global context. + * \param opt_out_original_allocator Output original TLS allocator if the address is not nullptr. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvSetDLPackManagedTensorAllocator( + DLPackManagedTensorAllocator allocator, int write_to_global_context, + DLPackManagedTensorAllocator *opt_out_original_allocator); + +/*! + * \brief FFI function get the current DLPackManagedTensorAllocator stored in context. + * + * This function first queries the global context, and if not found, + * queries the thread-local context. + * + * \return The current setted DLPackManagedTensorAllocator + */ +TVM_FFI_DLL DLPackManagedTensorAllocator TVMFFIEnvGetDLPackManagedTensorAllocator(); + +/*! + * \brief Allocate a tensor from the allocator set in thread-local(TLS) context. + * + * This function redirects to one of environment allocator. As of now, we only + * support the DLPackManagedTensorAllocator set in thread-local(TLS) context. + * + * \param prototype The prototype DLTensor, only the dtype, ndim, shape, + * and device fields are used, other fields are ignored. + * \param out The output tensor in kTVMFFITensor type. + * \return 0 when success, nonzero when failure happens + * \sa TVMFFIEnvSetDLPackManagedTensorAllocator + */ +TVM_FFI_DLL int TVMFFIEnvTensorAlloc(DLTensor *prototype, TVMFFIObjectHandle *out); + +/*! + * \brief Check if there are any signals raised in the surrounding env. + * \return 0 when success, nonzero when failure happens + * \note Under python this function redirects to PyErr_CheckSignals + */ +TVM_FFI_DLL int TVMFFIEnvCheckSignals(); + +/*! + * \brief Register a symbol into the from the surrounding env such as python + * \param name The name of the symbol. + * \param symbol The symbol to register. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvRegisterCAPI(const char *name, void *symbol); + +// ---------------------------------------------------------------------------- +// Module symbol management in callee side +// ---------------------------------------------------------------------------- +/*! + * \brief FFI function to lookup a function from a module's imports. + * + * This is a helper function that is used by generated code. + * + * \param library_ctx The library context module handle. + * \param func_name The name of the function. + * \param out The result function. + * \note The returned function is a weak reference that is cached/owned by the module. + * \return 0 when no error is thrown, -1 when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvModLookupFromImports(TVMFFIObjectHandle library_ctx, const char *func_name, + TVMFFIObjectHandle *out); + +/*! + * \brief Register a symbol value that will be initialized when a library with the symbol is loaded. + * + * This function can be used to make context functions to be available in the library + * module that wants to avoid an explicit link dependency + * + * \param name The name of the symbol. + * \param symbol The symbol to register. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvModRegisterContextSymbol(const char *name, void *symbol); + +/*! + * \brief Register a symbol that will be initialized when a system library is loaded. + * + * \param name The name of the symbol. + * \param symbol The symbol to register. + * \return 0 when success, nonzero when failure happens + */ +TVM_FFI_DLL int TVMFFIEnvModRegisterSystemLibSymbol(const char *name, void *symbol); + +#ifdef __cplusplus +} // extern "C" +#endif +#endif // TVM_FFI_EXTRA_C_ENV_API_H_ +// NOLINTEND(modernize-use-using) diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/base.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/base.h new file mode 100644 index 000000000..810fa064c --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/base.h @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/extra/cuda/base.h + * \brief CUDA base utilities. + */ +#ifndef TVM_FFI_EXTRA_CUDA_BASE_H_ +#define TVM_FFI_EXTRA_CUDA_BASE_H_ + +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief Macro for checking CUDA runtime API errors. + * + * This macro checks the return value of CUDA runtime API calls and throws + * a RuntimeError with detailed error information if the call fails. + * + * \param stmt The CUDA runtime API call to check. + */ +#define TVM_FFI_CHECK_CUDA_ERROR(stmt) \ + do { \ + cudaError_t __err = (stmt); \ + if (__err != cudaSuccess) { \ + const char* __err_name = cudaGetErrorName(__err); \ + const char* __err_str = cudaGetErrorString(__err); \ + TVM_FFI_THROW(RuntimeError) << "CUDA Runtime Error: " << __err_name << " (" \ + << static_cast(__err) << "): " << __err_str; \ + } \ + } while (0) + +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_EXTRA_CUDA_BASE_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/cubin_launcher.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/cubin_launcher.h new file mode 100644 index 000000000..72eadd2ea --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/cubin_launcher.h @@ -0,0 +1,604 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/extra/cuda/cubin_launcher.h + * \brief CUDA CUBIN launcher utility for loading and executing CUDA kernels. + * + * This header provides a lightweight C++ wrapper around CUDA Runtime API + * for loading CUBIN modules and launching kernels. It supports: + * - Loading CUBIN from memory (embedded data) + * - Multi-GPU execution using CUDA primary contexts + * - Kernel parameter management and launch configuration + */ +#ifndef TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_ +#define TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_ + +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief A simple 3D dimension type for CUDA kernel launch configuration. + * + * This struct mimics the behavior of dim3 from CUDA Runtime API and provides + * a compatible interface for kernel launch configuration. It can be constructed + * from 1, 2, or 3 dimensions. + */ +struct dim3 { + /*! \brief X dimension (number of blocks in x-direction or threads in x-direction) */ + unsigned int x; + /*! \brief Y dimension (number of blocks in y-direction or threads in y-direction) */ + unsigned int y; + /*! \brief Z dimension (number of blocks in z-direction or threads in z-direction) */ + unsigned int z; + + /*! \brief Default constructor initializes to (1, 1, 1) */ + dim3() : x(1), y(1), z(1) {} + + /*! \brief Construct with x dimension, y and z default to 1 */ + explicit dim3(unsigned int x_) : x(x_), y(1), z(1) {} + + /*! \brief Construct with x and y dimensions, z defaults to 1 */ + dim3(unsigned int x_, unsigned int y_) : x(x_), y(y_), z(1) {} + + /*! \brief Construct with all three dimensions */ + dim3(unsigned int x_, unsigned int y_, unsigned int z_) : x(x_), y(y_), z(z_) {} +}; + +/*! + * \brief Macro to embed a CUBIN module with static initialization. + * + * This macro declares external symbols for embedded CUBIN data and creates + * a singleton struct to manage the CubinModule instance. The CUBIN data + * symbols should be named `__tvm_ffi__cubin_` and `__tvm_ffi__cubin__end`, + * typically created using objcopy and ld. + * + * \par Creating Embedded CUBIN with TVM-FFI Utilities + * TVM-FFI provides utilities to simplify CUBIN embedding. You have two options: + * + * \par Option 1: CMake Utility (Recommended) + * Use the `tvm_ffi_embed_cubin` CMake function: + * \code{.cmake} + * # Find tvm_ffi package (provides tvm_ffi_embed_cubin utility) + * find_package(tvm_ffi CONFIG REQUIRED) + * find_package(CUDAToolkit REQUIRED) + * + * # Compile CUDA kernel to CUBIN + * tvm_ffi_generate_cubin( + * OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/kernel.cubin + * SOURCE src/kernel.cu + * ARCH native # or sm_75, sm_80, etc. + * ) + * + * # Embed CUBIN into C++ object file + * tvm_ffi_embed_cubin( + * OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/mycode_with_cubin.o + * SOURCE src/mycode.cc + * CUBIN ${CMAKE_CURRENT_BINARY_DIR}/kernel.cubin + * NAME my_kernels # Must match TVM_FFI_EMBED_CUBIN(my_kernels) in code + * ) + * + * # Link into shared library + * add_library(mylib SHARED ${CMAKE_CURRENT_BINARY_DIR}/mycode_with_cubin.o) + * target_link_libraries(mylib PRIVATE tvm_ffi_header CUDA::cudart) + * \endcode + * + * \par Option 2: Python Utility + * Use the `tvm_ffi.utils.embed_cubin` command-line tool: + * \code{.bash} + * # Step 1: Compile CUDA kernel to CUBIN + * nvcc --cubin -arch=sm_75 kernel.cu -o kernel.cubin + * + * # Step 2: Compile C++ source to object file + * g++ -c -fPIC -std=c++17 -I/path/to/tvm-ffi/include mycode.cc -o mycode.o + * + * # Step 3: Embed CUBIN using Python utility + * python -m tvm_ffi.utils.embed_cubin \ + * --output-obj mycode_with_cubin.o \ + * --input-obj mycode.o \ + * --cubin kernel.cubin \ + * --name my_kernels + * + * # Step 4: Link into shared library + * g++ -o mylib.so -shared mycode_with_cubin.o -lcudart + * \endcode + * + * The utilities automatically handle: + * - Symbol renaming to __tvm_ffi__cubin_ format + * - Adding .note.GNU-stack section for security + * - Symbol localization to prevent conflicts + * + * \par Usage in C++ Code + * In your C++ source file, use the embedded CUBIN: + * \code{.cpp} + * #include + * + * // Declare the embedded CUBIN module (name must match CMake NAME parameter) + * TVM_FFI_EMBED_CUBIN(my_kernels); + * + * void MyFunction() { + * // Get kernel from embedded CUBIN (cached in static variable for efficiency) + * static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(my_kernels, "my_kernel"); + * // Use kernel... + * } + * \endcode + * + * \note CMake Setup: To use the utilities, add to your CMakeLists.txt: + * \code{.cmake} + * find_package(tvm_ffi CONFIG REQUIRED) # Provides tvm_ffi_embed_cubin utility + * \endcode + * + * \par Option 3: Python Integration with load_inline + * When using `tvm_ffi.cpp.load_inline()` with the `embed_cubin` parameter, + * the CUBIN data is automatically embedded using the Python utility internally: + * \code{.py} + * from tvm_ffi import cpp + * from tvm_ffi.cpp import nvrtc + * + * # Compile CUDA source to CUBIN + * cubin_bytes = nvrtc.nvrtc_compile(cuda_source) + * + * # Load with embedded CUBIN - automatically handles embedding + * mod = cpp.load_inline( + * "my_module", + * cuda_sources=cpp_code, + * embed_cubin={"my_kernels": cubin_bytes}, + * extra_ldflags=["-lcudart"] + * ) + * \endcode + * + * \param name The identifier for this embedded CUBIN module (must match the + * symbol names created with objcopy or the key in embed_cubin dict). + * + * \see TVM_FFI_EMBED_CUBIN_GET_KERNEL + * \see CubinModule + * \see CubinKernel + */ +#define TVM_FFI_EMBED_CUBIN(name) \ + extern "C" const char __tvm_ffi__cubin_##name[]; \ + extern "C" const char __tvm_ffi__cubin_##name##_end[]; \ + namespace { \ + struct EmbedCubinModule_##name { \ + tvm::ffi::CubinModule mod{__tvm_ffi__cubin_##name}; \ + static EmbedCubinModule_##name* Global() { \ + static EmbedCubinModule_##name inst; \ + return &inst; \ + } \ + }; \ + } /* anonymous namespace */ + +/*! + * \brief Macro to get a kernel from an embedded CUBIN module. + * + * This macro retrieves a kernel by name from a previously declared embedded + * CUBIN module (using TVM_FFI_EMBED_CUBIN). The result is a CubinKernel object + * that can be used to launch the kernel with specified parameters. + * + * \par Performance Tip + * It's recommended to store the result in a static variable to avoid repeated + * kernel lookups, which improves performance: + * \code{.cpp} + * static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(my_kernels, "kernel_name"); + * \endcode + * + * \par Complete Example + * \code{.cpp} + * // Declare embedded CUBIN module + * TVM_FFI_EMBED_CUBIN(my_kernels); + * + * void LaunchKernel(tvm::ffi::TensorView input, tvm::ffi::TensorView output) { + * // Get kernel (cached in static variable for efficiency) + * static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(my_kernels, "add_one"); + * + * // Prepare kernel arguments + * void* in_ptr = input.data_ptr(); + * void* out_ptr = output.data_ptr(); + * int64_t n = input.size(0); + * void* args[] = {&in_ptr, &out_ptr, &n}; + * + * // Configure launch + * tvm::ffi::dim3 grid((n + 255) / 256); + * tvm::ffi::dim3 block(256); + * + * // Get stream and launch + * DLDevice device = input.device(); + * cudaStream_t stream = static_cast( + * TVMFFIEnvGetStream(device.device_type, device.device_id)); + * + * cudaError_t result = kernel.Launch(args, grid, block, stream); + * TVM_FFI_CHECK_CUDA_ERROR(result); + * } + * \endcode + * + * \param name The identifier of the embedded CUBIN module (must match the name + * used in TVM_FFI_EMBED_CUBIN). + * \param kernel_name The name of the kernel function as it appears in the CUBIN + * (typically the function name for `extern "C"` kernels). + * \return A CubinKernel object for the specified kernel. + * + * \see TVM_FFI_EMBED_CUBIN + * \see CubinKernel::Launch + */ +#define TVM_FFI_EMBED_CUBIN_GET_KERNEL(name, kernel_name) \ + (EmbedCubinModule_##name::Global()->mod[kernel_name]) + +// Forward declaration +class CubinKernel; + +/*! + * \brief CUDA CUBIN module loader and manager. + * + * This class provides a RAII wrapper around CUDA Runtime API's library management. + * It loads a CUBIN module from memory and manages the library handle automatically. + * The library is unloaded when the CubinModule object is destroyed. + * + * \par Features + * - Load CUBIN from memory (embedded data or runtime-generated) + * - Automatic resource management (RAII pattern) + * - Multi-GPU execution using CUDA primary contexts + * - Retrieve multiple kernels from the same module + * + * \par Example Usage + * \code{.cpp} + * // Load CUBIN from memory + * tvm::ffi::Bytes cubin_data = ...; + * tvm::ffi::CubinModule module(cubin_data); + * + * // Get kernels by name + * tvm::ffi::CubinKernel kernel1 = module["add_one"]; + * tvm::ffi::CubinKernel kernel2 = module.GetKernel("mul_two"); + * + * // Launch kernels + * void* args[] = {...}; + * tvm::ffi::dim3 grid(32), block(256); + * cudaStream_t stream = ...; + * kernel1.Launch(args, grid, block, stream); + * \endcode + * + * \note This class is movable but not copyable. + * \see TVM_FFI_EMBED_CUBIN for embedding CUBIN at compile time + * \see CubinKernel for kernel launching + */ +class CubinModule { + public: + /*! + * \brief Load CUBIN module from memory. + * + * \param bytes CUBIN binary data as a Bytes object. + */ + explicit CubinModule(const Bytes& bytes) { + TVM_FFI_CHECK_CUDA_ERROR( + cudaLibraryLoadData(&library_, bytes.data(), nullptr, nullptr, 0, nullptr, nullptr, 0)); + } + + /*! + * \brief Load CUBIN module from raw memory buffer. + * + * \param code Pointer to CUBIN binary data. + * \note The `code` buffer points to an ELF image. + */ + explicit CubinModule(const char* code) { + TVM_FFI_CHECK_CUDA_ERROR( + cudaLibraryLoadData(&library_, code, nullptr, nullptr, 0, nullptr, nullptr, 0)); + } + + /*! \brief Destructor unloads the library */ + ~CubinModule() { + if (library_ != nullptr) { + cudaLibraryUnload(library_); + } + } + + /*! + * \brief Get a kernel function from the module by name. + * + * \param name Name of the kernel function. + * \return CubinKernel object representing the loaded kernel. + */ + CubinKernel GetKernel(const char* name); + + /*! + * \brief Get a kernel function from the module by name with maximum dynamic shared memory. + * + * \param name Name of the kernel function. + * \param dynamic_smem_max Maximum dynamic shared memory in bytes to set for this kernel. + * -1 (default) means maximum available dynamic shared memory + * (device max - static shared memory used by kernel). + * \return CubinKernel object representing the loaded kernel. + */ + CubinKernel GetKernelWithMaxDynamicSharedMemory(const char* name, int64_t dynamic_smem_max); + + /*! + * \brief Operator[] for convenient kernel access. + * + * It's equivalent to calling GetKernel(name, -1). + * + * \param name Name of the kernel function. + * \return CubinKernel object representing the loaded kernel. + */ + CubinKernel operator[](const char* name); + + /*! \brief Get the underlying cudaLibrary_t handle */ + cudaLibrary_t GetHandle() const { return library_; } + + // Non-copyable + CubinModule(const CubinModule&) = delete; + CubinModule& operator=(const CubinModule&) = delete; + + /*! + * \brief Move constructor for CubinModule. + * + * Transfers ownership of the CUDA library handle from another CubinModule instance. + * + * \param other The source CubinModule to move from (will be left in an empty state). + */ + CubinModule(CubinModule&& other) noexcept : library_(other.library_) { other.library_ = nullptr; } + + /*! + * \brief Move assignment operator for CubinModule. + * + * Transfers ownership of the CUDA library handle from another CubinModule instance. + * Cleans up any existing library handle in this instance before taking ownership. + * + * \param other The source CubinModule to move from (will be left in an empty state). + * \return Reference to this CubinModule. + */ + CubinModule& operator=(CubinModule&& other) noexcept { + if (this != &other) { + if (library_ != nullptr) { + cudaLibraryUnload(library_); + } + library_ = other.library_; + other.library_ = nullptr; + } + return *this; + } + + private: + cudaLibrary_t library_ = nullptr; +}; + +/*! + * \brief CUDA kernel handle for launching kernels. + * + * This class represents a loaded CUDA kernel function and provides + * methods to launch it with specified grid/block dimensions, arguments, + * and stream configuration. Obtained from CubinModule by kernel name. + * + * \par Usage Pattern + * \code{.cpp} + * // Get kernel from module + * tvm::ffi::CubinKernel kernel = module["kernel_name"]; + * + * // Prepare arguments (must be pointers to actual values) + * void* data_ptr = tensor.data_ptr(); + * int64_t size = tensor.size(0); + * void* args[] = {&data_ptr, &size}; + * + * // Configure launch dimensions + * tvm::ffi::dim3 grid(32); // 32 blocks + * tvm::ffi::dim3 block(256); // 256 threads per block + * + * // Launch on stream + * cudaStream_t stream = ...; + * cudaError_t result = kernel.Launch(args, grid, block, stream); + * TVM_FFI_CHECK_CUDA_ERROR(result); + * \endcode + * + * \note This class is movable but not copyable. + * \see CubinModule for loading CUBIN and getting kernels + * \see dim3 for grid/block dimension specification + */ +class CubinKernel { + public: + /*! + * \brief Construct a CubinKernel from a library and kernel name. + * + * \param library The cudaLibrary_t handle. + * \param name Name of the kernel function. + */ + CubinKernel(cudaLibrary_t library, const char* name) { + TVM_FFI_CHECK_CUDA_ERROR(cudaLibraryGetKernel(&kernel_, library, name)); + } + + /*! \brief Destructor (kernel handle doesn't need explicit cleanup) */ + ~CubinKernel() = default; + + /*! + * \brief Launch the kernel with specified parameters. + * + * This function launches the kernel on the current CUDA context/device using + * the CUDA Runtime API. The kernel executes asynchronously on the specified stream. + * + * \par Argument Preparation + * The `args` array must contain pointers to the actual argument values, not the + * values themselves. For example: + * \code{.cpp} + * void* data_ptr = tensor.data_ptr(); + * int64_t size = 100; + * void* args[] = {&data_ptr, &size}; // Note: addresses of the variables + * \endcode + * + * \par Launch Configuration + * Grid and block dimensions determine the kernel's parallelism: + * - Grid: Number of thread blocks (can be 1D, 2D, or 3D) + * - Block: Number of threads per block (can be 1D, 2D, or 3D) + * - Total threads = grid.x * grid.y * grid.z * block.x * block.y * block.z + * + * \par Error Checking + * Always check the returned cudaError_t: + * \code{.cpp} + * cudaError_t result = kernel.Launch(args, grid, block, stream); + * TVM_FFI_CHECK_CUDA_ERROR(result); + * \endcode + * + * \param args Array of pointers to kernel arguments (must point to actual values). + * \param grid Grid dimensions (number of blocks in x, y, z). + * \param block Block dimensions (threads per block in x, y, z). + * \param stream CUDA stream to launch the kernel on (use 0 for default stream). + * \param dyn_smem_bytes Dynamic shared memory size in bytes (default: 0). + * \return cudaError_t error code from cudaLaunchKernel (cudaSuccess on success). + * + * \note The kernel executes asynchronously. Use cudaStreamSynchronize() or + * cudaDeviceSynchronize() to wait for completion if needed. + */ + cudaError_t Launch(void** args, dim3 grid, dim3 block, cudaStream_t stream, + uint32_t dyn_smem_bytes = 0) { + // Cast cudaKernel_t to const void* for use with cudaLaunchKernel + // The Runtime API accepts cudaKernel_t directly as a function pointer + auto kernel = reinterpret_cast(kernel_); + return cudaLaunchKernel(kernel, {grid.x, grid.y, grid.z}, {block.x, block.y, block.z}, args, + dyn_smem_bytes, stream); + } + + /*! \brief Get the underlying cudaKernel_t handle */ + cudaKernel_t GetHandle() const { return kernel_; } + + // Non-copyable + CubinKernel(const CubinKernel&) = delete; + CubinKernel& operator=(const CubinKernel&) = delete; + + /*! + * \brief Move constructor for CubinKernel. + * + * Transfers ownership of the CUDA kernel handle from another CubinKernel instance. + * + * \param other The source CubinKernel to move from (will be left in an empty state). + */ + CubinKernel(CubinKernel&& other) noexcept : kernel_(other.kernel_) { other.kernel_ = nullptr; } + + /*! + * \brief Move assignment operator for CubinKernel. + * + * Transfers ownership of the CUDA kernel handle from another CubinKernel instance. + * + * \param other The source CubinKernel to move from (will be left in an empty state). + * \return Reference to this CubinKernel. + */ + CubinKernel& operator=(CubinKernel&& other) noexcept { + if (this != &other) { + kernel_ = other.kernel_; + other.kernel_ = nullptr; + } + return *this; + } + + private: + /*! + * \brief Set maximum dynamic shared memory for this kernel across all devices. + * + * This method configures the maximum dynamic shared memory that can be allocated + * when launching this kernel. It must be called after the kernel is loaded. + * + * \param dynamic_smem_max Maximum dynamic shared memory in bytes to set. + * -1 (default) means maximum available dynamic shared memory, + * which is computed as (device max shared memory - static shared memory). + * For -1, the method queries the kernel's static shared memory usage + * and sets the attribute to the remaining available shared memory. + * + * \note This sets the maximum cap but doesn't force allocation. The actual dynamic + * shared memory used is controlled by the dyn_smem_bytes parameter in Launch(). + * \note This method attempts to set the attribute for all available devices and will + * only throw an error if it fails for ALL devices. + */ + void SetMaxDynamicSharedMemory(int64_t dynamic_smem_max = -1) { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + return; // No devices available, nothing to configure + } + + bool any_success = false; + for (int device_id = 0; device_id < device_count; ++device_id) { + // Query device's maximum shared memory per block + int max_shared_mem = 0; + err = cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlock, device_id); + if (err != cudaSuccess) { + continue; // Skip this device if we can't get its attribute + } + + int shared_mem_to_set; + if (dynamic_smem_max == -1) { + // Query the kernel's static shared memory usage + cudaFuncAttributes func_attr; + + // According to the documentation, we can use cudaFuncGetAttributes to get the attributes of + // cudaKernel_t returned by cudaLibraryGetKernel, just cast the kernel_ to const void* + err = cudaFuncGetAttributes(&func_attr, reinterpret_cast(kernel_)); + if (err != cudaSuccess) { + continue; // Skip this device if we can't get kernel attributes + } + + // Calculate available dynamic shared memory: + // device max shared memory - static shared memory used by kernel + int64_t static_shared = static_cast(func_attr.sharedSizeBytes); + int64_t max_shared = static_cast(max_shared_mem); + int64_t available = max_shared - static_shared; + shared_mem_to_set = (available > 0) ? static_cast(available) : 0; + } else { + shared_mem_to_set = static_cast(dynamic_smem_max); + } + + // Set the maximum dynamic shared memory size for this device + err = cudaKernelSetAttributeForDevice(kernel_, cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem_to_set, device_id); + if (err == cudaSuccess) { + any_success = true; + } + // Don't error out for individual device failures - user may only use some GPUs + } + + // Only error out if setting failed for ALL devices + if (!any_success && device_count > 0) { + TVM_FFI_THROW(RuntimeError) << "Failed to set dynamic shared memory attribute for any device"; + } + } + + cudaKernel_t kernel_ = nullptr; + + friend class CubinModule; +}; + +// Implementation of CubinModule methods that return CubinKernel +inline CubinKernel CubinModule::GetKernelWithMaxDynamicSharedMemory(const char* name, + int64_t dynamic_smem_max = -1) { + auto kernel = CubinKernel(library_, name); + kernel.SetMaxDynamicSharedMemory(dynamic_smem_max); + return kernel; +} + +inline CubinKernel CubinModule::GetKernel(const char* name) { + auto kernel = CubinKernel(library_, name); + return kernel; +} + +inline CubinKernel CubinModule::operator[](const char* name) { return GetKernel(name); } + +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/device_guard.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/device_guard.h new file mode 100644 index 000000000..083580f76 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/device_guard.h @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/extra/cuda/device_guard.h + * \brief Device guard structs. + */ +#ifndef TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_ +#define TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_ + +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief CUDA Device Guard. On construction, it calls `cudaGetDevice` to set the CUDA device to + * target index, and stores the original current device index. And on destruction, it sets the + * current CUDA device back to original device index. + * + * Example usage: + * \code + * void kernel(ffi::TensorView x) { + * ffi::CUDADeviceGuard guard(x.device().device_id); + * ... + * } + * \endcode + */ +struct CUDADeviceGuard { + CUDADeviceGuard() = delete; + /*! + * \brief Constructor from a device index, and store the original device index. + * \param device_index The device index to guard. + */ + explicit CUDADeviceGuard(int device_index) { + target_device_index_ = device_index; + TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&original_device_index_)); + if (target_device_index_ != original_device_index_) { + TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(device_index)); + } + } + + /*! + * \brief Destructor to set the current device index back to original one if different. + */ + ~CUDADeviceGuard() noexcept(false) { + if (original_device_index_ != target_device_index_) { + TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(original_device_index_)); + } + } + + private: + int original_device_index_; + int target_device_index_; +}; + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/json.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/json.h new file mode 100644 index 000000000..24ab2f0d8 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/json.h @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/extra/json.h + * \brief Minimal lightweight JSON parsing and serialization utilities + */ +#ifndef TVM_FFI_EXTRA_JSON_H_ +#define TVM_FFI_EXTRA_JSON_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace ffi { +namespace json { + +/*! + * \brief alias Any as json Value. + * + * To keep things lightweight, we simply reuse the ffi::Any system. + */ +using Value = Any; + +/*! + * \brief alias Map as json Object. + * \note We use Map instead of Map to avoid + * the overhead of key checking when doing as conversion, + * the check will be performed at runtime when we read each key + */ +using Object = ffi::Map; + +/*! \brief alias Array as json Array. */ +using Array = ffi::Array; + +/*! + * \brief Parse a JSON string into an Any value. + * + * Besides the standard JSON syntax, this function also supports: + * - Infinity/NaN as JavaScript syntax + * - int64 integer value + * + * If error_msg is not nullptr, the error message will be written to it + * and no exception will be thrown when parsing fails. + * + * \param json_str The JSON string to parse. + * \param error_msg The output error message, can be nullptr. + * + * \return The parsed Any value. + */ +TVM_FFI_EXTRA_CXX_API json::Value Parse(const String& json_str, String* error_msg = nullptr); + +/*! + * \brief Serialize an Any value into a JSON string. + * + * \param value The Any value to serialize. + * \param indent The number of spaces to indent the output. + * If not specified, the output will be compact. + * \return The output JSON string. + */ +TVM_FFI_EXTRA_CXX_API String Stringify(const json::Value& value, + Optional indent = std::nullopt); + +} // namespace json +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_JSON_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/module.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/module.h new file mode 100644 index 000000000..6af26c252 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/module.h @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/extra/module.h + * \brief A managed dynamic module in the TVM FFI. + */ +#ifndef TVM_FFI_EXTRA_MODULE_H_ +#define TVM_FFI_EXTRA_MODULE_H_ + +#include +#include +#include +#include + +#include + +namespace tvm { +namespace ffi { + +// forward declare Module +class Module; + +/*! + * \brief A module that can dynamically load ffi::Functions or exportable source code. + * \sa Module + */ +class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object { + public: + /*! + * \return The per module type key. + * \note This key is used to for serializing custom modules. + */ + virtual const char* kind() const = 0; + /*! + * \brief Get the property mask of the module. + * \return The property mask of the module. + * + * \sa Module::ModulePropertyMask + */ + virtual int GetPropertyMask() const { return 0b000; } + /*! + * \brief Get a ffi::Function from the module. + * \param name The name of the function. + * \return The function. + */ + virtual Optional GetFunction(const String& name) = 0; + /*! + * \brief Returns true if this module has a definition for a function of \p name. + * + * Note that even if this function returns true the corresponding \p GetFunction result + * may be nullptr if the function is not yet callable without further compilation. + * + * The default implementation just checks if \p GetFunction is non-null. + * \param name The name of the function. + * \return True if the module implements the function, false otherwise. + */ + virtual bool ImplementsFunction(const String& name) { return GetFunction(name).defined(); } + /*! + * \brief Get the docstring of the function, if available. + * \param name The name of the function. + * \return The documentation string if available, nullopt otherwise. + * + * \sa GetFunctionMetadata, TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC + */ + virtual Optional GetFunctionDoc(const String& name) { return std::nullopt; } + // Rationale: We separate the docstring from the metadata since docstrings + // can be unstructured and sometimes large, while metadata can be focused + // on storing structured information. + /*! + * \brief Get the metadata of the function, if available. + * \param name The name of the function. + * \return The metadata as JSON string if available, nullopt otherwise. + * + * \code + * Module mod = Module::LoadFromFile("lib.so"); + * Optional metadata = mod->GetFunctionMetadata("my_func"); + * if (metadata.has_value()) { + * // Parse JSON: {"type_schema": "..."} + * validate_signature(*metadata); + * } + * \endcode + * + * \sa GetFunctionDoc, TVM_FFI_DLL_EXPORT_TYPED_FUNC + */ + virtual Optional GetFunctionMetadata(const String& name) { return std::nullopt; } + /*! + * \brief Write the current module to file with given format (for further compilation). + * + * \param file_name The file to be saved to. + * \param format The format of the file. + * + * \note This function is mainly used by modules that + */ + virtual void WriteToFile(const String& file_name, const String& format) const { + TVM_FFI_THROW(RuntimeError) << "Module[" << kind() << "] does not support WriteToFile"; + } + /*! + * \brief Get the possible write formats of the module, when available. + * \return Possible write formats when available. + */ + virtual Array GetWriteFormats() const { return Array(); } + /*! + * \brief Serialize the the module to bytes. + * \return The serialized module. + */ + virtual Bytes SaveToBytes() const { + TVM_FFI_THROW(RuntimeError) << "Module[" << kind() << "] does not support SaveToBytes"; + TVM_FFI_UNREACHABLE(); + } + /*! + * \brief Get the source code of module, when available. + * \param format Format of the source code, can be empty by default. + * \return Possible source code when available, or empty string if not available. + */ + virtual String InspectSource(const String& format) const { return String(); } + /*! + * \brief Import another module. + * \param other The module to import. + */ + virtual void ImportModule(const Module& other); + /*! + * \brief Clear all imported modules. + */ + virtual void ClearImports(); + /*! + * \brief Overloaded function to optionally query from imports. + * \param name The name of the function. + * \param query_imports Whether to query imported modules. + * \return The function. + */ + Optional GetFunction(const String& name, bool query_imports); + /*! + * \brief Overloaded function to optionally query from imports. + * \param name The name of the function. + * \param query_imports Whether to query imported modules. + * \return True if the module implements the function, false otherwise. + */ + bool ImplementsFunction(const String& name, bool query_imports); + /*! + * \brief Get the function docstring of the function if available. + * \param name The name of the function. + * \param query_imports Whether to also query modules imported by this module. + * \return The documentation string if available, nullopt otherwise. + * + * \sa GetFunctionMetadata + */ + Optional GetFunctionDoc(const String& name, bool query_imports); + /*! + * \brief Get the function metadata of the function if available. + * \param name The name of the function. + * \param query_imports Whether to also query modules imported by this module. + * \return The metadata as JSON string if available, nullopt otherwise. + * + * \sa GetFunctionDoc + */ + Optional GetFunctionMetadata(const String& name, bool query_imports); + /*! + * \brief Get the imports of the module. + * \return The imports of the module. + * \note Note the signature is not part of the public API. + */ + const Array& imports() const { return this->imports_; } + + struct InternalUnsafe; + + /// \cond Doxygen_Suppress + static constexpr const int32_t _type_index = TypeIndex::kTVMFFIModule; + static constexpr const bool _type_mutable = true; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIModule, ModuleObj, Object); + /// \endcond + + protected: + friend struct InternalUnsafe; + + /*! + * \brief The modules that this module depends on. + * \note Use ObjectRef to avoid circular dep on Module. + */ + Array imports_; + + private: + /*! + * \brief cache used by TVMFFIModuleLookupFromImports + */ + Map import_lookup_cache_; +}; + +/*! + * \brief Reference to module object. + * + * When invoking a function on a ModuleObj, such as GetFunction, + * use operator-> to get the ModuleObj pointer and invoke the member functions. + * + * \code + * ffi::Module mod = ffi::Module::LoadFromFile("path/to/module.so"); + * ffi::Function func = mod->GetFunction(name); + * \endcode + * + * \sa ModuleObj which contains most of the function implementations. + */ +class Module : public ObjectRef { + public: + /*! + * \brief Property of ffi::Module + */ + enum ModulePropertyMask : int { + /*! + * \brief The module can be serialized to bytes. + * + * This prooperty indicates that module implements SaveToBytes. + * The system also registers a GlobalDef function + * `ffi.Module.load_from_bytes.` with signature (Bytes) -> Module. + */ + kBinarySerializable = 0b001, + /*! + * \brief The module can directly get runnable functions. + * + * This property indicates that module implements GetFunction that returns + * runnable ffi::Functions. + */ + kRunnable = 0b010, + /*! + * \brief The module can be exported to a object file or source file that then be compiled. + * + * This property indicates that module implements WriteToFile with a given format + * that can be queried by GetLibExportFormat. + * + * Examples include modules that can be exported to .o, .cc, .cu files. + * + * Such modules can be exported, compiled and loaded back as a dynamic library module. + */ + kCompilationExportable = 0b100 + }; + /*! + * \brief Constructor from ObjectPtr. + * \param ptr The object pointer. + */ + explicit Module(const ObjectPtr& ptr) : ObjectRef(ptr) { + TVM_FFI_ICHECK(ptr != nullptr); + } + /*! + * \brief Load a module from file. + * \param file_name The name of the host function module. + * \note This function won't load the import relationship. + * Re-create import relationship by calling Import. + */ + TVM_FFI_EXTRA_CXX_API static Module LoadFromFile(const String& file_name); + /*! + * \brief Query context symbols that is registered via TVMEnvRegisterSymbols. + * \param callback The callback to be called with the symbol name and address. + * \note This helper can be used to implement custom Module that needs to access context symbols. + */ + TVM_FFI_EXTRA_CXX_API static void VisitContextSymbols( + const ffi::TypedFunction& callback); + + /// \cond Doxygen_Suppress + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Module, ObjectRef, ModuleObj); + /// \endcond +}; + +/* + * \brief Symbols for library module. + */ +namespace symbol { +/*!\ brief symbol prefix for tvm ffi related function symbols */ +constexpr const char* tvm_ffi_symbol_prefix = "__tvm_ffi_"; +// Special symbols have one extra _ prefix to avoid conflict with user symbols +/*! + * \brief Default entry function of a library module is tvm_ffi_symbol_prefix + "main" + */ +constexpr const char* tvm_ffi_main = "__tvm_ffi_main"; +/*! \brief Global variable to store context pointer for a library module. */ +constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi__library_ctx"; +/*! \brief Global variable to store binary data alongside a library module. */ +constexpr const char* tvm_ffi_library_bin = "__tvm_ffi__library_bin"; +/*! \brief Optional metadata prefix of a symbol. */ +constexpr const char* tvm_ffi_metadata_prefix = "__tvm_ffi__metadata_"; +/*! \brief Optional documentation prefix of a symbol. */ +constexpr const char* tvm_ffi_doc_prefix = "__tvm_ffi__doc_"; +} // namespace symbol +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_EXTRA_MODULE_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/serialization.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/serialization.h new file mode 100644 index 000000000..b5aa2891a --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/serialization.h @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/extra/serialization.h + * \brief Reflection-based serialization utilities + */ +#ifndef TVM_FFI_EXTRA_SERIALIZATION_H_ +#define TVM_FFI_EXTRA_SERIALIZATION_H_ + +#include +#include + +namespace tvm { +namespace ffi { + +/** + * \brief Serialize ffi::Any to a JSON that stores the object graph. + * + * The JSON graph structure is stored as follows: + * + * ``` + * { + * "root_index": , // Index of root node in nodes array + * "nodes": [, ...], // Array of serialized nodes + * "metadata": // Optional metadata + * } + * ``` + * + * Each node has the format: `{"type": "", "data": }` + * For object types and strings, the data may contain indices to other nodes. + * For object fields whose static type is known as a primitive type, it is stored directly, + * otherwise, it is stored as a reference to the nodes array by an index. + * + * This function preserves the type and multiple references to the same object, + * which is useful for debugging and serialization. + * + * \param value The ffi::Any value to serialize. + * \param metadata Extra metadata attached to "metadata" field of the JSON object. + * \return The serialized JSON value. + */ +TVM_FFI_EXTRA_CXX_API json::Value ToJSONGraph(const Any& value, const Any& metadata = Any(nullptr)); + +/** + * \brief Deserialize a JSON that stores the object graph to an ffi::Any value. + * + * This function can be used to implement deserialization + * and debugging. + * + * \param value The JSON value to deserialize. + * \return The deserialized object graph. + */ +TVM_FFI_EXTRA_CXX_API Any FromJSONGraph(const json::Value& value); + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_SERIALIZATION_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_equal.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_equal.h new file mode 100644 index 000000000..ec960a85e --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_equal.h @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/extra/structural_equal.h + * \brief Structural equal implementation + */ +#ifndef TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ +#define TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace ffi { +/*! + * \brief Structural equality comparators + */ +class StructuralEqual { + public: + /** + * \brief Compare two Any values for structural equality. + * \param lhs The left hand side Any object. + * \param rhs The right hand side Any object. + * \param map_free_vars Whether to map free variables. + * \param skip_tensor_content Whether to skip comparingn darray data content, + * useful for cases where we don't care about parameters content + * \return True if the two Any values are structurally equal, false otherwise. + */ + TVM_FFI_EXTRA_CXX_API static bool Equal(const Any& lhs, const Any& rhs, + bool map_free_vars = false, + bool skip_tensor_content = false); + /** + * \brief Get the first mismatch AccessPath pair when running + * structural equal comparison between two Any values. + * + * \param lhs The left hand side Any object. + * \param rhs The right hand side Any object. + * \param map_free_vars Whether to map free variables. + * \param skip_tensor_content Whether to skip comparing tensor data content, + * useful for cases where we don't care about parameters content + * \return If comparison fails, return the first mismatch AccessPath pair, + * otherwise return std::nullopt. + */ + TVM_FFI_EXTRA_CXX_API static Optional GetFirstMismatch( + const Any& lhs, const Any& rhs, bool map_free_vars = false, bool skip_tensor_content = false); + + /* + * \brief Compare two Any values for structural equality. + * \param lhs The left hand side Any object. + * \param rhs The right hand side Any object. + * \return True if the two Any values are structurally equal, false otherwise. + */ + TVM_FFI_INLINE bool operator()(const Any& lhs, const Any& rhs) const { + return Equal(lhs, rhs, false, true); + } +}; + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_hash.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_hash.h new file mode 100644 index 000000000..bfe023c38 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_hash.h @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/extra/structural_hash.h + * \brief Structural hash + */ +#ifndef TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ +#define TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ + +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief Structural hash + */ +class StructuralHash { + public: + /*! + * \brief Hash an Any value. + * \param value The Any value to hash. + * \param map_free_vars Whether to map free variables. + * \param skip_tensor_content Whether to skip comparingn darray data content, + * useful for cases where we don't care about parameters content. + * \return The hash value. + */ + TVM_FFI_EXTRA_CXX_API static uint64_t Hash(const Any& value, bool map_free_vars = false, + bool skip_tensor_content = false); + /*! + * \brief Hash an Any value. + * \param value The Any value to hash. + * \return The hash value. + */ + TVM_FFI_INLINE uint64_t operator()(const Any& value) const { return Hash(value); } +}; + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/function.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/function.h new file mode 100644 index 000000000..4854ecd1d --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/function.h @@ -0,0 +1,998 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/function.h + * \brief A managed function in the TVM FFI. + */ +#ifndef TVM_FFI_FUNCTION_H_ +#define TVM_FFI_FUNCTION_H_ + +/*! + * \brief Controls whether DLL exports should include metadata. + * + * When set to 1, exported functions will include additional metadata. + * When set to 0 (default), exports are minimal without metadata. + */ +#ifndef TVM_FFI_DLL_EXPORT_INCLUDE_METADATA +#define TVM_FFI_DLL_EXPORT_INCLUDE_METADATA 0 +#endif + +#include "any.h" +#include "base_details.h" +#include "c_api.h" +#include "error.h" +#include "function_details.h" + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +/** + * Helper macro to construct a safe call + * + * \brief Marks the beginning of the safe call that catches exception explicitly + * \sa TVM_FFI_SAFE_CALL_END + * + * \code + * int TVMFFICStyleFunction() { + * TVM_FFI_SAFE_CALL_BEGIN(); + * // c++ code region here + * TVM_FFI_SAFE_CALL_END(); + * } + * \endcode + */ +#define TVM_FFI_SAFE_CALL_BEGIN() \ + try { \ + (void)0 + +/*! + * \brief Marks the end of safe call. + */ +#define TVM_FFI_SAFE_CALL_END() \ + return 0; \ + } \ + catch (const ::tvm::ffi::Error &err) { \ + ::tvm::ffi::details::SetSafeCallRaised(err); \ + return -1; \ + } \ + catch (const ::tvm::ffi::EnvErrorAlreadySet &) { \ + return -2; \ + } \ + catch (const std::exception &ex) { \ + ::tvm::ffi::details::SetSafeCallRaised(::tvm::ffi::Error("InternalError", ex.what(), "")); \ + return -1; \ + } \ + TVM_FFI_UNREACHABLE() + +/*! + * \brief Macro to check a call to TVMFFISafeCallType and raise exception if error happens. + * \param func The function to check. + * + * \code + * // calls TVMFFIFunctionCall and raises exception if error happens + * TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_arr, &type_index)); + * \endcode + */ +#define TVM_FFI_CHECK_SAFE_CALL(func) \ + { \ + int ret_code = (func); \ + if (ret_code != 0) { \ + if (ret_code == -2) { \ + throw ::tvm::ffi::EnvErrorAlreadySet(); \ + } \ + throw ::tvm::ffi::details::MoveFromSafeCallRaised(); \ + } \ + } + +/*! + * \brief Object container class that backs ffi::Function + * \note Do not use this class directly, use ffi::Function + */ +class FunctionObj : public Object, public TVMFFIFunctionCell { +public: + /*! \brief Typedef for C++ style calling signature that comes with exception propagation */ + using FCall = void (*)(const FunctionObj *, const AnyView *, int32_t, Any *); + using TVMFFIFunctionCell::cpp_call; + using TVMFFIFunctionCell::safe_call; + /*! + * \brief Call the function in packed format. + * \param args The arguments + * \param num_args The number of arguments + * \param result The return value. + */ + TVM_FFI_INLINE void CallPacked(const AnyView *args, int32_t num_args, Any *result) const { + // if cpp_call is set, use it to call the function, otherwise, redirect to safe_call + // use conditional expression here so the select is branchless + FCall call_ptr = this->cpp_call ? reinterpret_cast(this->cpp_call) : CppCallDedirectToSafeCall; + (*call_ptr)(this, args, num_args, result); + } + /// \cond Doxygen_Suppress + static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIFunction; + TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIFunction, FunctionObj, Object); + /// \endcond + +protected: + /*! \brief Make default constructor protected. */ + FunctionObj() {} + friend class Function; + +private: + static void CppCallDedirectToSafeCall(const FunctionObj *func, const AnyView *args, + int32_t num_args, Any *rv) { + FunctionObj *self = static_cast(const_cast(func)); + TVM_FFI_CHECK_SAFE_CALL(self->safe_call(self, reinterpret_cast(args), + num_args, reinterpret_cast(rv))); + } +}; + +namespace details { +/*! + * \brief Derived object class for constructing FunctionObj backed by a TCallable + * + * This is a helper class that implements the function call interface. + * Invariance: TCallable cannot be const or reference type. + */ +template +class FunctionObjImpl : public FunctionObj { +public: + static_assert(std::is_same_v>>, + "TCallable of FunctionObjImpl cannot be const or reference type"); + + /*! \brief The type of derived object class */ + using TSelf = FunctionObjImpl; + + /*! + * \brief Derived object class for constructing ffi::FunctionObj. + * \param callable The type-erased callable object (rvalue). + */ + explicit FunctionObjImpl(TCallable &&callable) : callable_(std::move(callable)) { + this->safe_call = SafeCall; + this->cpp_call = reinterpret_cast(CppCall); + } + /*! + * \brief Derived object class for constructing ffi::FunctionObj. + * \param callable The type-erased callable object (lvalue). + */ + explicit FunctionObjImpl(const TCallable &callable) : callable_(callable) { + this->safe_call = SafeCall; + this->cpp_call = reinterpret_cast(CppCall); + } + +private: + // implementation of call + static void CppCall(const FunctionObj *func, const AnyView *args, int32_t num_args, Any *result) { + (static_cast(func))->callable_(args, num_args, result); + } + /// \cond Doxygen_Suppress + // Implementing safe call style + static int SafeCall(void *func, const TVMFFIAny *args, int32_t num_args, TVMFFIAny *result) { + TVM_FFI_SAFE_CALL_BEGIN(); + TVM_FFI_ICHECK_LT(result->type_index, TypeIndex::kTVMFFIStaticObjectBegin); + FunctionObj *self = static_cast(func); + reinterpret_cast(self->cpp_call)(self, reinterpret_cast(args), num_args, + reinterpret_cast(result)); + TVM_FFI_SAFE_CALL_END(); + } + /// \endcond + /*! \brief Type-erased filed for storing callable object*/ + mutable TCallable callable_; +}; + +/*! + * \brief FunctionObj specialization for raw C style callback where handle and deleter are null. + */ +class ExternCFunctionObjNullHandleImpl : public FunctionObj { +public: + explicit ExternCFunctionObjNullHandleImpl(TVMFFISafeCallType safe_call) { + this->safe_call = safe_call; + this->cpp_call = nullptr; + } +}; + +/*! + * \brief FunctionObj specialization that leverages C-style callback definitions. + */ +class ExternCFunctionObjImpl : public FunctionObj { +public: + ExternCFunctionObjImpl(void *self, TVMFFISafeCallType safe_call, void (*deleter)(void *self)) + : self_(self), safe_call_(safe_call), deleter_(deleter) { + this->safe_call = SafeCall; + this->cpp_call = nullptr; + } + + ~ExternCFunctionObjImpl() { + if (deleter_) { + deleter_(self_); + } + } + +private: + static int32_t SafeCall(void *func, const TVMFFIAny *args, int32_t num_args, TVMFFIAny *rv) { + ExternCFunctionObjImpl *self = reinterpret_cast(func); + return self->safe_call_(self->self_, args, num_args, rv); + } + + void *self_; + TVMFFISafeCallType safe_call_; + void (*deleter_)(void *self); +}; + +// Helper class to set packed arguments +class PackedArgsSetter { +public: + explicit PackedArgsSetter(AnyView *args) : args_(args) {} + + // NOTE: setter needs to be very carefully designed + // such that we do not have temp variable conversion(eg. convert from lvalue to rvalue) + // that is why we need T&& and std::forward here + template + TVM_FFI_INLINE void operator()(size_t i, T &&value) const { + args_[i].operator=(std::forward(value)); + } + +private: + AnyView *args_; +}; +} // namespace details + +/*! + * \brief Represents arguments packed in AnyView array + * \note This class represent packed arguments to ffi::Function + */ +class PackedArgs { +public: + /*! + * \brief Constructor + * \param data The arguments + * \param size The number of arguments + */ + PackedArgs(const AnyView *data, int32_t size) : data_(data), size_(size) {} + + /*! \return size of the arguments */ + int size() const { return size_; } + + /*! \return The arguments */ + const AnyView *data() const { return data_; } + + /*! + * \brief Slice the arguments + * \param begin The begin index + * \param end The end index + * \return The sliced arguments + */ + PackedArgs Slice(int begin, int end = -1) const { + if (end == -1) { + end = size_; + } + return PackedArgs(data_ + begin, end - begin); + } + + /*! + * \brief Get i-th argument + * \param i the index. + * \return the ith argument. + */ + AnyView operator[](int i) const { return data_[i]; } + + /*! + * \brief Fill the arguments into the AnyView array + * \param data The AnyView array to store the packed arguments + * \param args The arguments to be packed + * \note Caller must ensure all args are alive during lifetime of data. + * A common pitfall is to pass in local variables that are immediately + * destroyed after calling Fill. + */ + template + TVM_FFI_INLINE static void Fill(AnyView *data, Args &&...args) { + details::for_each(details::PackedArgsSetter(data), std::forward(args)...); + } + +private: + /*! \brief The arguments */ + const AnyView *data_; + /*! \brief The number of arguments */ + int32_t size_; +}; + +/*! + * \brief ffi::Function is a type-erased function. + * The arguments are passed by "packed format" via AnyView + */ +class Function : public ObjectRef { +public: + /*! \brief Constructor from null */ + Function(std::nullptr_t) : ObjectRef(nullptr) {} // NOLINT(*) + /*! + * \brief Constructing a packed function from a callable type + * whose signature is consistent with `ffi::Function` + * \param packed_call The packed function signature + * \note legacy purpose, should change to Function::FromPacked for mostfuture use. + */ + template , Function>>> + explicit Function(TCallable &&packed_call) { + *this = FromPacked(std::forward(packed_call)); + } + /*! + * \brief Constructing a packed function from a callable type + * whose signature is consistent with `ffi::Function` + * \param packed_call The packed function signature + */ + template + static Function FromPacked(TCallable &&packed_call) { + static_assert( + std::is_convertible_v> || std::is_convertible_v>, + "tvm::ffi::Function::FromPacked requires input function signature to match packed func " + "format"); + if constexpr (std::is_convertible_v>) { + return FromPackedInternal( + [packed_call = std::forward(packed_call)]( + const AnyView *args, int32_t num_args, Any *rv) mutable -> void { + packed_call(PackedArgs{args, num_args}, rv); + }); + } else { + return FromPackedInternal(std::forward(packed_call)); + } + } + + /*! + * \brief Create ffi::Function from a C style callbacks. + * + * self and deleter can be nullptr if the function do not need closure support + * and corresponds to a raw function pointer. + * + * \param self Resource handle to the function + * \param safe_call The safe_call definition in C. + * \param deleter The deleter to release the resource of self. + * \return The created function. + */ + static Function FromExternC(void *self, TVMFFISafeCallType safe_call, + void (*deleter)(void *self)) { + // the other function coems from a different library + Function func; + if (self == nullptr && deleter == nullptr) { + func.data_ = make_object(safe_call); + } else { + func.data_ = make_object(self, safe_call, deleter); + } + return func; + } + /*! + * \brief Get global function by name + * \param name The function name + * \return The global function. + * \note This function will return std::nullopt if the function is not found. + */ + static std::optional GetGlobal(std::string_view name) { + TVMFFIObjectHandle handle; + TVMFFIByteArray name_arr{name.data(), name.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionGetGlobal(&name_arr, &handle)); + if (handle != nullptr) { + return Function( + details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle))); + } else { + return std::nullopt; + } + } + + /*! + * \brief Get global function by name + * \param name The name of the function + * \return The global function + * \note This function will return std::nullopt if the function is not found. + */ + static std::optional GetGlobal(const std::string &name) { + return GetGlobal(std::string_view(name.data(), name.length())); + } + + /*! + * \brief Get global function by name + * \param name The name of the function + * \return The global function + * \note This function will return std::nullopt if the function is not found. + */ + static std::optional GetGlobal(const String &name) { + return GetGlobal(std::string_view(name.data(), name.length())); + } + + /*! + * \brief Get global function by name + * \param name The name of the function + * \return The global function + * \note This function will return std::nullopt if the function is not found. + */ + static std::optional GetGlobal(const char *name) { + return GetGlobal(std::string_view(name)); + } + /*! + * \brief Get global function by name and throw an error if it is not found. + * \param name The name of the function + * \return The global function + * \note This function will throw an error if the function is not found. + */ + static Function GetGlobalRequired(std::string_view name) { + std::optional res = GetGlobal(name); + if (!res.has_value()) { + TVM_FFI_THROW(ValueError) << "Function " << name << " not found"; + } + return *res; + } + + /*! + * \brief Get global function by name + * \param name The name of the function + * \return The global function + * \note This function will throw an error if the function is not found. + */ + static Function GetGlobalRequired(const std::string &name) { + return GetGlobalRequired(std::string_view(name.data(), name.length())); + } + + /*! + * \brief Get global function by name + * \param name The name of the function + * \return The global function + * \note This function will throw an error if the function is not found. + */ + static Function GetGlobalRequired(const String &name) { + return GetGlobalRequired(std::string_view(name.data(), name.length())); + } + + /*! + * \brief Get global function by name + * \param name The name of the function + * \return The global function + * \note This function will throw an error if the function is not found. + */ + static Function GetGlobalRequired(const char *name) { + return GetGlobalRequired(std::string_view(name)); + } + /*! + * \brief Set global function by name + * \param name The name of the function + * \param func The function + * \param override Whether to override when there is duplication. + */ + static void SetGlobal(std::string_view name, + Function func, // NOLINT(performance-unnecessary-value-param) + bool override = false) { + TVMFFIByteArray name_arr{name.data(), name.size()}; + TVM_FFI_CHECK_SAFE_CALL( + TVMFFIFunctionSetGlobal(&name_arr, details::ObjectUnsafe::GetHeader(func.get()), override)); + } + /*! + * \brief List all global names + * \return A vector of all global names + * \note This function do not depend on Array so core do not have container dep. + */ + static std::vector ListGlobalNames() { + Function fname_functor = GetGlobalRequired("ffi.FunctionListGlobalNamesFunctor")().cast(); + std::vector names; + int len = fname_functor(-1).cast(); + names.reserve(len); + for (int i = 0; i < len; ++i) { + names.push_back(fname_functor(i).cast()); + } + return names; + } + /** + * \brief Remove a global function by name + * \param name The name of the function + */ + static void RemoveGlobal(const String &name) { + static Function fremove = GetGlobalRequired("ffi.FunctionRemoveGlobal"); + fremove(name); + } + /*! + * \brief Constructing a packed function from a normal function. + * + * \param callable the internal container of packed function. + */ + template + static Function FromTyped(TCallable &&callable) { + using FuncInfo = details::FunctionInfo>; + // Callable is always captured by value here to avoid possible dangling reference + auto call_packed = [callable = std::forward(callable)]( + const AnyView *args, int32_t num_args, Any *rv) mutable -> void { + details::unpack_call( + std::make_index_sequence{}, nullptr, callable, args, num_args, rv); + }; + return FromPackedInternal(std::move(call_packed)); + } + /*! + * \brief Constructing a packed function from a normal function. + * + * \param callable the internal container of packed function. + * \param name optional name attacked to the function. + */ + template + static Function FromTyped(TCallable &&callable, std::string name) { + using FuncInfo = details::FunctionInfo>; + // Callable is always captured by value here to avoid possible dangling reference + auto call_packed = [callable = std::forward(callable), name = std::move(name)]( + const AnyView *args, int32_t num_args, Any *rv) mutable -> void { + details::unpack_call( + std::make_index_sequence{}, &name, callable, args, num_args, rv); + }; + return FromPackedInternal(std::move(call_packed)); + } + + /*! + * \brief Directly invoke an extern "C" function that follows the TVM FFI SafeCall convention. + * + * This function can be useful to turn an existing exported symbol into a typed function. + * + * \code + * + * // An extern "C" function, matching TVMFFISafeCallType + * extern "C" int __tvm_ffi_add( + * void* handle, const TVMFFIAny* args, int32_t num_args, TVMFFIAny*result + * ); + * + * // redirect an existing symbol into a typed function + * inline int add(int a, int b) { + * return tvm::ffi::Function::InvokeExternC(nullptr, __tvm_ffi_add, a, b).cast(); + * } + * + * \endcode + * + * \tparam Args The types of the arguments to the extern function. + * \param handle The handle argument, for exported symbols this is usually nullptr. + * \param safe_call The function pointer to the extern "C" function. + * \param args The arguments to pass to the function. + * \return The return value, wrapped in a tvm::ffi::Any. + */ + template + TVM_FFI_INLINE static Any InvokeExternC(void *handle, TVMFFISafeCallType safe_call, + Args &&...args) { + const int kNumArgs = sizeof...(Args); + const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; + AnyView args_pack[kArraySize]; + PackedArgs::Fill(args_pack, std::forward(args)...); + Any result; + TVM_FFI_CHECK_SAFE_CALL(safe_call(handle, reinterpret_cast(args_pack), + kNumArgs, reinterpret_cast(&result))); + return result; + } + /*! + * \brief Call function by directly passing in unpacked arguments. + * + * \param args Arguments to be passed. + * \tparam Args arguments to be passed. + * + * \code + * // Example code on how to call packed function + * void CallFFIFunction(tvm::ffi::Function f) { + * // call like normal functions by pass in arguments + * // return value is automatically converted back + * int rvalue = f(1, 2.0); + * } + * \endcode + */ + template + TVM_FFI_INLINE Any operator()(Args &&...args) const { + const int kNumArgs = sizeof...(Args); + const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; + AnyView args_pack[kArraySize]; + PackedArgs::Fill(args_pack, std::forward(args)...); + Any result; + static_cast(data_.get())->CallPacked(args_pack, kNumArgs, &result); + return result; + } + /*! + * \brief Call the function in packed format. + * \param args The arguments + * \param num_args The number of arguments + * \param result The return value. + */ + TVM_FFI_INLINE void CallPacked(const AnyView *args, int32_t num_args, Any *result) const { + static_cast(data_.get())->CallPacked(args, num_args, result); + } + /*! + * \brief Call the function in packed format. + * \param args The arguments + * \param result The return value. + */ + TVM_FFI_INLINE void CallPacked(PackedArgs args, Any *result) const { + static_cast(data_.get())->CallPacked(args.data(), args.size(), result); + } + + /*! \return Whether the packed function is nullptr */ + TVM_FFI_INLINE bool operator==(std::nullptr_t) const { return data_ == nullptr; } + /*! \return Whether the packed function is not nullptr */ + TVM_FFI_INLINE bool operator!=(std::nullptr_t) const { return data_ != nullptr; } + + /// \cond Doxygen_Suppress + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Function, ObjectRef, FunctionObj); + /// \endcond + + class Registry; + +private: + /*! + * \brief Constructing a packed function from a callable type + * whose signature is consistent with `ffi::Function` + * \param packed_call The packed function signature + */ + template + static Function FromPackedInternal(TCallable &&packed_call) { + // We must make TCallable a value type (decay_t) that can hold the callable object + using ObjType = typename details::FunctionObjImpl>; + Function func; + func.data_ = make_object(std::forward(packed_call)); + return func; + } +}; + +/*! + * \brief Please refer to \ref TypedFunctionAnchor "TypedFunction" + */ +template +class TypedFunction; + +/*! + * \anchor TypedFunctionAnchor + * \brief A ffi::Function wrapper to provide typed function signature. + * It is backed by a ffi::Function internally. + * + * TypedFunction enables compile time type checking. + * TypedFunction works with the runtime system: + * - It can be passed as an argument of ffi::Function. + * - It can be assigned to ffi::Any. + * - It can be directly converted to a type-erased ffi::Function. + * + * Developers should prefer TypedFunction over ffi::Function in C++ code + * as it enables compile time checking. + * We can construct a TypedFunction from a lambda function + * with the same signature. + * + * \code + * // user defined lambda function. + * auto addone = [](int x)->int { + * return x + 1; + * }; + * // We can directly convert + * // lambda function to TypedFunction + * TypedFunction ftyped(addone); + * // invoke the function. + * int y = ftyped(1); + * // Can be directly converted to ffi::Function + * ffi::Function packed = ftype; + * \endcode + * \tparam R The return value of the function. + * \tparam Args The argument signature of the function. + */ +template +class TypedFunction { +public: + /*! \brief short hand for this function type */ + using TSelf = TypedFunction; + /*! \brief default constructor */ + TypedFunction() = default; + /*! \brief constructor from null */ + TypedFunction(std::nullptr_t null) {} // NOLINT(*) + /*! + * \brief constructor from a function + * \param packed The function + */ + TypedFunction(Function packed) : packed_(std::move(packed)) {} // NOLINT(*) + /*! + * \brief construct from a lambda function with the same signature. + * + * Example usage: + * \code + * auto typed_lambda = [](int x)->int { return x + 1; } + * // construct from packed function + * TypedFunction ftyped(typed_lambda, "add_one"); + * // call the typed version. + * CHECK_EQ(ftyped(1), 2); + * \endcode + * + * \param typed_lambda typed lambda function. + * \param name the name of the lambda function. + * \tparam FLambda the type of the lambda function. + */ + template >>> + TypedFunction(FLambda &&typed_lambda, std::string name) { + packed_ = Function::FromTyped(std::forward(typed_lambda), std::move(name)); + } + /*! + * \brief construct from a lambda function with the same signature. + * + * This version does not take a name. It is highly recommend you use the + * version that takes a name for the lambda. + * + * Example usage: + * \code + * auto typed_lambda = [](int x)->int { return x + 1; } + * // construct from packed function + * TypedFunction ftyped(typed_lambda); + * // call the typed version. + * CHECK_EQ(ftyped(1), 2); + * \endcode + * + * \param typed_lambda typed lambda function. + * \tparam FLambda the type of the lambda function. + */ + template > && !std::is_same_v, TSelf>>> + TypedFunction(FLambda &&typed_lambda) { // NOLINT(google-explicit-constructor) + packed_ = Function::FromTyped(std::forward(typed_lambda)); + } + /*! + * \brief copy assignment operator from typed lambda + * + * Example usage: + * \code + * // construct from packed function + * TypedFunction ftyped; + * ftyped = [](int x) { return x + 1; } + * // call the typed version. + * CHECK_EQ(ftyped(1), 2); + * \endcode + * + * \param typed_lambda typed lambda function. + * \tparam FLambda the type of the lambda function. + * \returns reference to self. + */ + template > && !std::is_same_v, TSelf>>> + TSelf &operator=(FLambda &&typed_lambda) { + packed_ = Function::FromTyped(std::forward(typed_lambda)); + return *this; + } + /*! + * \brief copy assignment operator from ffi::Function. + * \param packed The packed function. + * \returns reference to self. + */ + TSelf &operator=(Function packed) { + packed_ = std::move(packed); + return *this; + } + /*! + * \brief Invoke the operator. + * \param args The arguments + * \returns The return value. + */ + TVM_FFI_INLINE R operator()(Args... args) const { // NOLINT(performance-unnecessary-value-param) + if constexpr (std::is_same_v) { + packed_(std::forward(args)...); + } else { + Any res = packed_(std::forward(args)...); + if constexpr (std::is_same_v) { + return res; + } else { + return std::move(res).cast(); + } + } + } + /*! + * \brief convert to ffi::Function + * \return the internal ffi::Function + */ + operator Function() const { return packed(); } // NOLINT(google-explicit-constructor) + /*! + * \return reference the internal ffi::Function + */ + const Function &packed() const & { return packed_; } + /*! + * \return r-value reference the internal ffi::Function + */ + constexpr Function &&packed() && { return std::move(packed_); } + /*! \return Whether the packed function is nullptr */ + bool operator==(std::nullptr_t null) const { return packed_ == nullptr; } + /*! \return Whether the packed function is not nullptr */ + bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; } + /*! + * \brief Get the type schema of `TypedFunction` in json format. + * \return The type schema of the function in json format. + */ + static std::string TypeSchema() { return details::FuncFunctorImpl::TypeSchema(); } + +private: + /*! \brief The internal packed function */ + Function packed_; +}; + +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFunction; + + TVM_FFI_INLINE static void CopyToAnyView(const TypedFunction &src, TVMFFIAny *result) { + TypeTraits::CopyToAnyView(src.packed(), result); + } + + TVM_FFI_INLINE static void MoveToAny(TypedFunction src, TVMFFIAny *result) { + // Move from rvalue to trigger TypedFunction rvalue packed() overload + TypeTraits::MoveToAny(std::move(src).packed(), result); + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + return src->type_index == TypeIndex::kTVMFFIFunction; + } + + TVM_FFI_INLINE static TypedFunction CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { + return TypedFunction(TypeTraits::CopyFromAnyViewAfterCheck(src)); + } + + TVM_FFI_INLINE static std::optional> TryCastFromAnyView( + const TVMFFIAny *src) { + std::optional opt = TypeTraits::TryCastFromAnyView(src); + if (opt.has_value()) { + return TypedFunction(*std::move(opt)); + } else { + return std::nullopt; + } + } + + TVM_FFI_INLINE static std::string TypeStr() { return details::FunctionInfo::Sig(); } + TVM_FFI_INLINE static std::string TypeSchema() { return TypedFunction::TypeSchema(); } +}; + +/*! + * \brief helper function to get type index from key + */ +inline int32_t TypeKeyToIndex(std::string_view type_key) { + int32_t type_index; + TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); + return type_index; +} + +/// \cond Doxygen_Suppress +// Internal implementation macros used by TVM_FFI_DLL_EXPORT_TYPED_FUNC and related macros. +// These should not be used directly; use the public macros instead. + +// Internal implementation macro that generates the C ABI wrapper function +#define TVM_FFI_DLL_EXPORT_TYPED_FUNC_IMPL_(ExportName, Function) \ + extern "C" { \ + TVM_FFI_DLL_EXPORT int __tvm_ffi_##ExportName(void *self, const TVMFFIAny *args, \ + int32_t num_args, TVMFFIAny *result) { \ + TVM_FFI_SAFE_CALL_BEGIN(); \ + using FuncInfo = ::tvm::ffi::details::FunctionInfo; \ + static std::string name = #ExportName; \ + ::tvm::ffi::details::unpack_call( \ + std::make_index_sequence{}, &name, Function, \ + reinterpret_cast(args), num_args, \ + reinterpret_cast<::tvm::ffi::Any *>(result)); \ + TVM_FFI_SAFE_CALL_END(); \ + } \ + } +/// \endcond + +/*! + * \brief Export typed function as a SafeCallType symbol that follows the FFI ABI. + * + * This macro exports the function and automatically exports metadata when + * TVM_FFI_DLL_EXPORT_INCLUDE_METADATA is defined. + * + * \param ExportName The symbol name to be exported. + * \param Function The typed function. + * + * \sa ffi::TypedFunction, TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC + * + * \code + * + * int AddOne_(int x) { + * return x + 1; + * } + * + * // Expose the function as "AddOne" + * TVM_FFI_DLL_EXPORT_TYPED_FUNC(AddOne, AddOne_); + * + * // Expose the function as "SubOne" + * TVM_FFI_DLL_EXPORT_TYPED_FUNC(SubOne, [](int x) { + * return x - 1; + * }); + * \endcode + * + * \note The final symbol names are: + * - `__tvm_ffi_` (function) + * - `__tvm_ffi__metadata_` (metadata - only when + * TVM_FFI_DLL_EXPORT_INCLUDE_METADATA is defined) + */ +#if TVM_FFI_DLL_EXPORT_INCLUDE_METADATA +#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ + TVM_FFI_DLL_EXPORT_TYPED_FUNC_IMPL_(ExportName, Function) \ + extern "C" { \ + TVM_FFI_DLL_EXPORT int __tvm_ffi__metadata_##ExportName(void *self, const TVMFFIAny *args, \ + int32_t num_args, TVMFFIAny *result) { \ + TVM_FFI_SAFE_CALL_BEGIN(); \ + using FuncInfo = ::tvm::ffi::details::FunctionInfo; \ + std::ostringstream os; \ + os << R"({"type_schema":)" \ + << ::tvm::ffi::EscapeString(::tvm::ffi::String(FuncInfo::TypeSchema())) << R"(})"; \ + ::tvm::ffi::String str(os.str()); \ + ::tvm::ffi::TypeTraits<::tvm::ffi::String>::MoveToAny(std::move(str), result); \ + TVM_FFI_SAFE_CALL_END(); \ + } \ + } +#else +#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ + TVM_FFI_DLL_EXPORT_TYPED_FUNC_IMPL_(ExportName, Function) +#endif + +/*! + * \brief Export documentation string for a typed function. + * + * This macro exports a documentation string associated with a function export name. + * The docstring can be used by stub generators and documentation tools. + * This macro only exports the docstring; it does not export the function itself. + * + * \param ExportName The symbol name that the docstring is associated with. + * \param DocString The documentation string (C string literal). + * + * \sa ffi::TypedFunction, TVM_FFI_DLL_EXPORT_TYPED_FUNC + * + * \code + * + * int Add(int a, int b) { + * return a + b; + * } + * + * TVM_FFI_DLL_EXPORT_TYPED_FUNC(add, Add); + * TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC( + * add, + * R"(Add two integers and return the sum. + * + * Parameters + * ---------- + * a : int + * First integer + * b : int + * Second integer + * + * Returns + * ------- + * result : int + * Sum of a and b)"); + * + * \endcode + * + * \note The exported symbol name is `__tvm_ffi__doc_` (docstring getter function). + * This symbol is only exported when TVM_FFI_DLL_EXPORT_INCLUDE_METADATA is defined. + */ +#if TVM_FFI_DLL_EXPORT_INCLUDE_METADATA +#define TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC(ExportName, DocString) \ + extern "C" { \ + TVM_FFI_DLL_EXPORT int __tvm_ffi__doc_##ExportName(void *self, const TVMFFIAny *args, \ + int32_t num_args, TVMFFIAny *result) { \ + TVM_FFI_SAFE_CALL_BEGIN(); \ + ::tvm::ffi::String str(DocString); \ + ::tvm::ffi::TypeTraits<::tvm::ffi::String>::MoveToAny(std::move(str), result); \ + TVM_FFI_SAFE_CALL_END(); \ + } \ + } +#else +#define TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC(ExportName, DocString) +#endif +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_FUNCTION_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/function_details.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/function_details.h new file mode 100644 index 000000000..38725d800 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/function_details.h @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/function_details.h + * \brief Implements the funciton signature reflection + */ +#ifndef TVM_FFI_FUNCTION_DETAILS_H_ +#define TVM_FFI_FUNCTION_DETAILS_H_ + +#include "any.h" +#include "base_details.h" +#include "c_api.h" +#include "error.h" + +#include +#include +#include + +namespace tvm { +namespace ffi { +namespace details { + +template +struct Arg2Str { + template + TVM_FFI_INLINE static void Apply(std::ostream &os) { + using Arg = std::tuple_element_t; + if constexpr (i != 0) { + os << ", "; + } + os << i << ": " << Type2Str::v(); + } + template + TVM_FFI_INLINE static void Run(std::ostream &os, std::index_sequence) { + using TExpander = int[]; + (void)TExpander{0, (Apply(os), 0)...}; + } +}; + +/// NOTE: We only support `T`, `const T`, `const T&` and `T&&` as argument types. +template +static constexpr bool ArgTypeSupported = (!std::is_reference_v) || (std::is_const_v> && std::is_lvalue_reference_v) || (!std::is_const_v> && std::is_rvalue_reference_v); + +template +static constexpr bool ArgSupported = (ArgTypeSupported && (std::is_same_v>, Any> || std::is_same_v>, AnyView> || TypeTraitsNoCR::convert_enabled)); + +// NOTE: return type can only support non-reference managed returns +template +static constexpr bool RetSupported = (std::is_same_v || std::is_void_v || TypeTraits::convert_enabled); + +template +struct FuncFunctorImpl { + using FType = R(Args...); + using ArgType = std::tuple; + using RetType = R; + /*! \brief total number of arguments*/ + static constexpr size_t num_args = sizeof...(Args); + // MSVC is not that friendly to in-template nested bool evaluation +#ifndef _MSC_VER + /*! \brief Whether this function can be converted to ffi::Function via FromTyped */ + static constexpr bool unpacked_supported = (ArgSupported && ...) && (RetSupported); +#endif + TVM_FFI_INLINE static std::string Sig() { + using IdxSeq = std::make_index_sequence; + std::ostringstream ss; + ss << "("; + Arg2Str>::Run(ss, IdxSeq{}); + ss << ") -> " << Type2Str::v(); + return ss.str(); + } + TVM_FFI_INLINE static std::string TypeSchema() { + std::ostringstream oss; + oss << R"({"type":")" << StaticTypeKey::kTVMFFIFunction << R"(","args":[)"; + oss << details::TypeSchema::v(); + ((oss << "," << details::TypeSchema::v()), ...); + oss << "]}"; + return oss.str(); + } +}; + +template +struct FunctionInfoHelper; + +template +struct FunctionInfoHelper : FuncFunctorImpl {}; +template +struct FunctionInfoHelper : FuncFunctorImpl {}; + +/*! + * \brief Template class to get function signature of a function or functor. + * \tparam T The function/functor type. + * \note We need a decltype redirection because this helps lambda types. + */ +template +struct FunctionInfo : FunctionInfoHelper {}; +template +struct FunctionInfo : FuncFunctorImpl {}; +template +struct FunctionInfo : FuncFunctorImpl {}; +template +struct FunctionInfo : FuncFunctorImpl {}; +// Support pointer-to-member functions used in reflection (e.g. &Class::method) +template +struct FunctionInfo>> + : FuncFunctorImpl {}; +template +struct FunctionInfo>> + : FuncFunctorImpl {}; + +template +struct FunctionInfo>> + : FuncFunctorImpl {}; +template +struct FunctionInfo>> + : FuncFunctorImpl {}; + +/*! \brief Using static function to output typed function signature */ +using FGetFuncSignature = std::string (*)(); + +/*! + * \brief Auxilary argument value with context for error reporting + * \tparam Type The expected type of the argument. + * \note We use a template class with non-template operator conversion + * instead of a non-template class with template operator conversion. + * This is because template operator conversion doesn't play well with + * classes with template constructors. + * In this case, it may lead to some unintended compiler errors. + * An example of class can be `std::optional`. + */ +template +class ArgValueWithContext { +public: + using TypeWithoutCR = std::remove_const_t>; + + /*! + * \brief move constructor from another return value. + * \param args The argument list + * \param arg_index In a function call, this argument is at index arg_index (0-indexed). + * \param optional_name Name of the function being called. Can be nullptr if the function is not. + * \param f_sig Pointer to static function outputting signature of the function being called. + * named. + */ + TVM_FFI_INLINE ArgValueWithContext(const AnyView *args, int32_t arg_index, + const std::string *optional_name, FGetFuncSignature f_sig) + : args_(args), arg_index_(arg_index), optional_name_(optional_name), f_sig_(f_sig) {} + + TVM_FFI_INLINE operator TypeWithoutCR() { // NOLINT(google-explicit-constructor) + if constexpr (std::is_same_v) { + return args_[arg_index_]; + } else if constexpr (std::is_same_v) { + return Any(args_[arg_index_]); + } else { + std::optional opt = args_[arg_index_].template try_cast(); + if (!opt.has_value()) { + TVMFFIAny any_data = args_[arg_index_].CopyToTVMFFIAny(); + TVM_FFI_THROW(TypeError) << "Mismatched type on argument #" << arg_index_ + << " when calling: `" + << (optional_name_ == nullptr ? "" : *optional_name_) + << (f_sig_ == nullptr ? "" : (*f_sig_)()) << "`. Expected `" + << Type2Str::v() << "` but got `" + << TypeTraits::GetMismatchTypeInfo(&any_data) + << '`'; + } + return *std::move(opt); + } + } + +private: + const AnyView *args_; + int32_t arg_index_; + const std::string *optional_name_; + FGetFuncSignature f_sig_; +}; + +template +TVM_FFI_INLINE void unpack_call(std::index_sequence, const std::string *optional_name, + const F &f, [[maybe_unused]] const AnyView *args, + [[maybe_unused]] int32_t num_args, [[maybe_unused]] Any *rv) { + using FuncInfo = FunctionInfo; + using PackedArgs = typename FuncInfo::ArgType; + FGetFuncSignature f_sig = FuncInfo::Sig; + + // somehow MSVC does not support the static constexpr member in this case, function is fine +#ifndef _MSC_VER + static_assert(FuncInfo::unpacked_supported, "The function signature do not support unpacked"); +#endif + constexpr size_t nargs = sizeof...(Is); + if (nargs != num_args) { + TVM_FFI_THROW(TypeError) << "Mismatched number of arguments when calling: `" + << (optional_name == nullptr ? "" : *optional_name) + << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected " << nargs + << " but got " << num_args << " arguments"; + } + // use index sequence to do recursive-less unpacking + if constexpr (std::is_same_v) { + f(ArgValueWithContext>{args, Is, optional_name, f_sig}...); + } else { + *rv = R(f(ArgValueWithContext>{args, Is, optional_name, + f_sig}...)); + } +} + +/*! + * \brief Move the safe call raised error to the caller + * \return The error + */ +TVM_FFI_INLINE static Error MoveFromSafeCallRaised() { + TVMFFIObjectHandle handle; + TVMFFIErrorMoveFromRaised(&handle); + // handle is owned by caller + return details::ObjectUnsafe::ObjectRefFromObjectPtr( + details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle))); +} + +/*! + * \brief Set the safe call raised error + * \param error The error + */ +TVM_FFI_INLINE static void SetSafeCallRaised(const Error &error) { + TVMFFIErrorSetRaised(details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(error)); +} + +template +struct TypeSchemaImpl { + static std::string v() { + using U = std::remove_const_t>; + return TypeTraits::TypeSchema(); + } +}; + +template <> +struct TypeSchemaImpl { + static std::string v() { + return R"({"type":")" + std::string(StaticTypeKey::kTVMFFINone) + R"("})"; + } +}; + +template <> +struct TypeSchemaImpl { + static std::string v() { + return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIAny) + R"("})"; + } +}; + +template <> +struct TypeSchemaImpl { + static std::string v() { + return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIAny) + R"("})"; + } +}; + +} // namespace details +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_FUNCTION_DETAILS_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/memory.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/memory.h new file mode 100644 index 000000000..fd999da2a --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/memory.h @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/memory.h + * \brief Runtime memory management to allocate on heap object. + */ +#ifndef TVM_FFI_MEMORY_H_ +#define TVM_FFI_MEMORY_H_ + +#include "object.h" + +#include +#include +#include +#include + +namespace tvm { +namespace ffi { +/*! \brief Deleter function for obeject */ +using FObjectDeleter = void (*)(void *obj, int flags); + +// Detail implementations after this +// +// The current design allows swapping the +// allocator pattern when necessary. +// +// Possible future allocator optimizations: +// - Arena allocator that gives ownership of memory to arena (deleter = nullptr) +// - Thread-local object pools: one pool per size and alignment requirement. +// - Can specialize by type of object to give the specific allocator to each object. +namespace details { + +/*! + * \brief Allocate aligned memory. + * \param size The size. + * \tparam align The alignment, must be a power of 2. + * \return The pointer to the allocated memory. + */ +template +TVM_FFI_INLINE void *AlignedAlloc(size_t size) { + static_assert(align != 0 && (align & (align - 1)) == 0, "align must be a power of 2"); +#ifdef _MSC_VER + // MSVC have to use _aligned_malloc + if (void *ptr = _aligned_malloc(size, align)) { + return ptr; + } + throw std::bad_alloc(); +#else + if constexpr (align <= alignof(std::max_align_t)) { + // malloc guarantees alignment of std::max_align_t + if (void *ptr = std::malloc(size)) { + return ptr; + } + throw std::bad_alloc(); + } else { + void *ptr; + // for other alignments, use posix_memalign + if (posix_memalign(&ptr, align, size) != 0) { + throw std::bad_alloc(); + } + return ptr; + } +#endif +} + +/*! + * \brief Free aligned memory. + * \param data The pointer to the memory to free. + */ +TVM_FFI_INLINE void AlignedFree(void *data) { +#ifdef _MSC_VER + // MSVC have to use _aligned_free + _aligned_free(data); +#else + std::free(data); +#endif +} + +/*! + * \brief Base class of object allocators that implements make. + * Use curiously recurring template pattern. + * + * \tparam Derived The derived class. + */ +template +class ObjAllocatorBase { +public: + /*! + * \brief Make a new object using the allocator. + * \tparam T The type to be allocated. + * \tparam Args The constructor signature. + * \param args The arguments. + */ + template + ObjectPtr make_object(Args &&...args) { + using Handler = typename Derived::template Handler; + static_assert(std::is_base_of_v, "make can only be used to create Object"); + T *ptr = Handler::New(static_cast(this), std::forward(args)...); + TVMFFIObject *ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); + ffi_ptr->combined_ref_count = kCombinedRefCountBothOne; + ffi_ptr->type_index = T::RuntimeTypeIndex(); + ffi_ptr->__padding = 0; + ffi_ptr->deleter = Handler::Deleter(); + return details::ObjectUnsafe::ObjectPtrFromOwned(ptr); + } + + /*! + * \tparam ArrayType The type to be allocated. + * \tparam ElemType The type of array element. + * \tparam Args The constructor signature. + * \param num_elems The number of array elements. + * \param args The arguments. + */ + template + ObjectPtr make_inplace_array(size_t num_elems, Args &&...args) { + using Handler = typename Derived::template ArrayHandler; + static_assert(std::is_base_of_v, + "make_inplace_array can only be used to create Object"); + ArrayType *ptr = Handler::New(static_cast(this), num_elems, std::forward(args)...); + TVMFFIObject *ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); + ffi_ptr->combined_ref_count = kCombinedRefCountBothOne; + ffi_ptr->type_index = ArrayType::RuntimeTypeIndex(); + ffi_ptr->__padding = 0; + ffi_ptr->deleter = Handler::Deleter(); + return details::ObjectUnsafe::ObjectPtrFromOwned(ptr); + } + +private: + ObjAllocatorBase() = default; + friend Derived; +}; + +// Simple allocator that uses new/delete. +class SimpleObjAllocator : public ObjAllocatorBase { +public: + template + class Handler { + public: + template + static T *New(SimpleObjAllocator *, Args &&...args) { + // NOTE: the first argument is not needed for SimpleObjAllocator + // It is reserved for special allocators that needs to recycle + // the object to itself (e.g. in the case of object pool). + // + // In the case of an object pool, an allocator needs to create + // a special chunk memory that hides reference to the allocator + // and call allocator's release function in the deleter. + + // NOTE2: Use inplace new to allocate + // This is used to get rid of warning when deleting a virtual + // class with non-virtual destructor. + // We are fine here as we captured the right deleter during construction. + // This is also the right way to get storage type for an object pool. + void *data = AlignedAlloc(sizeof(T)); + new (data) T(std::forward(args)...); + return reinterpret_cast(data); + } + + static FObjectDeleter Deleter() { return Deleter_; } + + private: + static void Deleter_(void *objptr, int flags) { + T *tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned(static_cast(objptr)); + if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) { + // It is important to do tptr->T::~T(), + // so that we explicitly call the specific destructor + // instead of tptr->~T(), which could mean the intention + // call a virtual destructor(which may not be available and is not required). + tptr->T::~T(); + } + if (flags & kTVMFFIObjectDeleterFlagBitMaskWeak) { + AlignedFree(static_cast(tptr)); + } + } + }; + + // Array handler that uses new/delete. + template + class ArrayHandler { + public: + template + static ArrayType *New(SimpleObjAllocator *, size_t num_elems, Args &&...args) { + // NOTE: the first argument is not needed for ArrayObjAllocator + // It is reserved for special allocators that needs to recycle + // the object to itself (e.g. in the case of object pool). + // + // In the case of an object pool, an allocator needs to create + // a special chunk memory that hides reference to the allocator + // and call allocator's release function in the deleter. + // NOTE2: Use inplace new to allocate + // This is used to get rid of warning when deleting a virtual + // class with non-virtual destructor. + // We are fine here as we captured the right deleter during construction. + // This is also the right way to get storage type for an object pool. + + // for now only support elements that aligns with array header. + static_assert( + alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, + "element alignment constraint"); + size_t size = sizeof(ArrayType) + sizeof(ElemType) * num_elems; + // round up to the nearest multiple of align + constexpr size_t align = alignof(ArrayType); + // C++ standard always guarantees that alignof operator returns a power of 2 + size_t aligned_size = (size + (align - 1)) & ~(align - 1); + void *data = AlignedAlloc(aligned_size); + new (data) ArrayType(std::forward(args)...); + return reinterpret_cast(data); + } + + static FObjectDeleter Deleter() { return Deleter_; } + + private: + static void Deleter_(void *objptr, int flags) { + ArrayType *tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned( + static_cast(objptr)); + if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) { + // It is important to do tptr->ArrayType::~ArrayType(), + // so that we explicitly call the specific destructor + // instead of tptr->~ArrayType(), which could mean the intention + // call a virtual destructor(which may not be available and is not required). + tptr->ArrayType::~ArrayType(); + } + if (flags & kTVMFFIObjectDeleterFlagBitMaskWeak) { + AlignedFree(static_cast(tptr)); + } + } + }; +}; +} // namespace details + +/*! + * \brief Allocate an object + * \param args arguments to the constructor. + * \tparam T the node type. + * \return The ObjectPtr to the allocated object. + */ +template +inline ObjectPtr make_object(Args &&...args) { + return details::SimpleObjAllocator().make_object(std::forward(args)...); +} + +/*! + * \brief Allocate an Object with additional ElemType[num_elems] that are stored right after. + * \param num_elems The number of elements in the array. + * \param args arguments to the constructor. + * \tparam ArrayType the array type. + * \tparam ElemType the element type. + * \return The ObjectPtr to the allocated array. + */ +template +inline ObjectPtr make_inplace_array_object(size_t num_elems, Args &&...args) { + return details::SimpleObjAllocator().make_inplace_array( + num_elems, std::forward(args)...); +} + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_MEMORY_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/object.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/object.h new file mode 100644 index 000000000..eb796bf6a --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/object.h @@ -0,0 +1,1207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/object.h + * \brief A managed object in the TVM FFI. + */ +#ifndef TVM_FFI_OBJECT_H_ +#define TVM_FFI_OBJECT_H_ + +#include "base_details.h" +#include "c_api.h" + +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief TypeIndex enum, alias of TVMFFITypeIndex. + */ +using TypeIndex = TVMFFITypeIndex; + +/*! + * \brief TypeInfo, alias of TVMFFITypeInfo. + */ +using TypeInfo = TVMFFITypeInfo; + +/*! + * \brief Helper tag to explicitly request unsafe initialization. + * + * Constructing an ObjectRefType with UnsafeInit{} will set the data_ member to nullptr. + * + * When initializing Object fields, ObjectRef fields can be set to UnsafeInit. + * This enables the "construct with UnsafeInit then set all fields" pattern + * when the object does not have a default constructor. + * + * Used for initialization in controlled scenarios where such unsafe + * initialization is known to be safe. + * + * Each ObjectRefType should have a constructor that takes an UnsafeInit tag. + * + * \note As the name suggests, do not use it in normal code paths. + */ +struct UnsafeInit {}; + +/*! + * \brief Known type keys for pre-defined types. + */ +struct StaticTypeKey { + /*! \brief The type key for Any */ + static constexpr const char *kTVMFFIAny = "Any"; + /*! \brief The type key for None */ + static constexpr const char *kTVMFFINone = "None"; + /*! \brief The type key for bool */ + static constexpr const char *kTVMFFIBool = "bool"; + /*! \brief The type key for int */ + static constexpr const char *kTVMFFIInt = "int"; + /*! \brief The type key for float */ + static constexpr const char *kTVMFFIFloat = "float"; + /*! \brief The type key for void* */ + static constexpr const char *kTVMFFIOpaquePtr = "void*"; + /*! \brief The type key for DataType */ + static constexpr const char *kTVMFFIDataType = "DataType"; + /*! \brief The type key for Device */ + static constexpr const char *kTVMFFIDevice = "Device"; + /*! \brief The type key for DLTensor* */ + static constexpr const char *kTVMFFIDLTensorPtr = "DLTensor*"; + /*! \brief The type key for const char* */ + static constexpr const char *kTVMFFIRawStr = "const char*"; + /*! \brief The type key for TVMFFIByteArray* */ + static constexpr const char *kTVMFFIByteArrayPtr = "TVMFFIByteArray*"; + /*! \brief The type key for ObjectRValueRef */ + static constexpr const char *kTVMFFIObjectRValueRef = "ObjectRValueRef"; + /*! \brief The type key for SmallStr */ + static constexpr const char *kTVMFFISmallStr = "ffi.SmallStr"; + /*! \brief The type key for SmallBytes */ + static constexpr const char *kTVMFFISmallBytes = "ffi.SmallBytes"; + /*! \brief The type key for Error */ + static constexpr const char *kTVMFFIError = "ffi.Error"; + /*! \brief The type key for Bytes */ + static constexpr const char *kTVMFFIBytes = "ffi.Bytes"; + /*! \brief The type key for String */ + static constexpr const char *kTVMFFIStr = "ffi.String"; + /*! \brief The type key for Shape */ + static constexpr const char *kTVMFFIShape = "ffi.Shape"; + /*! \brief The type key for Tensor */ + static constexpr const char *kTVMFFITensor = "ffi.Tensor"; + /*! \brief The type key for Object */ + static constexpr const char *kTVMFFIObject = "ffi.Object"; + /*! \brief The type key for Function */ + static constexpr const char *kTVMFFIFunction = "ffi.Function"; + /*! \brief The type key for Array */ + static constexpr const char *kTVMFFIArray = "ffi.Array"; + /*! \brief The type key for Map */ + static constexpr const char *kTVMFFIMap = "ffi.Map"; + /*! \brief The type key for Module */ + static constexpr const char *kTVMFFIModule = "ffi.Module"; + /*! \brief The type key for OpaquePyObject */ + static constexpr const char *kTVMFFIOpaquePyObject = "ffi.OpaquePyObject"; +}; + +/*! + * \brief Get type key from type index + * \param type_index The input type index + * \return the type key + */ +inline std::string TypeIndexToTypeKey(int32_t type_index) { + const TypeInfo *type_info = TVMFFIGetTypeInfo(type_index); + return std::string(type_info->type_key.data, type_info->type_key.size); +} + +namespace details { +// Helper to perform +// unsafe operations related to object +struct ObjectUnsafe; + +/*! \brief One counter for weak reference. */ +constexpr uint64_t kCombinedRefCountWeakOne = static_cast(1) << 32; +/*! \brief One counter for strong reference. */ +constexpr uint64_t kCombinedRefCountStrongOne = 1; +/*! \brief Both reference counts. */ +constexpr uint64_t kCombinedRefCountBothOne = kCombinedRefCountWeakOne | kCombinedRefCountStrongOne; +/*! \brief Mask to get the lower 32 bits of the combined reference count. */ +constexpr uint64_t kCombinedRefCountMaskUInt32 = (static_cast(1) << 32) - 1; + +/*! + * Check if the type_index is an instance of TargetObjectType. + * + * \tparam TargetType The target object type to be checked. + * + * \param object_type_index The type index to be checked, caller + * ensures that the index is already within the object index range. + * + * \return Whether the target type is true. + */ +template +TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index); +} // namespace details + +/*! + * \brief Base class of all object containers. + * + * Sub-class of objects should declare the following static constexpr fields: + * + * - _type_index: + * Static type index of the object, if assigned to TypeIndex::kTVMFFIDynObject + * the type index will be assigned during runtime. + * Runtime type index can be accessed by ObjectType::TypeIndex(); + * - _type_key: + * The unique string identifier of the type. + * - _type_final: + * Whether the type is terminal type(there is no subclass of the type in the object system). + * This field is automatically set by macro TVM_FFI_DECLARE_OBJECT_INFO_FINAL + * It is still OK to sub-class a terminal object type T and construct it using make_object. + * But IsInstance check will only show that the object type is T(instead of the sub-class). + * - _type_mutable: + * Whether we would like to expose cast to non-constant pointer + * ObjectType* from Any/AnyView. By default, we set to false so it is not exposed. + * + * The following two fields are necessary for base classes that can be sub-classed. + * + * - _type_child_slots: + * Number of reserved type index slots for child classes. + * Used for runtime optimization for type checking in IsInstance. + * If an object's type_index is within range of [type_index, type_index + _type_child_slots] + * Then the object can be quickly decided as sub-class of the current object class. + * If not, a fallback mechanism is used to check the global type table. + * Recommendation: set to estimate number of children needed. + * + * - _type_child_slots_can_overflow: + * Whether we can add additional child classes even if the number of child classes + * exceeds the _type_child_slots. A fallback mechanism to check type table will be used. + * Recommendation: set to false for optimal runtime speed if we know exact number of children. + * + * Two macros are used to declare helper functions in the object: + * - Use TVM_FFI_DECLARE_OBJECT_INFO for object classes that can be sub-classed. + * - Use TVM_FFI_DECLARE_OBJECT_INFO_FINAL for object classes that cannot be sub-classed. + * + * New objects can be created using make_object function. + * Which will automatically populate the type_index and deleter of the object. + */ +class Object { +protected: + /*! \brief header field that is the common prefix of all objects */ + TVMFFIObject header_; + +public: + Object() { + header_.combined_ref_count = 0; + header_.type_index = 0; + header_.__padding = 0; + header_.__ensure_align = 0; + } + /*! + * Check if the object is an instance of TargetType. + * \tparam TargetType The target type to be checked. + * \return Whether the target type is true. + */ + template + bool IsInstance() const { + return details::IsObjectInstance(header_.type_index); + } + + /*! \return The internal runtime type index of the object. */ + int32_t type_index() const { return header_.type_index; } + + /*! + * \return the type key of the object. + * \note this operation is expensive, can be used for error reporting. + */ + std::string GetTypeKey() const { + // the function checks that the info exists + const TypeInfo *type_info = TVMFFIGetTypeInfo(header_.type_index); + return std::string(type_info->type_key.data, type_info->type_key.size); + } + + /*! + * \return A hash value of the return of GetTypeKey. + */ + uint64_t GetTypeKeyHash() const { + // the function checks that the info exists + const TypeInfo *type_info = TVMFFIGetTypeInfo(header_.type_index); + return type_info->type_key_hash; + } + + /*! + * \brief Get the type key of the corresponding index from runtime. + * \param tindex The type index. + * \return the result. + */ + static std::string TypeIndex2Key(int32_t tindex) { + const TypeInfo *type_info = TVMFFIGetTypeInfo(tindex); + return std::string(type_info->type_key.data, type_info->type_key.size); + } + + /*! + * \return Whether the object.use_count() == 1. + */ + bool unique() const { return use_count() == 1; } + + /*! + * \return The usage count of the cell. + * \note We use STL style naming to be consistent with known API in shared_ptr. + */ + uint64_t use_count() const { + // only need relaxed load of counters +#ifdef _MSC_VER + return ((reinterpret_cast( + &header_.combined_ref_count))[0] // NOLINT(*) + ) + & kCombinedRefCountMaskUInt32; +#else + return __atomic_load_n(&(header_.combined_ref_count), __ATOMIC_RELAXED) & kCombinedRefCountMaskUInt32; +#endif + } + + //---------------------------------------------------------------------------- + // The following fields are configuration flags for subclasses of object + //---------------------------------------------------------------------------- + /*! \brief The type key of the class */ + static constexpr const char *_type_key = StaticTypeKey::kTVMFFIObject; + /*! \brief Whether the class is final */ + static constexpr bool _type_final = false; + /*! \brief Whether allow mutable access to fields */ + static constexpr bool _type_mutable = false; + /*! \brief The number of child slots of the class to pre-allocate to this type */ + static constexpr uint32_t _type_child_slots = 0; + /*! + * \brief Whether allow additional children beyond pre-specified by _type_child_slots + */ + static constexpr bool _type_child_slots_can_overflow = true; + /*! \brief The static type index of the class */ + static constexpr int32_t _type_index = TypeIndex::kTVMFFIObject; + /*! \brief The static depth of the class in the object hierarchy */ + static constexpr int32_t _type_depth = 0; + /*! \brief The structural equality and hash kind of the type */ + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUnsupported; + // The following functions are provided by macro + // TVM_FFI_DECLARE_OBJECT_INFO and TVM_FFI_DECLARE_OBJECT_INFO_FINAL + /*! + * \brief Get the runtime allocated type index of the type + * \note Getting this information may need dynamic calls into a global table. + */ + static int32_t RuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; } + /*! + * \brief Internal function to get or allocate a runtime index. + */ + static int32_t _GetOrAllocRuntimeTypeIndex() { // NOLINT(bugprone-reserved-identifier) + return TypeIndex::kTVMFFIObject; + } + +private: + // exposing detailed constants to here + static constexpr uint64_t kCombinedRefCountMaskUInt32 = details::kCombinedRefCountMaskUInt32; + static constexpr uint64_t kCombinedRefCountStrongOne = details::kCombinedRefCountStrongOne; + static constexpr uint64_t kCombinedRefCountWeakOne = details::kCombinedRefCountWeakOne; + static constexpr uint64_t kCombinedRefCountBothOne = details::kCombinedRefCountBothOne; + /*! \brief increase strong reference count, the caller must already hold a strong reference */ + void IncRef() { +#ifdef _MSC_VER + _InterlockedIncrement64( + reinterpret_cast(&header_.combined_ref_count)); // NOLINT(*) +#else + __atomic_fetch_add(&(header_.combined_ref_count), 1, __ATOMIC_RELAXED); +#endif + } + /*! + * \brief Try to lock the object to increase the strong reference count, + * the caller must already hold a strong reference. + * \return whether the lock call is successful and object is still alive. + */ + bool TryPromoteWeakPtr() { +#ifdef _MSC_VER + uint64_t old_count = (reinterpret_cast(&header_.combined_ref_count))[0]; // NOLINT(*) + while ((old_count & kCombinedRefCountMaskUInt32) != 0) { + uint64_t new_count = old_count + kCombinedRefCountStrongOne; + uint64_t old_count_loaded = _InterlockedCompareExchange64( + reinterpret_cast(&header_.combined_ref_count), new_count, old_count); + if (old_count == old_count_loaded) { + return true; + } + old_count = old_count_loaded; + } + return false; +#else + uint64_t old_count = __atomic_load_n(&(header_.combined_ref_count), __ATOMIC_RELAXED); + while ((old_count & kCombinedRefCountMaskUInt32) != 0) { + // must do CAS to ensure that we are the only one that increases the reference count + // avoid condition when two threads tries to promote weak to strong at same time + // or when strong deletion happens between the load and the CAS + uint64_t new_count = old_count + kCombinedRefCountStrongOne; + if (__atomic_compare_exchange_n(&(header_.combined_ref_count), &old_count, new_count, true, + __ATOMIC_ACQ_REL, __ATOMIC_RELAXED)) { + return true; + } + } + return false; +#endif + } + + /*! \brief increase weak reference count */ + void IncWeakRef() { +#ifdef _MSC_VER + _InlineInterlockedAdd64( + reinterpret_cast(&header_.combined_ref_count), // NOLINT(*) + kCombinedRefCountWeakOne); +#else + __atomic_fetch_add(&(header_.combined_ref_count), kCombinedRefCountWeakOne, __ATOMIC_RELAXED); +#endif + } + + /*! \brief decrease strong reference count and delete the object */ + void DecRef() { +#ifdef _MSC_VER + // use simpler impl in windows to ensure correctness + uint64_t count_before_sub = _InterlockedDecrement64( // + reinterpret_cast(&header_.combined_ref_count) // NOLINT(*) + ) + + 1; + if (count_before_sub == kCombinedRefCountBothOne) { // NOLINT(*) + // fast path: both reference counts will go to zero + if (header_.deleter != nullptr) { + // full barrrier is implicit in InterlockedDecrement + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskBoth); + } + } else if ((count_before_sub & kCombinedRefCountMaskUInt32) == kCombinedRefCountStrongOne) { + // strong reference count becomes zero, we need to first do strong deletion + // then decrease weak reference count + // full barrrier is implicit in InterlockedAdd + if (header_.deleter != nullptr) { + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskStrong); + } + // decrease weak reference count + if (_InlineInterlockedAdd64( // + reinterpret_cast(&header_.combined_ref_count), + -kCombinedRefCountWeakOne) + == 0) { // NOLINT(*) + if (header_.deleter != nullptr) { + // full barrrier is implicit in InterlockedAdd + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); + } + } + } +#else + // first do a release, note we only need to acquire for deleter + // optimization: we only need one atomic to tell the common case + // where both reference counts are zero + uint64_t count_before_sub = __atomic_fetch_sub(&(header_.combined_ref_count), + kCombinedRefCountStrongOne, __ATOMIC_RELEASE); + if (count_before_sub == kCombinedRefCountBothOne) { + // common case, we need to delete both the object and the memory block + // only acquire when we need to call deleter + __atomic_thread_fence(__ATOMIC_ACQUIRE); + if (header_.deleter != nullptr) { + // call deleter once + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskBoth); + } + } else if ((count_before_sub & kCombinedRefCountMaskUInt32) == kCombinedRefCountStrongOne) { + // strong count is already zero + // Slower path: there is still a weak reference left + __atomic_thread_fence(__ATOMIC_ACQUIRE); + // call destructor first, then decrease weak reference count + if (header_.deleter != nullptr) { + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskStrong); + } + // now decrease weak reference count + if (__atomic_fetch_sub(&(header_.combined_ref_count), kCombinedRefCountWeakOne, + __ATOMIC_RELEASE) + == kCombinedRefCountWeakOne) { + __atomic_thread_fence(__ATOMIC_ACQUIRE); + if (header_.deleter != nullptr) { + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); + } + } + } +#endif + } + + /*! \brief decrease weak reference count */ + void DecWeakRef() { +#ifdef _MSC_VER + if (_InlineInterlockedAdd64( // + reinterpret_cast(&header_.combined_ref_count), // NOLINT(*) + -kCombinedRefCountWeakOne) + == 0) { + if (header_.deleter != nullptr) { + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); + } + } +#else + // now decrease weak reference count + if (__atomic_fetch_sub(&(header_.combined_ref_count), kCombinedRefCountWeakOne, + __ATOMIC_RELEASE) + == kCombinedRefCountWeakOne) { + __atomic_thread_fence(__ATOMIC_ACQUIRE); + if (header_.deleter != nullptr) { + header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); + } + } +#endif + } + + // friend classes + template + friend class ObjectPtr; + template + friend class WeakObjectPtr; + friend struct tvm::ffi::details::ObjectUnsafe; +}; + +/*! + * \brief A custom smart pointer for Object. + * \tparam T the content data type. + * \sa make_object + */ +template +class ObjectPtr { +public: + /*! \brief default constructor */ + ObjectPtr() = default; + /*! \brief default constructor */ + ObjectPtr(std::nullptr_t) {} // NOLINT(*) + /*! + * \brief copy constructor + * \param other The value to be moved + */ + ObjectPtr(const ObjectPtr &other) // NOLINT(*) + : ObjectPtr(other.data_) {} + /*! + * \brief copy constructor + * \param other The value to be moved + */ + template + ObjectPtr(const ObjectPtr &other) // NOLINT(*) + : ObjectPtr(other.data_) { + static_assert(std::is_base_of_v, "can only assign of child class ObjectPtr to parent"); + } + /*! + * \brief move constructor + * \param other The value to be moved + */ + ObjectPtr(ObjectPtr &&other) // NOLINT(*) + : data_(other.data_) { + other.data_ = nullptr; + } + /*! + * \brief move constructor + * \param other The value to be moved + */ + template + ObjectPtr(ObjectPtr &&other) // NOLINT(*) + : data_(other.data_) { + static_assert(std::is_base_of_v, "can only assign of child class ObjectPtr to parent"); + other.data_ = nullptr; + } + /*! \brief destructor */ + ~ObjectPtr() { this->reset(); } + /*! + * \brief Swap this array with another Object + * \param other The other Object + */ + void swap(ObjectPtr &other) { // NOLINT(*) + std::swap(data_, other.data_); + } + /*! + * \return Get the content of the pointer + */ + T *get() const { return static_cast(data_); } + /*! + * \return The pointer + */ + T *operator->() const { return get(); } + /*! + * \return The reference + */ + T &operator*() const { // NOLINT(*) + return *get(); + } + /*! + * \brief copy assignment + * \param other The value to be assigned. + * \return reference to self. + */ + ObjectPtr &operator=(const ObjectPtr &other) { // NOLINT(*) + // takes in plane operator to enable copy elison. + // copy-and-swap idiom + ObjectPtr(other).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief move assignment + * \param other The value to be assigned. + * \return reference to self. + */ + ObjectPtr &operator=(ObjectPtr &&other) { // NOLINT(*) + // copy-and-swap idiom + ObjectPtr(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief nullptr check + * \return result of comparison of internal pointer with nullptr. + */ + explicit operator bool() const { return get() != nullptr; } + /*! \brief reset the content of ptr to be nullptr */ + void reset() { + if (data_ != nullptr) { + data_->DecRef(); + data_ = nullptr; + } + } + /*! \return The use count of the ptr, for debug purposes */ + int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } + /*! \return whether the reference is unique */ + bool unique() const { return data_ != nullptr && data_->use_count() == 1; } + /*! \return Whether two ObjectPtr do not equal each other */ + bool operator==(const ObjectPtr &other) const { return data_ == other.data_; } + /*! \return Whether two ObjectPtr equals each other */ + bool operator!=(const ObjectPtr &other) const { return data_ != other.data_; } + /*! \return Whether the pointer is nullptr */ + bool operator==(std::nullptr_t) const { return data_ == nullptr; } + /*! \return Whether the pointer is not nullptr */ + bool operator!=(std::nullptr_t) const { return data_ != nullptr; } + +private: + /*! \brief internal pointer field */ + Object *data_{nullptr}; + /*! + * \brief constructor from Object + * \param data The data pointer + */ + explicit ObjectPtr(Object *data) : data_(data) { + if (data_ != nullptr) { + data_->IncRef(); + } + } + // friend classes + friend class Object; + friend class ObjectRef; + friend struct ObjectPtrHash; + template + friend class ObjectPtr; + template + friend class WeakObjectPtr; + friend struct tvm::ffi::details::ObjectUnsafe; +}; + +/*! + * \brief A custom smart pointer for Object. + * \tparam T the content data type. + * \sa make_object + */ +template +class WeakObjectPtr { +public: + /*! \brief default constructor */ + WeakObjectPtr() = default; + /*! \brief default constructor */ + WeakObjectPtr(std::nullptr_t) {} // NOLINT(*) + /*! + * \brief copy constructor + * \param other The value to be moved + */ + WeakObjectPtr(const WeakObjectPtr &other) // NOLINT(*) + : WeakObjectPtr(other.data_) {} + + /*! + * \brief copy constructor + * \param other The value to be moved + */ + WeakObjectPtr(const ObjectPtr &other) // NOLINT(*) + : WeakObjectPtr(other.get()) {} + /*! + * \brief copy constructor + * \param other The value to be moved + */ + template + WeakObjectPtr(const WeakObjectPtr &other) // NOLINT(*) + : WeakObjectPtr(other.data_) { + static_assert(std::is_base_of_v, "can only assign of child class ObjectPtr to parent"); + } + /*! + * \brief copy constructor + * \param other The value to be moved + */ + template + WeakObjectPtr(const ObjectPtr &other) // NOLINT(*) + : WeakObjectPtr(other.data_) { + static_assert(std::is_base_of_v, "can only assign of child class ObjectPtr to parent"); + } + /*! + * \brief move constructor + * \param other The value to be moved + */ + WeakObjectPtr(WeakObjectPtr &&other) // NOLINT(*) + : data_(other.data_) { + other.data_ = nullptr; + } + /*! + * \brief move constructor + * \param other The value to be moved + */ + template + WeakObjectPtr(WeakObjectPtr &&other) // NOLINT(*) + : data_(other.data_) { + static_assert(std::is_base_of_v, "can only assign of child class ObjectPtr to parent"); + other.data_ = nullptr; + } + /*! \brief destructor */ + ~WeakObjectPtr() { this->reset(); } + /*! + * \brief Swap this array with another Object + * \param other The other Object + */ + void swap(WeakObjectPtr &other) { // NOLINT(*) + std::swap(data_, other.data_); + } + + /*! + * \brief copy assignment + * \param other The value to be assigned. + * \return reference to self. + */ + WeakObjectPtr &operator=(const WeakObjectPtr &other) { // NOLINT(*) + // takes in plane operator to enable copy elison. + // copy-and-swap idiom + WeakObjectPtr(other).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief move assignment + * \param other The value to be assigned. + * \return reference to self. + */ + WeakObjectPtr &operator=(WeakObjectPtr &&other) { // NOLINT(*) + // copy-and-swap idiom + WeakObjectPtr(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + + /*! \return The internal object pointer if the object is still alive, otherwise nullptr */ + ObjectPtr lock() const { + if (data_ != nullptr && data_->TryPromoteWeakPtr()) { + ObjectPtr ret; + // we already increase the reference count, so we don't need to do it again + ret.data_ = data_; + return ret; + } + return nullptr; + } + + /*! \brief reset the content of ptr to be nullptr */ + void reset() { + if (data_ != nullptr) { + data_->DecWeakRef(); + data_ = nullptr; + } + } + + /*! \return The use count of the ptr, for debug purposes */ + int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } + + /*! \return whether the pointer is nullptr */ + bool expired() const { return data_ == nullptr || data_->use_count() == 0; } + +private: + /*! \brief internal pointer field */ + Object *data_{nullptr}; + + /*! + * \brief constructor from Object + * \param data The data pointer + */ + explicit WeakObjectPtr(Object *data) : data_(data) { + if (data_ != nullptr) { + data_->IncWeakRef(); + } + } + + template + friend class WeakObjectPtr; + friend struct tvm::ffi::details::ObjectUnsafe; +}; + +/*! + * \brief Optional data type in FFI. + * \tparam T The underlying type of the optional. + * + * \note Compared to std::optional, Optional + * akes less storage as it used nullptr to represent nullopt. + */ +template +class Optional; + +/*! \brief Base class of all object reference */ +class ObjectRef { +public: + /*! \brief default constructor */ + ObjectRef() = default; + /*! \brief copy constructor */ + ObjectRef(const ObjectRef &other) = default; + /*! \brief move constructor */ + ObjectRef(ObjectRef &&other) noexcept : data_(std::move(other.data_)) { other.data_ = nullptr; } + /*! \brief copy assignment */ + ObjectRef &operator=(const ObjectRef &other) = default; + /*! \brief move assignment */ + ObjectRef &operator=(ObjectRef &&other) noexcept { + data_ = std::move(other.data_); + other.data_ = nullptr; + return *this; + } + /*! \brief Constructor from existing object ptr */ + explicit ObjectRef(ObjectPtr data) : data_(std::move(data)) {} + /*! \brief Constructor from UnsafeInit */ + explicit ObjectRef(UnsafeInit) : data_(nullptr) {} + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool same_as(const ObjectRef &other) const { return data_ == other.data_; } + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool operator==(const ObjectRef &other) const { return data_ == other.data_; } + /*! + * \brief Comparator + * \param other Another object ref. + * \return the compare result. + */ + bool operator!=(const ObjectRef &other) const { return data_ != other.data_; } + /*! + * \brief Comparator + * \param other Another object ref by address. + * \return the compare result. + */ + bool operator<(const ObjectRef &other) const { return data_.get() < other.data_.get(); } + /*! + * \return whether the object is defined. + */ + bool defined() const { return data_ != nullptr; } + /*! \return the internal object pointer */ + const Object *get() const { return data_.get(); } + /*! \return the internal object pointer */ + const Object *operator->() const { return get(); } + /*! \return whether the reference is unique */ + bool unique() const { return data_.unique(); } + /*! \return The use count of the ptr, for debug purposes */ + int use_count() const { return data_.use_count(); } + + /*! + * \brief Try to downcast the internal Object to a + * raw pointer of a corresponding type. + * + * The function will return a nullptr if the cast failed. + * + * if (const AddNode *ptr = node_ref.as()) { + * // This is an add node + * } + * + * \tparam ObjectType the target type, must be a subtype of Object + * \return The pointer to the requested type. + */ + template >> + const ObjectType *as() const { + if (data_ != nullptr && data_->IsInstance()) { + return static_cast(data_.get()); + } else { + return nullptr; + } + } + + /*! + * \brief Try to downcast the ObjectRef to Optional of the requested type. + * + * The function will return a std::nullopt if the cast or if the pointer is nullptr. + * + * \tparam ObjectRefType the target type, must be a subtype of ObjectRef' + * \return The optional value of the requested type. + */ + template >> + TVM_FFI_INLINE std::optional as() const { + if (data_ != nullptr) { + if (data_->IsInstance()) { + ObjectRefType ref(UnsafeInit{}); + ref.data_ = data_; + return ref; + } else { + return std::nullopt; + } + } else { + return std::nullopt; + } + } + + /*! + * \brief Get the type index of the ObjectRef + * \return The type index of the ObjectRef + */ + int32_t type_index() const { + return data_ != nullptr ? data_->type_index() : TypeIndex::kTVMFFINone; + } + + /*! + * \brief Get the type key of the ObjectRef + * \return The type key of the ObjectRef + */ + std::string GetTypeKey() const { + return data_ != nullptr ? data_->GetTypeKey() : StaticTypeKey::kTVMFFINone; + } + + /*! \brief type indicate the container type. */ + using ContainerType = Object; + /*! \brief Whether the reference can point to nullptr */ + static constexpr bool _type_is_nullable = true; + +protected: + /*! \brief Internal pointer that backs the reference. */ + ObjectPtr data_; + /*! \return return a mutable internal ptr, can be used by sub-classes. */ + Object *get_mutable() const { return data_.get(); } + // friend classes. + friend struct ObjectPtrHash; + friend struct tvm::ffi::details::ObjectUnsafe; +}; + +// forward delcare variant +template +class Variant; + +/*! \brief ObjectRef hash functor */ +struct ObjectPtrHash { + size_t operator()(const ObjectRef &a) const { return operator()(a.data_); } + + template + size_t operator()(const ObjectPtr &a) const { + return std::hash()(a.get()); + } + + template + TVM_FFI_INLINE size_t operator()(const Variant &a) const; +}; + +/*! \brief ObjectRef equal functor */ +struct ObjectPtrEqual { + bool operator()(const ObjectRef &a, const ObjectRef &b) const { return a.same_as(b); } + + template + bool operator()(const ObjectPtr &a, const ObjectPtr &b) const { + return a == b; + } + + template + TVM_FFI_INLINE bool operator()(const Variant &a, const Variant &b) const; +}; + +/*! + * \brief Helper macro to declare object information with static type index. + * + * For each custom object, you need to call tvm::ffi::reflection::ObjectDef() + * once in your cc file to register the type index with the runtime. + * Alternatively, you can call TypeName::_GetOrAllocRuntimeTypeIndex() once. + * + * \param TypeKey The type key of the current type. + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + * + * \see tvm::ffi::reflection::ObjectDef + */ +#define TVM_FFI_DECLARE_OBJECT_INFO_STATIC(TypeKey, TypeName, ParentType) \ + static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \ + static int32_t _GetOrAllocRuntimeTypeIndex() { \ + static_assert(!ParentType::_type_final, "ParentType marked as final"); \ + static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || TypeName::_type_child_slots < ParentType::_type_child_slots, \ + "Need to set _type_child_slots when parent specifies it."); \ + TVMFFIByteArray type_key{TypeName::_type_key, \ + std::char_traits::length(TypeName::_type_key)}; \ + static int32_t tindex [[maybe_unused]] = TVMFFITypeGetOrAllocIndex( \ + &type_key, TypeName::_type_index, TypeName::_type_depth, TypeName::_type_child_slots, \ + TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \ + return TypeName::_type_index; \ + } \ + static int32_t RuntimeTypeIndex() { return TypeName::_type_index; } \ + static constexpr const char *_type_key = TypeKey + +/*! + * \brief Helper macro to declare object information with type key already defined in class. + * + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + */ +#define TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(TypeName, ParentType) \ + static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \ + static int32_t _GetOrAllocRuntimeTypeIndex() { \ + static_assert(!ParentType::_type_final, "ParentType marked as final"); \ + static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || TypeName::_type_child_slots < ParentType::_type_child_slots, \ + "Need to set _type_child_slots when parent specifies it."); \ + TVMFFIByteArray type_key{TypeName::_type_key, \ + std::char_traits::length(TypeName::_type_key)}; \ + static int32_t tindex = TVMFFITypeGetOrAllocIndex( \ + &type_key, -1, TypeName::_type_depth, TypeName::_type_child_slots, \ + TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \ + return tindex; \ + } \ + static int32_t RuntimeTypeIndex() { return _GetOrAllocRuntimeTypeIndex(); } + +/*! + * \brief Helper macro to declare object information with dynamic type index. + * + * For each custom object, you need to call tvm::ffi::reflection::ObjectDef() + * once in your cc file to register the type index with the runtime. + * Alternatively, you can call TypeName::_GetOrAllocRuntimeTypeIndex() once. + * + * \param TypeKey The type key of the current type. + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + * \sa tvm::ffi::reflection::ObjectDef + */ +#define TVM_FFI_DECLARE_OBJECT_INFO(TypeKey, TypeName, ParentType) \ + static constexpr const char *_type_key = TypeKey; \ + TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(TypeName, ParentType) + +/*! + * \brief Helper macro to declare object information with dynamic type index and is final. + * + * For each custom object, you need to call tvm::ffi::reflection::ObjectDef() + * once in your cc file to register the type index with the runtime. + * Alternatively, you can call TypeName::_GetOrAllocRuntimeTypeIndex() once. + * + * \param TypeKey The type key of the current type. + * \param TypeName The name of the current type. + * \param ParentType The name of the ParentType + * \sa tvm::ffi::reflection::ObjectDef + */ +#define TVM_FFI_DECLARE_OBJECT_INFO_FINAL(TypeKey, TypeName, ParentType) \ + static const constexpr int _type_child_slots [[maybe_unused]] = 0; \ + static const constexpr bool _type_final [[maybe_unused]] = true; \ + TVM_FFI_DECLARE_OBJECT_INFO(TypeKey, TypeName, ParentType) + +/*! + * \brief Define object reference methods. + * + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + * + * \note This macro also defines the default constructor that puts the ObjectRef + * in undefined state initially. + */ +#define TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + explicit TypeName(::tvm::ffi::ObjectPtr n) : ParentType(std::move(n)) {} \ + explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \ + TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + using __PtrType = std::conditional_t<(ObjectName::_type_mutable), \ + ObjectName *, /* NOLINT(bugprone-macro-parentheses) */ \ + const ObjectName *>; \ + __PtrType operator->() const { return static_cast<__PtrType>(data_.get()); } \ + __PtrType get() const { return static_cast<__PtrType>(data_.get()); } \ + [[maybe_unused]] static constexpr bool _type_is_nullable = true; \ + using ContainerType = ObjectName + +/*! + * \brief Define object reference methods do not have undefined state. + * + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ +#define TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TypeName, ParentType, ObjectName) \ + explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \ + TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + using __PtrType = std::conditional_t<(ObjectName::_type_mutable), \ + ObjectName *, /* NOLINT(bugprone-macro-parentheses) */ \ + const ObjectName *>; \ + __PtrType operator->() const { return static_cast<__PtrType>(data_.get()); } \ + __PtrType get() const { return static_cast<__PtrType>(data_.get()); } \ + [[maybe_unused]] static constexpr bool _type_is_nullable = false; \ + using ContainerType = ObjectName + +namespace details { + +template +TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index) { + static_assert(std::is_base_of_v); + // Everything is a subclass of object. + if constexpr (std::is_same_v) { + return true; + } else if constexpr (TargetType::_type_final) { + // if the target type is a final type + // then we only need to check the equivalence. + return object_type_index == TargetType::RuntimeTypeIndex(); + } else { + // Explicitly enclose in else to eliminate this branch early in compilation. + // if target type is a non-leaf type + // Check if type index falls into the range of reserved slots. + int32_t target_type_index = TargetType::RuntimeTypeIndex(); + int32_t begin = target_type_index; + // The condition will be optimized by constant-folding. + if constexpr (TargetType::_type_child_slots != 0) { + // total_slots = child_slots + 1 (including self) + int32_t end = begin + TargetType::_type_child_slots + 1; + if (object_type_index >= begin && object_type_index < end) { + return true; + } + } else { + if (object_type_index == begin) { + return true; + } + } + if constexpr (TargetType::_type_child_slots_can_overflow) { + // Invariance: parent index is always smaller than the child. + if (object_type_index < target_type_index) { + return false; + } + // Do a runtime lookup of type information + // the function checks that the info exists + const TypeInfo *type_info = TVMFFIGetTypeInfo(object_type_index); + return (type_info->type_depth > TargetType::_type_depth && type_info->type_ancestors[TargetType::_type_depth]->type_index == target_type_index); + } else { + return false; + } + } +} + +/*! + * \brief Namespace to internally manipulate object class. + * \note These functions are only supposed to be used by internal + * implementations and not external users of the tvm::ffi + */ +struct ObjectUnsafe { + // NOTE: get ffi header from an object + TVM_FFI_INLINE static TVMFFIObject *GetHeader(const Object *src) { + return const_cast(&(src->header_)); + } + + template + TVM_FFI_INLINE static int64_t GetObjectOffsetToSubclass() { + return (reinterpret_cast(&(static_cast(nullptr)->header_)) - reinterpret_cast(&(static_cast(nullptr)->header_))); + } + + template + TVM_FFI_INLINE static T ObjectRefFromObjectPtr(const ObjectPtr &ptr) { + T ref(UnsafeInit{}); + ref.data_ = ptr; + return ref; + } + + template + TVM_FFI_INLINE static T ObjectRefFromObjectPtr(ObjectPtr &&ptr) { + T ref(UnsafeInit{}); + ref.data_ = std::move(ptr); + return ref; + } + + template + TVM_FFI_INLINE static ObjectPtr ObjectPtrFromObjectRef(const ObjectRef &ref) { + if constexpr (std::is_same_v) { + return ref.data_; + } else { + return tvm::ffi::ObjectPtr(ref.data_.data_); + } + } + + template + TVM_FFI_INLINE static ObjectPtr ObjectPtrFromObjectRef(ObjectRef &&ref) { + if constexpr (std::is_same_v) { + return std::move(ref.data_); + } else { + ObjectPtr result; + result.data_ = std::move(ref.data_.data_); + ref.data_.data_ = nullptr; + return result; + } + } + + template + TVM_FFI_INLINE static ObjectPtr ObjectPtrFromOwned(Object *raw_ptr) { + tvm::ffi::ObjectPtr ptr; + ptr.data_ = raw_ptr; + return ptr; + } + + template + TVM_FFI_INLINE static ObjectPtr ObjectPtrFromOwned(TVMFFIObject *obj_ptr) { + return ObjectPtrFromOwned(reinterpret_cast(obj_ptr)); + } + + template + TVM_FFI_INLINE static T *RawObjectPtrFromUnowned(TVMFFIObject *obj_ptr) { + // NOTE: this is important to first cast to Object* + // then cast back to T* because objptr and tptr may not be the same + // depending on how sub-class allocates the space. + return static_cast(reinterpret_cast(obj_ptr)); + } + + // Create ObjectPtr from unowned ptr + template + TVM_FFI_INLINE static ObjectPtr ObjectPtrFromUnowned(Object *raw_ptr) { + return tvm::ffi::ObjectPtr(raw_ptr); + } + + template + TVM_FFI_INLINE static ObjectPtr ObjectPtrFromUnowned(TVMFFIObject *obj_ptr) { + return tvm::ffi::ObjectPtr(reinterpret_cast(obj_ptr)); + } + + TVM_FFI_INLINE static void DecRefObjectHandle(TVMFFIObjectHandle handle) { + reinterpret_cast(handle)->DecRef(); + } + + TVM_FFI_INLINE static void IncRefObjectHandle(TVMFFIObjectHandle handle) { + reinterpret_cast(handle)->IncRef(); + } + + TVM_FFI_INLINE static Object *RawObjectPtrFromObjectRef(const ObjectRef &src) { + return src.data_.data_; + } + + TVM_FFI_INLINE static TVMFFIObject *TVMFFIObjectPtrFromObjectRef(const ObjectRef &src) { + return GetHeader(src.data_.data_); + } + + template + TVM_FFI_INLINE static TVMFFIObject *TVMFFIObjectPtrFromObjectPtr(const ObjectPtr &src) { + return GetHeader(src.data_); + } + + template + TVM_FFI_INLINE static TVMFFIObject *MoveObjectPtrToTVMFFIObjectPtr(ObjectPtr &&src) { + Object *obj_ptr = src.data_; + src.data_ = nullptr; + return GetHeader(obj_ptr); + } + + TVM_FFI_INLINE static TVMFFIObject *MoveObjectRefToTVMFFIObjectPtr(ObjectRef &&src) { + Object *obj_ptr = src.data_.data_; + src.data_.data_ = nullptr; + return GetHeader(obj_ptr); + } +}; +} // namespace details +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_OBJECT_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/optional.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/optional.h new file mode 100644 index 000000000..11dcc46a8 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/optional.h @@ -0,0 +1,428 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/optional.h + * \brief Runtime Optional container types. + * \note Optional specializes for T is ObjectRef and used nullptr to indicate nullopt. + */ +#ifndef TVM_FFI_OPTIONAL_H_ +#define TVM_FFI_OPTIONAL_H_ + +#include "error.h" +#include "object.h" +#include "string.h" + +#include +#include +#include + +namespace tvm { +namespace ffi { + +// Note: We place optional in tvm/ffi instead of tvm/ffi/container +// because optional itself is an inherent core component of the FFI system. +/// \cond Doxygen_Suppress +template +inline constexpr bool is_optional_type_v = false; + +template +inline constexpr bool is_optional_type_v> = true; + +// we can safely used ptr based optional for ObjectRef types +// that do not have additional data members and virtual functions. +template +inline constexpr bool use_ptr_based_optional_v = (std::is_base_of_v && !is_optional_type_v); +/// \endcond + +// Specialization for non-ObjectRef types. +// simply fallback to std::optional +template +class Optional && !std::is_same_v && !std::is_same_v>> { +public: + // default constructors. + Optional() = default; + // NOLINTBEGIN(google-explicit-constructor) + Optional(const Optional &other) : data_(other.data_) {} + Optional(Optional &&other) noexcept : data_(std::move(other.data_)) {} + Optional(std::optional other) : data_(std::move(other)) {} + Optional(std::nullopt_t) {} + Optional(T other) : data_(std::move(other)) {} + // NOLINTEND(google-explicit-constructor) + + TVM_FFI_INLINE Optional &operator=(const Optional &other) { + data_ = other.data_; + return *this; + } + + TVM_FFI_INLINE Optional &operator=(Optional &&other) noexcept { + data_ = std::move(other.data_); + return *this; + } + + TVM_FFI_INLINE Optional &operator=(T other) { + data_ = std::move(other); + return *this; + } + + TVM_FFI_INLINE Optional &operator=(std::nullopt_t) { + data_ = std::nullopt; + return *this; + } + + TVM_FFI_INLINE const T &value() const & { + if (!data_.has_value()) { + TVM_FFI_THROW(RuntimeError) << "Back optional access"; + } + return *data_; + } + + TVM_FFI_INLINE T &&value() && { + if (!data_.has_value()) { + TVM_FFI_THROW(RuntimeError) << "Back optional access"; + } + return *std::move(data_); + } + + template > + TVM_FFI_INLINE T value_or(U &&default_value) const { + return data_.value_or(std::forward(default_value)); + } + + TVM_FFI_INLINE explicit operator bool() const noexcept { return data_.has_value(); } + + TVM_FFI_INLINE bool has_value() const noexcept { return data_.has_value(); } + + TVM_FFI_INLINE bool operator==(const Optional &other) const { return data_ == other.data_; } + + TVM_FFI_INLINE bool operator!=(const Optional &other) const { return data_ != other.data_; } + + template + TVM_FFI_INLINE bool operator==(const U &other) const { + return data_ == other; + } + template + TVM_FFI_INLINE bool operator!=(const U &other) const { + return data_ != other; + } + + // NOLINTBEGIN(bugprone-unchecked-optional-access) + /*! + * \brief Direct access to the value. + * \return the xvalue reference to the stored value. + * \note only use this function after checking has_value() + */ + TVM_FFI_INLINE T &&operator*() && noexcept { return *std::move(data_); } + /*! + * \brief Direct access to the value. + * \return the const reference to the stored value. + * \note only use this function after checking has_value() + */ + TVM_FFI_INLINE const T &operator*() const & noexcept { return *data_; } + // NOLINTEND(bugprone-unchecked-optional-access) + +private: + std::optional data_; +}; + +// Specialization for String type, use nullptr to indicate nullopt +template +class Optional || std::is_same_v>> { +public: + // default constructors. + Optional() = default; + // NOLINTBEGIN(google-explicit-constructor) + Optional(const Optional &other) : data_(other.data_) {} + Optional(Optional &&other) : data_(std::move(other.data_)) {} + Optional(std::nullopt_t) {} + Optional(T other) : data_(std::move(other)) {} + // NOLINTEND(google-explicit-constructor) + + TVM_FFI_INLINE Optional &operator=(const Optional &other) { + data_ = other.data_; + return *this; + } + + TVM_FFI_INLINE Optional &operator=(Optional &&other) { + data_ = std::move(other.data_); + return *this; + } + + TVM_FFI_INLINE Optional &operator=(T other) { + data_ = std::move(other); + return *this; + } + + TVM_FFI_INLINE Optional &operator=(std::nullopt_t) { + T(details::BytesBaseCell(std::nullopt)).swap(data_); + return *this; + } + + TVM_FFI_INLINE const T &value() const & { + if (data_.data_ == std::nullopt) { + TVM_FFI_THROW(RuntimeError) << "Back optional access"; + } + return data_; + } + + TVM_FFI_INLINE String &&value() && { + if (data_.data_ == std::nullopt) { + TVM_FFI_THROW(RuntimeError) << "Back optional access"; + } + return std::move(data_); + } + + template + TVM_FFI_INLINE T value_or(U &&default_value) const { + if (data_.data_ == std::nullopt) { + return std::forward(default_value); + } + return data_; + } + + TVM_FFI_INLINE explicit operator bool() const noexcept { return data_.data_ != std::nullopt; } + + TVM_FFI_INLINE bool has_value() const noexcept { return data_.data_ != std::nullopt; } + + TVM_FFI_INLINE bool operator==(const Optional &other) const { + if (data_.data_ == std::nullopt) { + return other.data_.data_ == std::nullopt; + } + if (other.data_.data_ == std::nullopt) { + return false; + } + return data_ == other.data_; + } + + TVM_FFI_INLINE bool operator!=(const Optional &other) const { return !(*this == other); } + + template + TVM_FFI_INLINE bool operator==(const U &other) const { + if constexpr (std::is_same_v) { + return data_.data_ == std::nullopt; + } else { + if (data_.data_ == std::nullopt) { + return false; + } + return data_ == other; + } + } + template + TVM_FFI_INLINE bool operator!=(const U &other) const { + if constexpr (std::is_same_v) { + return data_.data_ != std::nullopt; + } else { + if (data_.data_ == std::nullopt) { + return true; + } + return data_ != other; + } + } + + /*! + * \brief Direct access to the value. + * \return the xvalue reference to the stored value. + * \note only use this function after checking has_value() + */ + TVM_FFI_INLINE T &&operator*() && noexcept { return std::move(data_); } + /*! + * \brief Direct access to the value. + * \return the const reference to the stored value. + * \note only use this function after checking has_value() + */ + TVM_FFI_INLINE const T &operator*() const & noexcept { return data_; } + +private: + // this is a private initializer + T data_{details::BytesBaseCell(std::nullopt)}; +}; + +// Specialization for ObjectRef types. +// nullptr is treated as std::nullopt. +template +class Optional>> : public ObjectRef { +public: + using ContainerType = typename T::ContainerType; + Optional() = default; + // NOLINTBEGIN(google-explicit-constructor) + Optional(const Optional &other) : ObjectRef(other) {} + Optional(Optional &&other) noexcept : ObjectRef(std::move(other)) {} + explicit Optional(ffi::UnsafeInit tag) : ObjectRef(tag) {} + Optional(std::nullopt_t) {} + Optional(std::optional other) { + if (other.has_value()) { + *this = *std::move(other); + } + } + Optional(T other) : ObjectRef(std::move(other)) {} + // NOLINTEND(google-explicit-constructor) + + TVM_FFI_INLINE Optional &operator=(T other) { + ObjectRef::operator=(std::move(other)); + return *this; + } + + TVM_FFI_INLINE Optional &operator=(const Optional &other) { + data_ = other.data_; + return *this; + } + + TVM_FFI_INLINE Optional &operator=(std::nullptr_t) { + data_ = nullptr; + return *this; + } + + TVM_FFI_INLINE Optional &operator=(Optional &&other) { + data_ = std::move(other.data_); + return *this; + } + + TVM_FFI_INLINE T value() const & { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Back optional access"; + } + return details::ObjectUnsafe::ObjectRefFromObjectPtr(data_); + } + + TVM_FFI_INLINE T value() && { + if (data_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Back optional access"; + } + return details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(data_)); + } + + template > + TVM_FFI_INLINE T value_or(U &&default_value) const { + return data_ != nullptr ? details::ObjectUnsafe::ObjectRefFromObjectPtr(data_) + : T(std::forward(default_value)); + } + + TVM_FFI_INLINE explicit operator bool() const { return data_ != nullptr; } + + TVM_FFI_INLINE bool has_value() const { return data_ != nullptr; } + + /*! + * \brief Direct access to the value. + * \return the const reference to the stored value. + * \note only use this function after checking has_value() + */ + TVM_FFI_INLINE T operator*() const & noexcept { + return details::ObjectUnsafe::ObjectRefFromObjectPtr(data_); + } + + /*! + * \brief Direct access to the value. + * \return the const reference to the stored value. + * \note only use this function after checking has_value() + */ + TVM_FFI_INLINE T operator*() && noexcept { + return details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(data_)); + } + + TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { return !has_value(); } + TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { return has_value(); } + + // operator overloadings + TVM_FFI_INLINE auto operator==(const Optional &other) const { + // support case where sub-class returns a symbolic ref type. + return EQToOptional(other); + } + TVM_FFI_INLINE auto operator!=(const Optional &other) const { return NEToOptional(other); } + + TVM_FFI_INLINE auto operator==(const std::optional &other) const { + // support case where sub-class returns a symbolic ref type. + return EQToOptional(other); + } + TVM_FFI_INLINE auto operator!=(const std::optional &other) const { + return NEToOptional(other); + } + + TVM_FFI_INLINE auto operator==(const T &other) const { + using RetType = decltype(value() == other); + if (same_as(other)) { + return RetType(true); + } + if (has_value()) { + return operator*() == other; + } + return RetType(false); + } + + TVM_FFI_INLINE auto operator!=(const T &other) const { return !(*this == other); } + + template + TVM_FFI_INLINE auto operator==(const U &other) const { + using RetType = decltype(value() == other); + if (!has_value()) { + return RetType(false); + } + return operator*() == other; + } + + template + TVM_FFI_INLINE auto operator!=(const U &other) const { + using RetType = decltype(value() != other); + if (!has_value()) { + return RetType(true); + } + return operator*() != other; + } + + /*! + * \return The internal object pointer with container type of T. + * \note This function do not perform not-null checking. + */ + TVM_FFI_INLINE const ContainerType *get() const { + return static_cast(data_.get()); + } + +private: + template + TVM_FFI_INLINE auto EQToOptional(const U &other) const { + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(operator*() == *other); + if (same_as(other)) { + return RetType(true); + } + if (has_value() && other.has_value()) { + return operator*() == *other; + } else { + // one of them is nullptr. + return RetType(false); + } + } + + template + TVM_FFI_INLINE auto NEToOptional(const U &other) const { + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(operator*() != *other); + if (same_as(other)) { + return RetType(false); + } + if (has_value() && other.has_value()) { + return operator*() != *other; + } else { + // one of them is nullptr. + return RetType(true); + } + } +}; +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_OPTIONAL_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/access_path.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/access_path.h new file mode 100644 index 000000000..e4ebc2aa2 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/access_path.h @@ -0,0 +1,444 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/reflection/registry.h + * \brief Registry of reflection metadata. + */ +#ifndef TVM_FFI_REFLECTION_ACCESS_PATH_H_ +#define TVM_FFI_REFLECTION_ACCESS_PATH_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace ffi { +namespace reflection { + +/*! + * \brief The kind of the access pattern. + */ +enum class AccessKind : int32_t { + /*! \brief Object attribute access. */ + kAttr = 0, + /*! \brief Array item access. */ + kArrayItem = 1, + /*! \brief Map item access. */ + kMapItem = 2, + // the following two are used for error reporting when + // the supposed access field is not available + /*! \brief Object attribute missing access. */ + kAttrMissing = 3, + /*! \brief Array item missing access. */ + kArrayItemMissing = 4, + /*! \brief Map item missing access. */ + kMapItemMissing = 5, +}; + +class AccessStep; + +/*! + * \brief Represent a single step in object field, map key, array index access. + */ +class AccessStepObj : public Object { + public: + /*! + * \brief The kind of the access pattern. + */ + AccessKind kind; + /*! + * \brief The access key + * \note for array access, it will always be integer + * for field access, it will be string + */ + Any key; + + // default constructor to enable auto-serialization + AccessStepObj() = default; + /*! + * \brief Constructor + * \param kind The kind of the access step. + * \param key The key of the access step. + */ + AccessStepObj(AccessKind kind, Any key) : kind(kind), key(std::move(key)) {} + + /*! + * \brief Deep check if two steps are equal. + * \param other The other step to compare with. + * \return True if the two steps are equal, false otherwise. + */ + inline bool StepEqual(const AccessStep& other) const; + + /// \cond Doxygen_Suppress + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ffi.reflection.AccessStep", AccessStepObj, Object); + /// \endcond +}; + +/*! + * \brief ObjectRef class of AccessStepObj. + * + * \sa AccessStepObj + */ +class AccessStep : public ObjectRef { + public: + /*! + * \brief Constructor + * \param kind The kind of the access step. + * \param key The key of the access step. + */ + AccessStep(AccessKind kind, Any key) + : ObjectRef(make_object(kind, std::move(key))) {} + + /*! + * \brief Create an access step for a object attribute access. + * \param field_name The name of the field to access. + * \return The access step. + */ + static AccessStep Attr(String field_name) { + return AccessStep(AccessKind::kAttr, std::move(field_name)); + } + + /*! + * \brief Create an access step for a object attribute missing access. + * \param field_name The name of the field to access. + * \return The access step. + */ + static AccessStep AttrMissing(String field_name) { + return AccessStep(AccessKind::kAttrMissing, std::move(field_name)); + } + + /*! + * \brief Create an access step for a array item access. + * \param index The index of the array item to access. + * \return The access step. + */ + static AccessStep ArrayItem(int64_t index) { return AccessStep(AccessKind::kArrayItem, index); } + + /*! + * \brief Create an access step for a array item missing access. + * \param index The index of the array item to access. + * \return The access step. + */ + static AccessStep ArrayItemMissing(int64_t index) { + return AccessStep(AccessKind::kArrayItemMissing, index); + } + + /*! + * \brief Create an access step for a map item access. + * \param key The key of the map item to access. + * \return The access step. + */ + static AccessStep MapItem(Any key) { return AccessStep(AccessKind::kMapItem, std::move(key)); } + + /*! + * \brief Create an access step for a map item missing access. + * \param key The key of the map item to access. + * \return The access step. + */ + static AccessStep MapItemMissing(Any key = nullptr) { + return AccessStep(AccessKind::kMapItemMissing, std::move(key)); + } + + /// \cond Doxygen_Suppress + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AccessStep, ObjectRef, AccessStepObj); + /// \endcond +}; + +inline bool AccessStepObj::StepEqual(const AccessStep& other) const { + return this->kind == other->kind && AnyEqual()(this->key, other->key); +} + +// forward declaration +class AccessPath; + +/*! + * \brief ObjectRef class of AccessPathObj. + * + * \sa AccessPathObj + */ +class AccessPathObj : public Object { + public: + /*! + * \brief The parent of the access path. + * + * This parent-pointing tree structure is more space efficient when + * representing multiple paths that share a common prefix. + * + * \note Empty for root. + */ + Optional parent; + /*! + * \brief The current of the access path. + * \note Empty for root. + */ + Optional step; + /*! + * \brief The current depth of the access path, 0 for root + */ + int32_t depth; + + // default constructor to enable auto-serialization + AccessPathObj() = default; + /*! + * \brief Constructor for the access path. + * \param parent The parent of the access path. + * \param step The current step of the access path. + * \param depth The current depth of the access path. + */ + AccessPathObj(Optional parent, Optional step, int32_t depth) + : parent(std::move(parent)), step(std::move(step)), depth(depth) {} + + /*! + * \brief Get the parent of the access path. + * \return The parent of the access path. + */ + inline Optional GetParent() const; + + /*! + * \brief Extend the access path with a new step. + * \param step The step to extend the access path with. + * \return The extended access path. + */ + inline AccessPath Extend(AccessStep step) const; + + /*! + * \brief Extend the access path with an object attribute access. + * \param field_name The name of the field to access. + * \return The extended access path. + */ + inline AccessPath Attr(String field_name) const; + + /*! + * \brief Extend the access path with an object attribute missing access. + * \param field_name The name of the field to access. + * \return The extended access path. + */ + inline AccessPath AttrMissing(String field_name) const; + + /*! + * \brief Extend the access path with an array item access. + * \param index The index of the array item to access. + * \return The extended access path. + */ + inline AccessPath ArrayItem(int64_t index) const; + + /*! + * \brief Extend the access path with an array item missing access. + * \param index The index of the array item to access. + * \return The extended access path. + */ + inline AccessPath ArrayItemMissing(int64_t index) const; + + /*! + * \brief Extend the access path with a map item access. + * \param key The key of the map item to access. + * \return The extended access path. + */ + inline AccessPath MapItem(Any key) const; + + /*! + * \brief Extend the access path with a map item missing access. + * \param key The key of the map item to access. + * \return The extended access path. + */ + inline AccessPath MapItemMissing(Any key) const; + + /*! + * \brief Get the array of steps that corresponds to the access path. + * \return The array of steps that corresponds to the access path. + */ + inline Array ToSteps() const; + + /*! + * \brief Check if two paths are equal by deep comparing the steps. + * \param other The other path to compare with. + * \return True if the two paths are equal, false otherwise. + */ + inline bool PathEqual(const AccessPath& other) const; + + /*! + * \brief Check if this path is a prefix of another path. + * \param other The other path to compare with. + * \return True if this path is a prefix of the other path, false otherwise. + */ + inline bool IsPrefixOf(const AccessPath& other) const; + + /// \cond Doxygen_Suppress + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ffi.reflection.AccessPath", AccessPathObj, Object); + /// \endcond + + private: + static bool PathEqual(const AccessPathObj* lhs, const AccessPathObj* rhs) { + // fast path for same pointer + if (lhs == rhs) return true; + if (lhs->depth != rhs->depth) return false; + // do deep equality checks + while (lhs->parent.has_value()) { + TVM_FFI_ICHECK(rhs->parent.has_value()); + TVM_FFI_ICHECK(lhs->step.has_value()); + TVM_FFI_ICHECK(rhs->step.has_value()); + if (!(*lhs->step)->StepEqual(*(rhs->step))) { + return false; + } + lhs = static_cast(lhs->parent.get()); + rhs = static_cast(rhs->parent.get()); + // fast path for same pointer + if (lhs == rhs) return true; + TVM_FFI_ICHECK(lhs != nullptr); + TVM_FFI_ICHECK(rhs != nullptr); + } + return true; + } +}; + +/*! + * \brief ObjectRef class of AccessPath. + * + * \sa AccessPathObj + */ +class AccessPath : public ObjectRef { + public: + /*! + * \brief Create an access path from an iterator range of steps. + * \param begin The beginning of the iterator range. + * \param end The end of the iterator range. + * \return The access path. + */ + template // NOLINTNEXTLINE(performance-unnecessary-value-param) + static AccessPath FromSteps(Iter begin, Iter end) { + AccessPath path = AccessPath::Root(); + for (Iter it = begin; it != end; ++it) { + path = path->Extend(*it); + } + return path; + } + /*! + * \brief Create an access path from an array of steps. + * \param steps The array of steps. + * \return The access path. + */ + static AccessPath FromSteps(const Array& steps) { + AccessPath path = AccessPath::Root(); + for (AccessStep step : steps) { + path = path->Extend(step); + } + return path; + } + + /*! + * \brief Create a root access path. + * \return The root access path. + */ + static AccessPath Root() { + return AccessPath(make_object(std::nullopt, std::nullopt, 0)); + } + + /// \cond Doxygen_Suppress + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AccessPath, ObjectRef, AccessPathObj); + /// \endcond + + private: + friend class AccessPathObj; + explicit AccessPath(ObjectPtr ptr) : ObjectRef(std::move(ptr)) {} +}; + +/*! + * \brief The pair of access paths. + */ +using AccessPathPair = Tuple; + +inline Optional AccessPathObj::GetParent() const { + if (auto opt_parent = this->parent.as()) { + return opt_parent; + } + return std::nullopt; +} + +inline AccessPath AccessPathObj::Extend(AccessStep step) const { + return AccessPath( + make_object(GetRef(this), std::move(step), this->depth + 1)); +} + +inline AccessPath AccessPathObj::Attr(String field_name) const { + return this->Extend(AccessStep::Attr(std::move(field_name))); +} + +inline AccessPath AccessPathObj::AttrMissing(String field_name) const { + return this->Extend(AccessStep::AttrMissing(std::move(field_name))); +} + +inline AccessPath AccessPathObj::ArrayItem(int64_t index) const { + return this->Extend(AccessStep::ArrayItem(index)); +} + +inline AccessPath AccessPathObj::ArrayItemMissing(int64_t index) const { + return this->Extend(AccessStep::ArrayItemMissing(index)); +} + +inline AccessPath AccessPathObj::MapItem(Any key) const { + return this->Extend(AccessStep::MapItem(std::move(key))); +} + +inline AccessPath AccessPathObj::MapItemMissing(Any key) const { + return this->Extend(AccessStep::MapItemMissing(std::move(key))); +} + +inline Array AccessPathObj::ToSteps() const { + std::vector reverse_steps; + reverse_steps.reserve(this->depth); + const AccessPathObj* current = this; + while (current->parent.has_value()) { + TVM_FFI_ICHECK(current->step.has_value()); + reverse_steps.push_back(*(current->step)); + current = static_cast(current->parent.get()); + TVM_FFI_ICHECK(current != nullptr); + } + return Array(reverse_steps.rbegin(), reverse_steps.rend()); +} + +inline bool AccessPathObj::PathEqual(const AccessPath& other) const { + return PathEqual(this, other.get()); +} + +inline bool AccessPathObj::IsPrefixOf(const AccessPath& other) const { + if (this->depth > other->depth) { + return false; + } + const AccessPathObj* rhs_path = other.get(); + while (rhs_path->depth > this->depth) { + TVM_FFI_ICHECK(rhs_path->parent.has_value()); + rhs_path = static_cast(rhs_path->parent.get()); + } + return PathEqual(this, rhs_path); +} + +} // namespace reflection +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_REFLECTION_ACCESS_PATH_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/accessor.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/accessor.h new file mode 100644 index 000000000..b49da5193 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/accessor.h @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/reflection/accessor.h + * \brief Reflection-based accessor for object fields and methods. + */ +#ifndef TVM_FFI_REFLECTION_ACCESSOR_H_ +#define TVM_FFI_REFLECTION_ACCESSOR_H_ + +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace ffi { +namespace reflection { + +/*! + * \brief helper function to get reflection field info by type key and field name + */ +inline const TVMFFIFieldInfo* GetFieldInfo(std::string_view type_key, const char* field_name) { + int32_t type_index; + TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); + const TypeInfo* info = TVMFFIGetTypeInfo(type_index); + for (int32_t i = 0; i < info->num_fields; ++i) { + if (std::strncmp(info->fields[i].name.data, field_name, info->fields[i].name.size) == 0) { + return &(info->fields[i]); + } + } + TVM_FFI_THROW(RuntimeError) << "Cannot find field `" << field_name << "` in " << type_key; + TVM_FFI_UNREACHABLE(); +} + +/*! + * \brief helper wrapper class to obtain a getter. + */ +class FieldGetter { + public: + /*! + * \brief Constructor + * \param field_info The field info. + */ + explicit FieldGetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) {} + + /*! + * \brief Constructor + * \param type_key The type key. + * \param field_name The name of the field. + */ + explicit FieldGetter(std::string_view type_key, const char* field_name) + : FieldGetter(GetFieldInfo(type_key, field_name)) {} + + /*! + * \brief Get the value of the field + * \param obj_ptr The object pointer. + * \return The value of the field. + */ + Any operator()(const Object* obj_ptr) const { + Any result; + const void* addr = reinterpret_cast(obj_ptr) + field_info_->offset; + TVM_FFI_CHECK_SAFE_CALL( + field_info_->getter(const_cast(addr), reinterpret_cast(&result))); + return result; + } + + Any operator()(const ObjectPtr& obj_ptr) const { return operator()(obj_ptr.get()); } + + Any operator()(const ObjectRef& obj) const { return operator()(obj.get()); } + + private: + const TVMFFIFieldInfo* field_info_; +}; + +/*! + * \brief helper wrapper class to obtain a setter. + */ +class FieldSetter { + public: + /*! + * \brief Constructor + * \param field_info The field info. + */ + explicit FieldSetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) {} + + /*! + * \brief Constructor + * \param type_key The type key. + * \param field_name The name of the field. + */ + explicit FieldSetter(std::string_view type_key, const char* field_name) + : FieldSetter(GetFieldInfo(type_key, field_name)) {} + + /*! + * \brief Set the value of the field + * \param obj_ptr The object pointer. + * \param value The value to be set. + */ + void operator()(const Object* obj_ptr, AnyView value) const { + const void* addr = reinterpret_cast(obj_ptr) + field_info_->offset; + TVM_FFI_CHECK_SAFE_CALL( + field_info_->setter(const_cast(addr), reinterpret_cast(&value))); + } + + void operator()(const ObjectPtr& obj_ptr, AnyView value) const { + operator()(obj_ptr.get(), value); + } + + void operator()(const ObjectRef& obj, AnyView value) const { operator()(obj.get(), value); } + + private: + const TVMFFIFieldInfo* field_info_; +}; + +/*! + * \brief Helper class to get type attribute column. + */ +class TypeAttrColumn { + public: + /*! + * \brief Constructor + * \param attr_name The name of the type attribute. + */ + explicit TypeAttrColumn(std::string_view attr_name) { + TVMFFIByteArray attr_name_array = {attr_name.data(), attr_name.size()}; + column_ = TVMFFIGetTypeAttrColumn(&attr_name_array); + if (column_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Cannot find type attribute " << attr_name; + } + } + /*! + * \brief Get the type attribute column by type index. + * \param type_index The type index. + * \return The type attribute column. + */ + AnyView operator[](int32_t type_index) const { + size_t tindex = static_cast(type_index); + if (tindex >= column_->size) { + return AnyView(); + } + const AnyView* any_view_data = reinterpret_cast(column_->data); + return any_view_data[tindex]; + } + + private: + const TVMFFITypeAttrColumn* column_; +}; + +/*! + * \brief helper function to get reflection method info by type key and method name + * + * \param type_key The type key. + * \param method_name The name of the method. + * \return The method info. + */ +inline const TVMFFIMethodInfo* GetMethodInfo(std::string_view type_key, const char* method_name) { + int32_t type_index; + TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); + const TypeInfo* info = TVMFFIGetTypeInfo(type_index); + for (int32_t i = 0; i < info->num_methods; ++i) { + if (std::strncmp(info->methods[i].name.data, method_name, info->methods[i].name.size) == 0) { + return &(info->methods[i]); + } + } + TVM_FFI_THROW(RuntimeError) << "Cannot find method " << method_name << " in " << type_key; + TVM_FFI_UNREACHABLE(); +} + +/*! + * \brief helper function to get reflection method function by method info + * + * \param type_key The type key. + * \param method_name The name of the method. + * \return The method function. + */ +inline Function GetMethod(std::string_view type_key, const char* method_name) { + const TVMFFIMethodInfo* info = GetMethodInfo(type_key, method_name); + return AnyView::CopyFromTVMFFIAny(info->method).cast(); +} + +/*! + * \brief Visit each field info of the type info and run callback. + * + * \tparam Callback The callback function type. + * + * \param type_info The type info. + * \param callback The callback function. + * + * \note This function calls both the child and parent type info. + */ +template +inline void ForEachFieldInfo(const TypeInfo* type_info, Callback callback) { + using ResultType = decltype(callback(type_info->fields)); + static_assert(std::is_same_v, "Callback must return void"); + // iterate through acenstors in parent to child order + // skip the first one since it is always the root object + for (int i = 1; i < type_info->type_depth; ++i) { + const TVMFFITypeInfo* parent_info = type_info->type_ancestors[i]; + for (int j = 0; j < parent_info->num_fields; ++j) { + callback(parent_info->fields + j); + } + } + for (int i = 0; i < type_info->num_fields; ++i) { + callback(type_info->fields + i); + } +} + +/*! + * \brief Visit each field info of the type info and run callback which returns bool for early stop. + * + * \tparam Callback The callback function type, which returns bool for early stop. + * + * \param type_info The type info. + * \param callback_with_early_stop The callback function. + * \return true if any of early stop is triggered. + * + * \note This function calls both the child and parent type info and can be used for searching. + */ +template +inline bool ForEachFieldInfoWithEarlyStop(const TypeInfo* type_info, + Callback callback_with_early_stop) { + // iterate through acenstors in parent to child order + // skip the first one since it is always the root object + for (int i = 1; i < type_info->type_depth; ++i) { + const TVMFFITypeInfo* parent_info = type_info->type_ancestors[i]; + for (int j = 0; j < parent_info->num_fields; ++j) { + if (callback_with_early_stop(parent_info->fields + j)) return true; + } + } + for (int i = 0; i < type_info->num_fields; ++i) { + if (callback_with_early_stop(type_info->fields + i)) return true; + } + return false; +} + +} // namespace reflection +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_REFLECTION_ACCESSOR_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/creator.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/creator.h new file mode 100644 index 000000000..774eb8b0b --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/creator.h @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/reflection/creator.h + * \brief Reflection-based creator to create objects from type key and fields. + */ +#ifndef TVM_FFI_REFLECTION_CREATOR_H_ +#define TVM_FFI_REFLECTION_CREATOR_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace ffi { +namespace reflection { +/*! + * \brief helper wrapper class of TVMFFITypeInfo to create object based on reflection. + */ +class ObjectCreator { + public: + /*! + * \brief Constructor + * \param type_key The type key. + */ + explicit ObjectCreator(std::string_view type_key) + : ObjectCreator(TVMFFIGetTypeInfo(TypeKeyToIndex(type_key))) {} + + /*! + * \brief Constructor + * \param type_info The type info. + */ + explicit ObjectCreator(const TVMFFITypeInfo* type_info) : type_info_(type_info) { + int32_t type_index = type_info->type_index; + if (type_info->metadata == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) + << "` does not have reflection registered"; + } + if (type_info->metadata->creator == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) + << "` does not support default constructor, " + << "as a result cannot be created via reflection"; + } + } + + /** + * \brief Create an object from a map of fields. + * \param fields The fields of the object. + * \return The created object. + */ + Any operator()(const Map& fields) const { + TVMFFIObjectHandle handle; + TVM_FFI_CHECK_SAFE_CALL(type_info_->metadata->creator(&handle)); + ObjectPtr ptr = + details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + size_t match_field_count = 0; + ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo* field_info) { + String field_name(field_info->name); + void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; + if (fields.count(field_name) != 0) { + Any field_value = fields[field_name]; + field_info->setter(field_addr, reinterpret_cast(&field_value)); + ++match_field_count; + } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { + field_info->setter(field_addr, &(field_info->default_value)); + } else { + TVM_FFI_THROW(TypeError) << "Required field `" + << String(field_info->name.data, field_info->name.size) + << "` not set in type `" + << String(type_info_->type_key.data, type_info_->type_key.size) + << "`"; + } + }); + if (match_field_count == fields.size()) return ObjectRef(ptr); + // report error that checks if contains extra fields that are not in the type + auto check_field_name = [&](const String& field_name) { + bool found = false; + ForEachFieldInfoWithEarlyStop(type_info_, [&](const TVMFFIFieldInfo* field_info) { + if (field_name.compare(field_info->name) == 0) { + found = true; + return true; + } + return false; + }); + return found; + }; + for (const auto& [field_name, _] : fields) { + if (!check_field_name(field_name)) { + TVM_FFI_THROW(TypeError) << "Type `" + << String(type_info_->type_key.data, type_info_->type_key.size) + << "` does not have field `" << field_name << "`"; + } + } + TVM_FFI_UNREACHABLE(); + } + + private: + const TVMFFITypeInfo* type_info_; +}; +} // namespace reflection +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_REFLECTION_CREATOR_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/registry.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/registry.h new file mode 100644 index 000000000..3014108c8 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/registry.h @@ -0,0 +1,741 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/reflection/registry.h + * \brief Registry of reflection metadata. + */ +#ifndef TVM_FFI_REFLECTION_REGISTRY_H_ +#define TVM_FFI_REFLECTION_REGISTRY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace ffi { +/*! \brief Reflection namespace */ +namespace reflection { +/*! + * \brief Types of temporary metadata hold in FieldInfoBuilder and MethodInfoBuilder, + * before they are filled into final C metadata + */ +using _MetadataType = std::vector>; // NOLINT(bugprone-reserved-identifier) +/*! + * \brief Builder for TVMFFIFieldInfo + * \sa TVMFFIFieldInfo + */ +struct FieldInfoBuilder : public TVMFFIFieldInfo { + /*! \brief Temporary metadata info to be filled into TVMFFIFieldInfo::metadata */ + _MetadataType metadata_; +}; +/*! + * \brief Builder for TVMFFIMethodInfo + * \sa TVMFFIMethodInfo + */ +struct MethodInfoBuilder : public TVMFFIMethodInfo { + /*! \brief Temporary metadata info to be filled into TVMFFIMethodInfo::metadata */ + _MetadataType metadata_; +}; + +/*! + * \brief Trait that can be used to set information attached to a field or a method. + * \sa DefaultValue, AttachFieldFlag + */ +struct InfoTrait {}; + +/*! \brief User-supplied metadata attached to a field or a method */ +class Metadata : public InfoTrait { + public: + /*! + * \brief Constructor + * \param dict The initial dictionary + */ + Metadata(std::initializer_list> dict) : dict_(dict) {} + /*! + * \brief Move metadata into `FieldInfoBuilder` + * \param info The field info builder. + */ + inline void Apply(FieldInfoBuilder* info) const { this->Apply(&info->metadata_); } + /*! + * \brief Move metadata into `MethodInfoBuilder` + * \param info The method info builder. + */ + inline void Apply(MethodInfoBuilder* info) const { this->Apply(&info->metadata_); } + + private: + friend class GlobalDef; + template + friend class ObjectDef; + /*! + * \brief Move metadata into a vector of key-value pairs. + * \param out The output vector. + */ + inline void Apply(_MetadataType* out) const { + std::copy(std::make_move_iterator(dict_.begin()), std::make_move_iterator(dict_.end()), + std::back_inserter(*out)); + } + /*! \brief Convert the metadata to JSON string */ + static std::string ToJSON(const _MetadataType& metadata) { + using ::tvm::ffi::details::StringObj; + std::ostringstream os; + os << "{"; + bool first = true; + for (const auto& [key, value] : metadata) { + if (!first) { + os << ","; + } + os << "\"" << key << "\":"; + if (std::optional v = value.as()) { + os << *v; + } else if (std::optional v = value.as()) { + os << (*v ? "true" : "false"); + } else if (std::optional v = value.as()) { + String escaped = EscapeString(*v); + os << escaped.c_str(); + } else { + TVM_FFI_LOG_AND_THROW(TypeError) << "Metadata can be only int, bool or string, but on key `" + << key << "`, the type is " << value.GetTypeKey(); + } + first = false; + } + os << "}"; + return os.str(); + } + + std::vector> dict_; +}; +/*! + * \brief Trait that can be used to set field default value + */ +class DefaultValue : public InfoTrait { + public: + /*! + * \brief Constructor + * \param value The value to be set + */ + explicit DefaultValue(Any value) : value_(std::move(value)) {} + + /*! + * \brief Apply the default value to the field info + * \param info The field info. + */ + TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const { + info->default_value = AnyView(value_).CopyToTVMFFIAny(); + info->flags |= kTVMFFIFieldFlagBitMaskHasDefault; + } + + private: + Any value_; +}; + +/*! + * \brief Trait that can be used to attach field flag + */ +class AttachFieldFlag : public InfoTrait { + public: + /*! + * \brief Attach a field flag to the field + * \param flag The flag to be set + */ + explicit AttachFieldFlag(int32_t flag) : flag_(flag) {} + + /*! + * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDef + */ + TVM_FFI_INLINE static AttachFieldFlag SEqHashDef() { + return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDef); + } + /*! + * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashIgnore + */ + TVM_FFI_INLINE static AttachFieldFlag SEqHashIgnore() { + return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashIgnore); + } + + /*! + * \brief Apply the field flag to the field info + * \param info The field info. + */ + TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const { info->flags |= flag_; } + + private: + int32_t flag_; +}; + +/*! + * \brief Get the byte offset of a class member field. + * + * \tparam The original class. + * \tparam T the field type. + * + * \param field_ptr A class member pointer + * \returns The byteoffset + */ +template +TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::* field_ptr) { + int64_t field_offset_to_class = + reinterpret_cast(&(static_cast(nullptr)->*field_ptr)); + return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass(); +} + +/// \cond Doxygen_Suppress +class ReflectionDefBase { + protected: + template + static int FieldGetter(void* field, TVMFFIAny* result) { + TVM_FFI_SAFE_CALL_BEGIN(); + *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast(field))); + TVM_FFI_SAFE_CALL_END(); + } + + template + static int FieldSetter(void* field, const TVMFFIAny* value) { + TVM_FFI_SAFE_CALL_BEGIN(); + if constexpr (std::is_same_v) { + *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value); + } else { + *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value).cast(); + } + TVM_FFI_SAFE_CALL_END(); + } + + template + static int ObjectCreatorDefault(TVMFFIObjectHandle* result) { + TVM_FFI_SAFE_CALL_BEGIN(); + ObjectPtr obj = make_object(); + *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); + TVM_FFI_SAFE_CALL_END(); + } + + template + static int ObjectCreatorUnsafeInit(TVMFFIObjectHandle* result) { + TVM_FFI_SAFE_CALL_BEGIN(); + ObjectPtr obj = make_object(UnsafeInit{}); + *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); + TVM_FFI_SAFE_CALL_END(); + } + + template + TVM_FFI_INLINE static void ApplyFieldInfoTrait(FieldInfoBuilder* info, const T& value) { + if constexpr (std::is_base_of_v>) { + value.Apply(info); + } + if constexpr (std::is_same_v, char*>) { + info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + } + } + + template + TVM_FFI_INLINE static void ApplyMethodInfoTrait(MethodInfoBuilder* info, const T& value) { + if constexpr (std::is_base_of_v>) { + value.Apply(info); + } + if constexpr (std::is_same_v, char*>) { + info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + } + } + + template + TVM_FFI_INLINE static void ApplyExtraInfoTrait(TVMFFITypeMetadata* info, const T& value) { + if constexpr (std::is_same_v, char*>) { + info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + } + } + + template + TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...)) { + static_assert(std::is_base_of_v || std::is_base_of_v, + "Class must be derived from ObjectRef or Object"); + if constexpr (std::is_base_of_v) { + auto fwrap = [func](Class target, Args... params) -> R { + // call method pointer + return (target.*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, std::move(name)); + } + + if constexpr (std::is_base_of_v) { + auto fwrap = [func](const Class* target, Args... params) -> R { + // call method pointer + return (const_cast(target)->*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, std::move(name)); + } + } + + template + TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...) const) { + static_assert(std::is_base_of_v || std::is_base_of_v, + "Class must be derived from ObjectRef or Object"); + if constexpr (std::is_base_of_v) { + auto fwrap = [func](const Class& target, Args... params) -> R { + // call method pointer + return (target.*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, std::move(name)); + } + + if constexpr (std::is_base_of_v) { + auto fwrap = [func](const Class* target, Args... params) -> R { + // call method pointer + return (target->*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, std::move(name)); + } + } + + template + TVM_FFI_INLINE static Function GetMethod(std::string name, Func&& func) { + return ffi::Function::FromTyped(std::forward(func), std::move(name)); + } +}; +/// \endcond + +/*! + * \brief GlobalDef helper to register a global function. + * + * \code + * namespace refl = tvm::ffi::reflection; + * refl::GlobalDef().def("my_ffi_extension.my_function", MyFunction); + * \endcode + */ +class GlobalDef : public ReflectionDefBase { + public: + /*! + * \brief Define a global function. + * + * \tparam Func The function type. + * \tparam Extra The extra arguments. + * + * \param name The name of the function. + * \param func The function to be registered. + * \param extra The extra arguments that can be docstring or subclass of InfoTrait. + * + * \return The reflection definition. + */ + template + GlobalDef& def(const char* name, Func&& func, Extra&&... extra) { + using FuncInfo = details::FunctionInfo>; + RegisterFunc(name, ffi::Function::FromTyped(std::forward(func), std::string(name)), + FuncInfo::TypeSchema(), std::forward(extra)...); + return *this; + } + + /*! + * \brief Define a global function in ffi::PackedArgs format. + * + * \tparam Func The function type. + * \tparam Extra The extra arguments. + * + * \param name The name of the function. + * \param func The function to be registered. + * \param extra The extra arguments that can be docstring or subclass of InfoTrait. + * + * \return The reflection definition. + */ + template + GlobalDef& def_packed(const char* name, Func func, Extra&&... extra) { + RegisterFunc(name, ffi::Function::FromPacked(func), details::TypeSchemaImpl::v(), + std::forward(extra)...); + return *this; + } + + /*! + * \brief Expose a class method as a global function. + * + * An argument will be added to the first position if the function is not static. + * + * \tparam Class The class type. + * \tparam Func The function type. + * + * \param name The name of the method. + * \param func The function to be registered. + * \param extra The extra arguments that can be docstring. + * + * \return The reflection definition. + */ + template + GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) { + using FuncInfo = details::FunctionInfo>; + RegisterFunc(name, GetMethod(std::string(name), std::forward(func)), + FuncInfo::TypeSchema(), std::forward(extra)...); + return *this; + } + + private: + template // NOLINTNEXTLINE(performance-unnecessary-value-param) + void RegisterFunc(const char* name, ffi::Function func, String type_schema, Extra&&... extra) { + MethodInfoBuilder info; + info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; + info.doc = TVMFFIByteArray{nullptr, 0}; + info.flags = 0; + info.method = AnyView(func).CopyToTVMFFIAny(); + info.metadata_.emplace_back("type_schema", type_schema); + ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); + std::string metadata_str = Metadata::ToJSON(info.metadata_); + info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionSetGlobalFromMethodInfo(&info, 0)); + } +}; + +/*! + * \brief Helper class to register a constructor method for object types. + * + * This helper is used with `ObjectDef::def()` to register an `__init__` method + * that constructs an object instance with the specified argument types. + * + * \tparam Args The argument types for the constructor. + * + * Example usage: + * \code + * class ExampleObject : public Object { + * public: + * int64_t v_i64; + * int32_t v_i32; + * + * ExampleObject(int64_t v_i64, int32_t v_i32) : v_i64(v_i64), v_i32(v_i32) {} + * TVM_FFI_DECLARE_OBJECT_INFO("example.ExampleObject", ExampleObject, Object); + * }; + * + * // Register the constructor + * refl::ObjectDef() + * .def(refl::init()); + * \endcode + * + * \note The object type is automatically deduced from the `ObjectDef` context. + */ +template +struct init { + // Allow ObjectDef to access the execute function + template + friend class ObjectDef; + + /*! + * \brief Constructor + */ + constexpr init() noexcept = default; + + private: + /*! + * \brief Execute the constructor + * \tparam Class The class type. + * \param args The arguments to be passed to the constructor. + * \return The constructed object wrapped in an `ObjectRef`. + */ + template + static inline ObjectRef execute(Args&&... args) { + return ObjectRef(ffi::make_object(std::forward(args)...)); + } +}; + +/*! + * \brief Helper to register Object's reflection metadata. + * \tparam Class The class type. + * + * \code + * namespace refl = tvm::ffi::reflection; + * refl::ObjectDef().def_ro("my_field", &MyClass::my_field); + * \endcode + */ +template +class ObjectDef : public ReflectionDefBase { + public: + /*! + * \brief Constructor + * \tparam ExtraArgs The extra arguments. + * \param extra_args The extra arguments. + */ + template + explicit ObjectDef(ExtraArgs&&... extra_args) + : type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) { + RegisterExtraInfo(std::forward(extra_args)...); + } + + /*! + * \brief Define a readonly field. + * + * \tparam Class The class type. + * \tparam T The field type. + * \tparam Extra The extra arguments. + * + * \param name The name of the field. + * \param field_ptr The pointer to the field. + * \param extra The extra arguments that can be docstring or default value. + * + * \return The reflection definition. + */ + template + TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T BaseClass::* field_ptr, Extra&&... extra) { + RegisterField(name, field_ptr, false, std::forward(extra)...); + return *this; + } + + /*! + * \brief Define a read-write field. + * + * \tparam Class The class type. + * \tparam T The field type. + * \tparam Extra The extra arguments. + * + * \param name The name of the field. + * \param field_ptr The pointer to the field. + * \param extra The extra arguments that can be docstring or default value. + * + * \return The reflection definition. + */ + template + TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T BaseClass::* field_ptr, Extra&&... extra) { + static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields"); + RegisterField(name, field_ptr, true, std::forward(extra)...); + return *this; + } + + /*! + * \brief Define a method. + * + * \tparam Func The function type. + * \tparam Extra The extra arguments. + * + * \param name The name of the method. + * \param func The function to be registered. + * \param extra The extra arguments that can be docstring. + * + * \return The reflection definition. + */ + template + TVM_FFI_INLINE ObjectDef& def(const char* name, Func&& func, Extra&&... extra) { + RegisterMethod(name, false, std::forward(func), std::forward(extra)...); + return *this; + } + + /*! + * \brief Define a static method. + * + * \tparam Func The function type. + * \tparam Extra The extra arguments. + * + * \param name The name of the method. + * \param func The function to be registered. + * \param extra The extra arguments that can be docstring. + * + * \return The reflection definition. + */ + template + TVM_FFI_INLINE ObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) { + RegisterMethod(name, true, std::forward(func), std::forward(extra)...); + return *this; + } + + /*! + * \brief Register a constructor for this object type. + * + * This method registers a static `__init__` method that constructs an instance + * of the object with the specified argument types. The constructor can be invoked + * from Python or other FFI bindings. + * + * \tparam Args The argument types for the constructor. + * \tparam Extra Additional arguments (e.g., docstring). + * + * \param init_func An instance of `init` specifying constructor signature. + * \param extra Optional additional metadata such as docstring. + * + * \return Reference to this `ObjectDef` for method chaining. + * + * Example: + * \code + * refl::ObjectDef() + * .def(refl::init(), "Constructor docstring"); + * \endcode + */ + template + TVM_FFI_INLINE ObjectDef& def([[maybe_unused]] init init_func, Extra&&... extra) { + RegisterMethod(kInitMethodName, true, &init::template execute, + std::forward(extra)...); + return *this; + } + + private: + template + void RegisterExtraInfo(ExtraArgs&&... extra_args) { + TVMFFITypeMetadata info; + info.total_size = sizeof(Class); + info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind; + info.creator = nullptr; + info.doc = TVMFFIByteArray{nullptr, 0}; + if constexpr (std::is_default_constructible_v) { + info.creator = ObjectCreatorDefault; + } else if constexpr (std::is_constructible_v) { + info.creator = ObjectCreatorUnsafeInit; + } + // apply extra info traits + ((ApplyExtraInfoTrait(&info, std::forward(extra_args)), ...)); + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMetadata(type_index_, &info)); + } + + template + void RegisterField(const char* name, T BaseClass::* field_ptr, bool writable, + ExtraArgs&&... extra_args) { + static_assert(std::is_base_of_v, "BaseClass must be a base class of Class"); + FieldInfoBuilder info; + info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; + info.field_static_type_index = TypeToFieldStaticTypeIndex::value; + // store byte offset and setter, getter + // so the same setter can be reused for all the same type + info.offset = GetFieldByteOffsetToObject(field_ptr); + info.size = sizeof(T); + info.alignment = alignof(T); + info.flags = 0; + if (writable) { + info.flags |= kTVMFFIFieldFlagBitMaskWritable; + } + info.getter = FieldGetter; + info.setter = FieldSetter; + // initialize default value to nullptr + info.default_value = AnyView(nullptr).CopyToTVMFFIAny(); + info.doc = TVMFFIByteArray{nullptr, 0}; + info.metadata_.emplace_back("type_schema", details::TypeSchema::v()); + // apply field info traits + ((ApplyFieldInfoTrait(&info, std::forward(extra_args)), ...)); + // call register + std::string metadata_str = Metadata::ToJSON(info.metadata_); + info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info)); + } + + // register a method + template + void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) { + using FuncInfo = details::FunctionInfo>; + MethodInfoBuilder info; + info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; + info.doc = TVMFFIByteArray{nullptr, 0}; + info.flags = 0; + if (is_static) { + info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod; + } + // obtain the method function + Function method = GetMethod(std::string(type_key_) + "." + name, std::forward(func)); + info.method = AnyView(method).CopyToTVMFFIAny(); + info.metadata_.emplace_back("type_schema", FuncInfo::TypeSchema()); + // apply method info traits + ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); + std::string metadata_str = Metadata::ToJSON(info.metadata_); + info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info)); + } + + int32_t type_index_; + const char* type_key_; + static constexpr const char* kInitMethodName = "__ffi_init__"; +}; + +/*! + * \brief Helper to register type attribute. + * \tparam Class The class type. + * \tparam ExtraArgs The extra arguments. + * + * \code + * namespace refl = tvm::ffi::reflection; + * refl::TypeAttrDef().def("func_attr", MyFunc); + * \endcode + * + */ +template >> +class TypeAttrDef : public ReflectionDefBase { + public: + /*! + * \brief Constructor + * \tparam ExtraArgs The extra arguments. + * \param extra_args The extra arguments. + */ + template + explicit TypeAttrDef(ExtraArgs&&... extra_args) + : type_index_(Class::RuntimeTypeIndex()), type_key_(Class::_type_key) {} + + /*! + * \brief Define a function-valued type attribute. + * + * \tparam Func The function type. + * + * \param name The name of the function. + * \param func The function to be registered. + * + * \return The TypeAttrDef object. + */ + template + TypeAttrDef& def(const char* name, Func&& func) { + TVMFFIByteArray name_array = {name, std::char_traits::length(name)}; + ffi::Function ffi_func = + GetMethod(std::string(type_key_) + "." + name, std::forward(func)); + TVMFFIAny value_any = AnyView(ffi_func).CopyToTVMFFIAny(); + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any)); + return *this; + } + + /*! + * \brief Define a constant-valued type attribute. + * + * \tparam T The type of the value. + * + * \param name The name of the attribute. + * \param value The value of the attribute. + * + * \return The TypeAttrDef object. + */ + template + TypeAttrDef& attr(const char* name, T value) { + TVMFFIByteArray name_array = {name, std::char_traits::length(name)}; + TVMFFIAny value_any = AnyView(value).CopyToTVMFFIAny(); + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any)); + return *this; + } + + private: + int32_t type_index_; + const char* type_key_; +}; + +/*! + * \brief Ensure the type attribute column is presented in the system. + * + * \param name The name of the type attribute. + */ +inline void EnsureTypeAttrColumn(std::string_view name) { + TVMFFIByteArray name_array = {name.data(), name.size()}; + AnyView any_view(nullptr); + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(kTVMFFINone, &name_array, + reinterpret_cast(&any_view))); +} + +} // namespace reflection +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_REFLECTION_REGISTRY_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/rvalue_ref.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/rvalue_ref.h new file mode 100644 index 000000000..aca5840fa --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/rvalue_ref.h @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/rvalue_ref.h + * \brief Helper class to define rvalue reference type. + */ +#ifndef TVM_FFI_RVALUE_REF_H_ +#define TVM_FFI_RVALUE_REF_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief Helper class to define rvalue reference type. + * + * By default, FFI pass all values by lvalue reference. + * + * However, we do allow users to intentionally mark a function parameter + * as RValueRef. In such cases, the caller can choose to pass parameter + * wrapped by RValueRef to the function. In which case the parameter + * can be directly moved by the callee. The caller can also choose to pass + * a normal lvalue to the function, in such case a copy will be triggered. + * + * To keep FFI checking overhead minimal, we do not handle case when rvalue + * is passed, but the callee did not declare the parameter as RValueRef. + * + * This design allows us to still leverage move semantics for parameters that + * need copy on write scenarios (and requires an unique copy). + * + * \code + * + * void Example() { + * auto append = Function::FromTyped([](RValueRef> ref, int val) -> Array { + * Array arr = *std::move(ref); + * assert(arr.unique()); + * arr.push_back(val); + * return arr; + * }); + * Array a = Array({1, 2}); + * // as we use rvalue ref to move a into append + * // we keep a single copy of the Array without creating new copies during copy-on-write + * a = append(RvalueRef(std::move(a)), 3); + * assert(a.size() == 3); + * } + * + * \endcode + */ +template >> +class RValueRef { + public: + /*! \brief the container type of the rvalue ref */ + using ContainerType = typename TObjRef::ContainerType; + /*! \brief only allow move constructor from rvalue of T */ + explicit RValueRef(TObjRef&& data) + : data_(details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(data))) {} + + /*! \brief return the data as rvalue */ + TObjRef operator*() && { return TObjRef(std::move(data_)); } + + private: + mutable ObjectPtr data_; + + template + friend struct TypeTraits; +}; + +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> : public TypeTraitsBase { + static constexpr bool storage_enabled = false; + + TVM_FFI_INLINE static void CopyToAnyView(const RValueRef& src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIObjectRValueRef; + result->zero_padding = 0; + // store the address of the ObjectPtr, which allows us to move the value + // and set the original ObjectPtr to nullptr + result->v_ptr = &(src.data_); + } + + TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) { + ObjectPtr* rvalue_ref = reinterpret_cast*>(src->v_ptr); + // object type does not match up, we need to try to convert the object + // in this case we do not move the original rvalue ref since conversion creates a copy + TVMFFIAny tmp_any; + tmp_any.type_index = rvalue_ref->get()->type_index(); + tmp_any.zero_padding = 0; + tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); + return "RValueRef<" + TypeTraits::GetMismatchTypeInfo(&tmp_any) + ">"; + } else { + return TypeTraits::GetMismatchTypeInfo(src); + } + } + + TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { + // first try rvalue conversion + if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) { + ObjectPtr* rvalue_ref = reinterpret_cast*>(src->v_ptr); + TVMFFIAny tmp_any; + tmp_any.type_index = rvalue_ref->get()->type_index(); + tmp_any.zero_padding = 0; + tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); + // fast path, storage type matches, direct move the rvalue ref + if (TypeTraits::CheckAnyStrict(&tmp_any)) { + return RValueRef( + details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(*rvalue_ref))); + } + if (std::optional opt = TypeTraits::TryCastFromAnyView(&tmp_any)) { + // object type does not match up, we need to try to convert the object + // in this case we do not move the original rvalue ref since conversion creates a copy + return RValueRef(*std::move(opt)); + } + return std::nullopt; + } + // try lvalue conversion + if (std::optional opt = TypeTraits::TryCastFromAnyView(src)) { + return RValueRef(*std::move(opt)); + } else { + return std::nullopt; + } + } + + TVM_FFI_INLINE static std::string TypeStr() { + return "RValueRef<" + TypeTraits::TypeStr() + ">"; + } + + TVM_FFI_INLINE static std::string TypeSchema() { + std::ostringstream oss; + oss << R"({"type":")" << StaticTypeKey::kTVMFFIObjectRValueRef << R"(","args":[)"; + oss << TypeTraits::TypeSchema(); + oss << "]}"; + return oss.str(); + } +}; +} // namespace ffi +} // namespace tvm + +#endif // TVM_FFI_RVALUE_REF_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/string.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/string.h new file mode 100644 index 000000000..ad7230b93 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/string.h @@ -0,0 +1,1102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/string.h + * \brief Runtime Bytes and String types. + */ +#ifndef TVM_FFI_STRING_H_ +#define TVM_FFI_STRING_H_ + +#include "base_details.h" +#include "error.h" +#include "memory.h" +#include "object.h" +#include "type_traits.h" + +#include +#include +#include +#include +#include +#include + +// Note: We place string in tvm/ffi instead of tvm/ffi/container +// because string itself needs special handling and is an inherent +// core component for return string handling. +// The following dependency relation holds +// any -> string -> object + +/// \cond Doxygen_Suppress +#ifdef _MSC_VER +#define TVM_FFI_SNPRINTF _snprintf_s +#pragma warning(push) +#pragma warning(disable : 4244) +#pragma warning(disable : 4127) +#pragma warning(disable : 4702) +#else +#define TVM_FFI_SNPRINTF snprintf +#endif +/// \endcond + +namespace tvm { +namespace ffi { +namespace details { +/*! + * \brief Base class for bytes and string objects. + */ +class BytesObjBase : public Object, public TVMFFIByteArray {}; + +/*! + * \brief An object representing bytes. + * \note We use a separate object for bytes to follow Python convention + * and indicate passing of raw bytes. + * Bytes can be converted from/to string. + */ +class BytesObj : public BytesObjBase { +public: + static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIBytes; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIBytes, BytesObj, Object); +}; + +/*! \brief An object representing string. This is a POD type. */ +class StringObj : public BytesObjBase { +public: + static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIStr; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIStr, StringObj, Object); +}; + +// String moved from std::string +// without having to trigger a copy +template +class BytesObjStdImpl : public Base { +public: + explicit BytesObjStdImpl(std::string other) : data_{std::move(other)} { + this->data = data_.data(); + this->size = data_.size(); + } + +private: + std::string data_; +}; + +/*! + * \brief Helper cell class that can be used to back small string + * \note Do not use directly, use String or Bytes instead + */ +class BytesBaseCell { +public: + BytesBaseCell() { + // initialize to none + data_.type_index = TypeIndex::kTVMFFINone; + data_.zero_padding = 0; + data_.v_int64 = 0; + } + + explicit BytesBaseCell(std::nullopt_t) { + data_.type_index = TypeIndex::kTVMFFINone; + data_.zero_padding = 0; + data_.v_int64 = 0; + } + + BytesBaseCell(const BytesBaseCell &other) : data_(other.data_) { // NOLINT(*) + if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { + details::ObjectUnsafe::IncRefObjectHandle(data_.v_obj); + } + } + + BytesBaseCell(BytesBaseCell &&other) : data_(other.data_) { // NOLINT(*) + other.data_.type_index = TypeIndex::kTVMFFINone; + } + + BytesBaseCell &operator=(const BytesBaseCell &other) { + BytesBaseCell(other).swap(*this); // NOLINT(*) + return *this; + } + + BytesBaseCell &operator=(BytesBaseCell &&other) noexcept { + BytesBaseCell(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + + ~BytesBaseCell() { + if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { + details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj); + } + } + + /*! + * \brief Check if the cell is null + * \return true if the cell is null, false otherwise + */ + bool operator==(std::nullopt_t) const { return data_.type_index == TypeIndex::kTVMFFINone; } + + /*! + * \brief Check if the cell is not null + * \return true if the cell is not null, false otherwise + */ + bool operator!=(std::nullopt_t) const { return data_.type_index != TypeIndex::kTVMFFINone; } + + /*! + * \brief Swap this String with another string + * \param other The other string + */ + void swap(BytesBaseCell &other) { // NOLINT(*) + std::swap(data_, other.data_); + } + + const char *data() const noexcept { + if (data_.type_index < TypeIndex::kTVMFFIStaticObjectBegin) { + return data_.v_bytes; + } else { + // NOLINTNEXTLINE(clang-analyzer-security.ArrayBound) + return TVMFFIBytesGetByteArrayPtr(data_.v_obj)->data; + } + } + + size_t size() const noexcept { + if (data_.type_index < TypeIndex::kTVMFFIStaticObjectBegin) { + return data_.small_str_len; + } else { + // NOLINTNEXTLINE(clang-analyzer-security.ArrayBound) + return TVMFFIBytesGetByteArrayPtr(data_.v_obj)->size; + } + } + + template + void InitFromStd(std::string &&other, int32_t large_type_index) { + // needs to be reset to none first for exception safety + data_.type_index = TypeIndex::kTVMFFINone; + data_.zero_padding = 0; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&data_); + ObjectPtr ptr = make_object>(std::move(other)); + data_.v_obj = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(ptr)); + data_.type_index = large_type_index; + } + + /*! + * \brief Create a new empty space for a string + * \param size The size of the string + * \param small_type_index The type index for the small string + * \param large_type_index The type index for the large string + * \note always reserve one byte for \0 compactibility + * \return A pointer to the empty space + */ + template + char *InitSpaceForSize(size_t size, int32_t small_type_index, int32_t large_type_index) { + size_t kMaxSmallBytesLen = sizeof(int64_t) - 1; + // first zero the content, this is important for exception safety + data_.type_index = small_type_index; + data_.zero_padding = 0; + if (size <= kMaxSmallBytesLen) { + // set up the size accordingly + data_.small_str_len = static_cast(size); + return data_.v_bytes; + } else { + // allocate from heap + ObjectPtr ptr = make_inplace_array_object(size + 1); + char *dest_data = reinterpret_cast(ptr.get()) + sizeof(LargeObj); + ptr->data = dest_data; + ptr->size = size; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&data_); + data_.v_obj = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(ptr)); + // now reset the type index to str + data_.type_index = large_type_index; + return dest_data; + } + } + + void InitTypeIndex(int32_t type_index) { data_.type_index = type_index; } + + void MoveToAny(TVMFFIAny *result) { + *result = data_; + data_.type_index = TypeIndex::kTVMFFINone; + data_.zero_padding = 0; + data_.v_int64 = 0; + } + + TVMFFIAny CopyToTVMFFIAny() const { return data_; } + + static BytesBaseCell CopyFromAnyView(const TVMFFIAny *src) { + BytesBaseCell result(*src); + if (result.data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { + details::ObjectUnsafe::IncRefObjectHandle(result.data_.v_obj); + } + return result; + } + + static BytesBaseCell MoveFromAny(TVMFFIAny *src) { + BytesBaseCell result(*src); + src->type_index = TypeIndex::kTVMFFINone; + src->zero_padding = 0; + src->v_int64 = 0; + return result; + } + +private: + explicit BytesBaseCell(TVMFFIAny data) : data_(data) {} + /*! \brief internal backing data */ + TVMFFIAny data_; +}; +} // namespace details + +/*! + * \brief Managed reference of byte array. + */ +class Bytes { +public: + /*! \brief default constructor */ + Bytes() { data_.InitTypeIndex(TypeIndex::kTVMFFISmallBytes); } + /*! + * \brief constructor from size + * + * \param data The data pointer. + * \param size The size of the char array. + */ + Bytes(const char *data, size_t size) { this->InitData(data, size); } + /*! + * \brief constructor from TVMFFIByteArray + * + * \param bytes a char array. + */ + Bytes(TVMFFIByteArray bytes) { // NOLINT(*) + this->InitData(bytes.data, bytes.size); + } + /*! + * \brief constructor from std::string + * + * \param other a char array. + */ + Bytes(const std::string &other) { // NOLINT(*) + this->InitData(other.data(), other.size()); + } + /*! + * \brief constructor from std::string + * + * \param other a char array. + */ + Bytes(std::string &&other) { // NOLINT(*) + data_.InitFromStd(std::move(other), TypeIndex::kTVMFFIBytes); + } + /*! + * \brief Swap this String with another string + * \param other The other string + */ + void swap(Bytes &other) { // NOLINT(*) + std::swap(data_, other.data_); + } + + template + Bytes &operator=(T &&other) { + // copy-and-swap idiom + Bytes(std::forward(other)).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + size_t size() const { return data_.size(); } + /*! + * \brief Return the data pointer + * + * \return const char* data pointer + */ + const char *data() const { return data_.data(); } + /*! + * \brief Convert String to an std::string object + * + * \return std::string + */ + operator std::string() const { // NOLINT(google-explicit-constructor) + return std::string{data(), size()}; + } + + /*! + * \brief Compare two char sequence + * + * \param lhs Pointers to the char array to compare + * \param rhs Pointers to the char array to compare + * \param lhs_count Length of the char array to compare + * \param rhs_count Length of the char array to compare + * \return int zero if both char sequences compare equal. negative if this + * appear before other, positive otherwise. + */ + static int memncmp(const char *lhs, const char *rhs, size_t lhs_count, size_t rhs_count) { + if (lhs == rhs && lhs_count == rhs_count) { + return 0; + } + + for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { + if (lhs[i] < rhs[i]) { + return -1; + } + if (lhs[i] > rhs[i]) { + return 1; + } + } + if (lhs_count < rhs_count) { + return -1; + } else if (lhs_count > rhs_count) { + return 1; + } else { + return 0; + } + } + /*! + * \brief Compare two char sequence for equality + * + * \param lhs Pointers to the char array to compare + * \param rhs Pointers to the char array to compare + * \param lhs_count Length of the char array to compare + * \param rhs_count Length of the char array to compare + * + * \return true if the two char sequences are equal, false otherwise. + */ + static bool memequal(const void *lhs, const void *rhs, size_t lhs_count, size_t rhs_count) { + return lhs_count == rhs_count && (lhs == rhs || std::memcmp(lhs, rhs, lhs_count) == 0); + } + +private: + template + friend struct TypeTraits; + template + friend class Optional; + // internal backing cell + details::BytesBaseCell data_; + // create a new String from TVMFFIAny, must keep private + explicit Bytes(details::BytesBaseCell data) : data_(std::move(data)) {} + char *InitSpaceForSize(size_t size) { + return data_.InitSpaceForSize(size, TypeIndex::kTVMFFISmallBytes, + TypeIndex::kTVMFFIBytes); + } + void InitData(const char *data, size_t size) { + char *dest_data = InitSpaceForSize(size); + std::memcpy(dest_data, data, size); + // mainly to be compat with string + dest_data[size] = '\0'; + } +}; + +/*! + * \brief String container class. + */ +class String { +public: + /*! + * \brief avoid misuse of nullptr + */ + String(std::nullptr_t) = delete; // NOLINT(*) + /*! + * \brief constructor + */ + String() { data_.InitTypeIndex(TypeIndex::kTVMFFISmallStr); } + // constructors from Any + /*! + * \brief Copy constructor + * \param other The other string + */ + String(const String &other) = default; // NOLINT(*) + /*! + * \brief Move constructor + * \param other The other string + */ + String(String &&other) = default; // NOLINT(*) + /*! + * \brief Copy assignment operator + * \param other The other string + */ + String &operator=(const String &other) = default; // NOLINT(*) + /*! + * \brief Move assignment operator + * \param other The other string + */ + String &operator=(String &&other) = default; // NOLINT(*) + + /*! + * \brief Swap this String with another string + * \param other The other string + */ + void swap(String &other) noexcept { // NOLINT(*) + std::swap(data_, other.data_); + } + + /*! + * \brief Copy assignment operator + * \param other The other string + */ + String &operator=(const std::string &other) { + String(other).swap(*this); // NOLINT(*) + return *this; + } + /*! + * \brief Move assignment operator + * \param other The other string + */ + String &operator=(std::string &&other) { + String(std::move(other)).swap(*this); // NOLINT(*) + return *this; + } + + /*! + * \brief Copy assignment operator + * \param other The other string + */ + String &operator=(const char *other) { + String(other).swap(*this); // NOLINT(*) + return *this; + } + + /*! + * \brief constructor from raw string + * + * \param data The data pointer. + * \param size The size of the char array. + */ + String(const char *data, size_t size) { this->InitData(data, size); } + + /*! + * \brief constructor from raw string + * + * \param other a char array. + * \note This constructor is marked as explicit to avoid implicit conversion + * of nullptr value here to string, which then was used in comparison + */ + String(const char *other) { // NOLINT(*) + this->InitData(other, std::char_traits::length(other)); + } + /*! + * \brief Construct a new string object + * \param other The std::string object to be copied + */ + String(const std::string &other) { // NOLINT(*) + this->InitData(other.data(), other.size()); + } + + /*! + * \brief Construct a new string object + * \param other The std::string object to be moved + */ + String(std::string &&other) { // NOLINT(*) + // exception safety, first set to none so if exception is thrown + // destructor works correctly + data_.InitFromStd(std::move(other), TypeIndex::kTVMFFIStr); + } + + /*! + * \brief constructor from TVMFFIByteArray + * + * \param other a TVMFFIByteArray. + */ + explicit String(TVMFFIByteArray other) { this->InitData(other.data, other.size); } + + /*! + * \brief Return the data pointer + * + * \return const char* data pointer + */ + const char *data() const noexcept { return data_.data(); } + + /*! + * \brief Returns a pointer to the char array in the string. + * + * \return const char* + */ + const char *c_str() const noexcept { return data(); } + + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + size_t size() const noexcept { return data_.size(); } + + /*! + * \brief Compares this String object to other + * + * \param other The String to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const String &other) const { + return Bytes::memncmp(data(), other.data(), size(), other.size()); + } + + /*! + * \brief Compares this String object to other + * + * \param other The string to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const std::string &other) const { + return Bytes::memncmp(data(), other.data(), size(), other.size()); + } + + /*! + * \brief Compares this to other + * + * \param other The character array to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const char *other) const { + const char *this_data = data(); + size_t this_size = size(); + for (size_t i = 0; i < this_size; ++i) { + // other is shorter than this + if (other[i] == '\0') { + return 1; + } + if (this_data[i] < other[i]) { + return -1; + } + if (this_data[i] > other[i]) { + return 1; + } + } + // other equals this + if (other[this_size] == '\0') { + return 0; + } + // other longer than this + return -1; + } + + /*! + * \brief Compares this to other + * + * \param other The TVMFFIByteArray to compare with. + * + * \return zero if both char sequences compare equal. negative if this appear + * before other, positive otherwise. + */ + int compare(const TVMFFIByteArray &other) const { + return Bytes::memncmp(data(), other.data, size(), other.size); + } + + /*! + * \brief Return the length of the string + * + * \return size_t string length + */ + size_t length() const { return size(); } + + /*! + * \brief Retun if the string is empty + * + * \return true if empty, false otherwise. + */ + bool empty() const { return size() == 0; } + + /*! + * \brief Read an element. + * \param pos The position at which to read the character. + * + * \return The char at position + */ + char at(size_t pos) const { + if (pos < size()) { + return data()[pos]; + } else { + throw std::out_of_range("tvm::String index out of bounds"); + } + } + + /*! + * \brief Convert String to an std::string object + * + * \return std::string + */ + operator std::string() const { // NOLINT(google-explicit-constructor) + return std::string{data(), size()}; + } + +private: + template + friend struct TypeTraits; + template + friend class Optional; + // internal backing cell + details::BytesBaseCell data_; + // create a new String from TVMFFIAny, must keep private + explicit String(details::BytesBaseCell data) : data_(std::move(data)) {} + /*! + * \brief Create a new empty space for a string + * \param size The size of the string + * \return A pointer to the empty space + */ + char *InitSpaceForSize(size_t size) { + return data_.InitSpaceForSize(size, TypeIndex::kTVMFFISmallStr, + TypeIndex::kTVMFFIStr); + } + void InitData(const char *data, size_t size) { + char *dest_data = InitSpaceForSize(size); + std::memcpy(dest_data, data, size); + dest_data[size] = '\0'; + } + /*! + * \brief Concatenate two char sequences + * + * \param lhs Pointers to the lhs char array + * \param lhs_size The size of the lhs char array + * \param rhs Pointers to the rhs char array + * \param rhs_size The size of the rhs char array + * + * \return The concatenated char sequence + */ + static String Concat(const char *lhs, size_t lhs_size, const char *rhs, size_t rhs_size) { + String ret; + // disable stringop-overflow and restrict warnings + // gcc may produce false positive when we enable dest_data returned from small string path + // Because compiler is not able to detect the condition that the path is only triggered via + // size < kMaxSmallStrLen and can report it as a overflow case. +#if (__GNUC__) && !(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstringop-overflow" +#pragma GCC diagnostic ignored "-Wrestrict" +#endif + char *dest_data = ret.InitSpaceForSize(lhs_size + rhs_size); + std::memcpy(dest_data, lhs, lhs_size); + std::memcpy(dest_data + lhs_size, rhs, rhs_size); + // NOLINTNEXTLINE(clang-analyzer-security.ArrayBound) + dest_data[lhs_size + rhs_size] = '\0'; +#if (__GNUC__) && !(__clang__) +#pragma GCC diagnostic pop +#endif + return ret; + } + // Overload + operator + friend String operator+(const String &lhs, const String &rhs); + friend String operator+(const String &lhs, const std::string &rhs); + friend String operator+(const std::string &lhs, const String &rhs); + friend String operator+(const String &lhs, const char *rhs); + friend String operator+(const char *lhs, const String &rhs); +}; + +/*! + * \brief Return an escaped version of the string + * \param value The input string + * \return The escaped string, quoted with double quotes + */ +inline String EscapeString(const String &value) { + std::ostringstream oss; + oss << '"'; + const char *data = value.data(); + const size_t size = value.size(); + for (size_t i = 0; i < size; ++i) { + switch (data[i]) { +/// \cond Doxygen_Suppress +#define TVM_FFI_ESCAPE_CHAR(pattern, val) \ + case pattern: \ + oss << (val); \ + break + TVM_FFI_ESCAPE_CHAR('\"', "\\\""); + TVM_FFI_ESCAPE_CHAR('\\', "\\\\"); + TVM_FFI_ESCAPE_CHAR('/', "\\/"); + TVM_FFI_ESCAPE_CHAR('\b', "\\b"); + TVM_FFI_ESCAPE_CHAR('\f', "\\f"); + TVM_FFI_ESCAPE_CHAR('\n', "\\n"); + TVM_FFI_ESCAPE_CHAR('\r', "\\r"); + TVM_FFI_ESCAPE_CHAR('\t', "\\t"); +#undef TVM_FFI_ESCAPE_CHAR + /// \endcond + default: { + uint8_t u8_val = static_cast(data[i]); + // this is a control character, print as \uXXXX + if (u8_val < 0x20 || u8_val == 0x7f) { + char buffer[8]; + int size = TVM_FFI_SNPRINTF(buffer, sizeof(buffer), "\\u%04x", + static_cast(data[i]) & 0xff); + oss.write(buffer, size); + } else { + oss << data[i]; + } + break; + } + } + } + oss << '"'; + return String(oss.str()); +} + +/*! \brief Convert TVMFFIByteArray to std::string_view */ +TVM_FFI_INLINE std::string_view ToStringView(TVMFFIByteArray str) { + return std::string_view(str.data, str.size); +} +/// \cond Doxygen_Suppress + +template <> +inline constexpr bool use_default_type_traits_v = false; + +// specialize to enable implicit conversion from TVMFFIByteArray* +template <> +struct TypeTraits : public TypeTraitsBase { + // bytes can be union type of small bytes and object, so keep it as any + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny; + + TVM_FFI_INLINE static void CopyToAnyView(const Bytes &src, TVMFFIAny *result) { + *result = src.data_.CopyToTVMFFIAny(); + } + + TVM_FFI_INLINE static void MoveToAny(Bytes src, TVMFFIAny *result) { + src.data_.MoveToAny(result); + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + return src->type_index == TypeIndex::kTVMFFISmallBytes || src->type_index == TypeIndex::kTVMFFIBytes; + } + + TVM_FFI_INLINE static Bytes CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { + return Bytes(details::BytesBaseCell::CopyFromAnyView(src)); + } + + TVM_FFI_INLINE static Bytes MoveFromAnyAfterCheck(TVMFFIAny *src) { + return Bytes(details::BytesBaseCell::MoveFromAny(src)); + } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFIByteArrayPtr) { + return Bytes(*static_cast(src->v_ptr)); + } + if (src->type_index == TypeIndex::kTVMFFISmallBytes || src->type_index == TypeIndex::kTVMFFIBytes) { + return Bytes(details::BytesBaseCell::CopyFromAnyView(src)); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return "bytes"; } + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIBytes) + R"("})"; + } +}; + +template <> +inline constexpr bool use_default_type_traits_v = false; + +// specialize to enable implicit conversion from const char* +template <> +struct TypeTraits : public TypeTraitsBase { + // string can be union type of small string and object, so keep it as any + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny; + + TVM_FFI_INLINE static void CopyToAnyView(const String &src, TVMFFIAny *result) { + *result = src.data_.CopyToTVMFFIAny(); + } + + TVM_FFI_INLINE static void MoveToAny(String src, TVMFFIAny *result) { + src.data_.MoveToAny(result); + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + return src->type_index == TypeIndex::kTVMFFISmallStr || src->type_index == TypeIndex::kTVMFFIStr; + } + + TVM_FFI_INLINE static String CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { + return String(details::BytesBaseCell::CopyFromAnyView(src)); + } + + TVM_FFI_INLINE static String MoveFromAnyAfterCheck(TVMFFIAny *src) { + return String(details::BytesBaseCell::MoveFromAny(src)); + } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFIRawStr) { + return String(src->v_c_str); + } + if (src->type_index == TypeIndex::kTVMFFISmallStr || src->type_index == TypeIndex::kTVMFFIStr) { + return String(details::BytesBaseCell::CopyFromAnyView(src)); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return "str"; } + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIStr) + R"("})"; + } +}; + +// const char*, requirement: not nullable, do not retain ownership +template +struct TypeTraits : public TypeTraitsBase { + // NOTE: only enable implicit conversion into AnyView + static constexpr bool storage_enabled = false; + + TVM_FFI_INLINE static void CopyToAnyView(const char src[N], TVMFFIAny *result) { + result->type_index = TypeIndex::kTVMFFIRawStr; + result->zero_padding = 0; + result->v_c_str = src; + } + + TVM_FFI_INLINE static void MoveToAny(const char src[N], TVMFFIAny *result) { + // when we need to move to any, convert to owned object first + TypeTraits::MoveToAny(String(src), result); + } +}; + +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr bool storage_enabled = false; + + TVM_FFI_INLINE static void CopyToAnyView(const char *src, TVMFFIAny *result) { + TVM_FFI_ICHECK_NOTNULL(src); + result->type_index = TypeIndex::kTVMFFIRawStr; + result->zero_padding = 0; + result->v_c_str = src; + } + + TVM_FFI_INLINE static void MoveToAny(const char *src, TVMFFIAny *result) { + // when we need to move to any, convert to owned object first + TypeTraits::MoveToAny(String(src), result); + } + // Do not allow const char* in a container, so we do not need CheckAnyStrict + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFIRawStr) { + return static_cast(src->v_c_str); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return "const char*"; } + TVM_FFI_INLINE static std::string TypeSchema() { return R"({"type":"const char*"})"; } +}; + +// TVMFFIByteArray, requirement: not nullable, do not retain ownership +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIByteArrayPtr; + static constexpr bool storage_enabled = false; + + TVM_FFI_INLINE static void CopyToAnyView(TVMFFIByteArray *src, TVMFFIAny *result) { + TVM_FFI_ICHECK_NOTNULL(src); + result->type_index = TypeIndex::kTVMFFIByteArrayPtr; + result->zero_padding = 0; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); + result->v_ptr = src; + } + + TVM_FFI_INLINE static void MoveToAny(TVMFFIByteArray *src, TVMFFIAny *result) { + TypeTraits::MoveToAny(Bytes(*src), result); + } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFIByteArrayPtr) { + return static_cast(src->v_ptr); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIByteArrayPtr; } + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIByteArrayPtr) + R"("})"; + } +}; + +template <> +inline constexpr bool use_default_type_traits_v = false; + +template <> +struct TypeTraits + : public FallbackOnlyTraitsBase { + TVM_FFI_INLINE static void CopyToAnyView(const std::string &src, TVMFFIAny *result) { + result->type_index = TypeIndex::kTVMFFIRawStr; + result->zero_padding = 0; + result->v_c_str = src.c_str(); + } + + TVM_FFI_INLINE static void MoveToAny(std::string src, TVMFFIAny *result) { + // when we need to move to any, convert to owned object first + TypeTraits::MoveToAny(String(std::move(src)), result); + } + + TVM_FFI_INLINE static std::string TypeStr() { return "std::string"; } + TVM_FFI_INLINE static std::string TypeSchema() { return R"({"type":"std::string"})"; } + + TVM_FFI_INLINE static std::string ConvertFallbackValue(const char *src) { + return std::string(src); + } + + TVM_FFI_INLINE static std::string ConvertFallbackValue(TVMFFIByteArray *src) { + return std::string(src->data, src->size); + } + + // NOLINTNEXTLINE(performance-unnecessary-value-param) + TVM_FFI_INLINE static std::string ConvertFallbackValue(Bytes src) { + return src.operator std::string(); + } + + // NOLINTNEXTLINE(performance-unnecessary-value-param) + TVM_FFI_INLINE static std::string ConvertFallbackValue(String src) { + return src.operator std::string(); + } +}; + +inline String operator+(const String &lhs, const String &rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const String &lhs, const std::string &rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const std::string &lhs, const String &rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const char *lhs, const String &rhs) { + size_t lhs_size = std::strlen(lhs); + size_t rhs_size = rhs.size(); + return String::Concat(lhs, lhs_size, rhs.data(), rhs_size); +} + +inline String operator+(const String &lhs, const char *rhs) { + size_t lhs_size = lhs.size(); + size_t rhs_size = std::strlen(rhs); + return String::Concat(lhs.data(), lhs_size, rhs, rhs_size); +} + +// Overload < operator +inline bool operator<(std::nullptr_t, const String &rhs) = delete; +inline bool operator<(const String &lhs, std::nullptr_t) = delete; + +inline bool operator<(const String &lhs, const std::string &rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const std::string &lhs, const String &rhs) { return rhs.compare(lhs) > 0; } + +inline bool operator<(const String &lhs, const String &rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const String &lhs, const char *rhs) { return lhs.compare(rhs) < 0; } + +inline bool operator<(const char *lhs, const String &rhs) { return rhs.compare(lhs) > 0; } + +// Overload > operator +inline bool operator>(std::nullptr_t, const String &rhs) = delete; +inline bool operator>(const String &lhs, std::nullptr_t) = delete; + +inline bool operator>(const String &lhs, const std::string &rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const std::string &lhs, const String &rhs) { return rhs.compare(lhs) < 0; } + +inline bool operator>(const String &lhs, const String &rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const String &lhs, const char *rhs) { return lhs.compare(rhs) > 0; } + +inline bool operator>(const char *lhs, const String &rhs) { return rhs.compare(lhs) < 0; } + +// Overload <= operator +inline bool operator<=(std::nullptr_t, const String &rhs) = delete; +inline bool operator<=(const String &lhs, std::nullptr_t) = delete; + +inline bool operator<=(const String &lhs, const std::string &rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const std::string &lhs, const String &rhs) { return rhs.compare(lhs) >= 0; } + +inline bool operator<=(const String &lhs, const String &rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const String &lhs, const char *rhs) { return lhs.compare(rhs) <= 0; } + +inline bool operator<=(const char *lhs, const String &rhs) { return rhs.compare(lhs) >= 0; } + +// Overload >= operator +inline bool operator>=(std::nullptr_t, const String &rhs) = delete; +inline bool operator>=(const String &lhs, std::nullptr_t) = delete; + +inline bool operator>=(const String &lhs, const std::string &rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const std::string &lhs, const String &rhs) { return rhs.compare(lhs) <= 0; } + +inline bool operator>=(const String &lhs, const String &rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const String &lhs, const char *rhs) { return lhs.compare(rhs) >= 0; } + +inline bool operator>=(const char *lhs, const String &rhs) { return rhs.compare(lhs) <= 0; } + +// delete Overload == operator for nullptr +inline bool operator==(const String &lhs, std::nullptr_t) = delete; +inline bool operator==(std::nullptr_t, const String &rhs) = delete; + +inline bool operator==(const String &lhs, const std::string &rhs) { + return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size()); +} + +inline bool operator==(const std::string &lhs, const String &rhs) { + return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size()); +} + +inline bool operator==(const String &lhs, const String &rhs) { + return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size()); +} + +inline bool operator==(const String &lhs, const char *rhs) { return lhs.compare(rhs) == 0; } + +inline bool operator==(const char *lhs, const String &rhs) { return rhs.compare(lhs) == 0; } + +// Overload != operator +inline bool operator!=(const String &lhs, std::nullptr_t) = delete; +inline bool operator!=(std::nullptr_t, const String &rhs) = delete; + +inline bool operator!=(const String &lhs, const std::string &rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const std::string &lhs, const String &rhs) { return rhs.compare(lhs) != 0; } + +inline bool operator!=(const String &lhs, const String &rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const String &lhs, const char *rhs) { return lhs.compare(rhs) != 0; } + +inline bool operator!=(const char *lhs, const String &rhs) { return rhs.compare(lhs) != 0; } + +inline std::ostream &operator<<(std::ostream &out, const String &input) { + out.write(input.data(), static_cast(input.size())); + return out; +} +/// \endcond +} // namespace ffi +} // namespace tvm + +/// \cond Doxygen_Suppress +namespace std { + +template <> +struct hash<::tvm::ffi::Bytes> { + std::size_t operator()(const ::tvm::ffi::Bytes &bytes) const { + return std::hash()(std::string_view(bytes.data(), bytes.size())); + } +}; + +template <> +struct hash<::tvm::ffi::String> { + std::size_t operator()(const ::tvm::ffi::String &str) const { + return std::hash()(std::string_view(str.data(), str.size())); + } +}; +} // namespace std +/// \endcond +#endif // TVM_FFI_STRING_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/type_traits.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/type_traits.h new file mode 100644 index 000000000..d9f3f58a7 --- /dev/null +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/type_traits.h @@ -0,0 +1,828 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/object.h + * \brief A managed object in the TVM FFI. + */ +#ifndef TVM_FFI_TYPE_TRAITS_H_ +#define TVM_FFI_TYPE_TRAITS_H_ + +#include "base_details.h" +#include "c_api.h" +#include "error.h" +#include "object.h" + +#include +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief TypeTraits that specifies the conversion behavior from/to FFI Any. + * + * The function specifications of TypeTraits + * + * - CopyToAnyView: Convert a value T to AnyView + * - MoveToAny: Move a value to Any + * - CheckAnyStrict: Check if a Any stores a result of CopyToAnyView of current T. + * - CopyFromAnyViewAfterCheck: Copy a value T from Any view after we pass CheckAnyStrict. + * - MoveFromAnyAfterCheck: Move a value T from Any storage after we pass CheckAnyStrict. + * - TryCastFromAnyView: Convert a AnyView to a T, we may apply type conversion. + * - GetMismatchTypeInfo: Get the type key of a type when TryCastFromAnyView fails. + * - TypeStr: Get the type key of a type + * + * It is possible that CheckAnyStrict is false but TryCastFromAnyView still works. + * + * For example, when Any x stores int, TypeTraits::CheckAnyStrict(x) will be false, + * but TypeTraits::TryCastFromAnyView(x) will return a corresponding float value + * via type conversion. + * + * CheckAnyStrict is mainly used in recursive container such as Array to + * decide if a new Array needed to be created via recursive conversion, + * or we can use the current container as is when converting to Array. + * + * A container array: Array satisfies the following invariant: + * - `all(TypeTraits::CheckAnyStrict(x) for x in the array)`. + */ +template +struct TypeTraits { + /*! \brief Whether the type is enabled in FFI. */ + static constexpr bool convert_enabled = false; + /*! \brief Whether the type can appear as a storage type in Container */ + static constexpr bool storage_enabled = false; +}; + +/*! + * \brief TypeTraits that removes const and reference keywords. + * \tparam T the original type + */ +template +using TypeTraitsNoCR = TypeTraits>>; + +template +inline constexpr bool use_default_type_traits_v = true; + +struct TypeTraitsBase { + static constexpr bool convert_enabled = true; + static constexpr bool storage_enabled = true; + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny; + // get mismatched type when result mismatches the trait. + // this function is called after TryCastFromAnyView fails + // to get more detailed type information in runtime + // especially when the error involves nested container type + TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny *source) { + return TypeIndexToTypeKey(source->type_index); + } +}; + +/*! + * \brief Trait that maps a type to its field static type index + * \tparam T the type + * \return the field static type index + */ +template +struct TypeToFieldStaticTypeIndex { + /*! \brief The field static type index of the type */ + static constexpr int32_t value = TypeIndex::kTVMFFIAny; +}; + +template +struct TypeToFieldStaticTypeIndex::convert_enabled>> { + static constexpr int32_t value = TypeTraits::field_static_type_index; +}; + +/*! + * \brief Trait that maps a type to its runtime type index + * \tparam T the type + * \return the runtime type index + */ +template +struct TypeToRuntimeTypeIndex { + /*! + * \brief Get the runtime type index of the type + * \return the runtime type index + */ + static int32_t v() { return TypeToFieldStaticTypeIndex::value; } +}; + +template +struct TypeToRuntimeTypeIndex>> { + static int32_t v() { return T::ContainerType::RuntimeTypeIndex(); } +}; + +// None +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFINone; + + TVM_FFI_INLINE static void CopyToAnyView(const std::nullptr_t &, TVMFFIAny *result) { + result->type_index = TypeIndex::kTVMFFINone; + result->zero_padding = 0; + // invariant: the pointer field also equals nullptr + // this will simplify same_as comparisons and hash + result->v_int64 = 0; + } + + TVM_FFI_INLINE static void MoveToAny(std::nullptr_t, TVMFFIAny *result) { + result->type_index = TypeIndex::kTVMFFINone; + result->zero_padding = 0; + // invariant: the pointer field also equals nullptr + // this will simplify same_as comparisons and hash + result->v_int64 = 0; + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + return src->type_index == TypeIndex::kTVMFFINone; + } + + TVM_FFI_INLINE static std::nullptr_t CopyFromAnyViewAfterCheck(const TVMFFIAny *) { + return nullptr; + } + + TVM_FFI_INLINE static std::nullptr_t MoveFromAnyAfterCheck(TVMFFIAny *) { return nullptr; } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return nullptr; + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFINone; } + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":")" + std::string(StaticTypeKey::kTVMFFINone) + R"("})"; + } +}; + +/** + * \brief A type that forbids implicit conversion from int to bool + * + * This type is used to prevent implicit conversion from int to bool. + */ +class StrictBool { +public: + /*! + * \brief Constructor + * \param value The value of the strict bool. + */ + StrictBool(bool value) : value_(value) {} // NOLINT(google-explicit-constructor) + /*! + *\brief Convert the strict bool to bool. + * \return The value of the strict bool. + */ + operator bool() const { return value_; } // NOLINT(google-explicit-constructor) + +private: + bool value_; +}; + +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIBool; + + TVM_FFI_INLINE static void CopyToAnyView(const StrictBool &src, TVMFFIAny *result) { + result->type_index = TypeIndex::kTVMFFIBool; + result->zero_padding = 0; + result->v_int64 = static_cast(src); + } + + TVM_FFI_INLINE static void MoveToAny(StrictBool src, TVMFFIAny *result) { + CopyToAnyView(src, result); + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + return src->type_index == TypeIndex::kTVMFFIBool; + } + + TVM_FFI_INLINE static StrictBool CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { + return static_cast(src->v_int64); + } + + TVM_FFI_INLINE static StrictBool MoveFromAnyAfterCheck(TVMFFIAny *src) { + // POD type, we can just copy the value + return CopyFromAnyViewAfterCheck(src); + } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFIBool) { + return StrictBool(static_cast(src->v_int64)); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIBool; } + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIBool) + R"("})"; + } +}; + +// Bool type, allow implicit casting from int +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIBool; + + TVM_FFI_INLINE static void CopyToAnyView(const bool &src, TVMFFIAny *result) { + result->type_index = TypeIndex::kTVMFFIBool; + result->zero_padding = 0; + result->v_int64 = static_cast(src); + } + + TVM_FFI_INLINE static void MoveToAny(bool src, TVMFFIAny *result) { CopyToAnyView(src, result); } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + return src->type_index == TypeIndex::kTVMFFIBool; + } + + TVM_FFI_INLINE static bool CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { + return static_cast(src->v_int64); + } + + TVM_FFI_INLINE static bool MoveFromAnyAfterCheck(TVMFFIAny *src) { + // POD type, we can just copy the value + return CopyFromAnyViewAfterCheck(src); + } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { + return static_cast(src->v_int64); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIBool; } + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIBool) + R"("})"; + } +}; + +// Integer POD values +template +struct TypeTraits>> : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIInt; + + TVM_FFI_INLINE static void CopyToAnyView(const Int &src, TVMFFIAny *result) { + result->type_index = TypeIndex::kTVMFFIInt; + result->zero_padding = 0; + result->v_int64 = static_cast(src); + } + + TVM_FFI_INLINE static void MoveToAny(Int src, TVMFFIAny *result) { CopyToAnyView(src, result); } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny + return src->type_index == TypeIndex::kTVMFFIInt; + } + + TVM_FFI_INLINE static Int CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { + return static_cast(src->v_int64); + } + + TVM_FFI_INLINE static Int MoveFromAnyAfterCheck(TVMFFIAny *src) { + // POD type, we can just copy the value + return CopyFromAnyViewAfterCheck(src); + } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { + return Int(src->v_int64); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIInt; } + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIInt) + R"("})"; + } +}; + +/// \cond Doxygen_Suppress + +// trait to check if a type is an integeral enum +// note that we need this trait so we can confirm underlying_type_t is an integral type +// to avoid potential undefined behavior +template > +constexpr bool is_integeral_enum_v = false; + +template +constexpr bool is_integeral_enum_v = std::is_integral_v>; + +/// \endcond + +// Enum Integer POD values +template +struct TypeTraits>> : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIInt; + + TVM_FFI_INLINE static void CopyToAnyView(const IntEnum &src, TVMFFIAny *result) { + result->type_index = TypeIndex::kTVMFFIInt; + result->zero_padding = 0; + result->v_int64 = static_cast(src); + } + + TVM_FFI_INLINE static void MoveToAny(IntEnum src, TVMFFIAny *result) { + CopyToAnyView(src, result); + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny + return src->type_index == TypeIndex::kTVMFFIInt; + } + + TVM_FFI_INLINE static IntEnum CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { + return static_cast(src->v_int64); + } + + TVM_FFI_INLINE static IntEnum MoveFromAnyAfterCheck(TVMFFIAny *src) { + // POD type, we can just copy the value + return CopyFromAnyViewAfterCheck(src); + } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { + return static_cast(src->v_int64); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIInt; } + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIInt) + R"("})"; + } +}; + +// Float POD values +template +struct TypeTraits>> + : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFloat; + + TVM_FFI_INLINE static void CopyToAnyView(const Float &src, TVMFFIAny *result) { + result->type_index = TypeIndex::kTVMFFIFloat; + result->zero_padding = 0; + result->v_float64 = static_cast(src); + } + + TVM_FFI_INLINE static void MoveToAny(Float src, TVMFFIAny *result) { CopyToAnyView(src, result); } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny + return src->type_index == TypeIndex::kTVMFFIFloat; + } + + TVM_FFI_INLINE static Float CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { + return static_cast(src->v_float64); + } + + TVM_FFI_INLINE static Float MoveFromAnyAfterCheck(TVMFFIAny *src) { + // POD type, we can just copy the value + return CopyFromAnyViewAfterCheck(src); + } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFIFloat) { + return Float(src->v_float64); + } else if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { + return Float(src->v_int64); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIFloat; } + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIFloat) + R"("})"; + } +}; + +// void* +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIOpaquePtr; + + TVM_FFI_INLINE static void CopyToAnyView(void *src, TVMFFIAny *result) { + result->type_index = TypeIndex::kTVMFFIOpaquePtr; + result->zero_padding = 0; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); + result->v_ptr = src; + } + + TVM_FFI_INLINE static void MoveToAny(void *src, TVMFFIAny *result) { CopyToAnyView(src, result); } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny + return src->type_index == TypeIndex::kTVMFFIOpaquePtr; + } + + TVM_FFI_INLINE static void *CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { return src->v_ptr; } + + TVM_FFI_INLINE static void *MoveFromAnyAfterCheck(TVMFFIAny *src) { + // POD type, we can just copy the value + return CopyFromAnyViewAfterCheck(src); + } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFIOpaquePtr) { + return static_cast(src->v_ptr); + } + if (src->type_index == TypeIndex::kTVMFFINone) { + return static_cast(nullptr); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIOpaquePtr; } + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIOpaquePtr) + R"("})"; + } +}; + +// Device +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDevice; + + TVM_FFI_INLINE static void CopyToAnyView(const DLDevice &src, TVMFFIAny *result) { + result->type_index = TypeIndex::kTVMFFIDevice; + result->zero_padding = 0; + result->v_device = src; + } + + TVM_FFI_INLINE static void MoveToAny(DLDevice src, TVMFFIAny *result) { + result->type_index = TypeIndex::kTVMFFIDevice; + result->zero_padding = 0; + result->v_device = src; + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + return src->type_index == TypeIndex::kTVMFFIDevice; + } + + TVM_FFI_INLINE static DLDevice CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { + return src->v_device; + } + + TVM_FFI_INLINE static DLDevice MoveFromAnyAfterCheck(TVMFFIAny *src) { + // POD type, we can just copy the value + return CopyFromAnyViewAfterCheck(src); + } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFIDevice) { + return src->v_device; + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIDevice; } + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIDevice) + R"("})"; + } +}; + +// DLTensor*, requirement: not nullable, do not retain ownership +template <> +struct TypeTraits : public TypeTraitsBase { + static constexpr bool storage_enabled = false; + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDLTensorPtr; + + TVM_FFI_INLINE static void CopyToAnyView(DLTensor *src, TVMFFIAny *result) { + TVM_FFI_ICHECK_NOTNULL(src); + result->type_index = TypeIndex::kTVMFFIDLTensorPtr; + result->zero_padding = 0; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); + result->v_ptr = src; + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + return src->type_index == TypeIndex::kTVMFFIDLTensorPtr; + } + + TVM_FFI_INLINE static DLTensor *CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { + return static_cast(src->v_ptr); + } + + TVM_FFI_INLINE static void MoveToAny(DLTensor *, TVMFFIAny *) { + TVM_FFI_THROW(RuntimeError) + << "DLTensor* cannot be held in Any as it does not retain ownership, use Tensor instead"; + } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFIDLTensorPtr) { + return static_cast(src->v_ptr); + } else if (src->type_index == TypeIndex::kTVMFFITensor) { + // Conversion from Tensor pointer to DLTensor + // based on the assumption that Tensor always follows the TVMFFIObject header + static_assert(sizeof(TVMFFIObject) == 24); + return reinterpret_cast(reinterpret_cast(src->v_obj) + sizeof(TVMFFIObject)); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return "DLTensor*"; } + TVM_FFI_INLINE static std::string TypeSchema() { return R"({"type":"DLTensor*"})"; } +}; + +// Traits for ObjectRef, None to ObjectRef will always fail. +// use std::optional instead for nullable references. +template +struct ObjectRefTypeTraitsBase : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIObject; + using ContainerType = typename TObjRef::ContainerType; + + TVM_FFI_INLINE static void CopyToAnyView(const TObjRef &src, TVMFFIAny *result) { + if constexpr (TObjRef::_type_is_nullable) { + if (!src.defined()) { + TypeTraits::CopyToAnyView(nullptr, result); + return; + } + } + TVMFFIObject *obj_ptr = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(src); + result->type_index = obj_ptr->type_index; + result->zero_padding = 0; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); + result->v_obj = obj_ptr; + } + + TVM_FFI_INLINE static void MoveToAny(TObjRef src, TVMFFIAny *result) { + if constexpr (TObjRef::_type_is_nullable) { + if (!src.defined()) { + TypeTraits::CopyToAnyView(nullptr, result); + return; + } + } + TVMFFIObject *obj_ptr = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(src)); + result->type_index = obj_ptr->type_index; + result->zero_padding = 0; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); + result->v_obj = obj_ptr; + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + if constexpr (TObjRef::_type_is_nullable) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return true; + } + } + return (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin && details::IsObjectInstance(src->type_index)); + } + + TVM_FFI_INLINE static TObjRef CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { + if constexpr (TObjRef::_type_is_nullable) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); + } + } + return details::ObjectUnsafe::ObjectRefFromObjectPtr( + details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); + } + + TVM_FFI_INLINE static TObjRef MoveFromAnyAfterCheck(TVMFFIAny *src) { + if constexpr (TObjRef::_type_is_nullable) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); + } + } + // move out the object pointer + ObjectPtr obj_ptr = details::ObjectUnsafe::ObjectPtrFromOwned(src->v_obj); + // reset the src to nullptr + TypeTraits::MoveToAny(nullptr, src); + return details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(obj_ptr)); + } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if constexpr (TObjRef::_type_is_nullable) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); + } + } + if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { + if (details::IsObjectInstance(src->type_index)) { + return details::ObjectUnsafe::ObjectRefFromObjectPtr( + details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); + } + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return ContainerType::_type_key; } + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":")" + std::string(ContainerType::_type_key) + R"("})"; + } +}; + +template +struct TypeTraits && use_default_type_traits_v>> + : public ObjectRefTypeTraitsBase {}; + +/*! + * \brief Helper class that convert to T only via the FallbackTypes + * + * The conversion will go through the FallbackTypes in the order + * specified in the template parameter. + * \tparam T The type of the target value. + * \tparam FallbackTypes The type of the fallback value. + * \note TypeTraits must be derived from this class and define + * ConvertFallbackValue(FallbackType)->T for each FallbackType + */ +template +struct FallbackOnlyTraitsBase : public TypeTraitsBase { + // disable container for FallbackOnlyTraitsBase + /// \cond Doxygen_Suppress + static constexpr bool storage_enabled = false; + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + return TryFallbackTypes(src); + } + + template + TVM_FFI_INLINE static std::optional TryFallbackTypes(const TVMFFIAny *src) { + static_assert(!std::is_same_v, + "Using bool as FallbackType can cause bug because int will be detected as bool, " + "use tvm::ffi::StrictBool instead"); + if (auto opt_fallback = TypeTraits::TryCastFromAnyView(src)) { + return TypeTraits::ConvertFallbackValue(*std::move(opt_fallback)); + } + if constexpr (sizeof...(Rest) > 0) { + return TryFallbackTypes(src); + } + return std::nullopt; + } + /// \endcond +}; + +/*! + * \brief Helper class to define ObjectRef that can be auto-converted from a + * fallback type, the Traits must be derived from it + * and define a static methods named ConvertFallbackValue for each + * FallbackType + * + * The conversion will go through the FallbackTypes in the order + * specified in the template parameter. + * \tparam ObjectRefType The type of the ObjectRef. + * \tparam FallbackTypes The type of the fallback value. + */ +template +struct ObjectRefWithFallbackTraitsBase : public ObjectRefTypeTraitsBase { + /// \cond Doxygen_Suppress + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if (auto opt_obj = ObjectRefTypeTraitsBase::TryCastFromAnyView(src)) { + return opt_obj; + } + // apply fallback types in TryCastFromAnyView + return TryFallbackTypes(src); + } + + template + TVM_FFI_INLINE static std::optional TryFallbackTypes(const TVMFFIAny *src) { + static_assert(!std::is_same_v, + "Using bool as FallbackType can cause bug because int will be detected as bool, " + "use tvm::ffi::StrictBool instead"); + if (auto opt_fallback = TypeTraits::TryCastFromAnyView(src)) { + return TypeTraits::ConvertFallbackValue(*std::move(opt_fallback)); + } + if constexpr (sizeof...(Rest) > 0) { + return TryFallbackTypes(src); + } + return std::nullopt; + } + /// \endcond +}; + +// Traits for weak pointer of object +// NOTE: we require the weak pointer cast from + +template +struct TypeTraits>> + : public TypeTraitsBase { + TVM_FFI_INLINE static void CopyToAnyView(TObject *src, TVMFFIAny *result) { + TVMFFIObject *obj_ptr = details::ObjectUnsafe::GetHeader(src); + result->type_index = obj_ptr->type_index; + result->zero_padding = 0; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); + result->v_obj = obj_ptr; + } + + TVM_FFI_INLINE static void MoveToAny(TObject *src, TVMFFIAny *result) { + TVMFFIObject *obj_ptr = details::ObjectUnsafe::GetHeader(src); + result->type_index = obj_ptr->type_index; + result->zero_padding = 0; + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); + result->v_obj = obj_ptr; + // needs to increase ref because original weak ptr do not own the code + details::ObjectUnsafe::IncRefObjectHandle(result->v_obj); + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + return src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin && details::IsObjectInstance(src->type_index); + } + + TVM_FFI_INLINE static TObject *CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { + if constexpr (!std::is_const_v) { + static_assert(TObject::_type_mutable, "TObject must be mutable to enable cast from Any"); + } + return details::ObjectUnsafe::RawObjectPtrFromUnowned(src->v_obj); + } + + TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { + if constexpr (!std::is_const_v) { + static_assert(TObject::_type_mutable, "TObject must be mutable to enable cast from Any"); + } + if (CheckAnyStrict(src)) { + return CopyFromAnyViewAfterCheck(src); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return TObject::_type_key; } + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":")" + std::string(TObject::_type_key) + R"("})"; + } +}; + +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> : public TypeTraitsBase { + TVM_FFI_INLINE static void CopyToAnyView(const Optional &src, TVMFFIAny *result) { + if (src.has_value()) { + TypeTraits::CopyToAnyView(*src, result); + } else { + TypeTraits::CopyToAnyView(nullptr, result); + } + } + + TVM_FFI_INLINE static void MoveToAny(Optional src, TVMFFIAny *result) { + if (src.has_value()) { + TypeTraits::MoveToAny(*std::move(src), result); + } else { + TypeTraits::CopyToAnyView(nullptr, result); + } + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return true; + } + return TypeTraits::CheckAnyStrict(src); + } + + TVM_FFI_INLINE static Optional CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return Optional(std::nullopt); + } + return TypeTraits::CopyFromAnyViewAfterCheck(src); + } + + TVM_FFI_INLINE static Optional MoveFromAnyAfterCheck(TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return Optional(std::nullopt); + } + return TypeTraits::MoveFromAnyAfterCheck(src); + } + + TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFINone) { + return Optional(std::nullopt); + } + if (std::optional opt = TypeTraits::TryCastFromAnyView(src)) { + return Optional(*std::move(opt)); + } else { + // important to be explicit here + // because nullopt can convert to std::optional(nullopt) which indicate success + // return std::optional>(std::nullopt) to indicate failure + return std::optional>(std::nullopt); + } + } + + TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny *src) { + return TypeTraits::GetMismatchTypeInfo(src); + } + + TVM_FFI_INLINE static std::string TypeStr() { + return "Optional<" + TypeTraits::TypeStr() + ">"; + } + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":"Optional","args":[)" + details::TypeSchema::v() + "]}"; + } +}; +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_TYPE_TRAITS_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh index 18d5da7c3..18b52cf99 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh @@ -17,8 +17,8 @@ #include "utils.h" -#include -#include +#include "dlpack/dlpack.h" +#include "tvm/ffi/extra/c_env_api.h" #include #include diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h index d6892d0dd..e2f0b8420 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h @@ -45,9 +45,9 @@ #include "source_location.h" #endif +#include "dlpack/dlpack.h" #include #include -#include #include #include #include diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index e43f3a18d..1a4160807 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -1314,6 +1314,52 @@ def per_tensor_dequant_int8_(lib): ] +@OpRegister.operator +def gptq_marlin_gemm_(lib): + lib.infiniopCreateGptqMarlinGemmDescriptor.restype = c_int32 + lib.infiniopCreateGptqMarlinGemmDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopGetGptqMarlinGemmWorkspaceSize.restype = c_int32 + lib.infiniopGetGptqMarlinGemmWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + lib.infiniopGptqMarlinGemm.restype = c_int32 + lib.infiniopGptqMarlinGemm.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_int64, + c_bool, + c_bool, + c_bool, + c_bool, + c_void_p, + ] + lib.infiniopDestroyGptqMarlinGemmDescriptor.restype = c_int32 + lib.infiniopDestroyGptqMarlinGemmDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def gptq_qyblas_gemm_(lib): lib.infiniopCreateGptqQyblasGemmDescriptor.restype = c_int32 From 71e4b4cf16955e6f3dfb40e7b2c2a51027f3a13f Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Wed, 20 May 2026 10:41:44 +0800 Subject: [PATCH 05/10] issue/1083: modified format --- .../sgl_kernel/dlpack/dlpack.h | 562 ++-- .../sgl_kernel/tvm/ffi/cast.h | 40 +- .../sgl_kernel/tvm/ffi/container/map.h | 2986 +++++++++-------- .../sgl_kernel/tvm/ffi/container/tuple.h | 569 ++-- .../sgl_kernel/tvm/ffi/container/variant.h | 450 +-- .../sgl_kernel/tvm/ffi/endian.h | 25 +- .../sgl_kernel/tvm/ffi/extra/base.h | 2 +- .../sgl_kernel/tvm/ffi/extra/base64.h | 150 +- .../sgl_kernel/tvm/ffi/extra/cuda/base.h | 26 +- .../tvm/ffi/extra/cuda/cubin_launcher.h | 578 ++-- .../tvm/ffi/extra/cuda/device_guard.h | 48 +- .../sgl_kernel/tvm/ffi/extra/json.h | 12 +- .../sgl_kernel/tvm/ffi/extra/module.h | 428 +-- .../sgl_kernel/tvm/ffi/extra/serialization.h | 10 +- .../tvm/ffi/extra/structural_equal.h | 78 +- .../tvm/ffi/extra/structural_hash.h | 40 +- .../tvm/ffi/reflection/access_path.h | 648 ++-- .../sgl_kernel/tvm/ffi/reflection/accessor.h | 278 +- .../sgl_kernel/tvm/ffi/reflection/creator.h | 153 +- .../sgl_kernel/tvm/ffi/reflection/registry.h | 1088 +++--- .../sgl_kernel/tvm/ffi/rvalue_ref.h | 152 +- 21 files changed, 4169 insertions(+), 4154 deletions(-) diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/dlpack/dlpack.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/dlpack/dlpack.h index 9a710ebde..a6e2f5c58 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/dlpack/dlpack.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/dlpack/dlpack.h @@ -32,8 +32,8 @@ #define DLPACK_DLL #endif -#include #include +#include #ifdef __cplusplus extern "C" { @@ -59,10 +59,10 @@ extern "C" { * updates indicate the addition of enumeration values. */ typedef struct { - /*! \brief DLPack major version. */ - uint32_t major; - /*! \brief DLPack minor version. */ - uint32_t minor; + /*! \brief DLPack major version. */ + uint32_t major; + /*! \brief DLPack minor version. */ + uint32_t minor; } DLPackVersion; /*! @@ -73,112 +73,112 @@ typedef enum : int32_t { #else typedef enum { #endif - /*! \brief CPU device */ - kDLCPU = 1, - /*! \brief CUDA GPU device */ - kDLCUDA = 2, - /*! - * \brief Pinned CUDA CPU memory by cudaMallocHost - */ - kDLCUDAHost = 3, - /*! \brief OpenCL devices. */ - kDLOpenCL = 4, - /*! \brief Vulkan buffer for next generation graphics. */ - kDLVulkan = 7, - /*! \brief Metal for Apple GPU. */ - kDLMetal = 8, - /*! \brief Verilog simulator buffer */ - kDLVPI = 9, - /*! \brief ROCm GPUs for AMD GPUs */ - kDLROCM = 10, - /*! - * \brief Pinned ROCm CPU memory allocated by hipMallocHost - */ - kDLROCMHost = 11, - /*! - * \brief Reserved extension device type, - * used for quickly test extension device - * The semantics can differ depending on the implementation. - */ - kDLExtDev = 12, - /*! - * \brief CUDA managed/unified memory allocated by cudaMallocManaged - */ - kDLCUDAManaged = 13, - /*! - * \brief Unified shared memory allocated on a oneAPI non-partititioned - * device. Call to oneAPI runtime is required to determine the device - * type, the USM allocation type and the sycl context it is bound to. - * - */ - kDLOneAPI = 14, - /*! \brief GPU support for next generation WebGPU standard. */ - kDLWebGPU = 15, - /*! \brief Qualcomm Hexagon DSP */ - kDLHexagon = 16, - /*! \brief Microsoft MAIA devices */ - kDLMAIA = 17, - /*! \brief AWS Trainium */ - kDLTrn = 18, + /*! \brief CPU device */ + kDLCPU = 1, + /*! \brief CUDA GPU device */ + kDLCUDA = 2, + /*! + * \brief Pinned CUDA CPU memory by cudaMallocHost + */ + kDLCUDAHost = 3, + /*! \brief OpenCL devices. */ + kDLOpenCL = 4, + /*! \brief Vulkan buffer for next generation graphics. */ + kDLVulkan = 7, + /*! \brief Metal for Apple GPU. */ + kDLMetal = 8, + /*! \brief Verilog simulator buffer */ + kDLVPI = 9, + /*! \brief ROCm GPUs for AMD GPUs */ + kDLROCM = 10, + /*! + * \brief Pinned ROCm CPU memory allocated by hipMallocHost + */ + kDLROCMHost = 11, + /*! + * \brief Reserved extension device type, + * used for quickly test extension device + * The semantics can differ depending on the implementation. + */ + kDLExtDev = 12, + /*! + * \brief CUDA managed/unified memory allocated by cudaMallocManaged + */ + kDLCUDAManaged = 13, + /*! + * \brief Unified shared memory allocated on a oneAPI non-partititioned + * device. Call to oneAPI runtime is required to determine the device + * type, the USM allocation type and the sycl context it is bound to. + * + */ + kDLOneAPI = 14, + /*! \brief GPU support for next generation WebGPU standard. */ + kDLWebGPU = 15, + /*! \brief Qualcomm Hexagon DSP */ + kDLHexagon = 16, + /*! \brief Microsoft MAIA devices */ + kDLMAIA = 17, + /*! \brief AWS Trainium */ + kDLTrn = 18, } DLDeviceType; /*! * \brief A Device for Tensor and operator. */ typedef struct { - /*! \brief The device type used in the device. */ - DLDeviceType device_type; - /*! - * \brief The device index. - * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0. - */ - int32_t device_id; + /*! \brief The device type used in the device. */ + DLDeviceType device_type; + /*! + * \brief The device index. + * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0. + */ + int32_t device_id; } DLDevice; /*! * \brief The type code options DLDataType. */ typedef enum { - /*! \brief signed integer */ - kDLInt = 0U, - /*! \brief unsigned integer */ - kDLUInt = 1U, - /*! \brief IEEE floating point */ - kDLFloat = 2U, - /*! - * \brief Opaque handle type, reserved for testing purposes. - * Frameworks need to agree on the handle data type for the exchange to be well-defined. - */ - kDLOpaqueHandle = 3U, - /*! \brief bfloat16 */ - kDLBfloat = 4U, - /*! - * \brief complex number - * (C/C++/Python layout: compact struct per complex number) - */ - kDLComplex = 5U, - /*! \brief boolean */ - kDLBool = 6U, - /*! \brief FP8 data types */ - kDLFloat8_e3m4 = 7U, - kDLFloat8_e4m3 = 8U, - kDLFloat8_e4m3b11fnuz = 9U, - kDLFloat8_e4m3fn = 10U, - kDLFloat8_e4m3fnuz = 11U, - kDLFloat8_e5m2 = 12U, - kDLFloat8_e5m2fnuz = 13U, - kDLFloat8_e8m0fnu = 14U, - /*! \brief FP6 data types - * Setting bits != 6 is currently unspecified, and the producer must ensure it is set - * while the consumer must stop importing if the value is unexpected. - */ - kDLFloat6_e2m3fn = 15U, - kDLFloat6_e3m2fn = 16U, - /*! \brief FP4 data types - * Setting bits != 4 is currently unspecified, and the producer must ensure it is set - * while the consumer must stop importing if the value is unexpected. - */ - kDLFloat4_e2m1fn = 17U, + /*! \brief signed integer */ + kDLInt = 0U, + /*! \brief unsigned integer */ + kDLUInt = 1U, + /*! \brief IEEE floating point */ + kDLFloat = 2U, + /*! + * \brief Opaque handle type, reserved for testing purposes. + * Frameworks need to agree on the handle data type for the exchange to be well-defined. + */ + kDLOpaqueHandle = 3U, + /*! \brief bfloat16 */ + kDLBfloat = 4U, + /*! + * \brief complex number + * (C/C++/Python layout: compact struct per complex number) + */ + kDLComplex = 5U, + /*! \brief boolean */ + kDLBool = 6U, + /*! \brief FP8 data types */ + kDLFloat8_e3m4 = 7U, + kDLFloat8_e4m3 = 8U, + kDLFloat8_e4m3b11fnuz = 9U, + kDLFloat8_e4m3fn = 10U, + kDLFloat8_e4m3fnuz = 11U, + kDLFloat8_e5m2 = 12U, + kDLFloat8_e5m2fnuz = 13U, + kDLFloat8_e8m0fnu = 14U, + /*! \brief FP6 data types + * Setting bits != 6 is currently unspecified, and the producer must ensure it is set + * while the consumer must stop importing if the value is unexpected. + */ + kDLFloat6_e2m3fn = 15U, + kDLFloat6_e3m2fn = 16U, + /*! \brief FP4 data types + * Setting bits != 4 is currently unspecified, and the producer must ensure it is set + * while the consumer must stop importing if the value is unexpected. + */ + kDLFloat4_e2m1fn = 17U, } DLDataTypeCode; /*! @@ -200,81 +200,81 @@ typedef enum { * for a packed data set D ((D >> (i * bits)) && bit_mask) stores the i-th element. */ typedef struct { - /*! - * \brief Type code of base types. - * We keep it uint8_t instead of DLDataTypeCode for minimal memory - * footprint, but the value should be one of DLDataTypeCode enum values. - * */ - uint8_t code; - /*! - * \brief Number of bits, common choices are 8, 16, 32. - */ - uint8_t bits; - /*! \brief Number of lanes in the type, used for vector types. */ - uint16_t lanes; + /*! + * \brief Type code of base types. + * We keep it uint8_t instead of DLDataTypeCode for minimal memory + * footprint, but the value should be one of DLDataTypeCode enum values. + * */ + uint8_t code; + /*! + * \brief Number of bits, common choices are 8, 16, 32. + */ + uint8_t bits; + /*! \brief Number of lanes in the type, used for vector types. */ + uint16_t lanes; } DLDataType; /*! * \brief Plain C Tensor object, does not manage memory. */ typedef struct { - /*! - * \brief The data pointer points to the allocated data. This will be CUDA - * device pointer or cl_mem handle in OpenCL. It may be opaque on some device - * types. This pointer is always aligned to 256 bytes as in CUDA. The - * `byte_offset` field should be used to point to the beginning of the data. - * - * Note that as of Nov 2021, multiple libraries (CuPy, PyTorch, TensorFlow, - * TVM, perhaps others) do not adhere to this 256 byte alignment requirement - * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed - * (after which this note will be updated); at the moment it is recommended - * to not rely on the data pointer being correctly aligned. - * - * For given DLTensor, the size of memory required to store the contents of - * data is calculated as follows: - * - * \code{.c} - * static inline size_t GetDataSize(const DLTensor* t) { - * size_t size = 1; - * for (tvm_index_t i = 0; i < t->ndim; ++i) { - * size *= t->shape[i]; - * } - * size *= (t->dtype.bits * t->dtype.lanes + 7) / 8; - * return size; - * } - * \endcode - * - * Note that if the tensor is of size zero, then the data pointer should be - * set to `NULL`. - */ - void* data; - /*! \brief The device of the tensor */ - DLDevice device; - /*! \brief Number of dimensions */ - int32_t ndim; - /*! \brief The data type of the pointer*/ - DLDataType dtype; - /*! - * \brief The shape of the tensor - * - * When ndim == 0, shape can be set to NULL. - */ - int64_t* shape; - /*! - * \brief strides of the tensor (in number of elements, not bytes), - * can not be NULL if ndim != 0, must points to - * an array of ndim elements that specifies the strides, - * so consumer can always rely on strides[dim] being valid for 0 <= dim < ndim. - * - * When ndim == 0, strides can be set to NULL. - * - * \note Before DLPack v1.2, strides can be NULL to indicate contiguous data. - * This is not allowed in DLPack v1.2 and later. The rationale - * is to simplify the consumer handling. - */ - int64_t* strides; - /*! \brief The offset in bytes to the beginning pointer to data */ - uint64_t byte_offset; + /*! + * \brief The data pointer points to the allocated data. This will be CUDA + * device pointer or cl_mem handle in OpenCL. It may be opaque on some device + * types. This pointer is always aligned to 256 bytes as in CUDA. The + * `byte_offset` field should be used to point to the beginning of the data. + * + * Note that as of Nov 2021, multiple libraries (CuPy, PyTorch, TensorFlow, + * TVM, perhaps others) do not adhere to this 256 byte alignment requirement + * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed + * (after which this note will be updated); at the moment it is recommended + * to not rely on the data pointer being correctly aligned. + * + * For given DLTensor, the size of memory required to store the contents of + * data is calculated as follows: + * + * \code{.c} + * static inline size_t GetDataSize(const DLTensor* t) { + * size_t size = 1; + * for (tvm_index_t i = 0; i < t->ndim; ++i) { + * size *= t->shape[i]; + * } + * size *= (t->dtype.bits * t->dtype.lanes + 7) / 8; + * return size; + * } + * \endcode + * + * Note that if the tensor is of size zero, then the data pointer should be + * set to `NULL`. + */ + void *data; + /*! \brief The device of the tensor */ + DLDevice device; + /*! \brief Number of dimensions */ + int32_t ndim; + /*! \brief The data type of the pointer*/ + DLDataType dtype; + /*! + * \brief The shape of the tensor + * + * When ndim == 0, shape can be set to NULL. + */ + int64_t *shape; + /*! + * \brief strides of the tensor (in number of elements, not bytes), + * can not be NULL if ndim != 0, must points to + * an array of ndim elements that specifies the strides, + * so consumer can always rely on strides[dim] being valid for 0 <= dim < ndim. + * + * When ndim == 0, strides can be set to NULL. + * + * \note Before DLPack v1.2, strides can be NULL to indicate contiguous data. + * This is not allowed in DLPack v1.2 and later. The rationale + * is to simplify the consumer handling. + */ + int64_t *strides; + /*! \brief The offset in bytes to the beginning pointer to data */ + uint64_t byte_offset; } DLTensor; /*! @@ -292,19 +292,19 @@ typedef struct { * \sa DLManagedTensorVersioned */ typedef struct DLManagedTensor { - /*! \brief DLTensor which is being memory managed */ - DLTensor dl_tensor; - /*! \brief the context of the original host framework of DLManagedTensor in - * which DLManagedTensor is used in the framework. It can also be NULL. - */ - void * manager_ctx; - /*! - * \brief Destructor - this should be called - * to destruct the manager_ctx which backs the DLManagedTensor. It can be - * NULL if there is no way for the caller to provide a reasonable destructor. - * The destructor deletes the argument self as well. - */ - void (*deleter)(struct DLManagedTensor * self); + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; + /*! \brief the context of the original host framework of DLManagedTensor in + * which DLManagedTensor is used in the framework. It can also be NULL. + */ + void *manager_ctx; + /*! + * \brief Destructor - this should be called + * to destruct the manager_ctx which backs the DLManagedTensor. It can be + * NULL if there is no way for the caller to provide a reasonable destructor. + * The destructor deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensor *self); } DLManagedTensor; // bit masks used in the DLManagedTensorVersioned @@ -339,39 +339,39 @@ typedef struct DLManagedTensor { * \note This is the current standard DLPack exchange data structure. */ typedef struct DLManagedTensorVersioned { - /*! - * \brief The API and ABI version of the current managed Tensor - */ - DLPackVersion version; - /*! - * \brief the context of the original host framework. - * - * Stores DLManagedTensorVersioned is used in the - * framework. It can also be NULL. - */ - void *manager_ctx; - /*! - * \brief Destructor. - * - * This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned. - * It can be NULL if there is no way for the caller to provide a reasonable - * destructor. The destructor deletes the argument self as well. - */ - void (*deleter)(struct DLManagedTensorVersioned *self); - /*! - * \brief Additional bitmask flags information about the tensor. - * - * By default the flags should be set to 0. - * - * \note Future ABI changes should keep everything until this field - * stable, to ensure that deleter can be correctly called. - * - * \sa DLPACK_FLAG_BITMASK_READ_ONLY - * \sa DLPACK_FLAG_BITMASK_IS_COPIED - */ - uint64_t flags; - /*! \brief DLTensor which is being memory managed */ - DLTensor dl_tensor; + /*! + * \brief The API and ABI version of the current managed Tensor + */ + DLPackVersion version; + /*! + * \brief the context of the original host framework. + * + * Stores DLManagedTensorVersioned is used in the + * framework. It can also be NULL. + */ + void *manager_ctx; + /*! + * \brief Destructor. + * + * This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned. + * It can be NULL if there is no way for the caller to provide a reasonable + * destructor. The destructor deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensorVersioned *self); + /*! + * \brief Additional bitmask flags information about the tensor. + * + * By default the flags should be set to 0. + * + * \note Future ABI changes should keep everything until this field + * stable, to ensure that deleter can be correctly called. + * + * \sa DLPACK_FLAG_BITMASK_READ_ONLY + * \sa DLPACK_FLAG_BITMASK_IS_COPIED + */ + uint64_t flags; + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; } DLManagedTensorVersioned; //---------------------------------------------------------------------- @@ -400,9 +400,9 @@ typedef struct DLManagedTensorVersioned { * * \sa DLPackExchangeAPI */ -typedef int (*DLPackManagedTensorAllocator)( // - DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, // - void (*SetError)(void* error_ctx, const char* kind, const char* message) // +typedef int (*DLPackManagedTensorAllocator)( // + DLTensor *prototype, DLManagedTensorVersioned **out, void *error_ctx, // + void (*SetError)(void *error_ctx, const char *kind, const char *message) // ); /*! @@ -422,9 +422,9 @@ typedef int (*DLPackManagedTensorAllocator)( * * \sa DLPackExchangeAPI, DLPackCurrentWorkStream */ -typedef int (*DLPackManagedTensorFromPyObjectNoSync)( // - void* py_object, // - DLManagedTensorVersioned** out // +typedef int (*DLPackManagedTensorFromPyObjectNoSync)( // + void *py_object, // + DLManagedTensorVersioned **out // ); /*! @@ -451,9 +451,9 @@ typedef int (*DLPackManagedTensorFromPyObjectNoSync)( // * * \sa DLPackExchangeAPI, DLPackCurrentWorkStream */ -typedef int (*DLPackDLTensorFromPyObjectNoSync)( // - void* py_object, // - DLTensor* out // +typedef int (*DLPackDLTensorFromPyObjectNoSync)( // + void *py_object, // + DLTensor *out // ); /*! @@ -477,10 +477,10 @@ typedef int (*DLPackDLTensorFromPyObjectNoSync)( // * * \sa DLPackExchangeAPI */ -typedef int (*DLPackCurrentWorkStream)( // - DLDeviceType device_type, // - int32_t device_id, // - void** out_current_stream // +typedef int (*DLPackCurrentWorkStream)( // + DLDeviceType device_type, // + int32_t device_id, // + void **out_current_stream // ); /*! @@ -500,9 +500,9 @@ typedef int (*DLPackCurrentWorkStream)( // * * \sa DLPackExchangeAPI */ -typedef int (*DLPackManagedTensorToPyObjectNoSync)( // - DLManagedTensorVersioned* tensor, // - void** out_py_object // +typedef int (*DLPackManagedTensorToPyObjectNoSync)( // + DLManagedTensorVersioned *tensor, // + void **out_py_object // ); /*! @@ -510,21 +510,21 @@ typedef int (*DLPackManagedTensorToPyObjectNoSync)( // * \sa DLPackExchangeAPI */ typedef struct DLPackExchangeAPIHeader { - /*! - * \brief The provided DLPack version the consumer must check major version - * compatibility before using this struct. - */ - DLPackVersion version; - /*! - * \brief Optional pointer to an older DLPackExchangeAPI in the chain. - * - * It must be NULL if the framework does not support older versions. - * If the current major version is larger than the one supported by the - * consumer, the consumer may walk this to find an earlier supported version. - * - * \sa DLPackExchangeAPI - */ - struct DLPackExchangeAPIHeader* prev_api; + /*! + * \brief The provided DLPack version the consumer must check major version + * compatibility before using this struct. + */ + DLPackVersion version; + /*! + * \brief Optional pointer to an older DLPackExchangeAPI in the chain. + * + * It must be NULL if the framework does not support older versions. + * If the current major version is larger than the one supported by the + * consumer, the consumer may walk this to find an earlier supported version. + * + * \sa DLPackExchangeAPI + */ + struct DLPackExchangeAPIHeader *prev_api; } DLPackExchangeAPIHeader; /*! @@ -597,43 +597,43 @@ typedef struct DLPackExchangeAPIHeader { * to do so in other languages. */ typedef struct DLPackExchangeAPI { - /*! - * \brief The header that remains stable across versions. - */ - DLPackExchangeAPIHeader header; - /*! - * \brief Producer function pointer for DLPackManagedTensorAllocator - * This function must not be NULL. - * \sa DLPackManagedTensorAllocator - */ - DLPackManagedTensorAllocator managed_tensor_allocator; - /*! - * \brief Producer function pointer for DLPackManagedTensorFromPyObject - * This function must be not NULL. - * \sa DLPackManagedTensorFromPyObject - */ - DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync; - /*! - * \brief Producer function pointer for DLPackManagedTensorToPyObject - * This function must be not NULL. - * \sa DLPackManagedTensorToPyObject - */ - DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync; - /*! - * \brief Producer function pointer for DLPackDLTensorFromPyObject - * This function can be NULL when the producer does not support this function. - * \sa DLPackDLTensorFromPyObjectNoSync - */ - DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync; - /*! - * \brief Producer function pointer for DLPackCurrentWorkStream - * This function must be not NULL. - * \sa DLPackCurrentWorkStream - */ - DLPackCurrentWorkStream current_work_stream; + /*! + * \brief The header that remains stable across versions. + */ + DLPackExchangeAPIHeader header; + /*! + * \brief Producer function pointer for DLPackManagedTensorAllocator + * This function must not be NULL. + * \sa DLPackManagedTensorAllocator + */ + DLPackManagedTensorAllocator managed_tensor_allocator; + /*! + * \brief Producer function pointer for DLPackManagedTensorFromPyObject + * This function must be not NULL. + * \sa DLPackManagedTensorFromPyObject + */ + DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync; + /*! + * \brief Producer function pointer for DLPackManagedTensorToPyObject + * This function must be not NULL. + * \sa DLPackManagedTensorToPyObject + */ + DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync; + /*! + * \brief Producer function pointer for DLPackDLTensorFromPyObject + * This function can be NULL when the producer does not support this function. + * \sa DLPackDLTensorFromPyObjectNoSync + */ + DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync; + /*! + * \brief Producer function pointer for DLPackCurrentWorkStream + * This function must be not NULL. + * \sa DLPackCurrentWorkStream + */ + DLPackCurrentWorkStream current_work_stream; } DLPackExchangeAPI; #ifdef __cplusplus -} // DLPACK_EXTERN_C +} // DLPACK_EXTERN_C #endif -#endif // DLPACK_DLPACK_H_ +#endif // DLPACK_DLPACK_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/cast.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/cast.h index 66abd6644..6d6513af5 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/cast.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/cast.h @@ -43,21 +43,21 @@ namespace ffi { * \return The corresponding RefType */ template -inline RefType GetRef(const ObjectType* ptr) { - using ContainerType = typename RefType::ContainerType; - static_assert(std::is_base_of_v, - "Can only cast to the ref of same container type"); +inline RefType GetRef(const ObjectType *ptr) { + using ContainerType = typename RefType::ContainerType; + static_assert(std::is_base_of_v, + "Can only cast to the ref of same container type"); - if constexpr (is_optional_type_v || RefType::_type_is_nullable) { - if (ptr == nullptr) { - return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); + if constexpr (is_optional_type_v || RefType::_type_is_nullable) { + if (ptr == nullptr) { + return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); + } + } else { + TVM_FFI_ICHECK_NOTNULL(ptr); } - } else { - TVM_FFI_ICHECK_NOTNULL(ptr); - } - return details::ObjectUnsafe::ObjectRefFromObjectPtr( - details::ObjectUnsafe::ObjectPtrFromUnowned( - const_cast(static_cast(ptr)))); + return details::ObjectUnsafe::ObjectRefFromObjectPtr( + details::ObjectUnsafe::ObjectPtrFromUnowned( + const_cast(static_cast(ptr)))); } /*! @@ -69,11 +69,11 @@ inline RefType GetRef(const ObjectType* ptr) { * \return The corresponding RefType */ template -inline ObjectPtr GetObjectPtr(ObjectType* ptr) { - static_assert(std::is_base_of_v, - "Can only cast to the ref of same container type"); - return details::ObjectUnsafe::ObjectPtrFromUnowned(ptr); +inline ObjectPtr GetObjectPtr(ObjectType *ptr) { + static_assert(std::is_base_of_v, + "Can only cast to the ref of same container type"); + return details::ObjectUnsafe::ObjectPtrFromUnowned(ptr); } -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_CAST_H_ +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CAST_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/map.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/map.h index b948fc8e4..4e5d1a635 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/map.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/map.h @@ -43,416 +43,416 @@ namespace ffi { /// \cond Doxygen_Suppress #if TVM_FFI_DEBUG_WITH_ABI_CHANGE #define TVM_FFI_MAP_FAIL_IF_CHANGED() \ - TVM_FFI_ICHECK(state_marker == self->state_marker) << "Concurrent modification of the Map"; + TVM_FFI_ICHECK(state_marker == self->state_marker) << "Concurrent modification of the Map"; #else #define TVM_FFI_MAP_FAIL_IF_CHANGED() -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE +#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE /// \endcond /*! \brief Shared content of all specializations of hash map */ class MapObj : public Object { - public: - /*! \brief Type of the keys in the hash map */ - using key_type = Any; - /*! \brief Type of the values in the hash map */ - using mapped_type = Any; - /*! \brief Type of value stored in the hash map */ - using KVType = std::pair; - /// \cond Doxygen_Suppress - /*! \brief Type of raw storage of the key-value pair in the hash map */ - struct KVRawStorageType { - TVMFFIAny first; - TVMFFIAny second; - }; - /// \endcond - /*! \brief Iterator class */ - class iterator; - - static_assert(std::is_standard_layout_v, "KVType is not standard layout"); - static_assert(sizeof(KVType) == 32, "sizeof(KVType) incorrect"); - - /// \cond Doxygen_Suppress - static constexpr const int32_t _type_index = TypeIndex::kTVMFFIMap; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIMap, MapObj, Object); - /// \endcond - - /*! - * \brief Number of elements in the MapObj - * \return The result - */ - size_t size() const { return size_; } - /*! - * \brief Count the number of times a key exists in the hash map - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const; - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const; - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key); - /*! \return begin iterator */ - iterator begin() const; - /*! \return end iterator */ - iterator end() const; - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const; - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position); - /*! - * \brief Erase the entry associated with the key, do nothing if not exists - * \param key The indexing key - */ - void erase(const key_type& key) { erase(find(key)); } - - /// \cond Doxygen_Suppress - class iterator { - public: - using iterator_category = std::forward_iterator_tag; - using difference_type = int64_t; - using value_type = KVType; - using pointer = KVType*; - using reference = KVType&; +public: + /*! \brief Type of the keys in the hash map */ + using key_type = Any; + /*! \brief Type of the values in the hash map */ + using mapped_type = Any; + /*! \brief Type of value stored in the hash map */ + using KVType = std::pair; + /// \cond Doxygen_Suppress + /*! \brief Type of raw storage of the key-value pair in the hash map */ + struct KVRawStorageType { + TVMFFIAny first; + TVMFFIAny second; + }; + /// \endcond + /*! \brief Iterator class */ + class iterator; + + static_assert(std::is_standard_layout_v, "KVType is not standard layout"); + static_assert(sizeof(KVType) == 32, "sizeof(KVType) incorrect"); + + /// \cond Doxygen_Suppress + static constexpr const int32_t _type_index = TypeIndex::kTVMFFIMap; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIMap, MapObj, Object); + /// \endcond + + /*! + * \brief Number of elements in the MapObj + * \return The result + */ + size_t size() const { return size_; } + /*! + * \brief Count the number of times a key exists in the hash map + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type &key) const; + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type &at(const key_type &key) const; + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type &at(const key_type &key); + /*! \return begin iterator */ + iterator begin() const; + /*! \return end iterator */ + iterator end() const; + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type &key) const; + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator &position); + /*! + * \brief Erase the entry associated with the key, do nothing if not exists + * \param key The indexing key + */ + void erase(const key_type &key) { erase(find(key)); } + + /// \cond Doxygen_Suppress + class iterator { + public: + using iterator_category = std::forward_iterator_tag; + using difference_type = int64_t; + using value_type = KVType; + using pointer = KVType *; + using reference = KVType &; /*! \brief Default constructor */ #if TVM_FFI_DEBUG_WITH_ABI_CHANGE - iterator() : state_marker(0), index(0), self(nullptr) {} + iterator() : state_marker(0), index(0), self(nullptr) {} #else - iterator() : index(0), self(nullptr) {} -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - /*! \brief Compare iterators */ - bool operator==(const iterator& other) const { - TVM_FFI_MAP_FAIL_IF_CHANGED() - return index == other.index && self == other.self; - } - /*! \brief Compare iterators */ - bool operator!=(const iterator& other) const { return !(*this == other); } - /*! \brief De-reference iterators */ - pointer operator->() const; - /*! \brief De-reference iterators */ - reference operator*() const { - TVM_FFI_MAP_FAIL_IF_CHANGED() - return *((*this).operator->()); - } - /*! \brief Prefix self increment, e.g. ++iter */ - iterator& operator++(); - /*! \brief Prefix self decrement, e.g. --iter */ - iterator& operator--(); - /*! \brief Suffix self increment */ - iterator operator++(int) { - TVM_FFI_MAP_FAIL_IF_CHANGED() - iterator copy = *this; - ++(*this); - return copy; - } - /*! \brief Suffix self decrement */ - iterator operator--(int) { - TVM_FFI_MAP_FAIL_IF_CHANGED() - iterator copy = *this; - --(*this); - return copy; - } + iterator() : index(0), self(nullptr) {} +#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE + /*! \brief Compare iterators */ + bool operator==(const iterator &other) const { + TVM_FFI_MAP_FAIL_IF_CHANGED() + return index == other.index && self == other.self; + } + /*! \brief Compare iterators */ + bool operator!=(const iterator &other) const { return !(*this == other); } + /*! \brief De-reference iterators */ + pointer operator->() const; + /*! \brief De-reference iterators */ + reference operator*() const { + TVM_FFI_MAP_FAIL_IF_CHANGED() + return *((*this).operator->()); + } + /*! \brief Prefix self increment, e.g. ++iter */ + iterator &operator++(); + /*! \brief Prefix self decrement, e.g. --iter */ + iterator &operator--(); + /*! \brief Suffix self increment */ + iterator operator++(int) { + TVM_FFI_MAP_FAIL_IF_CHANGED() + iterator copy = *this; + ++(*this); + return copy; + } + /*! \brief Suffix self decrement */ + iterator operator--(int) { + TVM_FFI_MAP_FAIL_IF_CHANGED() + iterator copy = *this; + --(*this); + return copy; + } - protected: + protected: #if TVM_FFI_DEBUG_WITH_ABI_CHANGE - uint64_t state_marker; - /*! \brief Construct by value */ - iterator(uint64_t index, const MapObj* self) - : state_marker(self->state_marker), index(index), self(self) {} + uint64_t state_marker; + /*! \brief Construct by value */ + iterator(uint64_t index, const MapObj *self) + : state_marker(self->state_marker), index(index), self(self) {} #else - iterator(uint64_t index, const MapObj* self) : index(index), self(self) {} -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - /*! \brief The position on the array */ - uint64_t index; - /*! \brief The container it points to */ - const MapObj* self; - - friend class DenseMapObj; - friend class SmallMapObj; - }; - /// \endcond - /*! - * \brief Create an empty container - * \return The object created - */ - static inline ObjectPtr Empty(); - - protected: + iterator(uint64_t index, const MapObj *self) : index(index), self(self) {} +#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE + /*! \brief The position on the array */ + uint64_t index; + /*! \brief The container it points to */ + const MapObj *self; + + friend class DenseMapObj; + friend class SmallMapObj; + }; + /// \endcond + /*! + * \brief Create an empty container + * \return The object created + */ + static inline ObjectPtr Empty(); + +protected: #if TVM_FFI_DEBUG_WITH_ABI_CHANGE - uint64_t state_marker; -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - /*! - * \brief Create the map using contents from the given iterators. - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return ObjectPtr to the map created - */ - template - static inline ObjectPtr CreateFromRange(IterType first, IterType last); - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static inline void InsertMaybeReHash(KVType&& kv, ObjectPtr* map); - /*! - * \brief Create an empty container with elements copying from another SmallMapObj - * \param from The source container - * \return The object created - */ - static inline ObjectPtr CopyFrom(MapObj* from); - /*! - * \brief data pointer to the data region of the map. - * \note For immutable inplace small map we do not need data_, - * but we keep it here for future compact with mutable container. - */ - void* data_; - /*! \brief number of entries in the container */ - uint64_t size_; - /*! \brief number of slots */ - uint64_t slots_; - /*! - * \brief Small layout tag mask - * \note The most significant bit is used to indicate the small map layout. - */ - static constexpr uint64_t kSmallTagMask = static_cast(1) << 63; - /*! - * \brief Check if the map is a small map - * \return True if the map is a small map - */ - bool IsSmallMap() const { return (slots_ & kSmallTagMask) != 0ull; } - /*! - * \brief Optional data deleter when data is allocated separately - * and its deletion is not managed by MapObj::deleter_. - */ - void (*data_deleter_)(void*) = nullptr; - // Reference class - template - friend class Map; + uint64_t state_marker; +#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE + /*! + * \brief Create the map using contents from the given iterators. + * \param first Begin of iterator + * \param last End of iterator + * \tparam IterType The type of iterator + * \return ObjectPtr to the map created + */ + template + static inline ObjectPtr CreateFromRange(IterType first, IterType last); + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static inline void InsertMaybeReHash(KVType &&kv, ObjectPtr *map); + /*! + * \brief Create an empty container with elements copying from another SmallMapObj + * \param from The source container + * \return The object created + */ + static inline ObjectPtr CopyFrom(MapObj *from); + /*! + * \brief data pointer to the data region of the map. + * \note For immutable inplace small map we do not need data_, + * but we keep it here for future compact with mutable container. + */ + void *data_; + /*! \brief number of entries in the container */ + uint64_t size_; + /*! \brief number of slots */ + uint64_t slots_; + /*! + * \brief Small layout tag mask + * \note The most significant bit is used to indicate the small map layout. + */ + static constexpr uint64_t kSmallTagMask = static_cast(1) << 63; + /*! + * \brief Check if the map is a small map + * \return True if the map is a small map + */ + bool IsSmallMap() const { return (slots_ & kSmallTagMask) != 0ull; } + /*! + * \brief Optional data deleter when data is allocated separately + * and its deletion is not managed by MapObj::deleter_. + */ + void (*data_deleter_)(void *) = nullptr; + // Reference class + template + friend class Map; }; /*! \brief A specialization of small-sized hash map */ class SmallMapObj : public MapObj, public details::InplaceArrayBase { - private: - static constexpr uint64_t kInitSize = 2; - static constexpr uint64_t kMaxSize = 4; - - public: - using MapObj::iterator; - using MapObj::KVType; - - // Return the number of usable slots for Small layout (mask off tag). - /*! - * \brief Return the number of usable slots for Small layout (mask off tag). - * \return The number of usable slots - */ - uint64_t NumSlots() const { return slots_ & ~kSmallTagMask; } - - ~SmallMapObj() { - KVType* begin = static_cast(data_); - for (uint64_t index = 0; index < size_; ++index) { - // call destructor to destroy the item in `begin + index` - // Explicit call Any::~Any() to destroy the Any object - // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) - (begin + index)->first.Any::~Any(); - (begin + index)->second.Any::~Any(); - } - if (data_deleter_ != nullptr) { - data_deleter_(data_); +private: + static constexpr uint64_t kInitSize = 2; + static constexpr uint64_t kMaxSize = 4; + +public: + using MapObj::iterator; + using MapObj::KVType; + + // Return the number of usable slots for Small layout (mask off tag). + /*! + * \brief Return the number of usable slots for Small layout (mask off tag). + * \return The number of usable slots + */ + uint64_t NumSlots() const { return slots_ & ~kSmallTagMask; } + + ~SmallMapObj() { + KVType *begin = static_cast(data_); + for (uint64_t index = 0; index < size_; ++index) { + // call destructor to destroy the item in `begin + index` + // Explicit call Any::~Any() to destroy the Any object + // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) + (begin + index)->first.Any::~Any(); + (begin + index)->second.Any::~Any(); + } + if (data_deleter_ != nullptr) { + data_deleter_(data_); + } } - } - /*! - * \brief Count the number of times a key exists in the SmallMapObj - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type& key) const { return find(key).index < size_; } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { - iterator itr = find(key); - if (itr.index >= size_) { - TVM_FFI_THROW(KeyError) << "key is not in Map"; + /*! + * \brief Count the number of times a key exists in the SmallMapObj + * \param key The indexing key + * \return The result, 0 or 1 + */ + size_t count(const key_type &key) const { return find(key).index < size_; } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type &at(const key_type &key) const { + iterator itr = find(key); + if (itr.index >= size_) { + TVM_FFI_THROW(KeyError) << "key is not in Map"; + } + return itr->second; } - return itr->second; - } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { - iterator itr = find(key); - if (itr.index >= size_) { - TVM_FFI_THROW(KeyError) << "key is not in Map"; + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type &at(const key_type &key) { + iterator itr = find(key); + if (itr.index >= size_) { + TVM_FFI_THROW(KeyError) << "key is not in Map"; + } + return itr->second; } - return itr->second; - } - /*! \return begin iterator */ - iterator begin() const { return iterator(0, this); } - /*! \return end iterator */ - iterator end() const { return iterator(size_, this); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const { - KVType* ptr = static_cast(data_); - for (uint64_t i = 0; i < size_; ++i, ++ptr) { - if (AnyEqual()(ptr->first, key)) { - return iterator(i, this); - } + /*! \return begin iterator */ + iterator begin() const { return iterator(0, this); } + /*! \return end iterator */ + iterator end() const { return iterator(size_, this); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type &key) const { + KVType *ptr = static_cast(data_); + for (uint64_t i = 0; i < size_; ++i, ++ptr) { + if (AnyEqual()(ptr->first, key)) { + return iterator(i, this); + } + } + return iterator(size_, this); } - return iterator(size_, this); - } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { Erase(position.index); } - - private: - /*! - * \brief Set the number of slots and attach tags bit. - * \param n The number of slots - */ - void SetSlotsAndSmallLayoutTag(uint64_t n) { slots_ = (n & ~kSmallTagMask) | kSmallTagMask; } - /*! - * \brief Remove a position in SmallMapObj - * \param index The position to be removed - */ - void Erase(const uint64_t index) { - if (index >= size_) { - return; + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator &position) { Erase(position.index); } + +private: + /*! + * \brief Set the number of slots and attach tags bit. + * \param n The number of slots + */ + void SetSlotsAndSmallLayoutTag(uint64_t n) { slots_ = (n & ~kSmallTagMask) | kSmallTagMask; } + /*! + * \brief Remove a position in SmallMapObj + * \param index The position to be removed + */ + void Erase(const uint64_t index) { + if (index >= size_) { + return; + } + KVType *begin = static_cast(data_); + // call destructor to destroy the item in `begin + index` + // Explicit call Any::~Any() to destroy the Any object + // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) + (begin + index)->first.Any::~Any(); + (begin + index)->second.Any::~Any(); + // IMPORTANT: We do direct raw memmove to bring later items to the current position + // to preserve the order of insertion. + // This works because direct memory copy preserves the Any's move semantics. + if (index + 1 < size_) { + std::memmove(reinterpret_cast(begin + index), + reinterpret_cast(begin + index + 1), + (size_ - index - 1) * sizeof(KVType)); + } + size_ -= 1; } - KVType* begin = static_cast(data_); - // call destructor to destroy the item in `begin + index` - // Explicit call Any::~Any() to destroy the Any object - // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) - (begin + index)->first.Any::~Any(); - (begin + index)->second.Any::~Any(); - // IMPORTANT: We do direct raw memmove to bring later items to the current position - // to preserve the order of insertion. - // This works because direct memory copy preserves the Any's move semantics. - if (index + 1 < size_) { - std::memmove(reinterpret_cast(begin + index), - reinterpret_cast(begin + index + 1), - (size_ - index - 1) * sizeof(KVType)); + /*! + * \brief Create an empty container + * \param n Number of empty slots + * \return The object created + */ + static ObjectPtr Empty(uint64_t n = kInitSize) { + using ::tvm::ffi::make_inplace_array_object; + ObjectPtr p = make_inplace_array_object(n); + p->data_ = p->AddressOf(0); + p->size_ = 0; + p->SetSlotsAndSmallLayoutTag(n); + return p; } - size_ -= 1; - } - /*! - * \brief Create an empty container - * \param n Number of empty slots - * \return The object created - */ - static ObjectPtr Empty(uint64_t n = kInitSize) { - using ::tvm::ffi::make_inplace_array_object; - ObjectPtr p = make_inplace_array_object(n); - p->data_ = p->AddressOf(0); - p->size_ = 0; - p->SetSlotsAndSmallLayoutTag(n); - return p; - } - /*! - * \brief Create an empty container initialized with a given range - * \param n Number of empty slots - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - * \return The object created - */ - template - static ObjectPtr CreateFromRange(uint64_t n, IterType first, IterType last) { - ObjectPtr p = Empty(n); - KVType* ptr = static_cast(p->data_); - for (; first != last; ++first, ++p->size_) { - new (ptr++) KVType(*first); + /*! + * \brief Create an empty container initialized with a given range + * \param n Number of empty slots + * \param first begin of iterator + * \param last end of iterator + * \tparam IterType The type of iterator + * \return The object created + */ + template + static ObjectPtr CreateFromRange(uint64_t n, IterType first, IterType last) { + ObjectPtr p = Empty(n); + KVType *ptr = static_cast(p->data_); + for (; first != last; ++first, ++p->size_) { + new (ptr++) KVType(*first); + } + return p; } - return p; - } - /*! - * \brief Create an empty container with elements copying from another SmallMapObj - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(SmallMapObj* from) { - KVType* first = static_cast(from->data_); - KVType* last = first + from->size_; - return CreateFromRange(from->size_, first, last); - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(KVType&& kv, ObjectPtr* map) { - SmallMapObj* map_node = static_cast(map->get()); - iterator itr = map_node->find(kv.first); - if (itr.index < map_node->size_) { - itr->second = kv.second; - return; + /*! + * \brief Create an empty container with elements copying from another SmallMapObj + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(SmallMapObj *from) { + KVType *first = static_cast(from->data_); + KVType *last = first + from->size_; + return CreateFromRange(from->size_, first, last); } - if (map_node->size_ < map_node->NumSlots()) { - KVType* ptr = static_cast(map_node->data_) + map_node->size_; - new (ptr) KVType(std::move(kv)); - ++map_node->size_; - return; + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(KVType &&kv, ObjectPtr *map) { + SmallMapObj *map_node = static_cast(map->get()); + iterator itr = map_node->find(kv.first); + if (itr.index < map_node->size_) { + itr->second = kv.second; + return; + } + if (map_node->size_ < map_node->NumSlots()) { + KVType *ptr = static_cast(map_node->data_) + map_node->size_; + new (ptr) KVType(std::move(kv)); + ++map_node->size_; + return; + } + uint64_t next_size = std::max(map_node->NumSlots() * 2, kInitSize); + next_size = std::min(next_size, kMaxSize); + TVM_FFI_ICHECK_GT(next_size, map_node->NumSlots()); + ObjectPtr new_map = CreateFromRange(next_size, map_node->begin(), map_node->end()); + InsertMaybeReHash(std::move(kv), &new_map); + *map = std::move(new_map); } - uint64_t next_size = std::max(map_node->NumSlots() * 2, kInitSize); - next_size = std::min(next_size, kMaxSize); - TVM_FFI_ICHECK_GT(next_size, map_node->NumSlots()); - ObjectPtr new_map = CreateFromRange(next_size, map_node->begin(), map_node->end()); - InsertMaybeReHash(std::move(kv), &new_map); - *map = std::move(new_map); - } - /*! - * \brief Increment the pointer - * \param index The pointer to be incremented - * \return The increased pointer - */ - uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; } - /*! - * \brief Decrement the pointer - * \param index The pointer to be decremented - * \return The decreased pointer - */ - uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; } - /*! - * \brief De-reference the pointer - * \param index The pointer to be dereferenced - * \return The result - */ - KVType* DeRefItr(uint64_t index) const { return static_cast(data_) + index; } - /*! \brief A size function used by InplaceArrayBase */ - uint64_t GetSize() const { return size_; } - - protected: - friend class MapObj; - friend class DenseMapObj; - friend class details::InplaceArrayBase; + /*! + * \brief Increment the pointer + * \param index The pointer to be incremented + * \return The increased pointer + */ + uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; } + /*! + * \brief Decrement the pointer + * \param index The pointer to be decremented + * \return The decreased pointer + */ + uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; } + /*! + * \brief De-reference the pointer + * \param index The pointer to be dereferenced + * \return The result + */ + KVType *DeRefItr(uint64_t index) const { return static_cast(data_) + index; } + /*! \brief A size function used by InplaceArrayBase */ + uint64_t GetSize() const { return size_; } + +protected: + friend class MapObj; + friend class DenseMapObj; + friend class details::InplaceArrayBase; }; /*! \brief A specialization of hash map that implements the idea of array-based hash map. @@ -514,666 +514,665 @@ class SmallMapObj : public MapObj, * [3] https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ */ class DenseMapObj : public MapObj { - private: - /*! \brief The number of elements in a memory block */ - static constexpr int kBlockCap = 16; - /*! \brief Maximum load factor of the hash map */ - static constexpr double kMaxLoadFactor = 0.99; - /*! \brief Binary representation of the metadata of an empty slot */ - static constexpr uint8_t kEmptySlot = static_cast(0b11111111); - /*! \brief Binary representation of the metadata of a protected slot */ - static constexpr uint8_t kProtectedSlot = static_cast(0b11111110); - /*! \brief Number of probing choices available */ - static constexpr int kNumJumpDists = 126; - /*! \brief Index indicator to indicate an invalid index */ - static constexpr uint64_t kInvalidIndex = std::numeric_limits::max(); - /*! \brief Head of the implicit linked list */ - struct ListNode; - /*! \brief item type of the dense map, including a kv data and prev/next pointer */ - struct ItemType { - KVType data; - uint64_t prev = kInvalidIndex; - uint64_t next = kInvalidIndex; - - explicit ItemType(KVType&& data) : data(std::move(data)) {} - explicit ItemType(key_type key, mapped_type value) : data(std::move(key), std::move(value)) {} - }; - /*! \brief POD type of a block of memory */ - struct Block { - uint8_t bytes[kBlockCap + kBlockCap * sizeof(ItemType)]; - }; - static_assert(sizeof(Block) == kBlockCap * (sizeof(ItemType) + 1), "sizeof(Block) incorrect"); - static_assert(std::is_standard_layout_v, "Block is not standard layout"); - - /*! - * \brief Deleter for the Block - * \param data The pointer to the Block - */ - static void BlockDeleter(void* data) { delete[] static_cast(data); } - - public: - using MapObj::iterator; - - /*! - * \brief Return the number of usable slots for Dense layout (MSB clear => identity). - * \return The number of usable slots - */ - uint64_t NumSlots() const { return slots_; } - - /*! - * \brief Destroy the DenseMapObj - */ - ~DenseMapObj() { this->Reset(); } - /*! \return The number of elements of the key */ - size_t count(const key_type& key) const { return !Search(key).IsNone(); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type& at(const key_type& key) const { return At(key); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type& at(const key_type& key) { return At(key); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type& key) const { - ListNode node = Search(key); - return node.IsNone() ? end() : iterator(node.index, this); - } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator& position) { - uint64_t index = position.index; - if (position.self != nullptr && index <= this->NumSlots()) { - Erase(ListNode(index, this)); - } - } - /*! \return begin iterator */ - iterator begin() const { return iterator(iter_list_head_, this); } - /*! \return end iterator */ - iterator end() const { return iterator(kInvalidIndex, this); } - - private: - Block* GetBlock(size_t index) const { return static_cast(data_) + index; } - /*! - * \brief Unlink the entry from iterator list - * \param node The node to be unlinked - * \note This function is usually used before deletion, - * and it does not change data content of the node. - */ - void IterListUnlink(ListNode node) { - // update head and tail of iterator list if needed - if (node.Item().prev == kInvalidIndex) { - iter_list_head_ = node.Item().next; - } else { - ListNode prev_node(node.Item().prev, this); - prev_node.Item().next = node.Item().next; - } - if (node.Item().next == kInvalidIndex) { - iter_list_tail_ = node.Item().prev; - } else { - ListNode next_node(node.Item().next, this); - next_node.Item().prev = node.Item().prev; - } - } - /*! - * \brief Insert the entry into tail of iterator list - * \param node The node to be inserted - * \note this function does not change data content of the node. - */ - void IterListPushBack(ListNode node) { - node.Item().prev = iter_list_tail_; - node.Item().next = kInvalidIndex; - if (iter_list_tail_ != kInvalidIndex) { - ListNode prev_node(iter_list_tail_, this); - prev_node.Item().next = node.index; - } - if (iter_list_head_ == kInvalidIndex) { - iter_list_head_ = node.index; - } - iter_list_tail_ = node.index; - } - /*! - * \brief Replace node src by dst in the iter list - * \param src The source node - * \param dst The destination node, must be empty - * \note This function does not change data content of the nodes, - * which needs to be updated by the caller. - */ - void IterListReplaceNodeBy(ListNode src, ListNode dst) { - // set link correctly on the dst - dst.Item().prev = src.Item().prev; - dst.Item().next = src.Item().next; - // update prev and next of dst - if (dst.Item().prev == kInvalidIndex) { - iter_list_head_ = dst.index; - } else { - ListNode prev_node(dst.Item().prev, this); - prev_node.Item().next = dst.index; - } - if (dst.Item().next == kInvalidIndex) { - iter_list_tail_ = dst.index; - } else { - ListNode next_node(dst.Item().next, this); - next_node.Item().prev = dst.index; +private: + /*! \brief The number of elements in a memory block */ + static constexpr int kBlockCap = 16; + /*! \brief Maximum load factor of the hash map */ + static constexpr double kMaxLoadFactor = 0.99; + /*! \brief Binary representation of the metadata of an empty slot */ + static constexpr uint8_t kEmptySlot = static_cast(0b11111111); + /*! \brief Binary representation of the metadata of a protected slot */ + static constexpr uint8_t kProtectedSlot = static_cast(0b11111110); + /*! \brief Number of probing choices available */ + static constexpr int kNumJumpDists = 126; + /*! \brief Index indicator to indicate an invalid index */ + static constexpr uint64_t kInvalidIndex = std::numeric_limits::max(); + /*! \brief Head of the implicit linked list */ + struct ListNode; + /*! \brief item type of the dense map, including a kv data and prev/next pointer */ + struct ItemType { + KVType data; + uint64_t prev = kInvalidIndex; + uint64_t next = kInvalidIndex; + + explicit ItemType(KVType &&data) : data(std::move(data)) {} + explicit ItemType(key_type key, mapped_type value) : data(std::move(key), std::move(value)) {} + }; + /*! \brief POD type of a block of memory */ + struct Block { + uint8_t bytes[kBlockCap + kBlockCap * sizeof(ItemType)]; + }; + static_assert(sizeof(Block) == kBlockCap * (sizeof(ItemType) + 1), "sizeof(Block) incorrect"); + static_assert(std::is_standard_layout_v, "Block is not standard layout"); + + /*! + * \brief Deleter for the Block + * \param data The pointer to the Block + */ + static void BlockDeleter(void *data) { delete[] static_cast(data); } + +public: + using MapObj::iterator; + + /*! + * \brief Return the number of usable slots for Dense layout (MSB clear => identity). + * \return The number of usable slots + */ + uint64_t NumSlots() const { return slots_; } + + /*! + * \brief Destroy the DenseMapObj + */ + ~DenseMapObj() { this->Reset(); } + /*! \return The number of elements of the key */ + size_t count(const key_type &key) const { return !Search(key).IsNone(); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The const reference to the value + */ + const mapped_type &at(const key_type &key) const { return At(key); } + /*! + * \brief Index value associated with a key, throw exception if the key does not exist + * \param key The indexing key + * \return The mutable reference to the value + */ + mapped_type &at(const key_type &key) { return At(key); } + /*! + * \brief Index value associated with a key + * \param key The indexing key + * \return The iterator of the entry associated with the key, end iterator if not exists + */ + iterator find(const key_type &key) const { + ListNode node = Search(key); + return node.IsNone() ? end() : iterator(node.index, this); } - } - /*! - * \brief Search for the given key - * \param key The key - * \return ListNode that associated with the key - */ - ListNode Search(const key_type& key) const { - if (this->size_ == 0) { - return ListNode(); + /*! + * \brief Erase the entry associated with the iterator + * \param position The iterator + */ + void erase(const iterator &position) { + uint64_t index = position.index; + if (position.self != nullptr && index <= this->NumSlots()) { + Erase(ListNode(index, this)); + } } - for (ListNode iter = GetListHead(AnyHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) { - if (AnyEqual()(key, iter.Key())) { - return iter; - } + /*! \return begin iterator */ + iterator begin() const { return iterator(iter_list_head_, this); } + /*! \return end iterator */ + iterator end() const { return iterator(kInvalidIndex, this); } + +private: + Block *GetBlock(size_t index) const { return static_cast(data_) + index; } + /*! + * \brief Unlink the entry from iterator list + * \param node The node to be unlinked + * \note This function is usually used before deletion, + * and it does not change data content of the node. + */ + void IterListUnlink(ListNode node) { + // update head and tail of iterator list if needed + if (node.Item().prev == kInvalidIndex) { + iter_list_head_ = node.Item().next; + } else { + ListNode prev_node(node.Item().prev, this); + prev_node.Item().next = node.Item().next; + } + if (node.Item().next == kInvalidIndex) { + iter_list_tail_ = node.Item().prev; + } else { + ListNode next_node(node.Item().next, this); + next_node.Item().prev = node.Item().prev; + } } - return ListNode(); - } - /*! - * \brief Search for the given key, throw exception if not exists - * \param key The key - * \return ListNode that associated with the key - */ - mapped_type& At(const key_type& key) const { - ListNode iter = Search(key); - if (iter.IsNone()) { - TVM_FFI_THROW(KeyError) << "key is not in Map"; + /*! + * \brief Insert the entry into tail of iterator list + * \param node The node to be inserted + * \note this function does not change data content of the node. + */ + void IterListPushBack(ListNode node) { + node.Item().prev = iter_list_tail_; + node.Item().next = kInvalidIndex; + if (iter_list_tail_ != kInvalidIndex) { + ListNode prev_node(iter_list_tail_, this); + prev_node.Item().next = node.index; + } + if (iter_list_head_ == kInvalidIndex) { + iter_list_head_ = node.index; + } + iter_list_tail_ = node.index; } - return iter.Val(); - } - /*! - * \brief Try to insert a key, or do nothing if already exists - * \param key The indexing key - * \param result The linked-list entry found or just constructed - * \return A boolean, indicating if actual insertion happens - */ - bool TryInsert(const key_type& key, ListNode* result) { - if (slots_ == 0) { - return false; + /*! + * \brief Replace node src by dst in the iter list + * \param src The source node + * \param dst The destination node, must be empty + * \note This function does not change data content of the nodes, + * which needs to be updated by the caller. + */ + void IterListReplaceNodeBy(ListNode src, ListNode dst) { + // set link correctly on the dst + dst.Item().prev = src.Item().prev; + dst.Item().next = src.Item().next; + // update prev and next of dst + if (dst.Item().prev == kInvalidIndex) { + iter_list_head_ = dst.index; + } else { + ListNode prev_node(dst.Item().prev, this); + prev_node.Item().next = dst.index; + } + if (dst.Item().next == kInvalidIndex) { + iter_list_tail_ = dst.index; + } else { + ListNode next_node(dst.Item().next, this); + next_node.Item().prev = dst.index; + } } - // required that `iter` to be the head of a linked list through which we can iterator - ListNode iter = IndexFromHash(AnyHash()(key)); - // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list - // Case 1: empty - if (iter.IsEmpty()) { - iter.NewHead(ItemType(key, Any(nullptr))); - this->size_ += 1; - *result = iter; - return true; + /*! + * \brief Search for the given key + * \param key The key + * \return ListNode that associated with the key + */ + ListNode Search(const key_type &key) const { + if (this->size_ == 0) { + return ListNode(); + } + for (ListNode iter = GetListHead(AnyHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) { + if (AnyEqual()(key, iter.Key())) { + return iter; + } + } + return ListNode(); } - // Case 2: body of an irrelevant list - if (!iter.IsHead()) { - // we move the elements around and construct the single-element linked list - return IsFull() ? false : TrySpareListHead(iter, key, result); + /*! + * \brief Search for the given key, throw exception if not exists + * \param key The key + * \return ListNode that associated with the key + */ + mapped_type &At(const key_type &key) const { + ListNode iter = Search(key); + if (iter.IsNone()) { + TVM_FFI_THROW(KeyError) << "key is not in Map"; + } + return iter.Val(); } - // Case 3: head of the relevant list - // we iterate through the linked list until the end - // make sure `iter` is the previous element of `next` - ListNode next = iter; - do { - // find equal item, do not insert - if (AnyEqual()(key, next.Key())) { - // we plan to take next, so we need to unlink it from iterator list - IterListUnlink(next); - *result = next; + /*! + * \brief Try to insert a key, or do nothing if already exists + * \param key The indexing key + * \param result The linked-list entry found or just constructed + * \return A boolean, indicating if actual insertion happens + */ + bool TryInsert(const key_type &key, ListNode *result) { + if (slots_ == 0) { + return false; + } + // required that `iter` to be the head of a linked list through which we can iterator + ListNode iter = IndexFromHash(AnyHash()(key)); + // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list + // Case 1: empty + if (iter.IsEmpty()) { + iter.NewHead(ItemType(key, Any(nullptr))); + this->size_ += 1; + *result = iter; + return true; + } + // Case 2: body of an irrelevant list + if (!iter.IsHead()) { + // we move the elements around and construct the single-element linked list + return IsFull() ? false : TrySpareListHead(iter, key, result); + } + // Case 3: head of the relevant list + // we iterate through the linked list until the end + // make sure `iter` is the previous element of `next` + ListNode next = iter; + do { + // find equal item, do not insert + if (AnyEqual()(key, next.Key())) { + // we plan to take next, so we need to unlink it from iterator list + IterListUnlink(next); + *result = next; + return true; + } + // make sure `iter` is the previous element of `next` + iter = next; + } while (next.MoveToNext(this)); + // `iter` is the tail of the linked list + // always check capacity before insertion + if (IsFull()) { + return false; + } + // find the next empty slot + uint8_t jump; + if (!iter.GetNextEmpty(this, &jump, result)) { + return false; + } + result->NewTail(ItemType(key, Any(nullptr))); + // link `iter` to `empty`, and move forward + iter.SetJump(jump); + this->size_ += 1; return true; - } - // make sure `iter` is the previous element of `next` - iter = next; - } while (next.MoveToNext(this)); - // `iter` is the tail of the linked list - // always check capacity before insertion - if (IsFull()) { - return false; - } - // find the next empty slot - uint8_t jump; - if (!iter.GetNextEmpty(this, &jump, result)) { - return false; } - result->NewTail(ItemType(key, Any(nullptr))); - // link `iter` to `empty`, and move forward - iter.SetJump(jump); - this->size_ += 1; - return true; - } - /*! - * \brief Spare an entry to be the head of a linked list. - * As described in B3, during insertion, it is possible that the entire linked list does not - * exist, but the slot of its head has been occupied by other linked lists. In this case, we need - * to spare the slot by moving away the elements to another valid empty one to make insertion - * possible. - * \param target The given entry to be spared - * \param key The indexing key - * \param result The linked-list entry constructed as the head - * \return A boolean, if actual insertion happens - */ - bool TrySpareListHead(ListNode target, const key_type& key, ListNode* result) { - // `target` is not the head of the linked list - // move the original item of `target` (if any) - // and construct new item on the position `target` - // To make `target` empty, we - // 1) find `w` the previous element of `target` in the linked list - // 2) copy the linked list starting from `r = target` - // 3) paste them after `w` - // read from the linked list after `r` - ListNode r = target; - // write to the tail of `w` - ListNode w = target.FindPrev(this); - // after `target` is moved, we disallow writing to the slot - bool is_first = true; - uint8_t r_meta, jump; - ListNode empty; - do { - // `jump` describes how `w` is jumped to `empty` - // rehash if there is no empty space after `w` - if (!w.GetNextEmpty(this, &jump, &empty)) { - return false; - } - // move `r` to `empty` - // first move the data over - empty.NewTail(ItemType(std::move(r.Data()))); - // then move link list chain of r to empty - // this needs to happen after NewTail so empty's prev/next get updated - IterListReplaceNodeBy(r, empty); - // explicit call destructor to destroy the item in `r` - r.DestructData(); - // clear the metadata of `r` - r_meta = r.Meta(); - if (is_first) { - is_first = false; - r.SetProtected(); - } else { - r.SetEmpty(); - } - // link `w` to `empty`, and move forward - w.SetJump(jump); - w = empty; - // move `r` forward as well - } while (r.MoveToNext(this, r_meta)); - // finally we have done moving the linked list - // fill data_ into `target` - target.NewHead(ItemType(key, Any(nullptr))); - this->size_ += 1; - *result = target; - return true; - } - /*! - * \brief Remove a ListNode - * \param iter The node to be removed - */ - void Erase(const ListNode& iter) { - this->size_ -= 1; - if (!iter.HasNext()) { - // `iter` is the last - if (!iter.IsHead()) { - // cut the link if there is any - iter.FindPrev(this).SetJump(0); - } - // unlink the node from iterator list - IterListUnlink(iter); - // IMPORTANT: must explicit call destructor `iter` to avoid memory leak - // This is because we need to recycle iter's data - iter.DestructData(); - // set the meta data to be empty - iter.SetEmpty(); - } else { - ListNode last = iter, prev = iter; - for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) { - } - // needs to first unlink iter from the list - IterListUnlink(iter); - // move data from last to iter - iter.Data() = std::move(last.Data()); - // Move link chain of iter to last as we stores last node to the new iter loc. - IterListReplaceNodeBy(last, iter); - // IMPORTANT: must explicit call destructor `last` to avoid memory leak - // likely we don't need this in this particular case because Any move behavior - // keep it here to be safe so code do not depend on specific move behavior of KVType - last.DestructData(); - // set the meta data to be empty - last.SetEmpty(); - prev.SetJump(0); - } - } - /*! \brief Clear the container to empty, release all entries and memory acquired */ - void Reset() { - uint64_t n_blocks = CalcNumBlocks(this->NumSlots()); - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr = GetBlock(bi)->bytes; - ItemType* data_ptr = reinterpret_cast(GetBlock(bi)->bytes + kBlockCap); - for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { - uint8_t& meta = *meta_ptr; - if (meta != kProtectedSlot && meta != kEmptySlot) { - meta = kEmptySlot; - data_ptr->ItemType::~ItemType(); - } - } + /*! + * \brief Spare an entry to be the head of a linked list. + * As described in B3, during insertion, it is possible that the entire linked list does not + * exist, but the slot of its head has been occupied by other linked lists. In this case, we need + * to spare the slot by moving away the elements to another valid empty one to make insertion + * possible. + * \param target The given entry to be spared + * \param key The indexing key + * \param result The linked-list entry constructed as the head + * \return A boolean, if actual insertion happens + */ + bool TrySpareListHead(ListNode target, const key_type &key, ListNode *result) { + // `target` is not the head of the linked list + // move the original item of `target` (if any) + // and construct new item on the position `target` + // To make `target` empty, we + // 1) find `w` the previous element of `target` in the linked list + // 2) copy the linked list starting from `r = target` + // 3) paste them after `w` + // read from the linked list after `r` + ListNode r = target; + // write to the tail of `w` + ListNode w = target.FindPrev(this); + // after `target` is moved, we disallow writing to the slot + bool is_first = true; + uint8_t r_meta, jump; + ListNode empty; + do { + // `jump` describes how `w` is jumped to `empty` + // rehash if there is no empty space after `w` + if (!w.GetNextEmpty(this, &jump, &empty)) { + return false; + } + // move `r` to `empty` + // first move the data over + empty.NewTail(ItemType(std::move(r.Data()))); + // then move link list chain of r to empty + // this needs to happen after NewTail so empty's prev/next get updated + IterListReplaceNodeBy(r, empty); + // explicit call destructor to destroy the item in `r` + r.DestructData(); + // clear the metadata of `r` + r_meta = r.Meta(); + if (is_first) { + is_first = false; + r.SetProtected(); + } else { + r.SetEmpty(); + } + // link `w` to `empty`, and move forward + w.SetJump(jump); + w = empty; + // move `r` forward as well + } while (r.MoveToNext(this, r_meta)); + // finally we have done moving the linked list + // fill data_ into `target` + target.NewHead(ItemType(key, Any(nullptr))); + this->size_ += 1; + *result = target; + return true; } - ReleaseMemory(); - } - /*! \brief Release the memory acquired by the container without deleting its entries stored inside - */ - void ReleaseMemory() { - if (data_ != nullptr) { - TVM_FFI_ICHECK(data_deleter_ != nullptr); - data_deleter_(data_); + /*! + * \brief Remove a ListNode + * \param iter The node to be removed + */ + void Erase(const ListNode &iter) { + this->size_ -= 1; + if (!iter.HasNext()) { + // `iter` is the last + if (!iter.IsHead()) { + // cut the link if there is any + iter.FindPrev(this).SetJump(0); + } + // unlink the node from iterator list + IterListUnlink(iter); + // IMPORTANT: must explicit call destructor `iter` to avoid memory leak + // This is because we need to recycle iter's data + iter.DestructData(); + // set the meta data to be empty + iter.SetEmpty(); + } else { + ListNode last = iter, prev = iter; + for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) { + } + // needs to first unlink iter from the list + IterListUnlink(iter); + // move data from last to iter + iter.Data() = std::move(last.Data()); + // Move link chain of iter to last as we stores last node to the new iter loc. + IterListReplaceNodeBy(last, iter); + // IMPORTANT: must explicit call destructor `last` to avoid memory leak + // likely we don't need this in this particular case because Any move behavior + // keep it here to be safe so code do not depend on specific move behavior of KVType + last.DestructData(); + // set the meta data to be empty + last.SetEmpty(); + prev.SetJump(0); + } } - data_ = nullptr; - data_deleter_ = nullptr; - slots_ = 0; - size_ = 0; - fib_shift_ = 63; - } - /*! - * \brief Create an empty container - * \param fib_shift The fib shift provided - * \param n_slots Number of slots required, should be power-of-two - * \return The object created - */ - static ObjectPtr Empty(uint32_t fib_shift, uint64_t n_slots) { - TVM_FFI_ICHECK_GT(n_slots, uint64_t(SmallMapObj::kMaxSize)); - // Ensure even slot count (power-of-two expected by callers; this guard - // makes the method robust if a non-even value slips through). - ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(n_slots); - Block* block = new Block[n_blocks]; - p->data_ = block; - // assign block deleter so even if we take re-alloc data - // in another shared-lib that may have different malloc/free behavior - // it will still be safe. - p->data_deleter_ = BlockDeleter; - p->SetSlotsAndDenseLayoutTag(n_slots); - p->size_ = 0; - p->fib_shift_ = fib_shift; - p->iter_list_head_ = kInvalidIndex; - p->iter_list_tail_ = kInvalidIndex; - for (uint64_t i = 0; i < n_blocks; ++i, ++block) { - std::fill(block->bytes, block->bytes + kBlockCap, kEmptySlot); + /*! \brief Clear the container to empty, release all entries and memory acquired */ + void Reset() { + uint64_t n_blocks = CalcNumBlocks(this->NumSlots()); + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t *meta_ptr = GetBlock(bi)->bytes; + ItemType *data_ptr = reinterpret_cast(GetBlock(bi)->bytes + kBlockCap); + for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { + uint8_t &meta = *meta_ptr; + if (meta != kProtectedSlot && meta != kEmptySlot) { + meta = kEmptySlot; + data_ptr->ItemType::~ItemType(); + } + } + } + ReleaseMemory(); } - return p; - } - /*! - * \brief Create an empty container with elements copying from another DenseMapObj - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(DenseMapObj* from) { - ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(from->NumSlots()); - p->data_ = new Block[n_blocks]; - // assign block deleter so even if we take re-alloc data - // in another shared-lib that may have different malloc/free behavior - // it will still be safe. - p->data_deleter_ = BlockDeleter; - p->SetSlotsAndDenseLayoutTag(from->NumSlots()); - p->size_ = from->size_; - p->fib_shift_ = from->fib_shift_; - p->iter_list_head_ = from->iter_list_head_; - p->iter_list_tail_ = from->iter_list_tail_; - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t* meta_ptr_from = from->GetBlock(bi)->bytes; - ItemType* data_ptr_from = reinterpret_cast(from->GetBlock(bi)->bytes + kBlockCap); - uint8_t* meta_ptr_to = p->GetBlock(bi)->bytes; - ItemType* data_ptr_to = reinterpret_cast(p->GetBlock(bi)->bytes + kBlockCap); - for (int j = 0; j < kBlockCap; - ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) { - uint8_t& meta = *meta_ptr_to = *meta_ptr_from; - TVM_FFI_ICHECK(meta != kProtectedSlot); - if (meta != kEmptySlot) { - new (data_ptr_to) ItemType(*data_ptr_from); - } - } + /*! \brief Release the memory acquired by the container without deleting its entries stored inside + */ + void ReleaseMemory() { + if (data_ != nullptr) { + TVM_FFI_ICHECK(data_deleter_ != nullptr); + data_deleter_(data_); + } + data_ = nullptr; + data_deleter_ = nullptr; + slots_ = 0; + size_ = 0; + fib_shift_ = 63; } - return p; - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(KVType&& kv, ObjectPtr* map) { - DenseMapObj* map_node = static_cast(map->get()); - ListNode iter; - // Try to insert. If succeed, we simply return - if (map_node->TryInsert(kv.first, &iter)) { - iter.Val() = std::move(kv.second); - // update the iter list relation - map_node->IterListPushBack(iter); - return; + /*! + * \brief Create an empty container + * \param fib_shift The fib shift provided + * \param n_slots Number of slots required, should be power-of-two + * \return The object created + */ + static ObjectPtr Empty(uint32_t fib_shift, uint64_t n_slots) { + TVM_FFI_ICHECK_GT(n_slots, uint64_t(SmallMapObj::kMaxSize)); + // Ensure even slot count (power-of-two expected by callers; this guard + // makes the method robust if a non-even value slips through). + ObjectPtr p = make_object(); + uint64_t n_blocks = CalcNumBlocks(n_slots); + Block *block = new Block[n_blocks]; + p->data_ = block; + // assign block deleter so even if we take re-alloc data + // in another shared-lib that may have different malloc/free behavior + // it will still be safe. + p->data_deleter_ = BlockDeleter; + p->SetSlotsAndDenseLayoutTag(n_slots); + p->size_ = 0; + p->fib_shift_ = fib_shift; + p->iter_list_head_ = kInvalidIndex; + p->iter_list_tail_ = kInvalidIndex; + for (uint64_t i = 0; i < n_blocks; ++i, ++block) { + std::fill(block->bytes, block->bytes + kBlockCap, kEmptySlot); + } + return p; } - TVM_FFI_ICHECK(!map_node->IsSmallMap()); - // Otherwise, start rehash - ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->NumSlots() * 2); - - // need to insert in the same order as the original map - for (uint64_t index = map_node->iter_list_head_; index != kInvalidIndex;) { - ListNode node(index, map_node); - // now try move src_data into the new map, note that src may still not - // be fully consumed into the call, but destructor will be called. - InsertMaybeReHash(std::move(node.Data()), &p); - // Important, needs to explicit call destructor in case move did remove - // node's internal item - index = node.Item().next; - // IMPORTANT: must explicit call destructor `node` to avoid memory leak - // We must call node.DestructData() here. - // This is because std::move() arguments in IterMaybeReHash may or may not - // explicitly move out the node.Data() - // Remove this call will cause memory leak very likely. - node.DestructData(); + /*! + * \brief Create an empty container with elements copying from another DenseMapObj + * \param from The source container + * \return The object created + */ + static ObjectPtr CopyFrom(DenseMapObj *from) { + ObjectPtr p = make_object(); + uint64_t n_blocks = CalcNumBlocks(from->NumSlots()); + p->data_ = new Block[n_blocks]; + // assign block deleter so even if we take re-alloc data + // in another shared-lib that may have different malloc/free behavior + // it will still be safe. + p->data_deleter_ = BlockDeleter; + p->SetSlotsAndDenseLayoutTag(from->NumSlots()); + p->size_ = from->size_; + p->fib_shift_ = from->fib_shift_; + p->iter_list_head_ = from->iter_list_head_; + p->iter_list_tail_ = from->iter_list_tail_; + for (uint64_t bi = 0; bi < n_blocks; ++bi) { + uint8_t *meta_ptr_from = from->GetBlock(bi)->bytes; + ItemType *data_ptr_from = reinterpret_cast(from->GetBlock(bi)->bytes + kBlockCap); + uint8_t *meta_ptr_to = p->GetBlock(bi)->bytes; + ItemType *data_ptr_to = reinterpret_cast(p->GetBlock(bi)->bytes + kBlockCap); + for (int j = 0; j < kBlockCap; + ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) { + uint8_t &meta = *meta_ptr_to = *meta_ptr_from; + TVM_FFI_ICHECK(meta != kProtectedSlot); + if (meta != kEmptySlot) { + new (data_ptr_to) ItemType(*data_ptr_from); + } + } + } + return p; } - InsertMaybeReHash(std::move(kv), &p); - map_node->ReleaseMemory(); - *map = p; - } - /*! - * \brief Check whether the hash table is full - * \return A boolean indicating whether hash table is full - */ - bool IsFull() const { // NOLINTNEXTLINE(bugprone-narrowing-conversions) - return (size_ + 1) > static_cast(NumSlots()) * kMaxLoadFactor; - } - /*! - * \brief Increment the pointer - * \param index The pointer to be incremented - * \return The increased pointer - */ - uint64_t IncItr(uint64_t index) const { - // keep at the end of iterator - if (index == kInvalidIndex) { - return index; + /*! + * \brief InsertMaybeReHash an entry into the given hash map + * \param kv The entry to be inserted + * \param map The pointer to the map, can be changed if re-hashing happens + */ + static void InsertMaybeReHash(KVType &&kv, ObjectPtr *map) { + DenseMapObj *map_node = static_cast(map->get()); + ListNode iter; + // Try to insert. If succeed, we simply return + if (map_node->TryInsert(kv.first, &iter)) { + iter.Val() = std::move(kv.second); + // update the iter list relation + map_node->IterListPushBack(iter); + return; + } + TVM_FFI_ICHECK(!map_node->IsSmallMap()); + // Otherwise, start rehash + ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->NumSlots() * 2); + + // need to insert in the same order as the original map + for (uint64_t index = map_node->iter_list_head_; index != kInvalidIndex;) { + ListNode node(index, map_node); + // now try move src_data into the new map, note that src may still not + // be fully consumed into the call, but destructor will be called. + InsertMaybeReHash(std::move(node.Data()), &p); + // Important, needs to explicit call destructor in case move did remove + // node's internal item + index = node.Item().next; + // IMPORTANT: must explicit call destructor `node` to avoid memory leak + // We must call node.DestructData() here. + // This is because std::move() arguments in IterMaybeReHash may or may not + // explicitly move out the node.Data() + // Remove this call will cause memory leak very likely. + node.DestructData(); + } + InsertMaybeReHash(std::move(kv), &p); + map_node->ReleaseMemory(); + *map = p; } - ListNode node(index, this); - return node.Item().next; - } - /*! - * \brief Decrement the pointer - * \param index The pointer to be decremented - * \return The decreased pointer - */ - uint64_t DecItr(uint64_t index) const { - // this is the end iterator, we need to return tail. - if (index == kInvalidIndex) { - return iter_list_tail_; + /*! + * \brief Check whether the hash table is full + * \return A boolean indicating whether hash table is full + */ + bool IsFull() const { // NOLINTNEXTLINE(bugprone-narrowing-conversions) + return (size_ + 1) > static_cast(NumSlots()) * kMaxLoadFactor; } - // circle around the iterator list, which is OK - ListNode node(index, this); - return node.Item().prev; - } - /*! - * \brief De-reference the pointer - * \param index The pointer to be dereferenced - * \return The result - */ - KVType* DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); } - /*! \brief Construct from hash code */ - ListNode IndexFromHash(uint64_t hash_value) const { - return ListNode(FibHash(hash_value, fib_shift_), this); - } - /*! \brief Construct from hash code if the position is head of list */ - ListNode GetListHead(uint64_t hash_value) const { - ListNode node = IndexFromHash(hash_value); - return node.IsHead() ? node : ListNode(); - } - /*! \brief Construct the number of blocks in the hash table */ - static uint64_t CalcNumBlocks(uint64_t n_slots) { return (n_slots + kBlockCap - 1) / kBlockCap; } - /*! - * \brief Calculate the power-of-2 table size given the lower-bound of required capacity. - * \param cap The lower-bound of the required capacity - * \param fib_shift The result shift for Fibonacci Hashing - * \param n_slots The result number of slots - */ - static void CalcTableSize(uint64_t cap, uint32_t* fib_shift, uint64_t* n_slots) { - uint32_t shift = 64; - uint64_t slots = 1; - for (uint64_t c = cap; c; c >>= 1) { - shift -= 1; - slots <<= 1; + /*! + * \brief Increment the pointer + * \param index The pointer to be incremented + * \return The increased pointer + */ + uint64_t IncItr(uint64_t index) const { + // keep at the end of iterator + if (index == kInvalidIndex) { + return index; + } + ListNode node(index, this); + return node.Item().next; } - TVM_FFI_ICHECK_GT(slots, cap); - if (slots < cap * 2) { - *fib_shift = shift - 1; - *n_slots = slots << 1; - } else { - *fib_shift = shift; - *n_slots = slots; + /*! + * \brief Decrement the pointer + * \param index The pointer to be decremented + * \return The decreased pointer + */ + uint64_t DecItr(uint64_t index) const { + // this is the end iterator, we need to return tail. + if (index == kInvalidIndex) { + return iter_list_tail_; + } + // circle around the iterator list, which is OK + ListNode node(index, this); + return node.Item().prev; } - } - /*! - * \brief Fibonacci Hashing, maps a hash code to an index in a power-of-2-sized table. - * See also: https://programmingpraxis.com/2018/06/19/fibonacci-hash/. - * \param hash_value The raw hash value - * \param fib_shift The shift in Fibonacci Hashing - * \return An index calculated using Fibonacci Hashing - */ - static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) { - constexpr uint64_t coeff = 11400714819323198485ull; - return (coeff * hash_value) >> fib_shift; - } - /*! \brief The implicit in-place linked list used to index a chain */ - struct ListNode { - /*! \brief Construct None */ - ListNode() : index(0), block(nullptr) {} - /*! \brief Construct from position */ - ListNode(uint64_t index, const DenseMapObj* self) - : index(index), block(self->GetBlock(index / kBlockCap)) {} - /*! \brief Metadata on the entry */ - uint8_t& Meta() const { return *(block->bytes + index % kBlockCap); } - /*! \brief Data on the entry */ - ItemType& Item() const { - return *(reinterpret_cast(block->bytes + kBlockCap + - (index % kBlockCap) * sizeof(ItemType))); + /*! + * \brief De-reference the pointer + * \param index The pointer to be dereferenced + * \return The result + */ + KVType *DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); } + /*! \brief Construct from hash code */ + ListNode IndexFromHash(uint64_t hash_value) const { + return ListNode(FibHash(hash_value, fib_shift_), this); } - /*! \brief Data on the entry */ - KVType& Data() const { return Item().data; } - /*! \brief Key on the entry */ - key_type& Key() const { return Data().first; } - /*! \brief Value on the entry */ - mapped_type& Val() const { return Data().second; } - /*! \brief If the entry is head of linked list */ - bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; } - /*! \brief If the entry is none */ - bool IsNone() const { return block == nullptr; } - /*! \brief If the entry is empty slot */ - bool IsEmpty() const { return Meta() == kEmptySlot; } - /*! \brief If the entry is protected slot */ - bool IsProtected() const { return Meta() == kProtectedSlot; } - /*! \brief Set the entry to be empty */ - void SetEmpty() const { Meta() = kEmptySlot; } - /*! \brief Destruct the item in the entry */ - void DestructData() const { - // explicit call destructor to destroy the item - // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) - (&Data())->first.Any::~Any(); - (&Data())->second.Any::~Any(); + /*! \brief Construct from hash code if the position is head of list */ + ListNode GetListHead(uint64_t hash_value) const { + ListNode node = IndexFromHash(hash_value); + return node.IsHead() ? node : ListNode(); } - /*! \brief Set the entry to be protected */ - void SetProtected() const { Meta() = kProtectedSlot; } - /*! \brief Set the entry's jump to its next entry */ - void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; } - /*! \brief Construct a head of linked list in-place */ - void NewHead(ItemType v) const { - Meta() = 0b00000000; - new (&Item()) ItemType(std::move(v)); + /*! \brief Construct the number of blocks in the hash table */ + static uint64_t CalcNumBlocks(uint64_t n_slots) { return (n_slots + kBlockCap - 1) / kBlockCap; } + /*! + * \brief Calculate the power-of-2 table size given the lower-bound of required capacity. + * \param cap The lower-bound of the required capacity + * \param fib_shift The result shift for Fibonacci Hashing + * \param n_slots The result number of slots + */ + static void CalcTableSize(uint64_t cap, uint32_t *fib_shift, uint64_t *n_slots) { + uint32_t shift = 64; + uint64_t slots = 1; + for (uint64_t c = cap; c; c >>= 1) { + shift -= 1; + slots <<= 1; + } + TVM_FFI_ICHECK_GT(slots, cap); + if (slots < cap * 2) { + *fib_shift = shift - 1; + *n_slots = slots << 1; + } else { + *fib_shift = shift; + *n_slots = slots; + } } - /*! \brief Construct a tail of linked list in-place */ - void NewTail(ItemType v) const { - Meta() = 0b10000000; - new (&Item()) ItemType(std::move(v)); + /*! + * \brief Fibonacci Hashing, maps a hash code to an index in a power-of-2-sized table. + * See also: https://programmingpraxis.com/2018/06/19/fibonacci-hash/. + * \param hash_value The raw hash value + * \param fib_shift The shift in Fibonacci Hashing + * \return An index calculated using Fibonacci Hashing + */ + static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) { + constexpr uint64_t coeff = 11400714819323198485ull; + return (coeff * hash_value) >> fib_shift; } + /*! \brief The implicit in-place linked list used to index a chain */ + struct ListNode { + /*! \brief Construct None */ + ListNode() : index(0), block(nullptr) {} + /*! \brief Construct from position */ + ListNode(uint64_t index, const DenseMapObj *self) + : index(index), block(self->GetBlock(index / kBlockCap)) {} + /*! \brief Metadata on the entry */ + uint8_t &Meta() const { return *(block->bytes + index % kBlockCap); } + /*! \brief Data on the entry */ + ItemType &Item() const { + return *(reinterpret_cast(block->bytes + kBlockCap + (index % kBlockCap) * sizeof(ItemType))); + } + /*! \brief Data on the entry */ + KVType &Data() const { return Item().data; } + /*! \brief Key on the entry */ + key_type &Key() const { return Data().first; } + /*! \brief Value on the entry */ + mapped_type &Val() const { return Data().second; } + /*! \brief If the entry is head of linked list */ + bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; } + /*! \brief If the entry is none */ + bool IsNone() const { return block == nullptr; } + /*! \brief If the entry is empty slot */ + bool IsEmpty() const { return Meta() == kEmptySlot; } + /*! \brief If the entry is protected slot */ + bool IsProtected() const { return Meta() == kProtectedSlot; } + /*! \brief Set the entry to be empty */ + void SetEmpty() const { Meta() = kEmptySlot; } + /*! \brief Destruct the item in the entry */ + void DestructData() const { + // explicit call destructor to destroy the item + // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) + (&Data())->first.Any::~Any(); + (&Data())->second.Any::~Any(); + } + /*! \brief Set the entry to be protected */ + void SetProtected() const { Meta() = kProtectedSlot; } + /*! \brief Set the entry's jump to its next entry */ + void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; } + /*! \brief Construct a head of linked list in-place */ + void NewHead(ItemType v) const { + Meta() = 0b00000000; + new (&Item()) ItemType(std::move(v)); + } + /*! \brief Construct a tail of linked list in-place */ + void NewTail(ItemType v) const { + Meta() = 0b10000000; + new (&Item()) ItemType(std::move(v)); + } - /*! \brief If the entry has next entry on the linked list */ - bool HasNext() const { return NextProbeLocation(Meta() & 0b01111111) != 0; } - /*! \brief Move the entry to the next entry on the linked list */ - bool MoveToNext(const DenseMapObj* self, uint8_t meta) { - uint64_t offset = NextProbeLocation(meta & 0b01111111); - if (offset == 0) { - index = 0; - block = nullptr; - return false; - } - // the probing will go to next position and round back to stay within the - // correct range of the slots - index = (index + offset) % self->NumSlots(); - block = self->GetBlock(index / kBlockCap); - return true; - } - /*! \brief Move the entry to the next entry on the linked list */ - bool MoveToNext(const DenseMapObj* self) { return MoveToNext(self, Meta()); } - /*! \brief Get the previous entry on the linked list */ - ListNode FindPrev(const DenseMapObj* self) const { - // start from the head of the linked list, which must exist - ListNode next = self->IndexFromHash(AnyHash()(Key())); - // `prev` is always the previous item of `next` - ListNode prev = next; - for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) { - } - return prev; - } - /*! \brief Get the next empty jump */ - bool GetNextEmpty(const DenseMapObj* self, uint8_t* jump, ListNode* result) const { - for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) { - // the probing will go to next position and round back to stay within the - // correct range of the slots - ListNode candidate((index + NextProbeLocation(idx)) % self->NumSlots(), self); - if (candidate.IsEmpty()) { - *jump = idx; - *result = candidate; - return true; - } - } - return false; - } - /*! \brief Index on the real array */ - uint64_t index; - /*! \brief Pointer to the actual block */ - Block* block; - }; - - protected: - /*! \brief fib shift in Fibonacci Hashing */ - uint32_t fib_shift_; - /*! \brief the head of iterator list */ - uint64_t iter_list_head_ = kInvalidIndex; - /*! \brief the tail of iterator list */ - uint64_t iter_list_tail_ = kInvalidIndex; - - static uint64_t NextProbeLocation(size_t index) { - /* clang-format off */ + /*! \brief If the entry has next entry on the linked list */ + bool HasNext() const { return NextProbeLocation(Meta() & 0b01111111) != 0; } + /*! \brief Move the entry to the next entry on the linked list */ + bool MoveToNext(const DenseMapObj *self, uint8_t meta) { + uint64_t offset = NextProbeLocation(meta & 0b01111111); + if (offset == 0) { + index = 0; + block = nullptr; + return false; + } + // the probing will go to next position and round back to stay within the + // correct range of the slots + index = (index + offset) % self->NumSlots(); + block = self->GetBlock(index / kBlockCap); + return true; + } + /*! \brief Move the entry to the next entry on the linked list */ + bool MoveToNext(const DenseMapObj *self) { return MoveToNext(self, Meta()); } + /*! \brief Get the previous entry on the linked list */ + ListNode FindPrev(const DenseMapObj *self) const { + // start from the head of the linked list, which must exist + ListNode next = self->IndexFromHash(AnyHash()(Key())); + // `prev` is always the previous item of `next` + ListNode prev = next; + for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) { + } + return prev; + } + /*! \brief Get the next empty jump */ + bool GetNextEmpty(const DenseMapObj *self, uint8_t *jump, ListNode *result) const { + for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) { + // the probing will go to next position and round back to stay within the + // correct range of the slots + ListNode candidate((index + NextProbeLocation(idx)) % self->NumSlots(), self); + if (candidate.IsEmpty()) { + *jump = idx; + *result = candidate; + return true; + } + } + return false; + } + /*! \brief Index on the real array */ + uint64_t index; + /*! \brief Pointer to the actual block */ + Block *block; + }; + +protected: + /*! \brief fib shift in Fibonacci Hashing */ + uint32_t fib_shift_; + /*! \brief the head of iterator list */ + uint64_t iter_list_head_ = kInvalidIndex; + /*! \brief the tail of iterator list */ + uint64_t iter_list_tail_ = kInvalidIndex; + + static uint64_t NextProbeLocation(size_t index) { + /* clang-format off */ /*! \brief Candidates of probing distance */ static const uint64_t kNextProbeLocation[kNumJumpDists] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, @@ -1199,96 +1198,96 @@ class DenseMapObj : public MapObj { 17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, 457381325854679626, 1029107982097042876, 2315492959180353330, 5209859154120846435, }; - /* clang-format on */ - return kNextProbeLocation[index]; - } - friend class MapObj; - - private: - /*! - * \brief Set the number of slots and attach tags bit. - * \param n The number of slots - */ - void SetSlotsAndDenseLayoutTag(uint64_t n) { - TVM_FFI_ICHECK(((n & kSmallTagMask) == 0ull)) << "DenseMap expects MSB clear"; - slots_ = n; - } + /* clang-format on */ + return kNextProbeLocation[index]; + } + friend class MapObj; + +private: + /*! + * \brief Set the number of slots and attach tags bit. + * \param n The number of slots + */ + void SetSlotsAndDenseLayoutTag(uint64_t n) { + TVM_FFI_ICHECK(((n & kSmallTagMask) == 0ull)) << "DenseMap expects MSB clear"; + slots_ = n; + } }; /// \cond -#define TVM_FFI_DISPATCH_MAP(base, var, body) \ - { \ - using TSmall = SmallMapObj*; \ - using TDense = DenseMapObj*; \ - if ((base)->IsSmallMap()) { \ - TSmall var = static_cast((base)); \ - body; \ - } else { \ - TDense var = static_cast((base)); \ - body; \ - } \ - } - -#define TVM_FFI_DISPATCH_MAP_CONST(base, var, body) \ - { \ - using TSmall = const SmallMapObj*; \ - using TDense = const DenseMapObj*; \ - if ((base)->IsSmallMap()) { \ - TSmall var = static_cast((base)); \ - body; \ - } else { \ - TDense var = static_cast((base)); \ - body; \ - } \ - } +#define TVM_FFI_DISPATCH_MAP(base, var, body) \ + { \ + using TSmall = SmallMapObj *; \ + using TDense = DenseMapObj *; \ + if ((base)->IsSmallMap()) { \ + TSmall var = static_cast((base)); \ + body; \ + } else { \ + TDense var = static_cast((base)); \ + body; \ + } \ + } + +#define TVM_FFI_DISPATCH_MAP_CONST(base, var, body) \ + { \ + using TSmall = const SmallMapObj *; \ + using TDense = const DenseMapObj *; \ + if ((base)->IsSmallMap()) { \ + TSmall var = static_cast((base)); \ + body; \ + } else { \ + TDense var = static_cast((base)); \ + body; \ + } \ + } inline MapObj::iterator::pointer MapObj::iterator::operator->() const { - TVM_FFI_MAP_FAIL_IF_CHANGED() - TVM_FFI_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); + TVM_FFI_MAP_FAIL_IF_CHANGED() + TVM_FFI_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); } -inline MapObj::iterator& MapObj::iterator::operator++() { - TVM_FFI_MAP_FAIL_IF_CHANGED() - TVM_FFI_DISPATCH_MAP_CONST(self, p, { - index = p->IncItr(index); - return *this; - }); +inline MapObj::iterator &MapObj::iterator::operator++() { + TVM_FFI_MAP_FAIL_IF_CHANGED() + TVM_FFI_DISPATCH_MAP_CONST(self, p, { + index = p->IncItr(index); + return *this; + }); } -inline MapObj::iterator& MapObj::iterator::operator--() { - TVM_FFI_MAP_FAIL_IF_CHANGED() - TVM_FFI_DISPATCH_MAP_CONST(self, p, { - index = p->DecItr(index); - return *this; - }); +inline MapObj::iterator &MapObj::iterator::operator--() { + TVM_FFI_MAP_FAIL_IF_CHANGED() + TVM_FFI_DISPATCH_MAP_CONST(self, p, { + index = p->DecItr(index); + return *this; + }); } -inline size_t MapObj::count(const key_type& key) const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->count(key); }); +inline size_t MapObj::count(const key_type &key) const { + TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->count(key); }); } -inline const MapObj::mapped_type& MapObj::at(const MapObj::key_type& key) const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->at(key); }); +inline const MapObj::mapped_type &MapObj::at(const MapObj::key_type &key) const { + TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->at(key); }); } -inline MapObj::mapped_type& MapObj::at(const MapObj::key_type& key) { - TVM_FFI_DISPATCH_MAP(this, p, { return p->at(key); }); +inline MapObj::mapped_type &MapObj::at(const MapObj::key_type &key) { + TVM_FFI_DISPATCH_MAP(this, p, { return p->at(key); }); } inline MapObj::iterator MapObj::begin() const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->begin(); }); + TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->begin(); }); } inline MapObj::iterator MapObj::end() const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->end(); }); + TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->end(); }); } -inline MapObj::iterator MapObj::find(const MapObj::key_type& key) const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->find(key); }); +inline MapObj::iterator MapObj::find(const MapObj::key_type &key) const { + TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->find(key); }); } -inline void MapObj::erase(const MapObj::iterator& position) { - TVM_FFI_DISPATCH_MAP(this, p, { return p->erase(position); }); +inline void MapObj::erase(const MapObj::iterator &position) { + TVM_FFI_DISPATCH_MAP(this, p, { return p->erase(position); }); } /// \endcond @@ -1297,66 +1296,66 @@ inline void MapObj::erase(const MapObj::iterator& position) { inline ObjectPtr MapObj::Empty() { return SmallMapObj::Empty(); } -inline ObjectPtr MapObj::CopyFrom(MapObj* from) { - if (from->IsSmallMap()) { - return SmallMapObj::CopyFrom(static_cast(from)); - } else { - return DenseMapObj::CopyFrom(static_cast(from)); - } +inline ObjectPtr MapObj::CopyFrom(MapObj *from) { + if (from->IsSmallMap()) { + return SmallMapObj::CopyFrom(static_cast(from)); + } else { + return DenseMapObj::CopyFrom(static_cast(from)); + } } template inline ObjectPtr MapObj::CreateFromRange(IterType first, IterType last) { - int64_t _cap = std::distance(first, last); - if (_cap < 0) { - return SmallMapObj::Empty(); - } - uint64_t cap = static_cast(_cap); - if (cap < SmallMapObj::kMaxSize) { - if (cap < 2) { - return SmallMapObj::CreateFromRange(cap, first, last); + int64_t _cap = std::distance(first, last); + if (_cap < 0) { + return SmallMapObj::Empty(); } - // need to insert to avoid duplicate keys - ObjectPtr obj = SmallMapObj::Empty(cap); - for (; first != last; ++first) { - KVType kv(*first); - SmallMapObj::InsertMaybeReHash(std::move(kv), &obj); - } - return obj; - } else { - uint32_t fib_shift; - uint64_t n_slots; - DenseMapObj::CalcTableSize(cap, &fib_shift, &n_slots); - ObjectPtr obj = DenseMapObj::Empty(fib_shift, n_slots); - for (; first != last; ++first) { - KVType kv(*first); - DenseMapObj::InsertMaybeReHash(std::move(kv), &obj); + uint64_t cap = static_cast(_cap); + if (cap < SmallMapObj::kMaxSize) { + if (cap < 2) { + return SmallMapObj::CreateFromRange(cap, first, last); + } + // need to insert to avoid duplicate keys + ObjectPtr obj = SmallMapObj::Empty(cap); + for (; first != last; ++first) { + KVType kv(*first); + SmallMapObj::InsertMaybeReHash(std::move(kv), &obj); + } + return obj; + } else { + uint32_t fib_shift; + uint64_t n_slots; + DenseMapObj::CalcTableSize(cap, &fib_shift, &n_slots); + ObjectPtr obj = DenseMapObj::Empty(fib_shift, n_slots); + for (; first != last; ++first) { + KVType kv(*first); + DenseMapObj::InsertMaybeReHash(std::move(kv), &obj); + } + return obj; } - return obj; - } } -inline void MapObj::InsertMaybeReHash(KVType&& kv, ObjectPtr* map) { - MapObj* base = static_cast(map->get()); +inline void MapObj::InsertMaybeReHash(KVType &&kv, ObjectPtr *map) { + MapObj *base = static_cast(map->get()); #if TVM_FFI_DEBUG_WITH_ABI_CHANGE - base->state_marker++; -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - if (base->IsSmallMap()) { - SmallMapObj* sm = static_cast(base); - if (sm->NumSlots() < SmallMapObj::kMaxSize) { - SmallMapObj::InsertMaybeReHash(std::move(kv), map); - } else if (sm->NumSlots() == SmallMapObj::kMaxSize) { - if (base->size_ < sm->NumSlots()) { - SmallMapObj::InsertMaybeReHash(std::move(kv), map); - } else { - ObjectPtr new_map = MapObj::CreateFromRange(base->begin(), base->end()); - DenseMapObj::InsertMaybeReHash(std::move(kv), &new_map); - *map = std::move(new_map); - } + base->state_marker++; +#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE + if (base->IsSmallMap()) { + SmallMapObj *sm = static_cast(base); + if (sm->NumSlots() < SmallMapObj::kMaxSize) { + SmallMapObj::InsertMaybeReHash(std::move(kv), map); + } else if (sm->NumSlots() == SmallMapObj::kMaxSize) { + if (base->size_ < sm->NumSlots()) { + SmallMapObj::InsertMaybeReHash(std::move(kv), map); + } else { + ObjectPtr new_map = MapObj::CreateFromRange(base->begin(), base->end()); + DenseMapObj::InsertMaybeReHash(std::move(kv), &new_map); + *map = std::move(new_map); + } + } + } else { + DenseMapObj::InsertMaybeReHash(std::move(kv), map); } - } else { - DenseMapObj::InsertMaybeReHash(std::move(kv), map); - } } /// \cond Doxygen_Suppress @@ -1378,282 +1377,277 @@ inline ObjectPtr make_object<>() = delete; * \tparam V The value NodeRef type. */ template && - details::storage_enabled_v>> + typename = typename std::enable_if_t && details::storage_enabled_v>> class Map : public ObjectRef { - public: - /*! \brief The key type of the map */ - using key_type = K; - /*! \brief The mapped type of the map */ - using mapped_type = V; - /*! \brief The iterator type of the map */ - class iterator; - /*! - * \brief Construct an Map with UnsafeInit - */ - explicit Map(UnsafeInit tag) : ObjectRef(tag) {} - /*! - * \brief default constructor - */ - Map() { data_ = MapObj::Empty(); } - /*! - * \brief move constructor - * \param other source - */ - Map(Map&& other) // NOLINT(google-explicit-constructor) - : ObjectRef(std::move(other.data_)) {} - /*! - * \brief copy constructor - * \param other source - */ - Map(const Map& other) // NOLINT(google-explicit-constructor) - : ObjectRef(other.data_) {} - - /*! - * \brief Move constructor - * \param other The other map - * \tparam KU The key type of the other map - * \tparam VU The mapped type of the other map - */ - template && - details::type_contains_v>> - Map(Map&& other) // NOLINT(google-explicit-constructor) - : ObjectRef(std::move(other.data_)) {} - - /*! - * \brief Copy constructor - * \param other The other map - * \tparam KU The key type of the other map - * \tparam VU The mapped type of the other map - */ - template && - details::type_contains_v>> - Map(const Map& other) : ObjectRef(other.data_) {} // NOLINT(google-explicit-constructor) - - /*! - * \brief Move assignment - * \param other The other map - */ - Map& operator=(Map&& other) { - data_ = std::move(other.data_); - return *this; - } - - /*! - * \brief Copy assignment - * \param other The other map - */ - Map& operator=(const Map& other) { - data_ = other.data_; - return *this; - } - - /*! - * \brief Move assignment - * \param other The other map - * \tparam KU The key type of the other map - * \tparam VU The mapped type of the other map - */ - template && - details::type_contains_v>> - Map& operator=(Map&& other) { - data_ = std::move(other.data_); - return *this; - } - - /*! - * \brief Copy assignment - * \param other The other map - * \tparam KU The key type of the other map - * \tparam VU The mapped type of the other map - */ - template && - details::type_contains_v>> - Map& operator=(const Map& other) { - data_ = other.data_; - return *this; - } - /*! - * \brief constructor from pointer - * \param n the container pointer - */ - explicit Map(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - Map(IterType begin, IterType end) { - data_ = MapObj::CreateFromRange(begin, end); - } - /*! - * \brief constructor from initializer list - * \param init The initalizer list - */ - Map(std::initializer_list> init) { - data_ = MapObj::CreateFromRange(init.begin(), init.end()); - } - /*! - * \brief constructor from unordered_map - * \param init The unordered_map - */ - template - Map(const std::unordered_map& init) { // NOLINT(*) - data_ = MapObj::CreateFromRange(init.begin(), init.end()); - } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - const V at(const K& key) const { - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(GetMapObj()->at(key)); - } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - const V operator[](const K& key) const { return this->at(key); } - /*! \return The size of the array */ - size_t size() const { - MapObj* n = GetMapObj(); - return n == nullptr ? 0 : n->size(); - } - /*! \return The number of elements of the key */ - size_t count(const K& key) const { - MapObj* n = GetMapObj(); - return n == nullptr ? 0 : GetMapObj()->count(key); - } - /*! \return whether array is empty */ - bool empty() const { return size() == 0; } - /*! \brief Release reference to all the elements */ - void clear() { - MapObj* n = GetMapObj(); - if (n != nullptr) { - data_ = MapObj::Empty(); +public: + /*! \brief The key type of the map */ + using key_type = K; + /*! \brief The mapped type of the map */ + using mapped_type = V; + /*! \brief The iterator type of the map */ + class iterator; + /*! + * \brief Construct an Map with UnsafeInit + */ + explicit Map(UnsafeInit tag) : ObjectRef(tag) {} + /*! + * \brief default constructor + */ + Map() { data_ = MapObj::Empty(); } + /*! + * \brief move constructor + * \param other source + */ + Map(Map &&other) // NOLINT(google-explicit-constructor) + : ObjectRef(std::move(other.data_)) {} + /*! + * \brief copy constructor + * \param other source + */ + Map(const Map &other) // NOLINT(google-explicit-constructor) + : ObjectRef(other.data_) {} + + /*! + * \brief Move constructor + * \param other The other map + * \tparam KU The key type of the other map + * \tparam VU The mapped type of the other map + */ + template && details::type_contains_v>> + Map(Map &&other) // NOLINT(google-explicit-constructor) + : ObjectRef(std::move(other.data_)) {} + + /*! + * \brief Copy constructor + * \param other The other map + * \tparam KU The key type of the other map + * \tparam VU The mapped type of the other map + */ + template && details::type_contains_v>> + Map(const Map &other) : ObjectRef(other.data_) {} // NOLINT(google-explicit-constructor) + + /*! + * \brief Move assignment + * \param other The other map + */ + Map &operator=(Map &&other) { + data_ = std::move(other.data_); + return *this; } - } - /*! - * \brief set the Map. - * \param key The index key. - * \param value The value to be setted. - */ - void Set(const K& key, const V& value) { - CopyOnWrite(); - MapObj::InsertMaybeReHash(MapObj::KVType(key, value), &data_); - } - /*! \return begin iterator */ - iterator begin() const { return iterator(GetMapObj()->begin()); } - /*! \return end iterator */ - iterator end() const { return iterator(GetMapObj()->end()); } - /*! \return find the key and returns the associated iterator */ - iterator find(const K& key) const { return iterator(GetMapObj()->find(key)); } - /*! \return The value associated with the key, std::nullopt if not found */ - std::optional Get(const K& key) const { - MapObj::iterator iter = GetMapObj()->find(key); - if (iter == GetMapObj()->end()) { - return std::nullopt; + + /*! + * \brief Copy assignment + * \param other The other map + */ + Map &operator=(const Map &other) { + data_ = other.data_; + return *this; } - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(iter->second); - } - - /*! - * \brief Erase the entry associated with the key - * \param key The key - */ - void erase(const K& key) { CopyOnWrite()->erase(key); } - - /*! - * \brief copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which guarantees to be unique) - */ - MapObj* CopyOnWrite() { - if (data_.get() == nullptr) { - data_ = MapObj::Empty(); - } else if (!data_.unique()) { - data_ = MapObj::CopyFrom(GetMapObj()); + + /*! + * \brief Move assignment + * \param other The other map + * \tparam KU The key type of the other map + * \tparam VU The mapped type of the other map + */ + template && details::type_contains_v>> + Map &operator=(Map &&other) { + data_ = std::move(other.data_); + return *this; } - return GetMapObj(); - } - /*! \brief specify container node */ - using ContainerType = MapObj; - - /// \cond Doxygen_Suppress - /*! \brief Iterator of the hash map */ - class iterator { - public: - using iterator_category = std::bidirectional_iterator_tag; - using difference_type = int64_t; - using value_type = const std::pair; - using pointer = value_type*; - using reference = value_type; - - iterator() : itr() {} - - /*! \brief Compare iterators */ - bool operator==(const iterator& other) const { return itr == other.itr; } - /*! \brief Compare iterators */ - bool operator!=(const iterator& other) const { return itr != other.itr; } - /*! \brief De-reference iterators is not allowed */ - pointer operator->() const = delete; - /*! \brief De-reference iterators */ - reference operator*() const { - auto& kv = *itr; - return std::make_pair(details::AnyUnsafe::CopyFromAnyViewAfterCheck(kv.first), - details::AnyUnsafe::CopyFromAnyViewAfterCheck(kv.second)); + + /*! + * \brief Copy assignment + * \param other The other map + * \tparam KU The key type of the other map + * \tparam VU The mapped type of the other map + */ + template && details::type_contains_v>> + Map &operator=(const Map &other) { + data_ = other.data_; + return *this; } - /*! \brief Prefix self increment, e.g. ++iter */ - iterator& operator++() { - ++itr; - return *this; + /*! + * \brief constructor from pointer + * \param n the container pointer + */ + explicit Map(ObjectPtr n) : ObjectRef(n) {} + /*! + * \brief constructor from iterator + * \param begin begin of iterator + * \param end end of iterator + * \tparam IterType The type of iterator + */ + template + Map(IterType begin, IterType end) { + data_ = MapObj::CreateFromRange(begin, end); } - /*! \brief Suffix self increment */ - iterator operator++(int) { - iterator copy = *this; - ++(*this); - return copy; + /*! + * \brief constructor from initializer list + * \param init The initalizer list + */ + Map(std::initializer_list> init) { + data_ = MapObj::CreateFromRange(init.begin(), init.end()); } - - /*! \brief Prefix self decrement, e.g. --iter */ - iterator& operator--() { - --itr; - return *this; + /*! + * \brief constructor from unordered_map + * \param init The unordered_map + */ + template + Map(const std::unordered_map &init) { // NOLINT(*) + data_ = MapObj::CreateFromRange(init.begin(), init.end()); } - /*! \brief Suffix self decrement */ - iterator operator--(int) { - iterator copy = *this; - --(*this); - return copy; + /*! + * \brief Read element from map. + * \param key The key + * \return the corresonding element. + */ + const V at(const K &key) const { + return details::AnyUnsafe::CopyFromAnyViewAfterCheck(GetMapObj()->at(key)); + } + /*! + * \brief Read element from map. + * \param key The key + * \return the corresonding element. + */ + const V operator[](const K &key) const { return this->at(key); } + /*! \return The size of the array */ + size_t size() const { + MapObj *n = GetMapObj(); + return n == nullptr ? 0 : n->size(); + } + /*! \return The number of elements of the key */ + size_t count(const K &key) const { + MapObj *n = GetMapObj(); + return n == nullptr ? 0 : GetMapObj()->count(key); + } + /*! \return whether array is empty */ + bool empty() const { return size() == 0; } + /*! \brief Release reference to all the elements */ + void clear() { + MapObj *n = GetMapObj(); + if (n != nullptr) { + data_ = MapObj::Empty(); + } + } + /*! + * \brief set the Map. + * \param key The index key. + * \param value The value to be setted. + */ + void Set(const K &key, const V &value) { + CopyOnWrite(); + MapObj::InsertMaybeReHash(MapObj::KVType(key, value), &data_); + } + /*! \return begin iterator */ + iterator begin() const { return iterator(GetMapObj()->begin()); } + /*! \return end iterator */ + iterator end() const { return iterator(GetMapObj()->end()); } + /*! \return find the key and returns the associated iterator */ + iterator find(const K &key) const { return iterator(GetMapObj()->find(key)); } + /*! \return The value associated with the key, std::nullopt if not found */ + std::optional Get(const K &key) const { + MapObj::iterator iter = GetMapObj()->find(key); + if (iter == GetMapObj()->end()) { + return std::nullopt; + } + return details::AnyUnsafe::CopyFromAnyViewAfterCheck(iter->second); } - private: - iterator(const MapObj::iterator& itr) // NOLINT(*) - : itr(itr) {} + /*! + * \brief Erase the entry associated with the key + * \param key The key + */ + void erase(const K &key) { CopyOnWrite()->erase(key); } + + /*! + * \brief copy on write semantics + * Do nothing if current handle is the unique copy of the array. + * Otherwise make a new copy of the array to ensure the current handle + * hold a unique copy. + * + * \return Handle to the internal node container(which guarantees to be unique) + */ + MapObj *CopyOnWrite() { + if (data_.get() == nullptr) { + data_ = MapObj::Empty(); + } else if (!data_.unique()) { + data_ = MapObj::CopyFrom(GetMapObj()); + } + return GetMapObj(); + } + /*! \brief specify container node */ + using ContainerType = MapObj; + + /// \cond Doxygen_Suppress + /*! \brief Iterator of the hash map */ + class iterator { + public: + using iterator_category = std::bidirectional_iterator_tag; + using difference_type = int64_t; + using value_type = const std::pair; + using pointer = value_type *; + using reference = value_type; + + iterator() : itr() {} + + /*! \brief Compare iterators */ + bool operator==(const iterator &other) const { return itr == other.itr; } + /*! \brief Compare iterators */ + bool operator!=(const iterator &other) const { return itr != other.itr; } + /*! \brief De-reference iterators is not allowed */ + pointer operator->() const = delete; + /*! \brief De-reference iterators */ + reference operator*() const { + auto &kv = *itr; + return std::make_pair(details::AnyUnsafe::CopyFromAnyViewAfterCheck(kv.first), + details::AnyUnsafe::CopyFromAnyViewAfterCheck(kv.second)); + } + /*! \brief Prefix self increment, e.g. ++iter */ + iterator &operator++() { + ++itr; + return *this; + } + /*! \brief Suffix self increment */ + iterator operator++(int) { + iterator copy = *this; + ++(*this); + return copy; + } - template - friend class Map; + /*! \brief Prefix self decrement, e.g. --iter */ + iterator &operator--() { + --itr; + return *this; + } + /*! \brief Suffix self decrement */ + iterator operator--(int) { + iterator copy = *this; + --(*this); + return copy; + } + + private: + iterator(const MapObj::iterator &itr) // NOLINT(*) + : itr(itr) {} - MapObj::iterator itr; - }; - /// \endcond + template + friend class Map; - private: - /*! \brief Return data_ as type of pointer of MapObj */ - MapObj* GetMapObj() const { return static_cast(data_.get()); } + MapObj::iterator itr; + }; + /// \endcond + +private: + /*! \brief Return data_ as type of pointer of MapObj */ + MapObj *GetMapObj() const { return static_cast(data_.get()); } - template - friend class Map; + template + friend class Map; }; /*! @@ -1663,13 +1657,12 @@ class Map : public ObjectRef { * @return The merged Array. Original Maps are kept unchanged. */ template && - details::storage_enabled_v>> -inline Map Merge(Map lhs, const Map& rhs) { - for (const auto& p : rhs) { - lhs.Set(p.first, p.second); - } - return std::move(lhs); + typename = typename std::enable_if_t && details::storage_enabled_v>> +inline Map Merge(Map lhs, const Map &rhs) { + for (const auto &p : rhs) { + lhs.Set(p.first, p.second); + } + return std::move(lhs); } // Traits for Map @@ -1678,104 +1671,115 @@ inline constexpr bool use_default_type_traits_v> = false; template struct TypeTraits> : public ObjectRefTypeTraitsBase> { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIMap; - using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIMap; + using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIMap) { - return TypeTraitsBase::GetMismatchTypeInfo(src); - } - if constexpr (!std::is_same_v || !std::is_same_v) { - const MapObj* n = reinterpret_cast(src->v_obj); - for (const auto& kv : *n) { - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.first) && - !kv.first.try_cast().has_value()) { - return "Map[some key is " + details::AnyUnsafe::GetMismatchTypeInfo(kv.first) + - ", V]"; - } - } - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.second) && - !kv.second.try_cast().has_value()) { - return "Map[K, some value is " + details::AnyUnsafe::GetMismatchTypeInfo(kv.second) + - "]"; - } - } - } + TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny *src) { + if (src->type_index != TypeIndex::kTVMFFIMap) { + return TypeTraitsBase::GetMismatchTypeInfo(src); + } + if constexpr (!std::is_same_v || !std::is_same_v) { + const MapObj *n = reinterpret_cast(src->v_obj); + for (const auto &kv : *n) { + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStrict(kv.first) && !kv.first.try_cast().has_value()) { + return "Map[some key is " + details::AnyUnsafe::GetMismatchTypeInfo(kv.first) + ", V]"; + } + } + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStrict(kv.second) && !kv.second.try_cast().has_value()) { + return "Map[K, some value is " + details::AnyUnsafe::GetMismatchTypeInfo(kv.second) + "]"; + } + } + } + } + TVM_FFI_THROW(InternalError) << "Cannot reach here"; + TVM_FFI_UNREACHABLE(); } - TVM_FFI_THROW(InternalError) << "Cannot reach here"; - TVM_FFI_UNREACHABLE(); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIMap) return false; - if constexpr (std::is_same_v && std::is_same_v) { - return true; - } else { - const MapObj* n = reinterpret_cast(src->v_obj); - for (const auto& kv : *n) { - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.first)) return false; + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + if (src->type_index != TypeIndex::kTVMFFIMap) { + return false; } - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.second)) return false; + if constexpr (std::is_same_v && std::is_same_v) { + return true; + } else { + const MapObj *n = reinterpret_cast(src->v_obj); + for (const auto &kv : *n) { + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStrict(kv.first)) { + return false; + } + } + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStrict(kv.second)) { + return false; + } + } + } + return true; } - } - return true; } - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIMap) return std::nullopt; - if constexpr (!std::is_same_v || !std::is_same_v) { - const MapObj* n = reinterpret_cast(src->v_obj); - bool storage_check = [&]() { - for (const auto& kv : *n) { - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.first)) return false; - } - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.second)) return false; - } + + TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index != TypeIndex::kTVMFFIMap) { + return std::nullopt; } - return true; - }(); - // fast path, if storage check passes, we can return the array directly. - if (storage_check) return CopyFromAnyViewAfterCheck(src); - // slow path, we need to create a new map and convert to the target type. - Map ret; - for (const auto& kv : *n) { - auto k = kv.first.try_cast(); - auto v = kv.second.try_cast(); - if (!k.has_value() || !v.has_value()) return std::nullopt; - ret.Set(*std::move(k), *std::move(v)); - } - return ret; - } else { - return CopyFromAnyViewAfterCheck(src); + if constexpr (!std::is_same_v || !std::is_same_v) { + const MapObj *n = reinterpret_cast(src->v_obj); + bool storage_check = [&]() { + for (const auto &kv : *n) { + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStrict(kv.first)) { + return false; + } + } + if constexpr (!std::is_same_v) { + if (!details::AnyUnsafe::CheckAnyStrict(kv.second)) { + return false; + } + } + } + return true; + }(); + // fast path, if storage check passes, we can return the array directly. + if (storage_check) { + return CopyFromAnyViewAfterCheck(src); + } + // slow path, we need to create a new map and convert to the target type. + Map ret; + for (const auto &kv : *n) { + auto k = kv.first.try_cast(); + auto v = kv.second.try_cast(); + if (!k.has_value() || !v.has_value()) { + return std::nullopt; + } + ret.Set(*std::move(k), *std::move(v)); + } + return ret; + } else { + return CopyFromAnyViewAfterCheck(src); + } + } + + TVM_FFI_INLINE static std::string TypeStr() { + return "Map<" + details::Type2Str::v() + ", " + details::Type2Str::v() + ">"; + } + TVM_FFI_INLINE static std::string TypeSchema() { + std::ostringstream oss; + oss << R"({"type":")" << StaticTypeKey::kTVMFFIMap << R"(","args":[)"; + oss << details::TypeSchema::v() << ","; + oss << details::TypeSchema::v(); + oss << "]}"; + return oss.str(); } - } - - TVM_FFI_INLINE static std::string TypeStr() { - return "Map<" + details::Type2Str::v() + ", " + details::Type2Str::v() + ">"; - } - TVM_FFI_INLINE static std::string TypeSchema() { - std::ostringstream oss; - oss << R"({"type":")" << StaticTypeKey::kTVMFFIMap << R"(","args":[)"; - oss << details::TypeSchema::v() << ","; - oss << details::TypeSchema::v(); - oss << "]}"; - return oss.str(); - } }; namespace details { template -inline constexpr bool type_contains_v, Map> = - type_contains_v && type_contains_v; -} // namespace details +inline constexpr bool type_contains_v, Map> = type_contains_v && type_contains_v; +} // namespace details -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_CONTAINER_MAP_H_ +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CONTAINER_MAP_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tuple.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tuple.h index e5eb3cab6..483333195 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tuple.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tuple.h @@ -44,193 +44,191 @@ namespace ffi { */ template class Tuple : public ObjectRef { - public: - static_assert(details::all_storage_enabled_v, - "All types used in Tuple<...> must be compatible with Any"); - /*! \brief Default constructor */ - Tuple() : ObjectRef(MakeDefaultTupleNode()) {} - /*! - * \brief Constructor with UnsafeInit - */ - explicit Tuple(UnsafeInit tag) : ObjectRef(tag) {} - /*! \brief Copy constructor */ - Tuple(const Tuple& other) : ObjectRef(other) {} - /*! \brief Move constructor */ - Tuple(Tuple&& other) noexcept : ObjectRef(std::move(other)) {} - /*! - * \brief Constructor from another tuple - * \param other The other tuple - * \tparam UTypes The types of the other tuple - * \tparam The enable_if_t type - */ - template && ...), int>> - Tuple(const Tuple& other) : ObjectRef(other) {} // NOLINT(google-explicit-constructor) - - /*! - * \brief Constructor from another tuple - * \param other The other tuple - * \tparam UTypes The types of the other tuple - * \tparam The enable_if_t type - */ - template && ...), int>> - Tuple(Tuple&& other) // NOLINT(google-explicit-constructor) - : ObjectRef(std::move(other)) {} - - /*! - * \brief Constructor from arguments - * \param args The arguments - * \tparam UTypes The types of the other tuple - */ - template , Tuple> && ...))>> - explicit Tuple(UTypes&&... args) : ObjectRef(MakeTupleNode(std::forward(args)...)) {} - - /*! - * \brief Assignment from another tuple - * \param other The other tuple - * \tparam The enable_if_t type - */ - TVM_FFI_INLINE Tuple& operator=(const Tuple& other) { - data_ = other.data_; - return *this; - } - - /*! - * \brief Assignment from another tuple - * \param other The other tuple - * \tparam The enable_if_t type - */ - TVM_FFI_INLINE Tuple& operator=(Tuple&& other) noexcept { - data_ = std::move(other.data_); - return *this; - } - - /*! - * \brief Assignment from another tuple - * \param other The other tuple - * \tparam UTypes The types of the other tuple - * \tparam The enable_if_t type - */ - template && ...)>> - TVM_FFI_INLINE Tuple& operator=(const Tuple& other) { - data_ = other.data_; - return *this; - } - - /*! - * \brief Assignment from another tuple - * \param other The other tuple - * \tparam UTypes The types of the other tuple - * \tparam The enable_if_t type - */ - template && ...)>> - TVM_FFI_INLINE Tuple& operator=(Tuple&& other) { - data_ = std::move(other.data_); - return *this; - } - - /*! - * \brief Get I-th element of the tuple - * - * \tparam I The index of the element to get - * \return The I-th element of the tuple - * \note We use stl style since get usually is like a getter. - */ - template - auto get() const& { - static_assert(I < sizeof...(Types), "Tuple index out of bounds"); - using ReturnType = std::tuple_element_t>; - const Any* ptr = GetArrayObj()->begin() + I; - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*ptr); - } - - /*! - * \brief Move out I-th element of the tuple - * - * \tparam I The index of the element to get - * \return The I-th element of the tuple - * \note We use stl style since get usually is like a getter. - */ - template - auto get() && { - if (!this->unique()) { - // fallback to const& version if not unique - return std::as_const(*this).template get(); +public: + static_assert(details::all_storage_enabled_v, + "All types used in Tuple<...> must be compatible with Any"); + /*! \brief Default constructor */ + Tuple() : ObjectRef(MakeDefaultTupleNode()) {} + /*! + * \brief Constructor with UnsafeInit + */ + explicit Tuple(UnsafeInit tag) : ObjectRef(tag) {} + /*! \brief Copy constructor */ + Tuple(const Tuple &other) : ObjectRef(other) {} + /*! \brief Move constructor */ + Tuple(Tuple &&other) noexcept : ObjectRef(std::move(other)) {} + /*! + * \brief Constructor from another tuple + * \param other The other tuple + * \tparam UTypes The types of the other tuple + * \tparam The enable_if_t type + */ + template && ...), int>> + Tuple(const Tuple &other) : ObjectRef(other) {} // NOLINT(google-explicit-constructor) + + /*! + * \brief Constructor from another tuple + * \param other The other tuple + * \tparam UTypes The types of the other tuple + * \tparam The enable_if_t type + */ + template && ...), int>> + Tuple(Tuple &&other) // NOLINT(google-explicit-constructor) + : ObjectRef(std::move(other)) {} + + /*! + * \brief Constructor from arguments + * \param args The arguments + * \tparam UTypes The types of the other tuple + */ + template , Tuple> && ...))>> + explicit Tuple(UTypes &&...args) : ObjectRef(MakeTupleNode(std::forward(args)...)) {} + + /*! + * \brief Assignment from another tuple + * \param other The other tuple + * \tparam The enable_if_t type + */ + TVM_FFI_INLINE Tuple &operator=(const Tuple &other) { + data_ = other.data_; + return *this; } - static_assert(I < sizeof...(Types), "Tuple index out of bounds"); - using ReturnType = std::tuple_element_t>; - Any* ptr = GetArrayObj()->MutableBegin() + I; - return details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(*ptr)); - } - - /*! - * \brief Set I-th element of the tuple - * - * \param item The item to set - * \tparam I The index of the element to set - * \tparam U The type of the item - * - * \note This function will perform copy on write if underlying - * container is not uniquely owned. - * We use CamelCase since Set can cause copy on write - * and is more complicated than simple field setter. - */ - template - void Set(U&& item) { - static_assert(I < sizeof...(Types), "Tuple index out of bounds"); - using T = std::tuple_element_t>; - this->CopyIfNotUnique(); - Any* ptr = GetArrayObj()->MutableBegin() + I; - *ptr = T(std::forward(item)); - } - - /*! \brief specify container node */ - using ContainerType = ArrayObj; - - private: - static ObjectPtr MakeDefaultTupleNode() { - ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); - Any* itr = p->MutableBegin(); - // increase size after each new to ensure exception safety - ((new (itr++) Any(Types()), p->size_++), ...); - return p; - } - - template - static ObjectPtr MakeTupleNode(UTypes&&... args) { - ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); - Any* itr = p->MutableBegin(); - // increase size after each new to ensure exception safety - ((new (itr++) Any(Types(std::forward(args))), p->size_++), ...); - return p; - } - - /*! \brief Copy on write */ - void CopyIfNotUnique() { - if (!data_.unique()) { - ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); - Any* itr = p->MutableBegin(); - const Any* read = GetArrayObj()->begin(); - // increase size after each new to ensure exception safety - for (size_t i = 0; i < sizeof...(Types); ++i) { - new (itr++) Any(*read++); - p->size_++; - } - data_ = std::move(p); + + /*! + * \brief Assignment from another tuple + * \param other The other tuple + * \tparam The enable_if_t type + */ + TVM_FFI_INLINE Tuple &operator=(Tuple &&other) noexcept { + data_ = std::move(other.data_); + return *this; + } + + /*! + * \brief Assignment from another tuple + * \param other The other tuple + * \tparam UTypes The types of the other tuple + * \tparam The enable_if_t type + */ + template && ...)>> + TVM_FFI_INLINE Tuple &operator=(const Tuple &other) { + data_ = other.data_; + return *this; + } + + /*! + * \brief Assignment from another tuple + * \param other The other tuple + * \tparam UTypes The types of the other tuple + * \tparam The enable_if_t type + */ + template && ...)>> + TVM_FFI_INLINE Tuple &operator=(Tuple &&other) { + data_ = std::move(other.data_); + return *this; + } + + /*! + * \brief Get I-th element of the tuple + * + * \tparam I The index of the element to get + * \return The I-th element of the tuple + * \note We use stl style since get usually is like a getter. + */ + template + auto get() const & { + static_assert(I < sizeof...(Types), "Tuple index out of bounds"); + using ReturnType = std::tuple_element_t>; + const Any *ptr = GetArrayObj()->begin() + I; + return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*ptr); } - } - /*! \return The underlying ArrayObj */ - ArrayObj* GetArrayObj() const { return static_cast(data_.get()); } + /*! + * \brief Move out I-th element of the tuple + * + * \tparam I The index of the element to get + * \return The I-th element of the tuple + * \note We use stl style since get usually is like a getter. + */ + template + auto get() && { + if (!this->unique()) { + // fallback to const& version if not unique + return std::as_const(*this).template get(); + } + static_assert(I < sizeof...(Types), "Tuple index out of bounds"); + using ReturnType = std::tuple_element_t>; + Any *ptr = GetArrayObj()->MutableBegin() + I; + return details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(*ptr)); + } + + /*! + * \brief Set I-th element of the tuple + * + * \param item The item to set + * \tparam I The index of the element to set + * \tparam U The type of the item + * + * \note This function will perform copy on write if underlying + * container is not uniquely owned. + * We use CamelCase since Set can cause copy on write + * and is more complicated than simple field setter. + */ + template + void Set(U &&item) { + static_assert(I < sizeof...(Types), "Tuple index out of bounds"); + using T = std::tuple_element_t>; + this->CopyIfNotUnique(); + Any *ptr = GetArrayObj()->MutableBegin() + I; + *ptr = T(std::forward(item)); + } + + /*! \brief specify container node */ + using ContainerType = ArrayObj; - template - friend class Tuple; +private: + static ObjectPtr MakeDefaultTupleNode() { + ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); + Any *itr = p->MutableBegin(); + // increase size after each new to ensure exception safety + ((new (itr++) Any(Types()), p->size_++), ...); + return p; + } + + template + static ObjectPtr MakeTupleNode(UTypes &&...args) { + ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); + Any *itr = p->MutableBegin(); + // increase size after each new to ensure exception safety + ((new (itr++) Any(Types(std::forward(args))), p->size_++), ...); + return p; + } + + /*! \brief Copy on write */ + void CopyIfNotUnique() { + if (!data_.unique()) { + ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); + Any *itr = p->MutableBegin(); + const Any *read = GetArrayObj()->begin(); + // increase size after each new to ensure exception safety + for (size_t i = 0; i < sizeof...(Types); ++i) { + new (itr++) Any(*read++); + p->size_++; + } + data_ = std::move(p); + } + } + + /*! \return The underlying ArrayObj */ + ArrayObj *GetArrayObj() const { return static_cast(data_.get()); } + + template + friend class Tuple; }; template @@ -238,108 +236,115 @@ inline constexpr bool use_default_type_traits_v> = false; template struct TypeTraits> : public ObjectRefTypeTraitsBase> { - using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIArray) { - return TypeTraitsBase::GetMismatchTypeInfo(src); - } - const ArrayObj* n = reinterpret_cast(src->v_obj); - if (n->size() != sizeof...(Types)) { - return "Array[size=" + std::to_string(n->size()) + "]"; - } - return GetMismatchTypeInfoHelper<0, Types...>(n->begin()); - } - - template - TVM_FFI_INLINE static std::string GetMismatchTypeInfoHelper(const Any* arr) { - if constexpr (!std::is_same_v) { - const Any& any_v = arr[I]; - if (!details::AnyUnsafe::CheckAnyStrict(any_v) && !(any_v.try_cast().has_value())) { - // now report the accurate mismatch information - return "Array[index " + std::to_string(I) + ": " + - details::AnyUnsafe::GetMismatchTypeInfo(any_v) + "]"; - } + using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; + + TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny *src) { + if (src->type_index != TypeIndex::kTVMFFIArray) { + return TypeTraitsBase::GetMismatchTypeInfo(src); + } + const ArrayObj *n = reinterpret_cast(src->v_obj); + if (n->size() != sizeof...(Types)) { + return "Array[size=" + std::to_string(n->size()) + "]"; + } + return GetMismatchTypeInfoHelper<0, Types...>(n->begin()); } - if constexpr (sizeof...(Rest) > 0) { - return GetMismatchTypeInfoHelper(arr); + + template + TVM_FFI_INLINE static std::string GetMismatchTypeInfoHelper(const Any *arr) { + if constexpr (!std::is_same_v) { + const Any &any_v = arr[I]; + if (!details::AnyUnsafe::CheckAnyStrict(any_v) && !(any_v.try_cast().has_value())) { + // now report the accurate mismatch information + return "Array[index " + std::to_string(I) + ": " + details::AnyUnsafe::GetMismatchTypeInfo(any_v) + "]"; + } + } + if constexpr (sizeof...(Rest) > 0) { + return GetMismatchTypeInfoHelper(arr); + } + TVM_FFI_THROW(InternalError) << "Cannot reach here"; + TVM_FFI_UNREACHABLE(); } - TVM_FFI_THROW(InternalError) << "Cannot reach here"; - TVM_FFI_UNREACHABLE(); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIArray) return false; - const ArrayObj* n = reinterpret_cast(src->v_obj); - if (n->size() != sizeof...(Types)) return false; - const TVMFFIAny* ffi_any_arr = reinterpret_cast(n->begin()); - return CheckAnyStrictHelper<0, Types...>(ffi_any_arr); - } - - template - TVM_FFI_INLINE static bool CheckAnyStrictHelper(const TVMFFIAny* src_arr) { - if constexpr (!std::is_same_v) { - if (!TypeTraits::CheckAnyStrict(src_arr + I)) { - return false; - } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + if (src->type_index != TypeIndex::kTVMFFIArray) { + return false; + } + const ArrayObj *n = reinterpret_cast(src->v_obj); + if (n->size() != sizeof...(Types)) { + return false; + } + const TVMFFIAny *ffi_any_arr = reinterpret_cast(n->begin()); + return CheckAnyStrictHelper<0, Types...>(ffi_any_arr); } - if constexpr (sizeof...(Rest) > 0) { - return CheckAnyStrictHelper(src_arr); + + template + TVM_FFI_INLINE static bool CheckAnyStrictHelper(const TVMFFIAny *src_arr) { + if constexpr (!std::is_same_v) { + if (!TypeTraits::CheckAnyStrict(src_arr + I)) { + return false; + } + } + if constexpr (sizeof...(Rest) > 0) { + return CheckAnyStrictHelper(src_arr); + } + return true; } - return true; - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { - if (src->type_index != TypeIndex::kTVMFFIArray) return std::nullopt; - const ArrayObj* n = reinterpret_cast(src->v_obj); - if (n->size() != sizeof...(Types)) return std::nullopt; - // fast path, storage is already in the right type - if (CheckAnyStrict(src)) { - return CopyFromAnyViewAfterCheck(src); + + TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny *src) { + if (src->type_index != TypeIndex::kTVMFFIArray) { + return std::nullopt; + } + const ArrayObj *n = reinterpret_cast(src->v_obj); + if (n->size() != sizeof...(Types)) { + return std::nullopt; + } + // fast path, storage is already in the right type + if (CheckAnyStrict(src)) { + return CopyFromAnyViewAfterCheck(src); + } + // slow path, try to convert to each type to match the tuple storage need. + Array arr = TypeTraits>::CopyFromAnyViewAfterCheck(src); + Any *ptr = arr.CopyOnWrite()->MutableBegin(); + if (TryConvertElements<0, Types...>(ptr)) { + return details::ObjectUnsafe::ObjectRefFromObjectPtr>( + details::ObjectUnsafe::ObjectPtrFromObjectRef(arr)); + } + return std::nullopt; } - // slow path, try to convert to each type to match the tuple storage need. - Array arr = TypeTraits>::CopyFromAnyViewAfterCheck(src); - Any* ptr = arr.CopyOnWrite()->MutableBegin(); - if (TryConvertElements<0, Types...>(ptr)) { - return details::ObjectUnsafe::ObjectRefFromObjectPtr>( - details::ObjectUnsafe::ObjectPtrFromObjectRef(arr)); + + template + TVM_FFI_INLINE static bool TryConvertElements(Any *arr) { + if constexpr (!std::is_same_v) { + if (auto opt_convert = arr[I].try_cast()) { + arr[I] = *std::move(opt_convert); + } else { + return false; + } + } + if constexpr (sizeof...(Rest) > 0) { + return TryConvertElements(std::move(arr)); + } else { + return true; + } } - return std::nullopt; - } - - template - TVM_FFI_INLINE static bool TryConvertElements(Any* arr) { - if constexpr (!std::is_same_v) { - if (auto opt_convert = arr[I].try_cast()) { - arr[I] = *std::move(opt_convert); - } else { - return false; - } + + TVM_FFI_INLINE static std::string TypeStr() { + return details::ContainerTypeStr("Tuple"); } - if constexpr (sizeof...(Rest) > 0) { - return TryConvertElements(std::move(arr)); - } else { - return true; + TVM_FFI_INLINE static std::string TypeSchema() { + std::ostringstream oss; + oss << R"({"type":"Tuple","args":[)"; + const char *sep = ""; + ((oss << sep << details::TypeSchema::v(), sep = ","), ...); + oss << "]}"; + return oss.str(); } - } - - TVM_FFI_INLINE static std::string TypeStr() { - return details::ContainerTypeStr("Tuple"); - } - TVM_FFI_INLINE static std::string TypeSchema() { - std::ostringstream oss; - oss << R"({"type":"Tuple","args":[)"; - const char* sep = ""; - ((oss << sep << details::TypeSchema::v(), sep = ","), ...); - oss << "]}"; - return oss.str(); - } }; namespace details { template inline constexpr bool type_contains_v, Tuple> = (type_contains_v && ...); -} // namespace details +} // namespace details /// \cond Doxygen_Suppress @@ -354,9 +359,9 @@ inline constexpr bool type_contains_v, Tuple> = (type_contains * \return The I-th element of the tuple */ template -inline constexpr auto get(const Tuple& t) +inline constexpr auto get(const Tuple &t) -> std::tuple_element_t> { - return t.template get(); + return t.template get(); } /** @@ -366,18 +371,18 @@ inline constexpr auto get(const Tuple& t) * \return The I-th element of the tuple */ template -inline constexpr auto get(Tuple&& t) -> std::tuple_element_t> { - return std::move(t).template get(); +inline constexpr auto get(Tuple &&t) -> std::tuple_element_t> { + return std::move(t).template get(); } /// NOTE: C++17 deduction guide template -Tuple(UTypes&&...) -> Tuple>...>; +Tuple(UTypes &&...) -> Tuple>...>; /// \endcond -} // namespace ffi -} // namespace tvm +} // namespace ffi +} // namespace tvm namespace std { @@ -387,9 +392,9 @@ struct tuple_size<::tvm::ffi::Tuple> template struct tuple_element> { - using type = std::tuple_element_t>; + using type = std::tuple_element_t>; }; -} // namespace std +} // namespace std -#endif // TVM_FFI_CONTAINER_TUPLE_H_ +#endif // TVM_FFI_CONTAINER_TUPLE_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/variant.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/variant.h index 08dc764d5..e1b91b526 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/variant.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/variant.h @@ -42,57 +42,57 @@ namespace details { */ template class VariantBase { - public: - TVM_FFI_INLINE bool same_as(const VariantBase& other) const { - return data_.same_as(other.data_); - } +public: + TVM_FFI_INLINE bool same_as(const VariantBase &other) const { + return data_.same_as(other.data_); + } - protected: - template - explicit VariantBase(T other) : data_(std::move(other)) {} +protected: + template + explicit VariantBase(T other) : data_(std::move(other)) {} - TVM_FFI_INLINE void SetData(Any other_data) { data_ = std::move(other_data); } + TVM_FFI_INLINE void SetData(Any other_data) { data_ = std::move(other_data); } - TVM_FFI_INLINE Any MoveToAny() && { return std::move(data_); } + TVM_FFI_INLINE Any MoveToAny() && { return std::move(data_); } - TVM_FFI_INLINE AnyView ToAnyView() const { return data_.operator AnyView(); } + TVM_FFI_INLINE AnyView ToAnyView() const { return data_.operator AnyView(); } - Any data_; + Any data_; }; // Specialization for all object ref case, backed by ObjectRef. template <> class VariantBase : public ObjectRef { - protected: - template - explicit VariantBase(const T& other) : ObjectRef(other) {} - template , VariantBase>>> - explicit VariantBase(T&& other) : ObjectRef(std::forward(other)) {} - explicit VariantBase(UnsafeInit tag) : ObjectRef(tag) {} - explicit VariantBase(Any other) - : ObjectRef(details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(other))) {} - - TVM_FFI_INLINE void SetData(ObjectPtr other) { data_ = std::move(other); } - - TVM_FFI_INLINE Any MoveToAny() && { return Any(ObjectRef(std::move(data_))); } - - TVM_FFI_INLINE AnyView ToAnyView() const { - TVMFFIAny any_data; - if (data_ == nullptr) { - any_data.type_index = TypeIndex::kTVMFFINone; - any_data.zero_padding = 0; - any_data.v_int64 = 0; - } else { - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data); - any_data.type_index = data_->type_index(); - any_data.zero_padding = 0; - any_data.v_obj = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr(data_); +protected: + template + explicit VariantBase(const T &other) : ObjectRef(other) {} + template , VariantBase>>> + explicit VariantBase(T &&other) : ObjectRef(std::forward(other)) {} + explicit VariantBase(UnsafeInit tag) : ObjectRef(tag) {} + explicit VariantBase(Any other) + : ObjectRef(details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(other))) {} + + TVM_FFI_INLINE void SetData(ObjectPtr other) { data_ = std::move(other); } + + TVM_FFI_INLINE Any MoveToAny() && { return Any(ObjectRef(std::move(data_))); } + + TVM_FFI_INLINE AnyView ToAnyView() const { + TVMFFIAny any_data; + if (data_ == nullptr) { + any_data.type_index = TypeIndex::kTVMFFINone; + any_data.zero_padding = 0; + any_data.v_int64 = 0; + } else { + TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data); + any_data.type_index = data_->type_index(); + any_data.zero_padding = 0; + any_data.v_obj = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr(data_); + } + return AnyView::CopyFromTVMFFIAny(any_data); } - return AnyView::CopyFromTVMFFIAny(any_data); - } }; -} // namespace details +} // namespace details /*! * \brief A typed variant container. @@ -102,133 +102,133 @@ class VariantBase : public ObjectRef { */ template class Variant : public details::VariantBase> { - public: - /// \cond Doxygen_Suppress - using TParent = details::VariantBase>; - static_assert(details::all_storage_enabled_v, - "All types used in Variant<...> must be compatible with Any"); - /* - * \brief Helper utility to check if the type can be contained in the variant - */ - template - static constexpr bool variant_contains_v = (details::type_contains_v || ...); - /* \brief Helper utility for SFINAE if the type is part of the variant */ - template - using enable_if_variant_contains_t = std::enable_if_t>; - /// \endcond - /*! - * \brief Constructor from another variant - * \param other The other variant - */ - Variant(const Variant& other) : TParent(other.data_) {} - /*! - * \brief Constructor from another variant - * \param other The other variant - */ - Variant(Variant&& other) noexcept : TParent(std::move(other.data_)) {} - - /*! - * \brief Assignment from another variant - * \param other The other variant - */ - TVM_FFI_INLINE Variant& operator=(const Variant& other) { - this->SetData(other.data_); - return *this; - } - - /*! - * \brief Assignment from another variant - * \param other The other variant - */ - TVM_FFI_INLINE Variant& operator=(Variant&& other) noexcept { - this->SetData(std::move(other.data_)); - return *this; - } - - /*! - * \brief Constructor from another variant - * \param other The other variant - */ - template > - Variant(T other) : TParent(std::move(other)) {} // NOLINT(*) - - /*! - * \brief Assignment from another variant - * \param other The other variant - */ - template > - TVM_FFI_INLINE Variant& operator=(T other) { - return operator=(Variant(std::move(other))); - } - - /*! - * \brief Try to cast to a type T, return std::nullopt if the cast is not possible. - * \return The casted value, or std::nullopt if the cast is not possible. - * \tparam T The type to cast to. - */ - template > - TVM_FFI_INLINE std::optional as() const { - return this->TParent::ToAnyView().template as(); - } - - /*! - * \brief Shortcut of as Object to cast to a const pointer when T is an Object. - * - * \tparam T The object type. - * \return The requested pointer, returns nullptr if type mismatches. - */ - template >> - TVM_FFI_INLINE const T* as() const { - return this->TParent::ToAnyView().template as().value_or(nullptr); - } - - /*! - * \brief Get the value of the variant in type T, throws an exception if cast fails. - * \return The value of the variant - * \tparam T The type to get. - */ - template > - TVM_FFI_INLINE T get() const& { - return this->TParent::ToAnyView().template cast(); - } - - /*! - * \brief Get the value of the variant in type T, throws an exception if cast fails. - * \return The value of the variant - * \tparam T The type to get. - */ - template > - TVM_FFI_INLINE T get() && { - return std::move(*this).TParent::MoveToAny().template cast(); - } - - /*! - * \brief Get the type key of the variant - * \return The type key of the variant - */ - TVM_FFI_INLINE std::string GetTypeKey() const { return this->TParent::ToAnyView().GetTypeKey(); } - - private: - friend struct TypeTraits>; - friend struct ObjectPtrHash; - friend struct ObjectPtrEqual; - // constructor from any - explicit Variant(Any data) : TParent(std::move(data)) {} - /*! - * \brief Get the object pointer from the variant - * \note This function is only available if all types used in Variant<...> are derived from - * ObjectRef - */ - TVM_FFI_INLINE Object* GetObjectPtrForHashEqual() const { - constexpr bool all_object_v = (std::is_base_of_v && ...); - static_assert(all_object_v, - "All types used in Variant<...> must be derived from ObjectRef " - "to enable ObjectPtrHash/ObjectPtrEqual"); - return this->data_.get(); - } - // rexpose to friend class - using TParent::MoveToAny; - using TParent::ToAnyView; +public: + /// \cond Doxygen_Suppress + using TParent = details::VariantBase>; + static_assert(details::all_storage_enabled_v, + "All types used in Variant<...> must be compatible with Any"); + /* + * \brief Helper utility to check if the type can be contained in the variant + */ + template + static constexpr bool variant_contains_v = (details::type_contains_v || ...); + /* \brief Helper utility for SFINAE if the type is part of the variant */ + template + using enable_if_variant_contains_t = std::enable_if_t>; + /// \endcond + /*! + * \brief Constructor from another variant + * \param other The other variant + */ + Variant(const Variant &other) : TParent(other.data_) {} + /*! + * \brief Constructor from another variant + * \param other The other variant + */ + Variant(Variant &&other) noexcept : TParent(std::move(other.data_)) {} + + /*! + * \brief Assignment from another variant + * \param other The other variant + */ + TVM_FFI_INLINE Variant &operator=(const Variant &other) { + this->SetData(other.data_); + return *this; + } + + /*! + * \brief Assignment from another variant + * \param other The other variant + */ + TVM_FFI_INLINE Variant &operator=(Variant &&other) noexcept { + this->SetData(std::move(other.data_)); + return *this; + } + + /*! + * \brief Constructor from another variant + * \param other The other variant + */ + template > + Variant(T other) : TParent(std::move(other)) {} // NOLINT(*) + + /*! + * \brief Assignment from another variant + * \param other The other variant + */ + template > + TVM_FFI_INLINE Variant &operator=(T other) { + return operator=(Variant(std::move(other))); + } + + /*! + * \brief Try to cast to a type T, return std::nullopt if the cast is not possible. + * \return The casted value, or std::nullopt if the cast is not possible. + * \tparam T The type to cast to. + */ + template > + TVM_FFI_INLINE std::optional as() const { + return this->TParent::ToAnyView().template as(); + } + + /*! + * \brief Shortcut of as Object to cast to a const pointer when T is an Object. + * + * \tparam T The object type. + * \return The requested pointer, returns nullptr if type mismatches. + */ + template >> + TVM_FFI_INLINE const T *as() const { + return this->TParent::ToAnyView().template as().value_or(nullptr); + } + + /*! + * \brief Get the value of the variant in type T, throws an exception if cast fails. + * \return The value of the variant + * \tparam T The type to get. + */ + template > + TVM_FFI_INLINE T get() const & { + return this->TParent::ToAnyView().template cast(); + } + + /*! + * \brief Get the value of the variant in type T, throws an exception if cast fails. + * \return The value of the variant + * \tparam T The type to get. + */ + template > + TVM_FFI_INLINE T get() && { + return std::move(*this).TParent::MoveToAny().template cast(); + } + + /*! + * \brief Get the type key of the variant + * \return The type key of the variant + */ + TVM_FFI_INLINE std::string GetTypeKey() const { return this->TParent::ToAnyView().GetTypeKey(); } + +private: + friend struct TypeTraits>; + friend struct ObjectPtrHash; + friend struct ObjectPtrEqual; + // constructor from any + explicit Variant(Any data) : TParent(std::move(data)) {} + /*! + * \brief Get the object pointer from the variant + * \note This function is only available if all types used in Variant<...> are derived from + * ObjectRef + */ + TVM_FFI_INLINE Object *GetObjectPtrForHashEqual() const { + constexpr bool all_object_v = (std::is_base_of_v && ...); + static_assert(all_object_v, + "All types used in Variant<...> must be derived from ObjectRef " + "to enable ObjectPtrHash/ObjectPtrEqual"); + return this->data_.get(); + } + // rexpose to friend class + using TParent::MoveToAny; + using TParent::ToAnyView; }; template @@ -236,76 +236,76 @@ inline constexpr bool use_default_type_traits_v> = false; template struct TypeTraits> : public TypeTraitsBase { - TVM_FFI_INLINE static void CopyToAnyView(const Variant& src, TVMFFIAny* result) { - *result = src.ToAnyView().CopyToTVMFFIAny(); - } - - TVM_FFI_INLINE static void MoveToAny(Variant src, TVMFFIAny* result) { - *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src).MoveToAny()); - } - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { - return TypeTraitsBase::GetMismatchTypeInfo(src); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { - return (TypeTraits::CheckAnyStrict(src) || ...); - } - - TVM_FFI_INLINE static Variant CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return Variant(Any(AnyView::CopyFromTVMFFIAny(*src))); - } - - TVM_FFI_INLINE static Variant MoveFromAnyAfterCheck(TVMFFIAny* src) { - return Variant(details::AnyUnsafe::MoveTVMFFIAnyToAny(src)); - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { - // fast path, storage is already in the right type - if (CheckAnyStrict(src)) { - return CopyFromAnyViewAfterCheck(src); + TVM_FFI_INLINE static void CopyToAnyView(const Variant &src, TVMFFIAny *result) { + *result = src.ToAnyView().CopyToTVMFFIAny(); + } + + TVM_FFI_INLINE static void MoveToAny(Variant src, TVMFFIAny *result) { + *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src).MoveToAny()); + } + + TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny *src) { + return TypeTraitsBase::GetMismatchTypeInfo(src); + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { + return (TypeTraits::CheckAnyStrict(src) || ...); } - // More expensive path, try to convert to each type, in order of declaration - return TryVariantTypes(src); - } - - template - TVM_FFI_INLINE static std::optional> TryVariantTypes(const TVMFFIAny* src) { - if (auto opt_convert = TypeTraits::TryCastFromAnyView(src)) { - return Variant(*std::move(opt_convert)); + + TVM_FFI_INLINE static Variant CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { + return Variant(Any(AnyView::CopyFromTVMFFIAny(*src))); } - if constexpr (sizeof...(Rest) > 0) { - return TryVariantTypes(src); + + TVM_FFI_INLINE static Variant MoveFromAnyAfterCheck(TVMFFIAny *src) { + return Variant(details::AnyUnsafe::MoveTVMFFIAnyToAny(src)); + } + + TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny *src) { + // fast path, storage is already in the right type + if (CheckAnyStrict(src)) { + return CopyFromAnyViewAfterCheck(src); + } + // More expensive path, try to convert to each type, in order of declaration + return TryVariantTypes(src); + } + + template + TVM_FFI_INLINE static std::optional> TryVariantTypes(const TVMFFIAny *src) { + if (auto opt_convert = TypeTraits::TryCastFromAnyView(src)) { + return Variant(*std::move(opt_convert)); + } + if constexpr (sizeof...(Rest) > 0) { + return TryVariantTypes(src); + } + return std::nullopt; + } + + TVM_FFI_INLINE static std::string TypeStr() { return details::ContainerTypeStr("Variant"); } + TVM_FFI_INLINE static std::string TypeSchema() { + std::ostringstream oss; + oss << R"({"type":"Variant","args":[)"; + const char *sep = ""; + ((oss << sep << details::TypeSchema::v(), sep = ","), ...); + oss << "]}"; + return oss.str(); } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return details::ContainerTypeStr("Variant"); } - TVM_FFI_INLINE static std::string TypeSchema() { - std::ostringstream oss; - oss << R"({"type":"Variant","args":[)"; - const char* sep = ""; - ((oss << sep << details::TypeSchema::v(), sep = ","), ...); - oss << "]}"; - return oss.str(); - } }; template -TVM_FFI_INLINE size_t ObjectPtrHash::operator()(const Variant& a) const { - return std::hash()(a.GetObjectPtrForHashEqual()); +TVM_FFI_INLINE size_t ObjectPtrHash::operator()(const Variant &a) const { + return std::hash()(a.GetObjectPtrForHashEqual()); } template -TVM_FFI_INLINE bool ObjectPtrEqual::operator()(const Variant& a, - const Variant& b) const { - return a.GetObjectPtrForHashEqual() == b.GetObjectPtrForHashEqual(); +TVM_FFI_INLINE bool ObjectPtrEqual::operator()(const Variant &a, + const Variant &b) const { + return a.GetObjectPtrForHashEqual() == b.GetObjectPtrForHashEqual(); } namespace details { template inline constexpr bool type_contains_v, T> = (type_contains_v || ...); -} // namespace details -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_CONTAINER_VARIANT_H_ +} // namespace details +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_CONTAINER_VARIANT_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/endian.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/endian.h index 10639bea3..f8311d673 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/endian.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/endian.h @@ -38,8 +38,7 @@ #else #if defined(__APPLE__) || defined(_WIN32) #define TVM_FFI_LITTLE_ENDIAN 1 -#elif defined(__GLIBC__) || defined(__GNU_LIBRARY__) || defined(__ANDROID__) || \ - defined(__RISCV__) || defined(__MUSL__) +#elif defined(__GLIBC__) || defined(__GNU_LIBRARY__) || defined(__ANDROID__) || defined(__RISCV__) || defined(__MUSL__) #include #define TVM_FFI_LITTLE_ENDIAN (__BYTE_ORDER == __LITTLE_ENDIAN) #elif defined(__FreeBSD__) || defined(__OpenBSD__) @@ -75,16 +74,16 @@ namespace ffi { * \note Always try pass in constant elem_bytes to enable * compiler optimization */ -inline void ByteSwap(void* data, size_t elem_bytes, size_t num_elems) { - for (size_t i = 0; i < num_elems; ++i) { - uint8_t* bptr = reinterpret_cast(data) + elem_bytes * i; - for (size_t j = 0; j < elem_bytes / 2; ++j) { - uint8_t v = bptr[elem_bytes - 1 - j]; - bptr[elem_bytes - 1 - j] = bptr[j]; - bptr[j] = v; +inline void ByteSwap(void *data, size_t elem_bytes, size_t num_elems) { + for (size_t i = 0; i < num_elems; ++i) { + uint8_t *bptr = reinterpret_cast(data) + elem_bytes * i; + for (size_t j = 0; j < elem_bytes / 2; ++j) { + uint8_t v = bptr[elem_bytes - 1 - j]; + bptr[elem_bytes - 1 - j] = bptr[j]; + bptr[j] = v; + } } - } } -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_ENDIAN_H_ +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_ENDIAN_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base.h index b09b3540a..852db4466 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base.h @@ -45,4 +45,4 @@ #define TVM_FFI_EXTRA_CXX_API TVM_FFI_DLL #endif -#endif // TVM_FFI_EXTRA_BASE_H_ +#endif // TVM_FFI_EXTRA_BASE_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base64.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base64.h index ac92e9f84..f23314a75 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base64.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base64.h @@ -36,37 +36,36 @@ namespace ffi { * \return The base64 encoded string */ inline String Base64Encode(TVMFFIByteArray bytes) { - // encoding every 3 bytes into 4 characters - constexpr const char kEncodeTable[] = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - std::string encoded; - encoded.reserve(4 * (bytes.size + 2) / 3); + // encoding every 3 bytes into 4 characters + constexpr const char kEncodeTable[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + std::string encoded; + encoded.reserve(4 * (bytes.size + 2) / 3); - for (size_t i = 0; i < (bytes.size / 3) * 3; i += 3) { - int32_t buf[3]; - buf[0] = static_cast(static_cast(bytes.data[i])); - buf[1] = static_cast(static_cast(bytes.data[i + 1])); - buf[2] = static_cast(static_cast(bytes.data[i + 2])); - encoded.push_back(kEncodeTable[buf[0] >> 2]); - encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]); - encoded.push_back(kEncodeTable[((buf[1] << 2) | (buf[2] >> 6)) & 0x3F]); - encoded.push_back(kEncodeTable[buf[2] & 0x3F]); - } - if (bytes.size % 3 == 1) { - int32_t buf[1] = {static_cast(static_cast(bytes.data[bytes.size - 1]))}; - encoded.push_back(kEncodeTable[buf[0] >> 2]); - encoded.push_back(kEncodeTable[(buf[0] << 4) & 0x3F]); - encoded.push_back('='); - encoded.push_back('='); - } else if (bytes.size % 3 == 2) { - int32_t buf[2] = {static_cast(static_cast(bytes.data[bytes.size - 2])), - static_cast(static_cast(bytes.data[bytes.size - 1]))}; - encoded.push_back(kEncodeTable[buf[0] >> 2]); - encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]); - encoded.push_back(kEncodeTable[(buf[1] << 2) & 0x3F]); - encoded.push_back('='); - } - return String(encoded); + for (size_t i = 0; i < (bytes.size / 3) * 3; i += 3) { + int32_t buf[3]; + buf[0] = static_cast(static_cast(bytes.data[i])); + buf[1] = static_cast(static_cast(bytes.data[i + 1])); + buf[2] = static_cast(static_cast(bytes.data[i + 2])); + encoded.push_back(kEncodeTable[buf[0] >> 2]); + encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]); + encoded.push_back(kEncodeTable[((buf[1] << 2) | (buf[2] >> 6)) & 0x3F]); + encoded.push_back(kEncodeTable[buf[2] & 0x3F]); + } + if (bytes.size % 3 == 1) { + int32_t buf[1] = {static_cast(static_cast(bytes.data[bytes.size - 1]))}; + encoded.push_back(kEncodeTable[buf[0] >> 2]); + encoded.push_back(kEncodeTable[(buf[0] << 4) & 0x3F]); + encoded.push_back('='); + encoded.push_back('='); + } else if (bytes.size % 3 == 2) { + int32_t buf[2] = {static_cast(static_cast(bytes.data[bytes.size - 2])), + static_cast(static_cast(bytes.data[bytes.size - 1]))}; + encoded.push_back(kEncodeTable[buf[0] >> 2]); + encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]); + encoded.push_back(kEncodeTable[(buf[1] << 2) & 0x3F]); + encoded.push_back('='); + } + return String(encoded); } /*! @@ -74,8 +73,8 @@ inline String Base64Encode(TVMFFIByteArray bytes) { * \param data The bytes object to encode * \return The base64 encoded string */ -inline String Base64Encode(const Bytes& data) { - return Base64Encode(TVMFFIByteArray{data.data(), data.size()}); +inline String Base64Encode(const Bytes &data) { + return Base64Encode(TVMFFIByteArray{data.data(), data.size()}); } /*! @@ -84,48 +83,47 @@ inline String Base64Encode(const Bytes& data) { * \return The decoded byte array */ inline Bytes Base64Decode(TVMFFIByteArray bytes) { - constexpr const char kDecodeTable[] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 62, // '+' - 0, 0, 0, - 63, // '/' - 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' - 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' - 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, - 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' - }; - std::string decoded; - decoded.reserve(bytes.size * 3 / 4); - if (bytes.size == 0) return Bytes(); - TVM_FFI_ICHECK(bytes.size % 4 == 0) << "invalid base64 encoding"; - // leverage this property to simplify decoding - static_assert('=' < sizeof(kDecodeTable) && kDecodeTable[static_cast('=')] == 0); - // base64 is always multiple of 4 bytes - for (size_t i = 0; i < bytes.size; i += 4) { - // decode every 4 characters into 24bits, each character contains 6 bits - // note that = is also decoded as 0, which is safe to skip - int32_t buf[4] = { - static_cast(static_cast(bytes.data[i])), - static_cast(static_cast(bytes.data[i + 1])), - static_cast(static_cast(bytes.data[i + 2])), - static_cast(static_cast(bytes.data[i + 3])), + constexpr const char kDecodeTable[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 62, // '+' + 0, 0, 0, + 63, // '/' + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' + 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' + 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, + 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' }; - int32_t value_i24 = (static_cast(kDecodeTable[buf[0]]) << 18) | - (static_cast(kDecodeTable[buf[1]]) << 12) | - (static_cast(kDecodeTable[buf[2]]) << 6) | - static_cast(kDecodeTable[buf[3]]); - // unpack 24bits into 3 bytes, each contains 8 bits - decoded.push_back(static_cast((value_i24 >> 16) & 0xFF)); - if (buf[2] != '=') { - decoded.push_back(static_cast((value_i24 >> 8) & 0xFF)); + std::string decoded; + decoded.reserve(bytes.size * 3 / 4); + if (bytes.size == 0) { + return Bytes(); } - if (buf[3] != '=') { - decoded.push_back(static_cast(value_i24 & 0xFF)); + TVM_FFI_ICHECK(bytes.size % 4 == 0) << "invalid base64 encoding"; + // leverage this property to simplify decoding + static_assert('=' < sizeof(kDecodeTable) && kDecodeTable[static_cast('=')] == 0); + // base64 is always multiple of 4 bytes + for (size_t i = 0; i < bytes.size; i += 4) { + // decode every 4 characters into 24bits, each character contains 6 bits + // note that = is also decoded as 0, which is safe to skip + int32_t buf[4] = { + static_cast(static_cast(bytes.data[i])), + static_cast(static_cast(bytes.data[i + 1])), + static_cast(static_cast(bytes.data[i + 2])), + static_cast(static_cast(bytes.data[i + 3])), + }; + int32_t value_i24 = (static_cast(kDecodeTable[buf[0]]) << 18) | (static_cast(kDecodeTable[buf[1]]) << 12) | (static_cast(kDecodeTable[buf[2]]) << 6) | static_cast(kDecodeTable[buf[3]]); + // unpack 24bits into 3 bytes, each contains 8 bits + decoded.push_back(static_cast((value_i24 >> 16) & 0xFF)); + if (buf[2] != '=') { + decoded.push_back(static_cast((value_i24 >> 8) & 0xFF)); + } + if (buf[3] != '=') { + decoded.push_back(static_cast(value_i24 & 0xFF)); + } } - } - return Bytes(decoded); + return Bytes(decoded); } /*! @@ -133,10 +131,10 @@ inline Bytes Base64Decode(TVMFFIByteArray bytes) { * \param data The base64 encoded string to decode * \return The decoded byte array */ -inline Bytes Base64Decode(const String& data) { - return Base64Decode(TVMFFIByteArray{data.data(), data.size()}); +inline Bytes Base64Decode(const String &data) { + return Base64Decode(TVMFFIByteArray{data.data(), data.size()}); } -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_BASE64_H_ +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_BASE64_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/base.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/base.h index 810fa064c..be7cad2d8 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/base.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/base.h @@ -37,18 +37,18 @@ namespace ffi { * * \param stmt The CUDA runtime API call to check. */ -#define TVM_FFI_CHECK_CUDA_ERROR(stmt) \ - do { \ - cudaError_t __err = (stmt); \ - if (__err != cudaSuccess) { \ - const char* __err_name = cudaGetErrorName(__err); \ - const char* __err_str = cudaGetErrorString(__err); \ - TVM_FFI_THROW(RuntimeError) << "CUDA Runtime Error: " << __err_name << " (" \ - << static_cast(__err) << "): " << __err_str; \ - } \ - } while (0) +#define TVM_FFI_CHECK_CUDA_ERROR(stmt) \ + do { \ + cudaError_t __err = (stmt); \ + if (__err != cudaSuccess) { \ + const char *__err_name = cudaGetErrorName(__err); \ + const char *__err_str = cudaGetErrorString(__err); \ + TVM_FFI_THROW(RuntimeError) << "CUDA Runtime Error: " << __err_name << " (" \ + << static_cast(__err) << "): " << __err_str; \ + } \ + } while (0) -} // namespace ffi -} // namespace tvm +} // namespace ffi +} // namespace tvm -#endif // TVM_FFI_EXTRA_CUDA_BASE_H_ +#endif // TVM_FFI_EXTRA_CUDA_BASE_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/cubin_launcher.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/cubin_launcher.h index 72eadd2ea..10da7e532 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/cubin_launcher.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/cubin_launcher.h @@ -49,24 +49,24 @@ namespace ffi { * from 1, 2, or 3 dimensions. */ struct dim3 { - /*! \brief X dimension (number of blocks in x-direction or threads in x-direction) */ - unsigned int x; - /*! \brief Y dimension (number of blocks in y-direction or threads in y-direction) */ - unsigned int y; - /*! \brief Z dimension (number of blocks in z-direction or threads in z-direction) */ - unsigned int z; + /*! \brief X dimension (number of blocks in x-direction or threads in x-direction) */ + unsigned int x; + /*! \brief Y dimension (number of blocks in y-direction or threads in y-direction) */ + unsigned int y; + /*! \brief Z dimension (number of blocks in z-direction or threads in z-direction) */ + unsigned int z; - /*! \brief Default constructor initializes to (1, 1, 1) */ - dim3() : x(1), y(1), z(1) {} + /*! \brief Default constructor initializes to (1, 1, 1) */ + dim3() : x(1), y(1), z(1) {} - /*! \brief Construct with x dimension, y and z default to 1 */ - explicit dim3(unsigned int x_) : x(x_), y(1), z(1) {} + /*! \brief Construct with x dimension, y and z default to 1 */ + explicit dim3(unsigned int x_) : x(x_), y(1), z(1) {} - /*! \brief Construct with x and y dimensions, z defaults to 1 */ - dim3(unsigned int x_, unsigned int y_) : x(x_), y(y_), z(1) {} + /*! \brief Construct with x and y dimensions, z defaults to 1 */ + dim3(unsigned int x_, unsigned int y_) : x(x_), y(y_), z(1) {} - /*! \brief Construct with all three dimensions */ - dim3(unsigned int x_, unsigned int y_, unsigned int z_) : x(x_), y(y_), z(z_) {} + /*! \brief Construct with all three dimensions */ + dim3(unsigned int x_, unsigned int y_, unsigned int z_) : x(x_), y(y_), z(z_) {} }; /*! @@ -178,18 +178,18 @@ struct dim3 { * \see CubinModule * \see CubinKernel */ -#define TVM_FFI_EMBED_CUBIN(name) \ - extern "C" const char __tvm_ffi__cubin_##name[]; \ - extern "C" const char __tvm_ffi__cubin_##name##_end[]; \ - namespace { \ - struct EmbedCubinModule_##name { \ - tvm::ffi::CubinModule mod{__tvm_ffi__cubin_##name}; \ - static EmbedCubinModule_##name* Global() { \ - static EmbedCubinModule_##name inst; \ - return &inst; \ - } \ - }; \ - } /* anonymous namespace */ +#define TVM_FFI_EMBED_CUBIN(name) \ + extern "C" const char __tvm_ffi__cubin_##name[]; \ + extern "C" const char __tvm_ffi__cubin_##name##_end[]; \ + namespace { \ + struct EmbedCubinModule_##name { \ + tvm::ffi::CubinModule mod{__tvm_ffi__cubin_##name}; \ + static EmbedCubinModule_##name *Global() { \ + static EmbedCubinModule_##name inst; \ + return &inst; \ + } \ + }; \ + } /* anonymous namespace */ /*! * \brief Macro to get a kernel from an embedded CUBIN module. @@ -244,7 +244,7 @@ struct dim3 { * \see CubinKernel::Launch */ #define TVM_FFI_EMBED_CUBIN_GET_KERNEL(name, kernel_name) \ - (EmbedCubinModule_##name::Global()->mod[kernel_name]) + (EmbedCubinModule_##name::Global()->mod[kernel_name]) // Forward declaration class CubinKernel; @@ -284,102 +284,102 @@ class CubinKernel; * \see CubinKernel for kernel launching */ class CubinModule { - public: - /*! - * \brief Load CUBIN module from memory. - * - * \param bytes CUBIN binary data as a Bytes object. - */ - explicit CubinModule(const Bytes& bytes) { - TVM_FFI_CHECK_CUDA_ERROR( - cudaLibraryLoadData(&library_, bytes.data(), nullptr, nullptr, 0, nullptr, nullptr, 0)); - } - - /*! - * \brief Load CUBIN module from raw memory buffer. - * - * \param code Pointer to CUBIN binary data. - * \note The `code` buffer points to an ELF image. - */ - explicit CubinModule(const char* code) { - TVM_FFI_CHECK_CUDA_ERROR( - cudaLibraryLoadData(&library_, code, nullptr, nullptr, 0, nullptr, nullptr, 0)); - } - - /*! \brief Destructor unloads the library */ - ~CubinModule() { - if (library_ != nullptr) { - cudaLibraryUnload(library_); +public: + /*! + * \brief Load CUBIN module from memory. + * + * \param bytes CUBIN binary data as a Bytes object. + */ + explicit CubinModule(const Bytes &bytes) { + TVM_FFI_CHECK_CUDA_ERROR( + cudaLibraryLoadData(&library_, bytes.data(), nullptr, nullptr, 0, nullptr, nullptr, 0)); } - } - - /*! - * \brief Get a kernel function from the module by name. - * - * \param name Name of the kernel function. - * \return CubinKernel object representing the loaded kernel. - */ - CubinKernel GetKernel(const char* name); - - /*! - * \brief Get a kernel function from the module by name with maximum dynamic shared memory. - * - * \param name Name of the kernel function. - * \param dynamic_smem_max Maximum dynamic shared memory in bytes to set for this kernel. - * -1 (default) means maximum available dynamic shared memory - * (device max - static shared memory used by kernel). - * \return CubinKernel object representing the loaded kernel. - */ - CubinKernel GetKernelWithMaxDynamicSharedMemory(const char* name, int64_t dynamic_smem_max); - - /*! - * \brief Operator[] for convenient kernel access. - * - * It's equivalent to calling GetKernel(name, -1). - * - * \param name Name of the kernel function. - * \return CubinKernel object representing the loaded kernel. - */ - CubinKernel operator[](const char* name); - - /*! \brief Get the underlying cudaLibrary_t handle */ - cudaLibrary_t GetHandle() const { return library_; } - - // Non-copyable - CubinModule(const CubinModule&) = delete; - CubinModule& operator=(const CubinModule&) = delete; - - /*! - * \brief Move constructor for CubinModule. - * - * Transfers ownership of the CUDA library handle from another CubinModule instance. - * - * \param other The source CubinModule to move from (will be left in an empty state). - */ - CubinModule(CubinModule&& other) noexcept : library_(other.library_) { other.library_ = nullptr; } - - /*! - * \brief Move assignment operator for CubinModule. - * - * Transfers ownership of the CUDA library handle from another CubinModule instance. - * Cleans up any existing library handle in this instance before taking ownership. - * - * \param other The source CubinModule to move from (will be left in an empty state). - * \return Reference to this CubinModule. - */ - CubinModule& operator=(CubinModule&& other) noexcept { - if (this != &other) { - if (library_ != nullptr) { - cudaLibraryUnload(library_); - } - library_ = other.library_; - other.library_ = nullptr; + + /*! + * \brief Load CUBIN module from raw memory buffer. + * + * \param code Pointer to CUBIN binary data. + * \note The `code` buffer points to an ELF image. + */ + explicit CubinModule(const char *code) { + TVM_FFI_CHECK_CUDA_ERROR( + cudaLibraryLoadData(&library_, code, nullptr, nullptr, 0, nullptr, nullptr, 0)); + } + + /*! \brief Destructor unloads the library */ + ~CubinModule() { + if (library_ != nullptr) { + cudaLibraryUnload(library_); + } + } + + /*! + * \brief Get a kernel function from the module by name. + * + * \param name Name of the kernel function. + * \return CubinKernel object representing the loaded kernel. + */ + CubinKernel GetKernel(const char *name); + + /*! + * \brief Get a kernel function from the module by name with maximum dynamic shared memory. + * + * \param name Name of the kernel function. + * \param dynamic_smem_max Maximum dynamic shared memory in bytes to set for this kernel. + * -1 (default) means maximum available dynamic shared memory + * (device max - static shared memory used by kernel). + * \return CubinKernel object representing the loaded kernel. + */ + CubinKernel GetKernelWithMaxDynamicSharedMemory(const char *name, int64_t dynamic_smem_max); + + /*! + * \brief Operator[] for convenient kernel access. + * + * It's equivalent to calling GetKernel(name, -1). + * + * \param name Name of the kernel function. + * \return CubinKernel object representing the loaded kernel. + */ + CubinKernel operator[](const char *name); + + /*! \brief Get the underlying cudaLibrary_t handle */ + cudaLibrary_t GetHandle() const { return library_; } + + // Non-copyable + CubinModule(const CubinModule &) = delete; + CubinModule &operator=(const CubinModule &) = delete; + + /*! + * \brief Move constructor for CubinModule. + * + * Transfers ownership of the CUDA library handle from another CubinModule instance. + * + * \param other The source CubinModule to move from (will be left in an empty state). + */ + CubinModule(CubinModule &&other) noexcept : library_(other.library_) { other.library_ = nullptr; } + + /*! + * \brief Move assignment operator for CubinModule. + * + * Transfers ownership of the CUDA library handle from another CubinModule instance. + * Cleans up any existing library handle in this instance before taking ownership. + * + * \param other The source CubinModule to move from (will be left in an empty state). + * \return Reference to this CubinModule. + */ + CubinModule &operator=(CubinModule &&other) noexcept { + if (this != &other) { + if (library_ != nullptr) { + cudaLibraryUnload(library_); + } + library_ = other.library_; + other.library_ = nullptr; + } + return *this; } - return *this; - } - private: - cudaLibrary_t library_ = nullptr; +private: + cudaLibrary_t library_ = nullptr; }; /*! @@ -414,191 +414,191 @@ class CubinModule { * \see dim3 for grid/block dimension specification */ class CubinKernel { - public: - /*! - * \brief Construct a CubinKernel from a library and kernel name. - * - * \param library The cudaLibrary_t handle. - * \param name Name of the kernel function. - */ - CubinKernel(cudaLibrary_t library, const char* name) { - TVM_FFI_CHECK_CUDA_ERROR(cudaLibraryGetKernel(&kernel_, library, name)); - } - - /*! \brief Destructor (kernel handle doesn't need explicit cleanup) */ - ~CubinKernel() = default; - - /*! - * \brief Launch the kernel with specified parameters. - * - * This function launches the kernel on the current CUDA context/device using - * the CUDA Runtime API. The kernel executes asynchronously on the specified stream. - * - * \par Argument Preparation - * The `args` array must contain pointers to the actual argument values, not the - * values themselves. For example: - * \code{.cpp} - * void* data_ptr = tensor.data_ptr(); - * int64_t size = 100; - * void* args[] = {&data_ptr, &size}; // Note: addresses of the variables - * \endcode - * - * \par Launch Configuration - * Grid and block dimensions determine the kernel's parallelism: - * - Grid: Number of thread blocks (can be 1D, 2D, or 3D) - * - Block: Number of threads per block (can be 1D, 2D, or 3D) - * - Total threads = grid.x * grid.y * grid.z * block.x * block.y * block.z - * - * \par Error Checking - * Always check the returned cudaError_t: - * \code{.cpp} - * cudaError_t result = kernel.Launch(args, grid, block, stream); - * TVM_FFI_CHECK_CUDA_ERROR(result); - * \endcode - * - * \param args Array of pointers to kernel arguments (must point to actual values). - * \param grid Grid dimensions (number of blocks in x, y, z). - * \param block Block dimensions (threads per block in x, y, z). - * \param stream CUDA stream to launch the kernel on (use 0 for default stream). - * \param dyn_smem_bytes Dynamic shared memory size in bytes (default: 0). - * \return cudaError_t error code from cudaLaunchKernel (cudaSuccess on success). - * - * \note The kernel executes asynchronously. Use cudaStreamSynchronize() or - * cudaDeviceSynchronize() to wait for completion if needed. - */ - cudaError_t Launch(void** args, dim3 grid, dim3 block, cudaStream_t stream, - uint32_t dyn_smem_bytes = 0) { - // Cast cudaKernel_t to const void* for use with cudaLaunchKernel - // The Runtime API accepts cudaKernel_t directly as a function pointer - auto kernel = reinterpret_cast(kernel_); - return cudaLaunchKernel(kernel, {grid.x, grid.y, grid.z}, {block.x, block.y, block.z}, args, - dyn_smem_bytes, stream); - } - - /*! \brief Get the underlying cudaKernel_t handle */ - cudaKernel_t GetHandle() const { return kernel_; } - - // Non-copyable - CubinKernel(const CubinKernel&) = delete; - CubinKernel& operator=(const CubinKernel&) = delete; - - /*! - * \brief Move constructor for CubinKernel. - * - * Transfers ownership of the CUDA kernel handle from another CubinKernel instance. - * - * \param other The source CubinKernel to move from (will be left in an empty state). - */ - CubinKernel(CubinKernel&& other) noexcept : kernel_(other.kernel_) { other.kernel_ = nullptr; } - - /*! - * \brief Move assignment operator for CubinKernel. - * - * Transfers ownership of the CUDA kernel handle from another CubinKernel instance. - * - * \param other The source CubinKernel to move from (will be left in an empty state). - * \return Reference to this CubinKernel. - */ - CubinKernel& operator=(CubinKernel&& other) noexcept { - if (this != &other) { - kernel_ = other.kernel_; - other.kernel_ = nullptr; +public: + /*! + * \brief Construct a CubinKernel from a library and kernel name. + * + * \param library The cudaLibrary_t handle. + * \param name Name of the kernel function. + */ + CubinKernel(cudaLibrary_t library, const char *name) { + TVM_FFI_CHECK_CUDA_ERROR(cudaLibraryGetKernel(&kernel_, library, name)); } - return *this; - } - - private: - /*! - * \brief Set maximum dynamic shared memory for this kernel across all devices. - * - * This method configures the maximum dynamic shared memory that can be allocated - * when launching this kernel. It must be called after the kernel is loaded. - * - * \param dynamic_smem_max Maximum dynamic shared memory in bytes to set. - * -1 (default) means maximum available dynamic shared memory, - * which is computed as (device max shared memory - static shared memory). - * For -1, the method queries the kernel's static shared memory usage - * and sets the attribute to the remaining available shared memory. - * - * \note This sets the maximum cap but doesn't force allocation. The actual dynamic - * shared memory used is controlled by the dyn_smem_bytes parameter in Launch(). - * \note This method attempts to set the attribute for all available devices and will - * only throw an error if it fails for ALL devices. - */ - void SetMaxDynamicSharedMemory(int64_t dynamic_smem_max = -1) { - int device_count = 0; - cudaError_t err = cudaGetDeviceCount(&device_count); - if (err != cudaSuccess || device_count == 0) { - return; // No devices available, nothing to configure + + /*! \brief Destructor (kernel handle doesn't need explicit cleanup) */ + ~CubinKernel() = default; + + /*! + * \brief Launch the kernel with specified parameters. + * + * This function launches the kernel on the current CUDA context/device using + * the CUDA Runtime API. The kernel executes asynchronously on the specified stream. + * + * \par Argument Preparation + * The `args` array must contain pointers to the actual argument values, not the + * values themselves. For example: + * \code{.cpp} + * void* data_ptr = tensor.data_ptr(); + * int64_t size = 100; + * void* args[] = {&data_ptr, &size}; // Note: addresses of the variables + * \endcode + * + * \par Launch Configuration + * Grid and block dimensions determine the kernel's parallelism: + * - Grid: Number of thread blocks (can be 1D, 2D, or 3D) + * - Block: Number of threads per block (can be 1D, 2D, or 3D) + * - Total threads = grid.x * grid.y * grid.z * block.x * block.y * block.z + * + * \par Error Checking + * Always check the returned cudaError_t: + * \code{.cpp} + * cudaError_t result = kernel.Launch(args, grid, block, stream); + * TVM_FFI_CHECK_CUDA_ERROR(result); + * \endcode + * + * \param args Array of pointers to kernel arguments (must point to actual values). + * \param grid Grid dimensions (number of blocks in x, y, z). + * \param block Block dimensions (threads per block in x, y, z). + * \param stream CUDA stream to launch the kernel on (use 0 for default stream). + * \param dyn_smem_bytes Dynamic shared memory size in bytes (default: 0). + * \return cudaError_t error code from cudaLaunchKernel (cudaSuccess on success). + * + * \note The kernel executes asynchronously. Use cudaStreamSynchronize() or + * cudaDeviceSynchronize() to wait for completion if needed. + */ + cudaError_t Launch(void **args, dim3 grid, dim3 block, cudaStream_t stream, + uint32_t dyn_smem_bytes = 0) { + // Cast cudaKernel_t to const void* for use with cudaLaunchKernel + // The Runtime API accepts cudaKernel_t directly as a function pointer + auto kernel = reinterpret_cast(kernel_); + return cudaLaunchKernel(kernel, {grid.x, grid.y, grid.z}, {block.x, block.y, block.z}, args, + dyn_smem_bytes, stream); } - bool any_success = false; - for (int device_id = 0; device_id < device_count; ++device_id) { - // Query device's maximum shared memory per block - int max_shared_mem = 0; - err = cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlock, device_id); - if (err != cudaSuccess) { - continue; // Skip this device if we can't get its attribute - } - - int shared_mem_to_set; - if (dynamic_smem_max == -1) { - // Query the kernel's static shared memory usage - cudaFuncAttributes func_attr; - - // According to the documentation, we can use cudaFuncGetAttributes to get the attributes of - // cudaKernel_t returned by cudaLibraryGetKernel, just cast the kernel_ to const void* - err = cudaFuncGetAttributes(&func_attr, reinterpret_cast(kernel_)); - if (err != cudaSuccess) { - continue; // Skip this device if we can't get kernel attributes + /*! \brief Get the underlying cudaKernel_t handle */ + cudaKernel_t GetHandle() const { return kernel_; } + + // Non-copyable + CubinKernel(const CubinKernel &) = delete; + CubinKernel &operator=(const CubinKernel &) = delete; + + /*! + * \brief Move constructor for CubinKernel. + * + * Transfers ownership of the CUDA kernel handle from another CubinKernel instance. + * + * \param other The source CubinKernel to move from (will be left in an empty state). + */ + CubinKernel(CubinKernel &&other) noexcept : kernel_(other.kernel_) { other.kernel_ = nullptr; } + + /*! + * \brief Move assignment operator for CubinKernel. + * + * Transfers ownership of the CUDA kernel handle from another CubinKernel instance. + * + * \param other The source CubinKernel to move from (will be left in an empty state). + * \return Reference to this CubinKernel. + */ + CubinKernel &operator=(CubinKernel &&other) noexcept { + if (this != &other) { + kernel_ = other.kernel_; + other.kernel_ = nullptr; } - - // Calculate available dynamic shared memory: - // device max shared memory - static shared memory used by kernel - int64_t static_shared = static_cast(func_attr.sharedSizeBytes); - int64_t max_shared = static_cast(max_shared_mem); - int64_t available = max_shared - static_shared; - shared_mem_to_set = (available > 0) ? static_cast(available) : 0; - } else { - shared_mem_to_set = static_cast(dynamic_smem_max); - } - - // Set the maximum dynamic shared memory size for this device - err = cudaKernelSetAttributeForDevice(kernel_, cudaFuncAttributeMaxDynamicSharedMemorySize, - shared_mem_to_set, device_id); - if (err == cudaSuccess) { - any_success = true; - } - // Don't error out for individual device failures - user may only use some GPUs + return *this; } - // Only error out if setting failed for ALL devices - if (!any_success && device_count > 0) { - TVM_FFI_THROW(RuntimeError) << "Failed to set dynamic shared memory attribute for any device"; +private: + /*! + * \brief Set maximum dynamic shared memory for this kernel across all devices. + * + * This method configures the maximum dynamic shared memory that can be allocated + * when launching this kernel. It must be called after the kernel is loaded. + * + * \param dynamic_smem_max Maximum dynamic shared memory in bytes to set. + * -1 (default) means maximum available dynamic shared memory, + * which is computed as (device max shared memory - static shared memory). + * For -1, the method queries the kernel's static shared memory usage + * and sets the attribute to the remaining available shared memory. + * + * \note This sets the maximum cap but doesn't force allocation. The actual dynamic + * shared memory used is controlled by the dyn_smem_bytes parameter in Launch(). + * \note This method attempts to set the attribute for all available devices and will + * only throw an error if it fails for ALL devices. + */ + void SetMaxDynamicSharedMemory(int64_t dynamic_smem_max = -1) { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + return; // No devices available, nothing to configure + } + + bool any_success = false; + for (int device_id = 0; device_id < device_count; ++device_id) { + // Query device's maximum shared memory per block + int max_shared_mem = 0; + err = cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlock, device_id); + if (err != cudaSuccess) { + continue; // Skip this device if we can't get its attribute + } + + int shared_mem_to_set; + if (dynamic_smem_max == -1) { + // Query the kernel's static shared memory usage + cudaFuncAttributes func_attr; + + // According to the documentation, we can use cudaFuncGetAttributes to get the attributes of + // cudaKernel_t returned by cudaLibraryGetKernel, just cast the kernel_ to const void* + err = cudaFuncGetAttributes(&func_attr, reinterpret_cast(kernel_)); + if (err != cudaSuccess) { + continue; // Skip this device if we can't get kernel attributes + } + + // Calculate available dynamic shared memory: + // device max shared memory - static shared memory used by kernel + int64_t static_shared = static_cast(func_attr.sharedSizeBytes); + int64_t max_shared = static_cast(max_shared_mem); + int64_t available = max_shared - static_shared; + shared_mem_to_set = (available > 0) ? static_cast(available) : 0; + } else { + shared_mem_to_set = static_cast(dynamic_smem_max); + } + + // Set the maximum dynamic shared memory size for this device + err = cudaKernelSetAttributeForDevice(kernel_, cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem_to_set, device_id); + if (err == cudaSuccess) { + any_success = true; + } + // Don't error out for individual device failures - user may only use some GPUs + } + + // Only error out if setting failed for ALL devices + if (!any_success && device_count > 0) { + TVM_FFI_THROW(RuntimeError) << "Failed to set dynamic shared memory attribute for any device"; + } } - } - cudaKernel_t kernel_ = nullptr; + cudaKernel_t kernel_ = nullptr; - friend class CubinModule; + friend class CubinModule; }; // Implementation of CubinModule methods that return CubinKernel -inline CubinKernel CubinModule::GetKernelWithMaxDynamicSharedMemory(const char* name, +inline CubinKernel CubinModule::GetKernelWithMaxDynamicSharedMemory(const char *name, int64_t dynamic_smem_max = -1) { - auto kernel = CubinKernel(library_, name); - kernel.SetMaxDynamicSharedMemory(dynamic_smem_max); - return kernel; + auto kernel = CubinKernel(library_, name); + kernel.SetMaxDynamicSharedMemory(dynamic_smem_max); + return kernel; } -inline CubinKernel CubinModule::GetKernel(const char* name) { - auto kernel = CubinKernel(library_, name); - return kernel; +inline CubinKernel CubinModule::GetKernel(const char *name) { + auto kernel = CubinKernel(library_, name); + return kernel; } -inline CubinKernel CubinModule::operator[](const char* name) { return GetKernel(name); } +inline CubinKernel CubinModule::operator[](const char *name) { return GetKernel(name); } -} // namespace ffi -} // namespace tvm +} // namespace ffi +} // namespace tvm -#endif // TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_ +#endif // TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/device_guard.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/device_guard.h index 083580f76..b55b5f3b2 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/device_guard.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/device_guard.h @@ -42,33 +42,33 @@ namespace ffi { * \endcode */ struct CUDADeviceGuard { - CUDADeviceGuard() = delete; - /*! - * \brief Constructor from a device index, and store the original device index. - * \param device_index The device index to guard. - */ - explicit CUDADeviceGuard(int device_index) { - target_device_index_ = device_index; - TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&original_device_index_)); - if (target_device_index_ != original_device_index_) { - TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(device_index)); + CUDADeviceGuard() = delete; + /*! + * \brief Constructor from a device index, and store the original device index. + * \param device_index The device index to guard. + */ + explicit CUDADeviceGuard(int device_index) { + target_device_index_ = device_index; + TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&original_device_index_)); + if (target_device_index_ != original_device_index_) { + TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(device_index)); + } } - } - /*! - * \brief Destructor to set the current device index back to original one if different. - */ - ~CUDADeviceGuard() noexcept(false) { - if (original_device_index_ != target_device_index_) { - TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(original_device_index_)); + /*! + * \brief Destructor to set the current device index back to original one if different. + */ + ~CUDADeviceGuard() noexcept(false) { + if (original_device_index_ != target_device_index_) { + TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(original_device_index_)); + } } - } - private: - int original_device_index_; - int target_device_index_; +private: + int original_device_index_; + int target_device_index_; }; -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_ +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/json.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/json.h index 24ab2f0d8..c871ae827 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/json.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/json.h @@ -65,7 +65,7 @@ using Array = ffi::Array; * * \return The parsed Any value. */ -TVM_FFI_EXTRA_CXX_API json::Value Parse(const String& json_str, String* error_msg = nullptr); +TVM_FFI_EXTRA_CXX_API json::Value Parse(const String &json_str, String *error_msg = nullptr); /*! * \brief Serialize an Any value into a JSON string. @@ -75,10 +75,10 @@ TVM_FFI_EXTRA_CXX_API json::Value Parse(const String& json_str, String* error_ms * If not specified, the output will be compact. * \return The output JSON string. */ -TVM_FFI_EXTRA_CXX_API String Stringify(const json::Value& value, +TVM_FFI_EXTRA_CXX_API String Stringify(const json::Value &value, Optional indent = std::nullopt); -} // namespace json -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_JSON_H_ +} // namespace json +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_JSON_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/module.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/module.h index 6af26c252..06fc7849d 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/module.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/module.h @@ -41,165 +41,165 @@ class Module; * \sa Module */ class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object { - public: - /*! - * \return The per module type key. - * \note This key is used to for serializing custom modules. - */ - virtual const char* kind() const = 0; - /*! - * \brief Get the property mask of the module. - * \return The property mask of the module. - * - * \sa Module::ModulePropertyMask - */ - virtual int GetPropertyMask() const { return 0b000; } - /*! - * \brief Get a ffi::Function from the module. - * \param name The name of the function. - * \return The function. - */ - virtual Optional GetFunction(const String& name) = 0; - /*! - * \brief Returns true if this module has a definition for a function of \p name. - * - * Note that even if this function returns true the corresponding \p GetFunction result - * may be nullptr if the function is not yet callable without further compilation. - * - * The default implementation just checks if \p GetFunction is non-null. - * \param name The name of the function. - * \return True if the module implements the function, false otherwise. - */ - virtual bool ImplementsFunction(const String& name) { return GetFunction(name).defined(); } - /*! - * \brief Get the docstring of the function, if available. - * \param name The name of the function. - * \return The documentation string if available, nullopt otherwise. - * - * \sa GetFunctionMetadata, TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC - */ - virtual Optional GetFunctionDoc(const String& name) { return std::nullopt; } - // Rationale: We separate the docstring from the metadata since docstrings - // can be unstructured and sometimes large, while metadata can be focused - // on storing structured information. - /*! - * \brief Get the metadata of the function, if available. - * \param name The name of the function. - * \return The metadata as JSON string if available, nullopt otherwise. - * - * \code - * Module mod = Module::LoadFromFile("lib.so"); - * Optional metadata = mod->GetFunctionMetadata("my_func"); - * if (metadata.has_value()) { - * // Parse JSON: {"type_schema": "..."} - * validate_signature(*metadata); - * } - * \endcode - * - * \sa GetFunctionDoc, TVM_FFI_DLL_EXPORT_TYPED_FUNC - */ - virtual Optional GetFunctionMetadata(const String& name) { return std::nullopt; } - /*! - * \brief Write the current module to file with given format (for further compilation). - * - * \param file_name The file to be saved to. - * \param format The format of the file. - * - * \note This function is mainly used by modules that - */ - virtual void WriteToFile(const String& file_name, const String& format) const { - TVM_FFI_THROW(RuntimeError) << "Module[" << kind() << "] does not support WriteToFile"; - } - /*! - * \brief Get the possible write formats of the module, when available. - * \return Possible write formats when available. - */ - virtual Array GetWriteFormats() const { return Array(); } - /*! - * \brief Serialize the the module to bytes. - * \return The serialized module. - */ - virtual Bytes SaveToBytes() const { - TVM_FFI_THROW(RuntimeError) << "Module[" << kind() << "] does not support SaveToBytes"; - TVM_FFI_UNREACHABLE(); - } - /*! - * \brief Get the source code of module, when available. - * \param format Format of the source code, can be empty by default. - * \return Possible source code when available, or empty string if not available. - */ - virtual String InspectSource(const String& format) const { return String(); } - /*! - * \brief Import another module. - * \param other The module to import. - */ - virtual void ImportModule(const Module& other); - /*! - * \brief Clear all imported modules. - */ - virtual void ClearImports(); - /*! - * \brief Overloaded function to optionally query from imports. - * \param name The name of the function. - * \param query_imports Whether to query imported modules. - * \return The function. - */ - Optional GetFunction(const String& name, bool query_imports); - /*! - * \brief Overloaded function to optionally query from imports. - * \param name The name of the function. - * \param query_imports Whether to query imported modules. - * \return True if the module implements the function, false otherwise. - */ - bool ImplementsFunction(const String& name, bool query_imports); - /*! - * \brief Get the function docstring of the function if available. - * \param name The name of the function. - * \param query_imports Whether to also query modules imported by this module. - * \return The documentation string if available, nullopt otherwise. - * - * \sa GetFunctionMetadata - */ - Optional GetFunctionDoc(const String& name, bool query_imports); - /*! - * \brief Get the function metadata of the function if available. - * \param name The name of the function. - * \param query_imports Whether to also query modules imported by this module. - * \return The metadata as JSON string if available, nullopt otherwise. - * - * \sa GetFunctionDoc - */ - Optional GetFunctionMetadata(const String& name, bool query_imports); - /*! - * \brief Get the imports of the module. - * \return The imports of the module. - * \note Note the signature is not part of the public API. - */ - const Array& imports() const { return this->imports_; } +public: + /*! + * \return The per module type key. + * \note This key is used to for serializing custom modules. + */ + virtual const char *kind() const = 0; + /*! + * \brief Get the property mask of the module. + * \return The property mask of the module. + * + * \sa Module::ModulePropertyMask + */ + virtual int GetPropertyMask() const { return 0b000; } + /*! + * \brief Get a ffi::Function from the module. + * \param name The name of the function. + * \return The function. + */ + virtual Optional GetFunction(const String &name) = 0; + /*! + * \brief Returns true if this module has a definition for a function of \p name. + * + * Note that even if this function returns true the corresponding \p GetFunction result + * may be nullptr if the function is not yet callable without further compilation. + * + * The default implementation just checks if \p GetFunction is non-null. + * \param name The name of the function. + * \return True if the module implements the function, false otherwise. + */ + virtual bool ImplementsFunction(const String &name) { return GetFunction(name).defined(); } + /*! + * \brief Get the docstring of the function, if available. + * \param name The name of the function. + * \return The documentation string if available, nullopt otherwise. + * + * \sa GetFunctionMetadata, TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC + */ + virtual Optional GetFunctionDoc(const String &name) { return std::nullopt; } + // Rationale: We separate the docstring from the metadata since docstrings + // can be unstructured and sometimes large, while metadata can be focused + // on storing structured information. + /*! + * \brief Get the metadata of the function, if available. + * \param name The name of the function. + * \return The metadata as JSON string if available, nullopt otherwise. + * + * \code + * Module mod = Module::LoadFromFile("lib.so"); + * Optional metadata = mod->GetFunctionMetadata("my_func"); + * if (metadata.has_value()) { + * // Parse JSON: {"type_schema": "..."} + * validate_signature(*metadata); + * } + * \endcode + * + * \sa GetFunctionDoc, TVM_FFI_DLL_EXPORT_TYPED_FUNC + */ + virtual Optional GetFunctionMetadata(const String &name) { return std::nullopt; } + /*! + * \brief Write the current module to file with given format (for further compilation). + * + * \param file_name The file to be saved to. + * \param format The format of the file. + * + * \note This function is mainly used by modules that + */ + virtual void WriteToFile(const String &file_name, const String &format) const { + TVM_FFI_THROW(RuntimeError) << "Module[" << kind() << "] does not support WriteToFile"; + } + /*! + * \brief Get the possible write formats of the module, when available. + * \return Possible write formats when available. + */ + virtual Array GetWriteFormats() const { return Array(); } + /*! + * \brief Serialize the the module to bytes. + * \return The serialized module. + */ + virtual Bytes SaveToBytes() const { + TVM_FFI_THROW(RuntimeError) << "Module[" << kind() << "] does not support SaveToBytes"; + TVM_FFI_UNREACHABLE(); + } + /*! + * \brief Get the source code of module, when available. + * \param format Format of the source code, can be empty by default. + * \return Possible source code when available, or empty string if not available. + */ + virtual String InspectSource(const String &format) const { return String(); } + /*! + * \brief Import another module. + * \param other The module to import. + */ + virtual void ImportModule(const Module &other); + /*! + * \brief Clear all imported modules. + */ + virtual void ClearImports(); + /*! + * \brief Overloaded function to optionally query from imports. + * \param name The name of the function. + * \param query_imports Whether to query imported modules. + * \return The function. + */ + Optional GetFunction(const String &name, bool query_imports); + /*! + * \brief Overloaded function to optionally query from imports. + * \param name The name of the function. + * \param query_imports Whether to query imported modules. + * \return True if the module implements the function, false otherwise. + */ + bool ImplementsFunction(const String &name, bool query_imports); + /*! + * \brief Get the function docstring of the function if available. + * \param name The name of the function. + * \param query_imports Whether to also query modules imported by this module. + * \return The documentation string if available, nullopt otherwise. + * + * \sa GetFunctionMetadata + */ + Optional GetFunctionDoc(const String &name, bool query_imports); + /*! + * \brief Get the function metadata of the function if available. + * \param name The name of the function. + * \param query_imports Whether to also query modules imported by this module. + * \return The metadata as JSON string if available, nullopt otherwise. + * + * \sa GetFunctionDoc + */ + Optional GetFunctionMetadata(const String &name, bool query_imports); + /*! + * \brief Get the imports of the module. + * \return The imports of the module. + * \note Note the signature is not part of the public API. + */ + const Array &imports() const { return this->imports_; } - struct InternalUnsafe; + struct InternalUnsafe; - /// \cond Doxygen_Suppress - static constexpr const int32_t _type_index = TypeIndex::kTVMFFIModule; - static constexpr const bool _type_mutable = true; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIModule, ModuleObj, Object); - /// \endcond + /// \cond Doxygen_Suppress + static constexpr const int32_t _type_index = TypeIndex::kTVMFFIModule; + static constexpr const bool _type_mutable = true; + static const constexpr bool _type_final = true; + TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIModule, ModuleObj, Object); + /// \endcond - protected: - friend struct InternalUnsafe; +protected: + friend struct InternalUnsafe; - /*! - * \brief The modules that this module depends on. - * \note Use ObjectRef to avoid circular dep on Module. - */ - Array imports_; + /*! + * \brief The modules that this module depends on. + * \note Use ObjectRef to avoid circular dep on Module. + */ + Array imports_; - private: - /*! - * \brief cache used by TVMFFIModuleLookupFromImports - */ - Map import_lookup_cache_; +private: + /*! + * \brief cache used by TVMFFIModuleLookupFromImports + */ + Map import_lookup_cache_; }; /*! @@ -216,63 +216,63 @@ class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object { * \sa ModuleObj which contains most of the function implementations. */ class Module : public ObjectRef { - public: - /*! - * \brief Property of ffi::Module - */ - enum ModulePropertyMask : int { +public: /*! - * \brief The module can be serialized to bytes. - * - * This prooperty indicates that module implements SaveToBytes. - * The system also registers a GlobalDef function - * `ffi.Module.load_from_bytes.` with signature (Bytes) -> Module. + * \brief Property of ffi::Module */ - kBinarySerializable = 0b001, + enum ModulePropertyMask : int { + /*! + * \brief The module can be serialized to bytes. + * + * This prooperty indicates that module implements SaveToBytes. + * The system also registers a GlobalDef function + * `ffi.Module.load_from_bytes.` with signature (Bytes) -> Module. + */ + kBinarySerializable = 0b001, + /*! + * \brief The module can directly get runnable functions. + * + * This property indicates that module implements GetFunction that returns + * runnable ffi::Functions. + */ + kRunnable = 0b010, + /*! + * \brief The module can be exported to a object file or source file that then be compiled. + * + * This property indicates that module implements WriteToFile with a given format + * that can be queried by GetLibExportFormat. + * + * Examples include modules that can be exported to .o, .cc, .cu files. + * + * Such modules can be exported, compiled and loaded back as a dynamic library module. + */ + kCompilationExportable = 0b100 + }; /*! - * \brief The module can directly get runnable functions. - * - * This property indicates that module implements GetFunction that returns - * runnable ffi::Functions. + * \brief Constructor from ObjectPtr. + * \param ptr The object pointer. */ - kRunnable = 0b010, + explicit Module(const ObjectPtr &ptr) : ObjectRef(ptr) { + TVM_FFI_ICHECK(ptr != nullptr); + } /*! - * \brief The module can be exported to a object file or source file that then be compiled. - * - * This property indicates that module implements WriteToFile with a given format - * that can be queried by GetLibExportFormat. - * - * Examples include modules that can be exported to .o, .cc, .cu files. - * - * Such modules can be exported, compiled and loaded back as a dynamic library module. + * \brief Load a module from file. + * \param file_name The name of the host function module. + * \note This function won't load the import relationship. + * Re-create import relationship by calling Import. + */ + TVM_FFI_EXTRA_CXX_API static Module LoadFromFile(const String &file_name); + /*! + * \brief Query context symbols that is registered via TVMEnvRegisterSymbols. + * \param callback The callback to be called with the symbol name and address. + * \note This helper can be used to implement custom Module that needs to access context symbols. */ - kCompilationExportable = 0b100 - }; - /*! - * \brief Constructor from ObjectPtr. - * \param ptr The object pointer. - */ - explicit Module(const ObjectPtr& ptr) : ObjectRef(ptr) { - TVM_FFI_ICHECK(ptr != nullptr); - } - /*! - * \brief Load a module from file. - * \param file_name The name of the host function module. - * \note This function won't load the import relationship. - * Re-create import relationship by calling Import. - */ - TVM_FFI_EXTRA_CXX_API static Module LoadFromFile(const String& file_name); - /*! - * \brief Query context symbols that is registered via TVMEnvRegisterSymbols. - * \param callback The callback to be called with the symbol name and address. - * \note This helper can be used to implement custom Module that needs to access context symbols. - */ - TVM_FFI_EXTRA_CXX_API static void VisitContextSymbols( - const ffi::TypedFunction& callback); + TVM_FFI_EXTRA_CXX_API static void VisitContextSymbols( + const ffi::TypedFunction &callback); - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Module, ObjectRef, ModuleObj); - /// \endcond + /// \cond Doxygen_Suppress + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Module, ObjectRef, ModuleObj); + /// \endcond }; /* @@ -280,22 +280,22 @@ class Module : public ObjectRef { */ namespace symbol { /*!\ brief symbol prefix for tvm ffi related function symbols */ -constexpr const char* tvm_ffi_symbol_prefix = "__tvm_ffi_"; +constexpr const char *tvm_ffi_symbol_prefix = "__tvm_ffi_"; // Special symbols have one extra _ prefix to avoid conflict with user symbols /*! * \brief Default entry function of a library module is tvm_ffi_symbol_prefix + "main" */ -constexpr const char* tvm_ffi_main = "__tvm_ffi_main"; +constexpr const char *tvm_ffi_main = "__tvm_ffi_main"; /*! \brief Global variable to store context pointer for a library module. */ -constexpr const char* tvm_ffi_library_ctx = "__tvm_ffi__library_ctx"; +constexpr const char *tvm_ffi_library_ctx = "__tvm_ffi__library_ctx"; /*! \brief Global variable to store binary data alongside a library module. */ -constexpr const char* tvm_ffi_library_bin = "__tvm_ffi__library_bin"; +constexpr const char *tvm_ffi_library_bin = "__tvm_ffi__library_bin"; /*! \brief Optional metadata prefix of a symbol. */ -constexpr const char* tvm_ffi_metadata_prefix = "__tvm_ffi__metadata_"; +constexpr const char *tvm_ffi_metadata_prefix = "__tvm_ffi__metadata_"; /*! \brief Optional documentation prefix of a symbol. */ -constexpr const char* tvm_ffi_doc_prefix = "__tvm_ffi__doc_"; -} // namespace symbol -} // namespace ffi -} // namespace tvm +constexpr const char *tvm_ffi_doc_prefix = "__tvm_ffi__doc_"; +} // namespace symbol +} // namespace ffi +} // namespace tvm -#endif // TVM_FFI_EXTRA_MODULE_H_ +#endif // TVM_FFI_EXTRA_MODULE_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/serialization.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/serialization.h index b5aa2891a..3a726504f 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/serialization.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/serialization.h @@ -54,7 +54,7 @@ namespace ffi { * \param metadata Extra metadata attached to "metadata" field of the JSON object. * \return The serialized JSON value. */ -TVM_FFI_EXTRA_CXX_API json::Value ToJSONGraph(const Any& value, const Any& metadata = Any(nullptr)); +TVM_FFI_EXTRA_CXX_API json::Value ToJSONGraph(const Any &value, const Any &metadata = Any(nullptr)); /** * \brief Deserialize a JSON that stores the object graph to an ffi::Any value. @@ -65,8 +65,8 @@ TVM_FFI_EXTRA_CXX_API json::Value ToJSONGraph(const Any& value, const Any& metad * \param value The JSON value to deserialize. * \return The deserialized object graph. */ -TVM_FFI_EXTRA_CXX_API Any FromJSONGraph(const json::Value& value); +TVM_FFI_EXTRA_CXX_API Any FromJSONGraph(const json::Value &value); -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_SERIALIZATION_H_ +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_SERIALIZATION_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_equal.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_equal.h index ec960a85e..1ee5780d8 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_equal.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_equal.h @@ -34,45 +34,45 @@ namespace ffi { * \brief Structural equality comparators */ class StructuralEqual { - public: - /** - * \brief Compare two Any values for structural equality. - * \param lhs The left hand side Any object. - * \param rhs The right hand side Any object. - * \param map_free_vars Whether to map free variables. - * \param skip_tensor_content Whether to skip comparingn darray data content, - * useful for cases where we don't care about parameters content - * \return True if the two Any values are structurally equal, false otherwise. - */ - TVM_FFI_EXTRA_CXX_API static bool Equal(const Any& lhs, const Any& rhs, - bool map_free_vars = false, - bool skip_tensor_content = false); - /** - * \brief Get the first mismatch AccessPath pair when running - * structural equal comparison between two Any values. - * - * \param lhs The left hand side Any object. - * \param rhs The right hand side Any object. - * \param map_free_vars Whether to map free variables. - * \param skip_tensor_content Whether to skip comparing tensor data content, - * useful for cases where we don't care about parameters content - * \return If comparison fails, return the first mismatch AccessPath pair, - * otherwise return std::nullopt. - */ - TVM_FFI_EXTRA_CXX_API static Optional GetFirstMismatch( - const Any& lhs, const Any& rhs, bool map_free_vars = false, bool skip_tensor_content = false); +public: + /** + * \brief Compare two Any values for structural equality. + * \param lhs The left hand side Any object. + * \param rhs The right hand side Any object. + * \param map_free_vars Whether to map free variables. + * \param skip_tensor_content Whether to skip comparingn darray data content, + * useful for cases where we don't care about parameters content + * \return True if the two Any values are structurally equal, false otherwise. + */ + TVM_FFI_EXTRA_CXX_API static bool Equal(const Any &lhs, const Any &rhs, + bool map_free_vars = false, + bool skip_tensor_content = false); + /** + * \brief Get the first mismatch AccessPath pair when running + * structural equal comparison between two Any values. + * + * \param lhs The left hand side Any object. + * \param rhs The right hand side Any object. + * \param map_free_vars Whether to map free variables. + * \param skip_tensor_content Whether to skip comparing tensor data content, + * useful for cases where we don't care about parameters content + * \return If comparison fails, return the first mismatch AccessPath pair, + * otherwise return std::nullopt. + */ + TVM_FFI_EXTRA_CXX_API static Optional GetFirstMismatch( + const Any &lhs, const Any &rhs, bool map_free_vars = false, bool skip_tensor_content = false); - /* - * \brief Compare two Any values for structural equality. - * \param lhs The left hand side Any object. - * \param rhs The right hand side Any object. - * \return True if the two Any values are structurally equal, false otherwise. - */ - TVM_FFI_INLINE bool operator()(const Any& lhs, const Any& rhs) const { - return Equal(lhs, rhs, false, true); - } + /* + * \brief Compare two Any values for structural equality. + * \param lhs The left hand side Any object. + * \param rhs The right hand side Any object. + * \return True if the two Any values are structurally equal, false otherwise. + */ + TVM_FFI_INLINE bool operator()(const Any &lhs, const Any &rhs) const { + return Equal(lhs, rhs, false, true); + } }; -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_hash.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_hash.h index bfe023c38..b27181b37 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_hash.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_hash.h @@ -33,25 +33,25 @@ namespace ffi { * \brief Structural hash */ class StructuralHash { - public: - /*! - * \brief Hash an Any value. - * \param value The Any value to hash. - * \param map_free_vars Whether to map free variables. - * \param skip_tensor_content Whether to skip comparingn darray data content, - * useful for cases where we don't care about parameters content. - * \return The hash value. - */ - TVM_FFI_EXTRA_CXX_API static uint64_t Hash(const Any& value, bool map_free_vars = false, - bool skip_tensor_content = false); - /*! - * \brief Hash an Any value. - * \param value The Any value to hash. - * \return The hash value. - */ - TVM_FFI_INLINE uint64_t operator()(const Any& value) const { return Hash(value); } +public: + /*! + * \brief Hash an Any value. + * \param value The Any value to hash. + * \param map_free_vars Whether to map free variables. + * \param skip_tensor_content Whether to skip comparingn darray data content, + * useful for cases where we don't care about parameters content. + * \return The hash value. + */ + TVM_FFI_EXTRA_CXX_API static uint64_t Hash(const Any &value, bool map_free_vars = false, + bool skip_tensor_content = false); + /*! + * \brief Hash an Any value. + * \param value The Any value to hash. + * \return The hash value. + */ + TVM_FFI_INLINE uint64_t operator()(const Any &value) const { return Hash(value); } }; -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/access_path.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/access_path.h index e4ebc2aa2..4d716bce0 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/access_path.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/access_path.h @@ -42,20 +42,20 @@ namespace reflection { * \brief The kind of the access pattern. */ enum class AccessKind : int32_t { - /*! \brief Object attribute access. */ - kAttr = 0, - /*! \brief Array item access. */ - kArrayItem = 1, - /*! \brief Map item access. */ - kMapItem = 2, - // the following two are used for error reporting when - // the supposed access field is not available - /*! \brief Object attribute missing access. */ - kAttrMissing = 3, - /*! \brief Array item missing access. */ - kArrayItemMissing = 4, - /*! \brief Map item missing access. */ - kMapItemMissing = 5, + /*! \brief Object attribute access. */ + kAttr = 0, + /*! \brief Array item access. */ + kArrayItem = 1, + /*! \brief Map item access. */ + kMapItem = 2, + // the following two are used for error reporting when + // the supposed access field is not available + /*! \brief Object attribute missing access. */ + kAttrMissing = 3, + /*! \brief Array item missing access. */ + kArrayItemMissing = 4, + /*! \brief Map item missing access. */ + kMapItemMissing = 5, }; class AccessStep; @@ -64,38 +64,38 @@ class AccessStep; * \brief Represent a single step in object field, map key, array index access. */ class AccessStepObj : public Object { - public: - /*! - * \brief The kind of the access pattern. - */ - AccessKind kind; - /*! - * \brief The access key - * \note for array access, it will always be integer - * for field access, it will be string - */ - Any key; - - // default constructor to enable auto-serialization - AccessStepObj() = default; - /*! - * \brief Constructor - * \param kind The kind of the access step. - * \param key The key of the access step. - */ - AccessStepObj(AccessKind kind, Any key) : kind(kind), key(std::move(key)) {} - - /*! - * \brief Deep check if two steps are equal. - * \param other The other step to compare with. - * \return True if the two steps are equal, false otherwise. - */ - inline bool StepEqual(const AccessStep& other) const; - - /// \cond Doxygen_Suppress - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ffi.reflection.AccessStep", AccessStepObj, Object); - /// \endcond +public: + /*! + * \brief The kind of the access pattern. + */ + AccessKind kind; + /*! + * \brief The access key + * \note for array access, it will always be integer + * for field access, it will be string + */ + Any key; + + // default constructor to enable auto-serialization + AccessStepObj() = default; + /*! + * \brief Constructor + * \param kind The kind of the access step. + * \param key The key of the access step. + */ + AccessStepObj(AccessKind kind, Any key) : kind(kind), key(std::move(key)) {} + + /*! + * \brief Deep check if two steps are equal. + * \param other The other step to compare with. + * \return True if the two steps are equal, false otherwise. + */ + inline bool StepEqual(const AccessStep &other) const; + + /// \cond Doxygen_Suppress + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ffi.reflection.AccessStep", AccessStepObj, Object); + /// \endcond }; /*! @@ -104,72 +104,72 @@ class AccessStepObj : public Object { * \sa AccessStepObj */ class AccessStep : public ObjectRef { - public: - /*! - * \brief Constructor - * \param kind The kind of the access step. - * \param key The key of the access step. - */ - AccessStep(AccessKind kind, Any key) - : ObjectRef(make_object(kind, std::move(key))) {} - - /*! - * \brief Create an access step for a object attribute access. - * \param field_name The name of the field to access. - * \return The access step. - */ - static AccessStep Attr(String field_name) { - return AccessStep(AccessKind::kAttr, std::move(field_name)); - } - - /*! - * \brief Create an access step for a object attribute missing access. - * \param field_name The name of the field to access. - * \return The access step. - */ - static AccessStep AttrMissing(String field_name) { - return AccessStep(AccessKind::kAttrMissing, std::move(field_name)); - } - - /*! - * \brief Create an access step for a array item access. - * \param index The index of the array item to access. - * \return The access step. - */ - static AccessStep ArrayItem(int64_t index) { return AccessStep(AccessKind::kArrayItem, index); } - - /*! - * \brief Create an access step for a array item missing access. - * \param index The index of the array item to access. - * \return The access step. - */ - static AccessStep ArrayItemMissing(int64_t index) { - return AccessStep(AccessKind::kArrayItemMissing, index); - } - - /*! - * \brief Create an access step for a map item access. - * \param key The key of the map item to access. - * \return The access step. - */ - static AccessStep MapItem(Any key) { return AccessStep(AccessKind::kMapItem, std::move(key)); } - - /*! - * \brief Create an access step for a map item missing access. - * \param key The key of the map item to access. - * \return The access step. - */ - static AccessStep MapItemMissing(Any key = nullptr) { - return AccessStep(AccessKind::kMapItemMissing, std::move(key)); - } - - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AccessStep, ObjectRef, AccessStepObj); - /// \endcond +public: + /*! + * \brief Constructor + * \param kind The kind of the access step. + * \param key The key of the access step. + */ + AccessStep(AccessKind kind, Any key) + : ObjectRef(make_object(kind, std::move(key))) {} + + /*! + * \brief Create an access step for a object attribute access. + * \param field_name The name of the field to access. + * \return The access step. + */ + static AccessStep Attr(String field_name) { + return AccessStep(AccessKind::kAttr, std::move(field_name)); + } + + /*! + * \brief Create an access step for a object attribute missing access. + * \param field_name The name of the field to access. + * \return The access step. + */ + static AccessStep AttrMissing(String field_name) { + return AccessStep(AccessKind::kAttrMissing, std::move(field_name)); + } + + /*! + * \brief Create an access step for a array item access. + * \param index The index of the array item to access. + * \return The access step. + */ + static AccessStep ArrayItem(int64_t index) { return AccessStep(AccessKind::kArrayItem, index); } + + /*! + * \brief Create an access step for a array item missing access. + * \param index The index of the array item to access. + * \return The access step. + */ + static AccessStep ArrayItemMissing(int64_t index) { + return AccessStep(AccessKind::kArrayItemMissing, index); + } + + /*! + * \brief Create an access step for a map item access. + * \param key The key of the map item to access. + * \return The access step. + */ + static AccessStep MapItem(Any key) { return AccessStep(AccessKind::kMapItem, std::move(key)); } + + /*! + * \brief Create an access step for a map item missing access. + * \param key The key of the map item to access. + * \return The access step. + */ + static AccessStep MapItemMissing(Any key = nullptr) { + return AccessStep(AccessKind::kMapItemMissing, std::move(key)); + } + + /// \cond Doxygen_Suppress + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AccessStep, ObjectRef, AccessStepObj); + /// \endcond }; -inline bool AccessStepObj::StepEqual(const AccessStep& other) const { - return this->kind == other->kind && AnyEqual()(this->key, other->key); +inline bool AccessStepObj::StepEqual(const AccessStep &other) const { + return this->kind == other->kind && AnyEqual()(this->key, other->key); } // forward declaration @@ -181,139 +181,145 @@ class AccessPath; * \sa AccessPathObj */ class AccessPathObj : public Object { - public: - /*! - * \brief The parent of the access path. - * - * This parent-pointing tree structure is more space efficient when - * representing multiple paths that share a common prefix. - * - * \note Empty for root. - */ - Optional parent; - /*! - * \brief The current of the access path. - * \note Empty for root. - */ - Optional step; - /*! - * \brief The current depth of the access path, 0 for root - */ - int32_t depth; - - // default constructor to enable auto-serialization - AccessPathObj() = default; - /*! - * \brief Constructor for the access path. - * \param parent The parent of the access path. - * \param step The current step of the access path. - * \param depth The current depth of the access path. - */ - AccessPathObj(Optional parent, Optional step, int32_t depth) - : parent(std::move(parent)), step(std::move(step)), depth(depth) {} - - /*! - * \brief Get the parent of the access path. - * \return The parent of the access path. - */ - inline Optional GetParent() const; - - /*! - * \brief Extend the access path with a new step. - * \param step The step to extend the access path with. - * \return The extended access path. - */ - inline AccessPath Extend(AccessStep step) const; - - /*! - * \brief Extend the access path with an object attribute access. - * \param field_name The name of the field to access. - * \return The extended access path. - */ - inline AccessPath Attr(String field_name) const; - - /*! - * \brief Extend the access path with an object attribute missing access. - * \param field_name The name of the field to access. - * \return The extended access path. - */ - inline AccessPath AttrMissing(String field_name) const; - - /*! - * \brief Extend the access path with an array item access. - * \param index The index of the array item to access. - * \return The extended access path. - */ - inline AccessPath ArrayItem(int64_t index) const; - - /*! - * \brief Extend the access path with an array item missing access. - * \param index The index of the array item to access. - * \return The extended access path. - */ - inline AccessPath ArrayItemMissing(int64_t index) const; - - /*! - * \brief Extend the access path with a map item access. - * \param key The key of the map item to access. - * \return The extended access path. - */ - inline AccessPath MapItem(Any key) const; - - /*! - * \brief Extend the access path with a map item missing access. - * \param key The key of the map item to access. - * \return The extended access path. - */ - inline AccessPath MapItemMissing(Any key) const; - - /*! - * \brief Get the array of steps that corresponds to the access path. - * \return The array of steps that corresponds to the access path. - */ - inline Array ToSteps() const; - - /*! - * \brief Check if two paths are equal by deep comparing the steps. - * \param other The other path to compare with. - * \return True if the two paths are equal, false otherwise. - */ - inline bool PathEqual(const AccessPath& other) const; - - /*! - * \brief Check if this path is a prefix of another path. - * \param other The other path to compare with. - * \return True if this path is a prefix of the other path, false otherwise. - */ - inline bool IsPrefixOf(const AccessPath& other) const; - - /// \cond Doxygen_Suppress - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ffi.reflection.AccessPath", AccessPathObj, Object); - /// \endcond - - private: - static bool PathEqual(const AccessPathObj* lhs, const AccessPathObj* rhs) { - // fast path for same pointer - if (lhs == rhs) return true; - if (lhs->depth != rhs->depth) return false; - // do deep equality checks - while (lhs->parent.has_value()) { - TVM_FFI_ICHECK(rhs->parent.has_value()); - TVM_FFI_ICHECK(lhs->step.has_value()); - TVM_FFI_ICHECK(rhs->step.has_value()); - if (!(*lhs->step)->StepEqual(*(rhs->step))) { - return false; - } - lhs = static_cast(lhs->parent.get()); - rhs = static_cast(rhs->parent.get()); - // fast path for same pointer - if (lhs == rhs) return true; - TVM_FFI_ICHECK(lhs != nullptr); - TVM_FFI_ICHECK(rhs != nullptr); +public: + /*! + * \brief The parent of the access path. + * + * This parent-pointing tree structure is more space efficient when + * representing multiple paths that share a common prefix. + * + * \note Empty for root. + */ + Optional parent; + /*! + * \brief The current of the access path. + * \note Empty for root. + */ + Optional step; + /*! + * \brief The current depth of the access path, 0 for root + */ + int32_t depth; + + // default constructor to enable auto-serialization + AccessPathObj() = default; + /*! + * \brief Constructor for the access path. + * \param parent The parent of the access path. + * \param step The current step of the access path. + * \param depth The current depth of the access path. + */ + AccessPathObj(Optional parent, Optional step, int32_t depth) + : parent(std::move(parent)), step(std::move(step)), depth(depth) {} + + /*! + * \brief Get the parent of the access path. + * \return The parent of the access path. + */ + inline Optional GetParent() const; + + /*! + * \brief Extend the access path with a new step. + * \param step The step to extend the access path with. + * \return The extended access path. + */ + inline AccessPath Extend(AccessStep step) const; + + /*! + * \brief Extend the access path with an object attribute access. + * \param field_name The name of the field to access. + * \return The extended access path. + */ + inline AccessPath Attr(String field_name) const; + + /*! + * \brief Extend the access path with an object attribute missing access. + * \param field_name The name of the field to access. + * \return The extended access path. + */ + inline AccessPath AttrMissing(String field_name) const; + + /*! + * \brief Extend the access path with an array item access. + * \param index The index of the array item to access. + * \return The extended access path. + */ + inline AccessPath ArrayItem(int64_t index) const; + + /*! + * \brief Extend the access path with an array item missing access. + * \param index The index of the array item to access. + * \return The extended access path. + */ + inline AccessPath ArrayItemMissing(int64_t index) const; + + /*! + * \brief Extend the access path with a map item access. + * \param key The key of the map item to access. + * \return The extended access path. + */ + inline AccessPath MapItem(Any key) const; + + /*! + * \brief Extend the access path with a map item missing access. + * \param key The key of the map item to access. + * \return The extended access path. + */ + inline AccessPath MapItemMissing(Any key) const; + + /*! + * \brief Get the array of steps that corresponds to the access path. + * \return The array of steps that corresponds to the access path. + */ + inline Array ToSteps() const; + + /*! + * \brief Check if two paths are equal by deep comparing the steps. + * \param other The other path to compare with. + * \return True if the two paths are equal, false otherwise. + */ + inline bool PathEqual(const AccessPath &other) const; + + /*! + * \brief Check if this path is a prefix of another path. + * \param other The other path to compare with. + * \return True if this path is a prefix of the other path, false otherwise. + */ + inline bool IsPrefixOf(const AccessPath &other) const; + + /// \cond Doxygen_Suppress + static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ffi.reflection.AccessPath", AccessPathObj, Object); + /// \endcond + +private: + static bool PathEqual(const AccessPathObj *lhs, const AccessPathObj *rhs) { + // fast path for same pointer + if (lhs == rhs) { + return true; + } + if (lhs->depth != rhs->depth) { + return false; + } + // do deep equality checks + while (lhs->parent.has_value()) { + TVM_FFI_ICHECK(rhs->parent.has_value()); + TVM_FFI_ICHECK(lhs->step.has_value()); + TVM_FFI_ICHECK(rhs->step.has_value()); + if (!(*lhs->step)->StepEqual(*(rhs->step))) { + return false; + } + lhs = static_cast(lhs->parent.get()); + rhs = static_cast(rhs->parent.get()); + // fast path for same pointer + if (lhs == rhs) { + return true; + } + TVM_FFI_ICHECK(lhs != nullptr); + TVM_FFI_ICHECK(rhs != nullptr); + } + return true; } - return true; - } }; /*! @@ -322,49 +328,49 @@ class AccessPathObj : public Object { * \sa AccessPathObj */ class AccessPath : public ObjectRef { - public: - /*! - * \brief Create an access path from an iterator range of steps. - * \param begin The beginning of the iterator range. - * \param end The end of the iterator range. - * \return The access path. - */ - template // NOLINTNEXTLINE(performance-unnecessary-value-param) - static AccessPath FromSteps(Iter begin, Iter end) { - AccessPath path = AccessPath::Root(); - for (Iter it = begin; it != end; ++it) { - path = path->Extend(*it); +public: + /*! + * \brief Create an access path from an iterator range of steps. + * \param begin The beginning of the iterator range. + * \param end The end of the iterator range. + * \return The access path. + */ + template // NOLINTNEXTLINE(performance-unnecessary-value-param) + static AccessPath FromSteps(Iter begin, Iter end) { + AccessPath path = AccessPath::Root(); + for (Iter it = begin; it != end; ++it) { + path = path->Extend(*it); + } + return path; } - return path; - } - /*! - * \brief Create an access path from an array of steps. - * \param steps The array of steps. - * \return The access path. - */ - static AccessPath FromSteps(const Array& steps) { - AccessPath path = AccessPath::Root(); - for (AccessStep step : steps) { - path = path->Extend(step); + /*! + * \brief Create an access path from an array of steps. + * \param steps The array of steps. + * \return The access path. + */ + static AccessPath FromSteps(const Array &steps) { + AccessPath path = AccessPath::Root(); + for (AccessStep step : steps) { + path = path->Extend(step); + } + return path; } - return path; - } - - /*! - * \brief Create a root access path. - * \return The root access path. - */ - static AccessPath Root() { - return AccessPath(make_object(std::nullopt, std::nullopt, 0)); - } - - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AccessPath, ObjectRef, AccessPathObj); - /// \endcond - - private: - friend class AccessPathObj; - explicit AccessPath(ObjectPtr ptr) : ObjectRef(std::move(ptr)) {} + + /*! + * \brief Create a root access path. + * \return The root access path. + */ + static AccessPath Root() { + return AccessPath(make_object(std::nullopt, std::nullopt, 0)); + } + + /// \cond Doxygen_Suppress + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AccessPath, ObjectRef, AccessPathObj); + /// \endcond + +private: + friend class AccessPathObj; + explicit AccessPath(ObjectPtr ptr) : ObjectRef(std::move(ptr)) {} }; /*! @@ -373,72 +379,72 @@ class AccessPath : public ObjectRef { using AccessPathPair = Tuple; inline Optional AccessPathObj::GetParent() const { - if (auto opt_parent = this->parent.as()) { - return opt_parent; - } - return std::nullopt; + if (auto opt_parent = this->parent.as()) { + return opt_parent; + } + return std::nullopt; } inline AccessPath AccessPathObj::Extend(AccessStep step) const { - return AccessPath( - make_object(GetRef(this), std::move(step), this->depth + 1)); + return AccessPath( + make_object(GetRef(this), std::move(step), this->depth + 1)); } inline AccessPath AccessPathObj::Attr(String field_name) const { - return this->Extend(AccessStep::Attr(std::move(field_name))); + return this->Extend(AccessStep::Attr(std::move(field_name))); } inline AccessPath AccessPathObj::AttrMissing(String field_name) const { - return this->Extend(AccessStep::AttrMissing(std::move(field_name))); + return this->Extend(AccessStep::AttrMissing(std::move(field_name))); } inline AccessPath AccessPathObj::ArrayItem(int64_t index) const { - return this->Extend(AccessStep::ArrayItem(index)); + return this->Extend(AccessStep::ArrayItem(index)); } inline AccessPath AccessPathObj::ArrayItemMissing(int64_t index) const { - return this->Extend(AccessStep::ArrayItemMissing(index)); + return this->Extend(AccessStep::ArrayItemMissing(index)); } inline AccessPath AccessPathObj::MapItem(Any key) const { - return this->Extend(AccessStep::MapItem(std::move(key))); + return this->Extend(AccessStep::MapItem(std::move(key))); } inline AccessPath AccessPathObj::MapItemMissing(Any key) const { - return this->Extend(AccessStep::MapItemMissing(std::move(key))); + return this->Extend(AccessStep::MapItemMissing(std::move(key))); } inline Array AccessPathObj::ToSteps() const { - std::vector reverse_steps; - reverse_steps.reserve(this->depth); - const AccessPathObj* current = this; - while (current->parent.has_value()) { - TVM_FFI_ICHECK(current->step.has_value()); - reverse_steps.push_back(*(current->step)); - current = static_cast(current->parent.get()); - TVM_FFI_ICHECK(current != nullptr); - } - return Array(reverse_steps.rbegin(), reverse_steps.rend()); + std::vector reverse_steps; + reverse_steps.reserve(this->depth); + const AccessPathObj *current = this; + while (current->parent.has_value()) { + TVM_FFI_ICHECK(current->step.has_value()); + reverse_steps.push_back(*(current->step)); + current = static_cast(current->parent.get()); + TVM_FFI_ICHECK(current != nullptr); + } + return Array(reverse_steps.rbegin(), reverse_steps.rend()); } -inline bool AccessPathObj::PathEqual(const AccessPath& other) const { - return PathEqual(this, other.get()); +inline bool AccessPathObj::PathEqual(const AccessPath &other) const { + return PathEqual(this, other.get()); } -inline bool AccessPathObj::IsPrefixOf(const AccessPath& other) const { - if (this->depth > other->depth) { - return false; - } - const AccessPathObj* rhs_path = other.get(); - while (rhs_path->depth > this->depth) { - TVM_FFI_ICHECK(rhs_path->parent.has_value()); - rhs_path = static_cast(rhs_path->parent.get()); - } - return PathEqual(this, rhs_path); +inline bool AccessPathObj::IsPrefixOf(const AccessPath &other) const { + if (this->depth > other->depth) { + return false; + } + const AccessPathObj *rhs_path = other.get(); + while (rhs_path->depth > this->depth) { + TVM_FFI_ICHECK(rhs_path->parent.has_value()); + rhs_path = static_cast(rhs_path->parent.get()); + } + return PathEqual(this, rhs_path); } -} // namespace reflection -} // namespace ffi -} // namespace tvm +} // namespace reflection +} // namespace ffi +} // namespace tvm -#endif // TVM_FFI_REFLECTION_ACCESS_PATH_H_ +#endif // TVM_FFI_REFLECTION_ACCESS_PATH_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/accessor.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/accessor.h index b49da5193..c77b01679 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/accessor.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/accessor.h @@ -38,132 +38,132 @@ namespace reflection { /*! * \brief helper function to get reflection field info by type key and field name */ -inline const TVMFFIFieldInfo* GetFieldInfo(std::string_view type_key, const char* field_name) { - int32_t type_index; - TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); - const TypeInfo* info = TVMFFIGetTypeInfo(type_index); - for (int32_t i = 0; i < info->num_fields; ++i) { - if (std::strncmp(info->fields[i].name.data, field_name, info->fields[i].name.size) == 0) { - return &(info->fields[i]); +inline const TVMFFIFieldInfo *GetFieldInfo(std::string_view type_key, const char *field_name) { + int32_t type_index; + TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); + const TypeInfo *info = TVMFFIGetTypeInfo(type_index); + for (int32_t i = 0; i < info->num_fields; ++i) { + if (std::strncmp(info->fields[i].name.data, field_name, info->fields[i].name.size) == 0) { + return &(info->fields[i]); + } } - } - TVM_FFI_THROW(RuntimeError) << "Cannot find field `" << field_name << "` in " << type_key; - TVM_FFI_UNREACHABLE(); + TVM_FFI_THROW(RuntimeError) << "Cannot find field `" << field_name << "` in " << type_key; + TVM_FFI_UNREACHABLE(); } /*! * \brief helper wrapper class to obtain a getter. */ class FieldGetter { - public: - /*! - * \brief Constructor - * \param field_info The field info. - */ - explicit FieldGetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) {} +public: + /*! + * \brief Constructor + * \param field_info The field info. + */ + explicit FieldGetter(const TVMFFIFieldInfo *field_info) : field_info_(field_info) {} - /*! - * \brief Constructor - * \param type_key The type key. - * \param field_name The name of the field. - */ - explicit FieldGetter(std::string_view type_key, const char* field_name) - : FieldGetter(GetFieldInfo(type_key, field_name)) {} + /*! + * \brief Constructor + * \param type_key The type key. + * \param field_name The name of the field. + */ + explicit FieldGetter(std::string_view type_key, const char *field_name) + : FieldGetter(GetFieldInfo(type_key, field_name)) {} - /*! - * \brief Get the value of the field - * \param obj_ptr The object pointer. - * \return The value of the field. - */ - Any operator()(const Object* obj_ptr) const { - Any result; - const void* addr = reinterpret_cast(obj_ptr) + field_info_->offset; - TVM_FFI_CHECK_SAFE_CALL( - field_info_->getter(const_cast(addr), reinterpret_cast(&result))); - return result; - } + /*! + * \brief Get the value of the field + * \param obj_ptr The object pointer. + * \return The value of the field. + */ + Any operator()(const Object *obj_ptr) const { + Any result; + const void *addr = reinterpret_cast(obj_ptr) + field_info_->offset; + TVM_FFI_CHECK_SAFE_CALL( + field_info_->getter(const_cast(addr), reinterpret_cast(&result))); + return result; + } - Any operator()(const ObjectPtr& obj_ptr) const { return operator()(obj_ptr.get()); } + Any operator()(const ObjectPtr &obj_ptr) const { return operator()(obj_ptr.get()); } - Any operator()(const ObjectRef& obj) const { return operator()(obj.get()); } + Any operator()(const ObjectRef &obj) const { return operator()(obj.get()); } - private: - const TVMFFIFieldInfo* field_info_; +private: + const TVMFFIFieldInfo *field_info_; }; /*! * \brief helper wrapper class to obtain a setter. */ class FieldSetter { - public: - /*! - * \brief Constructor - * \param field_info The field info. - */ - explicit FieldSetter(const TVMFFIFieldInfo* field_info) : field_info_(field_info) {} +public: + /*! + * \brief Constructor + * \param field_info The field info. + */ + explicit FieldSetter(const TVMFFIFieldInfo *field_info) : field_info_(field_info) {} - /*! - * \brief Constructor - * \param type_key The type key. - * \param field_name The name of the field. - */ - explicit FieldSetter(std::string_view type_key, const char* field_name) - : FieldSetter(GetFieldInfo(type_key, field_name)) {} + /*! + * \brief Constructor + * \param type_key The type key. + * \param field_name The name of the field. + */ + explicit FieldSetter(std::string_view type_key, const char *field_name) + : FieldSetter(GetFieldInfo(type_key, field_name)) {} - /*! - * \brief Set the value of the field - * \param obj_ptr The object pointer. - * \param value The value to be set. - */ - void operator()(const Object* obj_ptr, AnyView value) const { - const void* addr = reinterpret_cast(obj_ptr) + field_info_->offset; - TVM_FFI_CHECK_SAFE_CALL( - field_info_->setter(const_cast(addr), reinterpret_cast(&value))); - } + /*! + * \brief Set the value of the field + * \param obj_ptr The object pointer. + * \param value The value to be set. + */ + void operator()(const Object *obj_ptr, AnyView value) const { + const void *addr = reinterpret_cast(obj_ptr) + field_info_->offset; + TVM_FFI_CHECK_SAFE_CALL( + field_info_->setter(const_cast(addr), reinterpret_cast(&value))); + } - void operator()(const ObjectPtr& obj_ptr, AnyView value) const { - operator()(obj_ptr.get(), value); - } + void operator()(const ObjectPtr &obj_ptr, AnyView value) const { + operator()(obj_ptr.get(), value); + } - void operator()(const ObjectRef& obj, AnyView value) const { operator()(obj.get(), value); } + void operator()(const ObjectRef &obj, AnyView value) const { operator()(obj.get(), value); } - private: - const TVMFFIFieldInfo* field_info_; +private: + const TVMFFIFieldInfo *field_info_; }; /*! * \brief Helper class to get type attribute column. */ class TypeAttrColumn { - public: - /*! - * \brief Constructor - * \param attr_name The name of the type attribute. - */ - explicit TypeAttrColumn(std::string_view attr_name) { - TVMFFIByteArray attr_name_array = {attr_name.data(), attr_name.size()}; - column_ = TVMFFIGetTypeAttrColumn(&attr_name_array); - if (column_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Cannot find type attribute " << attr_name; +public: + /*! + * \brief Constructor + * \param attr_name The name of the type attribute. + */ + explicit TypeAttrColumn(std::string_view attr_name) { + TVMFFIByteArray attr_name_array = {attr_name.data(), attr_name.size()}; + column_ = TVMFFIGetTypeAttrColumn(&attr_name_array); + if (column_ == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Cannot find type attribute " << attr_name; + } } - } - /*! - * \brief Get the type attribute column by type index. - * \param type_index The type index. - * \return The type attribute column. - */ - AnyView operator[](int32_t type_index) const { - size_t tindex = static_cast(type_index); - if (tindex >= column_->size) { - return AnyView(); + /*! + * \brief Get the type attribute column by type index. + * \param type_index The type index. + * \return The type attribute column. + */ + AnyView operator[](int32_t type_index) const { + size_t tindex = static_cast(type_index); + if (tindex >= column_->size) { + return AnyView(); + } + const AnyView *any_view_data = reinterpret_cast(column_->data); + return any_view_data[tindex]; } - const AnyView* any_view_data = reinterpret_cast(column_->data); - return any_view_data[tindex]; - } - private: - const TVMFFITypeAttrColumn* column_; +private: + const TVMFFITypeAttrColumn *column_; }; /*! @@ -173,18 +173,18 @@ class TypeAttrColumn { * \param method_name The name of the method. * \return The method info. */ -inline const TVMFFIMethodInfo* GetMethodInfo(std::string_view type_key, const char* method_name) { - int32_t type_index; - TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); - const TypeInfo* info = TVMFFIGetTypeInfo(type_index); - for (int32_t i = 0; i < info->num_methods; ++i) { - if (std::strncmp(info->methods[i].name.data, method_name, info->methods[i].name.size) == 0) { - return &(info->methods[i]); +inline const TVMFFIMethodInfo *GetMethodInfo(std::string_view type_key, const char *method_name) { + int32_t type_index; + TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); + const TypeInfo *info = TVMFFIGetTypeInfo(type_index); + for (int32_t i = 0; i < info->num_methods; ++i) { + if (std::strncmp(info->methods[i].name.data, method_name, info->methods[i].name.size) == 0) { + return &(info->methods[i]); + } } - } - TVM_FFI_THROW(RuntimeError) << "Cannot find method " << method_name << " in " << type_key; - TVM_FFI_UNREACHABLE(); + TVM_FFI_THROW(RuntimeError) << "Cannot find method " << method_name << " in " << type_key; + TVM_FFI_UNREACHABLE(); } /*! @@ -194,9 +194,9 @@ inline const TVMFFIMethodInfo* GetMethodInfo(std::string_view type_key, const ch * \param method_name The name of the method. * \return The method function. */ -inline Function GetMethod(std::string_view type_key, const char* method_name) { - const TVMFFIMethodInfo* info = GetMethodInfo(type_key, method_name); - return AnyView::CopyFromTVMFFIAny(info->method).cast(); +inline Function GetMethod(std::string_view type_key, const char *method_name) { + const TVMFFIMethodInfo *info = GetMethodInfo(type_key, method_name); + return AnyView::CopyFromTVMFFIAny(info->method).cast(); } /*! @@ -210,20 +210,20 @@ inline Function GetMethod(std::string_view type_key, const char* method_name) { * \note This function calls both the child and parent type info. */ template -inline void ForEachFieldInfo(const TypeInfo* type_info, Callback callback) { - using ResultType = decltype(callback(type_info->fields)); - static_assert(std::is_same_v, "Callback must return void"); - // iterate through acenstors in parent to child order - // skip the first one since it is always the root object - for (int i = 1; i < type_info->type_depth; ++i) { - const TVMFFITypeInfo* parent_info = type_info->type_ancestors[i]; - for (int j = 0; j < parent_info->num_fields; ++j) { - callback(parent_info->fields + j); +inline void ForEachFieldInfo(const TypeInfo *type_info, Callback callback) { + using ResultType = decltype(callback(type_info->fields)); + static_assert(std::is_same_v, "Callback must return void"); + // iterate through acenstors in parent to child order + // skip the first one since it is always the root object + for (int i = 1; i < type_info->type_depth; ++i) { + const TVMFFITypeInfo *parent_info = type_info->type_ancestors[i]; + for (int j = 0; j < parent_info->num_fields; ++j) { + callback(parent_info->fields + j); + } + } + for (int i = 0; i < type_info->num_fields; ++i) { + callback(type_info->fields + i); } - } - for (int i = 0; i < type_info->num_fields; ++i) { - callback(type_info->fields + i); - } } /*! @@ -238,23 +238,27 @@ inline void ForEachFieldInfo(const TypeInfo* type_info, Callback callback) { * \note This function calls both the child and parent type info and can be used for searching. */ template -inline bool ForEachFieldInfoWithEarlyStop(const TypeInfo* type_info, +inline bool ForEachFieldInfoWithEarlyStop(const TypeInfo *type_info, Callback callback_with_early_stop) { - // iterate through acenstors in parent to child order - // skip the first one since it is always the root object - for (int i = 1; i < type_info->type_depth; ++i) { - const TVMFFITypeInfo* parent_info = type_info->type_ancestors[i]; - for (int j = 0; j < parent_info->num_fields; ++j) { - if (callback_with_early_stop(parent_info->fields + j)) return true; + // iterate through acenstors in parent to child order + // skip the first one since it is always the root object + for (int i = 1; i < type_info->type_depth; ++i) { + const TVMFFITypeInfo *parent_info = type_info->type_ancestors[i]; + for (int j = 0; j < parent_info->num_fields; ++j) { + if (callback_with_early_stop(parent_info->fields + j)) { + return true; + } + } + } + for (int i = 0; i < type_info->num_fields; ++i) { + if (callback_with_early_stop(type_info->fields + i)) { + return true; + } } - } - for (int i = 0; i < type_info->num_fields; ++i) { - if (callback_with_early_stop(type_info->fields + i)) return true; - } - return false; + return false; } -} // namespace reflection -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_REFLECTION_ACCESSOR_H_ +} // namespace reflection +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_REFLECTION_ACCESSOR_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/creator.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/creator.h index 774eb8b0b..dcbf3e056 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/creator.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/creator.h @@ -35,86 +35,87 @@ namespace reflection { * \brief helper wrapper class of TVMFFITypeInfo to create object based on reflection. */ class ObjectCreator { - public: - /*! - * \brief Constructor - * \param type_key The type key. - */ - explicit ObjectCreator(std::string_view type_key) - : ObjectCreator(TVMFFIGetTypeInfo(TypeKeyToIndex(type_key))) {} +public: + /*! + * \brief Constructor + * \param type_key The type key. + */ + explicit ObjectCreator(std::string_view type_key) + : ObjectCreator(TVMFFIGetTypeInfo(TypeKeyToIndex(type_key))) {} - /*! - * \brief Constructor - * \param type_info The type info. - */ - explicit ObjectCreator(const TVMFFITypeInfo* type_info) : type_info_(type_info) { - int32_t type_index = type_info->type_index; - if (type_info->metadata == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not have reflection registered"; - } - if (type_info->metadata->creator == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not support default constructor, " - << "as a result cannot be created via reflection"; + /*! + * \brief Constructor + * \param type_info The type info. + */ + explicit ObjectCreator(const TVMFFITypeInfo *type_info) : type_info_(type_info) { + int32_t type_index = type_info->type_index; + if (type_info->metadata == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) + << "` does not have reflection registered"; + } + if (type_info->metadata->creator == nullptr) { + TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) + << "` does not support default constructor, " + << "as a result cannot be created via reflection"; + } } - } - /** - * \brief Create an object from a map of fields. - * \param fields The fields of the object. - * \return The created object. - */ - Any operator()(const Map& fields) const { - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(type_info_->metadata->creator(&handle)); - ObjectPtr ptr = - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); - size_t match_field_count = 0; - ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo* field_info) { - String field_name(field_info->name); - void* field_addr = reinterpret_cast(ptr.get()) + field_info->offset; - if (fields.count(field_name) != 0) { - Any field_value = fields[field_name]; - field_info->setter(field_addr, reinterpret_cast(&field_value)); - ++match_field_count; - } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { - field_info->setter(field_addr, &(field_info->default_value)); - } else { - TVM_FFI_THROW(TypeError) << "Required field `" - << String(field_info->name.data, field_info->name.size) - << "` not set in type `" - << String(type_info_->type_key.data, type_info_->type_key.size) - << "`"; - } - }); - if (match_field_count == fields.size()) return ObjectRef(ptr); - // report error that checks if contains extra fields that are not in the type - auto check_field_name = [&](const String& field_name) { - bool found = false; - ForEachFieldInfoWithEarlyStop(type_info_, [&](const TVMFFIFieldInfo* field_info) { - if (field_name.compare(field_info->name) == 0) { - found = true; - return true; + /** + * \brief Create an object from a map of fields. + * \param fields The fields of the object. + * \return The created object. + */ + Any operator()(const Map &fields) const { + TVMFFIObjectHandle handle; + TVM_FFI_CHECK_SAFE_CALL(type_info_->metadata->creator(&handle)); + ObjectPtr ptr = details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); + size_t match_field_count = 0; + ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo *field_info) { + String field_name(field_info->name); + void *field_addr = reinterpret_cast(ptr.get()) + field_info->offset; + if (fields.count(field_name) != 0) { + Any field_value = fields[field_name]; + field_info->setter(field_addr, reinterpret_cast(&field_value)); + ++match_field_count; + } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { + field_info->setter(field_addr, &(field_info->default_value)); + } else { + TVM_FFI_THROW(TypeError) << "Required field `" + << String(field_info->name.data, field_info->name.size) + << "` not set in type `" + << String(type_info_->type_key.data, type_info_->type_key.size) + << "`"; + } + }); + if (match_field_count == fields.size()) { + return ObjectRef(ptr); + } + // report error that checks if contains extra fields that are not in the type + auto check_field_name = [&](const String &field_name) { + bool found = false; + ForEachFieldInfoWithEarlyStop(type_info_, [&](const TVMFFIFieldInfo *field_info) { + if (field_name.compare(field_info->name) == 0) { + found = true; + return true; + } + return false; + }); + return found; + }; + for (const auto &[field_name, _] : fields) { + if (!check_field_name(field_name)) { + TVM_FFI_THROW(TypeError) << "Type `" + << String(type_info_->type_key.data, type_info_->type_key.size) + << "` does not have field `" << field_name << "`"; + } } - return false; - }); - return found; - }; - for (const auto& [field_name, _] : fields) { - if (!check_field_name(field_name)) { - TVM_FFI_THROW(TypeError) << "Type `" - << String(type_info_->type_key.data, type_info_->type_key.size) - << "` does not have field `" << field_name << "`"; - } + TVM_FFI_UNREACHABLE(); } - TVM_FFI_UNREACHABLE(); - } - private: - const TVMFFITypeInfo* type_info_; +private: + const TVMFFITypeInfo *type_info_; }; -} // namespace reflection -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_REFLECTION_CREATOR_H_ +} // namespace reflection +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_REFLECTION_CREATOR_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/registry.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/registry.h index 3014108c8..e1978b1e6 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/registry.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/registry.h @@ -47,22 +47,22 @@ namespace reflection { * \brief Types of temporary metadata hold in FieldInfoBuilder and MethodInfoBuilder, * before they are filled into final C metadata */ -using _MetadataType = std::vector>; // NOLINT(bugprone-reserved-identifier) +using _MetadataType = std::vector>; // NOLINT(bugprone-reserved-identifier) /*! * \brief Builder for TVMFFIFieldInfo * \sa TVMFFIFieldInfo */ struct FieldInfoBuilder : public TVMFFIFieldInfo { - /*! \brief Temporary metadata info to be filled into TVMFFIFieldInfo::metadata */ - _MetadataType metadata_; + /*! \brief Temporary metadata info to be filled into TVMFFIFieldInfo::metadata */ + _MetadataType metadata_; }; /*! * \brief Builder for TVMFFIMethodInfo * \sa TVMFFIMethodInfo */ struct MethodInfoBuilder : public TVMFFIMethodInfo { - /*! \brief Temporary metadata info to be filled into TVMFFIMethodInfo::metadata */ - _MetadataType metadata_; + /*! \brief Temporary metadata info to be filled into TVMFFIMethodInfo::metadata */ + _MetadataType metadata_; }; /*! @@ -73,121 +73,121 @@ struct InfoTrait {}; /*! \brief User-supplied metadata attached to a field or a method */ class Metadata : public InfoTrait { - public: - /*! - * \brief Constructor - * \param dict The initial dictionary - */ - Metadata(std::initializer_list> dict) : dict_(dict) {} - /*! - * \brief Move metadata into `FieldInfoBuilder` - * \param info The field info builder. - */ - inline void Apply(FieldInfoBuilder* info) const { this->Apply(&info->metadata_); } - /*! - * \brief Move metadata into `MethodInfoBuilder` - * \param info The method info builder. - */ - inline void Apply(MethodInfoBuilder* info) const { this->Apply(&info->metadata_); } - - private: - friend class GlobalDef; - template - friend class ObjectDef; - /*! - * \brief Move metadata into a vector of key-value pairs. - * \param out The output vector. - */ - inline void Apply(_MetadataType* out) const { - std::copy(std::make_move_iterator(dict_.begin()), std::make_move_iterator(dict_.end()), - std::back_inserter(*out)); - } - /*! \brief Convert the metadata to JSON string */ - static std::string ToJSON(const _MetadataType& metadata) { - using ::tvm::ffi::details::StringObj; - std::ostringstream os; - os << "{"; - bool first = true; - for (const auto& [key, value] : metadata) { - if (!first) { - os << ","; - } - os << "\"" << key << "\":"; - if (std::optional v = value.as()) { - os << *v; - } else if (std::optional v = value.as()) { - os << (*v ? "true" : "false"); - } else if (std::optional v = value.as()) { - String escaped = EscapeString(*v); - os << escaped.c_str(); - } else { - TVM_FFI_LOG_AND_THROW(TypeError) << "Metadata can be only int, bool or string, but on key `" - << key << "`, the type is " << value.GetTypeKey(); - } - first = false; +public: + /*! + * \brief Constructor + * \param dict The initial dictionary + */ + Metadata(std::initializer_list> dict) : dict_(dict) {} + /*! + * \brief Move metadata into `FieldInfoBuilder` + * \param info The field info builder. + */ + inline void Apply(FieldInfoBuilder *info) const { this->Apply(&info->metadata_); } + /*! + * \brief Move metadata into `MethodInfoBuilder` + * \param info The method info builder. + */ + inline void Apply(MethodInfoBuilder *info) const { this->Apply(&info->metadata_); } + +private: + friend class GlobalDef; + template + friend class ObjectDef; + /*! + * \brief Move metadata into a vector of key-value pairs. + * \param out The output vector. + */ + inline void Apply(_MetadataType *out) const { + std::copy(std::make_move_iterator(dict_.begin()), std::make_move_iterator(dict_.end()), + std::back_inserter(*out)); + } + /*! \brief Convert the metadata to JSON string */ + static std::string ToJSON(const _MetadataType &metadata) { + using ::tvm::ffi::details::StringObj; + std::ostringstream os; + os << "{"; + bool first = true; + for (const auto &[key, value] : metadata) { + if (!first) { + os << ","; + } + os << "\"" << key << "\":"; + if (std::optional v = value.as()) { + os << *v; + } else if (std::optional v = value.as()) { + os << (*v ? "true" : "false"); + } else if (std::optional v = value.as()) { + String escaped = EscapeString(*v); + os << escaped.c_str(); + } else { + TVM_FFI_LOG_AND_THROW(TypeError) << "Metadata can be only int, bool or string, but on key `" + << key << "`, the type is " << value.GetTypeKey(); + } + first = false; + } + os << "}"; + return os.str(); } - os << "}"; - return os.str(); - } - std::vector> dict_; + std::vector> dict_; }; /*! * \brief Trait that can be used to set field default value */ class DefaultValue : public InfoTrait { - public: - /*! - * \brief Constructor - * \param value The value to be set - */ - explicit DefaultValue(Any value) : value_(std::move(value)) {} - - /*! - * \brief Apply the default value to the field info - * \param info The field info. - */ - TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const { - info->default_value = AnyView(value_).CopyToTVMFFIAny(); - info->flags |= kTVMFFIFieldFlagBitMaskHasDefault; - } - - private: - Any value_; +public: + /*! + * \brief Constructor + * \param value The value to be set + */ + explicit DefaultValue(Any value) : value_(std::move(value)) {} + + /*! + * \brief Apply the default value to the field info + * \param info The field info. + */ + TVM_FFI_INLINE void Apply(TVMFFIFieldInfo *info) const { + info->default_value = AnyView(value_).CopyToTVMFFIAny(); + info->flags |= kTVMFFIFieldFlagBitMaskHasDefault; + } + +private: + Any value_; }; /*! * \brief Trait that can be used to attach field flag */ class AttachFieldFlag : public InfoTrait { - public: - /*! - * \brief Attach a field flag to the field - * \param flag The flag to be set - */ - explicit AttachFieldFlag(int32_t flag) : flag_(flag) {} - - /*! - * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDef - */ - TVM_FFI_INLINE static AttachFieldFlag SEqHashDef() { - return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDef); - } - /*! - * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashIgnore - */ - TVM_FFI_INLINE static AttachFieldFlag SEqHashIgnore() { - return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashIgnore); - } - - /*! - * \brief Apply the field flag to the field info - * \param info The field info. - */ - TVM_FFI_INLINE void Apply(TVMFFIFieldInfo* info) const { info->flags |= flag_; } - - private: - int32_t flag_; +public: + /*! + * \brief Attach a field flag to the field + * \param flag The flag to be set + */ + explicit AttachFieldFlag(int32_t flag) : flag_(flag) {} + + /*! + * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDef + */ + TVM_FFI_INLINE static AttachFieldFlag SEqHashDef() { + return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDef); + } + /*! + * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashIgnore + */ + TVM_FFI_INLINE static AttachFieldFlag SEqHashIgnore() { + return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashIgnore); + } + + /*! + * \brief Apply the field flag to the field info + * \param info The field info. + */ + TVM_FFI_INLINE void Apply(TVMFFIFieldInfo *info) const { info->flags |= flag_; } + +private: + int32_t flag_; }; /*! @@ -200,122 +200,121 @@ class AttachFieldFlag : public InfoTrait { * \returns The byteoffset */ template -TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::* field_ptr) { - int64_t field_offset_to_class = - reinterpret_cast(&(static_cast(nullptr)->*field_ptr)); - return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass(); +TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) { + int64_t field_offset_to_class = reinterpret_cast(&(static_cast(nullptr)->*field_ptr)); + return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass(); } /// \cond Doxygen_Suppress class ReflectionDefBase { - protected: - template - static int FieldGetter(void* field, TVMFFIAny* result) { - TVM_FFI_SAFE_CALL_BEGIN(); - *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast(field))); - TVM_FFI_SAFE_CALL_END(); - } - - template - static int FieldSetter(void* field, const TVMFFIAny* value) { - TVM_FFI_SAFE_CALL_BEGIN(); - if constexpr (std::is_same_v) { - *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value); - } else { - *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value).cast(); - } - TVM_FFI_SAFE_CALL_END(); - } - - template - static int ObjectCreatorDefault(TVMFFIObjectHandle* result) { - TVM_FFI_SAFE_CALL_BEGIN(); - ObjectPtr obj = make_object(); - *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); - TVM_FFI_SAFE_CALL_END(); - } - - template - static int ObjectCreatorUnsafeInit(TVMFFIObjectHandle* result) { - TVM_FFI_SAFE_CALL_BEGIN(); - ObjectPtr obj = make_object(UnsafeInit{}); - *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); - TVM_FFI_SAFE_CALL_END(); - } - - template - TVM_FFI_INLINE static void ApplyFieldInfoTrait(FieldInfoBuilder* info, const T& value) { - if constexpr (std::is_base_of_v>) { - value.Apply(info); +protected: + template + static int FieldGetter(void *field, TVMFFIAny *result) { + TVM_FFI_SAFE_CALL_BEGIN(); + *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast(field))); + TVM_FFI_SAFE_CALL_END(); } - if constexpr (std::is_same_v, char*>) { - info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + + template + static int FieldSetter(void *field, const TVMFFIAny *value) { + TVM_FFI_SAFE_CALL_BEGIN(); + if constexpr (std::is_same_v) { + *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value); + } else { + *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value).cast(); + } + TVM_FFI_SAFE_CALL_END(); } - } - template - TVM_FFI_INLINE static void ApplyMethodInfoTrait(MethodInfoBuilder* info, const T& value) { - if constexpr (std::is_base_of_v>) { - value.Apply(info); + template + static int ObjectCreatorDefault(TVMFFIObjectHandle *result) { + TVM_FFI_SAFE_CALL_BEGIN(); + ObjectPtr obj = make_object(); + *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); + TVM_FFI_SAFE_CALL_END(); } - if constexpr (std::is_same_v, char*>) { - info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + + template + static int ObjectCreatorUnsafeInit(TVMFFIObjectHandle *result) { + TVM_FFI_SAFE_CALL_BEGIN(); + ObjectPtr obj = make_object(UnsafeInit{}); + *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); + TVM_FFI_SAFE_CALL_END(); } - } - template - TVM_FFI_INLINE static void ApplyExtraInfoTrait(TVMFFITypeMetadata* info, const T& value) { - if constexpr (std::is_same_v, char*>) { - info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + template + TVM_FFI_INLINE static void ApplyFieldInfoTrait(FieldInfoBuilder *info, const T &value) { + if constexpr (std::is_base_of_v>) { + value.Apply(info); + } + if constexpr (std::is_same_v, char *>) { + info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + } } - } - - template - TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...)) { - static_assert(std::is_base_of_v || std::is_base_of_v, - "Class must be derived from ObjectRef or Object"); - if constexpr (std::is_base_of_v) { - auto fwrap = [func](Class target, Args... params) -> R { - // call method pointer - return (target.*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, std::move(name)); + + template + TVM_FFI_INLINE static void ApplyMethodInfoTrait(MethodInfoBuilder *info, const T &value) { + if constexpr (std::is_base_of_v>) { + value.Apply(info); + } + if constexpr (std::is_same_v, char *>) { + info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + } } - if constexpr (std::is_base_of_v) { - auto fwrap = [func](const Class* target, Args... params) -> R { - // call method pointer - return (const_cast(target)->*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, std::move(name)); + template + TVM_FFI_INLINE static void ApplyExtraInfoTrait(TVMFFITypeMetadata *info, const T &value) { + if constexpr (std::is_same_v, char *>) { + info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; + } } - } - - template - TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...) const) { - static_assert(std::is_base_of_v || std::is_base_of_v, - "Class must be derived from ObjectRef or Object"); - if constexpr (std::is_base_of_v) { - auto fwrap = [func](const Class& target, Args... params) -> R { - // call method pointer - return (target.*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, std::move(name)); + + template + TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...)) { + static_assert(std::is_base_of_v || std::is_base_of_v, + "Class must be derived from ObjectRef or Object"); + if constexpr (std::is_base_of_v) { + auto fwrap = [func](Class target, Args... params) -> R { + // call method pointer + return (target.*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, std::move(name)); + } + + if constexpr (std::is_base_of_v) { + auto fwrap = [func](const Class *target, Args... params) -> R { + // call method pointer + return (const_cast(target)->*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, std::move(name)); + } } - if constexpr (std::is_base_of_v) { - auto fwrap = [func](const Class* target, Args... params) -> R { - // call method pointer - return (target->*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, std::move(name)); + template + TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...) const) { + static_assert(std::is_base_of_v || std::is_base_of_v, + "Class must be derived from ObjectRef or Object"); + if constexpr (std::is_base_of_v) { + auto fwrap = [func](const Class &target, Args... params) -> R { + // call method pointer + return (target.*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, std::move(name)); + } + + if constexpr (std::is_base_of_v) { + auto fwrap = [func](const Class *target, Args... params) -> R { + // call method pointer + return (target->*func)(std::forward(params)...); + }; + return ffi::Function::FromTyped(fwrap, std::move(name)); + } } - } - template - TVM_FFI_INLINE static Function GetMethod(std::string name, Func&& func) { - return ffi::Function::FromTyped(std::forward(func), std::move(name)); - } + template + TVM_FFI_INLINE static Function GetMethod(std::string name, Func &&func) { + return ffi::Function::FromTyped(std::forward(func), std::move(name)); + } }; /// \endcond @@ -328,82 +327,82 @@ class ReflectionDefBase { * \endcode */ class GlobalDef : public ReflectionDefBase { - public: - /*! - * \brief Define a global function. - * - * \tparam Func The function type. - * \tparam Extra The extra arguments. - * - * \param name The name of the function. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring or subclass of InfoTrait. - * - * \return The reflection definition. - */ - template - GlobalDef& def(const char* name, Func&& func, Extra&&... extra) { - using FuncInfo = details::FunctionInfo>; - RegisterFunc(name, ffi::Function::FromTyped(std::forward(func), std::string(name)), - FuncInfo::TypeSchema(), std::forward(extra)...); - return *this; - } - - /*! - * \brief Define a global function in ffi::PackedArgs format. - * - * \tparam Func The function type. - * \tparam Extra The extra arguments. - * - * \param name The name of the function. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring or subclass of InfoTrait. - * - * \return The reflection definition. - */ - template - GlobalDef& def_packed(const char* name, Func func, Extra&&... extra) { - RegisterFunc(name, ffi::Function::FromPacked(func), details::TypeSchemaImpl::v(), - std::forward(extra)...); - return *this; - } - - /*! - * \brief Expose a class method as a global function. - * - * An argument will be added to the first position if the function is not static. - * - * \tparam Class The class type. - * \tparam Func The function type. - * - * \param name The name of the method. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring. - * - * \return The reflection definition. - */ - template - GlobalDef& def_method(const char* name, Func&& func, Extra&&... extra) { - using FuncInfo = details::FunctionInfo>; - RegisterFunc(name, GetMethod(std::string(name), std::forward(func)), - FuncInfo::TypeSchema(), std::forward(extra)...); - return *this; - } - - private: - template // NOLINTNEXTLINE(performance-unnecessary-value-param) - void RegisterFunc(const char* name, ffi::Function func, String type_schema, Extra&&... extra) { - MethodInfoBuilder info; - info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; - info.doc = TVMFFIByteArray{nullptr, 0}; - info.flags = 0; - info.method = AnyView(func).CopyToTVMFFIAny(); - info.metadata_.emplace_back("type_schema", type_schema); - ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); - std::string metadata_str = Metadata::ToJSON(info.metadata_); - info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionSetGlobalFromMethodInfo(&info, 0)); - } +public: + /*! + * \brief Define a global function. + * + * \tparam Func The function type. + * \tparam Extra The extra arguments. + * + * \param name The name of the function. + * \param func The function to be registered. + * \param extra The extra arguments that can be docstring or subclass of InfoTrait. + * + * \return The reflection definition. + */ + template + GlobalDef &def(const char *name, Func &&func, Extra &&...extra) { + using FuncInfo = details::FunctionInfo>; + RegisterFunc(name, ffi::Function::FromTyped(std::forward(func), std::string(name)), + FuncInfo::TypeSchema(), std::forward(extra)...); + return *this; + } + + /*! + * \brief Define a global function in ffi::PackedArgs format. + * + * \tparam Func The function type. + * \tparam Extra The extra arguments. + * + * \param name The name of the function. + * \param func The function to be registered. + * \param extra The extra arguments that can be docstring or subclass of InfoTrait. + * + * \return The reflection definition. + */ + template + GlobalDef &def_packed(const char *name, Func func, Extra &&...extra) { + RegisterFunc(name, ffi::Function::FromPacked(func), details::TypeSchemaImpl::v(), + std::forward(extra)...); + return *this; + } + + /*! + * \brief Expose a class method as a global function. + * + * An argument will be added to the first position if the function is not static. + * + * \tparam Class The class type. + * \tparam Func The function type. + * + * \param name The name of the method. + * \param func The function to be registered. + * \param extra The extra arguments that can be docstring. + * + * \return The reflection definition. + */ + template + GlobalDef &def_method(const char *name, Func &&func, Extra &&...extra) { + using FuncInfo = details::FunctionInfo>; + RegisterFunc(name, GetMethod(std::string(name), std::forward(func)), + FuncInfo::TypeSchema(), std::forward(extra)...); + return *this; + } + +private: + template // NOLINTNEXTLINE(performance-unnecessary-value-param) + void RegisterFunc(const char *name, ffi::Function func, String type_schema, Extra &&...extra) { + MethodInfoBuilder info; + info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; + info.doc = TVMFFIByteArray{nullptr, 0}; + info.flags = 0; + info.method = AnyView(func).CopyToTVMFFIAny(); + info.metadata_.emplace_back("type_schema", type_schema); + ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); + std::string metadata_str = Metadata::ToJSON(info.metadata_); + info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionSetGlobalFromMethodInfo(&info, 0)); + } }; /*! @@ -434,26 +433,26 @@ class GlobalDef : public ReflectionDefBase { */ template struct init { - // Allow ObjectDef to access the execute function - template - friend class ObjectDef; - - /*! - * \brief Constructor - */ - constexpr init() noexcept = default; - - private: - /*! - * \brief Execute the constructor - * \tparam Class The class type. - * \param args The arguments to be passed to the constructor. - * \return The constructed object wrapped in an `ObjectRef`. - */ - template - static inline ObjectRef execute(Args&&... args) { - return ObjectRef(ffi::make_object(std::forward(args)...)); - } + // Allow ObjectDef to access the execute function + template + friend class ObjectDef; + + /*! + * \brief Constructor + */ + constexpr init() noexcept = default; + +private: + /*! + * \brief Execute the constructor + * \tparam Class The class type. + * \param args The arguments to be passed to the constructor. + * \return The constructed object wrapped in an `ObjectRef`. + */ + template + static inline ObjectRef execute(Args &&...args) { + return ObjectRef(ffi::make_object(std::forward(args)...)); + } }; /*! @@ -467,194 +466,194 @@ struct init { */ template class ObjectDef : public ReflectionDefBase { - public: - /*! - * \brief Constructor - * \tparam ExtraArgs The extra arguments. - * \param extra_args The extra arguments. - */ - template - explicit ObjectDef(ExtraArgs&&... extra_args) - : type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) { - RegisterExtraInfo(std::forward(extra_args)...); - } - - /*! - * \brief Define a readonly field. - * - * \tparam Class The class type. - * \tparam T The field type. - * \tparam Extra The extra arguments. - * - * \param name The name of the field. - * \param field_ptr The pointer to the field. - * \param extra The extra arguments that can be docstring or default value. - * - * \return The reflection definition. - */ - template - TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T BaseClass::* field_ptr, Extra&&... extra) { - RegisterField(name, field_ptr, false, std::forward(extra)...); - return *this; - } - - /*! - * \brief Define a read-write field. - * - * \tparam Class The class type. - * \tparam T The field type. - * \tparam Extra The extra arguments. - * - * \param name The name of the field. - * \param field_ptr The pointer to the field. - * \param extra The extra arguments that can be docstring or default value. - * - * \return The reflection definition. - */ - template - TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T BaseClass::* field_ptr, Extra&&... extra) { - static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields"); - RegisterField(name, field_ptr, true, std::forward(extra)...); - return *this; - } - - /*! - * \brief Define a method. - * - * \tparam Func The function type. - * \tparam Extra The extra arguments. - * - * \param name The name of the method. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring. - * - * \return The reflection definition. - */ - template - TVM_FFI_INLINE ObjectDef& def(const char* name, Func&& func, Extra&&... extra) { - RegisterMethod(name, false, std::forward(func), std::forward(extra)...); - return *this; - } - - /*! - * \brief Define a static method. - * - * \tparam Func The function type. - * \tparam Extra The extra arguments. - * - * \param name The name of the method. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring. - * - * \return The reflection definition. - */ - template - TVM_FFI_INLINE ObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) { - RegisterMethod(name, true, std::forward(func), std::forward(extra)...); - return *this; - } - - /*! - * \brief Register a constructor for this object type. - * - * This method registers a static `__init__` method that constructs an instance - * of the object with the specified argument types. The constructor can be invoked - * from Python or other FFI bindings. - * - * \tparam Args The argument types for the constructor. - * \tparam Extra Additional arguments (e.g., docstring). - * - * \param init_func An instance of `init` specifying constructor signature. - * \param extra Optional additional metadata such as docstring. - * - * \return Reference to this `ObjectDef` for method chaining. - * - * Example: - * \code - * refl::ObjectDef() - * .def(refl::init(), "Constructor docstring"); - * \endcode - */ - template - TVM_FFI_INLINE ObjectDef& def([[maybe_unused]] init init_func, Extra&&... extra) { - RegisterMethod(kInitMethodName, true, &init::template execute, - std::forward(extra)...); - return *this; - } - - private: - template - void RegisterExtraInfo(ExtraArgs&&... extra_args) { - TVMFFITypeMetadata info; - info.total_size = sizeof(Class); - info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind; - info.creator = nullptr; - info.doc = TVMFFIByteArray{nullptr, 0}; - if constexpr (std::is_default_constructible_v) { - info.creator = ObjectCreatorDefault; - } else if constexpr (std::is_constructible_v) { - info.creator = ObjectCreatorUnsafeInit; +public: + /*! + * \brief Constructor + * \tparam ExtraArgs The extra arguments. + * \param extra_args The extra arguments. + */ + template + explicit ObjectDef(ExtraArgs &&...extra_args) + : type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) { + RegisterExtraInfo(std::forward(extra_args)...); + } + + /*! + * \brief Define a readonly field. + * + * \tparam Class The class type. + * \tparam T The field type. + * \tparam Extra The extra arguments. + * + * \param name The name of the field. + * \param field_ptr The pointer to the field. + * \param extra The extra arguments that can be docstring or default value. + * + * \return The reflection definition. + */ + template + TVM_FFI_INLINE ObjectDef &def_ro(const char *name, T BaseClass::*field_ptr, Extra &&...extra) { + RegisterField(name, field_ptr, false, std::forward(extra)...); + return *this; + } + + /*! + * \brief Define a read-write field. + * + * \tparam Class The class type. + * \tparam T The field type. + * \tparam Extra The extra arguments. + * + * \param name The name of the field. + * \param field_ptr The pointer to the field. + * \param extra The extra arguments that can be docstring or default value. + * + * \return The reflection definition. + */ + template + TVM_FFI_INLINE ObjectDef &def_rw(const char *name, T BaseClass::*field_ptr, Extra &&...extra) { + static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields"); + RegisterField(name, field_ptr, true, std::forward(extra)...); + return *this; + } + + /*! + * \brief Define a method. + * + * \tparam Func The function type. + * \tparam Extra The extra arguments. + * + * \param name The name of the method. + * \param func The function to be registered. + * \param extra The extra arguments that can be docstring. + * + * \return The reflection definition. + */ + template + TVM_FFI_INLINE ObjectDef &def(const char *name, Func &&func, Extra &&...extra) { + RegisterMethod(name, false, std::forward(func), std::forward(extra)...); + return *this; + } + + /*! + * \brief Define a static method. + * + * \tparam Func The function type. + * \tparam Extra The extra arguments. + * + * \param name The name of the method. + * \param func The function to be registered. + * \param extra The extra arguments that can be docstring. + * + * \return The reflection definition. + */ + template + TVM_FFI_INLINE ObjectDef &def_static(const char *name, Func &&func, Extra &&...extra) { + RegisterMethod(name, true, std::forward(func), std::forward(extra)...); + return *this; + } + + /*! + * \brief Register a constructor for this object type. + * + * This method registers a static `__init__` method that constructs an instance + * of the object with the specified argument types. The constructor can be invoked + * from Python or other FFI bindings. + * + * \tparam Args The argument types for the constructor. + * \tparam Extra Additional arguments (e.g., docstring). + * + * \param init_func An instance of `init` specifying constructor signature. + * \param extra Optional additional metadata such as docstring. + * + * \return Reference to this `ObjectDef` for method chaining. + * + * Example: + * \code + * refl::ObjectDef() + * .def(refl::init(), "Constructor docstring"); + * \endcode + */ + template + TVM_FFI_INLINE ObjectDef &def([[maybe_unused]] init init_func, Extra &&...extra) { + RegisterMethod(kInitMethodName, true, &init::template execute, + std::forward(extra)...); + return *this; } - // apply extra info traits - ((ApplyExtraInfoTrait(&info, std::forward(extra_args)), ...)); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMetadata(type_index_, &info)); - } - - template - void RegisterField(const char* name, T BaseClass::* field_ptr, bool writable, - ExtraArgs&&... extra_args) { - static_assert(std::is_base_of_v, "BaseClass must be a base class of Class"); - FieldInfoBuilder info; - info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; - info.field_static_type_index = TypeToFieldStaticTypeIndex::value; - // store byte offset and setter, getter - // so the same setter can be reused for all the same type - info.offset = GetFieldByteOffsetToObject(field_ptr); - info.size = sizeof(T); - info.alignment = alignof(T); - info.flags = 0; - if (writable) { - info.flags |= kTVMFFIFieldFlagBitMaskWritable; + +private: + template + void RegisterExtraInfo(ExtraArgs &&...extra_args) { + TVMFFITypeMetadata info; + info.total_size = sizeof(Class); + info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind; + info.creator = nullptr; + info.doc = TVMFFIByteArray{nullptr, 0}; + if constexpr (std::is_default_constructible_v) { + info.creator = ObjectCreatorDefault; + } else if constexpr (std::is_constructible_v) { + info.creator = ObjectCreatorUnsafeInit; + } + // apply extra info traits + ((ApplyExtraInfoTrait(&info, std::forward(extra_args)), ...)); + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMetadata(type_index_, &info)); } - info.getter = FieldGetter; - info.setter = FieldSetter; - // initialize default value to nullptr - info.default_value = AnyView(nullptr).CopyToTVMFFIAny(); - info.doc = TVMFFIByteArray{nullptr, 0}; - info.metadata_.emplace_back("type_schema", details::TypeSchema::v()); - // apply field info traits - ((ApplyFieldInfoTrait(&info, std::forward(extra_args)), ...)); - // call register - std::string metadata_str = Metadata::ToJSON(info.metadata_); - info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info)); - } - - // register a method - template - void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) { - using FuncInfo = details::FunctionInfo>; - MethodInfoBuilder info; - info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; - info.doc = TVMFFIByteArray{nullptr, 0}; - info.flags = 0; - if (is_static) { - info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod; + + template + void RegisterField(const char *name, T BaseClass::*field_ptr, bool writable, + ExtraArgs &&...extra_args) { + static_assert(std::is_base_of_v, "BaseClass must be a base class of Class"); + FieldInfoBuilder info; + info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; + info.field_static_type_index = TypeToFieldStaticTypeIndex::value; + // store byte offset and setter, getter + // so the same setter can be reused for all the same type + info.offset = GetFieldByteOffsetToObject(field_ptr); + info.size = sizeof(T); + info.alignment = alignof(T); + info.flags = 0; + if (writable) { + info.flags |= kTVMFFIFieldFlagBitMaskWritable; + } + info.getter = FieldGetter; + info.setter = FieldSetter; + // initialize default value to nullptr + info.default_value = AnyView(nullptr).CopyToTVMFFIAny(); + info.doc = TVMFFIByteArray{nullptr, 0}; + info.metadata_.emplace_back("type_schema", details::TypeSchema::v()); + // apply field info traits + ((ApplyFieldInfoTrait(&info, std::forward(extra_args)), ...)); + // call register + std::string metadata_str = Metadata::ToJSON(info.metadata_); + info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info)); + } + + // register a method + template + void RegisterMethod(const char *name, bool is_static, Func &&func, Extra &&...extra) { + using FuncInfo = details::FunctionInfo>; + MethodInfoBuilder info; + info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; + info.doc = TVMFFIByteArray{nullptr, 0}; + info.flags = 0; + if (is_static) { + info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod; + } + // obtain the method function + Function method = GetMethod(std::string(type_key_) + "." + name, std::forward(func)); + info.method = AnyView(method).CopyToTVMFFIAny(); + info.metadata_.emplace_back("type_schema", FuncInfo::TypeSchema()); + // apply method info traits + ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); + std::string metadata_str = Metadata::ToJSON(info.metadata_); + info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info)); } - // obtain the method function - Function method = GetMethod(std::string(type_key_) + "." + name, std::forward(func)); - info.method = AnyView(method).CopyToTVMFFIAny(); - info.metadata_.emplace_back("type_schema", FuncInfo::TypeSchema()); - // apply method info traits - ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); - std::string metadata_str = Metadata::ToJSON(info.metadata_); - info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info)); - } - - int32_t type_index_; - const char* type_key_; - static constexpr const char* kInitMethodName = "__ffi_init__"; + + int32_t type_index_; + const char *type_key_; + static constexpr const char *kInitMethodName = "__ffi_init__"; }; /*! @@ -670,57 +669,56 @@ class ObjectDef : public ReflectionDefBase { */ template >> class TypeAttrDef : public ReflectionDefBase { - public: - /*! - * \brief Constructor - * \tparam ExtraArgs The extra arguments. - * \param extra_args The extra arguments. - */ - template - explicit TypeAttrDef(ExtraArgs&&... extra_args) - : type_index_(Class::RuntimeTypeIndex()), type_key_(Class::_type_key) {} - - /*! - * \brief Define a function-valued type attribute. - * - * \tparam Func The function type. - * - * \param name The name of the function. - * \param func The function to be registered. - * - * \return The TypeAttrDef object. - */ - template - TypeAttrDef& def(const char* name, Func&& func) { - TVMFFIByteArray name_array = {name, std::char_traits::length(name)}; - ffi::Function ffi_func = - GetMethod(std::string(type_key_) + "." + name, std::forward(func)); - TVMFFIAny value_any = AnyView(ffi_func).CopyToTVMFFIAny(); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any)); - return *this; - } - - /*! - * \brief Define a constant-valued type attribute. - * - * \tparam T The type of the value. - * - * \param name The name of the attribute. - * \param value The value of the attribute. - * - * \return The TypeAttrDef object. - */ - template - TypeAttrDef& attr(const char* name, T value) { - TVMFFIByteArray name_array = {name, std::char_traits::length(name)}; - TVMFFIAny value_any = AnyView(value).CopyToTVMFFIAny(); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any)); - return *this; - } - - private: - int32_t type_index_; - const char* type_key_; +public: + /*! + * \brief Constructor + * \tparam ExtraArgs The extra arguments. + * \param extra_args The extra arguments. + */ + template + explicit TypeAttrDef(ExtraArgs &&...extra_args) + : type_index_(Class::RuntimeTypeIndex()), type_key_(Class::_type_key) {} + + /*! + * \brief Define a function-valued type attribute. + * + * \tparam Func The function type. + * + * \param name The name of the function. + * \param func The function to be registered. + * + * \return The TypeAttrDef object. + */ + template + TypeAttrDef &def(const char *name, Func &&func) { + TVMFFIByteArray name_array = {name, std::char_traits::length(name)}; + ffi::Function ffi_func = GetMethod(std::string(type_key_) + "." + name, std::forward(func)); + TVMFFIAny value_any = AnyView(ffi_func).CopyToTVMFFIAny(); + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any)); + return *this; + } + + /*! + * \brief Define a constant-valued type attribute. + * + * \tparam T The type of the value. + * + * \param name The name of the attribute. + * \param value The value of the attribute. + * + * \return The TypeAttrDef object. + */ + template + TypeAttrDef &attr(const char *name, T value) { + TVMFFIByteArray name_array = {name, std::char_traits::length(name)}; + TVMFFIAny value_any = AnyView(value).CopyToTVMFFIAny(); + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any)); + return *this; + } + +private: + int32_t type_index_; + const char *type_key_; }; /*! @@ -729,13 +727,13 @@ class TypeAttrDef : public ReflectionDefBase { * \param name The name of the type attribute. */ inline void EnsureTypeAttrColumn(std::string_view name) { - TVMFFIByteArray name_array = {name.data(), name.size()}; - AnyView any_view(nullptr); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(kTVMFFINone, &name_array, - reinterpret_cast(&any_view))); + TVMFFIByteArray name_array = {name.data(), name.size()}; + AnyView any_view(nullptr); + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(kTVMFFINone, &name_array, + reinterpret_cast(&any_view))); } -} // namespace reflection -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_REFLECTION_REGISTRY_H_ +} // namespace reflection +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_REFLECTION_REGISTRY_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/rvalue_ref.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/rvalue_ref.h index aca5840fa..e12a4d44e 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/rvalue_ref.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/rvalue_ref.h @@ -70,21 +70,21 @@ namespace ffi { */ template >> class RValueRef { - public: - /*! \brief the container type of the rvalue ref */ - using ContainerType = typename TObjRef::ContainerType; - /*! \brief only allow move constructor from rvalue of T */ - explicit RValueRef(TObjRef&& data) - : data_(details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(data))) {} +public: + /*! \brief the container type of the rvalue ref */ + using ContainerType = typename TObjRef::ContainerType; + /*! \brief only allow move constructor from rvalue of T */ + explicit RValueRef(TObjRef &&data) + : data_(details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(data))) {} - /*! \brief return the data as rvalue */ - TObjRef operator*() && { return TObjRef(std::move(data_)); } + /*! \brief return the data as rvalue */ + TObjRef operator*() && { return TObjRef(std::move(data_)); } - private: - mutable ObjectPtr data_; +private: + mutable ObjectPtr data_; - template - friend struct TypeTraits; + template + friend struct TypeTraits; }; template @@ -92,72 +92,72 @@ inline constexpr bool use_default_type_traits_v> = false; template struct TypeTraits> : public TypeTraitsBase { - static constexpr bool storage_enabled = false; - - TVM_FFI_INLINE static void CopyToAnyView(const RValueRef& src, TVMFFIAny* result) { - result->type_index = TypeIndex::kTVMFFIObjectRValueRef; - result->zero_padding = 0; - // store the address of the ObjectPtr, which allows us to move the value - // and set the original ObjectPtr to nullptr - result->v_ptr = &(src.data_); - } - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { - if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) { - ObjectPtr* rvalue_ref = reinterpret_cast*>(src->v_ptr); - // object type does not match up, we need to try to convert the object - // in this case we do not move the original rvalue ref since conversion creates a copy - TVMFFIAny tmp_any; - tmp_any.type_index = rvalue_ref->get()->type_index(); - tmp_any.zero_padding = 0; - tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); - return "RValueRef<" + TypeTraits::GetMismatchTypeInfo(&tmp_any) + ">"; - } else { - return TypeTraits::GetMismatchTypeInfo(src); + static constexpr bool storage_enabled = false; + + TVM_FFI_INLINE static void CopyToAnyView(const RValueRef &src, TVMFFIAny *result) { + result->type_index = TypeIndex::kTVMFFIObjectRValueRef; + result->zero_padding = 0; + // store the address of the ObjectPtr, which allows us to move the value + // and set the original ObjectPtr to nullptr + result->v_ptr = &(src.data_); } - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { - // first try rvalue conversion - if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) { - ObjectPtr* rvalue_ref = reinterpret_cast*>(src->v_ptr); - TVMFFIAny tmp_any; - tmp_any.type_index = rvalue_ref->get()->type_index(); - tmp_any.zero_padding = 0; - tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); - // fast path, storage type matches, direct move the rvalue ref - if (TypeTraits::CheckAnyStrict(&tmp_any)) { - return RValueRef( - details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(*rvalue_ref))); - } - if (std::optional opt = TypeTraits::TryCastFromAnyView(&tmp_any)) { - // object type does not match up, we need to try to convert the object - // in this case we do not move the original rvalue ref since conversion creates a copy - return RValueRef(*std::move(opt)); - } - return std::nullopt; + + TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny *src) { + if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) { + ObjectPtr *rvalue_ref = reinterpret_cast *>(src->v_ptr); + // object type does not match up, we need to try to convert the object + // in this case we do not move the original rvalue ref since conversion creates a copy + TVMFFIAny tmp_any; + tmp_any.type_index = rvalue_ref->get()->type_index(); + tmp_any.zero_padding = 0; + tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); + return "RValueRef<" + TypeTraits::GetMismatchTypeInfo(&tmp_any) + ">"; + } else { + return TypeTraits::GetMismatchTypeInfo(src); + } + } + + TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny *src) { + // first try rvalue conversion + if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) { + ObjectPtr *rvalue_ref = reinterpret_cast *>(src->v_ptr); + TVMFFIAny tmp_any; + tmp_any.type_index = rvalue_ref->get()->type_index(); + tmp_any.zero_padding = 0; + tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); + // fast path, storage type matches, direct move the rvalue ref + if (TypeTraits::CheckAnyStrict(&tmp_any)) { + return RValueRef( + details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(*rvalue_ref))); + } + if (std::optional opt = TypeTraits::TryCastFromAnyView(&tmp_any)) { + // object type does not match up, we need to try to convert the object + // in this case we do not move the original rvalue ref since conversion creates a copy + return RValueRef(*std::move(opt)); + } + return std::nullopt; + } + // try lvalue conversion + if (std::optional opt = TypeTraits::TryCastFromAnyView(src)) { + return RValueRef(*std::move(opt)); + } else { + return std::nullopt; + } } - // try lvalue conversion - if (std::optional opt = TypeTraits::TryCastFromAnyView(src)) { - return RValueRef(*std::move(opt)); - } else { - return std::nullopt; + + TVM_FFI_INLINE static std::string TypeStr() { + return "RValueRef<" + TypeTraits::TypeStr() + ">"; + } + + TVM_FFI_INLINE static std::string TypeSchema() { + std::ostringstream oss; + oss << R"({"type":")" << StaticTypeKey::kTVMFFIObjectRValueRef << R"(","args":[)"; + oss << TypeTraits::TypeSchema(); + oss << "]}"; + return oss.str(); } - } - - TVM_FFI_INLINE static std::string TypeStr() { - return "RValueRef<" + TypeTraits::TypeStr() + ">"; - } - - TVM_FFI_INLINE static std::string TypeSchema() { - std::ostringstream oss; - oss << R"({"type":")" << StaticTypeKey::kTVMFFIObjectRValueRef << R"(","args":[)"; - oss << TypeTraits::TypeSchema(); - oss << "]}"; - return oss.str(); - } }; -} // namespace ffi -} // namespace tvm +} // namespace ffi +} // namespace tvm -#endif // TVM_FFI_RVALUE_REF_H_ +#endif // TVM_FFI_RVALUE_REF_H_ From 598cec9e02a491f7af11f8e247771f4825884c9c Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Wed, 20 May 2026 03:17:52 +0000 Subject: [PATCH 06/10] issue/1083: delete tvm --- .../sgl_kernel/dlpack/dlpack.h | 639 ------ .../ops/gptq_marlin_gemm/sgl_kernel/tensor.h | 6 +- .../gptq_marlin_gemm/sgl_kernel/tvm/ffi/any.h | 682 ------- .../sgl_kernel/tvm/ffi/base_details.h | 317 --- .../sgl_kernel/tvm/ffi/c_api.h | 1226 ----------- .../sgl_kernel/tvm/ffi/cast.h | 79 - .../sgl_kernel/tvm/ffi/container/array.h | 1164 ----------- .../tvm/ffi/container/container_details.h | 360 ---- .../sgl_kernel/tvm/ffi/container/map.h | 1785 ----------------- .../sgl_kernel/tvm/ffi/container/shape.h | 343 ---- .../sgl_kernel/tvm/ffi/container/tensor.h | 785 -------- .../sgl_kernel/tvm/ffi/container/tuple.h | 400 ---- .../sgl_kernel/tvm/ffi/container/variant.h | 311 --- .../sgl_kernel/tvm/ffi/dtype.h | 199 -- .../sgl_kernel/tvm/ffi/endian.h | 89 - .../sgl_kernel/tvm/ffi/error.h | 398 ---- .../sgl_kernel/tvm/ffi/extra/base.h | 48 - .../sgl_kernel/tvm/ffi/extra/base64.h | 140 -- .../sgl_kernel/tvm/ffi/extra/c_env_api.h | 158 -- .../sgl_kernel/tvm/ffi/extra/cuda/base.h | 54 - .../tvm/ffi/extra/cuda/cubin_launcher.h | 604 ------ .../tvm/ffi/extra/cuda/device_guard.h | 74 - .../sgl_kernel/tvm/ffi/extra/json.h | 84 - .../sgl_kernel/tvm/ffi/extra/module.h | 301 --- .../sgl_kernel/tvm/ffi/extra/serialization.h | 72 - .../tvm/ffi/extra/structural_equal.h | 78 - .../tvm/ffi/extra/structural_hash.h | 57 - .../sgl_kernel/tvm/ffi/function.h | 998 --------- .../sgl_kernel/tvm/ffi/function_details.h | 272 --- .../sgl_kernel/tvm/ffi/memory.h | 274 --- .../sgl_kernel/tvm/ffi/object.h | 1207 ----------- .../sgl_kernel/tvm/ffi/optional.h | 428 ---- .../tvm/ffi/reflection/access_path.h | 450 ----- .../sgl_kernel/tvm/ffi/reflection/accessor.h | 264 --- .../sgl_kernel/tvm/ffi/reflection/creator.h | 121 -- .../sgl_kernel/tvm/ffi/reflection/registry.h | 739 ------- .../sgl_kernel/tvm/ffi/rvalue_ref.h | 163 -- .../sgl_kernel/tvm/ffi/string.h | 1102 ---------- .../sgl_kernel/tvm/ffi/type_traits.h | 828 -------- .../ops/gptq_marlin_gemm/sgl_kernel/utils.cuh | 4 +- .../ops/gptq_marlin_gemm/sgl_kernel/utils.h | 2 +- 41 files changed, 6 insertions(+), 17299 deletions(-) delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/dlpack/dlpack.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/any.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/base_details.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/c_api.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/cast.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/array.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/container_details.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/map.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/shape.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tensor.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tuple.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/variant.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/dtype.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/endian.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/error.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base64.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/c_env_api.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/base.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/cubin_launcher.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/device_guard.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/json.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/module.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/serialization.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_equal.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_hash.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/function.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/function_details.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/memory.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/object.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/optional.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/access_path.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/accessor.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/creator.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/registry.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/rvalue_ref.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/string.h delete mode 100644 src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/type_traits.h diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/dlpack/dlpack.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/dlpack/dlpack.h deleted file mode 100644 index a6e2f5c58..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/dlpack/dlpack.h +++ /dev/null @@ -1,639 +0,0 @@ -/*! - * Copyright (c) 2017 - by Contributors - * \file dlpack.h - * \brief The common header of DLPack. - */ -#ifndef DLPACK_DLPACK_H_ -#define DLPACK_DLPACK_H_ - -/** - * \brief Compatibility with C++ - */ -#ifdef __cplusplus -#define DLPACK_EXTERN_C extern "C" -#else -#define DLPACK_EXTERN_C -#endif - -/*! \brief The current major version of dlpack */ -#define DLPACK_MAJOR_VERSION 1 - -/*! \brief The current minor version of dlpack */ -#define DLPACK_MINOR_VERSION 2 - -/*! \brief DLPACK_DLL prefix for windows */ -#ifdef _WIN32 -#ifdef DLPACK_EXPORTS -#define DLPACK_DLL __declspec(dllexport) -#else -#define DLPACK_DLL __declspec(dllimport) -#endif -#else -#define DLPACK_DLL -#endif - -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -/*! - * \brief The DLPack version. - * - * A change in major version indicates that we have changed the - * data layout of the ABI - DLManagedTensorVersioned. - * - * A change in minor version indicates that we have added new - * code, such as a new device type, but the ABI is kept the same. - * - * If an obtained DLPack tensor has a major version that disagrees - * with the version number specified in this header file - * (i.e. major != DLPACK_MAJOR_VERSION), the consumer must call the deleter - * (and it is safe to do so). It is not safe to access any other fields - * as the memory layout will have changed. - * - * In the case of a minor version mismatch, the tensor can be safely used as - * long as the consumer knows how to interpret all fields. Minor version - * updates indicate the addition of enumeration values. - */ -typedef struct { - /*! \brief DLPack major version. */ - uint32_t major; - /*! \brief DLPack minor version. */ - uint32_t minor; -} DLPackVersion; - -/*! - * \brief The device type in DLDevice. - */ -#ifdef __cplusplus -typedef enum : int32_t { -#else -typedef enum { -#endif - /*! \brief CPU device */ - kDLCPU = 1, - /*! \brief CUDA GPU device */ - kDLCUDA = 2, - /*! - * \brief Pinned CUDA CPU memory by cudaMallocHost - */ - kDLCUDAHost = 3, - /*! \brief OpenCL devices. */ - kDLOpenCL = 4, - /*! \brief Vulkan buffer for next generation graphics. */ - kDLVulkan = 7, - /*! \brief Metal for Apple GPU. */ - kDLMetal = 8, - /*! \brief Verilog simulator buffer */ - kDLVPI = 9, - /*! \brief ROCm GPUs for AMD GPUs */ - kDLROCM = 10, - /*! - * \brief Pinned ROCm CPU memory allocated by hipMallocHost - */ - kDLROCMHost = 11, - /*! - * \brief Reserved extension device type, - * used for quickly test extension device - * The semantics can differ depending on the implementation. - */ - kDLExtDev = 12, - /*! - * \brief CUDA managed/unified memory allocated by cudaMallocManaged - */ - kDLCUDAManaged = 13, - /*! - * \brief Unified shared memory allocated on a oneAPI non-partititioned - * device. Call to oneAPI runtime is required to determine the device - * type, the USM allocation type and the sycl context it is bound to. - * - */ - kDLOneAPI = 14, - /*! \brief GPU support for next generation WebGPU standard. */ - kDLWebGPU = 15, - /*! \brief Qualcomm Hexagon DSP */ - kDLHexagon = 16, - /*! \brief Microsoft MAIA devices */ - kDLMAIA = 17, - /*! \brief AWS Trainium */ - kDLTrn = 18, -} DLDeviceType; - -/*! - * \brief A Device for Tensor and operator. - */ -typedef struct { - /*! \brief The device type used in the device. */ - DLDeviceType device_type; - /*! - * \brief The device index. - * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0. - */ - int32_t device_id; -} DLDevice; - -/*! - * \brief The type code options DLDataType. - */ -typedef enum { - /*! \brief signed integer */ - kDLInt = 0U, - /*! \brief unsigned integer */ - kDLUInt = 1U, - /*! \brief IEEE floating point */ - kDLFloat = 2U, - /*! - * \brief Opaque handle type, reserved for testing purposes. - * Frameworks need to agree on the handle data type for the exchange to be well-defined. - */ - kDLOpaqueHandle = 3U, - /*! \brief bfloat16 */ - kDLBfloat = 4U, - /*! - * \brief complex number - * (C/C++/Python layout: compact struct per complex number) - */ - kDLComplex = 5U, - /*! \brief boolean */ - kDLBool = 6U, - /*! \brief FP8 data types */ - kDLFloat8_e3m4 = 7U, - kDLFloat8_e4m3 = 8U, - kDLFloat8_e4m3b11fnuz = 9U, - kDLFloat8_e4m3fn = 10U, - kDLFloat8_e4m3fnuz = 11U, - kDLFloat8_e5m2 = 12U, - kDLFloat8_e5m2fnuz = 13U, - kDLFloat8_e8m0fnu = 14U, - /*! \brief FP6 data types - * Setting bits != 6 is currently unspecified, and the producer must ensure it is set - * while the consumer must stop importing if the value is unexpected. - */ - kDLFloat6_e2m3fn = 15U, - kDLFloat6_e3m2fn = 16U, - /*! \brief FP4 data types - * Setting bits != 4 is currently unspecified, and the producer must ensure it is set - * while the consumer must stop importing if the value is unexpected. - */ - kDLFloat4_e2m1fn = 17U, -} DLDataTypeCode; - -/*! - * \brief The data type the tensor can hold. The data type is assumed to follow the - * native endian-ness. An explicit error message should be raised when attempting to - * export an array with non-native endianness - * - * Examples - * - float: type_code = 2, bits = 32, lanes = 1 - * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes = 4 - * - int8: type_code = 0, bits = 8, lanes = 1 - * - std::complex: type_code = 5, bits = 64, lanes = 1 - * - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits) - * - float8_e4m3: type_code = 8, bits = 8, lanes = 1 (packed in memory) - * - float6_e3m2fn: type_code = 16, bits = 6, lanes = 1 (packed in memory) - * - float4_e2m1fn: type_code = 17, bits = 4, lanes = 1 (packed in memory) - * - * When a sub-byte type is packed, DLPack requires the data to be in little bit-endian, i.e., - * for a packed data set D ((D >> (i * bits)) && bit_mask) stores the i-th element. - */ -typedef struct { - /*! - * \brief Type code of base types. - * We keep it uint8_t instead of DLDataTypeCode for minimal memory - * footprint, but the value should be one of DLDataTypeCode enum values. - * */ - uint8_t code; - /*! - * \brief Number of bits, common choices are 8, 16, 32. - */ - uint8_t bits; - /*! \brief Number of lanes in the type, used for vector types. */ - uint16_t lanes; -} DLDataType; - -/*! - * \brief Plain C Tensor object, does not manage memory. - */ -typedef struct { - /*! - * \brief The data pointer points to the allocated data. This will be CUDA - * device pointer or cl_mem handle in OpenCL. It may be opaque on some device - * types. This pointer is always aligned to 256 bytes as in CUDA. The - * `byte_offset` field should be used to point to the beginning of the data. - * - * Note that as of Nov 2021, multiple libraries (CuPy, PyTorch, TensorFlow, - * TVM, perhaps others) do not adhere to this 256 byte alignment requirement - * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed - * (after which this note will be updated); at the moment it is recommended - * to not rely on the data pointer being correctly aligned. - * - * For given DLTensor, the size of memory required to store the contents of - * data is calculated as follows: - * - * \code{.c} - * static inline size_t GetDataSize(const DLTensor* t) { - * size_t size = 1; - * for (tvm_index_t i = 0; i < t->ndim; ++i) { - * size *= t->shape[i]; - * } - * size *= (t->dtype.bits * t->dtype.lanes + 7) / 8; - * return size; - * } - * \endcode - * - * Note that if the tensor is of size zero, then the data pointer should be - * set to `NULL`. - */ - void *data; - /*! \brief The device of the tensor */ - DLDevice device; - /*! \brief Number of dimensions */ - int32_t ndim; - /*! \brief The data type of the pointer*/ - DLDataType dtype; - /*! - * \brief The shape of the tensor - * - * When ndim == 0, shape can be set to NULL. - */ - int64_t *shape; - /*! - * \brief strides of the tensor (in number of elements, not bytes), - * can not be NULL if ndim != 0, must points to - * an array of ndim elements that specifies the strides, - * so consumer can always rely on strides[dim] being valid for 0 <= dim < ndim. - * - * When ndim == 0, strides can be set to NULL. - * - * \note Before DLPack v1.2, strides can be NULL to indicate contiguous data. - * This is not allowed in DLPack v1.2 and later. The rationale - * is to simplify the consumer handling. - */ - int64_t *strides; - /*! \brief The offset in bytes to the beginning pointer to data */ - uint64_t byte_offset; -} DLTensor; - -/*! - * \brief C Tensor object, manage memory of DLTensor. This data structure is - * intended to facilitate the borrowing of DLTensor by another framework. It is - * not meant to transfer the tensor. When the borrowing framework doesn't need - * the tensor, it should call the deleter to notify the host that the resource - * is no longer needed. - * - * \note This data structure is used as Legacy DLManagedTensor - * in DLPack exchange and is deprecated after DLPack v0.8 - * Use DLManagedTensorVersioned instead. - * This data structure may get renamed or deleted in future versions. - * - * \sa DLManagedTensorVersioned - */ -typedef struct DLManagedTensor { - /*! \brief DLTensor which is being memory managed */ - DLTensor dl_tensor; - /*! \brief the context of the original host framework of DLManagedTensor in - * which DLManagedTensor is used in the framework. It can also be NULL. - */ - void *manager_ctx; - /*! - * \brief Destructor - this should be called - * to destruct the manager_ctx which backs the DLManagedTensor. It can be - * NULL if there is no way for the caller to provide a reasonable destructor. - * The destructor deletes the argument self as well. - */ - void (*deleter)(struct DLManagedTensor *self); -} DLManagedTensor; - -// bit masks used in the DLManagedTensorVersioned - -/*! \brief bit mask to indicate that the tensor is read only. */ -#define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL) - -/*! - * \brief bit mask to indicate that the tensor is a copy made by the producer. - * - * If set, the tensor is considered solely owned throughout its lifetime by the - * consumer, until the producer-provided deleter is invoked. - */ -#define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL) - -/*! - * \brief bit mask to indicate that whether a sub-byte type is packed or padded. - * - * The default for sub-byte types (ex: fp4/fp6) is assumed packed. This flag can - * be set by the producer to signal that a tensor of sub-byte type is padded. - */ -#define DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED (1UL << 2UL) - -/*! - * \brief A versioned and managed C Tensor object, manage memory of DLTensor. - * - * This data structure is intended to facilitate the borrowing of DLTensor by - * another framework. It is not meant to transfer the tensor. When the borrowing - * framework doesn't need the tensor, it should call the deleter to notify the - * host that the resource is no longer needed. - * - * \note This is the current standard DLPack exchange data structure. - */ -typedef struct DLManagedTensorVersioned { - /*! - * \brief The API and ABI version of the current managed Tensor - */ - DLPackVersion version; - /*! - * \brief the context of the original host framework. - * - * Stores DLManagedTensorVersioned is used in the - * framework. It can also be NULL. - */ - void *manager_ctx; - /*! - * \brief Destructor. - * - * This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned. - * It can be NULL if there is no way for the caller to provide a reasonable - * destructor. The destructor deletes the argument self as well. - */ - void (*deleter)(struct DLManagedTensorVersioned *self); - /*! - * \brief Additional bitmask flags information about the tensor. - * - * By default the flags should be set to 0. - * - * \note Future ABI changes should keep everything until this field - * stable, to ensure that deleter can be correctly called. - * - * \sa DLPACK_FLAG_BITMASK_READ_ONLY - * \sa DLPACK_FLAG_BITMASK_IS_COPIED - */ - uint64_t flags; - /*! \brief DLTensor which is being memory managed */ - DLTensor dl_tensor; -} DLManagedTensorVersioned; - -//---------------------------------------------------------------------- -// DLPack `__c_dlpack_exchange_api__` fast exchange protocol definitions -//---------------------------------------------------------------------- -/*! - * \brief Request a producer library to create a new tensor. - * - * Create a new `DLManagedTensorVersioned` within the context of the producer - * library. The allocation is defined via the prototype DLTensor. - * - * This function is exposed by the framework through the DLPackExchangeAPI. - * - * \param prototype The prototype DLTensor. Only the dtype, ndim, shape, - * and device fields are used. - * \param out The output DLManagedTensorVersioned. - * \param error_ctx Context for `SetError`. - * \param SetError The function to set the error. - * \return The owning DLManagedTensorVersioned* or NULL on failure. - * SetError is called exactly when NULL is returned (the implementor - * must ensure this). - * \note - As a C function, must not thrown C++ exceptions. - * - Error propagation via SetError to avoid any direct need - * of Python API. Due to this `SetError` may have to ensure the GIL is - * held since it will presumably set a Python error. - * - * \sa DLPackExchangeAPI - */ -typedef int (*DLPackManagedTensorAllocator)( // - DLTensor *prototype, DLManagedTensorVersioned **out, void *error_ctx, // - void (*SetError)(void *error_ctx, const char *kind, const char *message) // -); - -/*! - * \brief Exports a PyObject* Tensor/NDArray to a DLManagedTensorVersioned. - * - * This function does not perform any stream synchronization. The consumer should query - * DLPackCurrentWorkStream to get the current work stream and launch kernels on it. - * - * This function is exposed by the framework through the DLPackExchangeAPI. - * - * \param py_object The Python object to convert. Must have the same type - * as the one the `DLPackExchangeAPI` was discovered from. - * \return The owning DLManagedTensorVersioned* or NULL on failure with a - * Python exception set. If the data cannot be described using DLPack - * this should be a BufferError if possible. - * \note - As a C function, must not thrown C++ exceptions. - * - * \sa DLPackExchangeAPI, DLPackCurrentWorkStream - */ -typedef int (*DLPackManagedTensorFromPyObjectNoSync)( // - void *py_object, // - DLManagedTensorVersioned **out // -); - -/*! - * \brief Exports a PyObject* Tensor/NDArray to a provided DLTensor. - * - * This function provides a faster interface for temporary, non-owning, exchange. - * The producer (implementor) still owns the memory of data, strides, shape. - * The liveness of the DLTensor and the data it views is only guaranteed until - * control is returned. - * - * This function currently assumes that the producer (implementor) can fill - * in the DLTensor shape and strides without the need for temporary allocations. - * - * This function does not perform any stream synchronization. The consumer should query - * DLPackCurrentWorkStream to get the current work stream and launch kernels on it. - * - * This function is exposed by the framework through the DLPackExchangeAPI. - * - * \param py_object The Python object to convert. Must have the same type - * as the one the `DLPackExchangeAPI` was discovered from. - * \param out The output DLTensor, whose space is pre-allocated on stack. - * \return 0 on success, -1 on failure with a Python exception set. - * \note - As a C function, must not thrown C++ exceptions. - * - * \sa DLPackExchangeAPI, DLPackCurrentWorkStream - */ -typedef int (*DLPackDLTensorFromPyObjectNoSync)( // - void *py_object, // - DLTensor *out // -); - -/*! - * \brief Obtain the current work stream of a device. - * - * Obtain the current work stream of a device from the producer framework. - * For example, it should map to torch.cuda.current_stream in PyTorch. - * - * When device_type is kDLCPU, the consumer do not have to query the stream - * and the producer can simply return NULL when queried. - * The consumer do not have to do anything on stream sync or setting. - * So CPU only framework can just provide a dummy implementation that - * always set out_current_stream[0] to NULL. - * - * \param device_type The device type. - * \param device_id The device id. - * \param out_current_stream The output current work stream. - * - * \return 0 on success, -1 on failure with a Python exception set. - * \note - As a C function, must not thrown C++ exceptions. - * - * \sa DLPackExchangeAPI - */ -typedef int (*DLPackCurrentWorkStream)( // - DLDeviceType device_type, // - int32_t device_id, // - void **out_current_stream // -); - -/*! - * \brief Imports a DLManagedTensorVersioned to a PyObject* Tensor/NDArray. - * - * Convert an owning DLManagedTensorVersioned* to the Python tensor of the - * producer (implementor) library with the correct type. - * - * This function does not perform any stream synchronization. - * - * This function is exposed by the framework through the DLPackExchangeAPI. - * - * \param tensor The DLManagedTensorVersioned to convert the ownership of the - * tensor is stolen. - * \param out_py_object The output Python object. - * \return 0 on success, -1 on failure with a Python exception set. - * - * \sa DLPackExchangeAPI - */ -typedef int (*DLPackManagedTensorToPyObjectNoSync)( // - DLManagedTensorVersioned *tensor, // - void **out_py_object // -); - -/*! - * \brief DLPackExchangeAPI stable header. - * \sa DLPackExchangeAPI - */ -typedef struct DLPackExchangeAPIHeader { - /*! - * \brief The provided DLPack version the consumer must check major version - * compatibility before using this struct. - */ - DLPackVersion version; - /*! - * \brief Optional pointer to an older DLPackExchangeAPI in the chain. - * - * It must be NULL if the framework does not support older versions. - * If the current major version is larger than the one supported by the - * consumer, the consumer may walk this to find an earlier supported version. - * - * \sa DLPackExchangeAPI - */ - struct DLPackExchangeAPIHeader *prev_api; -} DLPackExchangeAPIHeader; - -/*! - * \brief Framework-specific function pointers table for DLPack exchange. - * - * Additionally to `__dlpack__()` we define a C function table sharable by - * Python implementations via `__c_dlpack_exchange_api__`. - * This attribute must be set on the type as a Python integer compatible - * with `PyLong_FromVoidPtr`/`PyLong_AsVoidPtr`. - * - * A consumer library may use a pattern such as: - * - * \code - * - * PyObject *api_obj = type(tensor_obj).__c_dlpack_exchange_api__; // as C-code - * MyDLPackExchangeAPI *api = PyLong_AsVoidPtr(api_obj); - * if (api == NULL && PyErr_Occurred()) { goto handle_error; } - * - * \endcode - * - * Note that this must be defined on the type. The consumer should look up the - * attribute on the type and may cache the result for each unique type. - * - * The precise API table is given by: - * \code - * struct MyDLPackExchangeAPI : public DLPackExchangeAPI { - * MyDLPackExchangeAPI() { - * header.version.major = DLPACK_MAJOR_VERSION; - * header.version.minor = DLPACK_MINOR_VERSION; - * header.prev_version_api = nullptr; - * - * managed_tensor_allocator = MyDLPackManagedTensorAllocator; - * managed_tensor_from_py_object_no_sync = MyDLPackManagedTensorFromPyObjectNoSync; - * managed_tensor_to_py_object_no_sync = MyDLPackManagedTensorToPyObjectNoSync; - * dltensor_from_py_object_no_sync = MyDLPackDLTensorFromPyObjectNoSync; - * current_work_stream = MyDLPackCurrentWorkStream; - * } - * - * static const DLPackExchangeAPI* Global() { - * static MyDLPackExchangeAPI inst; - * return &inst; - * } - * }; - * \endcode - * - * Guidelines for leveraging DLPackExchangeAPI: - * - * There are generally two kinds of consumer needs for DLPack exchange: - * - N0: library support, where consumer.kernel(x, y, z) would like to run a kernel - * with the data from x, y, z. The consumer is also expected to run the kernel with the same - * stream context as the producer. For example, when x, y, z is torch.Tensor, - * consumer should query exchange_api->current_work_stream to get the - * current stream and launch the kernel with the same stream. - * This setup is necessary for no synchronization in kernel launch and maximum compatibility - * with CUDA graph capture in the producer. - * This is the desirable behavior for library extension support for frameworks like PyTorch. - * - N1: data ingestion and retention - * - * Note that obj.__dlpack__() API should provide useful ways for N1. - * The primary focus of the current DLPackExchangeAPI is to enable faster exchange N0 - * with the support of the function pointer current_work_stream. - * - * Array/Tensor libraries should statically create and initialize this structure - * then return a pointer to DLPackExchangeAPI as an int value in Tensor/Array. - * The DLPackExchangeAPI* must stay alive throughout the lifetime of the process. - * - * One simple way to do so is to create a static instance of DLPackExchangeAPI - * within the framework and return a pointer to it. The following code - * shows an example to do so in C++. It should also be reasonably easy - * to do so in other languages. - */ -typedef struct DLPackExchangeAPI { - /*! - * \brief The header that remains stable across versions. - */ - DLPackExchangeAPIHeader header; - /*! - * \brief Producer function pointer for DLPackManagedTensorAllocator - * This function must not be NULL. - * \sa DLPackManagedTensorAllocator - */ - DLPackManagedTensorAllocator managed_tensor_allocator; - /*! - * \brief Producer function pointer for DLPackManagedTensorFromPyObject - * This function must be not NULL. - * \sa DLPackManagedTensorFromPyObject - */ - DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync; - /*! - * \brief Producer function pointer for DLPackManagedTensorToPyObject - * This function must be not NULL. - * \sa DLPackManagedTensorToPyObject - */ - DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync; - /*! - * \brief Producer function pointer for DLPackDLTensorFromPyObject - * This function can be NULL when the producer does not support this function. - * \sa DLPackDLTensorFromPyObjectNoSync - */ - DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync; - /*! - * \brief Producer function pointer for DLPackCurrentWorkStream - * This function must be not NULL. - * \sa DLPackCurrentWorkStream - */ - DLPackCurrentWorkStream current_work_stream; -} DLPackExchangeAPI; - -#ifdef __cplusplus -} // DLPACK_EXTERN_C -#endif -#endif // DLPACK_DLPACK_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h index 308d6fac3..f30492621 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tensor.h @@ -3,9 +3,9 @@ #pragma once #include "utils.h" -#include "dlpack/dlpack.h" -#include "tvm/ffi/container/tensor.h" -#include "tvm/ffi/dtype.h" +#include +#include +#include #include #include diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/any.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/any.h deleted file mode 100644 index 2c79b383b..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/any.h +++ /dev/null @@ -1,682 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/any.h - * \brief Any value support. - */ -#ifndef TVM_FFI_ANY_H_ -#define TVM_FFI_ANY_H_ - -#include "c_api.h" -#include "string.h" -#include "type_traits.h" - -#include -#include - -namespace tvm { -namespace ffi { - -class Any; - -namespace details { -// Helper to perform -// unsafe operations related to object -struct AnyUnsafe; -} // namespace details - -/*! - * \brief AnyView allows us to take un-managed reference view of any value. - */ -class AnyView { -protected: - /*! \brief The underlying backing data of the any object */ - TVMFFIAny data_; - // Any can see AnyView - friend class Any; - -public: - // NOTE: the following functions use style - // since they are common functions appearing in FFI. - /*! - * \brief Reset any view to None - */ - void reset() { - data_.type_index = TypeIndex::kTVMFFINone; - // invariance: always set the union padding part to 0 - data_.zero_padding = 0; - data_.v_int64 = 0; - } - /*! - * \brief Swap this AnyView with another AnyView - * \param other The other AnyView - */ - TVM_FFI_INLINE void swap(AnyView &other) noexcept { std::swap(data_, other.data_); } - /*! \return the internal type index */ - TVM_FFI_INLINE int32_t type_index() const noexcept { return data_.type_index; } - /*! \brief Default constructor */ - AnyView() { - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - ~AnyView() = default; - // constructors from any view - /*! \brief Copy constructor */ - AnyView(const AnyView &) = default; - /*! \brief Copy assignment operator */ - AnyView &operator=(const AnyView &) = default; - /*! \brief Move constructor */ - AnyView(AnyView &&other) noexcept : data_(other.data_) { - other.data_.type_index = TypeIndex::kTVMFFINone; - other.data_.zero_padding = 0; - other.data_.v_int64 = 0; - } - TVM_FFI_INLINE AnyView &operator=(AnyView &&other) noexcept { - // copy-and-swap idiom - AnyView(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief Constructor from a general type. - * \tparam T The type to convert from. - * \param other The value to convert from. - */ - template ::convert_enabled>> - AnyView(const T &other) { // NOLINT(*) - TypeTraits::CopyToAnyView(other, &data_); - } - /*! - * \brief Assign from a general type. - * \tparam T The type to convert from. - * \param other The value to convert from. - */ - template ::convert_enabled>> - TVM_FFI_INLINE AnyView &operator=(const T &other) { // NOLINT(*) - // copy-and-swap idiom - AnyView(other).swap(*this); // NOLINT(*) - return *this; - } - - /*! - * \brief Try to see if we can reinterpret the AnyView to as T object. - * - * \tparam T The type to cast to. - * \return The casted value, or std::nullopt if the cast is not possible. - * \note This function won't try run type conversion (use try_cast for that purpose). - */ - template ::convert_enabled>> - TVM_FFI_INLINE std::optional as() const { - if (TypeTraits::CheckAnyStrict(&data_)) { - return TypeTraits::CopyFromAnyViewAfterCheck(&data_); - } else { - return std::optional(std::nullopt); - } - } - /*! - * \brief Shortcut of as Object to cast to a const pointer when T is an Object. - * - * \tparam T The object type. - * \return The requested pointer, returns nullptr if type mismatches. - */ - template >> - TVM_FFI_INLINE const T *as() const { - return this->as().value_or(nullptr); - } - - /*! - * \brief Cast to a type T. - * - * \tparam T The type to cast to. - * \return The casted value, or throws an exception if the cast is not possible. - */ - template ::convert_enabled>> - TVM_FFI_INLINE T cast() const { - std::optional opt = TypeTraits::TryCastFromAnyView(&data_); - if (!opt.has_value()) { - TVM_FFI_THROW(TypeError) << "Cannot convert from type `" - << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" - << TypeTraits::TypeStr() << "`"; - } - return *std::move(opt); - } - - /*! - * \brief Try to cast to a type T, return std::nullopt if the cast is not possible. - * - * \tparam T The type to cast to. - * \return The casted value, or std::nullopt if the cast is not possible. - */ - template ::convert_enabled>> - TVM_FFI_INLINE std::optional try_cast() const { - return TypeTraits::TryCastFromAnyView(&data_); - } - - // comparison with nullptr - TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { - return data_.type_index == TypeIndex::kTVMFFINone; - } - TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { - return data_.type_index != TypeIndex::kTVMFFINone; - } - /*! - * \brief Get the type key of the Any - * \return The type key of the Any - */ - TVM_FFI_INLINE std::string GetTypeKey() const { return TypeIndexToTypeKey(data_.type_index); } - // The following functions are only used for testing purposes - /*! - * \return The underlying supporting data of any view - * \note This function is used only for testing purposes. - */ - TVM_FFI_INLINE TVMFFIAny CopyToTVMFFIAny() const { return data_; } - /*! - * \return Create an AnyView from TVMFFIAny - * \param data the underlying ffi data. - */ - TVM_FFI_INLINE static AnyView CopyFromTVMFFIAny(TVMFFIAny data) { - AnyView view; - view.data_ = data; - return view; - } -}; - -namespace details { -/*! - * \brief Helper function to inplace convert any view to any. - * \param data The pointer that represents the format as any view. - * \param extra_any_bytes Indicate that the data may contain extra bytes following - * the TVMFFIAny data structure. This is reserved for future possible optimizations - * of small-string and extended any object. - */ -TVM_FFI_INLINE void InplaceConvertAnyViewToAny(TVMFFIAny *data, - [[maybe_unused]] size_t extra_any_bytes = 0) { - if (data->type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::IncRefObjectHandle(data->v_obj); - } else if (data->type_index >= TypeIndex::kTVMFFIRawStr) { - if (data->type_index == TypeIndex::kTVMFFIRawStr) { - // convert raw string to owned string object - String temp(data->v_c_str); - TypeTraits::MoveToAny(std::move(temp), data); - } else if (data->type_index == TypeIndex::kTVMFFIByteArrayPtr) { - // convert byte array to owned bytes object - Bytes temp(*static_cast(data->v_ptr)); - TypeTraits::MoveToAny(std::move(temp), data); - } else if (data->type_index == TypeIndex::kTVMFFIObjectRValueRef) { - // convert rvalue ref to owned object - Object **obj_addr = static_cast(data->v_ptr); - TVM_FFI_ICHECK(obj_addr[0] != nullptr) << "RValueRef already moved"; - ObjectRef temp(details::ObjectUnsafe::ObjectPtrFromOwned(obj_addr[0])); - // set the rvalue ref to nullptr to avoid double move - obj_addr[0] = nullptr; - TypeTraits::MoveToAny(std::move(temp), data); - } - } -} -} // namespace details - -/*! - * \brief Managed Any that takes strong reference to a value. - * - * \note Develooper invariance: the TVMFFIAny data_ - * in the Any can be safely used in AnyView. - */ -class Any { -protected: - /*! \brief The underlying backing data of the any object */ - TVMFFIAny data_; - -public: - /*! - * \brief Reset any to None - */ - TVM_FFI_INLINE void reset() { - if (data_.type_index >= TVMFFITypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj); - } - data_.type_index = TVMFFITypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - /*! - * \brief Swap this Any with another Any - * \param other The other Any - */ - TVM_FFI_INLINE void swap(Any &other) noexcept { std::swap(data_, other.data_); } - /*! \return the internal type index */ - TVM_FFI_INLINE int32_t type_index() const noexcept { return data_.type_index; } - /*! - * \brief Default constructor - */ - Any() { - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - /*! - * \brief Destructor - */ - ~Any() { this->reset(); } - /*! - * \brief Constructor from another Any - * \param other The other Any - */ - Any(const Any &other) : data_(other.data_) { - if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::IncRefObjectHandle(data_.v_obj); - } - } - /*! - * \brief Move constructor from another Any - * \param other The other Any - */ - Any(Any &&other) noexcept : data_(other.data_) { - other.data_.type_index = TypeIndex::kTVMFFINone; - other.data_.zero_padding = 0; - other.data_.v_int64 = 0; - } - /*! - * \brief Assign from another Any - * \param other The other Any - */ - TVM_FFI_INLINE Any &operator=(const Any &other) { - // copy-and-swap idiom - Any(other).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief Move assign from another Any - * \param other The other Any - */ - TVM_FFI_INLINE Any &operator=(Any &&other) noexcept { - // copy-and-swap idiom - Any(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief Constructor from another AnyView - * \param other The other AnyView - */ - Any(const AnyView &other) : data_(other.data_) { // NOLINT(*) - details::InplaceConvertAnyViewToAny(&data_); - } - /*! - * \brief Assign from another AnyView - * \param other The other AnyView - */ - TVM_FFI_INLINE Any &operator=(const AnyView &other) { - // copy-and-swap idiom - Any(other).swap(*this); // NOLINT(*) - return *this; - } - /*! \brief Any can be converted to AnyView in zero cost. */ - operator AnyView() const { // NOLINT(google-explicit-constructor) - return AnyView::CopyFromTVMFFIAny(data_); - } - /*! - * \brief Constructor from a general type - * \tparam T The value type of the other - */ - template ::convert_enabled>> - Any(T other) { // NOLINT(*) - TypeTraits::MoveToAny(std::move(other), &data_); - } - /*! - * \brief Assignment from a general type - * \tparam T The value type of the other - */ - template ::convert_enabled>> - TVM_FFI_INLINE Any &operator=(T other) { // NOLINT(*) - // copy-and-swap idiom - Any(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - - /** - * \brief Try to reinterpret the Any as a type T, return std::nullopt if it is not possible. - * - * \tparam T The type to cast to. - * \return The casted value, or std::nullopt if the cast is not possible. - * \note This function won't try to run type conversion (use try_cast for that purpose). - */ - template ::storage_enabled || std::is_same_v>> - TVM_FFI_INLINE std::optional as() && { - if constexpr (std::is_same_v) { - return std::move(*this); - } else { - if (TypeTraits::CheckAnyStrict(&data_)) { - return TypeTraits::MoveFromAnyAfterCheck(&data_); - } else { - return std::optional(std::nullopt); - } - } - } - - /** - * \brief Try to reinterpret the Any as a type T, return std::nullopt if it is not possible. - * - * \tparam T The type to cast to. - * \return The casted value, or std::nullopt if the cast is not possible. - * \note This function won't try to run type conversion (use try_cast for that purpose). - */ - template ::convert_enabled || std::is_same_v>> - TVM_FFI_INLINE std::optional as() const & { - if constexpr (std::is_same_v) { - return *this; - } else { - if (TypeTraits::CheckAnyStrict(&data_)) { - return TypeTraits::CopyFromAnyViewAfterCheck(&data_); - } else { - return std::optional(std::nullopt); - } - } - } - - /*! - * \brief Shortcut of as Object to cast to a const pointer when T is an Object. - * - * \tparam T The object type. - * \return The requested pointer, returns nullptr if type mismatches. - */ - template >> - TVM_FFI_INLINE const T *as() const & { - return this->as().value_or(nullptr); - } - - /** - * \brief Cast to a type T, throw an exception if the cast is not possible. - * - * \tparam T The type to cast to. - */ - template ::convert_enabled>> - TVM_FFI_INLINE T cast() const & { - std::optional opt = TypeTraits::TryCastFromAnyView(&data_); - if (!opt.has_value()) { - TVM_FFI_THROW(TypeError) << "Cannot convert from type `" - << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" - << TypeTraits::TypeStr() << "`"; - } - return *std::move(opt); - } - - /** - * \brief Cast to a type T, throw an exception if the cast is not possible. - * - * \tparam T The type to cast to. - */ - template ::storage_enabled>> - TVM_FFI_INLINE T cast() && { - if (TypeTraits::CheckAnyStrict(&data_)) { - return TypeTraits::MoveFromAnyAfterCheck(&data_); - } - // slow path, try to do fallback convert - std::optional opt = TypeTraits::TryCastFromAnyView(&data_); - if (!opt.has_value()) { - TVM_FFI_THROW(TypeError) << "Cannot convert from type `" - << TypeTraits::GetMismatchTypeInfo(&data_) << "` to `" - << TypeTraits::TypeStr() << "`"; - } - return *std::move(opt); - } - - /** - * \brief Try to cast to a type T. - * - * \tparam T The type to cast to. - * \return The casted value, or std::nullopt if the cast is not possible. - * \note use STL name since it to be more consistent with cast API. - */ - template ::convert_enabled || std::is_same_v>> - TVM_FFI_INLINE std::optional try_cast() const { - if constexpr (std::is_same_v) { - return *this; - } else { - return TypeTraits::TryCastFromAnyView(&data_); - } - } - /*! - * \brief Check if the two Any are same type and value in shallow comparison. - * \param other The other Any - * \return True if the two Any are same type and value, false otherwise. - */ - TVM_FFI_INLINE bool same_as(const Any &other) const noexcept { - return data_.type_index == other.data_.type_index && data_.zero_padding == other.data_.zero_padding && data_.v_int64 == other.data_.v_int64; - } - - /*! - * \brief Check if any and ObjectRef are same type and value in shallow comparison. - * \param other The other ObjectRef - * \return True if the two Any are same type and value, false otherwise. - */ - TVM_FFI_INLINE bool same_as(const ObjectRef &other) const noexcept { - if (other.get() != nullptr) { - return (data_.type_index == other->type_index() && reinterpret_cast(data_.v_obj) == other.get()); - } else { - return data_.type_index == TypeIndex::kTVMFFINone; - } - } - - TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { - return data_.type_index == TypeIndex::kTVMFFINone; - } - TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { - return data_.type_index != TypeIndex::kTVMFFINone; - } - - /*! - * \brief Get the type key of the Any - * \return The type key of the Any - */ - TVM_FFI_INLINE std::string GetTypeKey() const { return TypeIndexToTypeKey(data_.type_index); } - - friend struct details::AnyUnsafe; - friend struct AnyHash; - friend struct AnyEqual; -}; - -// layout assert to ensure we can freely cast between the two types -static_assert(sizeof(AnyView) == sizeof(TVMFFIAny)); -static_assert(sizeof(Any) == sizeof(TVMFFIAny)); - -namespace details { - -template -struct Type2Str { - static std::string v() { return TypeTraitsNoCR::TypeStr(); } -}; - -template <> -struct Type2Str { - static std::string v() { return "Any"; } -}; - -template <> -struct Type2Str { - static std::string v() { return "Any"; } -}; - -template <> -struct Type2Str { - static std::string v() { return "AnyView"; } -}; - -template <> -struct Type2Str { - static std::string v() { return "AnyView"; } -}; - -template <> -struct Type2Str { - static std::string v() { return "void"; } -}; - -// Extra unsafe method to help any manipulation -struct AnyUnsafe : public ObjectUnsafe { - // FFI related operations - TVM_FFI_INLINE static TVMFFIAny MoveAnyToTVMFFIAny(Any &&any) { - TVMFFIAny result = any.data_; - any.data_.type_index = TypeIndex::kTVMFFINone; - any.data_.zero_padding = 0; - any.data_.v_int64 = 0; - return result; - } - - TVM_FFI_INLINE static Any MoveTVMFFIAnyToAny(TVMFFIAny *data) { - Any any; - any.data_ = *data; - data->type_index = TypeIndex::kTVMFFINone; - data->zero_padding = 0; - data->v_int64 = 0; - return any; - } - - template - TVM_FFI_INLINE static bool CheckAnyStrict(const Any &ref) { - return TypeTraits::CheckAnyStrict(&(ref.data_)); - } - - template - TVM_FFI_INLINE static T CopyFromAnyViewAfterCheck(const Any &ref) { - if constexpr (!std::is_same_v) { - return TypeTraits::CopyFromAnyViewAfterCheck(&(ref.data_)); - } else { - return ref; - } - } - - template - TVM_FFI_INLINE static T MoveFromAnyAfterCheck(Any &&ref) { - if constexpr (!std::is_same_v) { - return TypeTraits::MoveFromAnyAfterCheck(&(ref.data_)); - } else { - return std::move(ref); - } - } - - TVM_FFI_INLINE static Object *ObjectPtrFromAnyAfterCheck(const Any &ref) { - return reinterpret_cast(ref.data_.v_obj); - } - - TVM_FFI_INLINE static const TVMFFIAny *TVMFFIAnyPtrFromAny(const Any &ref) { - return &(ref.data_); - } - - template - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const Any &ref) { - return TypeTraits::GetMismatchTypeInfo(&(ref.data_)); - } -}; -} // namespace details - -/*! \brief String-aware Any equal functor */ -struct AnyHash { - /*! - * \brief Calculate the hash code of an Any - * \param a The given Any - * \return Hash code of a, string hash for strings and pointer address otherwise. - */ - uint64_t operator()(const Any &src) const { - if (src.data_.type_index == TypeIndex::kTVMFFISmallStr) { - // for small string, we use the same type key hash as normal string - // so heap allocated string and on stack string will have the same hash - return details::StableHashCombine(TypeIndex::kTVMFFIStr, - details::StableHashSmallStrBytes(&src.data_)); - } else if (src.data_.type_index == TypeIndex::kTVMFFISmallBytes) { - // use byte the same type key as bytes - return details::StableHashCombine(TypeIndex::kTVMFFIBytes, - details::StableHashSmallStrBytes(&src.data_)); - } else if (src.data_.type_index == TypeIndex::kTVMFFIStr || src.data_.type_index == TypeIndex::kTVMFFIBytes) { - const details::BytesObjBase *src_str = details::AnyUnsafe::CopyFromAnyViewAfterCheck(src); - return details::StableHashCombine(src.data_.type_index, - details::StableHashBytes(src_str->data, src_str->size)); - } else { - return details::StableHashCombine(src.data_.type_index, src.data_.v_uint64); - } - } -}; - -/*! \brief String-aware Any hash functor */ -struct AnyEqual { - /*! - * \brief Check if the two Any are equal - * \param lhs left operand. - * \param rhs right operand - * \return String equality if both are strings, pointer address equality otherwise. - */ - bool operator()(const Any &lhs, const Any &rhs) const { - // header with type index - const int64_t *lhs_as_int64 = reinterpret_cast(&lhs.data_); - const int64_t *rhs_as_int64 = reinterpret_cast(&rhs.data_); - static_assert(sizeof(TVMFFIAny) == 16); - // fast path, check byte equality - if (lhs_as_int64[0] == rhs_as_int64[0] && lhs_as_int64[1] == rhs_as_int64[1]) { - return true; - } - // common false case type index match, in this case we only need to pay attention to string - // equality - if (lhs.data_.type_index == rhs.data_.type_index) { - // specialy handle string hash - if (lhs.data_.type_index == TypeIndex::kTVMFFIStr || lhs.data_.type_index == TypeIndex::kTVMFFIBytes) { - const details::BytesObjBase *lhs_str = details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); - const details::BytesObjBase *rhs_str = details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); - return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size); - } - return false; - } else { - // type_index mismatch, if index is not string, return false - if (lhs.data_.type_index != kTVMFFIStr && lhs.data_.type_index != kTVMFFISmallStr && lhs.data_.type_index != kTVMFFISmallBytes && lhs.data_.type_index != kTVMFFIBytes) { - return false; - } - // small string and normal string comparison - if (lhs.data_.type_index == kTVMFFIStr && rhs.data_.type_index == kTVMFFISmallStr) { - const details::BytesObjBase *lhs_str = details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); - return Bytes::memequal(lhs_str->data, rhs.data_.v_bytes, lhs_str->size, - rhs.data_.small_str_len); - } - if (lhs.data_.type_index == kTVMFFISmallStr && rhs.data_.type_index == kTVMFFIStr) { - const details::BytesObjBase *rhs_str = details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); - return Bytes::memequal(lhs.data_.v_bytes, rhs_str->data, lhs.data_.small_str_len, - rhs_str->size); - } - if (lhs.data_.type_index == kTVMFFIBytes && rhs.data_.type_index == kTVMFFISmallBytes) { - const details::BytesObjBase *lhs_bytes = details::AnyUnsafe::CopyFromAnyViewAfterCheck(lhs); - return Bytes::memequal(lhs_bytes->data, rhs.data_.v_bytes, lhs_bytes->size, - rhs.data_.small_str_len); - } - if (lhs.data_.type_index == kTVMFFISmallBytes && rhs.data_.type_index == kTVMFFIBytes) { - const details::BytesObjBase *rhs_bytes = details::AnyUnsafe::CopyFromAnyViewAfterCheck(rhs); - return Bytes::memequal(lhs.data_.v_bytes, rhs_bytes->data, lhs.data_.small_str_len, - rhs_bytes->size); - } - return false; - } - } -}; -} // namespace ffi - -// Expose to the tvm namespace for usability -// Rationale: no ambiguity even in root -using tvm::ffi::Any; -using tvm::ffi::AnyView; - -} // namespace tvm -#endif // TVM_FFI_ANY_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/base_details.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/base_details.h deleted file mode 100644 index 147862117..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/base_details.h +++ /dev/null @@ -1,317 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/base_details.h - * \brief Internal detail utils that can be used by files in tvm/ffi. - * \note details headers are for internal use only - * and not to be directly used by user. - */ -#ifndef TVM_FFI_BASE_DETAILS_H_ -#define TVM_FFI_BASE_DETAILS_H_ - -#include "c_api.h" -#include "endian.h" - -#include -#include - -#if defined(_MSC_VER) -#ifndef WIN32_LEAN_AND_MEAN -#define WIN32_LEAN_AND_MEAN -#endif - -#ifndef NOMINMAX -#define NOMINMAX -#endif - -#include -#if (defined(_M_ARM64) || defined(_ARM64_) || defined(_M_ARM64EC)) && !defined(_InlineInterlockedAdd64) -#define _InlineInterlockedAdd64 InterlockedAdd64 -#endif - -#ifdef ERROR -#undef ERROR -#endif - -#endif -/// \cond Doxygen_Suppress - -#if defined(_MSC_VER) -#define TVM_FFI_INLINE [[msvc::forceinline]] inline -#else -#define TVM_FFI_INLINE [[gnu::always_inline]] inline -#endif - -/*! - * \brief Macro helper to force a function not to be inlined. - * It is only used in places that we know not inlining is good, - * e.g. some logging functions. - */ -#if defined(_MSC_VER) -#define TVM_FFI_NO_INLINE [[msvc::noinline]] -#else -#define TVM_FFI_NO_INLINE [[gnu::noinline]] -#endif - -#if defined(_MSC_VER) -#define TVM_FFI_UNREACHABLE() __assume(false) -#else -#define TVM_FFI_UNREACHABLE() __builtin_unreachable() -#endif - -#define TVM_FFI_STR_CONCAT_(__x, __y) __x##__y -#define TVM_FFI_STR_CONCAT(__x, __y) TVM_FFI_STR_CONCAT_(__x, __y) - -#if defined(__GNUC__) || defined(__clang__) -#define TVM_FFI_FUNC_SIG __PRETTY_FUNCTION__ -#elif defined(_MSC_VER) -#define TVM_FFI_FUNC_SIG __FUNCSIG__ -#else -#define TVM_FFI_FUNC_SIG __func__ -#endif - -#if defined(__GNUC__) -// gcc and clang and attribute constructor -/// \cond Doxygen_Suppress -#define TVM_FFI_STATIC_INIT_BLOCK_DEF_(FnName) __attribute__((constructor)) static void FnName() -/// \endcond -/* - * \brief Macro that defines a block that will be called during static initialization. - * - * \code - * TVM_FFI_STATIC_INIT_BLOCK() { - * RegisterFunctions(); - * } - * \endcode - */ -#define TVM_FFI_STATIC_INIT_BLOCK() \ - TVM_FFI_STATIC_INIT_BLOCK_DEF_(TVM_FFI_STR_CONCAT(__TVMFFIStaticInitFunc, __COUNTER__)) - -#else -/// \cond Doxygen_Suppress -// for other compilers, use the variable trick -#define TVM_FFI_STATIC_INIT_BLOCK_DEF_(FnName, RegVar) \ - static void FnName(); \ - [[maybe_unused]] static inline int RegVar = []() { \ - FnName(); \ - return 0; \ - }(); \ - static void FnName() - -#define TVM_FFI_STATIC_INIT_BLOCK() \ - TVM_FFI_STATIC_INIT_BLOCK_DEF_(TVM_FFI_STR_CONCAT(__TVMFFIStaticInitFunc, __COUNTER__), \ - TVM_FFI_STR_CONCAT(__TVMFFIStaticInitReg, __COUNTER__)) -/// \endcond -#endif - -/* - * \brief Define the default copy/move constructor and assign operator - * \param TypeName The class typename. - */ -#define TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - TypeName(const TypeName &other) = default; /* NOLINT(bugprone-macro-parentheses) */ \ - TypeName(TypeName &&other) noexcept = default; /* NOLINT(bugprone-macro-parentheses) */ \ - TypeName &operator=(const TypeName &other) = default; /* NOLINT(bugprone-macro-parentheses) */ \ - TypeName &operator=(TypeName &&other) noexcept = default; /* NOLINT(bugprone-macro-parentheses)*/ - -/*! - * \brief marks the begining of a C call that logs exception - */ -#define TVM_FFI_LOG_EXCEPTION_CALL_BEGIN() \ - try { \ - (void)0 - -/*! - * \brief Marks the end of a C call that logs exception - */ -#define TVM_FFI_LOG_EXCEPTION_CALL_END(Name) \ - } \ - catch (const std::exception &err) { \ - std::cerr << "Exception caught during " << #Name << ":\n" \ - << err.what() << std::endl; \ - exit(-1); \ - } - -/*! - * \brief Clear the padding parts so we can safely use v_int64 for hash - * and equality check even when the value stored is a pointer. - * - * This macro is used to clear the padding parts for hash and equality check - * in 32bit platform. - */ -#define TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result) \ - if constexpr (sizeof(void *) != sizeof(int64_t)) { \ - (result)->v_int64 = 0; \ - } - -namespace tvm { -namespace ffi { -namespace details { - -// for each iterator -struct for_each_dispatcher { - template - static void run(std::index_sequence, const F &f, Args &&...args) { // NOLINT(*) - (f(I, std::forward(args)), ...); - } -}; - -template -void for_each(const F &f, Args &&...args) { // NOLINT(*) - for_each_dispatcher::run(std::index_sequence_for{}, f, std::forward(args)...); -} - -/*! - * \brief hash an object and combines uint64_t key with previous keys - * - * This hash function is stable across platforms. - * - * \param key The left operand. - * \param value The right operand. - * \return the combined result. - */ -template , bool> = true> -TVM_FFI_INLINE uint64_t StableHashCombine(uint64_t key, const T &value) { - // XXX: do not use std::hash in this function. This hash must be stable - // across different platforms and std::hash is implementation dependent. - return key ^ (uint64_t(value) + 0x9e3779b9 + (key << 6) + (key >> 2)); -} - -/*! - * \brief Hash the binary bytes - * \param data The data pointer - * \param size The size of the bytes. - * \return the hash value. - */ -TVM_FFI_INLINE uint64_t StableHashBytes(const void *data_ptr, size_t size) { - // NOLINTBEGIN(clang-analyzer-security.ArrayBound) - const char *data = reinterpret_cast(data_ptr); - const constexpr uint64_t kMultiplier = 1099511628211ULL; - const constexpr uint64_t kMod = 2147483647ULL; - union Union { - uint8_t a[8]; - uint64_t b; - } u; - static_assert(sizeof(Union) == sizeof(uint64_t), "sizeof(Union) != sizeof(uint64_t)"); - const char *it = data; - const char *end = it + size; - uint64_t result = 0; - if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) { - // if alignment requirement is met, directly use load - if (reinterpret_cast(it) % 8 == 0) { - for (; it + 8 <= end; it += 8) { - u.b = *reinterpret_cast(it); - result = (result * kMultiplier + u.b) % kMod; - } - } else { - // unaligned version - for (; it + 8 <= end; it += 8) { - u.a[0] = it[0]; - u.a[1] = it[1]; - u.a[2] = it[2]; - u.a[3] = it[3]; - u.a[4] = it[4]; - u.a[5] = it[5]; - u.a[6] = it[6]; - u.a[7] = it[7]; - result = (result * kMultiplier + u.b) % kMod; - } - } - } else { - // need endian swap - for (; it + 8 <= end; it += 8) { - u.a[0] = it[7]; - u.a[1] = it[6]; - u.a[2] = it[5]; - u.a[3] = it[4]; - u.a[4] = it[3]; - u.a[5] = it[2]; - u.a[6] = it[1]; - u.a[7] = it[0]; - result = (result * kMultiplier + u.b) % kMod; - } - } - - if (it < end) { - u.b = 0; - uint8_t *a = u.a; - if (it + 4 <= end) { - a[0] = it[0]; - a[1] = it[1]; - a[2] = it[2]; - a[3] = it[3]; - it += 4; - a += 4; - } - if (it + 2 <= end) { - a[0] = it[0]; - a[1] = it[1]; - it += 2; - a += 2; - } - if (it + 1 <= end) { - a[0] = it[0]; - } - if constexpr (!TVM_FFI_IO_NO_ENDIAN_SWAP) { - std::swap(u.a[0], u.a[7]); - std::swap(u.a[1], u.a[6]); - std::swap(u.a[2], u.a[5]); - std::swap(u.a[3], u.a[4]); - } - result = (result * kMultiplier + u.b) % kMod; - } - // NOLINTEND(clang-analyzer-security.ArrayBound) - return result; -} - -/*! - * \brief Same as StableHashBytes, but for small string data. - * \param data The data pointer - * \return the hash value. - */ -TVM_FFI_INLINE uint64_t StableHashSmallStrBytes(const TVMFFIAny *data) { - if constexpr (TVM_FFI_IO_NO_ENDIAN_SWAP) { - // fast path, no endian swap, simply hash as uint64_t - const constexpr uint64_t kMod = 2147483647ULL; - return data->v_uint64 % kMod; - } - return StableHashBytes(reinterpret_cast(data), sizeof(data->v_uint64)); -} - -/*! - * \brief Helper to generate a JSON-based type schema for a given type. - * \tparam T The type to generate the schema for. Assuming `T` is not - * const-qualified or reference-qualified. - */ -template -struct TypeSchemaImpl; -/*! - * \brief Helper to generate a JSON-based type schema for a given type. - * \tparam T The type to generate the schema for. - * \note This type removes const and reference qualifiers from `T` before - * passing it to `TypeSchemaImpl`. - */ -template -using TypeSchema = TypeSchemaImpl>>; - -} // namespace details -} // namespace ffi -} // namespace tvm -/// \endcond -#endif // TVM_FFI_BASE_DETAILS_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/c_api.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/c_api.h deleted file mode 100644 index 4b721f66d..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/c_api.h +++ /dev/null @@ -1,1226 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -// NOLINTBEGIN(modernize-use-using,bugprone-reserved-identifier,modernize-deprecated-headers) -/* - * \file tvm/ffi/c_api.h - * \brief This file defines the C convention of the FFI convention - */ -#ifndef TVM_FFI_C_API_H_ -#define TVM_FFI_C_API_H_ - -#include "../../dlpack/dlpack.h" -#include - -// Macros to do weak linking -#ifdef _MSC_VER -#define TVM_FFI_WEAK __declspec(selectany) -#else -#define TVM_FFI_WEAK __attribute__((weak)) -#endif - -// Defines two macros -// TVM_FFI_DLL: marks the function as a DLL export/import -// depending on whether TVM_FFI_EXPORTS is defined -// TVM_FFI_DLL_EXPORT: always marks the function as a DLL export -#if !defined(TVM_FFI_DLL) && defined(__EMSCRIPTEN__) -#include -#define TVM_FFI_DLL EMSCRIPTEN_KEEPALIVE -#define TVM_FFI_DLL_EXPORT EMSCRIPTEN_KEEPALIVE -#endif -#if !defined(TVM_FFI_DLL) && defined(_MSC_VER) -#ifdef TVM_FFI_EXPORTS -#define TVM_FFI_DLL __declspec(dllexport) -#else -#define TVM_FFI_DLL __declspec(dllimport) -#endif -#define TVM_FFI_DLL_EXPORT __declspec(dllexport) -#endif -#ifndef TVM_FFI_DLL -#define TVM_FFI_DLL __attribute__((visibility("default"))) -#define TVM_FFI_DLL_EXPORT __attribute__((visibility("default"))) -#endif - -// NOLINTBEGIN(modernize-macro-to-enum) -/*! \brief TVM FFI major version. */ -#define TVM_FFI_VERSION_MAJOR 0 -/*! \brief TVM FFI minor version. */ -#define TVM_FFI_VERSION_MINOR 1 -/*! \brief TVM FFI patch version. */ -#define TVM_FFI_VERSION_PATCH 4 -// NOLINTEND(modernize-macro-to-enum) - -#ifdef __cplusplus -extern "C" { -#endif - -/*! - * \brief TVM FFI version. - */ -typedef struct { - /*! \brief TVM FFI major version. */ - uint32_t major; - /*! \brief TVM FFI minor version. */ - uint32_t minor; - /*! \brief TVM FFI patch version. */ - uint32_t patch; -} TVMFFIVersion; - -#ifdef __cplusplus -enum TVMFFITypeIndex : int32_t { -#else -typedef enum { -#endif - /* - * \brief The root type of all FFI objects. - * - * We include it so TypeIndex captures all possible runtime values. - * `kTVMFFIAny` code will never appear in Any::type_index. - * However, it may appear in field annotations during reflection. - */ - kTVMFFIAny = -1, - // [Section] On-stack POD and special types: [0, kTVMFFIStaticObjectBegin) - // N.B. `kTVMFFIRawStr` is a string backed by a `\0`-terminated char array, - // which is not owned by TVMFFIAny. It is required that the following - // invariant holds: - // - `Any::type_index` is never `kTVMFFIRawStr` - // - `AnyView::type_index` can be `kTVMFFIRawStr` - // - /*! \brief None/nullptr value */ - kTVMFFINone = 0, - /*! \brief POD int value */ - kTVMFFIInt = 1, - /*! \brief POD bool value */ - kTVMFFIBool = 2, - /*! \brief POD float value */ - kTVMFFIFloat = 3, - /*! \brief Opaque pointer object */ - kTVMFFIOpaquePtr = 4, - /*! \brief DLDataType */ - kTVMFFIDataType = 5, - /*! \brief DLDevice */ - kTVMFFIDevice = 6, - /*! \brief DLTensor* */ - kTVMFFIDLTensorPtr = 7, - /*! \brief const char* */ - kTVMFFIRawStr = 8, - /*! \brief TVMFFIByteArray* */ - kTVMFFIByteArrayPtr = 9, - /*! \brief R-value reference to ObjectRef */ - kTVMFFIObjectRValueRef = 10, - /*! \brief Small string on stack */ - kTVMFFISmallStr = 11, - /*! \brief Small bytes on stack */ - kTVMFFISmallBytes = 12, - /*! \brief Start of statically defined objects. */ - kTVMFFIStaticObjectBegin = 64, - /*! - * \brief Object, all objects starts with TVMFFIObject as its header. - * \note We will also add other fields - */ - kTVMFFIObject = 64, - /*! - * \brief String object, layout = { TVMFFIObject, TVMFFIByteArray, ... } - */ - kTVMFFIStr = 65, - /*! - * \brief Bytes object, layout = { TVMFFIObject, TVMFFIByteArray, ... } - */ - kTVMFFIBytes = 66, - /*! \brief Error object. */ - kTVMFFIError = 67, - /*! \brief Function object. */ - kTVMFFIFunction = 68, - /*! - * \brief Shape object, layout = { TVMFFIObject, { const int64_t*, size_t }, ... } - */ - kTVMFFIShape = 69, - /*! - * \brief Tensor object, layout = { TVMFFIObject, DLTensor, ... } - */ - kTVMFFITensor = 70, - /*! \brief Array object. */ - kTVMFFIArray = 71, - /*! \brief Map object. */ - kTVMFFIMap = 72, - /*! \brief Runtime dynamic loaded module object. */ - kTVMFFIModule = 73, - /*! - * \brief Opaque python object. - * - * This is a special type index to indicate we are storing an opaque PyObject. - * Such object may interact with callback functions that are registered to support - * python-related operations. - * - * We only translate the objects that we do not recognize into this type index. - * - * \sa TVMFFIObjectCreateOpaque - */ - kTVMFFIOpaquePyObject = 74, - //---------------------------------------------------------------- - // more complex objects - //---------------------------------------------------------------- - kTVMFFIStaticObjectEnd, - // [Section] Dynamic Boxed: [kTVMFFIDynObjectBegin, +oo) - /*! \brief Start of type indices that are allocated at runtime. */ - kTVMFFIDynObjectBegin = 128 -#ifdef __cplusplus -}; -#else -} TVMFFITypeIndex; -#endif - -/*! \brief Handle to Object from C API's pov */ -typedef void *TVMFFIObjectHandle; - -/*! - * \brief bitmask of the object deleter flag. - */ -#ifdef __cplusplus -enum TVMFFIObjectDeleterFlagBitMask : int32_t { -#else -typedef enum { -#endif - /*! - * \brief deleter action when strong reference count becomes zero. - * Need to call destructor of the object but not free the memory block. - */ - kTVMFFIObjectDeleterFlagBitMaskStrong = 1 << 0, - /*! - * \brief deleter action when weak reference count becomes zero. - * Need to free the memory block. - */ - kTVMFFIObjectDeleterFlagBitMaskWeak = 1 << 1, - /*! - * \brief deleter action when both strong and weak reference counts become zero. - * \note This is the most common case. - */ - kTVMFFIObjectDeleterFlagBitMaskBoth = (kTVMFFIObjectDeleterFlagBitMaskStrong | kTVMFFIObjectDeleterFlagBitMaskWeak), -#ifdef __cplusplus -}; -#else -} TVMFFIObjectDeleterFlagBitMask; -#endif - -/*! - * \brief C-based type of all FFI object header that allocates on heap. - */ -typedef struct { - /*! - * \brief Combined strong and weak reference counter of the object. - * - * Strong ref counter is packed into the lower 32 bits. - * Weak ref counter is packed into the upper 32 bits. - * - * It is equivalent to { uint32_t strong_ref_count, uint32_t weak_ref_count } - * in little-endian structure: - * - * - strong_ref_count: `combined_ref_count & 0xFFFFFFFF` - * - weak_ref_count: `(combined_ref_count >> 32) & 0xFFFFFFFF` - * - * Rationale: atomic ops on strong ref counter remains the same as +1/-1, - * this combined ref counter allows us to use u64 atomic once - * instead of a separate atomic read of weak counter during deletion. - * - * The ref counter goes first to align ABI with most intrusive ptr designs. - * It is also likely more efficient as rc operations can be quite common. - */ - uint64_t combined_ref_count; - /*! - * \brief type index of the object. - * \note The type index of Object and Any are shared in FFI. - */ - int32_t type_index; - /*! \brief Extra padding to ensure 8 bytes alignment. */ - uint32_t __padding; -#if !defined(TVM_FFI_DOXYGEN_MODE) - union { -#endif - /*! - * \brief Deleter to be invoked when strong reference counter goes to zero. - * \param self The self object handle. - * \param flags The flags to indicate deletion behavior. - * \sa TVMFFIObjectDeleterFlagBitMask - */ - void (*deleter)(void *self, int flags); - /*! - * \brief auxilary field to TVMFFIObject is always 8 bytes aligned. - * \note This helps us to ensure cross platform compatibility. - */ - int64_t __ensure_align; -#if !defined(TVM_FFI_DOXYGEN_MODE) - }; -#endif -} TVMFFIObject; - -/*! - * \brief C-based type of all on stack Any value. - * - * Any value can hold on stack values like int, - * as well as reference counted pointers to object. - */ -typedef struct { - /*! - * \brief type index of the object. - * \note The type index of Object and Any are shared in FFI. - */ - int32_t type_index; -#if !defined(TVM_FFI_DOXYGEN_MODE) - union { // 4 bytes -#endif - /*! \brief padding, must set to zero for values other than small string. */ - uint32_t zero_padding; - /*! - * \brief Length of small string, with a max value of 7. - * - * We keep small str to start at next 4 bytes to ensure alignment - * when accessing the small str content. - */ - uint32_t small_str_len; -#if !defined(TVM_FFI_DOXYGEN_MODE) - }; -#endif -#if !defined(TVM_FFI_DOXYGEN_MODE) - union { // 8 bytes -#endif - /*! \brief integers */ - int64_t v_int64; - /*! \brief floating-point numbers */ - double v_float64; - /*! \brief typeless pointers */ - void *v_ptr; - /*! \brief raw C-string */ - const char *v_c_str; - /*! \brief ref counted objects */ - TVMFFIObject *v_obj; - /*! \brief data type */ - DLDataType v_dtype; - /*! \brief device */ - DLDevice v_device; - /*! \brief small string */ - char v_bytes[8]; - /*! \brief uint64 repr mainly used for hashing */ - uint64_t v_uint64; -#if !defined(TVM_FFI_DOXYGEN_MODE) - }; -#endif -} TVMFFIAny; - -/*! - * \brief Byte array data structure used by String and Bytes. - * - * String and Bytes object layout = { TVMFFIObject, TVMFFIByteArray, ... } - * - * \note This byte array data structure layout differs in 32/64 bit platforms. - * as size_t equals to the size of the pointer, use this convetion to - * be consistent with std::string and also avoid need to calculate padding - * for the size field on 32-bit platforms. - * The FFI binding should be careful when treating this ABI. - */ -typedef struct { - /*! \brief The data pointer. */ - const char *data; - /*! \brief The size of the data. */ - size_t size; -} TVMFFIByteArray; - -/*! - * \brief Shape cell used in shape object following header. - */ -typedef struct { - /*! \brief The data pointer. */ - const int64_t *data; - /*! \brief The size of the data. */ - size_t size; -} TVMFFIShapeCell; - -/*! - * \brief Mode to update the backtrace of the error. - */ -#ifdef __cplusplus -enum TVMFFIBacktraceUpdateMode : int32_t { -#else -typedef enum { -#endif - kTVMFFIBacktraceUpdateModeReplace = 0, - kTVMFFIBacktraceUpdateModeAppend = 1, -#ifdef __cplusplus -}; -#else -} TVMFFIBacktraceUpdateMode; -#endif - -/*! - * \brief Error cell used in error object following header. - */ -typedef struct { - /*! \brief The kind of the error. */ - TVMFFIByteArray kind; - /*! \brief The message of the error. */ - TVMFFIByteArray message; - /*! - * \brief The backtrace of the error. - * - * The backtrace is in the order of recent call first from the top of the stack - * to the bottom of the stack. This order makes it helpful for appending - * the extra backtrace to the end as we go up when error is propagated. - * - * When printing out, we encourage reverse the order of lines to make it - * align with python style. - */ - TVMFFIByteArray backtrace; - /*! - * \brief Function handle to update the backtrace of the error. - * \param self The self object handle. - * \param backtrace The backtrace to update. - * \param update_mode The mode to update the backtrace, - * can be either kTVMFFIBacktraceUpdateModeReplace, kTVMFFIBacktraceUpdateModeAppend. - */ - void (*update_backtrace)(TVMFFIObjectHandle self, const TVMFFIByteArray *backtrace, - int32_t update_mode); -} TVMFFIErrorCell; - -/*! - * \brief Type that defines C-style safe call convention - * - * Safe call explicitly catches exception on function boundary. - * - * \param handle The function handle - * \param num_args Number of input arguments - * \param args The input arguments to the call. - * \param result Store output result. - * - * IMPORTANT: caller must initialize result->type_index to be kTVMFFINone, - * or any other value smaller than kTVMFFIStaticObjectBegin. - * - * \return The call returns 0 if call is successful. - * It returns non-zero value if there is an error. - * - * Possible return error of the API functions: - * * 0: success - * * -1: error happens, can be retrieved by TVMFFIErrorMoveFromRaised - * * -2: a frontend error occurred and recorded in the frontend. - * - * \note We decided to leverage TVMFFIErrorMoveFromRaised and TVMFFIErrorSetRaised - * for C function error propagation. This design choice, while - * introducing a dependency for TLS runtime, simplifies error - * propgation in chains of calls in compiler codegen. - * As we do not need to propagate error through argument but simply - * set them in the runtime environment. - * - * \sa TVMFFIErrorMoveFromRaised - * \sa TVMFFIErrorSetRaised - * \sa TVMFFIErrorSetRaisedFromCStr - * \sa TVMFFIErrorSetRaisedFromCStrParts - */ -typedef int (*TVMFFISafeCallType)(void *handle, const TVMFFIAny *args, int32_t num_args, - TVMFFIAny *result); - -/*! - * \brief Object cell for function object following header. - */ -typedef struct { - /*! \brief A C API compatible call with exception catching. */ - TVMFFISafeCallType safe_call; - /*! - * \brief A function pointer to an underlying cpp call. - * - * The signature is the same as TVMFFISafeCallType except the return type is void, - * and the function throws exception directly instead of returning error code. - * We use void* here to avoid depending on c++ compiler. - * - * This pointer should be set to NULL for functions that are not originally created in cpp. - * - * \note The caller must assume the same cpp exception catching abi when using this pointer. - * When used across FFI boundaries, always use safe_call. - */ - void *cpp_call; -} TVMFFIFunctionCell; - -/*! - * \brief Object cell for opaque object following header. - */ -typedef struct { - /*! \brief The handle of the opaque object, for python it is PyObject* */ - void *handle; -} TVMFFIOpaqueObjectCell; - -//----------------------------------------------------------------------- -// Section: Version API -//----------------------------------------------------------------------- -/*! - * \brief Get the TVM FFI version from the current C ABI. - * - * This function is always stable across all versions of the C ABI. - * - * \param out_version The output version. - */ -TVM_FFI_DLL void TVMFFIGetVersion(TVMFFIVersion *out_version); - -//------------------------------------------------------------ -// Section: Basic object API -//------------------------------------------------------------ -/*! - * \brief Increase the strong reference count of an object handle - * \param obj The object handle. - * \note Internally we increase the reference counter of the object. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIObjectIncRef(TVMFFIObjectHandle obj); - -/*! - * \brief Free an object handle by decreasing strong reference - * \param obj The object handle. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIObjectDecRef(TVMFFIObjectHandle obj); - -/*! - * \brief Create an Opaque object by passing in handle, type_index and deleter. - * - * The opaque object's lifetime is managed as an Object, so it can be retained - * and released like other objects. - * When the opaque object is kTVMFFIOpaquePyObject, it can be converted back to - * the python type when returned or passed as arguments to a python function. - * - * We can support ffi::Function that interacts with these objects, - * most likely callback registered from python. - * - * For language bindings, we only convert types that we do not recognize into this type. - * On the C++ side, the most common way to represent such OpaqueObject is to simply - * use ffi::ObjectRef or ffi::Any. - * - * \param handle The resource handle of the opaque object. - * \param type_index The type index of the object. - * \param deleter deleter to recycle - * \param out The output of the opaque object. - * \return 0 when success, nonzero when failure happens - * - * \note The caller must ensure the type_index is a valid opaque object type index. - * \sa kTVMFFIOpaquePyObject - */ -TVM_FFI_DLL int TVMFFIObjectCreateOpaque(void *handle, int32_t type_index, - void (*deleter)(void *handle), TVMFFIObjectHandle *out); - -/*! - * \brief Convert type key to type index. - * \param type_key The key of the type. - * \param out_tindex the corresponding type index. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFITypeKeyToIndex(const TVMFFIByteArray *type_key, int32_t *out_tindex); - -//----------------------------------------------------------------------- -// Section: Basic function calling API for function implementation -//----------------------------------------------------------------------- -/*! - * \brief Create a FFIFunc by passing in callbacks from a C callback. - * The registered function can then be retrieved by the backend using its name. - * \param self The resource handle of the C callback. - * \param safe_call The C callback implementation. - * \param deleter The deleter to recycle. - * \param out The output of the function. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIFunctionCreate(void *self, TVMFFISafeCallType safe_call, - void (*deleter)(void *self), TVMFFIObjectHandle *out); - -/*! - * \brief Get a global function registered in the system. - * \param name The name of the function. - * \param out The result function pointer, NULL if it does not exist. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIFunctionGetGlobal(const TVMFFIByteArray *name, TVMFFIObjectHandle *out); - -/*! - * \brief Convert an AnyView to an owned Any. - * \param any_view The AnyView to convert. - * \param out The output Any, must be an empty object. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIAnyViewToOwnedAny(const TVMFFIAny *any_view, TVMFFIAny *out); - -/*! - * \brief Call a FFIFunc by passing in arguments. - * \param func The resource handle of the C callback. - * \param args The input arguments to the call. - * \param num_args The number of input arguments. - * \param result The output result, caller must ensure result->type_index is set to kTVMFFINone. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIFunctionCall(TVMFFIObjectHandle func, TVMFFIAny *args, int32_t num_args, - TVMFFIAny *result); - -/*! - * \brief Move the last error from the environment to the result. - * \param result The result error. - * \note This function clears the error stored in the TLS. - */ -TVM_FFI_DLL void TVMFFIErrorMoveFromRaised(TVMFFIObjectHandle *result); - -/*! - * \brief Set a raised error in TLS, which can be fetched by TVMFFIErrorMoveFromRaised. - * \param error The error object handle - */ -TVM_FFI_DLL void TVMFFIErrorSetRaised(TVMFFIObjectHandle error); - -/*! - * \brief Set a raised error in TLS, which can be fetched by TVMFFIErrorMoveFromRaised. - * \param kind The kind of the error. - * \param message The error message. - * \note This is a convenient method for the C API side to set an error directly from a string. - */ -TVM_FFI_DLL void TVMFFIErrorSetRaisedFromCStr(const char *kind, const char *message); - -/*! - * \brief Set a raised error in TLS, which can be fetched by TVMFFIErrorMoveFromRaised. - * - * Rationale: This function can be used by compilers to create error messages by - * concatenating multiple parts of the error message, which can reduce the - * storage size for common parts such as function signatures. - * - * For example, the following are possible error messages from a kernel DSL - * - * - Argument 1 mismatch in `matmul(x: Tensor, y: Tensor, z: Tensor)`, dtype mismatch - * - Argument 2 mismatch in `matmul(x: Tensor, y: Tensor, z: Tensor)`, shape[0] mismatch - * - Argument 2 mismatch in `matmul(x: Tensor, y: Tensor, z: Tensor)`, shape[1] mismatch - * - * Storing each part of the error message as a separate global string can cause quite - * a bit of duplication, especially considering the kinds of error reports we may have. - * Instead, compilers can store error messages in parts, where items like - * `matmul(x: Tensor, y: Tensor, z: Tensor)` can be reused across multiple error messages. - * This API simplifies error reporting for such cases. - * - * \param kind The kind of the error. - * \param message_parts The error message parts, each part can be NULL and will be skipped. - * \param num_parts The number of error message parts. - */ -TVM_FFI_DLL void TVMFFIErrorSetRaisedFromCStrParts(const char *kind, const char **message_parts, - int32_t num_parts); - -/*! - * \brief Create an initial error object. - * \param kind The kind of the error. - * \param message The error message. - * \param backtrace The backtrace of the error. - * \param out The output error object handle. - * \return 0 on success, nonzero on failure(likely MemoryError) - * - * \note This function is different from other functions as it is used in the error handling loop. - * So we do not follow normal error handling patterns. When error happens it will not set - * the error in TLS (since TLS error setting also involves creating an Error object). - * Instead, caller should simply report MemoryError to the logger. - */ -TVM_FFI_DLL int TVMFFIErrorCreate(const TVMFFIByteArray *kind, const TVMFFIByteArray *message, - const TVMFFIByteArray *backtrace, TVMFFIObjectHandle *out); - -//------------------------------------------------------------ -// Section: DLPack support APIs -//------------------------------------------------------------ -/*! - * \brief Produce a managed Tensor from a DLPack tensor. - * \param from The source DLPack tensor. - * \param require_alignment The minimum alignment required of the data + byte_offset. - * \param require_contiguous Boolean flag indicating if we need to check for contiguity. - * \param out The output Tensor handle. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFITensorFromDLPack(DLManagedTensor *from, int32_t require_alignment, - int32_t require_contiguous, TVMFFIObjectHandle *out); - -/*! - * \brief Produce a DLManagedTensor from the array that shares data memory with the array. - * \param from The source array. - * \param out The DLManagedTensor handle. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFITensorToDLPack(TVMFFIObjectHandle from, DLManagedTensor **out); - -/*! - * \brief Produce a managed Tensor from a DLPack tensor. - * \param from The source DLPack tensor. - * \param require_alignment The minimum alignment required of the data + byte_offset. - * \param require_contiguous Boolean flag indicating if we need to check for contiguity. - * \param out The output Tensor handle. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFITensorFromDLPackVersioned(DLManagedTensorVersioned *from, - int32_t require_alignment, - int32_t require_contiguous, - TVMFFIObjectHandle *out); - -/*! - * \brief Produce a DLManagedTensor from the array that shares data memory with the array. - * \param from The source array. - * \param out The DLManagedTensor handle. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFITensorToDLPackVersioned(TVMFFIObjectHandle from, - DLManagedTensorVersioned **out); -//--------------------------------------------------------------- -// Section: string/bytes support APIs. -// These APIs are used to simplify the string/bytes construction -//--------------------------------------------------------------- -/*! - * \brief Reinterpret the content of TVMFFIByteArray to String. - * \param input The TVMFFIByteArray to convert. - * \param out The output String owned by the caller, maybe a SmallStr or a Str object. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIStringFromByteArray(const TVMFFIByteArray *input, TVMFFIAny *out); - -/*! - * \brief Reinterpret the content of TVMFFIByteArray to Bytes. - * \param input The TVMFFIByteArray to convert. - * \param out The output Bytes owned by the caller, maybe a SmallBytes or a Bytes object. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIBytesFromByteArray(const TVMFFIByteArray *input, TVMFFIAny *out); - -//--------------------------------------------------------------- -// Section: dtype string support APIs. -// These APIs are used to simplify the dtype printings during FFI -//--------------------------------------------------------------- - -/*! - * \brief Convert a string to a DLDataType. - * \param str The string to convert. - * \param out The output DLDataType. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIDataTypeFromString(const TVMFFIByteArray *str, DLDataType *out); - -/*! -* \brief Convert a DLDataType to a string. -* \param dtype The DLDataType to convert. -* \param out The output string. -* \return 0 on success, nonzero on failure. -* \note out is a String object that needs to be freed by the caller via TVMFFIObjectDecRef. -The content of string can be accessed via TVMFFIObjectGetByteArrayPtr. - -* \note The input dtype is a pointer to the DLDataType to avoid ABI compatibility issues. -*/ -TVM_FFI_DLL int TVMFFIDataTypeToString(const DLDataType *dtype, TVMFFIAny *out); - -//------------------------------------------------------------ -// Section: Type reflection support APIs -// -// The reflec -//------------------------------------------------------------ -/*! - * \brief Getter that can take the address of a field and set the result. - * \param field The raw address of the field. - * \param result Stores the result. - * \return 0 on success, nonzero on failure. - */ -typedef int (*TVMFFIFieldGetter)(void *field, TVMFFIAny *result); - -/*! - * \brief Getter that can take the address of a field and set it to a value. - * \param field The raw address of the field. - * \param value The value to set. - * \return 0 on success, nonzero on failure. - */ -typedef int (*TVMFFIFieldSetter)(void *field, const TVMFFIAny *value); - -/*! - * \brief Function that creates a new instance of the type. - * \param result The new object handle - * \return 0 on success, nonzero on failure. - */ -typedef int (*TVMFFIObjectCreator)(TVMFFIObjectHandle *result); - -/*! - * \brief bitmask of the field. - */ -#ifdef __cplusplus -enum TVMFFIFieldFlagBitMask : int32_t { -#else -typedef enum { -#endif - /*! \brief The field is writable. */ - kTVMFFIFieldFlagBitMaskWritable = 1 << 0, - /*! \brief The field has default value. */ - kTVMFFIFieldFlagBitMaskHasDefault = 1 << 1, - /*! \brief The field is a static method. */ - kTVMFFIFieldFlagBitMaskIsStaticMethod = 1 << 2, - /*! - * \brief The field should be ignored when performing structural eq/hash - * - * This is an optional meta-data for structural eq/hash. - */ - kTVMFFIFieldFlagBitMaskSEqHashIgnore = 1 << 3, - /*! - * \brief The field enters a def region where var can be defined/matched. - * - * This is an optional meta-data for structural eq/hash. - */ - kTVMFFIFieldFlagBitMaskSEqHashDef = 1 << 4, -#ifdef __cplusplus -}; -#else -} TVMFFIFieldFlagBitMask; -#endif - -/*! - * \brief Optional meta-data for structural eq/hash. - * - * This meta-data is only useful when we want to leverage the information - * to perform richer semantics aware structural comparison and hash. - * It can be safely ignored if such information is not needed. - * - * The meta-data record comparison method in tree node and DAG node. - * - * \code - * x = VarNode() - * v0 = AddNode(x, 1) - * v1 = AddNode(x, 1) - * v2 = AddNode(v0, v0) - * v3 = AddNode(v1, v0) - * \endcode - * - * Consider the construct sequence of AddNode below, - * if AddNode is treated as a tree node, then v2 and v3 - * structural equals to each other, but if AddNode is - * treated as a DAG node, then v2 and v3 does not - * structural equals to each other. - */ -#ifdef __cplusplus -enum TVMFFISEqHashKind : int32_t { -#else -typedef enum { -#endif - /*! \brief Do not support structural eq/hash. */ - kTVMFFISEqHashKindUnsupported = 0, - /*! - * \brief The object be compared as a tree node. - */ - kTVMFFISEqHashKindTreeNode = 1, - /*! - * \brief The object is treated as a free variable that can be mapped - * to another free variable in the definition region. - */ - kTVMFFISEqHashKindFreeVar = 2, - /*! - * \brief The field should be compared as a DAG node. - */ - kTVMFFISEqHashKindDAGNode = 3, - /*! - * \brief The object is treated as a constant tree node. - * - * Same as tree node, but the object does not contain free var - * as any of its nested children. - * - * That means we can use pointer equality for equality. - */ - kTVMFFISEqHashKindConstTreeNode = 4, - /*! - * \brief One can simply use pointer equality for equality. - * - * This is useful for "singleton"-style object that can - * is only an unique copy of each value. - */ - kTVMFFISEqHashKindUniqueInstance = 5, -#ifdef __cplusplus -}; -#else -} TVMFFISEqHashKind; -#endif - -/*! - * \brief Information support for optional object reflection. - */ -typedef struct { - /*! \brief The name of the field. */ - TVMFFIByteArray name; - /*! \brief The docstring about the field. */ - TVMFFIByteArray doc; - /*! \brief The structured metadata of the field in JSON string. */ - TVMFFIByteArray metadata; - /*! - * \brief bitmask flags of the field. - */ - int64_t flags; - /*! \brief The size of the field. */ - int64_t size; - /*! \brief The alignment of the field. */ - int64_t alignment; - /*! \brief The offset of the field. */ - int64_t offset; - /*! \brief The getter to access the field. */ - TVMFFIFieldGetter getter; - /*! - * \brief The setter to access the field. - * \note The setter is set even if the field is readonly for serialization. - */ - TVMFFIFieldSetter setter; - /*! - * \brief The default value of the field, this field hold AnyView, - * valid when flags set kTVMFFIFieldFlagBitMaskHasDefault - */ - TVMFFIAny default_value; - /*! - * \brief Records the static type kind of the field. - * - * Possible values: - * - * - TVMFFITypeIndex::kTVMFFIObject for general objects. - * The value is nullable when kTVMFFIObject is chosen. - * - Static object type kinds such as Map, Dict, String - * - POD type index, note it does not give information about storage size of the field. - * - TVMFFITypeIndex::kTVMFFIAny if we don't have specialized info - * about the field. - * - * When the value is a type index of Object type, the field is storaged as an ObjectRef. - * - * \note This information maybe helpful in designing serializer. - * As it helps to narrow down the field type so we don't have to - * print type_key for cases like POD types. - * It also helps to provide opportunities to enable short-cut getter to ObjectRef fields. - */ - int32_t field_static_type_index; -} TVMFFIFieldInfo; - -/*! - * \brief Method information that can appear in reflection table. - */ -typedef struct { - /*! \brief The name of the field. */ - TVMFFIByteArray name; - /*! \brief The docstring about the method. */ - TVMFFIByteArray doc; - // Rationale: We separate the docstring from the metadata since docstrings - // can be unstructured and sometimes large, while metadata can be focused - // on storing structured information. - /*! \brief Optional structured metadata of the method in JSON string. */ - TVMFFIByteArray metadata; - /*! \brief bitmask flags of the method. */ - int64_t flags; - /*! - * \brief The method wrapped as ffi::Function, stored as AnyView. - * \note The first argument to the method is always the self for instance methods. - */ - TVMFFIAny method; -} TVMFFIMethodInfo; - -/*! - * \brief Extra information of object type that can be used for reflection. - * - * \note This information is optional and can be used to enable reflection based - * creation of the object. - */ -typedef struct { - /*! \brief The docstring about the object. */ - TVMFFIByteArray doc; - /*! - * \brief An optional function that can create a new empty instance of the type. - * - * When known_fixed_size is non-zero, creator can be called - * with nullptr passed to optional_bytes. - * - * \note Caller must call setter for each field to initialize the object for - * the final object to be in valid state. - * - * \note This field is optional to enable reflection based creation. - */ - TVMFFIObjectCreator creator; - /*! - * \brief Total size of the object struct, if it is fixed and known. - * - * This field is set optional and set to 0 if not registered. - */ - int32_t total_size; - /*! - * \brief Optional meta-data for structural eq/hash. - */ - TVMFFISEqHashKind structural_eq_hash_kind; -} TVMFFITypeMetadata; - -/*! - * \brief Column array that stores extra attributes about types - * - * The attributes stored in a column array that can be looked up by type index. - * Note that the TypeAttr behaves like type_traits so column[T] so not contain - * attributes from base classes. - * - * \note - * \sa TVMFFIRegisterTypeAttr - */ -typedef struct { - /*! \brief The data of the column. */ - const TVMFFIAny *data; - /*! \brief The size of the column. */ - size_t size; -} TVMFFITypeAttrColumn; - -/*! - * \brief Runtime type information for object type checking. - */ -#ifdef __cplusplus -struct TVMFFITypeInfo { -#else -typedef struct TVMFFITypeInfo { -#endif - /*! - *\brief The runtime type index, - * It can be allocated during runtime if the type is dynamic. - */ - int32_t type_index; - /*! \brief number of parent types in the type hierachy. */ - int32_t type_depth; - /*! \brief the unique type key to identify the type. */ - TVMFFIByteArray type_key; - /*! - * \brief type_ancestors[depth] stores the type_index of the acenstors at depth level - * \note To keep things simple, we do not allow multiple inheritance so the - * hieracy stays as a tree - */ - const struct TVMFFITypeInfo **type_ancestors; - // The following fields are used for reflection - /*! \brief Cached hash value of the type key, used for consistent structural hashing. */ - uint64_t type_key_hash; - /*! \brief number of reflection accessible fields. */ - int32_t num_fields; - /*! \brief number of reflection acccesible methods. */ - int32_t num_methods; - /*! \brief The reflection field information. */ - const TVMFFIFieldInfo *fields; - /*! \brief The reflection method. */ - const TVMFFIMethodInfo *methods; - /*! \brief The extra information of the type. */ - const TVMFFITypeMetadata *metadata; -#ifdef __cplusplus -}; -#else -} TVMFFITypeInfo; -#endif - -/*! - * \brief Register the function to runtime's global table. - * The registered function can then be retrieved by the backend using its name. - * \param name The name of the function. - * \param f The function to be registered. - * \param allow_override Whether to allow overriding an already registered function. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIFunctionSetGlobal(const TVMFFIByteArray *name, TVMFFIObjectHandle f, - int allow_override); - -/*! - * \brief Register the function to runtime's global table with method info. - * This is the same as TVMFFIFunctionSetGlobal but with method info that can provide extra - * metadata used in the runtime. - * \param method_info The method info to be registered. - * \param allow_override Whether to allow overriding an already registered function. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFIFunctionSetGlobalFromMethodInfo(const TVMFFIMethodInfo *method_info, - int allow_override); - -/*! - * \brief Register type field information for runtime reflection. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFITypeRegisterField(int32_t type_index, const TVMFFIFieldInfo *info); - -/*! - * \brief Register type method information for runtime reflection. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFITypeRegisterMethod(int32_t type_index, const TVMFFIMethodInfo *info); - -/*! - * \brief Register type creator information for runtime reflection. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFITypeRegisterMetadata(int32_t type_index, const TVMFFITypeMetadata *metadata); - -/*! - * \brief Register extra type attributes that can be looked up during runtime. - * \return 0 on success, nonzero on failure. - */ -TVM_FFI_DLL int TVMFFITypeRegisterAttr(int32_t type_index, const TVMFFIByteArray *attr_name, - const TVMFFIAny *attr_value); - -/*! - * \brief Get the type attribute column by name. - * \return The pointer to the type attribute column. - * \return NULL if the attribute was not registered in the system. - */ -TVM_FFI_DLL const TVMFFITypeAttrColumn *TVMFFIGetTypeAttrColumn(const TVMFFIByteArray *attr_name); - -//------------------------------------------------------------ -// Section: Backend noexcept functions for internal use -// -// These functions are used internally and do not throw error -// instead the error will be logged and abort the process -// These are function are being called in startup or exit time -// so exception handling do not apply -//------------------------------------------------------------ -/*! - * \brief Get stack backtrace in a string. - * - * The backtrace is in the order of recent call first from the top of the stack - * to the bottom of the stack. This order makes it helpful for appending - * the extra backtrace as we unwind the stack. - * - * When printing out, we encourage reverse the order of lines to make it - * align with python style. - * - * \param filename The current file name. - * \param lineno The current line number - * \param func The current function - * \param cross_ffi_boundary Whether the backtrace is crossing the ffi boundary - * or we should stop at the ffi boundary when detected - * \return The backtrace string - * - * \note filename/func can be nullptr, then this info is skipped, they are useful - * for cases when debug symbols are not available. - */ -TVM_FFI_DLL const TVMFFIByteArray *TVMFFIBacktrace(const char *filename, int lineno, - const char *func, int cross_ffi_boundary); - -/*! - * \brief Initialize the type info during runtime. - * - * When the function is first called for a type, - * it will register the type to the type table in the runtime. - * If the static_tindex is non-negative, the function will - * allocate a runtime type index. - * Otherwise, we will populate the type table and return the static index. - * - * \param type_key The type key. - * \param type_depth The type depth. - * \param static_type_index Static type index if any, can be -1, which means this is a dynamic index - * \param num_child_slots Number of slots reserved for its children. - * \param child_slots_can_overflow Whether to allow child to overflow the slots. - * \param parent_type_index Parent type index, pass in -1 if it is root. - * - * \return The allocated type index. - */ -TVM_FFI_DLL int32_t TVMFFITypeGetOrAllocIndex(const TVMFFIByteArray *type_key, - int32_t static_type_index, int32_t type_depth, - int32_t num_child_slots, - int32_t child_slots_can_overflow, - int32_t parent_type_index); - -/*! - * \brief Get dynamic type info by type index. - * \return The type info. - */ -TVM_FFI_DLL const TVMFFITypeInfo *TVMFFIGetTypeInfo(int32_t type_index); - -#ifdef __cplusplus -} // TVM_FFI_EXTERN_C -#endif - -//--------------------------------------------------------------- -// The following API defines static object attribute accessors -// for language bindings. -// -// They are defined in C++ inline functions for cleaner code. -// Note that they only have to do with address offset computation. -// So they can always be reimplemented in bindings when c++ is -// not available or when binding only wants to refer to the dll. -//---------------------------------------------------------------- -#ifdef __cplusplus -/*! - * \brief Get the type index of an object. - * \param obj The object handle. - * \return The type index. - */ -inline int32_t TVMFFIObjectGetTypeIndex(TVMFFIObjectHandle obj) { - return static_cast(obj)->type_index; -} - -/*! - * \brief Get the content of a small string in bytearray format. - * \param value The value to get the content of the small string in bytearray format. - * \return The content of the small string in bytearray format. - */ -inline TVMFFIByteArray TVMFFISmallBytesGetContentByteArray(const TVMFFIAny *value) { - return TVMFFIByteArray{value->v_bytes, static_cast(value->small_str_len)}; -} - -/*! - * \brief Get the data pointer of a bytearray from a string or bytes object. - * \param obj The object handle. - * \return The data pointer. - */ -inline TVMFFIByteArray *TVMFFIBytesGetByteArrayPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); -} - -/*! - * \brief Get the data pointer of a ErrorInfo from an Error object. - * \param obj The object handle. - * \return The cell pointer. - */ -inline TVMFFIErrorCell *TVMFFIErrorGetCellPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); -} - -/*! - * \brief Get the data pointer of a function cell from a function object. - * \param obj The object handle. - * \return The cell pointer. - */ -inline TVMFFIFunctionCell *TVMFFIFunctionGetCellPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); -} - -/*! - * \brief Get the data pointer of a opaque object cell from a opaque object. - * \param obj The object handle. - * \return The cell pointer. - */ -inline TVMFFIOpaqueObjectCell *TVMFFIOpaqueObjectGetCellPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); -} - -/*! - * \brief Get the data pointer of a shape array from a shape object. - * \param obj The object handle. - * \return The cell pointer. - */ -inline TVMFFIShapeCell *TVMFFIShapeGetCellPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); -} - -/*! - * \brief Get the DLTensor pointer from an Tensor object. - * \param obj The object handle. - * \return The DLTensor pointer. - */ -inline DLTensor *TVMFFITensorGetDLTensorPtr(TVMFFIObjectHandle obj) { - return reinterpret_cast(reinterpret_cast(obj) + sizeof(TVMFFIObject)); -} - -/*! - * \brief Create a DLDevice from a device type and device id. - * \param device_type The device type. - * \param device_id The device id. - * \return The DLDevice. - */ -inline DLDevice TVMFFIDLDeviceFromIntPair(int32_t device_type, int32_t device_id) { - return DLDevice{static_cast(device_type), device_id}; -} -#endif // __cplusplus -#endif // TVM_FFI_C_API_H_ -// NOLINTEND(modernize-use-using,bugprone-reserved-identifier,modernize-deprecated-headers) diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/cast.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/cast.h deleted file mode 100644 index 6d6513af5..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/cast.h +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/cast.h - * \brief Extra value casting helpers - */ -#ifndef TVM_FFI_CAST_H_ -#define TVM_FFI_CAST_H_ - -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Get a reference type from a raw object ptr type - * - * It is always important to get a reference type - * if we want to return a value as reference or keep - * the object alive beyond the scope of the function. - * - * \param ptr The object pointer - * \tparam RefType The reference type - * \tparam ObjectType The object type - * \return The corresponding RefType - */ -template -inline RefType GetRef(const ObjectType *ptr) { - using ContainerType = typename RefType::ContainerType; - static_assert(std::is_base_of_v, - "Can only cast to the ref of same container type"); - - if constexpr (is_optional_type_v || RefType::_type_is_nullable) { - if (ptr == nullptr) { - return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); - } - } else { - TVM_FFI_ICHECK_NOTNULL(ptr); - } - return details::ObjectUnsafe::ObjectRefFromObjectPtr( - details::ObjectUnsafe::ObjectPtrFromUnowned( - const_cast(static_cast(ptr)))); -} - -/*! - * \brief Get an object ptr type from a raw object ptr. - * - * \param ptr The object pointer - * \tparam BaseType The reference type - * \tparam ObjectType The object type - * \return The corresponding RefType - */ -template -inline ObjectPtr GetObjectPtr(ObjectType *ptr) { - static_assert(std::is_base_of_v, - "Can only cast to the ref of same container type"); - return details::ObjectUnsafe::ObjectPtrFromUnowned(ptr); -} -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_CAST_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/array.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/array.h deleted file mode 100644 index 9f8674b50..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/array.h +++ /dev/null @@ -1,1164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/array.h - * \brief Array type. - * - * tvm::ffi::Array is an erased type that contains a list of content - */ -#ifndef TVM_FFI_CONTAINER_ARRAY_H_ -#define TVM_FFI_CONTAINER_ARRAY_H_ - -#include "../any.h" -#include "../memory.h" -#include "../object.h" -#include "../optional.h" -#include "container_details.h" - -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! \brief Array node content in array */ -class ArrayObj : public Object, public details::InplaceArrayBase { -public: - ~ArrayObj() { - Any *begin = MutableBegin(); - for (int64_t i = 0; i < size_; ++i) { - (begin + i)->Any::~Any(); - } - if (data_deleter_ != nullptr) { - data_deleter_(data_); - } - } - - /*! \return The size of the array */ - size_t size() const { return this->size_; } - - /*! - * \brief Read i-th element from array. - * \param i The index - * \return the i-th element. - */ - const Any &at(int64_t i) const { return this->operator[](i); } - - /*! - * \brief Read i-th element from array. - * \param i The index - * \return the i-th element. - */ - const Any &operator[](int64_t i) const { - if (i >= size_) { - TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << size_; - } - return static_cast(data_)[i]; - } - - /*! \return begin constant iterator */ - const Any *begin() const { return static_cast(data_); } - - /*! \return end constant iterator */ - const Any *end() const { return begin() + size_; } - - /*! \brief Release reference to all the elements */ - void clear() { ShrinkBy(size_); } - - /*! - * \brief Set i-th element of the array in-place - * \param i The index - * \param item The value to be set - */ - void SetItem(int64_t i, Any item) { - if (i >= size_) { - TVM_FFI_THROW(IndexError) << "Index " << i << " out of bounds " << size_; - } - static_cast(data_)[i] = std::move(item); - } - - /*! - * \brief Constructs a container and copy from another - * \param cap The capacity of the container - * \param from Source of the copy - * \return Ref-counted ArrayObj requested - */ - static ObjectPtr CopyFrom(int64_t cap, ArrayObj *from) { - int64_t size = from->size_; - if (size > cap) { - TVM_FFI_THROW(ValueError) << "Not enough capacity"; - } - ObjectPtr p = ArrayObj::Empty(cap); - Any *write = p->MutableBegin(); - Any *read = from->MutableBegin(); - // To ensure exception safety, size is only incremented after the initialization succeeds - for (int64_t &i = p->size_ = 0; i < size; ++i) { - new (write++) Any(*read++); - } - return p; - } - - /*! - * \brief Constructs a container and move from another - * \param cap The capacity of the container - * \param from Source of the move - * \return Ref-counted ArrayObj requested - */ - static ObjectPtr MoveFrom(int64_t cap, ArrayObj *from) { - int64_t size = from->size_; - if (size > cap) { - TVM_FFI_THROW(RuntimeError) << "Not enough capacity"; - } - ObjectPtr p = ArrayObj::Empty(cap); - Any *write = p->MutableBegin(); - Any *read = from->MutableBegin(); - // To ensure exception safety, size is only incremented after the initialization succeeds - for (int64_t &i = p->size_ = 0; i < size; ++i) { - new (write++) Any(std::move(*read++)); - } - from->size_ = 0; - return p; - } - - /*! - * \brief Constructs a container with n elements. Each element is a copy of val - * \param n The size of the container - * \param val The init value - * \return Ref-counted ArrayObj requested - */ - static ObjectPtr CreateRepeated(int64_t n, const Any &val) { - ObjectPtr p = ArrayObj::Empty(n); - Any *itr = p->MutableBegin(); - for (int64_t &i = p->size_ = 0; i < n; ++i) { - new (itr++) Any(val); - } - return p; - } - - /// \cond Doxygen_Suppress - static constexpr const int32_t _type_index = TypeIndex::kTVMFFIArray; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIArray, ArrayObj, Object); - /// \endcond - -private: - /*! \return Size of initialized memory, used by InplaceArrayBase. */ - size_t GetSize() const { return this->size_; } - - /*! \return begin mutable iterator */ - Any *MutableBegin() const { return static_cast(this->data_); } - - /*! \return end mutable iterator */ - Any *MutableEnd() const { return MutableBegin() + size_; } - - /*! - * \brief Emplace a new element at the back of the array - * \param idx The index of the element. - * \param args The arguments to construct the new element - */ - template - void EmplaceInit(size_t idx, Args &&...args) { - Any *itr = MutableBegin() + idx; - new (itr) Any(std::forward(args)...); - } - - /*! - * \brief Create an ArrayObj with the given capacity. - * \param n Required capacity - * \return Ref-counted ArrayObj requested - */ - static ObjectPtr Empty(int64_t n = kInitSize) { - ObjectPtr p = make_inplace_array_object(n); - p->capacity_ = n; - p->size_ = 0; - p->data_ = p->AddressOf(0); - return p; - } - - /*! - * \brief Inplace-initialize the elements starting idx from [first, last) - * \param idx The starting point - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return Self - */ - template - ArrayObj *InitRange(int64_t idx, IterType first, IterType last) { - Any *itr = MutableBegin() + idx; - for (; first != last; ++first) { - Any ref = *first; - new (itr++) Any(std::move(ref)); - } - return this; - } - - /*! - * \brief Move elements from right to left, requires src_begin > dst - * \param dst Destination - * \param src_begin The start point of copy (inclusive) - * \param src_end The end point of copy (exclusive) - * \return Self - */ - ArrayObj *MoveElementsLeft(int64_t dst, int64_t src_begin, int64_t src_end) { - Any *from = MutableBegin() + src_begin; - Any *to = MutableBegin() + dst; - while (src_begin++ != src_end) { - *to++ = std::move(*from++); - } - return this; - } - - /*! - * \brief Move elements from left to right, requires src_begin < dst - * \param dst Destination - * \param src_begin The start point of move (inclusive) - * \param src_end The end point of move (exclusive) - * \return Self - */ - ArrayObj *MoveElementsRight(int64_t dst, int64_t src_begin, int64_t src_end) { - Any *from = MutableBegin() + src_end; - Any *to = MutableBegin() + (src_end - src_begin + dst); - while (src_begin++ != src_end) { - *--to = std::move(*--from); - } - return this; - } - - /*! - * \brief Enlarges the size of the array - * \param delta Size enlarged, should be positive - * \param val Default value - * \return Self - */ - ArrayObj *EnlargeBy(int64_t delta, const Any &val = Any()) { - Any *itr = MutableEnd(); - while (delta-- > 0) { - new (itr++) Any(val); - ++size_; - } - return this; - } - - /*! - * \brief Shrinks the size of the array - * \param delta Size shrinked, should be positive - * \return Self - */ - ArrayObj *ShrinkBy(int64_t delta) { - Any *itr = MutableEnd(); - while (delta-- > 0) { - (--itr)->Any::~Any(); - --size_; - } - return this; - } - - /*! \brief Data pointer to the first element of the array */ - void *data_; - /*! \brief Number of elements used */ - int64_t size_; - /*! \brief Number of elements allocated */ - int64_t capacity_; - /*! - * \brief Optional data deleter when data is allocated separately - * and its deletion is not managed by ArrayObj::deleter_. - */ - void (*data_deleter_)(void *) = nullptr; - - /*! \brief Initial size of ArrayObj */ - static constexpr int64_t kInitSize = 4; - - /*! \brief Expansion factor of the Array */ - static constexpr int64_t kIncFactor = 2; - - // CRTP parent class - friend InplaceArrayBase; - - // Reference class - template - friend class Array; - - template - friend class Tuple; - - template - friend struct TypeTraits; - - // To specialize make_object - friend ObjectPtr make_object<>(); -}; - -/*! \brief Helper struct for type-checking - * - * is_valid_iterator::value will be true if IterType can - * be dereferenced into a type that can be stored in an Array, and - * false otherwise. - */ -template -struct is_valid_iterator - : std::bool_constant< - std::is_same_v< - T, std::remove_cv_t())>>> - || std::is_base_of_v< - T, std::remove_cv_t())>>>> { -}; - -template -struct is_valid_iterator, IterType> : is_valid_iterator {}; - -template -struct is_valid_iterator : std::true_type {}; - -/*! - * \brief Check whether IterType is valid iterator for T. - * \tparam T The type. - * \tparam IterType The type of iterator. - */ -template -inline constexpr bool is_valid_iterator_v = is_valid_iterator::value; - -/*! - * \brief Array, container representing a contiguous sequence of ObjectRefs. - * - * Array implements in-place copy-on-write semantics. - * - * As in typical copy-on-write, a method which would typically mutate the array - * instead opaquely copies the underlying container, and then acts on its copy. - * - * If the array has reference count equal to one, we directly update the - * container in place without copying. This is optimization is sound because - * when the reference count is equal to one this reference is guranteed to be - * the sole pointer to the container. - * - * - * operator[] only provides const access, use Set to mutate the content. - * \tparam T The content Value type, must be compatible with tvm::ffi::Any - */ -template >> -class Array : public ObjectRef { -public: - /*! \brief The value type of the array */ - using value_type = T; - // constructors - /*! - * \brief Construct an Array with UnsafeInit - */ - explicit Array(UnsafeInit tag) : ObjectRef(tag) {} - /*! - * \brief default constructor - */ - Array() { data_ = ArrayObj::Empty(); } // NOLINT(modernize-use-equals-default) - /*! - * \brief Move constructor - * \param other The other array - */ - Array(Array &&other) // NOLINT(google-explicit-constructor) - : ObjectRef(std::move(other.data_)) {} - /*! - * \brief Copy constructor - * \param other The other array - */ - Array(const Array &other) : ObjectRef(other.data_) {} // NOLINT(google-explicit-constructor) - /*! - * \brief Constructor from another array - * \param other The other array - * \tparam U The value type of the other array - */ - template >> - Array(Array &&other) // NOLINT(google-explicit-constructor) - : ObjectRef(std::move(other.data_)) {} - /*! - * \brief Constructor from another array - * \param other The other array - * \tparam U The value type of the other array - */ - template >> - Array(const Array &other) // NOLINT(google-explicit-constructor) - : ObjectRef(other.data_) {} - - /*! - * \brief Move assignment from another array - * \param other The other array - */ - TVM_FFI_INLINE Array &operator=(Array &&other) { - data_ = std::move(other.data_); - return *this; - } - /*! - * \brief Assignment from another array - * \param other The other array - */ - TVM_FFI_INLINE Array &operator=(const Array &other) { - data_ = other.data_; - return *this; - } - /*! - * \brief Move assignment from another array - * \param other The other array - * \tparam U The value type of the other array - */ - template >> - TVM_FFI_INLINE Array &operator=(Array &&other) { - data_ = std::move(other.data_); - return *this; - } - /*! - * \brief Assignment from another array - * \param other The other array - * \tparam U The value type of the other array - */ - template >> - TVM_FFI_INLINE Array &operator=(const Array &other) { - data_ = other.data_; - return *this; - } - - /*! - * \brief Constructor from pointer - * \param n the container pointer - */ - explicit Array(ObjectPtr n) : ObjectRef(std::move(n)) {} - - /*! - * \brief Constructor from iterator - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - */ - template - Array(IterType first, IterType last) { // NOLINT(performance-unnecessary-value-param) - static_assert(is_valid_iterator_v, - "IterType cannot be inserted into a tvm::Array"); - Assign(first, last); - } - - /*! - * \brief constructor from initializer list - * \param init The initializer list - */ - Array(std::initializer_list init) { // NOLINT(*) - Assign(init.begin(), init.end()); - } - - /*! - * \brief constructor from vector - * \param init The vector - */ - Array(const std::vector &init) { // NOLINT(*) - Assign(init.begin(), init.end()); - } - - /*! - * \brief Constructs a container with n elements. Each element is a copy of val - * \param n The size of the container - * \param val The init value - */ - explicit Array(const size_t n, const T &val) { data_ = ArrayObj::CreateRepeated(n, val); } - -public: - // iterators - /// \cond Doxygen_Suppress - struct ValueConverter { - using ResultType = T; - /*! - * \brief Convert any to T - * \param n The any value to convert - * \return The converted value - */ - static T convert(const Any &n) { return details::AnyUnsafe::CopyFromAnyViewAfterCheck(n); } - }; - /// \endcond - - /*! \brief The iterator type of the array */ - using iterator = details::IterAdapter; - /*! \brief The reverse iterator type of the array */ - using reverse_iterator = details::ReverseIterAdapter; - - /*! \return begin iterator */ - iterator begin() const { return iterator(GetArrayObj()->begin()); } - - /*! \return end iterator */ - iterator end() const { return iterator(GetArrayObj()->end()); } - - /*! \return rbegin iterator */ - reverse_iterator rbegin() const { - // ArrayObj::end() is never nullptr - return reverse_iterator(GetArrayObj()->end() - 1); - } - - /*! \return rend iterator */ - reverse_iterator rend() const { - // ArrayObj::begin() is never nullptr - return reverse_iterator(GetArrayObj()->begin() - 1); - } - -public: - // const methods in std::vector - /*! - * \brief Immutably read i-th element from array. - * \param i The index - * \return the i-th element. - */ - const T operator[](int64_t i) const { - ArrayObj *p = GetArrayObj(); - if (p == nullptr) { - TVM_FFI_THROW(IndexError) << "cannot index a null array"; - } - if (i < 0 || i >= p->size_) { - TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size " << p->size_; - } - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*(p->begin() + i)); - } - - /*! \return The size of the array */ - size_t size() const { - ArrayObj *p = GetArrayObj(); - return p == nullptr ? 0 : GetArrayObj()->size_; - } - - /*! \return The capacity of the array */ - size_t capacity() const { - ArrayObj *p = GetArrayObj(); - return p == nullptr ? 0 : GetArrayObj()->capacity_; - } - - /*! \return Whether array is empty */ - bool empty() const { return size() == 0; } - - /*! \return The first element of the array */ - const T front() const { - ArrayObj *p = GetArrayObj(); - if (p == nullptr || p->size_ == 0) { - TVM_FFI_THROW(IndexError) << "cannot index a empty array"; - } - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*(p->begin())); - } - - /*! \return The last element of the array */ - const T back() const { - ArrayObj *p = GetArrayObj(); - if (p == nullptr || p->size_ == 0) { - TVM_FFI_THROW(IndexError) << "cannot index a empty array"; - } - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*(p->end() - 1)); - } - -public: - // mutation in std::vector, implements copy-on-write - /*! - * \brief push a new item to the back of the list - * \param item The item to be pushed. - */ - void push_back(const T &item) { - ArrayObj *p = CopyOnWrite(1); - p->EmplaceInit(p->size_++, item); - } - - /*! - * \brief Emplace a new element at the back of the array - * \param args The arguments to construct the new element - */ - template - void emplace_back(Args &&...args) { - ArrayObj *p = CopyOnWrite(1); - p->EmplaceInit(p->size_++, std::forward(args)...); - } - - /*! - * \brief Insert an element into the given position - * \param position An iterator pointing to the insertion point - * \param val The element to insert - */ - void insert(iterator position, const T &val) { - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "cannot insert a null array"; - } - int64_t idx = std::distance(begin(), position); - int64_t size = GetArrayObj()->size_; - auto addr = CopyOnWrite(1) // - ->EnlargeBy(1) // - ->MoveElementsRight(idx + 1, idx, size) // - ->MutableBegin(); - new (addr + idx) Any(val); - } - - /*! - * \brief Insert a range of elements into the given position - * \param position An iterator pointing to the insertion point - * \param first The begin iterator of the range - * \param last The end iterator of the range - */ - template - void insert(iterator position, IterType first, IterType last) { - static_assert(is_valid_iterator_v, - "IterType cannot be inserted into a tvm::Array"); - - if (first == last) { - return; - } - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "cannot insert a null array"; - } - int64_t idx = std::distance(begin(), position); - int64_t size = GetArrayObj()->size_; - int64_t numel = std::distance(first, last); - CopyOnWrite(numel) - ->EnlargeBy(numel) - ->MoveElementsRight(idx + numel, idx, size) - ->InitRange(idx, first, last); - } - - /*! \brief Remove the last item of the list */ - void pop_back() { - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "cannot pop_back a null array"; - } - int64_t size = GetArrayObj()->size_; - if (size == 0) { - TVM_FFI_THROW(RuntimeError) << "cannot pop_back an empty array"; - } - CopyOnWrite()->ShrinkBy(1); - } - - /*! - * \brief Erase an element on the given position - * \param position An iterator pointing to the element to be erased - */ - void erase(iterator position) { - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "cannot erase a null array"; - } - int64_t st = std::distance(begin(), position); - int64_t size = GetArrayObj()->size_; - if (st < 0 || st >= size) { - TVM_FFI_THROW(RuntimeError) << "cannot erase at index " << st << ", because Array size is " - << size; - } - CopyOnWrite() // - ->MoveElementsLeft(st, st + 1, size) // - ->ShrinkBy(1); - } - - /*! - * \brief Erase a given range of elements - * \param first The begin iterator of the range - * \param last The end iterator of the range - */ - void erase(iterator first, iterator last) { - if (first == last) { - return; - } - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "cannot erase a null array"; - } - int64_t size = GetArrayObj()->size_; - int64_t st = std::distance(begin(), first); - int64_t ed = std::distance(begin(), last); - if (st >= ed) { - TVM_FFI_THROW(IndexError) << "cannot erase array in range [" << st << ", " << ed << ")"; - } - if (st < 0 || st > size || ed < 0 || ed > size) { - TVM_FFI_THROW(IndexError) << "cannot erase array in range [" << st << ", " << ed << ")" - << ", because array size is " << size; - } - CopyOnWrite() // - ->MoveElementsLeft(st, ed, size) // - ->ShrinkBy(ed - st); - } - - /*! - * \brief Resize the array. - * \param n The new size. - */ - void resize(int64_t n) { - if (n < 0) { - TVM_FFI_THROW(ValueError) << "cannot resize an Array to negative size"; - } - if (data_ == nullptr) { - SwitchContainer(n); - return; - } - int64_t size = GetArrayObj()->size_; - if (size < n) { - CopyOnWrite(n - size)->EnlargeBy(n - size); - } else if (size > n) { - CopyOnWrite()->ShrinkBy(size - n); - } - } - - /*! - * \brief Make sure the list has the capacity of at least n - * \param n lower bound of the capacity - */ - void reserve(int64_t n) { - if (data_ == nullptr || n > GetArrayObj()->capacity_) { - SwitchContainer(n); - } - } - - /*! \brief Release reference to all the elements */ - void clear() { - if (data_ != nullptr) { - ArrayObj *p = CopyOnWrite(); - p->clear(); - } - } - /// \cond Doxygen_Suppress - template - static size_t CalcCapacityImpl() { - return 0; - } - - template - static size_t CalcCapacityImpl(Array value, Args... args) { - return value.size() + CalcCapacityImpl(args...); - } - - template - static size_t CalcCapacityImpl(T value, Args... args) { - return 1 + CalcCapacityImpl(args...); - } - - template - static void AgregateImpl(Array &dest) {} // NOLINT(*) - - template - static void AgregateImpl(Array &dest, Array value, Args... args) { // NOLINT(*) - dest.insert(dest.end(), value.begin(), value.end()); - AgregateImpl(dest, args...); - } - - template - static void AgregateImpl(Array &dest, T value, Args... args) { // NOLINT(*) - dest.push_back(value); - AgregateImpl(dest, args...); - } - /// \endcond - -public: - // Array's own methods - - /*! - * \brief set i-th element of the array. - * \param i The index - * \param value The value to be setted. - */ - void Set(int64_t i, T value) { - ArrayObj *p = this->CopyOnWrite(); - if (i < 0 || i >= p->size_) { - TVM_FFI_THROW(IndexError) << "indexing " << i << " on an array of size " << p->size_; - } - *(p->MutableBegin() + i) = std::move(value); - } - - /*! \return The underlying ArrayObj */ - ArrayObj *GetArrayObj() const { return static_cast(data_.get()); } - - /*! - * \brief Helper function to apply a map function onto the array. - * - * \param fmap The transformation function T -> U. - * - * \tparam F The type of the mutation function. - * - * \tparam U The type of the returned array, inferred from the - * return type of F. If overridden by the user, must be something - * that is convertible from the return type of F. - * - * \note This function performs copy on write optimization. If - * `fmap` returns an object of type `T`, and all elements of the - * array are mapped to themselves, then the returned array will be - * the same as the original, and reference counts of the elements in - * the array will not be incremented. - * - * \return The transformed array. - */ - template > - Array Map(F fmap) const { - return Array(MapHelper(data_, fmap)); - } - - /*! - * \brief Helper function to apply fmutate to mutate an array. - * \param fmutate The transformation function T -> T. - * \tparam F the type of the mutation function. - * \note This function performs copy on write optimization. - */ - template >>> - void MutateByApply(F fmutate) { - data_ = MapHelper(std::move(data_), fmutate); - } - - /*! - * \brief reset the array to content from iterator. - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - */ - template - void Assign(IterType first, IterType last) { // NOLINT(performance-unnecessary-value-param) - int64_t cap = std::distance(first, last); - if (cap < 0) { - TVM_FFI_THROW(ValueError) << "cannot construct an Array of negative size"; - } - ArrayObj *p = GetArrayObj(); - if (p != nullptr && data_.unique() && p->capacity_ >= cap) { - // do not have to make new space - p->clear(); - } else { - // create new space - data_ = ArrayObj::Empty(cap); - p = GetArrayObj(); - } - // To ensure exception safety, size is only incremented after the initialization succeeds - Any *itr = p->MutableBegin(); - for (int64_t &i = p->size_ = 0; i < cap; ++i, ++first, ++itr) { - new (itr) Any(*first); - } - } - - /*! - * \brief Copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which ganrantees to be unique) - */ - ArrayObj *CopyOnWrite() { - if (data_ == nullptr) { - return SwitchContainer(ArrayObj::kInitSize); - } - if (!data_.unique()) { - return SwitchContainer(capacity()); - } - return static_cast(data_.get()); - } - - /*! \brief specify container node */ - using ContainerType = ArrayObj; - - /*! - * \brief Agregate arguments into a single Array - * \param args sequence of T or Array elements - * \return Agregated Array - */ - template - static Array Agregate(Args... args) { - Array result; - result.reserve(CalcCapacityImpl(args...)); - AgregateImpl(result, args...); - return result; - } - -private: - /*! - * \brief Implement copy-on-write semantics, and ensures capacity is enough for extra elements. - * \param reserve_extra Number of extra slots needed - * \return ArrayObj pointer to the unique copy - */ - ArrayObj *CopyOnWrite(int64_t reserve_extra) { - ArrayObj *p = GetArrayObj(); - if (p == nullptr) { - // necessary to get around the constexpr address issue before c++17 - const int64_t kInitSize = ArrayObj::kInitSize; - return SwitchContainer(std::max(kInitSize, reserve_extra)); - } - if (p->capacity_ >= p->size_ + reserve_extra) { - return CopyOnWrite(); - } - int64_t cap = p->capacity_ * ArrayObj::kIncFactor; - cap = std::max(cap, p->size_ + reserve_extra); - return SwitchContainer(cap); - } - - /*! - * \brief Move or copy the ArrayObj to new address with the given capacity - * \param capacity The capacity requirement of the new address - */ - ArrayObj *SwitchContainer(int64_t capacity) { - if (data_ == nullptr) { - data_ = ArrayObj::Empty(capacity); - } else if (data_.unique()) { - data_ = ArrayObj::MoveFrom(capacity, GetArrayObj()); - } else { - data_ = ArrayObj::CopyFrom(capacity, GetArrayObj()); - } - return static_cast(data_.get()); - } - - /*! \brief Helper method for mutate/map - * - * A helper function used internally by both `Array::Map` and - * `Array::MutateInPlace`. Given an array of data, apply the - * mapping function to each element, returning the collected array. - * Applies both mutate-in-place and copy-on-write optimizations, if - * possible. - * - * \param data A pointer to the ArrayObj containing input data. - * Passed by value to allow for mutate-in-place optimizations. - * - * \param fmap The mapping function - * - * \tparam F The type of the mutation function. - * - * \tparam U The output type of the mutation function. Inferred - * from the callable type given. Must inherit from ObjectRef. - * - * \return The mapped array. Depending on whether mutate-in-place - * or copy-on-write optimizations were applicable, may be the same - * underlying array as the `data` parameter. - */ - template > - static ObjectPtr MapHelper(ObjectPtr data, F fmap) { - if (data == nullptr) { - return nullptr; - } - - TVM_FFI_ICHECK(data->IsInstance()); - - constexpr bool is_same_output_type = std::is_same_v; - - if constexpr (is_same_output_type) { - if (data.unique()) { - // Mutate-in-place path. Only allowed if the output type U is - // the same as type T, we have a mutable this*, and there are - // no other shared copies of the array. - auto arr = static_cast(data.get()); - for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) { - T value = details::AnyUnsafe::CopyFromAnyViewAfterCheck(*it); - // reset the original value to nullptr, to ensure unique ownership - it->reset(); - T mapped = fmap(std::move(value)); - *it = std::move(mapped); - } - return data; - } - } - - constexpr bool compatible_types = is_valid_iterator_v || is_valid_iterator_v; - - ObjectPtr output = nullptr; - auto arr = static_cast(data.get()); - - auto it = arr->begin(); - if constexpr (compatible_types) { - // Copy-on-write path, if the output Array might be - // represented by the same underlying array as the existing - // Array. Typically, this is for functions that map `T` to - // `T`, but can also apply to functions that map `T` to - // `Optional`, or that map `T` to a subclass or superclass of - // `T`. - bool all_identical = true; - for (; it != arr->end(); it++) { - U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck(*it)); - if (!(*it).same_as(mapped)) { - // At least one mapped element is different than the - // original. Therefore, prepare the output array, - // consisting of any previous elements that had mapped to - // themselves (if any), and the element that didn't map to - // itself. - // - // We cannot use `U()` as the default object, as `U` may be - // a non-nullable type. Since the default `Any()` - // will be overwritten before returning, all objects will be - // of type `U` for the calling scope. - all_identical = false; - output = ArrayObj::CreateRepeated(static_cast(arr->size()), Any()); - output->InitRange(0, arr->begin(), it); - output->SetItem(it - arr->begin(), std::move(mapped)); - it++; - break; - } - } - if (all_identical) { - return data; - } - } else { - // Path for incompatible types. The constexpr check for - // compatible types isn't strictly necessary, as the first - // (*it).same_as(mapped) would return false, but we might as well - // avoid it altogether. - // - // We cannot use `U()` as the default object, as `U` may be a - // non-nullable type. Since the default `Any()` will be - // overwritten before returning, all objects will be of type `U` - // for the calling scope. - output = ArrayObj::CreateRepeated(static_cast(arr->size()), Any()); - } - - // Normal path for incompatible types, or post-copy path for - // copy-on-write instances. - // - // If the types are incompatible, then at this point `output` is - // empty, and `it` points to the first element of the input. - // - // If the types were compatible, then at this point `output` - // contains zero or more elements that mapped to themselves - // followed by the first element that does not map to itself, and - // `it` points to the element just after the first element that - // does not map to itself. Because at least one element has been - // changed, we no longer have the opportunity to avoid a copy, so - // we don't need to check the result. - // - // In both cases, `it` points to the next element to be processed, - // so we can either start or resume the iteration from that point, - // with no further checks on the result. - for (; it != arr->end(); it++) { - U mapped = fmap(details::AnyUnsafe::CopyFromAnyViewAfterCheck(*it)); - output->SetItem(it - arr->begin(), std::move(mapped)); - } - - return output; - } - template - friend class Array; -}; - -/*! - * \brief Concat two Arrays. - * \param lhs first Array to be concatenated. - * \param rhs second Array to be concatenated. - * \return The concatenated Array. Original Arrays are kept unchanged. - */ -template || TypeTraits::convert_enabled>> -inline Array Concat(Array lhs, const Array &rhs) { - for (const auto &x : rhs) { - lhs.push_back(x); - } - return std::move(lhs); -} - -/*! - * \brief Specialize make_object - * \return The empty array object. - */ -template <> -inline ObjectPtr make_object() { - return ArrayObj::Empty(); -} - -// Traits for Array -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public ObjectRefTypeTraitsBase> { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIArray; - using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny *src) { - if (src->type_index != TypeIndex::kTVMFFIArray) { - return TypeTraitsBase::GetMismatchTypeInfo(src); - } - if constexpr (!std::is_same_v) { - const ArrayObj *n = reinterpret_cast(src->v_obj); - for (size_t i = 0; i < n->size(); i++) { - const Any &any_v = (*n)[static_cast(i)]; - // CheckAnyStrict is cheaper than try_cast - if (details::AnyUnsafe::CheckAnyStrict(any_v)) { - continue; - } - // try see if p is convertible to T - if (any_v.try_cast()) { - continue; - } - // now report the accurate mismatch information - return "Array[index " + std::to_string(i) + ": " + details::AnyUnsafe::GetMismatchTypeInfo(any_v) + "]"; - } - } - TVM_FFI_THROW(InternalError) << "Cannot reach here"; - TVM_FFI_UNREACHABLE(); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - if (src->type_index != TypeIndex::kTVMFFIArray) { - return false; - } - if constexpr (std::is_same_v) { - return true; - } else { - const ArrayObj *n = reinterpret_cast(src->v_obj); - for (const Any &any_v : *n) { - if (!details::AnyUnsafe::CheckAnyStrict(any_v)) { - return false; - } - } - return true; - } - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny *src) { - // try to run conversion. - if (src->type_index != TypeIndex::kTVMFFIArray) { - return std::nullopt; - } - if constexpr (!std::is_same_v) { - const ArrayObj *n = reinterpret_cast(src->v_obj); - bool storage_check = [&]() { - for (const Any &any_v : *n) { - if (!details::AnyUnsafe::CheckAnyStrict(any_v)) { - return false; - } - } - return true; - }(); - // fast path, if storage check passes, we can return the array directly. - if (storage_check) { - return CopyFromAnyViewAfterCheck(src); - } - // slow path, try to run a conversion to Array - Array result; - result.reserve(n->size()); - for (const Any &any_v : *n) { - if (auto opt_v = any_v.try_cast()) { - result.push_back(*std::move(opt_v)); - } else { - return std::nullopt; - } - } - return result; - } else { - return CopyFromAnyViewAfterCheck(src); - } - } - - TVM_FFI_INLINE static std::string TypeStr() { return "Array<" + details::Type2Str::v() + ">"; } - TVM_FFI_INLINE static std::string TypeSchema() { - std::ostringstream oss; - oss << R"({"type":")" << StaticTypeKey::kTVMFFIArray << R"(","args":[)"; - oss << details::TypeSchema::v(); - oss << "]}"; - return oss.str(); - } -}; - -namespace details { -template -inline constexpr bool type_contains_v, Array> = type_contains_v; -} // namespace details - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_CONTAINER_ARRAY_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/container_details.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/container_details.h deleted file mode 100644 index d6102946a..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/container_details.h +++ /dev/null @@ -1,360 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/container_details.h - * \brief Common utilities for typed container types. - */ -#ifndef TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ -#define TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ - -#include "../any.h" -#include "../memory.h" -#include "../object.h" - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace details { -/*! - * \brief Base template for classes with array like memory layout. - * - * It provides general methods to access the memory. The memory - * layout is ArrayType + [ElemType]. The alignment of ArrayType - * and ElemType is handled by the memory allocator. - * - * \tparam ArrayType The array header type, contains object specific metadata. - * \tparam ElemType The type of objects stored in the array right after - * ArrayType. - * - * \code - * // Example usage of the template to define a simple array wrapper - * class ArrayObj : public tvm::ffi::details::InplaceArrayBase { - * public: - * // Wrap EmplaceInit to initialize the elements - * template - * void Init(Iterator begin, Iterator end) { - * size_t num_elems = std::distance(begin, end); - * auto it = begin; - * this->size = 0; - * for (size_t i = 0; i < num_elems; ++i) { - * InplaceArrayBase::EmplaceInit(i, *it++); - * this->size++; - * } - * } - * } - * - * void test_function() { - * vector fields; - * auto ptr = make_inplace_array_object(fields.size()); - * ptr->Init(fields.begin(), fields.end()); - * - * // Access the 0th element in the array. - * assert(ptr->operator[](0) == fields[0]); - * } - * - * \endcode - */ -template -class InplaceArrayBase { -public: - /*! - * \brief Access element at index - * \param idx The index of the element. - * \return Const reference to ElemType at the index. - */ - const ElemType &operator[](size_t idx) const { - size_t size = Self()->GetSize(); - if (idx > size) { - TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " << size; - } - return *(reinterpret_cast(AddressOf(idx))); - } - - /*! - * \brief Access element at index - * \param idx The index of the element. - * \return Reference to ElemType at the index. - */ - ElemType &operator[](size_t idx) { - size_t size = Self()->GetSize(); - if (idx > size) { - TVM_FFI_THROW(IndexError) << "Index " << idx << " out of bounds " << size; - } - return *(reinterpret_cast(AddressOf(idx))); - } - - /*! - * \brief Destroy the Inplace Array Base object - */ - ~InplaceArrayBase() { - if constexpr (!(std::is_standard_layout_v && std::is_trivial_v)) { - size_t size = Self()->GetSize(); - for (size_t i = 0; i < size; ++i) { - ElemType *fp = reinterpret_cast(AddressOf(i)); - fp->ElemType::~ElemType(); - } - } - } - -private: - InplaceArrayBase() = default; - friend ArrayType; - -protected: - /*! - * \brief Construct a value in place with the arguments. - * - * \tparam Args Type parameters of the arguments. - * \param idx Index of the element. - * \param args Arguments to construct the new value. - * - * \note Please make sure ArrayType::GetSize returns 0 before first call of - * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds. - */ - template - void EmplaceInit(size_t idx, Args &&...args) { - void *field_ptr = AddressOf(idx); - new (field_ptr) ElemType(std::forward(args)...); - } - - /*! - * \brief Return the self object for the array. - * - * \return Pointer to ArrayType. - */ - inline ArrayType *Self() const { - return static_cast(const_cast(this)); - } - - /*! - * \brief Return the raw pointer to the element at idx. - * - * \param idx The index of the element. - * \return Raw pointer to the element. - */ - void *AddressOf(size_t idx) const { - static_assert( - alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, - "The size and alignment of ArrayType should respect " - "ElemType's alignment."); - - size_t kDataStart = sizeof(ArrayType); - ArrayType *self = Self(); - char *data_start = reinterpret_cast(self) + kDataStart; - return data_start + idx * sizeof(ElemType); - } -}; - -/*! - * \brief iterator adapter that adapts TIter to return another type. - * \tparam Converter a struct that contains converting function - * \tparam TIter the content iterator type. - */ -template -class IterAdapter { -public: - using difference_type = typename std::iterator_traits::difference_type; - using value_type = typename Converter::ResultType; - using pointer = const typename Converter::ResultType *; - using reference = const typename Converter::ResultType; - using iterator_category = typename std::iterator_traits::iterator_category; - - explicit IterAdapter(TIter iter) : iter_(iter) {} - IterAdapter &operator++() { - ++iter_; - return *this; - } - IterAdapter &operator--() { - --iter_; - return *this; - } - IterAdapter operator++(int) { - IterAdapter copy = *this; - ++iter_; - return copy; - } - IterAdapter operator--(int) { - IterAdapter copy = *this; - --iter_; - return copy; - } - - IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); } - - IterAdapter operator-(difference_type offset) const { return IterAdapter(iter_ - offset); } - - IterAdapter &operator+=(difference_type offset) { - iter_ += offset; - return *this; - } - - IterAdapter &operator-=(difference_type offset) { - iter_ -= offset; - return *this; - } - - template - inline std::enable_if_t, - typename T::difference_type> - operator-(const IterAdapter &rhs) const { - return iter_ - rhs.iter_; - } - - bool operator==(IterAdapter other) const { return iter_ == other.iter_; } - bool operator!=(IterAdapter other) const { return !(*this == other); } - reference operator*() const { return Converter::convert(*iter_); } - -private: - TIter iter_; -}; - -/*! - * \brief iterator adapter that adapts TIter to return another type. - * \tparam Converter a struct that contains converting function - * \tparam TIter the content iterator type. - */ -template -class ReverseIterAdapter { -public: - using difference_type = typename std::iterator_traits::difference_type; - using value_type = typename Converter::ResultType; - using pointer = const typename Converter::ResultType *; - using reference = const typename Converter::ResultType; - using iterator_category = typename std::iterator_traits::iterator_category; - - explicit ReverseIterAdapter(TIter iter) : iter_(iter) {} - ReverseIterAdapter &operator++() { - --iter_; - return *this; - } - ReverseIterAdapter &operator--() { - ++iter_; - return *this; - } - ReverseIterAdapter operator++(int) { - ReverseIterAdapter copy = *this; - --iter_; - return copy; - } - ReverseIterAdapter operator--(int) { - ReverseIterAdapter copy = *this; - ++iter_; - return copy; - } - ReverseIterAdapter operator+(difference_type offset) const { - return ReverseIterAdapter(iter_ - offset); - } - - template - inline std::enable_if_t, - typename T::difference_type> - operator-(const ReverseIterAdapter &rhs) const { - return rhs.iter_ - iter_; - } - - bool operator==(ReverseIterAdapter other) const { return iter_ == other.iter_; } - bool operator!=(ReverseIterAdapter other) const { return !(*this == other); } - reference operator*() const { return Converter::convert(*iter_); } - -private: - TIter iter_; -}; - -/*! - * \brief Check if T is compatible with Any. - * - * \tparam T The type to check. - * \return True if T is compatible with Any, false otherwise. - */ -template -inline constexpr bool storage_enabled_v = std::is_same_v || TypeTraits::storage_enabled; - -/*! - * \brief Check if all T are compatible with Any. - * - * \tparam T The type to check. - * \return True if T is compatible with Any, false otherwise. - */ -template -inline constexpr bool all_storage_enabled_v = (storage_enabled_v && ...); - -/*! - * \brief Check if all T are compatible with Any. - * - * \tparam T The type to check. - * \return True if T is compatible with Any, false otherwise. - */ -template -inline constexpr bool all_object_ref_v = (std::is_base_of_v && ...); -/** - * \brief Check if Any storage of Derived can always be directly used as Base. - * - * \tparam Base The base type. - * \tparam Derived The derived type. - * \return True if Derived's storage can be used as Base's storage, false otherwise. - */ -template -inline constexpr bool type_contains_v = std::is_base_of_v || std::is_same_v; -// special case for Any -template -inline constexpr bool type_contains_v = true; - -/*! - * \brief Create a string of the container type. - * \tparam V The types of the elements in the container. - * \param name The name of the container type. - * \return A string of the container type. - */ -template -std::string ContainerTypeStr(const char *name) { - std::stringstream ss; - // helper to construct concated string of TypeStr - class TypeStrHelper { - public: - TypeStrHelper(std::stringstream &stream) : stream_(stream) {} // NOLINT(*) - - TypeStrHelper &operator<<(const std::string &str) { - if (counter_ > 0) { - stream_ << ", "; - } - stream_ << str; - counter_++; - return *this; - } - - private: - std::stringstream &stream_; // NOLINT(*) - int counter_ = 0; - }; - TypeStrHelper helper(ss); - ss << name << '<'; - (helper << ... << Type2Str::v()); - ss << '>'; - return ss.str(); -} - -} // namespace details -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_CONTAINER_CONTAINER_DETAILS_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/map.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/map.h deleted file mode 100644 index 4e5d1a635..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/map.h +++ /dev/null @@ -1,1785 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/map.h - * \brief Runtime Map container types. - */ -#ifndef TVM_FFI_CONTAINER_MAP_H_ -#define TVM_FFI_CONTAINER_MAP_H_ - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/// \cond Doxygen_Suppress -#if TVM_FFI_DEBUG_WITH_ABI_CHANGE -#define TVM_FFI_MAP_FAIL_IF_CHANGED() \ - TVM_FFI_ICHECK(state_marker == self->state_marker) << "Concurrent modification of the Map"; -#else -#define TVM_FFI_MAP_FAIL_IF_CHANGED() -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE -/// \endcond - -/*! \brief Shared content of all specializations of hash map */ -class MapObj : public Object { -public: - /*! \brief Type of the keys in the hash map */ - using key_type = Any; - /*! \brief Type of the values in the hash map */ - using mapped_type = Any; - /*! \brief Type of value stored in the hash map */ - using KVType = std::pair; - /// \cond Doxygen_Suppress - /*! \brief Type of raw storage of the key-value pair in the hash map */ - struct KVRawStorageType { - TVMFFIAny first; - TVMFFIAny second; - }; - /// \endcond - /*! \brief Iterator class */ - class iterator; - - static_assert(std::is_standard_layout_v, "KVType is not standard layout"); - static_assert(sizeof(KVType) == 32, "sizeof(KVType) incorrect"); - - /// \cond Doxygen_Suppress - static constexpr const int32_t _type_index = TypeIndex::kTVMFFIMap; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIMap, MapObj, Object); - /// \endcond - - /*! - * \brief Number of elements in the MapObj - * \return The result - */ - size_t size() const { return size_; } - /*! - * \brief Count the number of times a key exists in the hash map - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type &key) const; - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type &at(const key_type &key) const; - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type &at(const key_type &key); - /*! \return begin iterator */ - iterator begin() const; - /*! \return end iterator */ - iterator end() const; - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type &key) const; - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator &position); - /*! - * \brief Erase the entry associated with the key, do nothing if not exists - * \param key The indexing key - */ - void erase(const key_type &key) { erase(find(key)); } - - /// \cond Doxygen_Suppress - class iterator { - public: - using iterator_category = std::forward_iterator_tag; - using difference_type = int64_t; - using value_type = KVType; - using pointer = KVType *; - using reference = KVType &; -/*! \brief Default constructor */ -#if TVM_FFI_DEBUG_WITH_ABI_CHANGE - iterator() : state_marker(0), index(0), self(nullptr) {} -#else - iterator() : index(0), self(nullptr) {} -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - /*! \brief Compare iterators */ - bool operator==(const iterator &other) const { - TVM_FFI_MAP_FAIL_IF_CHANGED() - return index == other.index && self == other.self; - } - /*! \brief Compare iterators */ - bool operator!=(const iterator &other) const { return !(*this == other); } - /*! \brief De-reference iterators */ - pointer operator->() const; - /*! \brief De-reference iterators */ - reference operator*() const { - TVM_FFI_MAP_FAIL_IF_CHANGED() - return *((*this).operator->()); - } - /*! \brief Prefix self increment, e.g. ++iter */ - iterator &operator++(); - /*! \brief Prefix self decrement, e.g. --iter */ - iterator &operator--(); - /*! \brief Suffix self increment */ - iterator operator++(int) { - TVM_FFI_MAP_FAIL_IF_CHANGED() - iterator copy = *this; - ++(*this); - return copy; - } - /*! \brief Suffix self decrement */ - iterator operator--(int) { - TVM_FFI_MAP_FAIL_IF_CHANGED() - iterator copy = *this; - --(*this); - return copy; - } - - protected: -#if TVM_FFI_DEBUG_WITH_ABI_CHANGE - uint64_t state_marker; - /*! \brief Construct by value */ - iterator(uint64_t index, const MapObj *self) - : state_marker(self->state_marker), index(index), self(self) {} - -#else - iterator(uint64_t index, const MapObj *self) : index(index), self(self) {} -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - /*! \brief The position on the array */ - uint64_t index; - /*! \brief The container it points to */ - const MapObj *self; - - friend class DenseMapObj; - friend class SmallMapObj; - }; - /// \endcond - /*! - * \brief Create an empty container - * \return The object created - */ - static inline ObjectPtr Empty(); - -protected: -#if TVM_FFI_DEBUG_WITH_ABI_CHANGE - uint64_t state_marker; -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - /*! - * \brief Create the map using contents from the given iterators. - * \param first Begin of iterator - * \param last End of iterator - * \tparam IterType The type of iterator - * \return ObjectPtr to the map created - */ - template - static inline ObjectPtr CreateFromRange(IterType first, IterType last); - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static inline void InsertMaybeReHash(KVType &&kv, ObjectPtr *map); - /*! - * \brief Create an empty container with elements copying from another SmallMapObj - * \param from The source container - * \return The object created - */ - static inline ObjectPtr CopyFrom(MapObj *from); - /*! - * \brief data pointer to the data region of the map. - * \note For immutable inplace small map we do not need data_, - * but we keep it here for future compact with mutable container. - */ - void *data_; - /*! \brief number of entries in the container */ - uint64_t size_; - /*! \brief number of slots */ - uint64_t slots_; - /*! - * \brief Small layout tag mask - * \note The most significant bit is used to indicate the small map layout. - */ - static constexpr uint64_t kSmallTagMask = static_cast(1) << 63; - /*! - * \brief Check if the map is a small map - * \return True if the map is a small map - */ - bool IsSmallMap() const { return (slots_ & kSmallTagMask) != 0ull; } - /*! - * \brief Optional data deleter when data is allocated separately - * and its deletion is not managed by MapObj::deleter_. - */ - void (*data_deleter_)(void *) = nullptr; - // Reference class - template - friend class Map; -}; - -/*! \brief A specialization of small-sized hash map */ -class SmallMapObj : public MapObj, - public details::InplaceArrayBase { -private: - static constexpr uint64_t kInitSize = 2; - static constexpr uint64_t kMaxSize = 4; - -public: - using MapObj::iterator; - using MapObj::KVType; - - // Return the number of usable slots for Small layout (mask off tag). - /*! - * \brief Return the number of usable slots for Small layout (mask off tag). - * \return The number of usable slots - */ - uint64_t NumSlots() const { return slots_ & ~kSmallTagMask; } - - ~SmallMapObj() { - KVType *begin = static_cast(data_); - for (uint64_t index = 0; index < size_; ++index) { - // call destructor to destroy the item in `begin + index` - // Explicit call Any::~Any() to destroy the Any object - // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) - (begin + index)->first.Any::~Any(); - (begin + index)->second.Any::~Any(); - } - if (data_deleter_ != nullptr) { - data_deleter_(data_); - } - } - /*! - * \brief Count the number of times a key exists in the SmallMapObj - * \param key The indexing key - * \return The result, 0 or 1 - */ - size_t count(const key_type &key) const { return find(key).index < size_; } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type &at(const key_type &key) const { - iterator itr = find(key); - if (itr.index >= size_) { - TVM_FFI_THROW(KeyError) << "key is not in Map"; - } - return itr->second; - } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type &at(const key_type &key) { - iterator itr = find(key); - if (itr.index >= size_) { - TVM_FFI_THROW(KeyError) << "key is not in Map"; - } - return itr->second; - } - /*! \return begin iterator */ - iterator begin() const { return iterator(0, this); } - /*! \return end iterator */ - iterator end() const { return iterator(size_, this); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type &key) const { - KVType *ptr = static_cast(data_); - for (uint64_t i = 0; i < size_; ++i, ++ptr) { - if (AnyEqual()(ptr->first, key)) { - return iterator(i, this); - } - } - return iterator(size_, this); - } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator &position) { Erase(position.index); } - -private: - /*! - * \brief Set the number of slots and attach tags bit. - * \param n The number of slots - */ - void SetSlotsAndSmallLayoutTag(uint64_t n) { slots_ = (n & ~kSmallTagMask) | kSmallTagMask; } - /*! - * \brief Remove a position in SmallMapObj - * \param index The position to be removed - */ - void Erase(const uint64_t index) { - if (index >= size_) { - return; - } - KVType *begin = static_cast(data_); - // call destructor to destroy the item in `begin + index` - // Explicit call Any::~Any() to destroy the Any object - // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) - (begin + index)->first.Any::~Any(); - (begin + index)->second.Any::~Any(); - // IMPORTANT: We do direct raw memmove to bring later items to the current position - // to preserve the order of insertion. - // This works because direct memory copy preserves the Any's move semantics. - if (index + 1 < size_) { - std::memmove(reinterpret_cast(begin + index), - reinterpret_cast(begin + index + 1), - (size_ - index - 1) * sizeof(KVType)); - } - size_ -= 1; - } - /*! - * \brief Create an empty container - * \param n Number of empty slots - * \return The object created - */ - static ObjectPtr Empty(uint64_t n = kInitSize) { - using ::tvm::ffi::make_inplace_array_object; - ObjectPtr p = make_inplace_array_object(n); - p->data_ = p->AddressOf(0); - p->size_ = 0; - p->SetSlotsAndSmallLayoutTag(n); - return p; - } - /*! - * \brief Create an empty container initialized with a given range - * \param n Number of empty slots - * \param first begin of iterator - * \param last end of iterator - * \tparam IterType The type of iterator - * \return The object created - */ - template - static ObjectPtr CreateFromRange(uint64_t n, IterType first, IterType last) { - ObjectPtr p = Empty(n); - KVType *ptr = static_cast(p->data_); - for (; first != last; ++first, ++p->size_) { - new (ptr++) KVType(*first); - } - return p; - } - /*! - * \brief Create an empty container with elements copying from another SmallMapObj - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(SmallMapObj *from) { - KVType *first = static_cast(from->data_); - KVType *last = first + from->size_; - return CreateFromRange(from->size_, first, last); - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(KVType &&kv, ObjectPtr *map) { - SmallMapObj *map_node = static_cast(map->get()); - iterator itr = map_node->find(kv.first); - if (itr.index < map_node->size_) { - itr->second = kv.second; - return; - } - if (map_node->size_ < map_node->NumSlots()) { - KVType *ptr = static_cast(map_node->data_) + map_node->size_; - new (ptr) KVType(std::move(kv)); - ++map_node->size_; - return; - } - uint64_t next_size = std::max(map_node->NumSlots() * 2, kInitSize); - next_size = std::min(next_size, kMaxSize); - TVM_FFI_ICHECK_GT(next_size, map_node->NumSlots()); - ObjectPtr new_map = CreateFromRange(next_size, map_node->begin(), map_node->end()); - InsertMaybeReHash(std::move(kv), &new_map); - *map = std::move(new_map); - } - /*! - * \brief Increment the pointer - * \param index The pointer to be incremented - * \return The increased pointer - */ - uint64_t IncItr(uint64_t index) const { return index + 1 < size_ ? index + 1 : size_; } - /*! - * \brief Decrement the pointer - * \param index The pointer to be decremented - * \return The decreased pointer - */ - uint64_t DecItr(uint64_t index) const { return index > 0 ? index - 1 : size_; } - /*! - * \brief De-reference the pointer - * \param index The pointer to be dereferenced - * \return The result - */ - KVType *DeRefItr(uint64_t index) const { return static_cast(data_) + index; } - /*! \brief A size function used by InplaceArrayBase */ - uint64_t GetSize() const { return size_; } - -protected: - friend class MapObj; - friend class DenseMapObj; - friend class details::InplaceArrayBase; -}; - -/*! \brief A specialization of hash map that implements the idea of array-based hash map. - * Another reference implementation can be found [1]. - * - * A. Overview - * - * DenseMapObj did several improvements over traditional separate chaining hash, - * in terms of cache locality, memory footprints and data organization. - * - * A1. Implicit linked list. For better cache locality, instead of using linked list - * explicitly for each bucket, we store list data into a single array that spans contiguously - * in memory, and then carefully design access patterns to make sure most of them fall into - * a single cache line. - * - * A2. 1-byte metadata. There is only 1 byte overhead for each slot in the array to indexing and - * traversal. This can be divided in 3 parts. - * 1) Reserved code: (0b11111111)_2 indicates a slot is empty; (0b11111110)_2 indicates protected, - * which means the slot is empty but not allowed to be written. - * 2) If not empty or protected, the highest bit is used to indicate whether data in the slot is - * head of a linked list. - * 3) The rest 7 bits are used as the "next pointer" (i.e. pointer to the next element). On 64-bit - * architecture, an ordinary pointer can take up to 8 bytes, which is not acceptable overhead when - * dealing with 16-byte ObjectRef pairs. Based on a commonly noticed fact that the lists are - * relatively short (length <= 3) in hash maps, we follow [1]'s idea that only allows the pointer to - * be one of the 126 possible values, i.e. if the next element of i-th slot is (i + x)-th element, - * then x must be one of the 126 pre-defined values. - * - * A3. Data blocking. We organize the array in the way that every 16 elements forms a data block. - * The 16-byte metadata of those 16 elements are stored together, followed by the real data, i.e. - * 16 key-value pairs. - * - * B. Implementation details - * - * B1. Power-of-2 table size and Fibonacci Hashing. We use power-of-two as table size to avoid - * modulo for more efficient arithmetics. To make the hash-to-slot mapping distribute more evenly, - * we use the Fibonacci Hashing [2] trick. - * - * B2. Traverse a linked list in the array. - * 1) List head. Assume Fibonacci Hashing maps a given key to slot i, if metadata at slot i - * indicates that it is list head, then we found the head; otherwise the list is empty. No probing - * is done in this procedure. 2) Next element. To find the next element of a non-empty slot i, we - * look at the last 7 bits of the metadata at slot i. If they are all zeros, then it is the end of - * list; otherwise, we know that the next element is (i + candidates[the-last-7-bits]). - * - * B3. InsertMaybeReHash an element. Following B2, we first traverse the linked list to see if this - * element is in the linked list, and if not, we put it at the end by probing the next empty - * position in one of the 126 candidate positions. If the linked list does not even exist, but the - * slot for list head has been occupied by another linked list, we should find this intruder another - * place. - * - * B4. Quadratic probing with triangle numbers. In open address hashing, it is provable that probing - * with triangle numbers can traverse power-of-2-sized table [3]. In our algorithm, we follow the - * suggestion in [1] that also use triangle numbers for "next pointer" as well as sparing for list - * head. - * - * [1] https://github.com/skarupke/flat_hash_map - * [2] https://programmingpraxis.com/2018/06/19/fibonacci-hash/ - * [3] https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ - */ -class DenseMapObj : public MapObj { -private: - /*! \brief The number of elements in a memory block */ - static constexpr int kBlockCap = 16; - /*! \brief Maximum load factor of the hash map */ - static constexpr double kMaxLoadFactor = 0.99; - /*! \brief Binary representation of the metadata of an empty slot */ - static constexpr uint8_t kEmptySlot = static_cast(0b11111111); - /*! \brief Binary representation of the metadata of a protected slot */ - static constexpr uint8_t kProtectedSlot = static_cast(0b11111110); - /*! \brief Number of probing choices available */ - static constexpr int kNumJumpDists = 126; - /*! \brief Index indicator to indicate an invalid index */ - static constexpr uint64_t kInvalidIndex = std::numeric_limits::max(); - /*! \brief Head of the implicit linked list */ - struct ListNode; - /*! \brief item type of the dense map, including a kv data and prev/next pointer */ - struct ItemType { - KVType data; - uint64_t prev = kInvalidIndex; - uint64_t next = kInvalidIndex; - - explicit ItemType(KVType &&data) : data(std::move(data)) {} - explicit ItemType(key_type key, mapped_type value) : data(std::move(key), std::move(value)) {} - }; - /*! \brief POD type of a block of memory */ - struct Block { - uint8_t bytes[kBlockCap + kBlockCap * sizeof(ItemType)]; - }; - static_assert(sizeof(Block) == kBlockCap * (sizeof(ItemType) + 1), "sizeof(Block) incorrect"); - static_assert(std::is_standard_layout_v, "Block is not standard layout"); - - /*! - * \brief Deleter for the Block - * \param data The pointer to the Block - */ - static void BlockDeleter(void *data) { delete[] static_cast(data); } - -public: - using MapObj::iterator; - - /*! - * \brief Return the number of usable slots for Dense layout (MSB clear => identity). - * \return The number of usable slots - */ - uint64_t NumSlots() const { return slots_; } - - /*! - * \brief Destroy the DenseMapObj - */ - ~DenseMapObj() { this->Reset(); } - /*! \return The number of elements of the key */ - size_t count(const key_type &key) const { return !Search(key).IsNone(); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The const reference to the value - */ - const mapped_type &at(const key_type &key) const { return At(key); } - /*! - * \brief Index value associated with a key, throw exception if the key does not exist - * \param key The indexing key - * \return The mutable reference to the value - */ - mapped_type &at(const key_type &key) { return At(key); } - /*! - * \brief Index value associated with a key - * \param key The indexing key - * \return The iterator of the entry associated with the key, end iterator if not exists - */ - iterator find(const key_type &key) const { - ListNode node = Search(key); - return node.IsNone() ? end() : iterator(node.index, this); - } - /*! - * \brief Erase the entry associated with the iterator - * \param position The iterator - */ - void erase(const iterator &position) { - uint64_t index = position.index; - if (position.self != nullptr && index <= this->NumSlots()) { - Erase(ListNode(index, this)); - } - } - /*! \return begin iterator */ - iterator begin() const { return iterator(iter_list_head_, this); } - /*! \return end iterator */ - iterator end() const { return iterator(kInvalidIndex, this); } - -private: - Block *GetBlock(size_t index) const { return static_cast(data_) + index; } - /*! - * \brief Unlink the entry from iterator list - * \param node The node to be unlinked - * \note This function is usually used before deletion, - * and it does not change data content of the node. - */ - void IterListUnlink(ListNode node) { - // update head and tail of iterator list if needed - if (node.Item().prev == kInvalidIndex) { - iter_list_head_ = node.Item().next; - } else { - ListNode prev_node(node.Item().prev, this); - prev_node.Item().next = node.Item().next; - } - if (node.Item().next == kInvalidIndex) { - iter_list_tail_ = node.Item().prev; - } else { - ListNode next_node(node.Item().next, this); - next_node.Item().prev = node.Item().prev; - } - } - /*! - * \brief Insert the entry into tail of iterator list - * \param node The node to be inserted - * \note this function does not change data content of the node. - */ - void IterListPushBack(ListNode node) { - node.Item().prev = iter_list_tail_; - node.Item().next = kInvalidIndex; - if (iter_list_tail_ != kInvalidIndex) { - ListNode prev_node(iter_list_tail_, this); - prev_node.Item().next = node.index; - } - if (iter_list_head_ == kInvalidIndex) { - iter_list_head_ = node.index; - } - iter_list_tail_ = node.index; - } - /*! - * \brief Replace node src by dst in the iter list - * \param src The source node - * \param dst The destination node, must be empty - * \note This function does not change data content of the nodes, - * which needs to be updated by the caller. - */ - void IterListReplaceNodeBy(ListNode src, ListNode dst) { - // set link correctly on the dst - dst.Item().prev = src.Item().prev; - dst.Item().next = src.Item().next; - // update prev and next of dst - if (dst.Item().prev == kInvalidIndex) { - iter_list_head_ = dst.index; - } else { - ListNode prev_node(dst.Item().prev, this); - prev_node.Item().next = dst.index; - } - if (dst.Item().next == kInvalidIndex) { - iter_list_tail_ = dst.index; - } else { - ListNode next_node(dst.Item().next, this); - next_node.Item().prev = dst.index; - } - } - /*! - * \brief Search for the given key - * \param key The key - * \return ListNode that associated with the key - */ - ListNode Search(const key_type &key) const { - if (this->size_ == 0) { - return ListNode(); - } - for (ListNode iter = GetListHead(AnyHash()(key)); !iter.IsNone(); iter.MoveToNext(this)) { - if (AnyEqual()(key, iter.Key())) { - return iter; - } - } - return ListNode(); - } - /*! - * \brief Search for the given key, throw exception if not exists - * \param key The key - * \return ListNode that associated with the key - */ - mapped_type &At(const key_type &key) const { - ListNode iter = Search(key); - if (iter.IsNone()) { - TVM_FFI_THROW(KeyError) << "key is not in Map"; - } - return iter.Val(); - } - /*! - * \brief Try to insert a key, or do nothing if already exists - * \param key The indexing key - * \param result The linked-list entry found or just constructed - * \return A boolean, indicating if actual insertion happens - */ - bool TryInsert(const key_type &key, ListNode *result) { - if (slots_ == 0) { - return false; - } - // required that `iter` to be the head of a linked list through which we can iterator - ListNode iter = IndexFromHash(AnyHash()(key)); - // `iter` can be: 1) empty; 2) body of an irrelevant list; 3) head of the relevant list - // Case 1: empty - if (iter.IsEmpty()) { - iter.NewHead(ItemType(key, Any(nullptr))); - this->size_ += 1; - *result = iter; - return true; - } - // Case 2: body of an irrelevant list - if (!iter.IsHead()) { - // we move the elements around and construct the single-element linked list - return IsFull() ? false : TrySpareListHead(iter, key, result); - } - // Case 3: head of the relevant list - // we iterate through the linked list until the end - // make sure `iter` is the previous element of `next` - ListNode next = iter; - do { - // find equal item, do not insert - if (AnyEqual()(key, next.Key())) { - // we plan to take next, so we need to unlink it from iterator list - IterListUnlink(next); - *result = next; - return true; - } - // make sure `iter` is the previous element of `next` - iter = next; - } while (next.MoveToNext(this)); - // `iter` is the tail of the linked list - // always check capacity before insertion - if (IsFull()) { - return false; - } - // find the next empty slot - uint8_t jump; - if (!iter.GetNextEmpty(this, &jump, result)) { - return false; - } - result->NewTail(ItemType(key, Any(nullptr))); - // link `iter` to `empty`, and move forward - iter.SetJump(jump); - this->size_ += 1; - return true; - } - /*! - * \brief Spare an entry to be the head of a linked list. - * As described in B3, during insertion, it is possible that the entire linked list does not - * exist, but the slot of its head has been occupied by other linked lists. In this case, we need - * to spare the slot by moving away the elements to another valid empty one to make insertion - * possible. - * \param target The given entry to be spared - * \param key The indexing key - * \param result The linked-list entry constructed as the head - * \return A boolean, if actual insertion happens - */ - bool TrySpareListHead(ListNode target, const key_type &key, ListNode *result) { - // `target` is not the head of the linked list - // move the original item of `target` (if any) - // and construct new item on the position `target` - // To make `target` empty, we - // 1) find `w` the previous element of `target` in the linked list - // 2) copy the linked list starting from `r = target` - // 3) paste them after `w` - // read from the linked list after `r` - ListNode r = target; - // write to the tail of `w` - ListNode w = target.FindPrev(this); - // after `target` is moved, we disallow writing to the slot - bool is_first = true; - uint8_t r_meta, jump; - ListNode empty; - do { - // `jump` describes how `w` is jumped to `empty` - // rehash if there is no empty space after `w` - if (!w.GetNextEmpty(this, &jump, &empty)) { - return false; - } - // move `r` to `empty` - // first move the data over - empty.NewTail(ItemType(std::move(r.Data()))); - // then move link list chain of r to empty - // this needs to happen after NewTail so empty's prev/next get updated - IterListReplaceNodeBy(r, empty); - // explicit call destructor to destroy the item in `r` - r.DestructData(); - // clear the metadata of `r` - r_meta = r.Meta(); - if (is_first) { - is_first = false; - r.SetProtected(); - } else { - r.SetEmpty(); - } - // link `w` to `empty`, and move forward - w.SetJump(jump); - w = empty; - // move `r` forward as well - } while (r.MoveToNext(this, r_meta)); - // finally we have done moving the linked list - // fill data_ into `target` - target.NewHead(ItemType(key, Any(nullptr))); - this->size_ += 1; - *result = target; - return true; - } - /*! - * \brief Remove a ListNode - * \param iter The node to be removed - */ - void Erase(const ListNode &iter) { - this->size_ -= 1; - if (!iter.HasNext()) { - // `iter` is the last - if (!iter.IsHead()) { - // cut the link if there is any - iter.FindPrev(this).SetJump(0); - } - // unlink the node from iterator list - IterListUnlink(iter); - // IMPORTANT: must explicit call destructor `iter` to avoid memory leak - // This is because we need to recycle iter's data - iter.DestructData(); - // set the meta data to be empty - iter.SetEmpty(); - } else { - ListNode last = iter, prev = iter; - for (last.MoveToNext(this); last.HasNext(); prev = last, last.MoveToNext(this)) { - } - // needs to first unlink iter from the list - IterListUnlink(iter); - // move data from last to iter - iter.Data() = std::move(last.Data()); - // Move link chain of iter to last as we stores last node to the new iter loc. - IterListReplaceNodeBy(last, iter); - // IMPORTANT: must explicit call destructor `last` to avoid memory leak - // likely we don't need this in this particular case because Any move behavior - // keep it here to be safe so code do not depend on specific move behavior of KVType - last.DestructData(); - // set the meta data to be empty - last.SetEmpty(); - prev.SetJump(0); - } - } - /*! \brief Clear the container to empty, release all entries and memory acquired */ - void Reset() { - uint64_t n_blocks = CalcNumBlocks(this->NumSlots()); - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t *meta_ptr = GetBlock(bi)->bytes; - ItemType *data_ptr = reinterpret_cast(GetBlock(bi)->bytes + kBlockCap); - for (int j = 0; j < kBlockCap; ++j, ++meta_ptr, ++data_ptr) { - uint8_t &meta = *meta_ptr; - if (meta != kProtectedSlot && meta != kEmptySlot) { - meta = kEmptySlot; - data_ptr->ItemType::~ItemType(); - } - } - } - ReleaseMemory(); - } - /*! \brief Release the memory acquired by the container without deleting its entries stored inside - */ - void ReleaseMemory() { - if (data_ != nullptr) { - TVM_FFI_ICHECK(data_deleter_ != nullptr); - data_deleter_(data_); - } - data_ = nullptr; - data_deleter_ = nullptr; - slots_ = 0; - size_ = 0; - fib_shift_ = 63; - } - /*! - * \brief Create an empty container - * \param fib_shift The fib shift provided - * \param n_slots Number of slots required, should be power-of-two - * \return The object created - */ - static ObjectPtr Empty(uint32_t fib_shift, uint64_t n_slots) { - TVM_FFI_ICHECK_GT(n_slots, uint64_t(SmallMapObj::kMaxSize)); - // Ensure even slot count (power-of-two expected by callers; this guard - // makes the method robust if a non-even value slips through). - ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(n_slots); - Block *block = new Block[n_blocks]; - p->data_ = block; - // assign block deleter so even if we take re-alloc data - // in another shared-lib that may have different malloc/free behavior - // it will still be safe. - p->data_deleter_ = BlockDeleter; - p->SetSlotsAndDenseLayoutTag(n_slots); - p->size_ = 0; - p->fib_shift_ = fib_shift; - p->iter_list_head_ = kInvalidIndex; - p->iter_list_tail_ = kInvalidIndex; - for (uint64_t i = 0; i < n_blocks; ++i, ++block) { - std::fill(block->bytes, block->bytes + kBlockCap, kEmptySlot); - } - return p; - } - /*! - * \brief Create an empty container with elements copying from another DenseMapObj - * \param from The source container - * \return The object created - */ - static ObjectPtr CopyFrom(DenseMapObj *from) { - ObjectPtr p = make_object(); - uint64_t n_blocks = CalcNumBlocks(from->NumSlots()); - p->data_ = new Block[n_blocks]; - // assign block deleter so even if we take re-alloc data - // in another shared-lib that may have different malloc/free behavior - // it will still be safe. - p->data_deleter_ = BlockDeleter; - p->SetSlotsAndDenseLayoutTag(from->NumSlots()); - p->size_ = from->size_; - p->fib_shift_ = from->fib_shift_; - p->iter_list_head_ = from->iter_list_head_; - p->iter_list_tail_ = from->iter_list_tail_; - for (uint64_t bi = 0; bi < n_blocks; ++bi) { - uint8_t *meta_ptr_from = from->GetBlock(bi)->bytes; - ItemType *data_ptr_from = reinterpret_cast(from->GetBlock(bi)->bytes + kBlockCap); - uint8_t *meta_ptr_to = p->GetBlock(bi)->bytes; - ItemType *data_ptr_to = reinterpret_cast(p->GetBlock(bi)->bytes + kBlockCap); - for (int j = 0; j < kBlockCap; - ++j, ++meta_ptr_from, ++data_ptr_from, ++meta_ptr_to, ++data_ptr_to) { - uint8_t &meta = *meta_ptr_to = *meta_ptr_from; - TVM_FFI_ICHECK(meta != kProtectedSlot); - if (meta != kEmptySlot) { - new (data_ptr_to) ItemType(*data_ptr_from); - } - } - } - return p; - } - /*! - * \brief InsertMaybeReHash an entry into the given hash map - * \param kv The entry to be inserted - * \param map The pointer to the map, can be changed if re-hashing happens - */ - static void InsertMaybeReHash(KVType &&kv, ObjectPtr *map) { - DenseMapObj *map_node = static_cast(map->get()); - ListNode iter; - // Try to insert. If succeed, we simply return - if (map_node->TryInsert(kv.first, &iter)) { - iter.Val() = std::move(kv.second); - // update the iter list relation - map_node->IterListPushBack(iter); - return; - } - TVM_FFI_ICHECK(!map_node->IsSmallMap()); - // Otherwise, start rehash - ObjectPtr p = Empty(map_node->fib_shift_ - 1, map_node->NumSlots() * 2); - - // need to insert in the same order as the original map - for (uint64_t index = map_node->iter_list_head_; index != kInvalidIndex;) { - ListNode node(index, map_node); - // now try move src_data into the new map, note that src may still not - // be fully consumed into the call, but destructor will be called. - InsertMaybeReHash(std::move(node.Data()), &p); - // Important, needs to explicit call destructor in case move did remove - // node's internal item - index = node.Item().next; - // IMPORTANT: must explicit call destructor `node` to avoid memory leak - // We must call node.DestructData() here. - // This is because std::move() arguments in IterMaybeReHash may or may not - // explicitly move out the node.Data() - // Remove this call will cause memory leak very likely. - node.DestructData(); - } - InsertMaybeReHash(std::move(kv), &p); - map_node->ReleaseMemory(); - *map = p; - } - /*! - * \brief Check whether the hash table is full - * \return A boolean indicating whether hash table is full - */ - bool IsFull() const { // NOLINTNEXTLINE(bugprone-narrowing-conversions) - return (size_ + 1) > static_cast(NumSlots()) * kMaxLoadFactor; - } - /*! - * \brief Increment the pointer - * \param index The pointer to be incremented - * \return The increased pointer - */ - uint64_t IncItr(uint64_t index) const { - // keep at the end of iterator - if (index == kInvalidIndex) { - return index; - } - ListNode node(index, this); - return node.Item().next; - } - /*! - * \brief Decrement the pointer - * \param index The pointer to be decremented - * \return The decreased pointer - */ - uint64_t DecItr(uint64_t index) const { - // this is the end iterator, we need to return tail. - if (index == kInvalidIndex) { - return iter_list_tail_; - } - // circle around the iterator list, which is OK - ListNode node(index, this); - return node.Item().prev; - } - /*! - * \brief De-reference the pointer - * \param index The pointer to be dereferenced - * \return The result - */ - KVType *DeRefItr(uint64_t index) const { return &ListNode(index, this).Data(); } - /*! \brief Construct from hash code */ - ListNode IndexFromHash(uint64_t hash_value) const { - return ListNode(FibHash(hash_value, fib_shift_), this); - } - /*! \brief Construct from hash code if the position is head of list */ - ListNode GetListHead(uint64_t hash_value) const { - ListNode node = IndexFromHash(hash_value); - return node.IsHead() ? node : ListNode(); - } - /*! \brief Construct the number of blocks in the hash table */ - static uint64_t CalcNumBlocks(uint64_t n_slots) { return (n_slots + kBlockCap - 1) / kBlockCap; } - /*! - * \brief Calculate the power-of-2 table size given the lower-bound of required capacity. - * \param cap The lower-bound of the required capacity - * \param fib_shift The result shift for Fibonacci Hashing - * \param n_slots The result number of slots - */ - static void CalcTableSize(uint64_t cap, uint32_t *fib_shift, uint64_t *n_slots) { - uint32_t shift = 64; - uint64_t slots = 1; - for (uint64_t c = cap; c; c >>= 1) { - shift -= 1; - slots <<= 1; - } - TVM_FFI_ICHECK_GT(slots, cap); - if (slots < cap * 2) { - *fib_shift = shift - 1; - *n_slots = slots << 1; - } else { - *fib_shift = shift; - *n_slots = slots; - } - } - /*! - * \brief Fibonacci Hashing, maps a hash code to an index in a power-of-2-sized table. - * See also: https://programmingpraxis.com/2018/06/19/fibonacci-hash/. - * \param hash_value The raw hash value - * \param fib_shift The shift in Fibonacci Hashing - * \return An index calculated using Fibonacci Hashing - */ - static uint64_t FibHash(uint64_t hash_value, uint32_t fib_shift) { - constexpr uint64_t coeff = 11400714819323198485ull; - return (coeff * hash_value) >> fib_shift; - } - /*! \brief The implicit in-place linked list used to index a chain */ - struct ListNode { - /*! \brief Construct None */ - ListNode() : index(0), block(nullptr) {} - /*! \brief Construct from position */ - ListNode(uint64_t index, const DenseMapObj *self) - : index(index), block(self->GetBlock(index / kBlockCap)) {} - /*! \brief Metadata on the entry */ - uint8_t &Meta() const { return *(block->bytes + index % kBlockCap); } - /*! \brief Data on the entry */ - ItemType &Item() const { - return *(reinterpret_cast(block->bytes + kBlockCap + (index % kBlockCap) * sizeof(ItemType))); - } - /*! \brief Data on the entry */ - KVType &Data() const { return Item().data; } - /*! \brief Key on the entry */ - key_type &Key() const { return Data().first; } - /*! \brief Value on the entry */ - mapped_type &Val() const { return Data().second; } - /*! \brief If the entry is head of linked list */ - bool IsHead() const { return (Meta() & 0b10000000) == 0b00000000; } - /*! \brief If the entry is none */ - bool IsNone() const { return block == nullptr; } - /*! \brief If the entry is empty slot */ - bool IsEmpty() const { return Meta() == kEmptySlot; } - /*! \brief If the entry is protected slot */ - bool IsProtected() const { return Meta() == kProtectedSlot; } - /*! \brief Set the entry to be empty */ - void SetEmpty() const { Meta() = kEmptySlot; } - /*! \brief Destruct the item in the entry */ - void DestructData() const { - // explicit call destructor to destroy the item - // Favor this over ~KVType as MSVC may not support ~KVType (need the original name) - (&Data())->first.Any::~Any(); - (&Data())->second.Any::~Any(); - } - /*! \brief Set the entry to be protected */ - void SetProtected() const { Meta() = kProtectedSlot; } - /*! \brief Set the entry's jump to its next entry */ - void SetJump(uint8_t jump) const { (Meta() &= 0b10000000) |= jump; } - /*! \brief Construct a head of linked list in-place */ - void NewHead(ItemType v) const { - Meta() = 0b00000000; - new (&Item()) ItemType(std::move(v)); - } - /*! \brief Construct a tail of linked list in-place */ - void NewTail(ItemType v) const { - Meta() = 0b10000000; - new (&Item()) ItemType(std::move(v)); - } - - /*! \brief If the entry has next entry on the linked list */ - bool HasNext() const { return NextProbeLocation(Meta() & 0b01111111) != 0; } - /*! \brief Move the entry to the next entry on the linked list */ - bool MoveToNext(const DenseMapObj *self, uint8_t meta) { - uint64_t offset = NextProbeLocation(meta & 0b01111111); - if (offset == 0) { - index = 0; - block = nullptr; - return false; - } - // the probing will go to next position and round back to stay within the - // correct range of the slots - index = (index + offset) % self->NumSlots(); - block = self->GetBlock(index / kBlockCap); - return true; - } - /*! \brief Move the entry to the next entry on the linked list */ - bool MoveToNext(const DenseMapObj *self) { return MoveToNext(self, Meta()); } - /*! \brief Get the previous entry on the linked list */ - ListNode FindPrev(const DenseMapObj *self) const { - // start from the head of the linked list, which must exist - ListNode next = self->IndexFromHash(AnyHash()(Key())); - // `prev` is always the previous item of `next` - ListNode prev = next; - for (next.MoveToNext(self); index != next.index; prev = next, next.MoveToNext(self)) { - } - return prev; - } - /*! \brief Get the next empty jump */ - bool GetNextEmpty(const DenseMapObj *self, uint8_t *jump, ListNode *result) const { - for (uint8_t idx = 1; idx < kNumJumpDists; ++idx) { - // the probing will go to next position and round back to stay within the - // correct range of the slots - ListNode candidate((index + NextProbeLocation(idx)) % self->NumSlots(), self); - if (candidate.IsEmpty()) { - *jump = idx; - *result = candidate; - return true; - } - } - return false; - } - /*! \brief Index on the real array */ - uint64_t index; - /*! \brief Pointer to the actual block */ - Block *block; - }; - -protected: - /*! \brief fib shift in Fibonacci Hashing */ - uint32_t fib_shift_; - /*! \brief the head of iterator list */ - uint64_t iter_list_head_ = kInvalidIndex; - /*! \brief the tail of iterator list */ - uint64_t iter_list_tail_ = kInvalidIndex; - - static uint64_t NextProbeLocation(size_t index) { - /* clang-format off */ - /*! \brief Candidates of probing distance */ - static const uint64_t kNextProbeLocation[kNumJumpDists] { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - // Quadratic probing with triangle numbers. See also: - // 1) https://en.wikipedia.org/wiki/Quadratic_probing - // 2) https://fgiesen.wordpress.com/2015/02/22/triangular-numbers-mod-2n/ - // 3) https://github.com/skarupke/flat_hash_map - 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, - 136, 153, 171, 190, 210, 231, 253, 276, 300, 325, - 351, 378, 406, 435, 465, 496, 528, 561, 595, 630, - 666, 703, 741, 780, 820, 861, 903, 946, 990, 1035, - 1081, 1128, 1176, 1225, 1275, 1326, 1378, 1431, 1485, 1540, - 1596, 1653, 1711, 1770, 1830, 1891, 1953, 2016, 2080, 2145, - 2211, 2278, 2346, 2415, 2485, 2556, 2628, - // larger triangle numbers - 8515, 19110, 42778, 96141, 216153, - 486591, 1092981, 2458653, 5532801, 12442566, - 27993903, 62983476, 141717030, 318844378, 717352503, - 1614057336, 3631522476, 8170957530, 18384510628, 41364789378, - 93070452520, 209408356380, 471168559170, 1060128894105, 2385289465695, - 5366898840628, 12075518705635, 27169915244790, 61132312065111, 137547689707000, - 309482283181501, 696335127828753, 1566753995631385, 3525196511162271, 7931691992677701, - 17846306936293605, 40154190677507445, 90346928918121501, 203280589587557251, - 457381325854679626, 1029107982097042876, 2315492959180353330, 5209859154120846435, - }; - /* clang-format on */ - return kNextProbeLocation[index]; - } - friend class MapObj; - -private: - /*! - * \brief Set the number of slots and attach tags bit. - * \param n The number of slots - */ - void SetSlotsAndDenseLayoutTag(uint64_t n) { - TVM_FFI_ICHECK(((n & kSmallTagMask) == 0ull)) << "DenseMap expects MSB clear"; - slots_ = n; - } -}; - -/// \cond -#define TVM_FFI_DISPATCH_MAP(base, var, body) \ - { \ - using TSmall = SmallMapObj *; \ - using TDense = DenseMapObj *; \ - if ((base)->IsSmallMap()) { \ - TSmall var = static_cast((base)); \ - body; \ - } else { \ - TDense var = static_cast((base)); \ - body; \ - } \ - } - -#define TVM_FFI_DISPATCH_MAP_CONST(base, var, body) \ - { \ - using TSmall = const SmallMapObj *; \ - using TDense = const DenseMapObj *; \ - if ((base)->IsSmallMap()) { \ - TSmall var = static_cast((base)); \ - body; \ - } else { \ - TDense var = static_cast((base)); \ - body; \ - } \ - } - -inline MapObj::iterator::pointer MapObj::iterator::operator->() const { - TVM_FFI_MAP_FAIL_IF_CHANGED() - TVM_FFI_DISPATCH_MAP_CONST(self, p, { return p->DeRefItr(index); }); -} - -inline MapObj::iterator &MapObj::iterator::operator++() { - TVM_FFI_MAP_FAIL_IF_CHANGED() - TVM_FFI_DISPATCH_MAP_CONST(self, p, { - index = p->IncItr(index); - return *this; - }); -} - -inline MapObj::iterator &MapObj::iterator::operator--() { - TVM_FFI_MAP_FAIL_IF_CHANGED() - TVM_FFI_DISPATCH_MAP_CONST(self, p, { - index = p->DecItr(index); - return *this; - }); -} - -inline size_t MapObj::count(const key_type &key) const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->count(key); }); -} - -inline const MapObj::mapped_type &MapObj::at(const MapObj::key_type &key) const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->at(key); }); -} - -inline MapObj::mapped_type &MapObj::at(const MapObj::key_type &key) { - TVM_FFI_DISPATCH_MAP(this, p, { return p->at(key); }); -} - -inline MapObj::iterator MapObj::begin() const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->begin(); }); -} - -inline MapObj::iterator MapObj::end() const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->end(); }); -} - -inline MapObj::iterator MapObj::find(const MapObj::key_type &key) const { - TVM_FFI_DISPATCH_MAP_CONST(this, p, { return p->find(key); }); -} - -inline void MapObj::erase(const MapObj::iterator &position) { - TVM_FFI_DISPATCH_MAP(this, p, { return p->erase(position); }); -} -/// \endcond - -#undef TVM_FFI_DISPATCH_MAP -#undef TVM_FFI_DISPATCH_MAP_CONST - -inline ObjectPtr MapObj::Empty() { return SmallMapObj::Empty(); } - -inline ObjectPtr MapObj::CopyFrom(MapObj *from) { - if (from->IsSmallMap()) { - return SmallMapObj::CopyFrom(static_cast(from)); - } else { - return DenseMapObj::CopyFrom(static_cast(from)); - } -} - -template -inline ObjectPtr MapObj::CreateFromRange(IterType first, IterType last) { - int64_t _cap = std::distance(first, last); - if (_cap < 0) { - return SmallMapObj::Empty(); - } - uint64_t cap = static_cast(_cap); - if (cap < SmallMapObj::kMaxSize) { - if (cap < 2) { - return SmallMapObj::CreateFromRange(cap, first, last); - } - // need to insert to avoid duplicate keys - ObjectPtr obj = SmallMapObj::Empty(cap); - for (; first != last; ++first) { - KVType kv(*first); - SmallMapObj::InsertMaybeReHash(std::move(kv), &obj); - } - return obj; - } else { - uint32_t fib_shift; - uint64_t n_slots; - DenseMapObj::CalcTableSize(cap, &fib_shift, &n_slots); - ObjectPtr obj = DenseMapObj::Empty(fib_shift, n_slots); - for (; first != last; ++first) { - KVType kv(*first); - DenseMapObj::InsertMaybeReHash(std::move(kv), &obj); - } - return obj; - } -} - -inline void MapObj::InsertMaybeReHash(KVType &&kv, ObjectPtr *map) { - MapObj *base = static_cast(map->get()); -#if TVM_FFI_DEBUG_WITH_ABI_CHANGE - base->state_marker++; -#endif // TVM_FFI_DEBUG_WITH_ABI_CHANGE - if (base->IsSmallMap()) { - SmallMapObj *sm = static_cast(base); - if (sm->NumSlots() < SmallMapObj::kMaxSize) { - SmallMapObj::InsertMaybeReHash(std::move(kv), map); - } else if (sm->NumSlots() == SmallMapObj::kMaxSize) { - if (base->size_ < sm->NumSlots()) { - SmallMapObj::InsertMaybeReHash(std::move(kv), map); - } else { - ObjectPtr new_map = MapObj::CreateFromRange(base->begin(), base->end()); - DenseMapObj::InsertMaybeReHash(std::move(kv), &new_map); - *map = std::move(new_map); - } - } - } else { - DenseMapObj::InsertMaybeReHash(std::move(kv), map); - } -} - -/// \cond Doxygen_Suppress -/*! - * \brief Specialize make_object to be deleted for make_object and - * make_object only. - */ -template <> -inline ObjectPtr make_object<>() = delete; -/// \endcond - -/*! - * \brief Map container of NodeRef->NodeRef in DSL graph. - * Map implements copy on write semantics, which means map is mutable - * but copy will happen when array is referenced in more than two places. - * - * operator[] only provide const acces, use Set to mutate the content. - * \tparam K The key NodeRef type. - * \tparam V The value NodeRef type. - */ -template && details::storage_enabled_v>> -class Map : public ObjectRef { -public: - /*! \brief The key type of the map */ - using key_type = K; - /*! \brief The mapped type of the map */ - using mapped_type = V; - /*! \brief The iterator type of the map */ - class iterator; - /*! - * \brief Construct an Map with UnsafeInit - */ - explicit Map(UnsafeInit tag) : ObjectRef(tag) {} - /*! - * \brief default constructor - */ - Map() { data_ = MapObj::Empty(); } - /*! - * \brief move constructor - * \param other source - */ - Map(Map &&other) // NOLINT(google-explicit-constructor) - : ObjectRef(std::move(other.data_)) {} - /*! - * \brief copy constructor - * \param other source - */ - Map(const Map &other) // NOLINT(google-explicit-constructor) - : ObjectRef(other.data_) {} - - /*! - * \brief Move constructor - * \param other The other map - * \tparam KU The key type of the other map - * \tparam VU The mapped type of the other map - */ - template && details::type_contains_v>> - Map(Map &&other) // NOLINT(google-explicit-constructor) - : ObjectRef(std::move(other.data_)) {} - - /*! - * \brief Copy constructor - * \param other The other map - * \tparam KU The key type of the other map - * \tparam VU The mapped type of the other map - */ - template && details::type_contains_v>> - Map(const Map &other) : ObjectRef(other.data_) {} // NOLINT(google-explicit-constructor) - - /*! - * \brief Move assignment - * \param other The other map - */ - Map &operator=(Map &&other) { - data_ = std::move(other.data_); - return *this; - } - - /*! - * \brief Copy assignment - * \param other The other map - */ - Map &operator=(const Map &other) { - data_ = other.data_; - return *this; - } - - /*! - * \brief Move assignment - * \param other The other map - * \tparam KU The key type of the other map - * \tparam VU The mapped type of the other map - */ - template && details::type_contains_v>> - Map &operator=(Map &&other) { - data_ = std::move(other.data_); - return *this; - } - - /*! - * \brief Copy assignment - * \param other The other map - * \tparam KU The key type of the other map - * \tparam VU The mapped type of the other map - */ - template && details::type_contains_v>> - Map &operator=(const Map &other) { - data_ = other.data_; - return *this; - } - /*! - * \brief constructor from pointer - * \param n the container pointer - */ - explicit Map(ObjectPtr n) : ObjectRef(n) {} - /*! - * \brief constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - Map(IterType begin, IterType end) { - data_ = MapObj::CreateFromRange(begin, end); - } - /*! - * \brief constructor from initializer list - * \param init The initalizer list - */ - Map(std::initializer_list> init) { - data_ = MapObj::CreateFromRange(init.begin(), init.end()); - } - /*! - * \brief constructor from unordered_map - * \param init The unordered_map - */ - template - Map(const std::unordered_map &init) { // NOLINT(*) - data_ = MapObj::CreateFromRange(init.begin(), init.end()); - } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - const V at(const K &key) const { - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(GetMapObj()->at(key)); - } - /*! - * \brief Read element from map. - * \param key The key - * \return the corresonding element. - */ - const V operator[](const K &key) const { return this->at(key); } - /*! \return The size of the array */ - size_t size() const { - MapObj *n = GetMapObj(); - return n == nullptr ? 0 : n->size(); - } - /*! \return The number of elements of the key */ - size_t count(const K &key) const { - MapObj *n = GetMapObj(); - return n == nullptr ? 0 : GetMapObj()->count(key); - } - /*! \return whether array is empty */ - bool empty() const { return size() == 0; } - /*! \brief Release reference to all the elements */ - void clear() { - MapObj *n = GetMapObj(); - if (n != nullptr) { - data_ = MapObj::Empty(); - } - } - /*! - * \brief set the Map. - * \param key The index key. - * \param value The value to be setted. - */ - void Set(const K &key, const V &value) { - CopyOnWrite(); - MapObj::InsertMaybeReHash(MapObj::KVType(key, value), &data_); - } - /*! \return begin iterator */ - iterator begin() const { return iterator(GetMapObj()->begin()); } - /*! \return end iterator */ - iterator end() const { return iterator(GetMapObj()->end()); } - /*! \return find the key and returns the associated iterator */ - iterator find(const K &key) const { return iterator(GetMapObj()->find(key)); } - /*! \return The value associated with the key, std::nullopt if not found */ - std::optional Get(const K &key) const { - MapObj::iterator iter = GetMapObj()->find(key); - if (iter == GetMapObj()->end()) { - return std::nullopt; - } - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(iter->second); - } - - /*! - * \brief Erase the entry associated with the key - * \param key The key - */ - void erase(const K &key) { CopyOnWrite()->erase(key); } - - /*! - * \brief copy on write semantics - * Do nothing if current handle is the unique copy of the array. - * Otherwise make a new copy of the array to ensure the current handle - * hold a unique copy. - * - * \return Handle to the internal node container(which guarantees to be unique) - */ - MapObj *CopyOnWrite() { - if (data_.get() == nullptr) { - data_ = MapObj::Empty(); - } else if (!data_.unique()) { - data_ = MapObj::CopyFrom(GetMapObj()); - } - return GetMapObj(); - } - /*! \brief specify container node */ - using ContainerType = MapObj; - - /// \cond Doxygen_Suppress - /*! \brief Iterator of the hash map */ - class iterator { - public: - using iterator_category = std::bidirectional_iterator_tag; - using difference_type = int64_t; - using value_type = const std::pair; - using pointer = value_type *; - using reference = value_type; - - iterator() : itr() {} - - /*! \brief Compare iterators */ - bool operator==(const iterator &other) const { return itr == other.itr; } - /*! \brief Compare iterators */ - bool operator!=(const iterator &other) const { return itr != other.itr; } - /*! \brief De-reference iterators is not allowed */ - pointer operator->() const = delete; - /*! \brief De-reference iterators */ - reference operator*() const { - auto &kv = *itr; - return std::make_pair(details::AnyUnsafe::CopyFromAnyViewAfterCheck(kv.first), - details::AnyUnsafe::CopyFromAnyViewAfterCheck(kv.second)); - } - /*! \brief Prefix self increment, e.g. ++iter */ - iterator &operator++() { - ++itr; - return *this; - } - /*! \brief Suffix self increment */ - iterator operator++(int) { - iterator copy = *this; - ++(*this); - return copy; - } - - /*! \brief Prefix self decrement, e.g. --iter */ - iterator &operator--() { - --itr; - return *this; - } - /*! \brief Suffix self decrement */ - iterator operator--(int) { - iterator copy = *this; - --(*this); - return copy; - } - - private: - iterator(const MapObj::iterator &itr) // NOLINT(*) - : itr(itr) {} - - template - friend class Map; - - MapObj::iterator itr; - }; - /// \endcond - -private: - /*! \brief Return data_ as type of pointer of MapObj */ - MapObj *GetMapObj() const { return static_cast(data_.get()); } - - template - friend class Map; -}; - -/*! - * \brief Merge two Maps. - * \param lhs the first Map to merge. - * \param rhs the second Map to merge. - * @return The merged Array. Original Maps are kept unchanged. - */ -template && details::storage_enabled_v>> -inline Map Merge(Map lhs, const Map &rhs) { - for (const auto &p : rhs) { - lhs.Set(p.first, p.second); - } - return std::move(lhs); -} - -// Traits for Map -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public ObjectRefTypeTraitsBase> { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIMap; - using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny *src) { - if (src->type_index != TypeIndex::kTVMFFIMap) { - return TypeTraitsBase::GetMismatchTypeInfo(src); - } - if constexpr (!std::is_same_v || !std::is_same_v) { - const MapObj *n = reinterpret_cast(src->v_obj); - for (const auto &kv : *n) { - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.first) && !kv.first.try_cast().has_value()) { - return "Map[some key is " + details::AnyUnsafe::GetMismatchTypeInfo(kv.first) + ", V]"; - } - } - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.second) && !kv.second.try_cast().has_value()) { - return "Map[K, some value is " + details::AnyUnsafe::GetMismatchTypeInfo(kv.second) + "]"; - } - } - } - } - TVM_FFI_THROW(InternalError) << "Cannot reach here"; - TVM_FFI_UNREACHABLE(); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - if (src->type_index != TypeIndex::kTVMFFIMap) { - return false; - } - if constexpr (std::is_same_v && std::is_same_v) { - return true; - } else { - const MapObj *n = reinterpret_cast(src->v_obj); - for (const auto &kv : *n) { - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.first)) { - return false; - } - } - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.second)) { - return false; - } - } - } - return true; - } - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index != TypeIndex::kTVMFFIMap) { - return std::nullopt; - } - if constexpr (!std::is_same_v || !std::is_same_v) { - const MapObj *n = reinterpret_cast(src->v_obj); - bool storage_check = [&]() { - for (const auto &kv : *n) { - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.first)) { - return false; - } - } - if constexpr (!std::is_same_v) { - if (!details::AnyUnsafe::CheckAnyStrict(kv.second)) { - return false; - } - } - } - return true; - }(); - // fast path, if storage check passes, we can return the array directly. - if (storage_check) { - return CopyFromAnyViewAfterCheck(src); - } - // slow path, we need to create a new map and convert to the target type. - Map ret; - for (const auto &kv : *n) { - auto k = kv.first.try_cast(); - auto v = kv.second.try_cast(); - if (!k.has_value() || !v.has_value()) { - return std::nullopt; - } - ret.Set(*std::move(k), *std::move(v)); - } - return ret; - } else { - return CopyFromAnyViewAfterCheck(src); - } - } - - TVM_FFI_INLINE static std::string TypeStr() { - return "Map<" + details::Type2Str::v() + ", " + details::Type2Str::v() + ">"; - } - TVM_FFI_INLINE static std::string TypeSchema() { - std::ostringstream oss; - oss << R"({"type":")" << StaticTypeKey::kTVMFFIMap << R"(","args":[)"; - oss << details::TypeSchema::v() << ","; - oss << details::TypeSchema::v(); - oss << "]}"; - return oss.str(); - } -}; - -namespace details { -template -inline constexpr bool type_contains_v, Map> = type_contains_v && type_contains_v; -} // namespace details - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_CONTAINER_MAP_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/shape.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/shape.h deleted file mode 100644 index 074fefbc9..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/shape.h +++ /dev/null @@ -1,343 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/shape.h - * \brief Container to store shape of an Tensor. - */ -#ifndef TVM_FFI_CONTAINER_SHAPE_H_ -#define TVM_FFI_CONTAINER_SHAPE_H_ - -#include "../error.h" -#include "../type_traits.h" -#include "array.h" - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Lightweight view non-owning class for shape. - */ -class ShapeView { -public: - /*! \brief Default constructor. */ - ShapeView() : cell_{nullptr, 0} {} - /*! \brief Copy constructor. */ - ShapeView(const ShapeView &other) = default; - /*! \brief Copy assignment operator. */ - ShapeView &operator=(const ShapeView &other) = default; - /*! \brief Move constructor. */ - ShapeView(ShapeView &&other) = default; - /*! \brief Move assignment operator. */ - ShapeView &operator=(ShapeView &&other) = default; - /*! \brief Constructor from data and size. */ - ShapeView(const int64_t *data, size_t size) : cell_{data, size} {} - /*! \brief Constructor from initializer list. */ - ShapeView(const std::initializer_list &other) : cell_{other.begin(), other.size()} {} - /*! \brief Get the data pointer. */ - const int64_t *data() const { return cell_.data; } - /*! \brief Get the size of the shape. */ - size_t size() const { return cell_.size; } - - /*! \brief Get the product of the shape. */ - int64_t Product() const { - int64_t product = 1; - for (size_t i = 0; i < cell_.size; ++i) { - product *= cell_.data[i]; - } - return product; - } - - /*! \brief Get the i-th element of the shape. */ - int64_t operator[](size_t idx) const { return cell_.data[idx]; } - - /*! \return begin iterator */ - const int64_t *begin() const { return cell_.data; } - - /*! \return end iterator */ - const int64_t *end() const { return cell_.data + cell_.size; } - - /*! \return Whether shape tuple is empty */ - bool empty() const { return size() == 0; } - - /*! \return The first element of the shape tuple */ - int64_t front() const { return this->at(0); } - - /*! \return The last element of the shape tuple */ - int64_t back() const { return this->at(this->size() - 1); } - - /*! \brief Get the i-th element of the shape. */ - int64_t at(size_t idx) const { - if (idx >= this->size()) { - TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size " << this->size(); - } - return cell_.data[idx]; - } - -private: - TVMFFIShapeCell cell_; -}; - -/*! \brief An object representing a shape tuple. */ -class ShapeObj : public Object, public TVMFFIShapeCell { -public: - /*! \brief The type of shape index element. */ - using index_type = int64_t; - - /*! \brief Get "numel", meaning the number of elements of an array if the array has this shape */ - int64_t Product() const { - int64_t product = 1; - for (size_t i = 0; i < this->size; ++i) { - product *= this->data[i]; - } - return product; - } - - /// \cond Doxygen_Suppress - static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIShape; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIShape, ShapeObj, Object); - /// \endcond -}; - -namespace details { - -class ShapeObjStdImpl : public ShapeObj { -public: - explicit ShapeObjStdImpl(std::vector other) : data_{std::move(other)} { - this->data = data_.data(); - this->size = static_cast(data_.size()); - } - -private: - std::vector data_; -}; - -TVM_FFI_INLINE ObjectPtr MakeEmptyShape(size_t length, int64_t **mutable_data) { - ObjectPtr p = make_inplace_array_object(length); - static_assert(alignof(ShapeObj) % alignof(int64_t) == 0); - static_assert(sizeof(ShapeObj) % alignof(int64_t) == 0); - int64_t *data = reinterpret_cast(reinterpret_cast(p.get()) + sizeof(ShapeObj)); - if (mutable_data) { - *mutable_data = data; - } - p->data = data; - p->size = length; - return p; -} - -// inplace shape allocation -template -TVM_FFI_INLINE ObjectPtr MakeInplaceShape(IterType begin, IterType end) { - size_t length = std::distance(begin, end); - int64_t *mutable_data; - ObjectPtr p = MakeEmptyShape(length, &mutable_data); - std::copy(begin, end, mutable_data); - return p; -} - -/*! - * \brief Get the product of a shape. - * \param shape The input shape. - * \param out_strides The output strides. - * \return The product of the shape. - */ -TVM_FFI_INLINE void FillStridesFromShape(ShapeView shape, int64_t *out_strides) { - int64_t stride = 1; - for (int64_t i = static_cast(shape.size()) - 1; i >= 0; --i) { - out_strides[i] = stride; - stride *= shape[i]; - } -} - -/*! - * \brief Make a strides shape from a shape view. - * \param shape The input shape. - * \return The shape. - */ -TVM_FFI_INLINE ObjectPtr MakeStridesFromShape(ShapeView shape) { - int64_t *strides_data; - ObjectPtr strides = details::MakeEmptyShape(shape.size(), &strides_data); - FillStridesFromShape(shape, strides_data); - return strides; -} - -} // namespace details - -/*! - * \brief Managed reference to shape object. - * - * When possible, use ShapeView instead of Shape to reduce memory allocation. - * Use Shape when you need to have a managed reference to on-heap allocated shape. - * - * \sa ShapeView - */ -class Shape : public ObjectRef { -public: - /*! \brief The type of shape index element. */ - using index_type = ShapeObj::index_type; - - /*! \brief Default constructor */ - Shape() : ObjectRef(details::MakeEmptyShape(0, nullptr)) {} - - /*! - * \brief Constructor from iterator - * \param begin begin of iterator - * \param end end of iterator - * \tparam IterType The type of iterator - */ - template - Shape(IterType begin, IterType end) : Shape(details::MakeInplaceShape(begin, end)) {} - - /** - * \brief Constructor from Array - * \param shape The Array - * - * \note This constructor will copy the data content. - */ - Shape(Array shape) // NOLINT(*) - : Shape(shape.begin(), shape.end()) {} - - /*! - * \brief constructor from initializer list - * \param shape The initializer list - */ - Shape(std::initializer_list shape) : Shape(shape.begin(), shape.end()) {} - - /*! - * \brief constructor from int64_t [N] - * - * \param other a int64_t array. - */ - Shape(std::vector other) // NOLINT(*) - : ObjectRef(make_object(std::move(other))) {} - - /*! - * \brief constructor from shape view. - * \param other The shape view. - */ - Shape(ShapeView other) : Shape(other.begin(), other.end()) {} // NOLINT(*) - - /*! - * \brief Create a strides from a shape. - * \param shape The input shape. - * \return The strides. - */ - static Shape StridesFromShape(ShapeView shape) { - return Shape(details::MakeStridesFromShape(shape)); - } - - /*! - * \brief Convert to shape view. - * \return The shape view. - */ - operator ShapeView() const { return ShapeView(data(), size()); } // NOLINT(*) - - /*! - * \brief Return the data pointer - * - * \return const index_type* data pointer - */ - const int64_t *data() const { return get()->data; } - - /*! - * \brief Return the size of the shape tuple - * - * \return size_t shape tuple size - */ - size_t size() const { return get()->size; } - - /*! - * \brief Immutably read i-th element from the shape tuple. - * \param idx The index - * \return the i-th element. - */ - int64_t operator[](size_t idx) const { return this->data()[idx]; } - - /*! - * \brief Immutably read i-th element from the shape tuple. - * \param idx The index - * \return the i-th element. - */ - int64_t at(size_t idx) const { - if (idx >= this->size()) { - TVM_FFI_THROW(IndexError) << "indexing " << idx << " on a Shape of size " << this->size(); - } - return this->operator[](idx); - } - - /*! \return Whether shape tuple is empty */ - bool empty() const { return size() == 0; } - - /*! \return The first element of the shape tuple */ - int64_t front() const { return this->at(0); } - - /*! \return The last element of the shape tuple */ - int64_t back() const { return this->at(this->size() - 1); } - - /*! \return begin iterator */ - const int64_t *begin() const { return get()->data; } - - /*! \return end iterator */ - const int64_t *end() const { return (get()->data + size()); } - - /*! \return The product of the shape tuple */ - int64_t Product() const { return get()->Product(); } - - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Shape, ObjectRef, ShapeObj); - /// \endcond - -private: - explicit Shape(ObjectPtr ptr) : ObjectRef(std::move(ptr)) {} -}; - -inline std::ostream &operator<<(std::ostream &os, const Shape &shape) { - os << '['; - for (size_t i = 0; i < shape.size(); ++i) { - if (i != 0) { - os << ", "; - } - os << shape[i]; - } - os << ']'; - return os; -} - -// Shape -template <> -inline constexpr bool use_default_type_traits_v = false; - -// Allow auto conversion from Array to Shape, but not from Shape to Array -template <> -struct TypeTraits : public ObjectRefWithFallbackTraitsBase> { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIShape; - TVM_FFI_INLINE static Shape ConvertFallbackValue(Array src) { - return Shape(std::move(src)); - } -}; - -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_CONTAINER_SHAPE_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tensor.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tensor.h deleted file mode 100644 index b9533c118..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tensor.h +++ /dev/null @@ -1,785 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/tensor.h - * \brief Container to store a Tensor. - */ -#ifndef TVM_FFI_CONTAINER_TENSOR_H_ -#define TVM_FFI_CONTAINER_TENSOR_H_ - -#include "../dtype.h" -#include "../error.h" -#include "../type_traits.h" -#include "shape.h" - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -class Tensor; - -/*! - * \brief Check if the device uses direct address, where address of data indicate alignment. - * \param device The input device. - * \return True if the device uses direct address, false otherwise. - */ -inline bool IsDirectAddressDevice(const DLDevice &device) { - return device.device_type <= kDLCUDAHost || device.device_type == kDLCUDAManaged || device.device_type == kDLROCM || device.device_type == kDLROCMHost; -} - -/*! - * \brief check if a DLTensor is contiguous. - * \param arr The input DLTensor. - * \return The check result. - */ -inline bool IsContiguous(const DLTensor &arr) { - if (arr.strides == nullptr) { - return true; - } - int64_t expected_stride = 1; - for (int32_t i = arr.ndim; i != 0; --i) { - int32_t k = i - 1; - if (arr.shape[k] == 1) { - // Skip stride check if shape[k] is 1, where the dimension is contiguous - // regardless of the value of stride. - // - // For example, PyTorch will normalize stride to 1 if shape is 1 when exporting - // to DLPack. - // More context: https://github.com/pytorch/pytorch/pull/83158 - continue; - } - if (arr.strides[k] != expected_stride) { - return false; - } - expected_stride *= arr.shape[k]; - } - return true; -} - -/** - * \brief Check if the data in the DLTensor is aligned to the given alignment. - * \param arr The input DLTensor. - * \param alignment The alignment to check. - * \return True if the data is aligned to the given alignment, false otherwise. - */ -inline bool IsAligned(const DLTensor &arr, size_t alignment) { - if (IsDirectAddressDevice(arr.device)) { - return (reinterpret_cast(static_cast(arr.data) + arr.byte_offset) % alignment == 0); - } else { - return arr.byte_offset % alignment == 0; - } -} - -/*! - * \brief return the total number of bytes needed to store packed data - * - * \param numel the number of elements in the array - * \param dtype the data type of the array - * \return the total number of bytes needed to store packed data - */ -inline size_t GetDataSize(size_t numel, DLDataType dtype) { - // compatible handling sub-byte uint1(bool), which usually stored as uint8_t - // TODO(tqchen): revisit and switch to kDLBool - if (dtype.code == kDLUInt && dtype.bits == 1 && dtype.lanes == 1) { - return numel; - } - // for other sub-byte types, packing is preferred - return (numel * dtype.bits * dtype.lanes + 7) / 8; -} - -/*! - * \brief return the size of data the DLTensor holds, in terms of number of bytes - * - * \param arr the input DLTensor - * \return number of bytes of data in the DLTensor. - */ -inline size_t GetDataSize(const DLTensor &arr) { - size_t size = 1; - for (int i = 0; i < arr.ndim; ++i) { - size *= static_cast(arr.shape[i]); - } - return GetDataSize(size, arr.dtype); -} - -/*! \brief An object representing a Tensor. */ -class TensorObj : public Object, public DLTensor { -public: - /// \cond Doxygen_Suppress - static constexpr const uint32_t _type_index = TypeIndex::kTVMFFITensor; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFITensor, TensorObj, Object); - /// \endcond - - /*! - * \brief Move a Tensor to a DLPack managed tensor. - * \return The converted DLPack managed tensor. - */ - DLManagedTensor *ToDLPack() const { - TensorObj *self = const_cast(this); - DLManagedTensor *ret = new DLManagedTensor(); - ret->dl_tensor = *static_cast(self); - ret->manager_ctx = self; - ret->deleter = DLManagedTensorDeleter; - details::ObjectUnsafe::IncRefObjectHandle(self); - return ret; - } - - /*! - * \brief Move a Tensor to a DLPack managed tensor. - * \return The converted DLPack managed tensor. - */ - DLManagedTensorVersioned *ToDLPackVersioned() const { - TensorObj *self = const_cast(this); - DLManagedTensorVersioned *ret = new DLManagedTensorVersioned(); - ret->version.major = DLPACK_MAJOR_VERSION; - ret->version.minor = DLPACK_MINOR_VERSION; - ret->dl_tensor = *static_cast(self); - ret->manager_ctx = self; - ret->deleter = DLManagedTensorDeleter; - details::ObjectUnsafe::IncRefObjectHandle(self); - return ret; - } - -protected: - /*! - * \brief Deleter for DLManagedTensor. - * \param tensor The DLManagedTensor to be deleted. - */ - template - static void DLManagedTensorDeleter(TDLManagedTensor *tensor) { - TensorObj *obj = static_cast(tensor->manager_ctx); - details::ObjectUnsafe::DecRefObjectHandle(obj); - delete tensor; - } - - friend class Tensor; -}; - -namespace details { -/*! - *\brief Helper class to create an TensorObj from an NDAllocator - * - * The underlying allocator needs to be implemented by user. - */ -template -class TensorObjFromNDAlloc : public TensorObj { -public: - using Self = TensorObjFromNDAlloc; - - template - TensorObjFromNDAlloc(TNDAlloc alloc, ffi::ShapeView shape, DLDataType dtype, DLDevice device, - ExtraArgs &&...extra_args) - : alloc_(alloc) { - this->device = device; - this->ndim = static_cast(shape.size()); - this->dtype = dtype; - this->byte_offset = 0; - // inplace alloc shape and strides after data structure - this->shape = reinterpret_cast(reinterpret_cast(this) + sizeof(Self)); - this->strides = this->shape + shape.size(); - std::copy(shape.begin(), shape.end(), this->shape); - details::FillStridesFromShape(shape, this->strides); - // call allocator to alloc data - alloc_.AllocData(static_cast(this), std::forward(extra_args)...); - } - - ~TensorObjFromNDAlloc() { alloc_.FreeData(static_cast(this)); } - -private: - TNDAlloc alloc_; -}; - -/*! \brief helper class to import from DLPack legacy DLManagedTensor */ -template -class TensorObjFromDLPack : public TensorObj { -public: - using Self = TensorObjFromDLPack; - - explicit TensorObjFromDLPack(TDLPackManagedTensor *tensor, bool extra_strides_at_tail) - : tensor_(tensor) { - *static_cast(this) = tensor_->dl_tensor; - if (extra_strides_at_tail) { - this->strides = reinterpret_cast(reinterpret_cast(this) + sizeof(Self)); - details::FillStridesFromShape(ShapeView(tensor_->dl_tensor.shape, tensor_->dl_tensor.ndim), - this->strides); - } - } - - ~TensorObjFromDLPack() { - // run DLPack deleter if needed. - if (tensor_->deleter != nullptr) { - (*tensor_->deleter)(tensor_); - } - } - -private: - TDLPackManagedTensor *tensor_; -}; -} // namespace details - -/*! - * \brief Managed Tensor (n-dimensional array). - * The tensor is backed by reference counted blocks. - * - * \note This class can be subclassed to implement downstream customized - * Tensor types that are backed by the same TensorObj storage type. - */ -class Tensor : public ObjectRef { -public: - /*! - * \brief Default constructor. - */ - Tensor() = default; - /*! - * \brief Constructor from a ObjectPtr. - * \param n The ObjectPtr. - */ - explicit Tensor(::tvm::ffi::ObjectPtr n) : ObjectRef(std::move(n)) {} - /*! - * \brief Constructor from a UnsafeInit tag. - * \param tag The UnsafeInit tag. - */ - explicit Tensor(::tvm::ffi::UnsafeInit tag) : ObjectRef(tag) {} - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Tensor) - /// \endcond - /*! - * \brief Get the data pointer of the Tensor. - * \return The data pointer of the Tensor. - */ - void *data_ptr() const { return get()->data; } - - /*! - * \brief Get the device of the Tensor. - * \return The device of the Tensor. - */ - DLDevice device() const { return get()->device; } - - /*! - * \brief Get the number of dimensions in the Tensor. - * \return The number of dimensions in the Tensor. - */ - int32_t ndim() const { return get()->ndim; } - - /*! - * \brief Get the data type of the Tensor. - * \return The data type of the Tensor. - */ - DLDataType dtype() const { return get()->dtype; } - - /*! - * \brief Get the shape of the Tensor. - * \return The shape of the Tensor. - */ - ShapeView shape() const { - const TensorObj *obj = get(); - return tvm::ffi::ShapeView(obj->shape, obj->ndim); - } - - /*! - * \brief Get the strides of the Tensor. - * \return The strides of the Tensor. - */ - ShapeView strides() const { - const TensorObj *obj = get(); - TVM_FFI_ICHECK(obj->strides != nullptr || obj->ndim == 0); - return ShapeView(obj->strides, obj->ndim); - } - - /*! - * \brief Get the size of the idx-th dimension. If the idx is negative, - * it gets the size of last idx-th dimension. - * \param idx The index of the size. - * \return The size of the idx-th dimension. - */ - int64_t size(int64_t idx) const { - const TensorObj *ptr = get(); - return ptr->shape[idx >= 0 ? idx : (ptr->ndim + idx)]; - } - - /*! - * \brief Get the stride of the idx-th dimension. If the idx is negative, - * it gets the stride of last idx-th dimension. - * \param idx The index of the stride. - * \return The stride of the idx-th dimension. - */ - int64_t stride(int64_t idx) const { - const TensorObj *ptr = get(); - return ptr->strides[idx >= 0 ? idx : (ptr->ndim + idx)]; - } - - /*! - * \brief Get the number of elements in the Tensor. - * \return The number of elements in the Tensor. - */ - int64_t numel() const { return this->shape().Product(); } - /*! - * \brief Get the byte offset of the Tensor. - * \return The byte offset of the Tensor. - */ - uint64_t byte_offset() const { return get()->byte_offset; } - /*! - * \brief Check if the Tensor is contiguous. - * \return True if the Tensor is contiguous, false otherwise. - */ - bool IsContiguous() const { return tvm::ffi::IsContiguous(*get()); } - /*! - * \brief Check if the Tensor data is aligned to the given alignment. - * \param alignment The alignment to check. - * \return True if the Tensor data is aligned to the given alignment, false otherwise. - */ - bool IsAligned(size_t alignment) const { return tvm::ffi::IsAligned(*get(), alignment); } - /*! - * \brief Create a Tensor from a NDAllocator. - * - * \note When building a kernel library, we always recommend use FromEnvAlloc when possible to - * allocate intermediate Tensors. When a loaded module returns an allocated tensor to the caller, - * we need to keep the module alive before the returned tensors get freed, because its - * deleter is defined within the module. FromNDAlloc can be used by C++ applications and runtimes - * to create Tensors. - * - * Example usage: - * \code - * // CPU Allocator - * struct CPUNDAlloc { - * void AllocData(DLTensor* tensor) { tensor->data = malloc(ffi::GetDataSize(*tensor)); } - * void FreeData(DLTensor* tensor) { free(tensor->data); } - * }; - * - * // CUDA Allocator - * struct CUDANDAlloc { - * void AllocData(DLTensor* tensor) { - * size_t data_size = ffi::GetDataSize(*tensor); - * void* ptr = nullptr; - * cudaError_t err = cudaMalloc(&ptr, data_size); - * TVM_FFI_ICHECK_EQ(err, cudaSuccess) << "cudaMalloc failed: " << cudaGetErrorString(err); - * tensor->data = ptr; - * } - * void FreeData(DLTensor* tensor) { - * if (tensor->data != nullptr) { - * cudaError_t err = cudaFree(tensor->data); - * TVM_FFI_ICHECK_EQ(err, cudaSuccess) << "cudaFree failed: " << cudaGetErrorString(err); - * tensor->data = nullptr; - * } - * } - * }; - * - * // NVSHMEM Allocator - * struct NVSHMEMNDAlloc { - * void AllocData(DLTensor* tensor) { - * size_t size = tvm::ffi::GetDataSize(*tensor); - * tensor->data = nvshmem_malloc(size); - * TVM_FFI_ICHECK_NE(tensor->data, nullptr) << "nvshmem_malloc failed. size: " << size; - * } - * void FreeData(DLTensor* tensor) { nvshmem_free(tensor->data); } - * }; - * - * // Allocator usage - * ffi::Tensor cpu_tensor = ffi::Tensor::FromNDAlloc(CPUNDAlloc(), ...); - * ffi::Tensor cuda_tensor = ffi::Tensor::FromNDAlloc(CUDANDAlloc(), ...); - * ffi::Tensor nvshmem_tensor = ffi::Tensor::FromNDAlloc(NVSHMEMNDAlloc(), ...); - * \endcode - * - * \param alloc The NDAllocator. - * \param shape The shape of the Tensor. - * \param dtype The data type of the Tensor. - * \param device The device of the Tensor. - * \param extra_args Extra arguments to be forwarded to TNDAlloc. - * \return The created Tensor. - * \tparam TNDAlloc The type of the NDAllocator, impelments Alloc and Free. - * \tparam ExtraArgs Extra arguments to be passed to Alloc. - */ - template - static Tensor FromNDAlloc(TNDAlloc alloc, ffi::ShapeView shape, DLDataType dtype, DLDevice device, - ExtraArgs &&...extra_args) { - // inplace alloc shape and strides after data structure (as a result why multiply 2) - size_t num_extra_i64_at_tail = shape.size() * 2; - return Tensor(make_inplace_array_object, int64_t>( - num_extra_i64_at_tail, alloc, shape, dtype, device, - std::forward(extra_args)...)); - } - /*! - * \brief Create a Tensor from the TVMFFIEnvTensorAlloc API - * - * This function can be used together with TVMFFIEnvSetDLPackManagedTensorAllocator - * in the extra/c_env_api.h to create a Tensor from the thread-local environment allocator. - * We explicitly pass TVMFFIEnvTensorAlloc to maintain explicit dependency on extra/c_env_api.h - * - * \code - * - * ffi::Tensor tensor = ffi::Tensor::FromEnvAlloc( - * TVMFFIEnvTensorAlloc, shape, dtype, device - * ); - * - * \endcode - * - * \param env_alloc TVMFFIEnvTensorAlloc function pointer. - * \param shape The shape of the Tensor. - * \param dtype The data type of the Tensor. - * \param device The device of the Tensor. - * \return The created Tensor. - * - * \sa TVMFFIEnvTensorAlloc - */ - static Tensor FromEnvAlloc(int (*env_alloc)(DLTensor *, TVMFFIObjectHandle *), ffi::ShapeView shape, - DLDataType dtype, DLDevice device) { - TVMFFIObjectHandle out; - DLTensor prototype{}; - prototype.device = device; - prototype.dtype = dtype; - prototype.shape = const_cast(shape.data()); - prototype.ndim = static_cast(shape.size()); - TVM_FFI_CHECK_SAFE_CALL(env_alloc(&prototype, &out)); - return Tensor( - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(out))); - } - /*! - * \brief Create a Tensor from a DLPack managed tensor, pre v1.0 API. - * \param tensor The input DLPack managed tensor. - * \param require_alignment The minimum alignment requored of the data + byte_offset. - * \param require_contiguous Boolean flag indicating if we need to check for contiguity. - * \note This function will not run any checks on flags. - * \return The created Tensor. - */ - static Tensor FromDLPack(DLManagedTensor *tensor, size_t require_alignment = 0, - bool require_contiguous = false) { - if (require_alignment != 0 && !ffi::IsAligned(tensor->dl_tensor, require_alignment)) { - TVM_FFI_THROW(RuntimeError) << "FromDLPack: Data is not aligned to " << require_alignment - << " bytes."; - } - if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) { - TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous."; - } - if (tensor->dl_tensor.strides != nullptr || tensor->dl_tensor.ndim == 0) { - return Tensor(make_object>( - tensor, /*extra_strides_at_tail=*/false)); - } else { - return Tensor( - make_inplace_array_object, int64_t>( - tensor->dl_tensor.ndim, tensor, /*extra_strides_at_tail=*/true)); - } - } - - /*! - * \brief Create a Tensor from a DLPack managed tensor, post v1.0 API. - * \param tensor The input DLPack managed tensor. - * \param require_alignment The minimum alignment requored of the data + byte_offset. - * \param require_contiguous Boolean flag indicating if we need to check for contiguity. - * \return The created Tensor. - */ - static Tensor FromDLPackVersioned(DLManagedTensorVersioned *tensor, size_t require_alignment = 0, - bool require_contiguous = false) { - if (require_alignment != 0 && !ffi::IsAligned(tensor->dl_tensor, require_alignment)) { - TVM_FFI_THROW(RuntimeError) << "FromDLPack: Data is not aligned to " << require_alignment - << " bytes."; - } - if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) { - TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous."; - } - if (tensor->flags & DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED) { - TVM_FFI_THROW(RuntimeError) << "Subbyte type padded is not yet supported"; - } - if (tensor->dl_tensor.strides != nullptr || tensor->dl_tensor.ndim == 0) { - return Tensor(make_object>( - tensor, /*extra_strides_at_tail=*/false)); - } else { - return Tensor( - make_inplace_array_object, - int64_t>(tensor->dl_tensor.ndim, tensor, - /*extra_strides_at_tail=*/true)); - } - } - - /*! - * \brief Convert the Tensor to a DLPack managed tensor. - * \return The converted DLPack managed tensor. - */ - DLManagedTensor *ToDLPack() const { return get_mutable()->ToDLPack(); } - - /*! - * \brief Convert the Tensor to a DLPack managed tensor. - * \return The converted DLPack managed tensor. - */ - DLManagedTensorVersioned *ToDLPackVersioned() const { return get_mutable()->ToDLPackVersioned(); } - /*! - * \brief Get the underlying DLTensor pointer. - * \return The underlying DLTensor pointer. - */ - const DLTensor *GetDLTensorPtr() const { return get(); } - /// \cond Doxygen_Suppress - [[maybe_unused]] static constexpr bool _type_is_nullable = true; - using ContainerType = TensorObj; - /// \endcond - - // the following code are convenient APIs redirections created to provide aten-style api - /*! - * \brief This functions redirects to ndim(). - * \return The number of dimensions in the Tensor. - */ - inline int32_t dim() { return ndim(); } - /*! - * \brief This functions redirects to shape(). - * \return The shape of the Tensor. - */ - inline ShapeView sizes() const { return shape(); } - /*! - * \brief This functions redirects to IsContiguous(). - * \return True if the Tensor is contiguous, false otherwise. - */ - inline bool is_contiguous() const { return IsContiguous(); } - -protected: - /*! - * \brief Get const internal container pointer. - * \return a const container pointer. - */ - const TensorObj *get() const { return static_cast(ObjectRef::get()); } - /*! - * \brief Get mutable internal container pointer. - * \return a mutable container pointer. - */ - TensorObj *get_mutable() const { return const_cast(get()); } -}; - -/*! - * \brief A non-owning view of a Tensor. - * - * This class stores a light-weight non-owning view of a Tensor. - * This is useful for accessing tensor data without retaining a strong reference to the Tensor. - * Since the caller may not always be able to pass in a strong referenced tensor. - * - * It is the user's responsibility to ensure - * that the underlying tensor data outlives the `TensorView`. - * This responsibility extends to all data pointed to by the underlying DLTensor. - * This includes not only the tensor elements in data but also the memory for shape and strides. - * - * When exposing a function that expects only expects a TensorView, we recommend using - * ffi::TensorView as the argument type instead of ffi::Tensor. - */ -class TensorView { -public: - /*! - * \brief Create a TensorView from a Tensor. - * \param tensor The input Tensor. - */ - TensorView(const Tensor &tensor) { // NOLINT(*) - TVM_FFI_ICHECK(tensor.defined()); - tensor_ = *tensor.GetDLTensorPtr(); - } // NOLINT(*) - /*! - * \brief Create a TensorView from a DLTensor. - * \param tensor The input DLTensor. - */ - TensorView(const DLTensor *tensor) { // NOLINT(*) - TVM_FFI_ICHECK(tensor != nullptr); - tensor_ = *tensor; - } - /*! - * \brief Copy constructor. - * \param tensor The input TensorView. - */ - TensorView(const TensorView &tensor) = default; - /*! - * \brief Move constructor. - * \param tensor The input TensorView. - */ - TensorView(TensorView &&tensor) = default; - /*! - * \brief Copy assignment operator. - * \param tensor The input TensorView. - * \return The created TensorView. - */ - TensorView &operator=(const TensorView &tensor) = default; - /*! - * \brief Move assignment operator. - * \param tensor The input TensorView. - * \return The created TensorView. - */ - TensorView &operator=(TensorView &&tensor) = default; - /*! - * \brief Assignment operator from a Tensor. - * \param tensor The input Tensor. - * \return The created TensorView. - */ - TensorView &operator=(const Tensor &tensor) { - TVM_FFI_ICHECK(tensor.defined()); - tensor_ = *tensor.GetDLTensorPtr(); - return *this; - } - - // explicitly delete move constructor - TensorView(Tensor &&tensor) = delete; // NOLINT(*) - // delete move assignment operator from owned tensor - TensorView &operator=(Tensor &&tensor) = delete; - /*! - * \brief Get the data pointer of the Tensor. - * \return The data pointer of the Tensor. - */ - void *data_ptr() const { return tensor_.data; } - /*! - * \brief Get the device of the Tensor. - * \return The device of the Tensor. - */ - DLDevice device() const { return tensor_.device; } - /*! - * \brief Get the number of dimensions in the Tensor. - * \return The number of dimensions in the Tensor. - */ - int32_t ndim() const { return tensor_.ndim; } - /*! - * \brief Get the data type of the Tensor. - * \return The data type of the Tensor. - */ - DLDataType dtype() const { return tensor_.dtype; } - /*! - * \brief Get the shape of the Tensor. - * \return The shape of the Tensor. - */ - ShapeView shape() const { return ShapeView(tensor_.shape, tensor_.ndim); } - - /*! - * \brief Get the number of elements in the Tensor. - * \return The number of elements in the Tensor. - */ - int64_t numel() const { return this->shape().Product(); } - - /*! - * \brief Get the strides of the Tensor. - * \return The strides of the Tensor. - */ - ShapeView strides() const { - TVM_FFI_ICHECK(tensor_.strides != nullptr || tensor_.ndim == 0); - return ShapeView(tensor_.strides, tensor_.ndim); - } - - /*! - * \brief Get the size of the idx-th dimension. If the idx is negative, - * it gets the size of last idx-th dimension. - * \param idx The index of the size. - * \return The size of the idx-th dimension. - */ - int64_t size(int64_t idx) const { return tensor_.shape[idx >= 0 ? idx : tensor_.ndim + idx]; } - - /*! - * \brief Get the stride of the idx-th dimension. If the idx is negative, - * it gets the stride of last idx-th dimension. - * \param idx The index of the stride. - * \return The stride of the idx-th dimension. - */ - int64_t stride(int64_t idx) const { return tensor_.strides[idx >= 0 ? idx : tensor_.ndim + idx]; } - - /*! - * \brief Get the byte offset of the Tensor. - * \return The byte offset of the Tensor. - */ - uint64_t byte_offset() const { return tensor_.byte_offset; } - - /*! - * \brief Check if the Tensor is contiguous. - * \return True if the Tensor is contiguous, false otherwise. - */ - bool IsContiguous() const { return tvm::ffi::IsContiguous(tensor_); } - - // the following code are convenient APIs redirections created to provide aten-style api - /*! - * \brief This functions redirects to ndim(). - * \return The number of dimensions in the Tensor. - */ - inline int32_t dim() { return ndim(); } - /*! - * \brief This functions redirects to shape(). - * \return The shape of the Tensor. - */ - inline ShapeView sizes() const { return shape(); } - /*! - * \brief This functions redirects to IsContiguous(). - * \return True if the Tensor is contiguous, false otherwise. - */ - inline bool is_contiguous() const { return IsContiguous(); } - -private: - DLTensor tensor_; - template - friend struct TypeTraits; -}; - -/*! - * \brief Get the data size of the Tensor. - * \param tensor The input Tensor. - * \return The data size of the Tensor. - */ -inline size_t GetDataSize(const Tensor &tensor) { - return GetDataSize(tensor.numel(), tensor.dtype()); -} - -/*! - * \brief Get the data size of the TensorView. - * \param tensor The input TensorView. - * \return The data size of the TensorView. - */ -inline size_t GetDataSize(const TensorView &tensor) { - return GetDataSize(tensor.numel(), tensor.dtype()); -} - -// TensorView type, allow implicit casting from DLTensor* -// NOTE: we deliberately do not support MoveToAny and MoveFromAny since it does not retain ownership -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr bool storage_enabled = false; - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDLTensorPtr; - - TVM_FFI_INLINE static void CopyToAnyView(const TensorView &src, TVMFFIAny *result) { - result->type_index = TypeIndex::kTVMFFIDLTensorPtr; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_ptr = const_cast(&(src.tensor_)); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - return src->type_index == TypeIndex::kTVMFFIDLTensorPtr; - } - - TVM_FFI_INLINE static TensorView CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { - return TensorView(static_cast(src->v_ptr)); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFIDLTensorPtr) { - return TensorView(static_cast(src->v_ptr)); - } else if (src->type_index == TypeIndex::kTVMFFITensor) { - return TensorView(TVMFFITensorGetDLTensorPtr(src->v_obj)); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIDLTensorPtr; } - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIDLTensorPtr) + R"("})"; - } -}; - -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_CONTAINER_TENSOR_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tuple.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tuple.h deleted file mode 100644 index 483333195..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/tuple.h +++ /dev/null @@ -1,400 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/tuple.h - * \brief Typed tuple like std::tuple backed by ArrayObj container. - */ -#ifndef TVM_FFI_CONTAINER_TUPLE_H_ -#define TVM_FFI_CONTAINER_TUPLE_H_ - -#include - -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Typed tuple like std::tuple backed by ArrayObj container. - * - * Tuple implements in-place copy-on-write semantics. - * - * \tparam Types The types of the tuple elements - */ -template -class Tuple : public ObjectRef { -public: - static_assert(details::all_storage_enabled_v, - "All types used in Tuple<...> must be compatible with Any"); - /*! \brief Default constructor */ - Tuple() : ObjectRef(MakeDefaultTupleNode()) {} - /*! - * \brief Constructor with UnsafeInit - */ - explicit Tuple(UnsafeInit tag) : ObjectRef(tag) {} - /*! \brief Copy constructor */ - Tuple(const Tuple &other) : ObjectRef(other) {} - /*! \brief Move constructor */ - Tuple(Tuple &&other) noexcept : ObjectRef(std::move(other)) {} - /*! - * \brief Constructor from another tuple - * \param other The other tuple - * \tparam UTypes The types of the other tuple - * \tparam The enable_if_t type - */ - template && ...), int>> - Tuple(const Tuple &other) : ObjectRef(other) {} // NOLINT(google-explicit-constructor) - - /*! - * \brief Constructor from another tuple - * \param other The other tuple - * \tparam UTypes The types of the other tuple - * \tparam The enable_if_t type - */ - template && ...), int>> - Tuple(Tuple &&other) // NOLINT(google-explicit-constructor) - : ObjectRef(std::move(other)) {} - - /*! - * \brief Constructor from arguments - * \param args The arguments - * \tparam UTypes The types of the other tuple - */ - template , Tuple> && ...))>> - explicit Tuple(UTypes &&...args) : ObjectRef(MakeTupleNode(std::forward(args)...)) {} - - /*! - * \brief Assignment from another tuple - * \param other The other tuple - * \tparam The enable_if_t type - */ - TVM_FFI_INLINE Tuple &operator=(const Tuple &other) { - data_ = other.data_; - return *this; - } - - /*! - * \brief Assignment from another tuple - * \param other The other tuple - * \tparam The enable_if_t type - */ - TVM_FFI_INLINE Tuple &operator=(Tuple &&other) noexcept { - data_ = std::move(other.data_); - return *this; - } - - /*! - * \brief Assignment from another tuple - * \param other The other tuple - * \tparam UTypes The types of the other tuple - * \tparam The enable_if_t type - */ - template && ...)>> - TVM_FFI_INLINE Tuple &operator=(const Tuple &other) { - data_ = other.data_; - return *this; - } - - /*! - * \brief Assignment from another tuple - * \param other The other tuple - * \tparam UTypes The types of the other tuple - * \tparam The enable_if_t type - */ - template && ...)>> - TVM_FFI_INLINE Tuple &operator=(Tuple &&other) { - data_ = std::move(other.data_); - return *this; - } - - /*! - * \brief Get I-th element of the tuple - * - * \tparam I The index of the element to get - * \return The I-th element of the tuple - * \note We use stl style since get usually is like a getter. - */ - template - auto get() const & { - static_assert(I < sizeof...(Types), "Tuple index out of bounds"); - using ReturnType = std::tuple_element_t>; - const Any *ptr = GetArrayObj()->begin() + I; - return details::AnyUnsafe::CopyFromAnyViewAfterCheck(*ptr); - } - - /*! - * \brief Move out I-th element of the tuple - * - * \tparam I The index of the element to get - * \return The I-th element of the tuple - * \note We use stl style since get usually is like a getter. - */ - template - auto get() && { - if (!this->unique()) { - // fallback to const& version if not unique - return std::as_const(*this).template get(); - } - static_assert(I < sizeof...(Types), "Tuple index out of bounds"); - using ReturnType = std::tuple_element_t>; - Any *ptr = GetArrayObj()->MutableBegin() + I; - return details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(*ptr)); - } - - /*! - * \brief Set I-th element of the tuple - * - * \param item The item to set - * \tparam I The index of the element to set - * \tparam U The type of the item - * - * \note This function will perform copy on write if underlying - * container is not uniquely owned. - * We use CamelCase since Set can cause copy on write - * and is more complicated than simple field setter. - */ - template - void Set(U &&item) { - static_assert(I < sizeof...(Types), "Tuple index out of bounds"); - using T = std::tuple_element_t>; - this->CopyIfNotUnique(); - Any *ptr = GetArrayObj()->MutableBegin() + I; - *ptr = T(std::forward(item)); - } - - /*! \brief specify container node */ - using ContainerType = ArrayObj; - -private: - static ObjectPtr MakeDefaultTupleNode() { - ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); - Any *itr = p->MutableBegin(); - // increase size after each new to ensure exception safety - ((new (itr++) Any(Types()), p->size_++), ...); - return p; - } - - template - static ObjectPtr MakeTupleNode(UTypes &&...args) { - ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); - Any *itr = p->MutableBegin(); - // increase size after each new to ensure exception safety - ((new (itr++) Any(Types(std::forward(args))), p->size_++), ...); - return p; - } - - /*! \brief Copy on write */ - void CopyIfNotUnique() { - if (!data_.unique()) { - ObjectPtr p = ArrayObj::Empty(sizeof...(Types)); - Any *itr = p->MutableBegin(); - const Any *read = GetArrayObj()->begin(); - // increase size after each new to ensure exception safety - for (size_t i = 0; i < sizeof...(Types); ++i) { - new (itr++) Any(*read++); - p->size_++; - } - data_ = std::move(p); - } - } - - /*! \return The underlying ArrayObj */ - ArrayObj *GetArrayObj() const { return static_cast(data_.get()); } - - template - friend class Tuple; -}; - -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public ObjectRefTypeTraitsBase> { - using ObjectRefTypeTraitsBase>::CopyFromAnyViewAfterCheck; - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny *src) { - if (src->type_index != TypeIndex::kTVMFFIArray) { - return TypeTraitsBase::GetMismatchTypeInfo(src); - } - const ArrayObj *n = reinterpret_cast(src->v_obj); - if (n->size() != sizeof...(Types)) { - return "Array[size=" + std::to_string(n->size()) + "]"; - } - return GetMismatchTypeInfoHelper<0, Types...>(n->begin()); - } - - template - TVM_FFI_INLINE static std::string GetMismatchTypeInfoHelper(const Any *arr) { - if constexpr (!std::is_same_v) { - const Any &any_v = arr[I]; - if (!details::AnyUnsafe::CheckAnyStrict(any_v) && !(any_v.try_cast().has_value())) { - // now report the accurate mismatch information - return "Array[index " + std::to_string(I) + ": " + details::AnyUnsafe::GetMismatchTypeInfo(any_v) + "]"; - } - } - if constexpr (sizeof...(Rest) > 0) { - return GetMismatchTypeInfoHelper(arr); - } - TVM_FFI_THROW(InternalError) << "Cannot reach here"; - TVM_FFI_UNREACHABLE(); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - if (src->type_index != TypeIndex::kTVMFFIArray) { - return false; - } - const ArrayObj *n = reinterpret_cast(src->v_obj); - if (n->size() != sizeof...(Types)) { - return false; - } - const TVMFFIAny *ffi_any_arr = reinterpret_cast(n->begin()); - return CheckAnyStrictHelper<0, Types...>(ffi_any_arr); - } - - template - TVM_FFI_INLINE static bool CheckAnyStrictHelper(const TVMFFIAny *src_arr) { - if constexpr (!std::is_same_v) { - if (!TypeTraits::CheckAnyStrict(src_arr + I)) { - return false; - } - } - if constexpr (sizeof...(Rest) > 0) { - return CheckAnyStrictHelper(src_arr); - } - return true; - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index != TypeIndex::kTVMFFIArray) { - return std::nullopt; - } - const ArrayObj *n = reinterpret_cast(src->v_obj); - if (n->size() != sizeof...(Types)) { - return std::nullopt; - } - // fast path, storage is already in the right type - if (CheckAnyStrict(src)) { - return CopyFromAnyViewAfterCheck(src); - } - // slow path, try to convert to each type to match the tuple storage need. - Array arr = TypeTraits>::CopyFromAnyViewAfterCheck(src); - Any *ptr = arr.CopyOnWrite()->MutableBegin(); - if (TryConvertElements<0, Types...>(ptr)) { - return details::ObjectUnsafe::ObjectRefFromObjectPtr>( - details::ObjectUnsafe::ObjectPtrFromObjectRef(arr)); - } - return std::nullopt; - } - - template - TVM_FFI_INLINE static bool TryConvertElements(Any *arr) { - if constexpr (!std::is_same_v) { - if (auto opt_convert = arr[I].try_cast()) { - arr[I] = *std::move(opt_convert); - } else { - return false; - } - } - if constexpr (sizeof...(Rest) > 0) { - return TryConvertElements(std::move(arr)); - } else { - return true; - } - } - - TVM_FFI_INLINE static std::string TypeStr() { - return details::ContainerTypeStr("Tuple"); - } - TVM_FFI_INLINE static std::string TypeSchema() { - std::ostringstream oss; - oss << R"({"type":"Tuple","args":[)"; - const char *sep = ""; - ((oss << sep << details::TypeSchema::v(), sep = ","), ...); - oss << "]}"; - return oss.str(); - } -}; - -namespace details { -template -inline constexpr bool type_contains_v, Tuple> = (type_contains_v && ...); -} // namespace details - -/// \cond Doxygen_Suppress - -/// NOTE: ADL friendly get functions -/// Example usage: { using std::get; get<0>(t); } -/// ADL will find the right get function - -/** - * \brief get I-th element of the tuple - * \tparam I The index of the element to get - * \param t The tuple - * \return The I-th element of the tuple - */ -template -inline constexpr auto get(const Tuple &t) - -> std::tuple_element_t> { - return t.template get(); -} - -/** - * \brief get I-th element of the tuple - * \tparam I The index of the element to get - * \param t The tuple (rvalue) - * \return The I-th element of the tuple - */ -template -inline constexpr auto get(Tuple &&t) -> std::tuple_element_t> { - return std::move(t).template get(); -} - -/// NOTE: C++17 deduction guide -template -Tuple(UTypes &&...) -> Tuple>...>; - -/// \endcond - -} // namespace ffi -} // namespace tvm - -namespace std { - -template -struct tuple_size<::tvm::ffi::Tuple> - : public std::integral_constant {}; - -template -struct tuple_element> { - using type = std::tuple_element_t>; -}; - -} // namespace std - -#endif // TVM_FFI_CONTAINER_TUPLE_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/variant.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/variant.h deleted file mode 100644 index e1b91b526..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/container/variant.h +++ /dev/null @@ -1,311 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/container/variant.h - * \brief Runtime variant container types. - */ -#ifndef TVM_FFI_CONTAINER_VARIANT_H_ -#define TVM_FFI_CONTAINER_VARIANT_H_ - -#include -#include -#include - -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace details { -/*! - * \brief Base class for Variant. - * - * \tparam all_storage_object Whether all types are derived from ObjectRef. - */ -template -class VariantBase { -public: - TVM_FFI_INLINE bool same_as(const VariantBase &other) const { - return data_.same_as(other.data_); - } - -protected: - template - explicit VariantBase(T other) : data_(std::move(other)) {} - - TVM_FFI_INLINE void SetData(Any other_data) { data_ = std::move(other_data); } - - TVM_FFI_INLINE Any MoveToAny() && { return std::move(data_); } - - TVM_FFI_INLINE AnyView ToAnyView() const { return data_.operator AnyView(); } - - Any data_; -}; - -// Specialization for all object ref case, backed by ObjectRef. -template <> -class VariantBase : public ObjectRef { -protected: - template - explicit VariantBase(const T &other) : ObjectRef(other) {} - template , VariantBase>>> - explicit VariantBase(T &&other) : ObjectRef(std::forward(other)) {} - explicit VariantBase(UnsafeInit tag) : ObjectRef(tag) {} - explicit VariantBase(Any other) - : ObjectRef(details::AnyUnsafe::MoveFromAnyAfterCheck(std::move(other))) {} - - TVM_FFI_INLINE void SetData(ObjectPtr other) { data_ = std::move(other); } - - TVM_FFI_INLINE Any MoveToAny() && { return Any(ObjectRef(std::move(data_))); } - - TVM_FFI_INLINE AnyView ToAnyView() const { - TVMFFIAny any_data; - if (data_ == nullptr) { - any_data.type_index = TypeIndex::kTVMFFINone; - any_data.zero_padding = 0; - any_data.v_int64 = 0; - } else { - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&any_data); - any_data.type_index = data_->type_index(); - any_data.zero_padding = 0; - any_data.v_obj = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectPtr(data_); - } - return AnyView::CopyFromTVMFFIAny(any_data); - } -}; -} // namespace details - -/*! - * \brief A typed variant container. - * - * When all values are ObjectRef, Variant is backed by ObjectRef, - * otherwise it is backed by Any. - */ -template -class Variant : public details::VariantBase> { -public: - /// \cond Doxygen_Suppress - using TParent = details::VariantBase>; - static_assert(details::all_storage_enabled_v, - "All types used in Variant<...> must be compatible with Any"); - /* - * \brief Helper utility to check if the type can be contained in the variant - */ - template - static constexpr bool variant_contains_v = (details::type_contains_v || ...); - /* \brief Helper utility for SFINAE if the type is part of the variant */ - template - using enable_if_variant_contains_t = std::enable_if_t>; - /// \endcond - /*! - * \brief Constructor from another variant - * \param other The other variant - */ - Variant(const Variant &other) : TParent(other.data_) {} - /*! - * \brief Constructor from another variant - * \param other The other variant - */ - Variant(Variant &&other) noexcept : TParent(std::move(other.data_)) {} - - /*! - * \brief Assignment from another variant - * \param other The other variant - */ - TVM_FFI_INLINE Variant &operator=(const Variant &other) { - this->SetData(other.data_); - return *this; - } - - /*! - * \brief Assignment from another variant - * \param other The other variant - */ - TVM_FFI_INLINE Variant &operator=(Variant &&other) noexcept { - this->SetData(std::move(other.data_)); - return *this; - } - - /*! - * \brief Constructor from another variant - * \param other The other variant - */ - template > - Variant(T other) : TParent(std::move(other)) {} // NOLINT(*) - - /*! - * \brief Assignment from another variant - * \param other The other variant - */ - template > - TVM_FFI_INLINE Variant &operator=(T other) { - return operator=(Variant(std::move(other))); - } - - /*! - * \brief Try to cast to a type T, return std::nullopt if the cast is not possible. - * \return The casted value, or std::nullopt if the cast is not possible. - * \tparam T The type to cast to. - */ - template > - TVM_FFI_INLINE std::optional as() const { - return this->TParent::ToAnyView().template as(); - } - - /*! - * \brief Shortcut of as Object to cast to a const pointer when T is an Object. - * - * \tparam T The object type. - * \return The requested pointer, returns nullptr if type mismatches. - */ - template >> - TVM_FFI_INLINE const T *as() const { - return this->TParent::ToAnyView().template as().value_or(nullptr); - } - - /*! - * \brief Get the value of the variant in type T, throws an exception if cast fails. - * \return The value of the variant - * \tparam T The type to get. - */ - template > - TVM_FFI_INLINE T get() const & { - return this->TParent::ToAnyView().template cast(); - } - - /*! - * \brief Get the value of the variant in type T, throws an exception if cast fails. - * \return The value of the variant - * \tparam T The type to get. - */ - template > - TVM_FFI_INLINE T get() && { - return std::move(*this).TParent::MoveToAny().template cast(); - } - - /*! - * \brief Get the type key of the variant - * \return The type key of the variant - */ - TVM_FFI_INLINE std::string GetTypeKey() const { return this->TParent::ToAnyView().GetTypeKey(); } - -private: - friend struct TypeTraits>; - friend struct ObjectPtrHash; - friend struct ObjectPtrEqual; - // constructor from any - explicit Variant(Any data) : TParent(std::move(data)) {} - /*! - * \brief Get the object pointer from the variant - * \note This function is only available if all types used in Variant<...> are derived from - * ObjectRef - */ - TVM_FFI_INLINE Object *GetObjectPtrForHashEqual() const { - constexpr bool all_object_v = (std::is_base_of_v && ...); - static_assert(all_object_v, - "All types used in Variant<...> must be derived from ObjectRef " - "to enable ObjectPtrHash/ObjectPtrEqual"); - return this->data_.get(); - } - // rexpose to friend class - using TParent::MoveToAny; - using TParent::ToAnyView; -}; - -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public TypeTraitsBase { - TVM_FFI_INLINE static void CopyToAnyView(const Variant &src, TVMFFIAny *result) { - *result = src.ToAnyView().CopyToTVMFFIAny(); - } - - TVM_FFI_INLINE static void MoveToAny(Variant src, TVMFFIAny *result) { - *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(std::move(src).MoveToAny()); - } - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny *src) { - return TypeTraitsBase::GetMismatchTypeInfo(src); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - return (TypeTraits::CheckAnyStrict(src) || ...); - } - - TVM_FFI_INLINE static Variant CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { - return Variant(Any(AnyView::CopyFromTVMFFIAny(*src))); - } - - TVM_FFI_INLINE static Variant MoveFromAnyAfterCheck(TVMFFIAny *src) { - return Variant(details::AnyUnsafe::MoveTVMFFIAnyToAny(src)); - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny *src) { - // fast path, storage is already in the right type - if (CheckAnyStrict(src)) { - return CopyFromAnyViewAfterCheck(src); - } - // More expensive path, try to convert to each type, in order of declaration - return TryVariantTypes(src); - } - - template - TVM_FFI_INLINE static std::optional> TryVariantTypes(const TVMFFIAny *src) { - if (auto opt_convert = TypeTraits::TryCastFromAnyView(src)) { - return Variant(*std::move(opt_convert)); - } - if constexpr (sizeof...(Rest) > 0) { - return TryVariantTypes(src); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return details::ContainerTypeStr("Variant"); } - TVM_FFI_INLINE static std::string TypeSchema() { - std::ostringstream oss; - oss << R"({"type":"Variant","args":[)"; - const char *sep = ""; - ((oss << sep << details::TypeSchema::v(), sep = ","), ...); - oss << "]}"; - return oss.str(); - } -}; - -template -TVM_FFI_INLINE size_t ObjectPtrHash::operator()(const Variant &a) const { - return std::hash()(a.GetObjectPtrForHashEqual()); -} - -template -TVM_FFI_INLINE bool ObjectPtrEqual::operator()(const Variant &a, - const Variant &b) const { - return a.GetObjectPtrForHashEqual() == b.GetObjectPtrForHashEqual(); -} - -namespace details { -template -inline constexpr bool type_contains_v, T> = (type_contains_v || ...); -} // namespace details -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_CONTAINER_VARIANT_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/dtype.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/dtype.h deleted file mode 100644 index 32a0da5f9..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/dtype.h +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/dtype.h - * \brief Data type handling. - */ -#ifndef TVM_FFI_DTYPE_H_ -#define TVM_FFI_DTYPE_H_ - -#include "../../dlpack/dlpack.h" -#include "error.h" -#include "function.h" -#include "string.h" -#include "type_traits.h" - -#include -#include - -namespace tvm { -namespace ffi { -/*! - * \brief Extension code beyond the DLDataType. - * - * This class is always consistent with the DLPack. - */ -enum DLExtDataTypeCode { kDLExtCustomBegin = 129 }; - -namespace details { - -/* - * \brief Convert a DLDataTypeCode to a string. - * \param os The output stream. - * \param type_code The DLDataTypeCode to convert. - */ -inline const char *DLDataTypeCodeAsCStr(DLDataTypeCode type_code) { // NOLINT(*) - switch (static_cast(type_code)) { - case kDLInt: { - return "int"; - } - case kDLUInt: { - return "uint"; - } - case kDLFloat: { - return "float"; - } - case kDLOpaqueHandle: { - return "handle"; - } - case kDLBfloat: { - return "bfloat"; - } - case kDLBool: { - return "bool"; - } - case kDLFloat8_e3m4: { - return "float8_e3m4"; - } - case kDLFloat8_e4m3: { - return "float8_e4m3"; - } - case kDLFloat8_e4m3b11fnuz: { - return "float8_e4m3b11fnuz"; - } - case kDLFloat8_e4m3fn: { - return "float8_e4m3fn"; - } - case kDLFloat8_e4m3fnuz: { - return "float8_e4m3fnuz"; - } - case kDLFloat8_e5m2: { - return "float8_e5m2"; - } - case kDLFloat8_e5m2fnuz: { - return "float8_e5m2fnuz"; - } - case kDLFloat8_e8m0fnu: { - return "float8_e8m0fnu"; - } - case kDLFloat6_e2m3fn: { - return "float6_e2m3fn"; - } - case kDLFloat6_e3m2fn: { - return "float6_e3m2fn"; - } - case kDLFloat4_e2m1fn: { - return "float4_e2m1fn"; - } - default: { - if (static_cast(type_code) >= static_cast(DLExtDataTypeCode::kDLExtCustomBegin)) { - return "custom"; - } else { - TVM_FFI_THROW(ValueError) << "DLDataType contains unknown type_code=" - << static_cast(type_code); - } - TVM_FFI_UNREACHABLE(); - } - } -} -} // namespace details - -/*! - * \brief Convert a string to a DLDataType. - * \param str The string to convert. - * \return The DLDataType. - */ -inline DLDataType StringToDLDataType(const String &str) { - DLDataType out; - TVMFFIByteArray data{str.data(), str.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeFromString(&data, &out)); - return out; -} - -/*! - * \brief Convert a DLDataType to a string. - * \param dtype The DLDataType to convert. - * \return The string. - */ -inline String DLDataTypeToString(DLDataType dtype) { - TVMFFIAny out; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIDataTypeToString(&dtype, &out)); - return TypeTraits::MoveFromAnyAfterCheck(&out); -} - -// DLDataType -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDataType; - - TVM_FFI_INLINE static void CopyToAnyView(const DLDataType &src, TVMFFIAny *result) { - // clear padding part to ensure the equality check can always check the v_uint64 part - result->v_uint64 = 0; - result->type_index = TypeIndex::kTVMFFIDataType; - result->zero_padding = 0; - result->v_dtype = src; - } - - TVM_FFI_INLINE static void MoveToAny(DLDataType src, TVMFFIAny *result) { - // clear padding part to ensure the equality check can always check the v_uint64 part - result->v_uint64 = 0; - result->type_index = TypeIndex::kTVMFFIDataType; - result->zero_padding = 0; - result->v_dtype = src; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - return src->type_index == TypeIndex::kTVMFFIDataType; - } - - TVM_FFI_INLINE static DLDataType CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { - return src->v_dtype; - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFIDataType) { - return src->v_dtype; - } - // enable string to dtype auto conversion - if (auto opt_str = TypeTraits::TryCastFromAnyView(src)) { - return StringToDLDataType(*opt_str); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return ffi::StaticTypeKey::kTVMFFIDataType; } - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(ffi::StaticTypeKey::kTVMFFIDataType) + R"("})"; - } -}; -} // namespace ffi -} // namespace tvm - -// define DLDataType comparison and printing in root namespace -inline std::ostream &operator<<(std::ostream &os, DLDataType dtype) { // NOLINT(*) - return os << tvm::ffi::DLDataTypeToString(dtype); -} - -inline bool operator==(const DLDataType &lhs, const DLDataType &rhs) { - return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes; -} - -inline bool operator!=(const DLDataType &lhs, const DLDataType &rhs) { return !(lhs == rhs); } -#endif // TVM_FFI_DTYPE_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/endian.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/endian.h deleted file mode 100644 index f8311d673..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/endian.h +++ /dev/null @@ -1,89 +0,0 @@ - -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * \file tvm/ffi/endian.h - * \brief Endian detection and handling - */ -#ifndef TVM_FFI_ENDIAN_H_ -#define TVM_FFI_ENDIAN_H_ - -#include -#include - -#ifndef TVM_FFI_IO_USE_LITTLE_ENDIAN -#define TVM_FFI_IO_USE_LITTLE_ENDIAN 1 -#endif - -#ifdef TVM_FFI_CMAKE_LITTLE_ENDIAN -// If compiled with CMake, use CMake's endian detection logic -#define TVM_FFI_LITTLE_ENDIAN TVM_FFI_CMAKE_LITTLE_ENDIAN -#else -#if defined(__APPLE__) || defined(_WIN32) -#define TVM_FFI_LITTLE_ENDIAN 1 -#elif defined(__GLIBC__) || defined(__GNU_LIBRARY__) || defined(__ANDROID__) || defined(__RISCV__) || defined(__MUSL__) -#include -#define TVM_FFI_LITTLE_ENDIAN (__BYTE_ORDER == __LITTLE_ENDIAN) -#elif defined(__FreeBSD__) || defined(__OpenBSD__) -#include -#define TVM_FFI_LITTLE_ENDIAN (_BYTE_ORDER == _LITTLE_ENDIAN) -#elif defined(__QNX__) -#include -#define TVM_FFI_LITTLE_ENDIAN (BYTE_ORDER == LITTLE_ENDIAN) -#elif defined(__EMSCRIPTEN__) || defined(__hexagon__) -#define TVM_FFI_LITTLE_ENDIAN 1 -#elif defined(__sun) || defined(sun) -#include -#if defined(_LITTLE_ENDIAN) -#define TVM_FFI_LITTLE_ENDIAN 1 -#else -#define TVM_FFI_LITTLE_ENDIAN 0 -#endif -#else -#error "Unable to determine endianness of your machine; use CMake to compile" -#endif -#endif - -/*! \brief whether serialize using little endian */ -#define TVM_FFI_IO_NO_ENDIAN_SWAP (TVM_FFI_LITTLE_ENDIAN == TVM_FFI_IO_USE_LITTLE_ENDIAN) - -namespace tvm { -namespace ffi { -/*! - * \brief A generic inplace byte swapping function. - * \param data The data pointer. - * \param elem_bytes The number of bytes of the data elements - * \param num_elems Number of elements in the data. - * \note Always try pass in constant elem_bytes to enable - * compiler optimization - */ -inline void ByteSwap(void *data, size_t elem_bytes, size_t num_elems) { - for (size_t i = 0; i < num_elems; ++i) { - uint8_t *bptr = reinterpret_cast(data) + elem_bytes * i; - for (size_t j = 0; j < elem_bytes / 2; ++j) { - uint8_t v = bptr[elem_bytes - 1 - j]; - bptr[elem_bytes - 1 - j] = bptr[j]; - bptr[j] = v; - } - } -} -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_ENDIAN_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/error.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/error.h deleted file mode 100644 index f310dcbcc..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/error.h +++ /dev/null @@ -1,398 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/* - * \file tvm/ffi/error.h - * \brief Error handling component. - */ -#ifndef TVM_FFI_ERROR_H_ -#define TVM_FFI_ERROR_H_ - -#include "base_details.h" -#include "c_api.h" -#include "memory.h" -#include "object.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -/*! - * \brief Macro defines whether we enable libbacktrace - */ -#ifndef TVM_FFI_USE_LIBBACKTRACE -#define TVM_FFI_USE_LIBBACKTRACE 1 -#endif - -/*! - * \brief Macro defines whether to install signal handler - * and print backtrace during segfault - */ -#ifndef TVM_FFI_BACKTRACE_ON_SEGFAULT -#define TVM_FFI_BACKTRACE_ON_SEGFAULT 1 -#endif - -#ifndef TVM_FFI_ALWAYS_LOG_BEFORE_THROW -#define TVM_FFI_ALWAYS_LOG_BEFORE_THROW 0 -#endif - -namespace tvm { -namespace ffi { - -/*! - * \brief Error already set in frontend env. - * - * This error can be thrown by EnvCheckSignals to indicate - * that there is an error set in the frontend environment(e.g. - * python interpreter). The TVM FFI should catch this error - * and return a proper code to tell the frontend caller about - * this fact. - * - * \code - * - * void ExampleLongRunningFunction() { - * if (TVMFFIEnvCheckSignals() != 0) { - * throw ::tvm::ffi::EnvErrorAlreadySet(); - * } - * // do work here - * } - * - * \endcode - */ -struct EnvErrorAlreadySet : public std::exception {}; - -/*! - * \brief Error object class. - */ -class ErrorObj : public Object, public TVMFFIErrorCell { -public: - /// \cond Doxygen_Suppress - static constexpr const int32_t _type_index = TypeIndex::kTVMFFIError; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIError, ErrorObj, Object); - /// \endcond -}; - -namespace details { -class ErrorObjFromStd : public ErrorObj { -public: - ErrorObjFromStd(std::string kind, std::string message, std::string backtrace) - : kind_data_(std::move(kind)), - message_data_(std::move(message)), - backtrace_data_(std::move(backtrace)) { - this->kind = TVMFFIByteArray{kind_data_.data(), kind_data_.length()}; - this->message = TVMFFIByteArray{message_data_.data(), message_data_.length()}; - this->backtrace = TVMFFIByteArray{backtrace_data_.data(), backtrace_data_.length()}; - this->update_backtrace = UpdateBacktrace; - } - -private: - /*! - * \brief Update the backtrace of the error object. - * \param backtrace The backtrace to update. - * \param update_mode The mode to update the backtrace, - * can be either kTVMFFIBacktraceUpdateModeReplace, kTVMFFIBacktraceUpdateModeAppend. - */ - static void UpdateBacktrace(TVMFFIObjectHandle self, const TVMFFIByteArray *backtrace_str, - int32_t update_mode) { - ErrorObjFromStd *obj = static_cast(self); - if (update_mode == kTVMFFIBacktraceUpdateModeReplace) { - obj->backtrace_data_.resize(backtrace_str->size); - std::memcpy(obj->backtrace_data_.data(), backtrace_str->data, backtrace_str->size); - obj->backtrace = TVMFFIByteArray{obj->backtrace_data_.data(), obj->backtrace_data_.length()}; - } else { - obj->backtrace_data_.append(backtrace_str->data, backtrace_str->size); - obj->backtrace = TVMFFIByteArray{obj->backtrace_data_.data(), obj->backtrace_data_.length()}; - } - } - - std::string kind_data_; - std::string message_data_; - std::string backtrace_data_; -}; -} // namespace details - -/*! - * \brief Managed reference to ErrorObj - * \sa Error Object - */ -class Error : public ObjectRef, public std::exception { -public: - /*! - * \brief Constructor - * \param kind The kind of the error. - * \param message The message of the error. - * \param backtrace The backtrace of the error. - */ - Error(std::string kind, std::string message, std::string backtrace) { - data_ = make_object(std::move(kind), std::move(message), - std::move(backtrace)); - } - - /*! - * \brief Constructor - * \param kind The kind of the error. - * \param message The message of the error. - * \param backtrace The backtrace of the error. - */ - Error(std::string kind, std::string message, const TVMFFIByteArray *backtrace) - : Error(std::move(kind), std::move(message), std::string(backtrace->data, backtrace->size)) {} - - /*! - * \brief Get the kind of the error object. - * \return The kind of the error object. - */ - std::string kind() const { - ErrorObj *obj = static_cast(data_.get()); - return std::string(obj->kind.data, obj->kind.size); - } - - /*! - * \brief Get the message of the error object. - * \return The message of the error object. - */ - std::string message() const { - ErrorObj *obj = static_cast(data_.get()); - return std::string(obj->message.data, obj->message.size); - } - - /*! - * \brief Get the backtrace of the error object. - * \return The backtrace of the error object. - * \note Consider use TracebackMostRecentCallLast for pythonic style traceback. - * - * \sa TracebackMostRecentCallLast - */ - std::string backtrace() const { - ErrorObj *obj = static_cast(data_.get()); - return std::string(obj->backtrace.data, obj->backtrace.size); - } - - /*! - * \brief Get the traceback in the order of most recent call last. - * - * \return The traceback of the error object. - */ - std::string TracebackMostRecentCallLast() const { - // add placeholder for the first line - std::vector line_breakers = {-1}; - ErrorObj *obj = static_cast(data_.get()); - for (size_t i = 0; i < obj->backtrace.size; i++) { - if (obj->backtrace.data[i] == '\n') { - line_breakers.push_back(static_cast(i)); - } - } - std::string result; - result.reserve(obj->backtrace.size); - for (size_t i = line_breakers.size() - 1; i > 0; --i) { - int64_t line_start = line_breakers[i - 1] + 1; - int64_t line_end = line_breakers[i]; - if (line_start == line_end) { - continue; - } - result.append(obj->backtrace.data + line_start, line_end - line_start); - result.append("\n"); - } - return result; - } - - /*! - * \brief Update the backtrace of the error object. - * \param backtrace_str The backtrace to update. - * \param update_mode The mode to update the backtrace, - * can be either kTVMFFIBacktraceUpdateModeReplace, kTVMFFIBacktraceUpdateModeAppend. - */ - void UpdateBacktrace(const TVMFFIByteArray *backtrace_str, int32_t update_mode) { - ErrorObj *obj = static_cast(data_.get()); - obj->update_backtrace(obj, backtrace_str, update_mode); - } - - /*! - * \brief Get the full message of the error, including kind, message and traceback. - * \return The full message of the error object. - */ - std::string FullMessage() const { - ErrorObj *obj = static_cast(data_.get()); - return (std::string("Traceback (most recent call last):\n") + TracebackMostRecentCallLast() + std::string(obj->kind.data, obj->kind.size) + std::string(": ") + std::string(obj->message.data, obj->message.size) + '\n'); - } - - /*! - * \brief Get the error message - * \return The error message - * \note To get the full message including kind and traceback, use FullMessage() instead. - */ - const char *what() const noexcept(true) override { - ErrorObj *obj = static_cast(data_.get()); - return obj->message.data; - } - - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Error, ObjectRef, ErrorObj); - /// \endcond -}; - -namespace details { - -class ErrorBuilder { -public: - explicit ErrorBuilder(std::string kind, std::string backtrace, bool log_before_throw) - : kind_(std::move(kind)), - backtrace_(std::move(backtrace)), - log_before_throw_(log_before_throw) {} - - explicit ErrorBuilder(std::string kind, const TVMFFIByteArray *backtrace, bool log_before_throw) - : ErrorBuilder(std::move(kind), std::string(backtrace->data, backtrace->size), - log_before_throw) {} - -// MSVC disable warning in error builder as it is exepected -#ifdef _MSC_VER -#pragma warning(push) -#pragma warning(disable : 4722) -#endif - // avoid inline to reduce binary size, error throw path do not need to be fast - [[noreturn]] ~ErrorBuilder() noexcept(false) { - ::tvm::ffi::Error error(std::move(kind_), stream_.str(), std::move(backtrace_)); - if (log_before_throw_) { - std::cerr << error.FullMessage(); - } - throw error; - } -#ifdef _MSC_VER -#pragma warning(pop) -#endif - - std::ostringstream &stream() { return stream_; } - -protected: - std::string kind_; - std::ostringstream stream_; - std::string backtrace_; - bool log_before_throw_; -}; - -} // namespace details - -/*! - * \brief Helper macro to throw an error with backtrace and message - * - * \code - * - * void ThrowError() { - * TVM_FFI_THROW(RuntimeError) << "error message"; - * } - * - * \endcode - */ -#define TVM_FFI_THROW(ErrorKind) \ - ::tvm::ffi::details::ErrorBuilder(#ErrorKind, \ - TVMFFIBacktrace(__FILE__, __LINE__, TVM_FFI_FUNC_SIG, 0), \ - TVM_FFI_ALWAYS_LOG_BEFORE_THROW) \ - .stream() - -/*! - * \brief Explicitly log error in stderr and then throw the error. - * - * \note This is only necessary on startup functions where we know error - * cannot be caught, and it is better to have a clear log message. - * In most cases, we should use use TVM_FFI_THROW. - */ -#define TVM_FFI_LOG_AND_THROW(ErrorKind) \ - ::tvm::ffi::details::ErrorBuilder( \ - #ErrorKind, TVMFFIBacktrace(__FILE__, __LINE__, TVM_FFI_FUNC_SIG, 0), true) \ - .stream() - -// Glog style checks with TVM_FFI prefix -// NOTE: we explicitly avoid glog style generic macros (LOG/CHECK) in tvm ffi -// to avoid potential conflict of downstream users who might have their own GLOG style macros -namespace details { - -template -TVM_FFI_INLINE std::unique_ptr LogCheckFormat(const X &x, const Y &y) { - std::ostringstream os; - os << " (" << x << " vs. " << y << ") "; // CHECK_XX(x, y) requires x and y can be serialized to - // string. Use CHECK(x OP y) otherwise. - return std::make_unique(os.str()); -} - -#define TVM_FFI_CHECK_FUNC(name, op) \ - template \ - TVM_FFI_INLINE std::unique_ptr LogCheck##name(const X &x, const Y &y) { \ - if (x op y) \ - return nullptr; \ - return LogCheckFormat(x, y); \ - } \ - TVM_FFI_INLINE std::unique_ptr LogCheck##name(int x, int y) { \ - return LogCheck##name(x, y); \ - } - -// Inline _Pragma in macros does not work reliably on old version of MSVC and -// GCC. We wrap all comparisons in a function so that we can use #pragma to -// silence bad comparison warnings. -#if defined(__GNUC__) || defined(__clang__) // GCC and Clang -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wsign-compare" -#elif defined(_MSC_VER) // MSVC -#pragma warning(push) -#pragma warning(disable : 4389) // '==' : signed/unsigned mismatch -#endif - -TVM_FFI_CHECK_FUNC(_LT, <) -TVM_FFI_CHECK_FUNC(_GT, >) -TVM_FFI_CHECK_FUNC(_LE, <=) -TVM_FFI_CHECK_FUNC(_GE, >=) -TVM_FFI_CHECK_FUNC(_EQ, ==) -TVM_FFI_CHECK_FUNC(_NE, !=) - -#if defined(__GNUC__) || defined(__clang__) // GCC and Clang -#pragma GCC diagnostic pop -#elif defined(_MSC_VER) // MSVC -#pragma warning(pop) -#endif -} // namespace details - -#define TVM_FFI_ICHECK_BINARY_OP(name, op, x, y) \ - if (auto __tvm_ffi_log_err = /* NOLINT(bugprone-reserved-identifier) */ \ - ::tvm::ffi::details::LogCheck##name(x, y)) \ - TVM_FFI_THROW(InternalError) << "Check failed: " << #x " " #op " " #y << *__tvm_ffi_log_err \ - << ": " - -#define TVM_FFI_ICHECK(x) \ - if (!(x)) \ - TVM_FFI_THROW(InternalError) << "Check failed: (" #x << ") is false: " - -#define TVM_FFI_CHECK(cond, ErrorKind) \ - if (!(cond)) \ - TVM_FFI_THROW(ErrorKind) << "Check failed: (" #cond << ") is false: " - -#define TVM_FFI_ICHECK_LT(x, y) TVM_FFI_ICHECK_BINARY_OP(_LT, <, x, y) -#define TVM_FFI_ICHECK_GT(x, y) TVM_FFI_ICHECK_BINARY_OP(_GT, >, x, y) -#define TVM_FFI_ICHECK_LE(x, y) TVM_FFI_ICHECK_BINARY_OP(_LE, <=, x, y) -#define TVM_FFI_ICHECK_GE(x, y) TVM_FFI_ICHECK_BINARY_OP(_GE, >=, x, y) -#define TVM_FFI_ICHECK_EQ(x, y) TVM_FFI_ICHECK_BINARY_OP(_EQ, ==, x, y) -#define TVM_FFI_ICHECK_NE(x, y) TVM_FFI_ICHECK_BINARY_OP(_NE, !=, x, y) -#define TVM_FFI_ICHECK_NOTNULL(x) \ - ((x) == nullptr ? TVM_FFI_THROW(InternalError) << "Check not null: " #x << ' ', \ - (x) : (x)) // NOLINT(*) -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_ERROR_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base.h deleted file mode 100644 index 852db4466..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/base.h - * \brief Base header for Extra API. - * - * The extra APIs contains a minmal set of extra APIs that are not - * required to support essential core functionality. - */ -#ifndef TVM_FFI_EXTRA_BASE_H_ -#define TVM_FFI_EXTRA_BASE_H_ - -#include - -/*! - * \brief Marks the API as extra c++ api that is defined in cc files. - * - * They are implemented in cc files to reduce compile-time overhead. - * The input/output only uses POD/Any/ObjectRef for ABI stability. - * However, these extra APIs may have an issue across MSVC/Itanium ABI, - * - * Related features are also available through reflection based function - * that is fully based on C API - * - * The project aims to minimize the number of extra C++ APIs to keep things - * lightweight and restrict the use to non-core functionalities. - */ -#ifndef TVM_FFI_EXTRA_CXX_API -#define TVM_FFI_EXTRA_CXX_API TVM_FFI_DLL -#endif - -#endif // TVM_FFI_EXTRA_BASE_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base64.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base64.h deleted file mode 100644 index f23314a75..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/base64.h +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * - * \file tvm/ffi/extra/base64.h - * \brief Base64 encoding and decoding utilities - */ -#ifndef TVM_FFI_EXTRA_BASE64_H_ -#define TVM_FFI_EXTRA_BASE64_H_ - -#include - -#include - -namespace tvm { -namespace ffi { -/*! - * \brief Encode a byte array into a base64 string - * \param bytes The byte array to encode - * \return The base64 encoded string - */ -inline String Base64Encode(TVMFFIByteArray bytes) { - // encoding every 3 bytes into 4 characters - constexpr const char kEncodeTable[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - std::string encoded; - encoded.reserve(4 * (bytes.size + 2) / 3); - - for (size_t i = 0; i < (bytes.size / 3) * 3; i += 3) { - int32_t buf[3]; - buf[0] = static_cast(static_cast(bytes.data[i])); - buf[1] = static_cast(static_cast(bytes.data[i + 1])); - buf[2] = static_cast(static_cast(bytes.data[i + 2])); - encoded.push_back(kEncodeTable[buf[0] >> 2]); - encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]); - encoded.push_back(kEncodeTable[((buf[1] << 2) | (buf[2] >> 6)) & 0x3F]); - encoded.push_back(kEncodeTable[buf[2] & 0x3F]); - } - if (bytes.size % 3 == 1) { - int32_t buf[1] = {static_cast(static_cast(bytes.data[bytes.size - 1]))}; - encoded.push_back(kEncodeTable[buf[0] >> 2]); - encoded.push_back(kEncodeTable[(buf[0] << 4) & 0x3F]); - encoded.push_back('='); - encoded.push_back('='); - } else if (bytes.size % 3 == 2) { - int32_t buf[2] = {static_cast(static_cast(bytes.data[bytes.size - 2])), - static_cast(static_cast(bytes.data[bytes.size - 1]))}; - encoded.push_back(kEncodeTable[buf[0] >> 2]); - encoded.push_back(kEncodeTable[((buf[0] << 4) | (buf[1] >> 4)) & 0x3F]); - encoded.push_back(kEncodeTable[(buf[1] << 2) & 0x3F]); - encoded.push_back('='); - } - return String(encoded); -} - -/*! - * \brief Encode a bytes object into a base64 string - * \param data The bytes object to encode - * \return The base64 encoded string - */ -inline String Base64Encode(const Bytes &data) { - return Base64Encode(TVMFFIByteArray{data.data(), data.size()}); -} - -/*! - * \brief Decode a base64 string into a byte array - * \param bytes The bytes to be decoded - * \return The decoded byte array - */ -inline Bytes Base64Decode(TVMFFIByteArray bytes) { - constexpr const char kDecodeTable[] = { - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 62, // '+' - 0, 0, 0, - 63, // '/' - 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' - 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' - 0, 0, 0, 0, 0, 0, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, - 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' - }; - std::string decoded; - decoded.reserve(bytes.size * 3 / 4); - if (bytes.size == 0) { - return Bytes(); - } - TVM_FFI_ICHECK(bytes.size % 4 == 0) << "invalid base64 encoding"; - // leverage this property to simplify decoding - static_assert('=' < sizeof(kDecodeTable) && kDecodeTable[static_cast('=')] == 0); - // base64 is always multiple of 4 bytes - for (size_t i = 0; i < bytes.size; i += 4) { - // decode every 4 characters into 24bits, each character contains 6 bits - // note that = is also decoded as 0, which is safe to skip - int32_t buf[4] = { - static_cast(static_cast(bytes.data[i])), - static_cast(static_cast(bytes.data[i + 1])), - static_cast(static_cast(bytes.data[i + 2])), - static_cast(static_cast(bytes.data[i + 3])), - }; - int32_t value_i24 = (static_cast(kDecodeTable[buf[0]]) << 18) | (static_cast(kDecodeTable[buf[1]]) << 12) | (static_cast(kDecodeTable[buf[2]]) << 6) | static_cast(kDecodeTable[buf[3]]); - // unpack 24bits into 3 bytes, each contains 8 bits - decoded.push_back(static_cast((value_i24 >> 16) & 0xFF)); - if (buf[2] != '=') { - decoded.push_back(static_cast((value_i24 >> 8) & 0xFF)); - } - if (buf[3] != '=') { - decoded.push_back(static_cast(value_i24 & 0xFF)); - } - } - return Bytes(decoded); -} - -/*! - * \brief Decode a base64 string into a byte array - * \param data The base64 encoded string to decode - * \return The decoded byte array - */ -inline Bytes Base64Decode(const String &data) { - return Base64Decode(TVMFFIByteArray{data.data(), data.size()}); -} - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_BASE64_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/c_env_api.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/c_env_api.h deleted file mode 100644 index 9f879705c..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/c_env_api.h +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -// NOLINTBEGIN(modernize-use-using) -/*! - * \file tvm/ffi/extra/c_env_api.h - * \brief Extra environment API. - */ -#ifndef TVM_FFI_EXTRA_C_ENV_API_H_ -#define TVM_FFI_EXTRA_C_ENV_API_H_ - -#include "../c_api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -// ---------------------------------------------------------------------------- -// Stream context -// Focusing on minimalistic thread-local context recording stream being used. -// We explicitly not handle allocation/de-allocation of stream here. -// ---------------------------------------------------------------------------- -/*! - * \brief The type of the stream handle. - */ -typedef void *TVMFFIStreamHandle; - -/*! - * \brief FFI function to set the current stream for a device - * - * \param device_type The type of the device. - * \param device_id The id of the device. - * \param stream The stream to set. - * \param opt_out_original_stream Output original stream if the address is not nullptr. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvSetStream(int32_t device_type, int32_t device_id, - TVMFFIStreamHandle stream, - TVMFFIStreamHandle *opt_out_original_stream); - -/*! - * \brief FFI function to get the current stream for a device - * - * \param device_type The type of the device. - * \param device_id The id of the device. - * \return The current stream of the device. - */ -TVM_FFI_DLL TVMFFIStreamHandle TVMFFIEnvGetStream(int32_t device_type, int32_t device_id); - -/*! - * \brief Set the current DLPackManagedTensorAllocator in thread-local(TLS) context - * - * \param allocator The allocator to set. - * \param write_to_global_context Whether to also set the allocator to the global context. - * \param opt_out_original_allocator Output original TLS allocator if the address is not nullptr. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvSetDLPackManagedTensorAllocator( - DLPackManagedTensorAllocator allocator, int write_to_global_context, - DLPackManagedTensorAllocator *opt_out_original_allocator); - -/*! - * \brief FFI function get the current DLPackManagedTensorAllocator stored in context. - * - * This function first queries the global context, and if not found, - * queries the thread-local context. - * - * \return The current setted DLPackManagedTensorAllocator - */ -TVM_FFI_DLL DLPackManagedTensorAllocator TVMFFIEnvGetDLPackManagedTensorAllocator(); - -/*! - * \brief Allocate a tensor from the allocator set in thread-local(TLS) context. - * - * This function redirects to one of environment allocator. As of now, we only - * support the DLPackManagedTensorAllocator set in thread-local(TLS) context. - * - * \param prototype The prototype DLTensor, only the dtype, ndim, shape, - * and device fields are used, other fields are ignored. - * \param out The output tensor in kTVMFFITensor type. - * \return 0 when success, nonzero when failure happens - * \sa TVMFFIEnvSetDLPackManagedTensorAllocator - */ -TVM_FFI_DLL int TVMFFIEnvTensorAlloc(DLTensor *prototype, TVMFFIObjectHandle *out); - -/*! - * \brief Check if there are any signals raised in the surrounding env. - * \return 0 when success, nonzero when failure happens - * \note Under python this function redirects to PyErr_CheckSignals - */ -TVM_FFI_DLL int TVMFFIEnvCheckSignals(); - -/*! - * \brief Register a symbol into the from the surrounding env such as python - * \param name The name of the symbol. - * \param symbol The symbol to register. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvRegisterCAPI(const char *name, void *symbol); - -// ---------------------------------------------------------------------------- -// Module symbol management in callee side -// ---------------------------------------------------------------------------- -/*! - * \brief FFI function to lookup a function from a module's imports. - * - * This is a helper function that is used by generated code. - * - * \param library_ctx The library context module handle. - * \param func_name The name of the function. - * \param out The result function. - * \note The returned function is a weak reference that is cached/owned by the module. - * \return 0 when no error is thrown, -1 when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvModLookupFromImports(TVMFFIObjectHandle library_ctx, const char *func_name, - TVMFFIObjectHandle *out); - -/*! - * \brief Register a symbol value that will be initialized when a library with the symbol is loaded. - * - * This function can be used to make context functions to be available in the library - * module that wants to avoid an explicit link dependency - * - * \param name The name of the symbol. - * \param symbol The symbol to register. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvModRegisterContextSymbol(const char *name, void *symbol); - -/*! - * \brief Register a symbol that will be initialized when a system library is loaded. - * - * \param name The name of the symbol. - * \param symbol The symbol to register. - * \return 0 when success, nonzero when failure happens - */ -TVM_FFI_DLL int TVMFFIEnvModRegisterSystemLibSymbol(const char *name, void *symbol); - -#ifdef __cplusplus -} // extern "C" -#endif -#endif // TVM_FFI_EXTRA_C_ENV_API_H_ -// NOLINTEND(modernize-use-using) diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/base.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/base.h deleted file mode 100644 index be7cad2d8..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/base.h +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/cuda/base.h - * \brief CUDA base utilities. - */ -#ifndef TVM_FFI_EXTRA_CUDA_BASE_H_ -#define TVM_FFI_EXTRA_CUDA_BASE_H_ - -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Macro for checking CUDA runtime API errors. - * - * This macro checks the return value of CUDA runtime API calls and throws - * a RuntimeError with detailed error information if the call fails. - * - * \param stmt The CUDA runtime API call to check. - */ -#define TVM_FFI_CHECK_CUDA_ERROR(stmt) \ - do { \ - cudaError_t __err = (stmt); \ - if (__err != cudaSuccess) { \ - const char *__err_name = cudaGetErrorName(__err); \ - const char *__err_str = cudaGetErrorString(__err); \ - TVM_FFI_THROW(RuntimeError) << "CUDA Runtime Error: " << __err_name << " (" \ - << static_cast(__err) << "): " << __err_str; \ - } \ - } while (0) - -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_EXTRA_CUDA_BASE_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/cubin_launcher.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/cubin_launcher.h deleted file mode 100644 index 10da7e532..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/cubin_launcher.h +++ /dev/null @@ -1,604 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/cuda/cubin_launcher.h - * \brief CUDA CUBIN launcher utility for loading and executing CUDA kernels. - * - * This header provides a lightweight C++ wrapper around CUDA Runtime API - * for loading CUBIN modules and launching kernels. It supports: - * - Loading CUBIN from memory (embedded data) - * - Multi-GPU execution using CUDA primary contexts - * - Kernel parameter management and launch configuration - */ -#ifndef TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_ -#define TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_ - -#include -#include -#include -#include -#include - -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief A simple 3D dimension type for CUDA kernel launch configuration. - * - * This struct mimics the behavior of dim3 from CUDA Runtime API and provides - * a compatible interface for kernel launch configuration. It can be constructed - * from 1, 2, or 3 dimensions. - */ -struct dim3 { - /*! \brief X dimension (number of blocks in x-direction or threads in x-direction) */ - unsigned int x; - /*! \brief Y dimension (number of blocks in y-direction or threads in y-direction) */ - unsigned int y; - /*! \brief Z dimension (number of blocks in z-direction or threads in z-direction) */ - unsigned int z; - - /*! \brief Default constructor initializes to (1, 1, 1) */ - dim3() : x(1), y(1), z(1) {} - - /*! \brief Construct with x dimension, y and z default to 1 */ - explicit dim3(unsigned int x_) : x(x_), y(1), z(1) {} - - /*! \brief Construct with x and y dimensions, z defaults to 1 */ - dim3(unsigned int x_, unsigned int y_) : x(x_), y(y_), z(1) {} - - /*! \brief Construct with all three dimensions */ - dim3(unsigned int x_, unsigned int y_, unsigned int z_) : x(x_), y(y_), z(z_) {} -}; - -/*! - * \brief Macro to embed a CUBIN module with static initialization. - * - * This macro declares external symbols for embedded CUBIN data and creates - * a singleton struct to manage the CubinModule instance. The CUBIN data - * symbols should be named `__tvm_ffi__cubin_` and `__tvm_ffi__cubin__end`, - * typically created using objcopy and ld. - * - * \par Creating Embedded CUBIN with TVM-FFI Utilities - * TVM-FFI provides utilities to simplify CUBIN embedding. You have two options: - * - * \par Option 1: CMake Utility (Recommended) - * Use the `tvm_ffi_embed_cubin` CMake function: - * \code{.cmake} - * # Find tvm_ffi package (provides tvm_ffi_embed_cubin utility) - * find_package(tvm_ffi CONFIG REQUIRED) - * find_package(CUDAToolkit REQUIRED) - * - * # Compile CUDA kernel to CUBIN - * tvm_ffi_generate_cubin( - * OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/kernel.cubin - * SOURCE src/kernel.cu - * ARCH native # or sm_75, sm_80, etc. - * ) - * - * # Embed CUBIN into C++ object file - * tvm_ffi_embed_cubin( - * OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/mycode_with_cubin.o - * SOURCE src/mycode.cc - * CUBIN ${CMAKE_CURRENT_BINARY_DIR}/kernel.cubin - * NAME my_kernels # Must match TVM_FFI_EMBED_CUBIN(my_kernels) in code - * ) - * - * # Link into shared library - * add_library(mylib SHARED ${CMAKE_CURRENT_BINARY_DIR}/mycode_with_cubin.o) - * target_link_libraries(mylib PRIVATE tvm_ffi_header CUDA::cudart) - * \endcode - * - * \par Option 2: Python Utility - * Use the `tvm_ffi.utils.embed_cubin` command-line tool: - * \code{.bash} - * # Step 1: Compile CUDA kernel to CUBIN - * nvcc --cubin -arch=sm_75 kernel.cu -o kernel.cubin - * - * # Step 2: Compile C++ source to object file - * g++ -c -fPIC -std=c++17 -I/path/to/tvm-ffi/include mycode.cc -o mycode.o - * - * # Step 3: Embed CUBIN using Python utility - * python -m tvm_ffi.utils.embed_cubin \ - * --output-obj mycode_with_cubin.o \ - * --input-obj mycode.o \ - * --cubin kernel.cubin \ - * --name my_kernels - * - * # Step 4: Link into shared library - * g++ -o mylib.so -shared mycode_with_cubin.o -lcudart - * \endcode - * - * The utilities automatically handle: - * - Symbol renaming to __tvm_ffi__cubin_ format - * - Adding .note.GNU-stack section for security - * - Symbol localization to prevent conflicts - * - * \par Usage in C++ Code - * In your C++ source file, use the embedded CUBIN: - * \code{.cpp} - * #include - * - * // Declare the embedded CUBIN module (name must match CMake NAME parameter) - * TVM_FFI_EMBED_CUBIN(my_kernels); - * - * void MyFunction() { - * // Get kernel from embedded CUBIN (cached in static variable for efficiency) - * static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(my_kernels, "my_kernel"); - * // Use kernel... - * } - * \endcode - * - * \note CMake Setup: To use the utilities, add to your CMakeLists.txt: - * \code{.cmake} - * find_package(tvm_ffi CONFIG REQUIRED) # Provides tvm_ffi_embed_cubin utility - * \endcode - * - * \par Option 3: Python Integration with load_inline - * When using `tvm_ffi.cpp.load_inline()` with the `embed_cubin` parameter, - * the CUBIN data is automatically embedded using the Python utility internally: - * \code{.py} - * from tvm_ffi import cpp - * from tvm_ffi.cpp import nvrtc - * - * # Compile CUDA source to CUBIN - * cubin_bytes = nvrtc.nvrtc_compile(cuda_source) - * - * # Load with embedded CUBIN - automatically handles embedding - * mod = cpp.load_inline( - * "my_module", - * cuda_sources=cpp_code, - * embed_cubin={"my_kernels": cubin_bytes}, - * extra_ldflags=["-lcudart"] - * ) - * \endcode - * - * \param name The identifier for this embedded CUBIN module (must match the - * symbol names created with objcopy or the key in embed_cubin dict). - * - * \see TVM_FFI_EMBED_CUBIN_GET_KERNEL - * \see CubinModule - * \see CubinKernel - */ -#define TVM_FFI_EMBED_CUBIN(name) \ - extern "C" const char __tvm_ffi__cubin_##name[]; \ - extern "C" const char __tvm_ffi__cubin_##name##_end[]; \ - namespace { \ - struct EmbedCubinModule_##name { \ - tvm::ffi::CubinModule mod{__tvm_ffi__cubin_##name}; \ - static EmbedCubinModule_##name *Global() { \ - static EmbedCubinModule_##name inst; \ - return &inst; \ - } \ - }; \ - } /* anonymous namespace */ - -/*! - * \brief Macro to get a kernel from an embedded CUBIN module. - * - * This macro retrieves a kernel by name from a previously declared embedded - * CUBIN module (using TVM_FFI_EMBED_CUBIN). The result is a CubinKernel object - * that can be used to launch the kernel with specified parameters. - * - * \par Performance Tip - * It's recommended to store the result in a static variable to avoid repeated - * kernel lookups, which improves performance: - * \code{.cpp} - * static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(my_kernels, "kernel_name"); - * \endcode - * - * \par Complete Example - * \code{.cpp} - * // Declare embedded CUBIN module - * TVM_FFI_EMBED_CUBIN(my_kernels); - * - * void LaunchKernel(tvm::ffi::TensorView input, tvm::ffi::TensorView output) { - * // Get kernel (cached in static variable for efficiency) - * static auto kernel = TVM_FFI_EMBED_CUBIN_GET_KERNEL(my_kernels, "add_one"); - * - * // Prepare kernel arguments - * void* in_ptr = input.data_ptr(); - * void* out_ptr = output.data_ptr(); - * int64_t n = input.size(0); - * void* args[] = {&in_ptr, &out_ptr, &n}; - * - * // Configure launch - * tvm::ffi::dim3 grid((n + 255) / 256); - * tvm::ffi::dim3 block(256); - * - * // Get stream and launch - * DLDevice device = input.device(); - * cudaStream_t stream = static_cast( - * TVMFFIEnvGetStream(device.device_type, device.device_id)); - * - * cudaError_t result = kernel.Launch(args, grid, block, stream); - * TVM_FFI_CHECK_CUDA_ERROR(result); - * } - * \endcode - * - * \param name The identifier of the embedded CUBIN module (must match the name - * used in TVM_FFI_EMBED_CUBIN). - * \param kernel_name The name of the kernel function as it appears in the CUBIN - * (typically the function name for `extern "C"` kernels). - * \return A CubinKernel object for the specified kernel. - * - * \see TVM_FFI_EMBED_CUBIN - * \see CubinKernel::Launch - */ -#define TVM_FFI_EMBED_CUBIN_GET_KERNEL(name, kernel_name) \ - (EmbedCubinModule_##name::Global()->mod[kernel_name]) - -// Forward declaration -class CubinKernel; - -/*! - * \brief CUDA CUBIN module loader and manager. - * - * This class provides a RAII wrapper around CUDA Runtime API's library management. - * It loads a CUBIN module from memory and manages the library handle automatically. - * The library is unloaded when the CubinModule object is destroyed. - * - * \par Features - * - Load CUBIN from memory (embedded data or runtime-generated) - * - Automatic resource management (RAII pattern) - * - Multi-GPU execution using CUDA primary contexts - * - Retrieve multiple kernels from the same module - * - * \par Example Usage - * \code{.cpp} - * // Load CUBIN from memory - * tvm::ffi::Bytes cubin_data = ...; - * tvm::ffi::CubinModule module(cubin_data); - * - * // Get kernels by name - * tvm::ffi::CubinKernel kernel1 = module["add_one"]; - * tvm::ffi::CubinKernel kernel2 = module.GetKernel("mul_two"); - * - * // Launch kernels - * void* args[] = {...}; - * tvm::ffi::dim3 grid(32), block(256); - * cudaStream_t stream = ...; - * kernel1.Launch(args, grid, block, stream); - * \endcode - * - * \note This class is movable but not copyable. - * \see TVM_FFI_EMBED_CUBIN for embedding CUBIN at compile time - * \see CubinKernel for kernel launching - */ -class CubinModule { -public: - /*! - * \brief Load CUBIN module from memory. - * - * \param bytes CUBIN binary data as a Bytes object. - */ - explicit CubinModule(const Bytes &bytes) { - TVM_FFI_CHECK_CUDA_ERROR( - cudaLibraryLoadData(&library_, bytes.data(), nullptr, nullptr, 0, nullptr, nullptr, 0)); - } - - /*! - * \brief Load CUBIN module from raw memory buffer. - * - * \param code Pointer to CUBIN binary data. - * \note The `code` buffer points to an ELF image. - */ - explicit CubinModule(const char *code) { - TVM_FFI_CHECK_CUDA_ERROR( - cudaLibraryLoadData(&library_, code, nullptr, nullptr, 0, nullptr, nullptr, 0)); - } - - /*! \brief Destructor unloads the library */ - ~CubinModule() { - if (library_ != nullptr) { - cudaLibraryUnload(library_); - } - } - - /*! - * \brief Get a kernel function from the module by name. - * - * \param name Name of the kernel function. - * \return CubinKernel object representing the loaded kernel. - */ - CubinKernel GetKernel(const char *name); - - /*! - * \brief Get a kernel function from the module by name with maximum dynamic shared memory. - * - * \param name Name of the kernel function. - * \param dynamic_smem_max Maximum dynamic shared memory in bytes to set for this kernel. - * -1 (default) means maximum available dynamic shared memory - * (device max - static shared memory used by kernel). - * \return CubinKernel object representing the loaded kernel. - */ - CubinKernel GetKernelWithMaxDynamicSharedMemory(const char *name, int64_t dynamic_smem_max); - - /*! - * \brief Operator[] for convenient kernel access. - * - * It's equivalent to calling GetKernel(name, -1). - * - * \param name Name of the kernel function. - * \return CubinKernel object representing the loaded kernel. - */ - CubinKernel operator[](const char *name); - - /*! \brief Get the underlying cudaLibrary_t handle */ - cudaLibrary_t GetHandle() const { return library_; } - - // Non-copyable - CubinModule(const CubinModule &) = delete; - CubinModule &operator=(const CubinModule &) = delete; - - /*! - * \brief Move constructor for CubinModule. - * - * Transfers ownership of the CUDA library handle from another CubinModule instance. - * - * \param other The source CubinModule to move from (will be left in an empty state). - */ - CubinModule(CubinModule &&other) noexcept : library_(other.library_) { other.library_ = nullptr; } - - /*! - * \brief Move assignment operator for CubinModule. - * - * Transfers ownership of the CUDA library handle from another CubinModule instance. - * Cleans up any existing library handle in this instance before taking ownership. - * - * \param other The source CubinModule to move from (will be left in an empty state). - * \return Reference to this CubinModule. - */ - CubinModule &operator=(CubinModule &&other) noexcept { - if (this != &other) { - if (library_ != nullptr) { - cudaLibraryUnload(library_); - } - library_ = other.library_; - other.library_ = nullptr; - } - return *this; - } - -private: - cudaLibrary_t library_ = nullptr; -}; - -/*! - * \brief CUDA kernel handle for launching kernels. - * - * This class represents a loaded CUDA kernel function and provides - * methods to launch it with specified grid/block dimensions, arguments, - * and stream configuration. Obtained from CubinModule by kernel name. - * - * \par Usage Pattern - * \code{.cpp} - * // Get kernel from module - * tvm::ffi::CubinKernel kernel = module["kernel_name"]; - * - * // Prepare arguments (must be pointers to actual values) - * void* data_ptr = tensor.data_ptr(); - * int64_t size = tensor.size(0); - * void* args[] = {&data_ptr, &size}; - * - * // Configure launch dimensions - * tvm::ffi::dim3 grid(32); // 32 blocks - * tvm::ffi::dim3 block(256); // 256 threads per block - * - * // Launch on stream - * cudaStream_t stream = ...; - * cudaError_t result = kernel.Launch(args, grid, block, stream); - * TVM_FFI_CHECK_CUDA_ERROR(result); - * \endcode - * - * \note This class is movable but not copyable. - * \see CubinModule for loading CUBIN and getting kernels - * \see dim3 for grid/block dimension specification - */ -class CubinKernel { -public: - /*! - * \brief Construct a CubinKernel from a library and kernel name. - * - * \param library The cudaLibrary_t handle. - * \param name Name of the kernel function. - */ - CubinKernel(cudaLibrary_t library, const char *name) { - TVM_FFI_CHECK_CUDA_ERROR(cudaLibraryGetKernel(&kernel_, library, name)); - } - - /*! \brief Destructor (kernel handle doesn't need explicit cleanup) */ - ~CubinKernel() = default; - - /*! - * \brief Launch the kernel with specified parameters. - * - * This function launches the kernel on the current CUDA context/device using - * the CUDA Runtime API. The kernel executes asynchronously on the specified stream. - * - * \par Argument Preparation - * The `args` array must contain pointers to the actual argument values, not the - * values themselves. For example: - * \code{.cpp} - * void* data_ptr = tensor.data_ptr(); - * int64_t size = 100; - * void* args[] = {&data_ptr, &size}; // Note: addresses of the variables - * \endcode - * - * \par Launch Configuration - * Grid and block dimensions determine the kernel's parallelism: - * - Grid: Number of thread blocks (can be 1D, 2D, or 3D) - * - Block: Number of threads per block (can be 1D, 2D, or 3D) - * - Total threads = grid.x * grid.y * grid.z * block.x * block.y * block.z - * - * \par Error Checking - * Always check the returned cudaError_t: - * \code{.cpp} - * cudaError_t result = kernel.Launch(args, grid, block, stream); - * TVM_FFI_CHECK_CUDA_ERROR(result); - * \endcode - * - * \param args Array of pointers to kernel arguments (must point to actual values). - * \param grid Grid dimensions (number of blocks in x, y, z). - * \param block Block dimensions (threads per block in x, y, z). - * \param stream CUDA stream to launch the kernel on (use 0 for default stream). - * \param dyn_smem_bytes Dynamic shared memory size in bytes (default: 0). - * \return cudaError_t error code from cudaLaunchKernel (cudaSuccess on success). - * - * \note The kernel executes asynchronously. Use cudaStreamSynchronize() or - * cudaDeviceSynchronize() to wait for completion if needed. - */ - cudaError_t Launch(void **args, dim3 grid, dim3 block, cudaStream_t stream, - uint32_t dyn_smem_bytes = 0) { - // Cast cudaKernel_t to const void* for use with cudaLaunchKernel - // The Runtime API accepts cudaKernel_t directly as a function pointer - auto kernel = reinterpret_cast(kernel_); - return cudaLaunchKernel(kernel, {grid.x, grid.y, grid.z}, {block.x, block.y, block.z}, args, - dyn_smem_bytes, stream); - } - - /*! \brief Get the underlying cudaKernel_t handle */ - cudaKernel_t GetHandle() const { return kernel_; } - - // Non-copyable - CubinKernel(const CubinKernel &) = delete; - CubinKernel &operator=(const CubinKernel &) = delete; - - /*! - * \brief Move constructor for CubinKernel. - * - * Transfers ownership of the CUDA kernel handle from another CubinKernel instance. - * - * \param other The source CubinKernel to move from (will be left in an empty state). - */ - CubinKernel(CubinKernel &&other) noexcept : kernel_(other.kernel_) { other.kernel_ = nullptr; } - - /*! - * \brief Move assignment operator for CubinKernel. - * - * Transfers ownership of the CUDA kernel handle from another CubinKernel instance. - * - * \param other The source CubinKernel to move from (will be left in an empty state). - * \return Reference to this CubinKernel. - */ - CubinKernel &operator=(CubinKernel &&other) noexcept { - if (this != &other) { - kernel_ = other.kernel_; - other.kernel_ = nullptr; - } - return *this; - } - -private: - /*! - * \brief Set maximum dynamic shared memory for this kernel across all devices. - * - * This method configures the maximum dynamic shared memory that can be allocated - * when launching this kernel. It must be called after the kernel is loaded. - * - * \param dynamic_smem_max Maximum dynamic shared memory in bytes to set. - * -1 (default) means maximum available dynamic shared memory, - * which is computed as (device max shared memory - static shared memory). - * For -1, the method queries the kernel's static shared memory usage - * and sets the attribute to the remaining available shared memory. - * - * \note This sets the maximum cap but doesn't force allocation. The actual dynamic - * shared memory used is controlled by the dyn_smem_bytes parameter in Launch(). - * \note This method attempts to set the attribute for all available devices and will - * only throw an error if it fails for ALL devices. - */ - void SetMaxDynamicSharedMemory(int64_t dynamic_smem_max = -1) { - int device_count = 0; - cudaError_t err = cudaGetDeviceCount(&device_count); - if (err != cudaSuccess || device_count == 0) { - return; // No devices available, nothing to configure - } - - bool any_success = false; - for (int device_id = 0; device_id < device_count; ++device_id) { - // Query device's maximum shared memory per block - int max_shared_mem = 0; - err = cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlock, device_id); - if (err != cudaSuccess) { - continue; // Skip this device if we can't get its attribute - } - - int shared_mem_to_set; - if (dynamic_smem_max == -1) { - // Query the kernel's static shared memory usage - cudaFuncAttributes func_attr; - - // According to the documentation, we can use cudaFuncGetAttributes to get the attributes of - // cudaKernel_t returned by cudaLibraryGetKernel, just cast the kernel_ to const void* - err = cudaFuncGetAttributes(&func_attr, reinterpret_cast(kernel_)); - if (err != cudaSuccess) { - continue; // Skip this device if we can't get kernel attributes - } - - // Calculate available dynamic shared memory: - // device max shared memory - static shared memory used by kernel - int64_t static_shared = static_cast(func_attr.sharedSizeBytes); - int64_t max_shared = static_cast(max_shared_mem); - int64_t available = max_shared - static_shared; - shared_mem_to_set = (available > 0) ? static_cast(available) : 0; - } else { - shared_mem_to_set = static_cast(dynamic_smem_max); - } - - // Set the maximum dynamic shared memory size for this device - err = cudaKernelSetAttributeForDevice(kernel_, cudaFuncAttributeMaxDynamicSharedMemorySize, - shared_mem_to_set, device_id); - if (err == cudaSuccess) { - any_success = true; - } - // Don't error out for individual device failures - user may only use some GPUs - } - - // Only error out if setting failed for ALL devices - if (!any_success && device_count > 0) { - TVM_FFI_THROW(RuntimeError) << "Failed to set dynamic shared memory attribute for any device"; - } - } - - cudaKernel_t kernel_ = nullptr; - - friend class CubinModule; -}; - -// Implementation of CubinModule methods that return CubinKernel -inline CubinKernel CubinModule::GetKernelWithMaxDynamicSharedMemory(const char *name, - int64_t dynamic_smem_max = -1) { - auto kernel = CubinKernel(library_, name); - kernel.SetMaxDynamicSharedMemory(dynamic_smem_max); - return kernel; -} - -inline CubinKernel CubinModule::GetKernel(const char *name) { - auto kernel = CubinKernel(library_, name); - return kernel; -} - -inline CubinKernel CubinModule::operator[](const char *name) { return GetKernel(name); } - -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/device_guard.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/device_guard.h deleted file mode 100644 index b55b5f3b2..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/cuda/device_guard.h +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/cuda/device_guard.h - * \brief Device guard structs. - */ -#ifndef TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_ -#define TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_ - -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief CUDA Device Guard. On construction, it calls `cudaGetDevice` to set the CUDA device to - * target index, and stores the original current device index. And on destruction, it sets the - * current CUDA device back to original device index. - * - * Example usage: - * \code - * void kernel(ffi::TensorView x) { - * ffi::CUDADeviceGuard guard(x.device().device_id); - * ... - * } - * \endcode - */ -struct CUDADeviceGuard { - CUDADeviceGuard() = delete; - /*! - * \brief Constructor from a device index, and store the original device index. - * \param device_index The device index to guard. - */ - explicit CUDADeviceGuard(int device_index) { - target_device_index_ = device_index; - TVM_FFI_CHECK_CUDA_ERROR(cudaGetDevice(&original_device_index_)); - if (target_device_index_ != original_device_index_) { - TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(device_index)); - } - } - - /*! - * \brief Destructor to set the current device index back to original one if different. - */ - ~CUDADeviceGuard() noexcept(false) { - if (original_device_index_ != target_device_index_) { - TVM_FFI_CHECK_CUDA_ERROR(cudaSetDevice(original_device_index_)); - } - } - -private: - int original_device_index_; - int target_device_index_; -}; - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/json.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/json.h deleted file mode 100644 index c871ae827..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/json.h +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/json.h - * \brief Minimal lightweight JSON parsing and serialization utilities - */ -#ifndef TVM_FFI_EXTRA_JSON_H_ -#define TVM_FFI_EXTRA_JSON_H_ - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace json { - -/*! - * \brief alias Any as json Value. - * - * To keep things lightweight, we simply reuse the ffi::Any system. - */ -using Value = Any; - -/*! - * \brief alias Map as json Object. - * \note We use Map instead of Map to avoid - * the overhead of key checking when doing as conversion, - * the check will be performed at runtime when we read each key - */ -using Object = ffi::Map; - -/*! \brief alias Array as json Array. */ -using Array = ffi::Array; - -/*! - * \brief Parse a JSON string into an Any value. - * - * Besides the standard JSON syntax, this function also supports: - * - Infinity/NaN as JavaScript syntax - * - int64 integer value - * - * If error_msg is not nullptr, the error message will be written to it - * and no exception will be thrown when parsing fails. - * - * \param json_str The JSON string to parse. - * \param error_msg The output error message, can be nullptr. - * - * \return The parsed Any value. - */ -TVM_FFI_EXTRA_CXX_API json::Value Parse(const String &json_str, String *error_msg = nullptr); - -/*! - * \brief Serialize an Any value into a JSON string. - * - * \param value The Any value to serialize. - * \param indent The number of spaces to indent the output. - * If not specified, the output will be compact. - * \return The output JSON string. - */ -TVM_FFI_EXTRA_CXX_API String Stringify(const json::Value &value, - Optional indent = std::nullopt); - -} // namespace json -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_JSON_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/module.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/module.h deleted file mode 100644 index 06fc7849d..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/module.h +++ /dev/null @@ -1,301 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/module.h - * \brief A managed dynamic module in the TVM FFI. - */ -#ifndef TVM_FFI_EXTRA_MODULE_H_ -#define TVM_FFI_EXTRA_MODULE_H_ - -#include -#include -#include -#include - -#include - -namespace tvm { -namespace ffi { - -// forward declare Module -class Module; - -/*! - * \brief A module that can dynamically load ffi::Functions or exportable source code. - * \sa Module - */ -class TVM_FFI_EXTRA_CXX_API ModuleObj : public Object { -public: - /*! - * \return The per module type key. - * \note This key is used to for serializing custom modules. - */ - virtual const char *kind() const = 0; - /*! - * \brief Get the property mask of the module. - * \return The property mask of the module. - * - * \sa Module::ModulePropertyMask - */ - virtual int GetPropertyMask() const { return 0b000; } - /*! - * \brief Get a ffi::Function from the module. - * \param name The name of the function. - * \return The function. - */ - virtual Optional GetFunction(const String &name) = 0; - /*! - * \brief Returns true if this module has a definition for a function of \p name. - * - * Note that even if this function returns true the corresponding \p GetFunction result - * may be nullptr if the function is not yet callable without further compilation. - * - * The default implementation just checks if \p GetFunction is non-null. - * \param name The name of the function. - * \return True if the module implements the function, false otherwise. - */ - virtual bool ImplementsFunction(const String &name) { return GetFunction(name).defined(); } - /*! - * \brief Get the docstring of the function, if available. - * \param name The name of the function. - * \return The documentation string if available, nullopt otherwise. - * - * \sa GetFunctionMetadata, TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC - */ - virtual Optional GetFunctionDoc(const String &name) { return std::nullopt; } - // Rationale: We separate the docstring from the metadata since docstrings - // can be unstructured and sometimes large, while metadata can be focused - // on storing structured information. - /*! - * \brief Get the metadata of the function, if available. - * \param name The name of the function. - * \return The metadata as JSON string if available, nullopt otherwise. - * - * \code - * Module mod = Module::LoadFromFile("lib.so"); - * Optional metadata = mod->GetFunctionMetadata("my_func"); - * if (metadata.has_value()) { - * // Parse JSON: {"type_schema": "..."} - * validate_signature(*metadata); - * } - * \endcode - * - * \sa GetFunctionDoc, TVM_FFI_DLL_EXPORT_TYPED_FUNC - */ - virtual Optional GetFunctionMetadata(const String &name) { return std::nullopt; } - /*! - * \brief Write the current module to file with given format (for further compilation). - * - * \param file_name The file to be saved to. - * \param format The format of the file. - * - * \note This function is mainly used by modules that - */ - virtual void WriteToFile(const String &file_name, const String &format) const { - TVM_FFI_THROW(RuntimeError) << "Module[" << kind() << "] does not support WriteToFile"; - } - /*! - * \brief Get the possible write formats of the module, when available. - * \return Possible write formats when available. - */ - virtual Array GetWriteFormats() const { return Array(); } - /*! - * \brief Serialize the the module to bytes. - * \return The serialized module. - */ - virtual Bytes SaveToBytes() const { - TVM_FFI_THROW(RuntimeError) << "Module[" << kind() << "] does not support SaveToBytes"; - TVM_FFI_UNREACHABLE(); - } - /*! - * \brief Get the source code of module, when available. - * \param format Format of the source code, can be empty by default. - * \return Possible source code when available, or empty string if not available. - */ - virtual String InspectSource(const String &format) const { return String(); } - /*! - * \brief Import another module. - * \param other The module to import. - */ - virtual void ImportModule(const Module &other); - /*! - * \brief Clear all imported modules. - */ - virtual void ClearImports(); - /*! - * \brief Overloaded function to optionally query from imports. - * \param name The name of the function. - * \param query_imports Whether to query imported modules. - * \return The function. - */ - Optional GetFunction(const String &name, bool query_imports); - /*! - * \brief Overloaded function to optionally query from imports. - * \param name The name of the function. - * \param query_imports Whether to query imported modules. - * \return True if the module implements the function, false otherwise. - */ - bool ImplementsFunction(const String &name, bool query_imports); - /*! - * \brief Get the function docstring of the function if available. - * \param name The name of the function. - * \param query_imports Whether to also query modules imported by this module. - * \return The documentation string if available, nullopt otherwise. - * - * \sa GetFunctionMetadata - */ - Optional GetFunctionDoc(const String &name, bool query_imports); - /*! - * \brief Get the function metadata of the function if available. - * \param name The name of the function. - * \param query_imports Whether to also query modules imported by this module. - * \return The metadata as JSON string if available, nullopt otherwise. - * - * \sa GetFunctionDoc - */ - Optional GetFunctionMetadata(const String &name, bool query_imports); - /*! - * \brief Get the imports of the module. - * \return The imports of the module. - * \note Note the signature is not part of the public API. - */ - const Array &imports() const { return this->imports_; } - - struct InternalUnsafe; - - /// \cond Doxygen_Suppress - static constexpr const int32_t _type_index = TypeIndex::kTVMFFIModule; - static constexpr const bool _type_mutable = true; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIModule, ModuleObj, Object); - /// \endcond - -protected: - friend struct InternalUnsafe; - - /*! - * \brief The modules that this module depends on. - * \note Use ObjectRef to avoid circular dep on Module. - */ - Array imports_; - -private: - /*! - * \brief cache used by TVMFFIModuleLookupFromImports - */ - Map import_lookup_cache_; -}; - -/*! - * \brief Reference to module object. - * - * When invoking a function on a ModuleObj, such as GetFunction, - * use operator-> to get the ModuleObj pointer and invoke the member functions. - * - * \code - * ffi::Module mod = ffi::Module::LoadFromFile("path/to/module.so"); - * ffi::Function func = mod->GetFunction(name); - * \endcode - * - * \sa ModuleObj which contains most of the function implementations. - */ -class Module : public ObjectRef { -public: - /*! - * \brief Property of ffi::Module - */ - enum ModulePropertyMask : int { - /*! - * \brief The module can be serialized to bytes. - * - * This prooperty indicates that module implements SaveToBytes. - * The system also registers a GlobalDef function - * `ffi.Module.load_from_bytes.` with signature (Bytes) -> Module. - */ - kBinarySerializable = 0b001, - /*! - * \brief The module can directly get runnable functions. - * - * This property indicates that module implements GetFunction that returns - * runnable ffi::Functions. - */ - kRunnable = 0b010, - /*! - * \brief The module can be exported to a object file or source file that then be compiled. - * - * This property indicates that module implements WriteToFile with a given format - * that can be queried by GetLibExportFormat. - * - * Examples include modules that can be exported to .o, .cc, .cu files. - * - * Such modules can be exported, compiled and loaded back as a dynamic library module. - */ - kCompilationExportable = 0b100 - }; - /*! - * \brief Constructor from ObjectPtr. - * \param ptr The object pointer. - */ - explicit Module(const ObjectPtr &ptr) : ObjectRef(ptr) { - TVM_FFI_ICHECK(ptr != nullptr); - } - /*! - * \brief Load a module from file. - * \param file_name The name of the host function module. - * \note This function won't load the import relationship. - * Re-create import relationship by calling Import. - */ - TVM_FFI_EXTRA_CXX_API static Module LoadFromFile(const String &file_name); - /*! - * \brief Query context symbols that is registered via TVMEnvRegisterSymbols. - * \param callback The callback to be called with the symbol name and address. - * \note This helper can be used to implement custom Module that needs to access context symbols. - */ - TVM_FFI_EXTRA_CXX_API static void VisitContextSymbols( - const ffi::TypedFunction &callback); - - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(Module, ObjectRef, ModuleObj); - /// \endcond -}; - -/* - * \brief Symbols for library module. - */ -namespace symbol { -/*!\ brief symbol prefix for tvm ffi related function symbols */ -constexpr const char *tvm_ffi_symbol_prefix = "__tvm_ffi_"; -// Special symbols have one extra _ prefix to avoid conflict with user symbols -/*! - * \brief Default entry function of a library module is tvm_ffi_symbol_prefix + "main" - */ -constexpr const char *tvm_ffi_main = "__tvm_ffi_main"; -/*! \brief Global variable to store context pointer for a library module. */ -constexpr const char *tvm_ffi_library_ctx = "__tvm_ffi__library_ctx"; -/*! \brief Global variable to store binary data alongside a library module. */ -constexpr const char *tvm_ffi_library_bin = "__tvm_ffi__library_bin"; -/*! \brief Optional metadata prefix of a symbol. */ -constexpr const char *tvm_ffi_metadata_prefix = "__tvm_ffi__metadata_"; -/*! \brief Optional documentation prefix of a symbol. */ -constexpr const char *tvm_ffi_doc_prefix = "__tvm_ffi__doc_"; -} // namespace symbol -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_EXTRA_MODULE_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/serialization.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/serialization.h deleted file mode 100644 index 3a726504f..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/serialization.h +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/serialization.h - * \brief Reflection-based serialization utilities - */ -#ifndef TVM_FFI_EXTRA_SERIALIZATION_H_ -#define TVM_FFI_EXTRA_SERIALIZATION_H_ - -#include -#include - -namespace tvm { -namespace ffi { - -/** - * \brief Serialize ffi::Any to a JSON that stores the object graph. - * - * The JSON graph structure is stored as follows: - * - * ``` - * { - * "root_index": , // Index of root node in nodes array - * "nodes": [, ...], // Array of serialized nodes - * "metadata": // Optional metadata - * } - * ``` - * - * Each node has the format: `{"type": "", "data": }` - * For object types and strings, the data may contain indices to other nodes. - * For object fields whose static type is known as a primitive type, it is stored directly, - * otherwise, it is stored as a reference to the nodes array by an index. - * - * This function preserves the type and multiple references to the same object, - * which is useful for debugging and serialization. - * - * \param value The ffi::Any value to serialize. - * \param metadata Extra metadata attached to "metadata" field of the JSON object. - * \return The serialized JSON value. - */ -TVM_FFI_EXTRA_CXX_API json::Value ToJSONGraph(const Any &value, const Any &metadata = Any(nullptr)); - -/** - * \brief Deserialize a JSON that stores the object graph to an ffi::Any value. - * - * This function can be used to implement deserialization - * and debugging. - * - * \param value The JSON value to deserialize. - * \return The deserialized object graph. - */ -TVM_FFI_EXTRA_CXX_API Any FromJSONGraph(const json::Value &value); - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_SERIALIZATION_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_equal.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_equal.h deleted file mode 100644 index 1ee5780d8..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_equal.h +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/structural_equal.h - * \brief Structural equal implementation - */ -#ifndef TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ -#define TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -/*! - * \brief Structural equality comparators - */ -class StructuralEqual { -public: - /** - * \brief Compare two Any values for structural equality. - * \param lhs The left hand side Any object. - * \param rhs The right hand side Any object. - * \param map_free_vars Whether to map free variables. - * \param skip_tensor_content Whether to skip comparingn darray data content, - * useful for cases where we don't care about parameters content - * \return True if the two Any values are structurally equal, false otherwise. - */ - TVM_FFI_EXTRA_CXX_API static bool Equal(const Any &lhs, const Any &rhs, - bool map_free_vars = false, - bool skip_tensor_content = false); - /** - * \brief Get the first mismatch AccessPath pair when running - * structural equal comparison between two Any values. - * - * \param lhs The left hand side Any object. - * \param rhs The right hand side Any object. - * \param map_free_vars Whether to map free variables. - * \param skip_tensor_content Whether to skip comparing tensor data content, - * useful for cases where we don't care about parameters content - * \return If comparison fails, return the first mismatch AccessPath pair, - * otherwise return std::nullopt. - */ - TVM_FFI_EXTRA_CXX_API static Optional GetFirstMismatch( - const Any &lhs, const Any &rhs, bool map_free_vars = false, bool skip_tensor_content = false); - - /* - * \brief Compare two Any values for structural equality. - * \param lhs The left hand side Any object. - * \param rhs The right hand side Any object. - * \return True if the two Any values are structurally equal, false otherwise. - */ - TVM_FFI_INLINE bool operator()(const Any &lhs, const Any &rhs) const { - return Equal(lhs, rhs, false, true); - } -}; - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_STRUCTURAL_EQUAL_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_hash.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_hash.h deleted file mode 100644 index b27181b37..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/extra/structural_hash.h +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/extra/structural_hash.h - * \brief Structural hash - */ -#ifndef TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ -#define TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ - -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Structural hash - */ -class StructuralHash { -public: - /*! - * \brief Hash an Any value. - * \param value The Any value to hash. - * \param map_free_vars Whether to map free variables. - * \param skip_tensor_content Whether to skip comparingn darray data content, - * useful for cases where we don't care about parameters content. - * \return The hash value. - */ - TVM_FFI_EXTRA_CXX_API static uint64_t Hash(const Any &value, bool map_free_vars = false, - bool skip_tensor_content = false); - /*! - * \brief Hash an Any value. - * \param value The Any value to hash. - * \return The hash value. - */ - TVM_FFI_INLINE uint64_t operator()(const Any &value) const { return Hash(value); } -}; - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_EXTRA_STRUCTURAL_HASH_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/function.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/function.h deleted file mode 100644 index 4854ecd1d..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/function.h +++ /dev/null @@ -1,998 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/function.h - * \brief A managed function in the TVM FFI. - */ -#ifndef TVM_FFI_FUNCTION_H_ -#define TVM_FFI_FUNCTION_H_ - -/*! - * \brief Controls whether DLL exports should include metadata. - * - * When set to 1, exported functions will include additional metadata. - * When set to 0 (default), exports are minimal without metadata. - */ -#ifndef TVM_FFI_DLL_EXPORT_INCLUDE_METADATA -#define TVM_FFI_DLL_EXPORT_INCLUDE_METADATA 0 -#endif - -#include "any.h" -#include "base_details.h" -#include "c_api.h" -#include "error.h" -#include "function_details.h" - -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/** - * Helper macro to construct a safe call - * - * \brief Marks the beginning of the safe call that catches exception explicitly - * \sa TVM_FFI_SAFE_CALL_END - * - * \code - * int TVMFFICStyleFunction() { - * TVM_FFI_SAFE_CALL_BEGIN(); - * // c++ code region here - * TVM_FFI_SAFE_CALL_END(); - * } - * \endcode - */ -#define TVM_FFI_SAFE_CALL_BEGIN() \ - try { \ - (void)0 - -/*! - * \brief Marks the end of safe call. - */ -#define TVM_FFI_SAFE_CALL_END() \ - return 0; \ - } \ - catch (const ::tvm::ffi::Error &err) { \ - ::tvm::ffi::details::SetSafeCallRaised(err); \ - return -1; \ - } \ - catch (const ::tvm::ffi::EnvErrorAlreadySet &) { \ - return -2; \ - } \ - catch (const std::exception &ex) { \ - ::tvm::ffi::details::SetSafeCallRaised(::tvm::ffi::Error("InternalError", ex.what(), "")); \ - return -1; \ - } \ - TVM_FFI_UNREACHABLE() - -/*! - * \brief Macro to check a call to TVMFFISafeCallType and raise exception if error happens. - * \param func The function to check. - * - * \code - * // calls TVMFFIFunctionCall and raises exception if error happens - * TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_arr, &type_index)); - * \endcode - */ -#define TVM_FFI_CHECK_SAFE_CALL(func) \ - { \ - int ret_code = (func); \ - if (ret_code != 0) { \ - if (ret_code == -2) { \ - throw ::tvm::ffi::EnvErrorAlreadySet(); \ - } \ - throw ::tvm::ffi::details::MoveFromSafeCallRaised(); \ - } \ - } - -/*! - * \brief Object container class that backs ffi::Function - * \note Do not use this class directly, use ffi::Function - */ -class FunctionObj : public Object, public TVMFFIFunctionCell { -public: - /*! \brief Typedef for C++ style calling signature that comes with exception propagation */ - using FCall = void (*)(const FunctionObj *, const AnyView *, int32_t, Any *); - using TVMFFIFunctionCell::cpp_call; - using TVMFFIFunctionCell::safe_call; - /*! - * \brief Call the function in packed format. - * \param args The arguments - * \param num_args The number of arguments - * \param result The return value. - */ - TVM_FFI_INLINE void CallPacked(const AnyView *args, int32_t num_args, Any *result) const { - // if cpp_call is set, use it to call the function, otherwise, redirect to safe_call - // use conditional expression here so the select is branchless - FCall call_ptr = this->cpp_call ? reinterpret_cast(this->cpp_call) : CppCallDedirectToSafeCall; - (*call_ptr)(this, args, num_args, result); - } - /// \cond Doxygen_Suppress - static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIFunction; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIFunction, FunctionObj, Object); - /// \endcond - -protected: - /*! \brief Make default constructor protected. */ - FunctionObj() {} - friend class Function; - -private: - static void CppCallDedirectToSafeCall(const FunctionObj *func, const AnyView *args, - int32_t num_args, Any *rv) { - FunctionObj *self = static_cast(const_cast(func)); - TVM_FFI_CHECK_SAFE_CALL(self->safe_call(self, reinterpret_cast(args), - num_args, reinterpret_cast(rv))); - } -}; - -namespace details { -/*! - * \brief Derived object class for constructing FunctionObj backed by a TCallable - * - * This is a helper class that implements the function call interface. - * Invariance: TCallable cannot be const or reference type. - */ -template -class FunctionObjImpl : public FunctionObj { -public: - static_assert(std::is_same_v>>, - "TCallable of FunctionObjImpl cannot be const or reference type"); - - /*! \brief The type of derived object class */ - using TSelf = FunctionObjImpl; - - /*! - * \brief Derived object class for constructing ffi::FunctionObj. - * \param callable The type-erased callable object (rvalue). - */ - explicit FunctionObjImpl(TCallable &&callable) : callable_(std::move(callable)) { - this->safe_call = SafeCall; - this->cpp_call = reinterpret_cast(CppCall); - } - /*! - * \brief Derived object class for constructing ffi::FunctionObj. - * \param callable The type-erased callable object (lvalue). - */ - explicit FunctionObjImpl(const TCallable &callable) : callable_(callable) { - this->safe_call = SafeCall; - this->cpp_call = reinterpret_cast(CppCall); - } - -private: - // implementation of call - static void CppCall(const FunctionObj *func, const AnyView *args, int32_t num_args, Any *result) { - (static_cast(func))->callable_(args, num_args, result); - } - /// \cond Doxygen_Suppress - // Implementing safe call style - static int SafeCall(void *func, const TVMFFIAny *args, int32_t num_args, TVMFFIAny *result) { - TVM_FFI_SAFE_CALL_BEGIN(); - TVM_FFI_ICHECK_LT(result->type_index, TypeIndex::kTVMFFIStaticObjectBegin); - FunctionObj *self = static_cast(func); - reinterpret_cast(self->cpp_call)(self, reinterpret_cast(args), num_args, - reinterpret_cast(result)); - TVM_FFI_SAFE_CALL_END(); - } - /// \endcond - /*! \brief Type-erased filed for storing callable object*/ - mutable TCallable callable_; -}; - -/*! - * \brief FunctionObj specialization for raw C style callback where handle and deleter are null. - */ -class ExternCFunctionObjNullHandleImpl : public FunctionObj { -public: - explicit ExternCFunctionObjNullHandleImpl(TVMFFISafeCallType safe_call) { - this->safe_call = safe_call; - this->cpp_call = nullptr; - } -}; - -/*! - * \brief FunctionObj specialization that leverages C-style callback definitions. - */ -class ExternCFunctionObjImpl : public FunctionObj { -public: - ExternCFunctionObjImpl(void *self, TVMFFISafeCallType safe_call, void (*deleter)(void *self)) - : self_(self), safe_call_(safe_call), deleter_(deleter) { - this->safe_call = SafeCall; - this->cpp_call = nullptr; - } - - ~ExternCFunctionObjImpl() { - if (deleter_) { - deleter_(self_); - } - } - -private: - static int32_t SafeCall(void *func, const TVMFFIAny *args, int32_t num_args, TVMFFIAny *rv) { - ExternCFunctionObjImpl *self = reinterpret_cast(func); - return self->safe_call_(self->self_, args, num_args, rv); - } - - void *self_; - TVMFFISafeCallType safe_call_; - void (*deleter_)(void *self); -}; - -// Helper class to set packed arguments -class PackedArgsSetter { -public: - explicit PackedArgsSetter(AnyView *args) : args_(args) {} - - // NOTE: setter needs to be very carefully designed - // such that we do not have temp variable conversion(eg. convert from lvalue to rvalue) - // that is why we need T&& and std::forward here - template - TVM_FFI_INLINE void operator()(size_t i, T &&value) const { - args_[i].operator=(std::forward(value)); - } - -private: - AnyView *args_; -}; -} // namespace details - -/*! - * \brief Represents arguments packed in AnyView array - * \note This class represent packed arguments to ffi::Function - */ -class PackedArgs { -public: - /*! - * \brief Constructor - * \param data The arguments - * \param size The number of arguments - */ - PackedArgs(const AnyView *data, int32_t size) : data_(data), size_(size) {} - - /*! \return size of the arguments */ - int size() const { return size_; } - - /*! \return The arguments */ - const AnyView *data() const { return data_; } - - /*! - * \brief Slice the arguments - * \param begin The begin index - * \param end The end index - * \return The sliced arguments - */ - PackedArgs Slice(int begin, int end = -1) const { - if (end == -1) { - end = size_; - } - return PackedArgs(data_ + begin, end - begin); - } - - /*! - * \brief Get i-th argument - * \param i the index. - * \return the ith argument. - */ - AnyView operator[](int i) const { return data_[i]; } - - /*! - * \brief Fill the arguments into the AnyView array - * \param data The AnyView array to store the packed arguments - * \param args The arguments to be packed - * \note Caller must ensure all args are alive during lifetime of data. - * A common pitfall is to pass in local variables that are immediately - * destroyed after calling Fill. - */ - template - TVM_FFI_INLINE static void Fill(AnyView *data, Args &&...args) { - details::for_each(details::PackedArgsSetter(data), std::forward(args)...); - } - -private: - /*! \brief The arguments */ - const AnyView *data_; - /*! \brief The number of arguments */ - int32_t size_; -}; - -/*! - * \brief ffi::Function is a type-erased function. - * The arguments are passed by "packed format" via AnyView - */ -class Function : public ObjectRef { -public: - /*! \brief Constructor from null */ - Function(std::nullptr_t) : ObjectRef(nullptr) {} // NOLINT(*) - /*! - * \brief Constructing a packed function from a callable type - * whose signature is consistent with `ffi::Function` - * \param packed_call The packed function signature - * \note legacy purpose, should change to Function::FromPacked for mostfuture use. - */ - template , Function>>> - explicit Function(TCallable &&packed_call) { - *this = FromPacked(std::forward(packed_call)); - } - /*! - * \brief Constructing a packed function from a callable type - * whose signature is consistent with `ffi::Function` - * \param packed_call The packed function signature - */ - template - static Function FromPacked(TCallable &&packed_call) { - static_assert( - std::is_convertible_v> || std::is_convertible_v>, - "tvm::ffi::Function::FromPacked requires input function signature to match packed func " - "format"); - if constexpr (std::is_convertible_v>) { - return FromPackedInternal( - [packed_call = std::forward(packed_call)]( - const AnyView *args, int32_t num_args, Any *rv) mutable -> void { - packed_call(PackedArgs{args, num_args}, rv); - }); - } else { - return FromPackedInternal(std::forward(packed_call)); - } - } - - /*! - * \brief Create ffi::Function from a C style callbacks. - * - * self and deleter can be nullptr if the function do not need closure support - * and corresponds to a raw function pointer. - * - * \param self Resource handle to the function - * \param safe_call The safe_call definition in C. - * \param deleter The deleter to release the resource of self. - * \return The created function. - */ - static Function FromExternC(void *self, TVMFFISafeCallType safe_call, - void (*deleter)(void *self)) { - // the other function coems from a different library - Function func; - if (self == nullptr && deleter == nullptr) { - func.data_ = make_object(safe_call); - } else { - func.data_ = make_object(self, safe_call, deleter); - } - return func; - } - /*! - * \brief Get global function by name - * \param name The function name - * \return The global function. - * \note This function will return std::nullopt if the function is not found. - */ - static std::optional GetGlobal(std::string_view name) { - TVMFFIObjectHandle handle; - TVMFFIByteArray name_arr{name.data(), name.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionGetGlobal(&name_arr, &handle)); - if (handle != nullptr) { - return Function( - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle))); - } else { - return std::nullopt; - } - } - - /*! - * \brief Get global function by name - * \param name The name of the function - * \return The global function - * \note This function will return std::nullopt if the function is not found. - */ - static std::optional GetGlobal(const std::string &name) { - return GetGlobal(std::string_view(name.data(), name.length())); - } - - /*! - * \brief Get global function by name - * \param name The name of the function - * \return The global function - * \note This function will return std::nullopt if the function is not found. - */ - static std::optional GetGlobal(const String &name) { - return GetGlobal(std::string_view(name.data(), name.length())); - } - - /*! - * \brief Get global function by name - * \param name The name of the function - * \return The global function - * \note This function will return std::nullopt if the function is not found. - */ - static std::optional GetGlobal(const char *name) { - return GetGlobal(std::string_view(name)); - } - /*! - * \brief Get global function by name and throw an error if it is not found. - * \param name The name of the function - * \return The global function - * \note This function will throw an error if the function is not found. - */ - static Function GetGlobalRequired(std::string_view name) { - std::optional res = GetGlobal(name); - if (!res.has_value()) { - TVM_FFI_THROW(ValueError) << "Function " << name << " not found"; - } - return *res; - } - - /*! - * \brief Get global function by name - * \param name The name of the function - * \return The global function - * \note This function will throw an error if the function is not found. - */ - static Function GetGlobalRequired(const std::string &name) { - return GetGlobalRequired(std::string_view(name.data(), name.length())); - } - - /*! - * \brief Get global function by name - * \param name The name of the function - * \return The global function - * \note This function will throw an error if the function is not found. - */ - static Function GetGlobalRequired(const String &name) { - return GetGlobalRequired(std::string_view(name.data(), name.length())); - } - - /*! - * \brief Get global function by name - * \param name The name of the function - * \return The global function - * \note This function will throw an error if the function is not found. - */ - static Function GetGlobalRequired(const char *name) { - return GetGlobalRequired(std::string_view(name)); - } - /*! - * \brief Set global function by name - * \param name The name of the function - * \param func The function - * \param override Whether to override when there is duplication. - */ - static void SetGlobal(std::string_view name, - Function func, // NOLINT(performance-unnecessary-value-param) - bool override = false) { - TVMFFIByteArray name_arr{name.data(), name.size()}; - TVM_FFI_CHECK_SAFE_CALL( - TVMFFIFunctionSetGlobal(&name_arr, details::ObjectUnsafe::GetHeader(func.get()), override)); - } - /*! - * \brief List all global names - * \return A vector of all global names - * \note This function do not depend on Array so core do not have container dep. - */ - static std::vector ListGlobalNames() { - Function fname_functor = GetGlobalRequired("ffi.FunctionListGlobalNamesFunctor")().cast(); - std::vector names; - int len = fname_functor(-1).cast(); - names.reserve(len); - for (int i = 0; i < len; ++i) { - names.push_back(fname_functor(i).cast()); - } - return names; - } - /** - * \brief Remove a global function by name - * \param name The name of the function - */ - static void RemoveGlobal(const String &name) { - static Function fremove = GetGlobalRequired("ffi.FunctionRemoveGlobal"); - fremove(name); - } - /*! - * \brief Constructing a packed function from a normal function. - * - * \param callable the internal container of packed function. - */ - template - static Function FromTyped(TCallable &&callable) { - using FuncInfo = details::FunctionInfo>; - // Callable is always captured by value here to avoid possible dangling reference - auto call_packed = [callable = std::forward(callable)]( - const AnyView *args, int32_t num_args, Any *rv) mutable -> void { - details::unpack_call( - std::make_index_sequence{}, nullptr, callable, args, num_args, rv); - }; - return FromPackedInternal(std::move(call_packed)); - } - /*! - * \brief Constructing a packed function from a normal function. - * - * \param callable the internal container of packed function. - * \param name optional name attacked to the function. - */ - template - static Function FromTyped(TCallable &&callable, std::string name) { - using FuncInfo = details::FunctionInfo>; - // Callable is always captured by value here to avoid possible dangling reference - auto call_packed = [callable = std::forward(callable), name = std::move(name)]( - const AnyView *args, int32_t num_args, Any *rv) mutable -> void { - details::unpack_call( - std::make_index_sequence{}, &name, callable, args, num_args, rv); - }; - return FromPackedInternal(std::move(call_packed)); - } - - /*! - * \brief Directly invoke an extern "C" function that follows the TVM FFI SafeCall convention. - * - * This function can be useful to turn an existing exported symbol into a typed function. - * - * \code - * - * // An extern "C" function, matching TVMFFISafeCallType - * extern "C" int __tvm_ffi_add( - * void* handle, const TVMFFIAny* args, int32_t num_args, TVMFFIAny*result - * ); - * - * // redirect an existing symbol into a typed function - * inline int add(int a, int b) { - * return tvm::ffi::Function::InvokeExternC(nullptr, __tvm_ffi_add, a, b).cast(); - * } - * - * \endcode - * - * \tparam Args The types of the arguments to the extern function. - * \param handle The handle argument, for exported symbols this is usually nullptr. - * \param safe_call The function pointer to the extern "C" function. - * \param args The arguments to pass to the function. - * \return The return value, wrapped in a tvm::ffi::Any. - */ - template - TVM_FFI_INLINE static Any InvokeExternC(void *handle, TVMFFISafeCallType safe_call, - Args &&...args) { - const int kNumArgs = sizeof...(Args); - const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; - AnyView args_pack[kArraySize]; - PackedArgs::Fill(args_pack, std::forward(args)...); - Any result; - TVM_FFI_CHECK_SAFE_CALL(safe_call(handle, reinterpret_cast(args_pack), - kNumArgs, reinterpret_cast(&result))); - return result; - } - /*! - * \brief Call function by directly passing in unpacked arguments. - * - * \param args Arguments to be passed. - * \tparam Args arguments to be passed. - * - * \code - * // Example code on how to call packed function - * void CallFFIFunction(tvm::ffi::Function f) { - * // call like normal functions by pass in arguments - * // return value is automatically converted back - * int rvalue = f(1, 2.0); - * } - * \endcode - */ - template - TVM_FFI_INLINE Any operator()(Args &&...args) const { - const int kNumArgs = sizeof...(Args); - const int kArraySize = kNumArgs > 0 ? kNumArgs : 1; - AnyView args_pack[kArraySize]; - PackedArgs::Fill(args_pack, std::forward(args)...); - Any result; - static_cast(data_.get())->CallPacked(args_pack, kNumArgs, &result); - return result; - } - /*! - * \brief Call the function in packed format. - * \param args The arguments - * \param num_args The number of arguments - * \param result The return value. - */ - TVM_FFI_INLINE void CallPacked(const AnyView *args, int32_t num_args, Any *result) const { - static_cast(data_.get())->CallPacked(args, num_args, result); - } - /*! - * \brief Call the function in packed format. - * \param args The arguments - * \param result The return value. - */ - TVM_FFI_INLINE void CallPacked(PackedArgs args, Any *result) const { - static_cast(data_.get())->CallPacked(args.data(), args.size(), result); - } - - /*! \return Whether the packed function is nullptr */ - TVM_FFI_INLINE bool operator==(std::nullptr_t) const { return data_ == nullptr; } - /*! \return Whether the packed function is not nullptr */ - TVM_FFI_INLINE bool operator!=(std::nullptr_t) const { return data_ != nullptr; } - - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Function, ObjectRef, FunctionObj); - /// \endcond - - class Registry; - -private: - /*! - * \brief Constructing a packed function from a callable type - * whose signature is consistent with `ffi::Function` - * \param packed_call The packed function signature - */ - template - static Function FromPackedInternal(TCallable &&packed_call) { - // We must make TCallable a value type (decay_t) that can hold the callable object - using ObjType = typename details::FunctionObjImpl>; - Function func; - func.data_ = make_object(std::forward(packed_call)); - return func; - } -}; - -/*! - * \brief Please refer to \ref TypedFunctionAnchor "TypedFunction" - */ -template -class TypedFunction; - -/*! - * \anchor TypedFunctionAnchor - * \brief A ffi::Function wrapper to provide typed function signature. - * It is backed by a ffi::Function internally. - * - * TypedFunction enables compile time type checking. - * TypedFunction works with the runtime system: - * - It can be passed as an argument of ffi::Function. - * - It can be assigned to ffi::Any. - * - It can be directly converted to a type-erased ffi::Function. - * - * Developers should prefer TypedFunction over ffi::Function in C++ code - * as it enables compile time checking. - * We can construct a TypedFunction from a lambda function - * with the same signature. - * - * \code - * // user defined lambda function. - * auto addone = [](int x)->int { - * return x + 1; - * }; - * // We can directly convert - * // lambda function to TypedFunction - * TypedFunction ftyped(addone); - * // invoke the function. - * int y = ftyped(1); - * // Can be directly converted to ffi::Function - * ffi::Function packed = ftype; - * \endcode - * \tparam R The return value of the function. - * \tparam Args The argument signature of the function. - */ -template -class TypedFunction { -public: - /*! \brief short hand for this function type */ - using TSelf = TypedFunction; - /*! \brief default constructor */ - TypedFunction() = default; - /*! \brief constructor from null */ - TypedFunction(std::nullptr_t null) {} // NOLINT(*) - /*! - * \brief constructor from a function - * \param packed The function - */ - TypedFunction(Function packed) : packed_(std::move(packed)) {} // NOLINT(*) - /*! - * \brief construct from a lambda function with the same signature. - * - * Example usage: - * \code - * auto typed_lambda = [](int x)->int { return x + 1; } - * // construct from packed function - * TypedFunction ftyped(typed_lambda, "add_one"); - * // call the typed version. - * CHECK_EQ(ftyped(1), 2); - * \endcode - * - * \param typed_lambda typed lambda function. - * \param name the name of the lambda function. - * \tparam FLambda the type of the lambda function. - */ - template >>> - TypedFunction(FLambda &&typed_lambda, std::string name) { - packed_ = Function::FromTyped(std::forward(typed_lambda), std::move(name)); - } - /*! - * \brief construct from a lambda function with the same signature. - * - * This version does not take a name. It is highly recommend you use the - * version that takes a name for the lambda. - * - * Example usage: - * \code - * auto typed_lambda = [](int x)->int { return x + 1; } - * // construct from packed function - * TypedFunction ftyped(typed_lambda); - * // call the typed version. - * CHECK_EQ(ftyped(1), 2); - * \endcode - * - * \param typed_lambda typed lambda function. - * \tparam FLambda the type of the lambda function. - */ - template > && !std::is_same_v, TSelf>>> - TypedFunction(FLambda &&typed_lambda) { // NOLINT(google-explicit-constructor) - packed_ = Function::FromTyped(std::forward(typed_lambda)); - } - /*! - * \brief copy assignment operator from typed lambda - * - * Example usage: - * \code - * // construct from packed function - * TypedFunction ftyped; - * ftyped = [](int x) { return x + 1; } - * // call the typed version. - * CHECK_EQ(ftyped(1), 2); - * \endcode - * - * \param typed_lambda typed lambda function. - * \tparam FLambda the type of the lambda function. - * \returns reference to self. - */ - template > && !std::is_same_v, TSelf>>> - TSelf &operator=(FLambda &&typed_lambda) { - packed_ = Function::FromTyped(std::forward(typed_lambda)); - return *this; - } - /*! - * \brief copy assignment operator from ffi::Function. - * \param packed The packed function. - * \returns reference to self. - */ - TSelf &operator=(Function packed) { - packed_ = std::move(packed); - return *this; - } - /*! - * \brief Invoke the operator. - * \param args The arguments - * \returns The return value. - */ - TVM_FFI_INLINE R operator()(Args... args) const { // NOLINT(performance-unnecessary-value-param) - if constexpr (std::is_same_v) { - packed_(std::forward(args)...); - } else { - Any res = packed_(std::forward(args)...); - if constexpr (std::is_same_v) { - return res; - } else { - return std::move(res).cast(); - } - } - } - /*! - * \brief convert to ffi::Function - * \return the internal ffi::Function - */ - operator Function() const { return packed(); } // NOLINT(google-explicit-constructor) - /*! - * \return reference the internal ffi::Function - */ - const Function &packed() const & { return packed_; } - /*! - * \return r-value reference the internal ffi::Function - */ - constexpr Function &&packed() && { return std::move(packed_); } - /*! \return Whether the packed function is nullptr */ - bool operator==(std::nullptr_t null) const { return packed_ == nullptr; } - /*! \return Whether the packed function is not nullptr */ - bool operator!=(std::nullptr_t null) const { return packed_ != nullptr; } - /*! - * \brief Get the type schema of `TypedFunction` in json format. - * \return The type schema of the function in json format. - */ - static std::string TypeSchema() { return details::FuncFunctorImpl::TypeSchema(); } - -private: - /*! \brief The internal packed function */ - Function packed_; -}; - -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFunction; - - TVM_FFI_INLINE static void CopyToAnyView(const TypedFunction &src, TVMFFIAny *result) { - TypeTraits::CopyToAnyView(src.packed(), result); - } - - TVM_FFI_INLINE static void MoveToAny(TypedFunction src, TVMFFIAny *result) { - // Move from rvalue to trigger TypedFunction rvalue packed() overload - TypeTraits::MoveToAny(std::move(src).packed(), result); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - return src->type_index == TypeIndex::kTVMFFIFunction; - } - - TVM_FFI_INLINE static TypedFunction CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { - return TypedFunction(TypeTraits::CopyFromAnyViewAfterCheck(src)); - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView( - const TVMFFIAny *src) { - std::optional opt = TypeTraits::TryCastFromAnyView(src); - if (opt.has_value()) { - return TypedFunction(*std::move(opt)); - } else { - return std::nullopt; - } - } - - TVM_FFI_INLINE static std::string TypeStr() { return details::FunctionInfo::Sig(); } - TVM_FFI_INLINE static std::string TypeSchema() { return TypedFunction::TypeSchema(); } -}; - -/*! - * \brief helper function to get type index from key - */ -inline int32_t TypeKeyToIndex(std::string_view type_key) { - int32_t type_index; - TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); - return type_index; -} - -/// \cond Doxygen_Suppress -// Internal implementation macros used by TVM_FFI_DLL_EXPORT_TYPED_FUNC and related macros. -// These should not be used directly; use the public macros instead. - -// Internal implementation macro that generates the C ABI wrapper function -#define TVM_FFI_DLL_EXPORT_TYPED_FUNC_IMPL_(ExportName, Function) \ - extern "C" { \ - TVM_FFI_DLL_EXPORT int __tvm_ffi_##ExportName(void *self, const TVMFFIAny *args, \ - int32_t num_args, TVMFFIAny *result) { \ - TVM_FFI_SAFE_CALL_BEGIN(); \ - using FuncInfo = ::tvm::ffi::details::FunctionInfo; \ - static std::string name = #ExportName; \ - ::tvm::ffi::details::unpack_call( \ - std::make_index_sequence{}, &name, Function, \ - reinterpret_cast(args), num_args, \ - reinterpret_cast<::tvm::ffi::Any *>(result)); \ - TVM_FFI_SAFE_CALL_END(); \ - } \ - } -/// \endcond - -/*! - * \brief Export typed function as a SafeCallType symbol that follows the FFI ABI. - * - * This macro exports the function and automatically exports metadata when - * TVM_FFI_DLL_EXPORT_INCLUDE_METADATA is defined. - * - * \param ExportName The symbol name to be exported. - * \param Function The typed function. - * - * \sa ffi::TypedFunction, TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC - * - * \code - * - * int AddOne_(int x) { - * return x + 1; - * } - * - * // Expose the function as "AddOne" - * TVM_FFI_DLL_EXPORT_TYPED_FUNC(AddOne, AddOne_); - * - * // Expose the function as "SubOne" - * TVM_FFI_DLL_EXPORT_TYPED_FUNC(SubOne, [](int x) { - * return x - 1; - * }); - * \endcode - * - * \note The final symbol names are: - * - `__tvm_ffi_` (function) - * - `__tvm_ffi__metadata_` (metadata - only when - * TVM_FFI_DLL_EXPORT_INCLUDE_METADATA is defined) - */ -#if TVM_FFI_DLL_EXPORT_INCLUDE_METADATA -#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ - TVM_FFI_DLL_EXPORT_TYPED_FUNC_IMPL_(ExportName, Function) \ - extern "C" { \ - TVM_FFI_DLL_EXPORT int __tvm_ffi__metadata_##ExportName(void *self, const TVMFFIAny *args, \ - int32_t num_args, TVMFFIAny *result) { \ - TVM_FFI_SAFE_CALL_BEGIN(); \ - using FuncInfo = ::tvm::ffi::details::FunctionInfo; \ - std::ostringstream os; \ - os << R"({"type_schema":)" \ - << ::tvm::ffi::EscapeString(::tvm::ffi::String(FuncInfo::TypeSchema())) << R"(})"; \ - ::tvm::ffi::String str(os.str()); \ - ::tvm::ffi::TypeTraits<::tvm::ffi::String>::MoveToAny(std::move(str), result); \ - TVM_FFI_SAFE_CALL_END(); \ - } \ - } -#else -#define TVM_FFI_DLL_EXPORT_TYPED_FUNC(ExportName, Function) \ - TVM_FFI_DLL_EXPORT_TYPED_FUNC_IMPL_(ExportName, Function) -#endif - -/*! - * \brief Export documentation string for a typed function. - * - * This macro exports a documentation string associated with a function export name. - * The docstring can be used by stub generators and documentation tools. - * This macro only exports the docstring; it does not export the function itself. - * - * \param ExportName The symbol name that the docstring is associated with. - * \param DocString The documentation string (C string literal). - * - * \sa ffi::TypedFunction, TVM_FFI_DLL_EXPORT_TYPED_FUNC - * - * \code - * - * int Add(int a, int b) { - * return a + b; - * } - * - * TVM_FFI_DLL_EXPORT_TYPED_FUNC(add, Add); - * TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC( - * add, - * R"(Add two integers and return the sum. - * - * Parameters - * ---------- - * a : int - * First integer - * b : int - * Second integer - * - * Returns - * ------- - * result : int - * Sum of a and b)"); - * - * \endcode - * - * \note The exported symbol name is `__tvm_ffi__doc_` (docstring getter function). - * This symbol is only exported when TVM_FFI_DLL_EXPORT_INCLUDE_METADATA is defined. - */ -#if TVM_FFI_DLL_EXPORT_INCLUDE_METADATA -#define TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC(ExportName, DocString) \ - extern "C" { \ - TVM_FFI_DLL_EXPORT int __tvm_ffi__doc_##ExportName(void *self, const TVMFFIAny *args, \ - int32_t num_args, TVMFFIAny *result) { \ - TVM_FFI_SAFE_CALL_BEGIN(); \ - ::tvm::ffi::String str(DocString); \ - ::tvm::ffi::TypeTraits<::tvm::ffi::String>::MoveToAny(std::move(str), result); \ - TVM_FFI_SAFE_CALL_END(); \ - } \ - } -#else -#define TVM_FFI_DLL_EXPORT_TYPED_FUNC_DOC(ExportName, DocString) -#endif -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_FUNCTION_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/function_details.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/function_details.h deleted file mode 100644 index 38725d800..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/function_details.h +++ /dev/null @@ -1,272 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/function_details.h - * \brief Implements the funciton signature reflection - */ -#ifndef TVM_FFI_FUNCTION_DETAILS_H_ -#define TVM_FFI_FUNCTION_DETAILS_H_ - -#include "any.h" -#include "base_details.h" -#include "c_api.h" -#include "error.h" - -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace details { - -template -struct Arg2Str { - template - TVM_FFI_INLINE static void Apply(std::ostream &os) { - using Arg = std::tuple_element_t; - if constexpr (i != 0) { - os << ", "; - } - os << i << ": " << Type2Str::v(); - } - template - TVM_FFI_INLINE static void Run(std::ostream &os, std::index_sequence) { - using TExpander = int[]; - (void)TExpander{0, (Apply(os), 0)...}; - } -}; - -/// NOTE: We only support `T`, `const T`, `const T&` and `T&&` as argument types. -template -static constexpr bool ArgTypeSupported = (!std::is_reference_v) || (std::is_const_v> && std::is_lvalue_reference_v) || (!std::is_const_v> && std::is_rvalue_reference_v); - -template -static constexpr bool ArgSupported = (ArgTypeSupported && (std::is_same_v>, Any> || std::is_same_v>, AnyView> || TypeTraitsNoCR::convert_enabled)); - -// NOTE: return type can only support non-reference managed returns -template -static constexpr bool RetSupported = (std::is_same_v || std::is_void_v || TypeTraits::convert_enabled); - -template -struct FuncFunctorImpl { - using FType = R(Args...); - using ArgType = std::tuple; - using RetType = R; - /*! \brief total number of arguments*/ - static constexpr size_t num_args = sizeof...(Args); - // MSVC is not that friendly to in-template nested bool evaluation -#ifndef _MSC_VER - /*! \brief Whether this function can be converted to ffi::Function via FromTyped */ - static constexpr bool unpacked_supported = (ArgSupported && ...) && (RetSupported); -#endif - TVM_FFI_INLINE static std::string Sig() { - using IdxSeq = std::make_index_sequence; - std::ostringstream ss; - ss << "("; - Arg2Str>::Run(ss, IdxSeq{}); - ss << ") -> " << Type2Str::v(); - return ss.str(); - } - TVM_FFI_INLINE static std::string TypeSchema() { - std::ostringstream oss; - oss << R"({"type":")" << StaticTypeKey::kTVMFFIFunction << R"(","args":[)"; - oss << details::TypeSchema::v(); - ((oss << "," << details::TypeSchema::v()), ...); - oss << "]}"; - return oss.str(); - } -}; - -template -struct FunctionInfoHelper; - -template -struct FunctionInfoHelper : FuncFunctorImpl {}; -template -struct FunctionInfoHelper : FuncFunctorImpl {}; - -/*! - * \brief Template class to get function signature of a function or functor. - * \tparam T The function/functor type. - * \note We need a decltype redirection because this helps lambda types. - */ -template -struct FunctionInfo : FunctionInfoHelper {}; -template -struct FunctionInfo : FuncFunctorImpl {}; -template -struct FunctionInfo : FuncFunctorImpl {}; -template -struct FunctionInfo : FuncFunctorImpl {}; -// Support pointer-to-member functions used in reflection (e.g. &Class::method) -template -struct FunctionInfo>> - : FuncFunctorImpl {}; -template -struct FunctionInfo>> - : FuncFunctorImpl {}; - -template -struct FunctionInfo>> - : FuncFunctorImpl {}; -template -struct FunctionInfo>> - : FuncFunctorImpl {}; - -/*! \brief Using static function to output typed function signature */ -using FGetFuncSignature = std::string (*)(); - -/*! - * \brief Auxilary argument value with context for error reporting - * \tparam Type The expected type of the argument. - * \note We use a template class with non-template operator conversion - * instead of a non-template class with template operator conversion. - * This is because template operator conversion doesn't play well with - * classes with template constructors. - * In this case, it may lead to some unintended compiler errors. - * An example of class can be `std::optional`. - */ -template -class ArgValueWithContext { -public: - using TypeWithoutCR = std::remove_const_t>; - - /*! - * \brief move constructor from another return value. - * \param args The argument list - * \param arg_index In a function call, this argument is at index arg_index (0-indexed). - * \param optional_name Name of the function being called. Can be nullptr if the function is not. - * \param f_sig Pointer to static function outputting signature of the function being called. - * named. - */ - TVM_FFI_INLINE ArgValueWithContext(const AnyView *args, int32_t arg_index, - const std::string *optional_name, FGetFuncSignature f_sig) - : args_(args), arg_index_(arg_index), optional_name_(optional_name), f_sig_(f_sig) {} - - TVM_FFI_INLINE operator TypeWithoutCR() { // NOLINT(google-explicit-constructor) - if constexpr (std::is_same_v) { - return args_[arg_index_]; - } else if constexpr (std::is_same_v) { - return Any(args_[arg_index_]); - } else { - std::optional opt = args_[arg_index_].template try_cast(); - if (!opt.has_value()) { - TVMFFIAny any_data = args_[arg_index_].CopyToTVMFFIAny(); - TVM_FFI_THROW(TypeError) << "Mismatched type on argument #" << arg_index_ - << " when calling: `" - << (optional_name_ == nullptr ? "" : *optional_name_) - << (f_sig_ == nullptr ? "" : (*f_sig_)()) << "`. Expected `" - << Type2Str::v() << "` but got `" - << TypeTraits::GetMismatchTypeInfo(&any_data) - << '`'; - } - return *std::move(opt); - } - } - -private: - const AnyView *args_; - int32_t arg_index_; - const std::string *optional_name_; - FGetFuncSignature f_sig_; -}; - -template -TVM_FFI_INLINE void unpack_call(std::index_sequence, const std::string *optional_name, - const F &f, [[maybe_unused]] const AnyView *args, - [[maybe_unused]] int32_t num_args, [[maybe_unused]] Any *rv) { - using FuncInfo = FunctionInfo; - using PackedArgs = typename FuncInfo::ArgType; - FGetFuncSignature f_sig = FuncInfo::Sig; - - // somehow MSVC does not support the static constexpr member in this case, function is fine -#ifndef _MSC_VER - static_assert(FuncInfo::unpacked_supported, "The function signature do not support unpacked"); -#endif - constexpr size_t nargs = sizeof...(Is); - if (nargs != num_args) { - TVM_FFI_THROW(TypeError) << "Mismatched number of arguments when calling: `" - << (optional_name == nullptr ? "" : *optional_name) - << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected " << nargs - << " but got " << num_args << " arguments"; - } - // use index sequence to do recursive-less unpacking - if constexpr (std::is_same_v) { - f(ArgValueWithContext>{args, Is, optional_name, f_sig}...); - } else { - *rv = R(f(ArgValueWithContext>{args, Is, optional_name, - f_sig}...)); - } -} - -/*! - * \brief Move the safe call raised error to the caller - * \return The error - */ -TVM_FFI_INLINE static Error MoveFromSafeCallRaised() { - TVMFFIObjectHandle handle; - TVMFFIErrorMoveFromRaised(&handle); - // handle is owned by caller - return details::ObjectUnsafe::ObjectRefFromObjectPtr( - details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle))); -} - -/*! - * \brief Set the safe call raised error - * \param error The error - */ -TVM_FFI_INLINE static void SetSafeCallRaised(const Error &error) { - TVMFFIErrorSetRaised(details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(error)); -} - -template -struct TypeSchemaImpl { - static std::string v() { - using U = std::remove_const_t>; - return TypeTraits::TypeSchema(); - } -}; - -template <> -struct TypeSchemaImpl { - static std::string v() { - return R"({"type":")" + std::string(StaticTypeKey::kTVMFFINone) + R"("})"; - } -}; - -template <> -struct TypeSchemaImpl { - static std::string v() { - return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIAny) + R"("})"; - } -}; - -template <> -struct TypeSchemaImpl { - static std::string v() { - return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIAny) + R"("})"; - } -}; - -} // namespace details -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_FUNCTION_DETAILS_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/memory.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/memory.h deleted file mode 100644 index fd999da2a..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/memory.h +++ /dev/null @@ -1,274 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/memory.h - * \brief Runtime memory management to allocate on heap object. - */ -#ifndef TVM_FFI_MEMORY_H_ -#define TVM_FFI_MEMORY_H_ - -#include "object.h" - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -/*! \brief Deleter function for obeject */ -using FObjectDeleter = void (*)(void *obj, int flags); - -// Detail implementations after this -// -// The current design allows swapping the -// allocator pattern when necessary. -// -// Possible future allocator optimizations: -// - Arena allocator that gives ownership of memory to arena (deleter = nullptr) -// - Thread-local object pools: one pool per size and alignment requirement. -// - Can specialize by type of object to give the specific allocator to each object. -namespace details { - -/*! - * \brief Allocate aligned memory. - * \param size The size. - * \tparam align The alignment, must be a power of 2. - * \return The pointer to the allocated memory. - */ -template -TVM_FFI_INLINE void *AlignedAlloc(size_t size) { - static_assert(align != 0 && (align & (align - 1)) == 0, "align must be a power of 2"); -#ifdef _MSC_VER - // MSVC have to use _aligned_malloc - if (void *ptr = _aligned_malloc(size, align)) { - return ptr; - } - throw std::bad_alloc(); -#else - if constexpr (align <= alignof(std::max_align_t)) { - // malloc guarantees alignment of std::max_align_t - if (void *ptr = std::malloc(size)) { - return ptr; - } - throw std::bad_alloc(); - } else { - void *ptr; - // for other alignments, use posix_memalign - if (posix_memalign(&ptr, align, size) != 0) { - throw std::bad_alloc(); - } - return ptr; - } -#endif -} - -/*! - * \brief Free aligned memory. - * \param data The pointer to the memory to free. - */ -TVM_FFI_INLINE void AlignedFree(void *data) { -#ifdef _MSC_VER - // MSVC have to use _aligned_free - _aligned_free(data); -#else - std::free(data); -#endif -} - -/*! - * \brief Base class of object allocators that implements make. - * Use curiously recurring template pattern. - * - * \tparam Derived The derived class. - */ -template -class ObjAllocatorBase { -public: - /*! - * \brief Make a new object using the allocator. - * \tparam T The type to be allocated. - * \tparam Args The constructor signature. - * \param args The arguments. - */ - template - ObjectPtr make_object(Args &&...args) { - using Handler = typename Derived::template Handler; - static_assert(std::is_base_of_v, "make can only be used to create Object"); - T *ptr = Handler::New(static_cast(this), std::forward(args)...); - TVMFFIObject *ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); - ffi_ptr->combined_ref_count = kCombinedRefCountBothOne; - ffi_ptr->type_index = T::RuntimeTypeIndex(); - ffi_ptr->__padding = 0; - ffi_ptr->deleter = Handler::Deleter(); - return details::ObjectUnsafe::ObjectPtrFromOwned(ptr); - } - - /*! - * \tparam ArrayType The type to be allocated. - * \tparam ElemType The type of array element. - * \tparam Args The constructor signature. - * \param num_elems The number of array elements. - * \param args The arguments. - */ - template - ObjectPtr make_inplace_array(size_t num_elems, Args &&...args) { - using Handler = typename Derived::template ArrayHandler; - static_assert(std::is_base_of_v, - "make_inplace_array can only be used to create Object"); - ArrayType *ptr = Handler::New(static_cast(this), num_elems, std::forward(args)...); - TVMFFIObject *ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); - ffi_ptr->combined_ref_count = kCombinedRefCountBothOne; - ffi_ptr->type_index = ArrayType::RuntimeTypeIndex(); - ffi_ptr->__padding = 0; - ffi_ptr->deleter = Handler::Deleter(); - return details::ObjectUnsafe::ObjectPtrFromOwned(ptr); - } - -private: - ObjAllocatorBase() = default; - friend Derived; -}; - -// Simple allocator that uses new/delete. -class SimpleObjAllocator : public ObjAllocatorBase { -public: - template - class Handler { - public: - template - static T *New(SimpleObjAllocator *, Args &&...args) { - // NOTE: the first argument is not needed for SimpleObjAllocator - // It is reserved for special allocators that needs to recycle - // the object to itself (e.g. in the case of object pool). - // - // In the case of an object pool, an allocator needs to create - // a special chunk memory that hides reference to the allocator - // and call allocator's release function in the deleter. - - // NOTE2: Use inplace new to allocate - // This is used to get rid of warning when deleting a virtual - // class with non-virtual destructor. - // We are fine here as we captured the right deleter during construction. - // This is also the right way to get storage type for an object pool. - void *data = AlignedAlloc(sizeof(T)); - new (data) T(std::forward(args)...); - return reinterpret_cast(data); - } - - static FObjectDeleter Deleter() { return Deleter_; } - - private: - static void Deleter_(void *objptr, int flags) { - T *tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned(static_cast(objptr)); - if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) { - // It is important to do tptr->T::~T(), - // so that we explicitly call the specific destructor - // instead of tptr->~T(), which could mean the intention - // call a virtual destructor(which may not be available and is not required). - tptr->T::~T(); - } - if (flags & kTVMFFIObjectDeleterFlagBitMaskWeak) { - AlignedFree(static_cast(tptr)); - } - } - }; - - // Array handler that uses new/delete. - template - class ArrayHandler { - public: - template - static ArrayType *New(SimpleObjAllocator *, size_t num_elems, Args &&...args) { - // NOTE: the first argument is not needed for ArrayObjAllocator - // It is reserved for special allocators that needs to recycle - // the object to itself (e.g. in the case of object pool). - // - // In the case of an object pool, an allocator needs to create - // a special chunk memory that hides reference to the allocator - // and call allocator's release function in the deleter. - // NOTE2: Use inplace new to allocate - // This is used to get rid of warning when deleting a virtual - // class with non-virtual destructor. - // We are fine here as we captured the right deleter during construction. - // This is also the right way to get storage type for an object pool. - - // for now only support elements that aligns with array header. - static_assert( - alignof(ArrayType) % alignof(ElemType) == 0 && sizeof(ArrayType) % alignof(ElemType) == 0, - "element alignment constraint"); - size_t size = sizeof(ArrayType) + sizeof(ElemType) * num_elems; - // round up to the nearest multiple of align - constexpr size_t align = alignof(ArrayType); - // C++ standard always guarantees that alignof operator returns a power of 2 - size_t aligned_size = (size + (align - 1)) & ~(align - 1); - void *data = AlignedAlloc(aligned_size); - new (data) ArrayType(std::forward(args)...); - return reinterpret_cast(data); - } - - static FObjectDeleter Deleter() { return Deleter_; } - - private: - static void Deleter_(void *objptr, int flags) { - ArrayType *tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned( - static_cast(objptr)); - if (flags & kTVMFFIObjectDeleterFlagBitMaskStrong) { - // It is important to do tptr->ArrayType::~ArrayType(), - // so that we explicitly call the specific destructor - // instead of tptr->~ArrayType(), which could mean the intention - // call a virtual destructor(which may not be available and is not required). - tptr->ArrayType::~ArrayType(); - } - if (flags & kTVMFFIObjectDeleterFlagBitMaskWeak) { - AlignedFree(static_cast(tptr)); - } - } - }; -}; -} // namespace details - -/*! - * \brief Allocate an object - * \param args arguments to the constructor. - * \tparam T the node type. - * \return The ObjectPtr to the allocated object. - */ -template -inline ObjectPtr make_object(Args &&...args) { - return details::SimpleObjAllocator().make_object(std::forward(args)...); -} - -/*! - * \brief Allocate an Object with additional ElemType[num_elems] that are stored right after. - * \param num_elems The number of elements in the array. - * \param args arguments to the constructor. - * \tparam ArrayType the array type. - * \tparam ElemType the element type. - * \return The ObjectPtr to the allocated array. - */ -template -inline ObjectPtr make_inplace_array_object(size_t num_elems, Args &&...args) { - return details::SimpleObjAllocator().make_inplace_array( - num_elems, std::forward(args)...); -} - -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_MEMORY_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/object.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/object.h deleted file mode 100644 index eb796bf6a..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/object.h +++ /dev/null @@ -1,1207 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/object.h - * \brief A managed object in the TVM FFI. - */ -#ifndef TVM_FFI_OBJECT_H_ -#define TVM_FFI_OBJECT_H_ - -#include "base_details.h" -#include "c_api.h" - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief TypeIndex enum, alias of TVMFFITypeIndex. - */ -using TypeIndex = TVMFFITypeIndex; - -/*! - * \brief TypeInfo, alias of TVMFFITypeInfo. - */ -using TypeInfo = TVMFFITypeInfo; - -/*! - * \brief Helper tag to explicitly request unsafe initialization. - * - * Constructing an ObjectRefType with UnsafeInit{} will set the data_ member to nullptr. - * - * When initializing Object fields, ObjectRef fields can be set to UnsafeInit. - * This enables the "construct with UnsafeInit then set all fields" pattern - * when the object does not have a default constructor. - * - * Used for initialization in controlled scenarios where such unsafe - * initialization is known to be safe. - * - * Each ObjectRefType should have a constructor that takes an UnsafeInit tag. - * - * \note As the name suggests, do not use it in normal code paths. - */ -struct UnsafeInit {}; - -/*! - * \brief Known type keys for pre-defined types. - */ -struct StaticTypeKey { - /*! \brief The type key for Any */ - static constexpr const char *kTVMFFIAny = "Any"; - /*! \brief The type key for None */ - static constexpr const char *kTVMFFINone = "None"; - /*! \brief The type key for bool */ - static constexpr const char *kTVMFFIBool = "bool"; - /*! \brief The type key for int */ - static constexpr const char *kTVMFFIInt = "int"; - /*! \brief The type key for float */ - static constexpr const char *kTVMFFIFloat = "float"; - /*! \brief The type key for void* */ - static constexpr const char *kTVMFFIOpaquePtr = "void*"; - /*! \brief The type key for DataType */ - static constexpr const char *kTVMFFIDataType = "DataType"; - /*! \brief The type key for Device */ - static constexpr const char *kTVMFFIDevice = "Device"; - /*! \brief The type key for DLTensor* */ - static constexpr const char *kTVMFFIDLTensorPtr = "DLTensor*"; - /*! \brief The type key for const char* */ - static constexpr const char *kTVMFFIRawStr = "const char*"; - /*! \brief The type key for TVMFFIByteArray* */ - static constexpr const char *kTVMFFIByteArrayPtr = "TVMFFIByteArray*"; - /*! \brief The type key for ObjectRValueRef */ - static constexpr const char *kTVMFFIObjectRValueRef = "ObjectRValueRef"; - /*! \brief The type key for SmallStr */ - static constexpr const char *kTVMFFISmallStr = "ffi.SmallStr"; - /*! \brief The type key for SmallBytes */ - static constexpr const char *kTVMFFISmallBytes = "ffi.SmallBytes"; - /*! \brief The type key for Error */ - static constexpr const char *kTVMFFIError = "ffi.Error"; - /*! \brief The type key for Bytes */ - static constexpr const char *kTVMFFIBytes = "ffi.Bytes"; - /*! \brief The type key for String */ - static constexpr const char *kTVMFFIStr = "ffi.String"; - /*! \brief The type key for Shape */ - static constexpr const char *kTVMFFIShape = "ffi.Shape"; - /*! \brief The type key for Tensor */ - static constexpr const char *kTVMFFITensor = "ffi.Tensor"; - /*! \brief The type key for Object */ - static constexpr const char *kTVMFFIObject = "ffi.Object"; - /*! \brief The type key for Function */ - static constexpr const char *kTVMFFIFunction = "ffi.Function"; - /*! \brief The type key for Array */ - static constexpr const char *kTVMFFIArray = "ffi.Array"; - /*! \brief The type key for Map */ - static constexpr const char *kTVMFFIMap = "ffi.Map"; - /*! \brief The type key for Module */ - static constexpr const char *kTVMFFIModule = "ffi.Module"; - /*! \brief The type key for OpaquePyObject */ - static constexpr const char *kTVMFFIOpaquePyObject = "ffi.OpaquePyObject"; -}; - -/*! - * \brief Get type key from type index - * \param type_index The input type index - * \return the type key - */ -inline std::string TypeIndexToTypeKey(int32_t type_index) { - const TypeInfo *type_info = TVMFFIGetTypeInfo(type_index); - return std::string(type_info->type_key.data, type_info->type_key.size); -} - -namespace details { -// Helper to perform -// unsafe operations related to object -struct ObjectUnsafe; - -/*! \brief One counter for weak reference. */ -constexpr uint64_t kCombinedRefCountWeakOne = static_cast(1) << 32; -/*! \brief One counter for strong reference. */ -constexpr uint64_t kCombinedRefCountStrongOne = 1; -/*! \brief Both reference counts. */ -constexpr uint64_t kCombinedRefCountBothOne = kCombinedRefCountWeakOne | kCombinedRefCountStrongOne; -/*! \brief Mask to get the lower 32 bits of the combined reference count. */ -constexpr uint64_t kCombinedRefCountMaskUInt32 = (static_cast(1) << 32) - 1; - -/*! - * Check if the type_index is an instance of TargetObjectType. - * - * \tparam TargetType The target object type to be checked. - * - * \param object_type_index The type index to be checked, caller - * ensures that the index is already within the object index range. - * - * \return Whether the target type is true. - */ -template -TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index); -} // namespace details - -/*! - * \brief Base class of all object containers. - * - * Sub-class of objects should declare the following static constexpr fields: - * - * - _type_index: - * Static type index of the object, if assigned to TypeIndex::kTVMFFIDynObject - * the type index will be assigned during runtime. - * Runtime type index can be accessed by ObjectType::TypeIndex(); - * - _type_key: - * The unique string identifier of the type. - * - _type_final: - * Whether the type is terminal type(there is no subclass of the type in the object system). - * This field is automatically set by macro TVM_FFI_DECLARE_OBJECT_INFO_FINAL - * It is still OK to sub-class a terminal object type T and construct it using make_object. - * But IsInstance check will only show that the object type is T(instead of the sub-class). - * - _type_mutable: - * Whether we would like to expose cast to non-constant pointer - * ObjectType* from Any/AnyView. By default, we set to false so it is not exposed. - * - * The following two fields are necessary for base classes that can be sub-classed. - * - * - _type_child_slots: - * Number of reserved type index slots for child classes. - * Used for runtime optimization for type checking in IsInstance. - * If an object's type_index is within range of [type_index, type_index + _type_child_slots] - * Then the object can be quickly decided as sub-class of the current object class. - * If not, a fallback mechanism is used to check the global type table. - * Recommendation: set to estimate number of children needed. - * - * - _type_child_slots_can_overflow: - * Whether we can add additional child classes even if the number of child classes - * exceeds the _type_child_slots. A fallback mechanism to check type table will be used. - * Recommendation: set to false for optimal runtime speed if we know exact number of children. - * - * Two macros are used to declare helper functions in the object: - * - Use TVM_FFI_DECLARE_OBJECT_INFO for object classes that can be sub-classed. - * - Use TVM_FFI_DECLARE_OBJECT_INFO_FINAL for object classes that cannot be sub-classed. - * - * New objects can be created using make_object function. - * Which will automatically populate the type_index and deleter of the object. - */ -class Object { -protected: - /*! \brief header field that is the common prefix of all objects */ - TVMFFIObject header_; - -public: - Object() { - header_.combined_ref_count = 0; - header_.type_index = 0; - header_.__padding = 0; - header_.__ensure_align = 0; - } - /*! - * Check if the object is an instance of TargetType. - * \tparam TargetType The target type to be checked. - * \return Whether the target type is true. - */ - template - bool IsInstance() const { - return details::IsObjectInstance(header_.type_index); - } - - /*! \return The internal runtime type index of the object. */ - int32_t type_index() const { return header_.type_index; } - - /*! - * \return the type key of the object. - * \note this operation is expensive, can be used for error reporting. - */ - std::string GetTypeKey() const { - // the function checks that the info exists - const TypeInfo *type_info = TVMFFIGetTypeInfo(header_.type_index); - return std::string(type_info->type_key.data, type_info->type_key.size); - } - - /*! - * \return A hash value of the return of GetTypeKey. - */ - uint64_t GetTypeKeyHash() const { - // the function checks that the info exists - const TypeInfo *type_info = TVMFFIGetTypeInfo(header_.type_index); - return type_info->type_key_hash; - } - - /*! - * \brief Get the type key of the corresponding index from runtime. - * \param tindex The type index. - * \return the result. - */ - static std::string TypeIndex2Key(int32_t tindex) { - const TypeInfo *type_info = TVMFFIGetTypeInfo(tindex); - return std::string(type_info->type_key.data, type_info->type_key.size); - } - - /*! - * \return Whether the object.use_count() == 1. - */ - bool unique() const { return use_count() == 1; } - - /*! - * \return The usage count of the cell. - * \note We use STL style naming to be consistent with known API in shared_ptr. - */ - uint64_t use_count() const { - // only need relaxed load of counters -#ifdef _MSC_VER - return ((reinterpret_cast( - &header_.combined_ref_count))[0] // NOLINT(*) - ) - & kCombinedRefCountMaskUInt32; -#else - return __atomic_load_n(&(header_.combined_ref_count), __ATOMIC_RELAXED) & kCombinedRefCountMaskUInt32; -#endif - } - - //---------------------------------------------------------------------------- - // The following fields are configuration flags for subclasses of object - //---------------------------------------------------------------------------- - /*! \brief The type key of the class */ - static constexpr const char *_type_key = StaticTypeKey::kTVMFFIObject; - /*! \brief Whether the class is final */ - static constexpr bool _type_final = false; - /*! \brief Whether allow mutable access to fields */ - static constexpr bool _type_mutable = false; - /*! \brief The number of child slots of the class to pre-allocate to this type */ - static constexpr uint32_t _type_child_slots = 0; - /*! - * \brief Whether allow additional children beyond pre-specified by _type_child_slots - */ - static constexpr bool _type_child_slots_can_overflow = true; - /*! \brief The static type index of the class */ - static constexpr int32_t _type_index = TypeIndex::kTVMFFIObject; - /*! \brief The static depth of the class in the object hierarchy */ - static constexpr int32_t _type_depth = 0; - /*! \brief The structural equality and hash kind of the type */ - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindUnsupported; - // The following functions are provided by macro - // TVM_FFI_DECLARE_OBJECT_INFO and TVM_FFI_DECLARE_OBJECT_INFO_FINAL - /*! - * \brief Get the runtime allocated type index of the type - * \note Getting this information may need dynamic calls into a global table. - */ - static int32_t RuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; } - /*! - * \brief Internal function to get or allocate a runtime index. - */ - static int32_t _GetOrAllocRuntimeTypeIndex() { // NOLINT(bugprone-reserved-identifier) - return TypeIndex::kTVMFFIObject; - } - -private: - // exposing detailed constants to here - static constexpr uint64_t kCombinedRefCountMaskUInt32 = details::kCombinedRefCountMaskUInt32; - static constexpr uint64_t kCombinedRefCountStrongOne = details::kCombinedRefCountStrongOne; - static constexpr uint64_t kCombinedRefCountWeakOne = details::kCombinedRefCountWeakOne; - static constexpr uint64_t kCombinedRefCountBothOne = details::kCombinedRefCountBothOne; - /*! \brief increase strong reference count, the caller must already hold a strong reference */ - void IncRef() { -#ifdef _MSC_VER - _InterlockedIncrement64( - reinterpret_cast(&header_.combined_ref_count)); // NOLINT(*) -#else - __atomic_fetch_add(&(header_.combined_ref_count), 1, __ATOMIC_RELAXED); -#endif - } - /*! - * \brief Try to lock the object to increase the strong reference count, - * the caller must already hold a strong reference. - * \return whether the lock call is successful and object is still alive. - */ - bool TryPromoteWeakPtr() { -#ifdef _MSC_VER - uint64_t old_count = (reinterpret_cast(&header_.combined_ref_count))[0]; // NOLINT(*) - while ((old_count & kCombinedRefCountMaskUInt32) != 0) { - uint64_t new_count = old_count + kCombinedRefCountStrongOne; - uint64_t old_count_loaded = _InterlockedCompareExchange64( - reinterpret_cast(&header_.combined_ref_count), new_count, old_count); - if (old_count == old_count_loaded) { - return true; - } - old_count = old_count_loaded; - } - return false; -#else - uint64_t old_count = __atomic_load_n(&(header_.combined_ref_count), __ATOMIC_RELAXED); - while ((old_count & kCombinedRefCountMaskUInt32) != 0) { - // must do CAS to ensure that we are the only one that increases the reference count - // avoid condition when two threads tries to promote weak to strong at same time - // or when strong deletion happens between the load and the CAS - uint64_t new_count = old_count + kCombinedRefCountStrongOne; - if (__atomic_compare_exchange_n(&(header_.combined_ref_count), &old_count, new_count, true, - __ATOMIC_ACQ_REL, __ATOMIC_RELAXED)) { - return true; - } - } - return false; -#endif - } - - /*! \brief increase weak reference count */ - void IncWeakRef() { -#ifdef _MSC_VER - _InlineInterlockedAdd64( - reinterpret_cast(&header_.combined_ref_count), // NOLINT(*) - kCombinedRefCountWeakOne); -#else - __atomic_fetch_add(&(header_.combined_ref_count), kCombinedRefCountWeakOne, __ATOMIC_RELAXED); -#endif - } - - /*! \brief decrease strong reference count and delete the object */ - void DecRef() { -#ifdef _MSC_VER - // use simpler impl in windows to ensure correctness - uint64_t count_before_sub = _InterlockedDecrement64( // - reinterpret_cast(&header_.combined_ref_count) // NOLINT(*) - ) - + 1; - if (count_before_sub == kCombinedRefCountBothOne) { // NOLINT(*) - // fast path: both reference counts will go to zero - if (header_.deleter != nullptr) { - // full barrrier is implicit in InterlockedDecrement - header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskBoth); - } - } else if ((count_before_sub & kCombinedRefCountMaskUInt32) == kCombinedRefCountStrongOne) { - // strong reference count becomes zero, we need to first do strong deletion - // then decrease weak reference count - // full barrrier is implicit in InterlockedAdd - if (header_.deleter != nullptr) { - header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskStrong); - } - // decrease weak reference count - if (_InlineInterlockedAdd64( // - reinterpret_cast(&header_.combined_ref_count), - -kCombinedRefCountWeakOne) - == 0) { // NOLINT(*) - if (header_.deleter != nullptr) { - // full barrrier is implicit in InterlockedAdd - header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); - } - } - } -#else - // first do a release, note we only need to acquire for deleter - // optimization: we only need one atomic to tell the common case - // where both reference counts are zero - uint64_t count_before_sub = __atomic_fetch_sub(&(header_.combined_ref_count), - kCombinedRefCountStrongOne, __ATOMIC_RELEASE); - if (count_before_sub == kCombinedRefCountBothOne) { - // common case, we need to delete both the object and the memory block - // only acquire when we need to call deleter - __atomic_thread_fence(__ATOMIC_ACQUIRE); - if (header_.deleter != nullptr) { - // call deleter once - header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskBoth); - } - } else if ((count_before_sub & kCombinedRefCountMaskUInt32) == kCombinedRefCountStrongOne) { - // strong count is already zero - // Slower path: there is still a weak reference left - __atomic_thread_fence(__ATOMIC_ACQUIRE); - // call destructor first, then decrease weak reference count - if (header_.deleter != nullptr) { - header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskStrong); - } - // now decrease weak reference count - if (__atomic_fetch_sub(&(header_.combined_ref_count), kCombinedRefCountWeakOne, - __ATOMIC_RELEASE) - == kCombinedRefCountWeakOne) { - __atomic_thread_fence(__ATOMIC_ACQUIRE); - if (header_.deleter != nullptr) { - header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); - } - } - } -#endif - } - - /*! \brief decrease weak reference count */ - void DecWeakRef() { -#ifdef _MSC_VER - if (_InlineInterlockedAdd64( // - reinterpret_cast(&header_.combined_ref_count), // NOLINT(*) - -kCombinedRefCountWeakOne) - == 0) { - if (header_.deleter != nullptr) { - header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); - } - } -#else - // now decrease weak reference count - if (__atomic_fetch_sub(&(header_.combined_ref_count), kCombinedRefCountWeakOne, - __ATOMIC_RELEASE) - == kCombinedRefCountWeakOne) { - __atomic_thread_fence(__ATOMIC_ACQUIRE); - if (header_.deleter != nullptr) { - header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak); - } - } -#endif - } - - // friend classes - template - friend class ObjectPtr; - template - friend class WeakObjectPtr; - friend struct tvm::ffi::details::ObjectUnsafe; -}; - -/*! - * \brief A custom smart pointer for Object. - * \tparam T the content data type. - * \sa make_object - */ -template -class ObjectPtr { -public: - /*! \brief default constructor */ - ObjectPtr() = default; - /*! \brief default constructor */ - ObjectPtr(std::nullptr_t) {} // NOLINT(*) - /*! - * \brief copy constructor - * \param other The value to be moved - */ - ObjectPtr(const ObjectPtr &other) // NOLINT(*) - : ObjectPtr(other.data_) {} - /*! - * \brief copy constructor - * \param other The value to be moved - */ - template - ObjectPtr(const ObjectPtr &other) // NOLINT(*) - : ObjectPtr(other.data_) { - static_assert(std::is_base_of_v, "can only assign of child class ObjectPtr to parent"); - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - ObjectPtr(ObjectPtr &&other) // NOLINT(*) - : data_(other.data_) { - other.data_ = nullptr; - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - template - ObjectPtr(ObjectPtr &&other) // NOLINT(*) - : data_(other.data_) { - static_assert(std::is_base_of_v, "can only assign of child class ObjectPtr to parent"); - other.data_ = nullptr; - } - /*! \brief destructor */ - ~ObjectPtr() { this->reset(); } - /*! - * \brief Swap this array with another Object - * \param other The other Object - */ - void swap(ObjectPtr &other) { // NOLINT(*) - std::swap(data_, other.data_); - } - /*! - * \return Get the content of the pointer - */ - T *get() const { return static_cast(data_); } - /*! - * \return The pointer - */ - T *operator->() const { return get(); } - /*! - * \return The reference - */ - T &operator*() const { // NOLINT(*) - return *get(); - } - /*! - * \brief copy assignment - * \param other The value to be assigned. - * \return reference to self. - */ - ObjectPtr &operator=(const ObjectPtr &other) { // NOLINT(*) - // takes in plane operator to enable copy elison. - // copy-and-swap idiom - ObjectPtr(other).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief move assignment - * \param other The value to be assigned. - * \return reference to self. - */ - ObjectPtr &operator=(ObjectPtr &&other) { // NOLINT(*) - // copy-and-swap idiom - ObjectPtr(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief nullptr check - * \return result of comparison of internal pointer with nullptr. - */ - explicit operator bool() const { return get() != nullptr; } - /*! \brief reset the content of ptr to be nullptr */ - void reset() { - if (data_ != nullptr) { - data_->DecRef(); - data_ = nullptr; - } - } - /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } - /*! \return whether the reference is unique */ - bool unique() const { return data_ != nullptr && data_->use_count() == 1; } - /*! \return Whether two ObjectPtr do not equal each other */ - bool operator==(const ObjectPtr &other) const { return data_ == other.data_; } - /*! \return Whether two ObjectPtr equals each other */ - bool operator!=(const ObjectPtr &other) const { return data_ != other.data_; } - /*! \return Whether the pointer is nullptr */ - bool operator==(std::nullptr_t) const { return data_ == nullptr; } - /*! \return Whether the pointer is not nullptr */ - bool operator!=(std::nullptr_t) const { return data_ != nullptr; } - -private: - /*! \brief internal pointer field */ - Object *data_{nullptr}; - /*! - * \brief constructor from Object - * \param data The data pointer - */ - explicit ObjectPtr(Object *data) : data_(data) { - if (data_ != nullptr) { - data_->IncRef(); - } - } - // friend classes - friend class Object; - friend class ObjectRef; - friend struct ObjectPtrHash; - template - friend class ObjectPtr; - template - friend class WeakObjectPtr; - friend struct tvm::ffi::details::ObjectUnsafe; -}; - -/*! - * \brief A custom smart pointer for Object. - * \tparam T the content data type. - * \sa make_object - */ -template -class WeakObjectPtr { -public: - /*! \brief default constructor */ - WeakObjectPtr() = default; - /*! \brief default constructor */ - WeakObjectPtr(std::nullptr_t) {} // NOLINT(*) - /*! - * \brief copy constructor - * \param other The value to be moved - */ - WeakObjectPtr(const WeakObjectPtr &other) // NOLINT(*) - : WeakObjectPtr(other.data_) {} - - /*! - * \brief copy constructor - * \param other The value to be moved - */ - WeakObjectPtr(const ObjectPtr &other) // NOLINT(*) - : WeakObjectPtr(other.get()) {} - /*! - * \brief copy constructor - * \param other The value to be moved - */ - template - WeakObjectPtr(const WeakObjectPtr &other) // NOLINT(*) - : WeakObjectPtr(other.data_) { - static_assert(std::is_base_of_v, "can only assign of child class ObjectPtr to parent"); - } - /*! - * \brief copy constructor - * \param other The value to be moved - */ - template - WeakObjectPtr(const ObjectPtr &other) // NOLINT(*) - : WeakObjectPtr(other.data_) { - static_assert(std::is_base_of_v, "can only assign of child class ObjectPtr to parent"); - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - WeakObjectPtr(WeakObjectPtr &&other) // NOLINT(*) - : data_(other.data_) { - other.data_ = nullptr; - } - /*! - * \brief move constructor - * \param other The value to be moved - */ - template - WeakObjectPtr(WeakObjectPtr &&other) // NOLINT(*) - : data_(other.data_) { - static_assert(std::is_base_of_v, "can only assign of child class ObjectPtr to parent"); - other.data_ = nullptr; - } - /*! \brief destructor */ - ~WeakObjectPtr() { this->reset(); } - /*! - * \brief Swap this array with another Object - * \param other The other Object - */ - void swap(WeakObjectPtr &other) { // NOLINT(*) - std::swap(data_, other.data_); - } - - /*! - * \brief copy assignment - * \param other The value to be assigned. - * \return reference to self. - */ - WeakObjectPtr &operator=(const WeakObjectPtr &other) { // NOLINT(*) - // takes in plane operator to enable copy elison. - // copy-and-swap idiom - WeakObjectPtr(other).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief move assignment - * \param other The value to be assigned. - * \return reference to self. - */ - WeakObjectPtr &operator=(WeakObjectPtr &&other) { // NOLINT(*) - // copy-and-swap idiom - WeakObjectPtr(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - - /*! \return The internal object pointer if the object is still alive, otherwise nullptr */ - ObjectPtr lock() const { - if (data_ != nullptr && data_->TryPromoteWeakPtr()) { - ObjectPtr ret; - // we already increase the reference count, so we don't need to do it again - ret.data_ = data_; - return ret; - } - return nullptr; - } - - /*! \brief reset the content of ptr to be nullptr */ - void reset() { - if (data_ != nullptr) { - data_->DecWeakRef(); - data_ = nullptr; - } - } - - /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { return data_ != nullptr ? data_->use_count() : 0; } - - /*! \return whether the pointer is nullptr */ - bool expired() const { return data_ == nullptr || data_->use_count() == 0; } - -private: - /*! \brief internal pointer field */ - Object *data_{nullptr}; - - /*! - * \brief constructor from Object - * \param data The data pointer - */ - explicit WeakObjectPtr(Object *data) : data_(data) { - if (data_ != nullptr) { - data_->IncWeakRef(); - } - } - - template - friend class WeakObjectPtr; - friend struct tvm::ffi::details::ObjectUnsafe; -}; - -/*! - * \brief Optional data type in FFI. - * \tparam T The underlying type of the optional. - * - * \note Compared to std::optional, Optional - * akes less storage as it used nullptr to represent nullopt. - */ -template -class Optional; - -/*! \brief Base class of all object reference */ -class ObjectRef { -public: - /*! \brief default constructor */ - ObjectRef() = default; - /*! \brief copy constructor */ - ObjectRef(const ObjectRef &other) = default; - /*! \brief move constructor */ - ObjectRef(ObjectRef &&other) noexcept : data_(std::move(other.data_)) { other.data_ = nullptr; } - /*! \brief copy assignment */ - ObjectRef &operator=(const ObjectRef &other) = default; - /*! \brief move assignment */ - ObjectRef &operator=(ObjectRef &&other) noexcept { - data_ = std::move(other.data_); - other.data_ = nullptr; - return *this; - } - /*! \brief Constructor from existing object ptr */ - explicit ObjectRef(ObjectPtr data) : data_(std::move(data)) {} - /*! \brief Constructor from UnsafeInit */ - explicit ObjectRef(UnsafeInit) : data_(nullptr) {} - /*! - * \brief Comparator - * \param other Another object ref. - * \return the compare result. - */ - bool same_as(const ObjectRef &other) const { return data_ == other.data_; } - /*! - * \brief Comparator - * \param other Another object ref. - * \return the compare result. - */ - bool operator==(const ObjectRef &other) const { return data_ == other.data_; } - /*! - * \brief Comparator - * \param other Another object ref. - * \return the compare result. - */ - bool operator!=(const ObjectRef &other) const { return data_ != other.data_; } - /*! - * \brief Comparator - * \param other Another object ref by address. - * \return the compare result. - */ - bool operator<(const ObjectRef &other) const { return data_.get() < other.data_.get(); } - /*! - * \return whether the object is defined. - */ - bool defined() const { return data_ != nullptr; } - /*! \return the internal object pointer */ - const Object *get() const { return data_.get(); } - /*! \return the internal object pointer */ - const Object *operator->() const { return get(); } - /*! \return whether the reference is unique */ - bool unique() const { return data_.unique(); } - /*! \return The use count of the ptr, for debug purposes */ - int use_count() const { return data_.use_count(); } - - /*! - * \brief Try to downcast the internal Object to a - * raw pointer of a corresponding type. - * - * The function will return a nullptr if the cast failed. - * - * if (const AddNode *ptr = node_ref.as()) { - * // This is an add node - * } - * - * \tparam ObjectType the target type, must be a subtype of Object - * \return The pointer to the requested type. - */ - template >> - const ObjectType *as() const { - if (data_ != nullptr && data_->IsInstance()) { - return static_cast(data_.get()); - } else { - return nullptr; - } - } - - /*! - * \brief Try to downcast the ObjectRef to Optional of the requested type. - * - * The function will return a std::nullopt if the cast or if the pointer is nullptr. - * - * \tparam ObjectRefType the target type, must be a subtype of ObjectRef' - * \return The optional value of the requested type. - */ - template >> - TVM_FFI_INLINE std::optional as() const { - if (data_ != nullptr) { - if (data_->IsInstance()) { - ObjectRefType ref(UnsafeInit{}); - ref.data_ = data_; - return ref; - } else { - return std::nullopt; - } - } else { - return std::nullopt; - } - } - - /*! - * \brief Get the type index of the ObjectRef - * \return The type index of the ObjectRef - */ - int32_t type_index() const { - return data_ != nullptr ? data_->type_index() : TypeIndex::kTVMFFINone; - } - - /*! - * \brief Get the type key of the ObjectRef - * \return The type key of the ObjectRef - */ - std::string GetTypeKey() const { - return data_ != nullptr ? data_->GetTypeKey() : StaticTypeKey::kTVMFFINone; - } - - /*! \brief type indicate the container type. */ - using ContainerType = Object; - /*! \brief Whether the reference can point to nullptr */ - static constexpr bool _type_is_nullable = true; - -protected: - /*! \brief Internal pointer that backs the reference. */ - ObjectPtr data_; - /*! \return return a mutable internal ptr, can be used by sub-classes. */ - Object *get_mutable() const { return data_.get(); } - // friend classes. - friend struct ObjectPtrHash; - friend struct tvm::ffi::details::ObjectUnsafe; -}; - -// forward delcare variant -template -class Variant; - -/*! \brief ObjectRef hash functor */ -struct ObjectPtrHash { - size_t operator()(const ObjectRef &a) const { return operator()(a.data_); } - - template - size_t operator()(const ObjectPtr &a) const { - return std::hash()(a.get()); - } - - template - TVM_FFI_INLINE size_t operator()(const Variant &a) const; -}; - -/*! \brief ObjectRef equal functor */ -struct ObjectPtrEqual { - bool operator()(const ObjectRef &a, const ObjectRef &b) const { return a.same_as(b); } - - template - bool operator()(const ObjectPtr &a, const ObjectPtr &b) const { - return a == b; - } - - template - TVM_FFI_INLINE bool operator()(const Variant &a, const Variant &b) const; -}; - -/*! - * \brief Helper macro to declare object information with static type index. - * - * For each custom object, you need to call tvm::ffi::reflection::ObjectDef() - * once in your cc file to register the type index with the runtime. - * Alternatively, you can call TypeName::_GetOrAllocRuntimeTypeIndex() once. - * - * \param TypeKey The type key of the current type. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - * - * \see tvm::ffi::reflection::ObjectDef - */ -#define TVM_FFI_DECLARE_OBJECT_INFO_STATIC(TypeKey, TypeName, ParentType) \ - static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \ - static int32_t _GetOrAllocRuntimeTypeIndex() { \ - static_assert(!ParentType::_type_final, "ParentType marked as final"); \ - static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || TypeName::_type_child_slots < ParentType::_type_child_slots, \ - "Need to set _type_child_slots when parent specifies it."); \ - TVMFFIByteArray type_key{TypeName::_type_key, \ - std::char_traits::length(TypeName::_type_key)}; \ - static int32_t tindex [[maybe_unused]] = TVMFFITypeGetOrAllocIndex( \ - &type_key, TypeName::_type_index, TypeName::_type_depth, TypeName::_type_child_slots, \ - TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \ - return TypeName::_type_index; \ - } \ - static int32_t RuntimeTypeIndex() { return TypeName::_type_index; } \ - static constexpr const char *_type_key = TypeKey - -/*! - * \brief Helper macro to declare object information with type key already defined in class. - * - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - */ -#define TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(TypeName, ParentType) \ - static constexpr int32_t _type_depth = ParentType::_type_depth + 1; \ - static int32_t _GetOrAllocRuntimeTypeIndex() { \ - static_assert(!ParentType::_type_final, "ParentType marked as final"); \ - static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || TypeName::_type_child_slots < ParentType::_type_child_slots, \ - "Need to set _type_child_slots when parent specifies it."); \ - TVMFFIByteArray type_key{TypeName::_type_key, \ - std::char_traits::length(TypeName::_type_key)}; \ - static int32_t tindex = TVMFFITypeGetOrAllocIndex( \ - &type_key, -1, TypeName::_type_depth, TypeName::_type_child_slots, \ - TypeName::_type_child_slots_can_overflow, ParentType::_GetOrAllocRuntimeTypeIndex()); \ - return tindex; \ - } \ - static int32_t RuntimeTypeIndex() { return _GetOrAllocRuntimeTypeIndex(); } - -/*! - * \brief Helper macro to declare object information with dynamic type index. - * - * For each custom object, you need to call tvm::ffi::reflection::ObjectDef() - * once in your cc file to register the type index with the runtime. - * Alternatively, you can call TypeName::_GetOrAllocRuntimeTypeIndex() once. - * - * \param TypeKey The type key of the current type. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - * \sa tvm::ffi::reflection::ObjectDef - */ -#define TVM_FFI_DECLARE_OBJECT_INFO(TypeKey, TypeName, ParentType) \ - static constexpr const char *_type_key = TypeKey; \ - TVM_FFI_DECLARE_OBJECT_INFO_PREDEFINED_TYPE_KEY(TypeName, ParentType) - -/*! - * \brief Helper macro to declare object information with dynamic type index and is final. - * - * For each custom object, you need to call tvm::ffi::reflection::ObjectDef() - * once in your cc file to register the type index with the runtime. - * Alternatively, you can call TypeName::_GetOrAllocRuntimeTypeIndex() once. - * - * \param TypeKey The type key of the current type. - * \param TypeName The name of the current type. - * \param ParentType The name of the ParentType - * \sa tvm::ffi::reflection::ObjectDef - */ -#define TVM_FFI_DECLARE_OBJECT_INFO_FINAL(TypeKey, TypeName, ParentType) \ - static const constexpr int _type_child_slots [[maybe_unused]] = 0; \ - static const constexpr bool _type_final [[maybe_unused]] = true; \ - TVM_FFI_DECLARE_OBJECT_INFO(TypeKey, TypeName, ParentType) - -/*! - * \brief Define object reference methods. - * - * \param TypeName The object type name - * \param ParentType The parent type of the objectref - * \param ObjectName The type name of the object. - * - * \note This macro also defines the default constructor that puts the ObjectRef - * in undefined state initially. - */ -#define TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ - explicit TypeName(::tvm::ffi::ObjectPtr n) : ParentType(std::move(n)) {} \ - explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \ - TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - using __PtrType = std::conditional_t<(ObjectName::_type_mutable), \ - ObjectName *, /* NOLINT(bugprone-macro-parentheses) */ \ - const ObjectName *>; \ - __PtrType operator->() const { return static_cast<__PtrType>(data_.get()); } \ - __PtrType get() const { return static_cast<__PtrType>(data_.get()); } \ - [[maybe_unused]] static constexpr bool _type_is_nullable = true; \ - using ContainerType = ObjectName - -/*! - * \brief Define object reference methods do not have undefined state. - * - * \param TypeName The object type name - * \param ParentType The parent type of the objectref - * \param ObjectName The type name of the object. - */ -#define TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(TypeName, ParentType, ObjectName) \ - explicit TypeName(::tvm::ffi::UnsafeInit tag) : ParentType(tag) {} \ - TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ - using __PtrType = std::conditional_t<(ObjectName::_type_mutable), \ - ObjectName *, /* NOLINT(bugprone-macro-parentheses) */ \ - const ObjectName *>; \ - __PtrType operator->() const { return static_cast<__PtrType>(data_.get()); } \ - __PtrType get() const { return static_cast<__PtrType>(data_.get()); } \ - [[maybe_unused]] static constexpr bool _type_is_nullable = false; \ - using ContainerType = ObjectName - -namespace details { - -template -TVM_FFI_INLINE bool IsObjectInstance(int32_t object_type_index) { - static_assert(std::is_base_of_v); - // Everything is a subclass of object. - if constexpr (std::is_same_v) { - return true; - } else if constexpr (TargetType::_type_final) { - // if the target type is a final type - // then we only need to check the equivalence. - return object_type_index == TargetType::RuntimeTypeIndex(); - } else { - // Explicitly enclose in else to eliminate this branch early in compilation. - // if target type is a non-leaf type - // Check if type index falls into the range of reserved slots. - int32_t target_type_index = TargetType::RuntimeTypeIndex(); - int32_t begin = target_type_index; - // The condition will be optimized by constant-folding. - if constexpr (TargetType::_type_child_slots != 0) { - // total_slots = child_slots + 1 (including self) - int32_t end = begin + TargetType::_type_child_slots + 1; - if (object_type_index >= begin && object_type_index < end) { - return true; - } - } else { - if (object_type_index == begin) { - return true; - } - } - if constexpr (TargetType::_type_child_slots_can_overflow) { - // Invariance: parent index is always smaller than the child. - if (object_type_index < target_type_index) { - return false; - } - // Do a runtime lookup of type information - // the function checks that the info exists - const TypeInfo *type_info = TVMFFIGetTypeInfo(object_type_index); - return (type_info->type_depth > TargetType::_type_depth && type_info->type_ancestors[TargetType::_type_depth]->type_index == target_type_index); - } else { - return false; - } - } -} - -/*! - * \brief Namespace to internally manipulate object class. - * \note These functions are only supposed to be used by internal - * implementations and not external users of the tvm::ffi - */ -struct ObjectUnsafe { - // NOTE: get ffi header from an object - TVM_FFI_INLINE static TVMFFIObject *GetHeader(const Object *src) { - return const_cast(&(src->header_)); - } - - template - TVM_FFI_INLINE static int64_t GetObjectOffsetToSubclass() { - return (reinterpret_cast(&(static_cast(nullptr)->header_)) - reinterpret_cast(&(static_cast(nullptr)->header_))); - } - - template - TVM_FFI_INLINE static T ObjectRefFromObjectPtr(const ObjectPtr &ptr) { - T ref(UnsafeInit{}); - ref.data_ = ptr; - return ref; - } - - template - TVM_FFI_INLINE static T ObjectRefFromObjectPtr(ObjectPtr &&ptr) { - T ref(UnsafeInit{}); - ref.data_ = std::move(ptr); - return ref; - } - - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromObjectRef(const ObjectRef &ref) { - if constexpr (std::is_same_v) { - return ref.data_; - } else { - return tvm::ffi::ObjectPtr(ref.data_.data_); - } - } - - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromObjectRef(ObjectRef &&ref) { - if constexpr (std::is_same_v) { - return std::move(ref.data_); - } else { - ObjectPtr result; - result.data_ = std::move(ref.data_.data_); - ref.data_.data_ = nullptr; - return result; - } - } - - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromOwned(Object *raw_ptr) { - tvm::ffi::ObjectPtr ptr; - ptr.data_ = raw_ptr; - return ptr; - } - - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromOwned(TVMFFIObject *obj_ptr) { - return ObjectPtrFromOwned(reinterpret_cast(obj_ptr)); - } - - template - TVM_FFI_INLINE static T *RawObjectPtrFromUnowned(TVMFFIObject *obj_ptr) { - // NOTE: this is important to first cast to Object* - // then cast back to T* because objptr and tptr may not be the same - // depending on how sub-class allocates the space. - return static_cast(reinterpret_cast(obj_ptr)); - } - - // Create ObjectPtr from unowned ptr - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromUnowned(Object *raw_ptr) { - return tvm::ffi::ObjectPtr(raw_ptr); - } - - template - TVM_FFI_INLINE static ObjectPtr ObjectPtrFromUnowned(TVMFFIObject *obj_ptr) { - return tvm::ffi::ObjectPtr(reinterpret_cast(obj_ptr)); - } - - TVM_FFI_INLINE static void DecRefObjectHandle(TVMFFIObjectHandle handle) { - reinterpret_cast(handle)->DecRef(); - } - - TVM_FFI_INLINE static void IncRefObjectHandle(TVMFFIObjectHandle handle) { - reinterpret_cast(handle)->IncRef(); - } - - TVM_FFI_INLINE static Object *RawObjectPtrFromObjectRef(const ObjectRef &src) { - return src.data_.data_; - } - - TVM_FFI_INLINE static TVMFFIObject *TVMFFIObjectPtrFromObjectRef(const ObjectRef &src) { - return GetHeader(src.data_.data_); - } - - template - TVM_FFI_INLINE static TVMFFIObject *TVMFFIObjectPtrFromObjectPtr(const ObjectPtr &src) { - return GetHeader(src.data_); - } - - template - TVM_FFI_INLINE static TVMFFIObject *MoveObjectPtrToTVMFFIObjectPtr(ObjectPtr &&src) { - Object *obj_ptr = src.data_; - src.data_ = nullptr; - return GetHeader(obj_ptr); - } - - TVM_FFI_INLINE static TVMFFIObject *MoveObjectRefToTVMFFIObjectPtr(ObjectRef &&src) { - Object *obj_ptr = src.data_.data_; - src.data_.data_ = nullptr; - return GetHeader(obj_ptr); - } -}; -} // namespace details -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_OBJECT_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/optional.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/optional.h deleted file mode 100644 index 11dcc46a8..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/optional.h +++ /dev/null @@ -1,428 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/optional.h - * \brief Runtime Optional container types. - * \note Optional specializes for T is ObjectRef and used nullptr to indicate nullopt. - */ -#ifndef TVM_FFI_OPTIONAL_H_ -#define TVM_FFI_OPTIONAL_H_ - -#include "error.h" -#include "object.h" -#include "string.h" - -#include -#include -#include - -namespace tvm { -namespace ffi { - -// Note: We place optional in tvm/ffi instead of tvm/ffi/container -// because optional itself is an inherent core component of the FFI system. -/// \cond Doxygen_Suppress -template -inline constexpr bool is_optional_type_v = false; - -template -inline constexpr bool is_optional_type_v> = true; - -// we can safely used ptr based optional for ObjectRef types -// that do not have additional data members and virtual functions. -template -inline constexpr bool use_ptr_based_optional_v = (std::is_base_of_v && !is_optional_type_v); -/// \endcond - -// Specialization for non-ObjectRef types. -// simply fallback to std::optional -template -class Optional && !std::is_same_v && !std::is_same_v>> { -public: - // default constructors. - Optional() = default; - // NOLINTBEGIN(google-explicit-constructor) - Optional(const Optional &other) : data_(other.data_) {} - Optional(Optional &&other) noexcept : data_(std::move(other.data_)) {} - Optional(std::optional other) : data_(std::move(other)) {} - Optional(std::nullopt_t) {} - Optional(T other) : data_(std::move(other)) {} - // NOLINTEND(google-explicit-constructor) - - TVM_FFI_INLINE Optional &operator=(const Optional &other) { - data_ = other.data_; - return *this; - } - - TVM_FFI_INLINE Optional &operator=(Optional &&other) noexcept { - data_ = std::move(other.data_); - return *this; - } - - TVM_FFI_INLINE Optional &operator=(T other) { - data_ = std::move(other); - return *this; - } - - TVM_FFI_INLINE Optional &operator=(std::nullopt_t) { - data_ = std::nullopt; - return *this; - } - - TVM_FFI_INLINE const T &value() const & { - if (!data_.has_value()) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return *data_; - } - - TVM_FFI_INLINE T &&value() && { - if (!data_.has_value()) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return *std::move(data_); - } - - template > - TVM_FFI_INLINE T value_or(U &&default_value) const { - return data_.value_or(std::forward(default_value)); - } - - TVM_FFI_INLINE explicit operator bool() const noexcept { return data_.has_value(); } - - TVM_FFI_INLINE bool has_value() const noexcept { return data_.has_value(); } - - TVM_FFI_INLINE bool operator==(const Optional &other) const { return data_ == other.data_; } - - TVM_FFI_INLINE bool operator!=(const Optional &other) const { return data_ != other.data_; } - - template - TVM_FFI_INLINE bool operator==(const U &other) const { - return data_ == other; - } - template - TVM_FFI_INLINE bool operator!=(const U &other) const { - return data_ != other; - } - - // NOLINTBEGIN(bugprone-unchecked-optional-access) - /*! - * \brief Direct access to the value. - * \return the xvalue reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE T &&operator*() && noexcept { return *std::move(data_); } - /*! - * \brief Direct access to the value. - * \return the const reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE const T &operator*() const & noexcept { return *data_; } - // NOLINTEND(bugprone-unchecked-optional-access) - -private: - std::optional data_; -}; - -// Specialization for String type, use nullptr to indicate nullopt -template -class Optional || std::is_same_v>> { -public: - // default constructors. - Optional() = default; - // NOLINTBEGIN(google-explicit-constructor) - Optional(const Optional &other) : data_(other.data_) {} - Optional(Optional &&other) : data_(std::move(other.data_)) {} - Optional(std::nullopt_t) {} - Optional(T other) : data_(std::move(other)) {} - // NOLINTEND(google-explicit-constructor) - - TVM_FFI_INLINE Optional &operator=(const Optional &other) { - data_ = other.data_; - return *this; - } - - TVM_FFI_INLINE Optional &operator=(Optional &&other) { - data_ = std::move(other.data_); - return *this; - } - - TVM_FFI_INLINE Optional &operator=(T other) { - data_ = std::move(other); - return *this; - } - - TVM_FFI_INLINE Optional &operator=(std::nullopt_t) { - T(details::BytesBaseCell(std::nullopt)).swap(data_); - return *this; - } - - TVM_FFI_INLINE const T &value() const & { - if (data_.data_ == std::nullopt) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return data_; - } - - TVM_FFI_INLINE String &&value() && { - if (data_.data_ == std::nullopt) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return std::move(data_); - } - - template - TVM_FFI_INLINE T value_or(U &&default_value) const { - if (data_.data_ == std::nullopt) { - return std::forward(default_value); - } - return data_; - } - - TVM_FFI_INLINE explicit operator bool() const noexcept { return data_.data_ != std::nullopt; } - - TVM_FFI_INLINE bool has_value() const noexcept { return data_.data_ != std::nullopt; } - - TVM_FFI_INLINE bool operator==(const Optional &other) const { - if (data_.data_ == std::nullopt) { - return other.data_.data_ == std::nullopt; - } - if (other.data_.data_ == std::nullopt) { - return false; - } - return data_ == other.data_; - } - - TVM_FFI_INLINE bool operator!=(const Optional &other) const { return !(*this == other); } - - template - TVM_FFI_INLINE bool operator==(const U &other) const { - if constexpr (std::is_same_v) { - return data_.data_ == std::nullopt; - } else { - if (data_.data_ == std::nullopt) { - return false; - } - return data_ == other; - } - } - template - TVM_FFI_INLINE bool operator!=(const U &other) const { - if constexpr (std::is_same_v) { - return data_.data_ != std::nullopt; - } else { - if (data_.data_ == std::nullopt) { - return true; - } - return data_ != other; - } - } - - /*! - * \brief Direct access to the value. - * \return the xvalue reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE T &&operator*() && noexcept { return std::move(data_); } - /*! - * \brief Direct access to the value. - * \return the const reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE const T &operator*() const & noexcept { return data_; } - -private: - // this is a private initializer - T data_{details::BytesBaseCell(std::nullopt)}; -}; - -// Specialization for ObjectRef types. -// nullptr is treated as std::nullopt. -template -class Optional>> : public ObjectRef { -public: - using ContainerType = typename T::ContainerType; - Optional() = default; - // NOLINTBEGIN(google-explicit-constructor) - Optional(const Optional &other) : ObjectRef(other) {} - Optional(Optional &&other) noexcept : ObjectRef(std::move(other)) {} - explicit Optional(ffi::UnsafeInit tag) : ObjectRef(tag) {} - Optional(std::nullopt_t) {} - Optional(std::optional other) { - if (other.has_value()) { - *this = *std::move(other); - } - } - Optional(T other) : ObjectRef(std::move(other)) {} - // NOLINTEND(google-explicit-constructor) - - TVM_FFI_INLINE Optional &operator=(T other) { - ObjectRef::operator=(std::move(other)); - return *this; - } - - TVM_FFI_INLINE Optional &operator=(const Optional &other) { - data_ = other.data_; - return *this; - } - - TVM_FFI_INLINE Optional &operator=(std::nullptr_t) { - data_ = nullptr; - return *this; - } - - TVM_FFI_INLINE Optional &operator=(Optional &&other) { - data_ = std::move(other.data_); - return *this; - } - - TVM_FFI_INLINE T value() const & { - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return details::ObjectUnsafe::ObjectRefFromObjectPtr(data_); - } - - TVM_FFI_INLINE T value() && { - if (data_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Back optional access"; - } - return details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(data_)); - } - - template > - TVM_FFI_INLINE T value_or(U &&default_value) const { - return data_ != nullptr ? details::ObjectUnsafe::ObjectRefFromObjectPtr(data_) - : T(std::forward(default_value)); - } - - TVM_FFI_INLINE explicit operator bool() const { return data_ != nullptr; } - - TVM_FFI_INLINE bool has_value() const { return data_ != nullptr; } - - /*! - * \brief Direct access to the value. - * \return the const reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE T operator*() const & noexcept { - return details::ObjectUnsafe::ObjectRefFromObjectPtr(data_); - } - - /*! - * \brief Direct access to the value. - * \return the const reference to the stored value. - * \note only use this function after checking has_value() - */ - TVM_FFI_INLINE T operator*() && noexcept { - return details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(data_)); - } - - TVM_FFI_INLINE bool operator==(std::nullptr_t) const noexcept { return !has_value(); } - TVM_FFI_INLINE bool operator!=(std::nullptr_t) const noexcept { return has_value(); } - - // operator overloadings - TVM_FFI_INLINE auto operator==(const Optional &other) const { - // support case where sub-class returns a symbolic ref type. - return EQToOptional(other); - } - TVM_FFI_INLINE auto operator!=(const Optional &other) const { return NEToOptional(other); } - - TVM_FFI_INLINE auto operator==(const std::optional &other) const { - // support case where sub-class returns a symbolic ref type. - return EQToOptional(other); - } - TVM_FFI_INLINE auto operator!=(const std::optional &other) const { - return NEToOptional(other); - } - - TVM_FFI_INLINE auto operator==(const T &other) const { - using RetType = decltype(value() == other); - if (same_as(other)) { - return RetType(true); - } - if (has_value()) { - return operator*() == other; - } - return RetType(false); - } - - TVM_FFI_INLINE auto operator!=(const T &other) const { return !(*this == other); } - - template - TVM_FFI_INLINE auto operator==(const U &other) const { - using RetType = decltype(value() == other); - if (!has_value()) { - return RetType(false); - } - return operator*() == other; - } - - template - TVM_FFI_INLINE auto operator!=(const U &other) const { - using RetType = decltype(value() != other); - if (!has_value()) { - return RetType(true); - } - return operator*() != other; - } - - /*! - * \return The internal object pointer with container type of T. - * \note This function do not perform not-null checking. - */ - TVM_FFI_INLINE const ContainerType *get() const { - return static_cast(data_.get()); - } - -private: - template - TVM_FFI_INLINE auto EQToOptional(const U &other) const { - // support case where sub-class returns a symbolic ref type. - using RetType = decltype(operator*() == *other); - if (same_as(other)) { - return RetType(true); - } - if (has_value() && other.has_value()) { - return operator*() == *other; - } else { - // one of them is nullptr. - return RetType(false); - } - } - - template - TVM_FFI_INLINE auto NEToOptional(const U &other) const { - // support case where sub-class returns a symbolic ref type. - using RetType = decltype(operator*() != *other); - if (same_as(other)) { - return RetType(false); - } - if (has_value() && other.has_value()) { - return operator*() != *other; - } else { - // one of them is nullptr. - return RetType(true); - } - } -}; -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_OPTIONAL_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/access_path.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/access_path.h deleted file mode 100644 index 4d716bce0..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/access_path.h +++ /dev/null @@ -1,450 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/reflection/registry.h - * \brief Registry of reflection metadata. - */ -#ifndef TVM_FFI_REFLECTION_ACCESS_PATH_H_ -#define TVM_FFI_REFLECTION_ACCESS_PATH_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace tvm { -namespace ffi { -namespace reflection { - -/*! - * \brief The kind of the access pattern. - */ -enum class AccessKind : int32_t { - /*! \brief Object attribute access. */ - kAttr = 0, - /*! \brief Array item access. */ - kArrayItem = 1, - /*! \brief Map item access. */ - kMapItem = 2, - // the following two are used for error reporting when - // the supposed access field is not available - /*! \brief Object attribute missing access. */ - kAttrMissing = 3, - /*! \brief Array item missing access. */ - kArrayItemMissing = 4, - /*! \brief Map item missing access. */ - kMapItemMissing = 5, -}; - -class AccessStep; - -/*! - * \brief Represent a single step in object field, map key, array index access. - */ -class AccessStepObj : public Object { -public: - /*! - * \brief The kind of the access pattern. - */ - AccessKind kind; - /*! - * \brief The access key - * \note for array access, it will always be integer - * for field access, it will be string - */ - Any key; - - // default constructor to enable auto-serialization - AccessStepObj() = default; - /*! - * \brief Constructor - * \param kind The kind of the access step. - * \param key The key of the access step. - */ - AccessStepObj(AccessKind kind, Any key) : kind(kind), key(std::move(key)) {} - - /*! - * \brief Deep check if two steps are equal. - * \param other The other step to compare with. - * \return True if the two steps are equal, false otherwise. - */ - inline bool StepEqual(const AccessStep &other) const; - - /// \cond Doxygen_Suppress - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ffi.reflection.AccessStep", AccessStepObj, Object); - /// \endcond -}; - -/*! - * \brief ObjectRef class of AccessStepObj. - * - * \sa AccessStepObj - */ -class AccessStep : public ObjectRef { -public: - /*! - * \brief Constructor - * \param kind The kind of the access step. - * \param key The key of the access step. - */ - AccessStep(AccessKind kind, Any key) - : ObjectRef(make_object(kind, std::move(key))) {} - - /*! - * \brief Create an access step for a object attribute access. - * \param field_name The name of the field to access. - * \return The access step. - */ - static AccessStep Attr(String field_name) { - return AccessStep(AccessKind::kAttr, std::move(field_name)); - } - - /*! - * \brief Create an access step for a object attribute missing access. - * \param field_name The name of the field to access. - * \return The access step. - */ - static AccessStep AttrMissing(String field_name) { - return AccessStep(AccessKind::kAttrMissing, std::move(field_name)); - } - - /*! - * \brief Create an access step for a array item access. - * \param index The index of the array item to access. - * \return The access step. - */ - static AccessStep ArrayItem(int64_t index) { return AccessStep(AccessKind::kArrayItem, index); } - - /*! - * \brief Create an access step for a array item missing access. - * \param index The index of the array item to access. - * \return The access step. - */ - static AccessStep ArrayItemMissing(int64_t index) { - return AccessStep(AccessKind::kArrayItemMissing, index); - } - - /*! - * \brief Create an access step for a map item access. - * \param key The key of the map item to access. - * \return The access step. - */ - static AccessStep MapItem(Any key) { return AccessStep(AccessKind::kMapItem, std::move(key)); } - - /*! - * \brief Create an access step for a map item missing access. - * \param key The key of the map item to access. - * \return The access step. - */ - static AccessStep MapItemMissing(Any key = nullptr) { - return AccessStep(AccessKind::kMapItemMissing, std::move(key)); - } - - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AccessStep, ObjectRef, AccessStepObj); - /// \endcond -}; - -inline bool AccessStepObj::StepEqual(const AccessStep &other) const { - return this->kind == other->kind && AnyEqual()(this->key, other->key); -} - -// forward declaration -class AccessPath; - -/*! - * \brief ObjectRef class of AccessPathObj. - * - * \sa AccessPathObj - */ -class AccessPathObj : public Object { -public: - /*! - * \brief The parent of the access path. - * - * This parent-pointing tree structure is more space efficient when - * representing multiple paths that share a common prefix. - * - * \note Empty for root. - */ - Optional parent; - /*! - * \brief The current of the access path. - * \note Empty for root. - */ - Optional step; - /*! - * \brief The current depth of the access path, 0 for root - */ - int32_t depth; - - // default constructor to enable auto-serialization - AccessPathObj() = default; - /*! - * \brief Constructor for the access path. - * \param parent The parent of the access path. - * \param step The current step of the access path. - * \param depth The current depth of the access path. - */ - AccessPathObj(Optional parent, Optional step, int32_t depth) - : parent(std::move(parent)), step(std::move(step)), depth(depth) {} - - /*! - * \brief Get the parent of the access path. - * \return The parent of the access path. - */ - inline Optional GetParent() const; - - /*! - * \brief Extend the access path with a new step. - * \param step The step to extend the access path with. - * \return The extended access path. - */ - inline AccessPath Extend(AccessStep step) const; - - /*! - * \brief Extend the access path with an object attribute access. - * \param field_name The name of the field to access. - * \return The extended access path. - */ - inline AccessPath Attr(String field_name) const; - - /*! - * \brief Extend the access path with an object attribute missing access. - * \param field_name The name of the field to access. - * \return The extended access path. - */ - inline AccessPath AttrMissing(String field_name) const; - - /*! - * \brief Extend the access path with an array item access. - * \param index The index of the array item to access. - * \return The extended access path. - */ - inline AccessPath ArrayItem(int64_t index) const; - - /*! - * \brief Extend the access path with an array item missing access. - * \param index The index of the array item to access. - * \return The extended access path. - */ - inline AccessPath ArrayItemMissing(int64_t index) const; - - /*! - * \brief Extend the access path with a map item access. - * \param key The key of the map item to access. - * \return The extended access path. - */ - inline AccessPath MapItem(Any key) const; - - /*! - * \brief Extend the access path with a map item missing access. - * \param key The key of the map item to access. - * \return The extended access path. - */ - inline AccessPath MapItemMissing(Any key) const; - - /*! - * \brief Get the array of steps that corresponds to the access path. - * \return The array of steps that corresponds to the access path. - */ - inline Array ToSteps() const; - - /*! - * \brief Check if two paths are equal by deep comparing the steps. - * \param other The other path to compare with. - * \return True if the two paths are equal, false otherwise. - */ - inline bool PathEqual(const AccessPath &other) const; - - /*! - * \brief Check if this path is a prefix of another path. - * \param other The other path to compare with. - * \return True if this path is a prefix of the other path, false otherwise. - */ - inline bool IsPrefixOf(const AccessPath &other) const; - - /// \cond Doxygen_Suppress - static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindConstTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ffi.reflection.AccessPath", AccessPathObj, Object); - /// \endcond - -private: - static bool PathEqual(const AccessPathObj *lhs, const AccessPathObj *rhs) { - // fast path for same pointer - if (lhs == rhs) { - return true; - } - if (lhs->depth != rhs->depth) { - return false; - } - // do deep equality checks - while (lhs->parent.has_value()) { - TVM_FFI_ICHECK(rhs->parent.has_value()); - TVM_FFI_ICHECK(lhs->step.has_value()); - TVM_FFI_ICHECK(rhs->step.has_value()); - if (!(*lhs->step)->StepEqual(*(rhs->step))) { - return false; - } - lhs = static_cast(lhs->parent.get()); - rhs = static_cast(rhs->parent.get()); - // fast path for same pointer - if (lhs == rhs) { - return true; - } - TVM_FFI_ICHECK(lhs != nullptr); - TVM_FFI_ICHECK(rhs != nullptr); - } - return true; - } -}; - -/*! - * \brief ObjectRef class of AccessPath. - * - * \sa AccessPathObj - */ -class AccessPath : public ObjectRef { -public: - /*! - * \brief Create an access path from an iterator range of steps. - * \param begin The beginning of the iterator range. - * \param end The end of the iterator range. - * \return The access path. - */ - template // NOLINTNEXTLINE(performance-unnecessary-value-param) - static AccessPath FromSteps(Iter begin, Iter end) { - AccessPath path = AccessPath::Root(); - for (Iter it = begin; it != end; ++it) { - path = path->Extend(*it); - } - return path; - } - /*! - * \brief Create an access path from an array of steps. - * \param steps The array of steps. - * \return The access path. - */ - static AccessPath FromSteps(const Array &steps) { - AccessPath path = AccessPath::Root(); - for (AccessStep step : steps) { - path = path->Extend(step); - } - return path; - } - - /*! - * \brief Create a root access path. - * \return The root access path. - */ - static AccessPath Root() { - return AccessPath(make_object(std::nullopt, std::nullopt, 0)); - } - - /// \cond Doxygen_Suppress - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(AccessPath, ObjectRef, AccessPathObj); - /// \endcond - -private: - friend class AccessPathObj; - explicit AccessPath(ObjectPtr ptr) : ObjectRef(std::move(ptr)) {} -}; - -/*! - * \brief The pair of access paths. - */ -using AccessPathPair = Tuple; - -inline Optional AccessPathObj::GetParent() const { - if (auto opt_parent = this->parent.as()) { - return opt_parent; - } - return std::nullopt; -} - -inline AccessPath AccessPathObj::Extend(AccessStep step) const { - return AccessPath( - make_object(GetRef(this), std::move(step), this->depth + 1)); -} - -inline AccessPath AccessPathObj::Attr(String field_name) const { - return this->Extend(AccessStep::Attr(std::move(field_name))); -} - -inline AccessPath AccessPathObj::AttrMissing(String field_name) const { - return this->Extend(AccessStep::AttrMissing(std::move(field_name))); -} - -inline AccessPath AccessPathObj::ArrayItem(int64_t index) const { - return this->Extend(AccessStep::ArrayItem(index)); -} - -inline AccessPath AccessPathObj::ArrayItemMissing(int64_t index) const { - return this->Extend(AccessStep::ArrayItemMissing(index)); -} - -inline AccessPath AccessPathObj::MapItem(Any key) const { - return this->Extend(AccessStep::MapItem(std::move(key))); -} - -inline AccessPath AccessPathObj::MapItemMissing(Any key) const { - return this->Extend(AccessStep::MapItemMissing(std::move(key))); -} - -inline Array AccessPathObj::ToSteps() const { - std::vector reverse_steps; - reverse_steps.reserve(this->depth); - const AccessPathObj *current = this; - while (current->parent.has_value()) { - TVM_FFI_ICHECK(current->step.has_value()); - reverse_steps.push_back(*(current->step)); - current = static_cast(current->parent.get()); - TVM_FFI_ICHECK(current != nullptr); - } - return Array(reverse_steps.rbegin(), reverse_steps.rend()); -} - -inline bool AccessPathObj::PathEqual(const AccessPath &other) const { - return PathEqual(this, other.get()); -} - -inline bool AccessPathObj::IsPrefixOf(const AccessPath &other) const { - if (this->depth > other->depth) { - return false; - } - const AccessPathObj *rhs_path = other.get(); - while (rhs_path->depth > this->depth) { - TVM_FFI_ICHECK(rhs_path->parent.has_value()); - rhs_path = static_cast(rhs_path->parent.get()); - } - return PathEqual(this, rhs_path); -} - -} // namespace reflection -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_REFLECTION_ACCESS_PATH_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/accessor.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/accessor.h deleted file mode 100644 index c77b01679..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/accessor.h +++ /dev/null @@ -1,264 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/reflection/accessor.h - * \brief Reflection-based accessor for object fields and methods. - */ -#ifndef TVM_FFI_REFLECTION_ACCESSOR_H_ -#define TVM_FFI_REFLECTION_ACCESSOR_H_ - -#include -#include -#include -#include - -#include -#include - -namespace tvm { -namespace ffi { -namespace reflection { - -/*! - * \brief helper function to get reflection field info by type key and field name - */ -inline const TVMFFIFieldInfo *GetFieldInfo(std::string_view type_key, const char *field_name) { - int32_t type_index; - TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); - const TypeInfo *info = TVMFFIGetTypeInfo(type_index); - for (int32_t i = 0; i < info->num_fields; ++i) { - if (std::strncmp(info->fields[i].name.data, field_name, info->fields[i].name.size) == 0) { - return &(info->fields[i]); - } - } - TVM_FFI_THROW(RuntimeError) << "Cannot find field `" << field_name << "` in " << type_key; - TVM_FFI_UNREACHABLE(); -} - -/*! - * \brief helper wrapper class to obtain a getter. - */ -class FieldGetter { -public: - /*! - * \brief Constructor - * \param field_info The field info. - */ - explicit FieldGetter(const TVMFFIFieldInfo *field_info) : field_info_(field_info) {} - - /*! - * \brief Constructor - * \param type_key The type key. - * \param field_name The name of the field. - */ - explicit FieldGetter(std::string_view type_key, const char *field_name) - : FieldGetter(GetFieldInfo(type_key, field_name)) {} - - /*! - * \brief Get the value of the field - * \param obj_ptr The object pointer. - * \return The value of the field. - */ - Any operator()(const Object *obj_ptr) const { - Any result; - const void *addr = reinterpret_cast(obj_ptr) + field_info_->offset; - TVM_FFI_CHECK_SAFE_CALL( - field_info_->getter(const_cast(addr), reinterpret_cast(&result))); - return result; - } - - Any operator()(const ObjectPtr &obj_ptr) const { return operator()(obj_ptr.get()); } - - Any operator()(const ObjectRef &obj) const { return operator()(obj.get()); } - -private: - const TVMFFIFieldInfo *field_info_; -}; - -/*! - * \brief helper wrapper class to obtain a setter. - */ -class FieldSetter { -public: - /*! - * \brief Constructor - * \param field_info The field info. - */ - explicit FieldSetter(const TVMFFIFieldInfo *field_info) : field_info_(field_info) {} - - /*! - * \brief Constructor - * \param type_key The type key. - * \param field_name The name of the field. - */ - explicit FieldSetter(std::string_view type_key, const char *field_name) - : FieldSetter(GetFieldInfo(type_key, field_name)) {} - - /*! - * \brief Set the value of the field - * \param obj_ptr The object pointer. - * \param value The value to be set. - */ - void operator()(const Object *obj_ptr, AnyView value) const { - const void *addr = reinterpret_cast(obj_ptr) + field_info_->offset; - TVM_FFI_CHECK_SAFE_CALL( - field_info_->setter(const_cast(addr), reinterpret_cast(&value))); - } - - void operator()(const ObjectPtr &obj_ptr, AnyView value) const { - operator()(obj_ptr.get(), value); - } - - void operator()(const ObjectRef &obj, AnyView value) const { operator()(obj.get(), value); } - -private: - const TVMFFIFieldInfo *field_info_; -}; - -/*! - * \brief Helper class to get type attribute column. - */ -class TypeAttrColumn { -public: - /*! - * \brief Constructor - * \param attr_name The name of the type attribute. - */ - explicit TypeAttrColumn(std::string_view attr_name) { - TVMFFIByteArray attr_name_array = {attr_name.data(), attr_name.size()}; - column_ = TVMFFIGetTypeAttrColumn(&attr_name_array); - if (column_ == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Cannot find type attribute " << attr_name; - } - } - /*! - * \brief Get the type attribute column by type index. - * \param type_index The type index. - * \return The type attribute column. - */ - AnyView operator[](int32_t type_index) const { - size_t tindex = static_cast(type_index); - if (tindex >= column_->size) { - return AnyView(); - } - const AnyView *any_view_data = reinterpret_cast(column_->data); - return any_view_data[tindex]; - } - -private: - const TVMFFITypeAttrColumn *column_; -}; - -/*! - * \brief helper function to get reflection method info by type key and method name - * - * \param type_key The type key. - * \param method_name The name of the method. - * \return The method info. - */ -inline const TVMFFIMethodInfo *GetMethodInfo(std::string_view type_key, const char *method_name) { - int32_t type_index; - TVMFFIByteArray type_key_array = {type_key.data(), type_key.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeKeyToIndex(&type_key_array, &type_index)); - const TypeInfo *info = TVMFFIGetTypeInfo(type_index); - for (int32_t i = 0; i < info->num_methods; ++i) { - if (std::strncmp(info->methods[i].name.data, method_name, info->methods[i].name.size) == 0) { - return &(info->methods[i]); - } - } - TVM_FFI_THROW(RuntimeError) << "Cannot find method " << method_name << " in " << type_key; - TVM_FFI_UNREACHABLE(); -} - -/*! - * \brief helper function to get reflection method function by method info - * - * \param type_key The type key. - * \param method_name The name of the method. - * \return The method function. - */ -inline Function GetMethod(std::string_view type_key, const char *method_name) { - const TVMFFIMethodInfo *info = GetMethodInfo(type_key, method_name); - return AnyView::CopyFromTVMFFIAny(info->method).cast(); -} - -/*! - * \brief Visit each field info of the type info and run callback. - * - * \tparam Callback The callback function type. - * - * \param type_info The type info. - * \param callback The callback function. - * - * \note This function calls both the child and parent type info. - */ -template -inline void ForEachFieldInfo(const TypeInfo *type_info, Callback callback) { - using ResultType = decltype(callback(type_info->fields)); - static_assert(std::is_same_v, "Callback must return void"); - // iterate through acenstors in parent to child order - // skip the first one since it is always the root object - for (int i = 1; i < type_info->type_depth; ++i) { - const TVMFFITypeInfo *parent_info = type_info->type_ancestors[i]; - for (int j = 0; j < parent_info->num_fields; ++j) { - callback(parent_info->fields + j); - } - } - for (int i = 0; i < type_info->num_fields; ++i) { - callback(type_info->fields + i); - } -} - -/*! - * \brief Visit each field info of the type info and run callback which returns bool for early stop. - * - * \tparam Callback The callback function type, which returns bool for early stop. - * - * \param type_info The type info. - * \param callback_with_early_stop The callback function. - * \return true if any of early stop is triggered. - * - * \note This function calls both the child and parent type info and can be used for searching. - */ -template -inline bool ForEachFieldInfoWithEarlyStop(const TypeInfo *type_info, - Callback callback_with_early_stop) { - // iterate through acenstors in parent to child order - // skip the first one since it is always the root object - for (int i = 1; i < type_info->type_depth; ++i) { - const TVMFFITypeInfo *parent_info = type_info->type_ancestors[i]; - for (int j = 0; j < parent_info->num_fields; ++j) { - if (callback_with_early_stop(parent_info->fields + j)) { - return true; - } - } - } - for (int i = 0; i < type_info->num_fields; ++i) { - if (callback_with_early_stop(type_info->fields + i)) { - return true; - } - } - return false; -} - -} // namespace reflection -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_REFLECTION_ACCESSOR_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/creator.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/creator.h deleted file mode 100644 index dcbf3e056..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/creator.h +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/reflection/creator.h - * \brief Reflection-based creator to create objects from type key and fields. - */ -#ifndef TVM_FFI_REFLECTION_CREATOR_H_ -#define TVM_FFI_REFLECTION_CREATOR_H_ - -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -namespace reflection { -/*! - * \brief helper wrapper class of TVMFFITypeInfo to create object based on reflection. - */ -class ObjectCreator { -public: - /*! - * \brief Constructor - * \param type_key The type key. - */ - explicit ObjectCreator(std::string_view type_key) - : ObjectCreator(TVMFFIGetTypeInfo(TypeKeyToIndex(type_key))) {} - - /*! - * \brief Constructor - * \param type_info The type info. - */ - explicit ObjectCreator(const TVMFFITypeInfo *type_info) : type_info_(type_info) { - int32_t type_index = type_info->type_index; - if (type_info->metadata == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not have reflection registered"; - } - if (type_info->metadata->creator == nullptr) { - TVM_FFI_THROW(RuntimeError) << "Type `" << TypeIndexToTypeKey(type_index) - << "` does not support default constructor, " - << "as a result cannot be created via reflection"; - } - } - - /** - * \brief Create an object from a map of fields. - * \param fields The fields of the object. - * \return The created object. - */ - Any operator()(const Map &fields) const { - TVMFFIObjectHandle handle; - TVM_FFI_CHECK_SAFE_CALL(type_info_->metadata->creator(&handle)); - ObjectPtr ptr = details::ObjectUnsafe::ObjectPtrFromOwned(static_cast(handle)); - size_t match_field_count = 0; - ForEachFieldInfo(type_info_, [&](const TVMFFIFieldInfo *field_info) { - String field_name(field_info->name); - void *field_addr = reinterpret_cast(ptr.get()) + field_info->offset; - if (fields.count(field_name) != 0) { - Any field_value = fields[field_name]; - field_info->setter(field_addr, reinterpret_cast(&field_value)); - ++match_field_count; - } else if (field_info->flags & kTVMFFIFieldFlagBitMaskHasDefault) { - field_info->setter(field_addr, &(field_info->default_value)); - } else { - TVM_FFI_THROW(TypeError) << "Required field `" - << String(field_info->name.data, field_info->name.size) - << "` not set in type `" - << String(type_info_->type_key.data, type_info_->type_key.size) - << "`"; - } - }); - if (match_field_count == fields.size()) { - return ObjectRef(ptr); - } - // report error that checks if contains extra fields that are not in the type - auto check_field_name = [&](const String &field_name) { - bool found = false; - ForEachFieldInfoWithEarlyStop(type_info_, [&](const TVMFFIFieldInfo *field_info) { - if (field_name.compare(field_info->name) == 0) { - found = true; - return true; - } - return false; - }); - return found; - }; - for (const auto &[field_name, _] : fields) { - if (!check_field_name(field_name)) { - TVM_FFI_THROW(TypeError) << "Type `" - << String(type_info_->type_key.data, type_info_->type_key.size) - << "` does not have field `" << field_name << "`"; - } - } - TVM_FFI_UNREACHABLE(); - } - -private: - const TVMFFITypeInfo *type_info_; -}; -} // namespace reflection -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_REFLECTION_CREATOR_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/registry.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/registry.h deleted file mode 100644 index e1978b1e6..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/reflection/registry.h +++ /dev/null @@ -1,739 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/reflection/registry.h - * \brief Registry of reflection metadata. - */ -#ifndef TVM_FFI_REFLECTION_REGISTRY_H_ -#define TVM_FFI_REFLECTION_REGISTRY_H_ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace tvm { -namespace ffi { -/*! \brief Reflection namespace */ -namespace reflection { -/*! - * \brief Types of temporary metadata hold in FieldInfoBuilder and MethodInfoBuilder, - * before they are filled into final C metadata - */ -using _MetadataType = std::vector>; // NOLINT(bugprone-reserved-identifier) -/*! - * \brief Builder for TVMFFIFieldInfo - * \sa TVMFFIFieldInfo - */ -struct FieldInfoBuilder : public TVMFFIFieldInfo { - /*! \brief Temporary metadata info to be filled into TVMFFIFieldInfo::metadata */ - _MetadataType metadata_; -}; -/*! - * \brief Builder for TVMFFIMethodInfo - * \sa TVMFFIMethodInfo - */ -struct MethodInfoBuilder : public TVMFFIMethodInfo { - /*! \brief Temporary metadata info to be filled into TVMFFIMethodInfo::metadata */ - _MetadataType metadata_; -}; - -/*! - * \brief Trait that can be used to set information attached to a field or a method. - * \sa DefaultValue, AttachFieldFlag - */ -struct InfoTrait {}; - -/*! \brief User-supplied metadata attached to a field or a method */ -class Metadata : public InfoTrait { -public: - /*! - * \brief Constructor - * \param dict The initial dictionary - */ - Metadata(std::initializer_list> dict) : dict_(dict) {} - /*! - * \brief Move metadata into `FieldInfoBuilder` - * \param info The field info builder. - */ - inline void Apply(FieldInfoBuilder *info) const { this->Apply(&info->metadata_); } - /*! - * \brief Move metadata into `MethodInfoBuilder` - * \param info The method info builder. - */ - inline void Apply(MethodInfoBuilder *info) const { this->Apply(&info->metadata_); } - -private: - friend class GlobalDef; - template - friend class ObjectDef; - /*! - * \brief Move metadata into a vector of key-value pairs. - * \param out The output vector. - */ - inline void Apply(_MetadataType *out) const { - std::copy(std::make_move_iterator(dict_.begin()), std::make_move_iterator(dict_.end()), - std::back_inserter(*out)); - } - /*! \brief Convert the metadata to JSON string */ - static std::string ToJSON(const _MetadataType &metadata) { - using ::tvm::ffi::details::StringObj; - std::ostringstream os; - os << "{"; - bool first = true; - for (const auto &[key, value] : metadata) { - if (!first) { - os << ","; - } - os << "\"" << key << "\":"; - if (std::optional v = value.as()) { - os << *v; - } else if (std::optional v = value.as()) { - os << (*v ? "true" : "false"); - } else if (std::optional v = value.as()) { - String escaped = EscapeString(*v); - os << escaped.c_str(); - } else { - TVM_FFI_LOG_AND_THROW(TypeError) << "Metadata can be only int, bool or string, but on key `" - << key << "`, the type is " << value.GetTypeKey(); - } - first = false; - } - os << "}"; - return os.str(); - } - - std::vector> dict_; -}; -/*! - * \brief Trait that can be used to set field default value - */ -class DefaultValue : public InfoTrait { -public: - /*! - * \brief Constructor - * \param value The value to be set - */ - explicit DefaultValue(Any value) : value_(std::move(value)) {} - - /*! - * \brief Apply the default value to the field info - * \param info The field info. - */ - TVM_FFI_INLINE void Apply(TVMFFIFieldInfo *info) const { - info->default_value = AnyView(value_).CopyToTVMFFIAny(); - info->flags |= kTVMFFIFieldFlagBitMaskHasDefault; - } - -private: - Any value_; -}; - -/*! - * \brief Trait that can be used to attach field flag - */ -class AttachFieldFlag : public InfoTrait { -public: - /*! - * \brief Attach a field flag to the field - * \param flag The flag to be set - */ - explicit AttachFieldFlag(int32_t flag) : flag_(flag) {} - - /*! - * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashDef - */ - TVM_FFI_INLINE static AttachFieldFlag SEqHashDef() { - return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashDef); - } - /*! - * \brief Attach kTVMFFIFieldFlagBitMaskSEqHashIgnore - */ - TVM_FFI_INLINE static AttachFieldFlag SEqHashIgnore() { - return AttachFieldFlag(kTVMFFIFieldFlagBitMaskSEqHashIgnore); - } - - /*! - * \brief Apply the field flag to the field info - * \param info The field info. - */ - TVM_FFI_INLINE void Apply(TVMFFIFieldInfo *info) const { info->flags |= flag_; } - -private: - int32_t flag_; -}; - -/*! - * \brief Get the byte offset of a class member field. - * - * \tparam The original class. - * \tparam T the field type. - * - * \param field_ptr A class member pointer - * \returns The byteoffset - */ -template -TVM_FFI_INLINE int64_t GetFieldByteOffsetToObject(T Class::*field_ptr) { - int64_t field_offset_to_class = reinterpret_cast(&(static_cast(nullptr)->*field_ptr)); - return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass(); -} - -/// \cond Doxygen_Suppress -class ReflectionDefBase { -protected: - template - static int FieldGetter(void *field, TVMFFIAny *result) { - TVM_FFI_SAFE_CALL_BEGIN(); - *result = details::AnyUnsafe::MoveAnyToTVMFFIAny(Any(*reinterpret_cast(field))); - TVM_FFI_SAFE_CALL_END(); - } - - template - static int FieldSetter(void *field, const TVMFFIAny *value) { - TVM_FFI_SAFE_CALL_BEGIN(); - if constexpr (std::is_same_v) { - *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value); - } else { - *reinterpret_cast(field) = AnyView::CopyFromTVMFFIAny(*value).cast(); - } - TVM_FFI_SAFE_CALL_END(); - } - - template - static int ObjectCreatorDefault(TVMFFIObjectHandle *result) { - TVM_FFI_SAFE_CALL_BEGIN(); - ObjectPtr obj = make_object(); - *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); - TVM_FFI_SAFE_CALL_END(); - } - - template - static int ObjectCreatorUnsafeInit(TVMFFIObjectHandle *result) { - TVM_FFI_SAFE_CALL_BEGIN(); - ObjectPtr obj = make_object(UnsafeInit{}); - *result = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(obj)); - TVM_FFI_SAFE_CALL_END(); - } - - template - TVM_FFI_INLINE static void ApplyFieldInfoTrait(FieldInfoBuilder *info, const T &value) { - if constexpr (std::is_base_of_v>) { - value.Apply(info); - } - if constexpr (std::is_same_v, char *>) { - info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; - } - } - - template - TVM_FFI_INLINE static void ApplyMethodInfoTrait(MethodInfoBuilder *info, const T &value) { - if constexpr (std::is_base_of_v>) { - value.Apply(info); - } - if constexpr (std::is_same_v, char *>) { - info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; - } - } - - template - TVM_FFI_INLINE static void ApplyExtraInfoTrait(TVMFFITypeMetadata *info, const T &value) { - if constexpr (std::is_same_v, char *>) { - info->doc = TVMFFIByteArray{value, std::char_traits::length(value)}; - } - } - - template - TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...)) { - static_assert(std::is_base_of_v || std::is_base_of_v, - "Class must be derived from ObjectRef or Object"); - if constexpr (std::is_base_of_v) { - auto fwrap = [func](Class target, Args... params) -> R { - // call method pointer - return (target.*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, std::move(name)); - } - - if constexpr (std::is_base_of_v) { - auto fwrap = [func](const Class *target, Args... params) -> R { - // call method pointer - return (const_cast(target)->*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, std::move(name)); - } - } - - template - TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...) const) { - static_assert(std::is_base_of_v || std::is_base_of_v, - "Class must be derived from ObjectRef or Object"); - if constexpr (std::is_base_of_v) { - auto fwrap = [func](const Class &target, Args... params) -> R { - // call method pointer - return (target.*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, std::move(name)); - } - - if constexpr (std::is_base_of_v) { - auto fwrap = [func](const Class *target, Args... params) -> R { - // call method pointer - return (target->*func)(std::forward(params)...); - }; - return ffi::Function::FromTyped(fwrap, std::move(name)); - } - } - - template - TVM_FFI_INLINE static Function GetMethod(std::string name, Func &&func) { - return ffi::Function::FromTyped(std::forward(func), std::move(name)); - } -}; -/// \endcond - -/*! - * \brief GlobalDef helper to register a global function. - * - * \code - * namespace refl = tvm::ffi::reflection; - * refl::GlobalDef().def("my_ffi_extension.my_function", MyFunction); - * \endcode - */ -class GlobalDef : public ReflectionDefBase { -public: - /*! - * \brief Define a global function. - * - * \tparam Func The function type. - * \tparam Extra The extra arguments. - * - * \param name The name of the function. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring or subclass of InfoTrait. - * - * \return The reflection definition. - */ - template - GlobalDef &def(const char *name, Func &&func, Extra &&...extra) { - using FuncInfo = details::FunctionInfo>; - RegisterFunc(name, ffi::Function::FromTyped(std::forward(func), std::string(name)), - FuncInfo::TypeSchema(), std::forward(extra)...); - return *this; - } - - /*! - * \brief Define a global function in ffi::PackedArgs format. - * - * \tparam Func The function type. - * \tparam Extra The extra arguments. - * - * \param name The name of the function. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring or subclass of InfoTrait. - * - * \return The reflection definition. - */ - template - GlobalDef &def_packed(const char *name, Func func, Extra &&...extra) { - RegisterFunc(name, ffi::Function::FromPacked(func), details::TypeSchemaImpl::v(), - std::forward(extra)...); - return *this; - } - - /*! - * \brief Expose a class method as a global function. - * - * An argument will be added to the first position if the function is not static. - * - * \tparam Class The class type. - * \tparam Func The function type. - * - * \param name The name of the method. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring. - * - * \return The reflection definition. - */ - template - GlobalDef &def_method(const char *name, Func &&func, Extra &&...extra) { - using FuncInfo = details::FunctionInfo>; - RegisterFunc(name, GetMethod(std::string(name), std::forward(func)), - FuncInfo::TypeSchema(), std::forward(extra)...); - return *this; - } - -private: - template // NOLINTNEXTLINE(performance-unnecessary-value-param) - void RegisterFunc(const char *name, ffi::Function func, String type_schema, Extra &&...extra) { - MethodInfoBuilder info; - info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; - info.doc = TVMFFIByteArray{nullptr, 0}; - info.flags = 0; - info.method = AnyView(func).CopyToTVMFFIAny(); - info.metadata_.emplace_back("type_schema", type_schema); - ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); - std::string metadata_str = Metadata::ToJSON(info.metadata_); - info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFIFunctionSetGlobalFromMethodInfo(&info, 0)); - } -}; - -/*! - * \brief Helper class to register a constructor method for object types. - * - * This helper is used with `ObjectDef::def()` to register an `__init__` method - * that constructs an object instance with the specified argument types. - * - * \tparam Args The argument types for the constructor. - * - * Example usage: - * \code - * class ExampleObject : public Object { - * public: - * int64_t v_i64; - * int32_t v_i32; - * - * ExampleObject(int64_t v_i64, int32_t v_i32) : v_i64(v_i64), v_i32(v_i32) {} - * TVM_FFI_DECLARE_OBJECT_INFO("example.ExampleObject", ExampleObject, Object); - * }; - * - * // Register the constructor - * refl::ObjectDef() - * .def(refl::init()); - * \endcode - * - * \note The object type is automatically deduced from the `ObjectDef` context. - */ -template -struct init { - // Allow ObjectDef to access the execute function - template - friend class ObjectDef; - - /*! - * \brief Constructor - */ - constexpr init() noexcept = default; - -private: - /*! - * \brief Execute the constructor - * \tparam Class The class type. - * \param args The arguments to be passed to the constructor. - * \return The constructed object wrapped in an `ObjectRef`. - */ - template - static inline ObjectRef execute(Args &&...args) { - return ObjectRef(ffi::make_object(std::forward(args)...)); - } -}; - -/*! - * \brief Helper to register Object's reflection metadata. - * \tparam Class The class type. - * - * \code - * namespace refl = tvm::ffi::reflection; - * refl::ObjectDef().def_ro("my_field", &MyClass::my_field); - * \endcode - */ -template -class ObjectDef : public ReflectionDefBase { -public: - /*! - * \brief Constructor - * \tparam ExtraArgs The extra arguments. - * \param extra_args The extra arguments. - */ - template - explicit ObjectDef(ExtraArgs &&...extra_args) - : type_index_(Class::_GetOrAllocRuntimeTypeIndex()), type_key_(Class::_type_key) { - RegisterExtraInfo(std::forward(extra_args)...); - } - - /*! - * \brief Define a readonly field. - * - * \tparam Class The class type. - * \tparam T The field type. - * \tparam Extra The extra arguments. - * - * \param name The name of the field. - * \param field_ptr The pointer to the field. - * \param extra The extra arguments that can be docstring or default value. - * - * \return The reflection definition. - */ - template - TVM_FFI_INLINE ObjectDef &def_ro(const char *name, T BaseClass::*field_ptr, Extra &&...extra) { - RegisterField(name, field_ptr, false, std::forward(extra)...); - return *this; - } - - /*! - * \brief Define a read-write field. - * - * \tparam Class The class type. - * \tparam T The field type. - * \tparam Extra The extra arguments. - * - * \param name The name of the field. - * \param field_ptr The pointer to the field. - * \param extra The extra arguments that can be docstring or default value. - * - * \return The reflection definition. - */ - template - TVM_FFI_INLINE ObjectDef &def_rw(const char *name, T BaseClass::*field_ptr, Extra &&...extra) { - static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields"); - RegisterField(name, field_ptr, true, std::forward(extra)...); - return *this; - } - - /*! - * \brief Define a method. - * - * \tparam Func The function type. - * \tparam Extra The extra arguments. - * - * \param name The name of the method. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring. - * - * \return The reflection definition. - */ - template - TVM_FFI_INLINE ObjectDef &def(const char *name, Func &&func, Extra &&...extra) { - RegisterMethod(name, false, std::forward(func), std::forward(extra)...); - return *this; - } - - /*! - * \brief Define a static method. - * - * \tparam Func The function type. - * \tparam Extra The extra arguments. - * - * \param name The name of the method. - * \param func The function to be registered. - * \param extra The extra arguments that can be docstring. - * - * \return The reflection definition. - */ - template - TVM_FFI_INLINE ObjectDef &def_static(const char *name, Func &&func, Extra &&...extra) { - RegisterMethod(name, true, std::forward(func), std::forward(extra)...); - return *this; - } - - /*! - * \brief Register a constructor for this object type. - * - * This method registers a static `__init__` method that constructs an instance - * of the object with the specified argument types. The constructor can be invoked - * from Python or other FFI bindings. - * - * \tparam Args The argument types for the constructor. - * \tparam Extra Additional arguments (e.g., docstring). - * - * \param init_func An instance of `init` specifying constructor signature. - * \param extra Optional additional metadata such as docstring. - * - * \return Reference to this `ObjectDef` for method chaining. - * - * Example: - * \code - * refl::ObjectDef() - * .def(refl::init(), "Constructor docstring"); - * \endcode - */ - template - TVM_FFI_INLINE ObjectDef &def([[maybe_unused]] init init_func, Extra &&...extra) { - RegisterMethod(kInitMethodName, true, &init::template execute, - std::forward(extra)...); - return *this; - } - -private: - template - void RegisterExtraInfo(ExtraArgs &&...extra_args) { - TVMFFITypeMetadata info; - info.total_size = sizeof(Class); - info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind; - info.creator = nullptr; - info.doc = TVMFFIByteArray{nullptr, 0}; - if constexpr (std::is_default_constructible_v) { - info.creator = ObjectCreatorDefault; - } else if constexpr (std::is_constructible_v) { - info.creator = ObjectCreatorUnsafeInit; - } - // apply extra info traits - ((ApplyExtraInfoTrait(&info, std::forward(extra_args)), ...)); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMetadata(type_index_, &info)); - } - - template - void RegisterField(const char *name, T BaseClass::*field_ptr, bool writable, - ExtraArgs &&...extra_args) { - static_assert(std::is_base_of_v, "BaseClass must be a base class of Class"); - FieldInfoBuilder info; - info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; - info.field_static_type_index = TypeToFieldStaticTypeIndex::value; - // store byte offset and setter, getter - // so the same setter can be reused for all the same type - info.offset = GetFieldByteOffsetToObject(field_ptr); - info.size = sizeof(T); - info.alignment = alignof(T); - info.flags = 0; - if (writable) { - info.flags |= kTVMFFIFieldFlagBitMaskWritable; - } - info.getter = FieldGetter; - info.setter = FieldSetter; - // initialize default value to nullptr - info.default_value = AnyView(nullptr).CopyToTVMFFIAny(); - info.doc = TVMFFIByteArray{nullptr, 0}; - info.metadata_.emplace_back("type_schema", details::TypeSchema::v()); - // apply field info traits - ((ApplyFieldInfoTrait(&info, std::forward(extra_args)), ...)); - // call register - std::string metadata_str = Metadata::ToJSON(info.metadata_); - info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info)); - } - - // register a method - template - void RegisterMethod(const char *name, bool is_static, Func &&func, Extra &&...extra) { - using FuncInfo = details::FunctionInfo>; - MethodInfoBuilder info; - info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; - info.doc = TVMFFIByteArray{nullptr, 0}; - info.flags = 0; - if (is_static) { - info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod; - } - // obtain the method function - Function method = GetMethod(std::string(type_key_) + "." + name, std::forward(func)); - info.method = AnyView(method).CopyToTVMFFIAny(); - info.metadata_.emplace_back("type_schema", FuncInfo::TypeSchema()); - // apply method info traits - ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); - std::string metadata_str = Metadata::ToJSON(info.metadata_); - info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()}; - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info)); - } - - int32_t type_index_; - const char *type_key_; - static constexpr const char *kInitMethodName = "__ffi_init__"; -}; - -/*! - * \brief Helper to register type attribute. - * \tparam Class The class type. - * \tparam ExtraArgs The extra arguments. - * - * \code - * namespace refl = tvm::ffi::reflection; - * refl::TypeAttrDef().def("func_attr", MyFunc); - * \endcode - * - */ -template >> -class TypeAttrDef : public ReflectionDefBase { -public: - /*! - * \brief Constructor - * \tparam ExtraArgs The extra arguments. - * \param extra_args The extra arguments. - */ - template - explicit TypeAttrDef(ExtraArgs &&...extra_args) - : type_index_(Class::RuntimeTypeIndex()), type_key_(Class::_type_key) {} - - /*! - * \brief Define a function-valued type attribute. - * - * \tparam Func The function type. - * - * \param name The name of the function. - * \param func The function to be registered. - * - * \return The TypeAttrDef object. - */ - template - TypeAttrDef &def(const char *name, Func &&func) { - TVMFFIByteArray name_array = {name, std::char_traits::length(name)}; - ffi::Function ffi_func = GetMethod(std::string(type_key_) + "." + name, std::forward(func)); - TVMFFIAny value_any = AnyView(ffi_func).CopyToTVMFFIAny(); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any)); - return *this; - } - - /*! - * \brief Define a constant-valued type attribute. - * - * \tparam T The type of the value. - * - * \param name The name of the attribute. - * \param value The value of the attribute. - * - * \return The TypeAttrDef object. - */ - template - TypeAttrDef &attr(const char *name, T value) { - TVMFFIByteArray name_array = {name, std::char_traits::length(name)}; - TVMFFIAny value_any = AnyView(value).CopyToTVMFFIAny(); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &name_array, &value_any)); - return *this; - } - -private: - int32_t type_index_; - const char *type_key_; -}; - -/*! - * \brief Ensure the type attribute column is presented in the system. - * - * \param name The name of the type attribute. - */ -inline void EnsureTypeAttrColumn(std::string_view name) { - TVMFFIByteArray name_array = {name.data(), name.size()}; - AnyView any_view(nullptr); - TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(kTVMFFINone, &name_array, - reinterpret_cast(&any_view))); -} - -} // namespace reflection -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_REFLECTION_REGISTRY_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/rvalue_ref.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/rvalue_ref.h deleted file mode 100644 index e12a4d44e..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/rvalue_ref.h +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/rvalue_ref.h - * \brief Helper class to define rvalue reference type. - */ -#ifndef TVM_FFI_RVALUE_REF_H_ -#define TVM_FFI_RVALUE_REF_H_ - -#include -#include - -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief Helper class to define rvalue reference type. - * - * By default, FFI pass all values by lvalue reference. - * - * However, we do allow users to intentionally mark a function parameter - * as RValueRef. In such cases, the caller can choose to pass parameter - * wrapped by RValueRef to the function. In which case the parameter - * can be directly moved by the callee. The caller can also choose to pass - * a normal lvalue to the function, in such case a copy will be triggered. - * - * To keep FFI checking overhead minimal, we do not handle case when rvalue - * is passed, but the callee did not declare the parameter as RValueRef. - * - * This design allows us to still leverage move semantics for parameters that - * need copy on write scenarios (and requires an unique copy). - * - * \code - * - * void Example() { - * auto append = Function::FromTyped([](RValueRef> ref, int val) -> Array { - * Array arr = *std::move(ref); - * assert(arr.unique()); - * arr.push_back(val); - * return arr; - * }); - * Array a = Array({1, 2}); - * // as we use rvalue ref to move a into append - * // we keep a single copy of the Array without creating new copies during copy-on-write - * a = append(RvalueRef(std::move(a)), 3); - * assert(a.size() == 3); - * } - * - * \endcode - */ -template >> -class RValueRef { -public: - /*! \brief the container type of the rvalue ref */ - using ContainerType = typename TObjRef::ContainerType; - /*! \brief only allow move constructor from rvalue of T */ - explicit RValueRef(TObjRef &&data) - : data_(details::ObjectUnsafe::ObjectPtrFromObjectRef(std::move(data))) {} - - /*! \brief return the data as rvalue */ - TObjRef operator*() && { return TObjRef(std::move(data_)); } - -private: - mutable ObjectPtr data_; - - template - friend struct TypeTraits; -}; - -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public TypeTraitsBase { - static constexpr bool storage_enabled = false; - - TVM_FFI_INLINE static void CopyToAnyView(const RValueRef &src, TVMFFIAny *result) { - result->type_index = TypeIndex::kTVMFFIObjectRValueRef; - result->zero_padding = 0; - // store the address of the ObjectPtr, which allows us to move the value - // and set the original ObjectPtr to nullptr - result->v_ptr = &(src.data_); - } - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) { - ObjectPtr *rvalue_ref = reinterpret_cast *>(src->v_ptr); - // object type does not match up, we need to try to convert the object - // in this case we do not move the original rvalue ref since conversion creates a copy - TVMFFIAny tmp_any; - tmp_any.type_index = rvalue_ref->get()->type_index(); - tmp_any.zero_padding = 0; - tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); - return "RValueRef<" + TypeTraits::GetMismatchTypeInfo(&tmp_any) + ">"; - } else { - return TypeTraits::GetMismatchTypeInfo(src); - } - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny *src) { - // first try rvalue conversion - if (src->type_index == TypeIndex::kTVMFFIObjectRValueRef) { - ObjectPtr *rvalue_ref = reinterpret_cast *>(src->v_ptr); - TVMFFIAny tmp_any; - tmp_any.type_index = rvalue_ref->get()->type_index(); - tmp_any.zero_padding = 0; - tmp_any.v_obj = reinterpret_cast(rvalue_ref->get()); - // fast path, storage type matches, direct move the rvalue ref - if (TypeTraits::CheckAnyStrict(&tmp_any)) { - return RValueRef( - details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(*rvalue_ref))); - } - if (std::optional opt = TypeTraits::TryCastFromAnyView(&tmp_any)) { - // object type does not match up, we need to try to convert the object - // in this case we do not move the original rvalue ref since conversion creates a copy - return RValueRef(*std::move(opt)); - } - return std::nullopt; - } - // try lvalue conversion - if (std::optional opt = TypeTraits::TryCastFromAnyView(src)) { - return RValueRef(*std::move(opt)); - } else { - return std::nullopt; - } - } - - TVM_FFI_INLINE static std::string TypeStr() { - return "RValueRef<" + TypeTraits::TypeStr() + ">"; - } - - TVM_FFI_INLINE static std::string TypeSchema() { - std::ostringstream oss; - oss << R"({"type":")" << StaticTypeKey::kTVMFFIObjectRValueRef << R"(","args":[)"; - oss << TypeTraits::TypeSchema(); - oss << "]}"; - return oss.str(); - } -}; -} // namespace ffi -} // namespace tvm - -#endif // TVM_FFI_RVALUE_REF_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/string.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/string.h deleted file mode 100644 index ad7230b93..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/string.h +++ /dev/null @@ -1,1102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ffi/string.h - * \brief Runtime Bytes and String types. - */ -#ifndef TVM_FFI_STRING_H_ -#define TVM_FFI_STRING_H_ - -#include "base_details.h" -#include "error.h" -#include "memory.h" -#include "object.h" -#include "type_traits.h" - -#include -#include -#include -#include -#include -#include - -// Note: We place string in tvm/ffi instead of tvm/ffi/container -// because string itself needs special handling and is an inherent -// core component for return string handling. -// The following dependency relation holds -// any -> string -> object - -/// \cond Doxygen_Suppress -#ifdef _MSC_VER -#define TVM_FFI_SNPRINTF _snprintf_s -#pragma warning(push) -#pragma warning(disable : 4244) -#pragma warning(disable : 4127) -#pragma warning(disable : 4702) -#else -#define TVM_FFI_SNPRINTF snprintf -#endif -/// \endcond - -namespace tvm { -namespace ffi { -namespace details { -/*! - * \brief Base class for bytes and string objects. - */ -class BytesObjBase : public Object, public TVMFFIByteArray {}; - -/*! - * \brief An object representing bytes. - * \note We use a separate object for bytes to follow Python convention - * and indicate passing of raw bytes. - * Bytes can be converted from/to string. - */ -class BytesObj : public BytesObjBase { -public: - static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIBytes; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIBytes, BytesObj, Object); -}; - -/*! \brief An object representing string. This is a POD type. */ -class StringObj : public BytesObjBase { -public: - static constexpr const uint32_t _type_index = TypeIndex::kTVMFFIStr; - static const constexpr bool _type_final = true; - TVM_FFI_DECLARE_OBJECT_INFO_STATIC(StaticTypeKey::kTVMFFIStr, StringObj, Object); -}; - -// String moved from std::string -// without having to trigger a copy -template -class BytesObjStdImpl : public Base { -public: - explicit BytesObjStdImpl(std::string other) : data_{std::move(other)} { - this->data = data_.data(); - this->size = data_.size(); - } - -private: - std::string data_; -}; - -/*! - * \brief Helper cell class that can be used to back small string - * \note Do not use directly, use String or Bytes instead - */ -class BytesBaseCell { -public: - BytesBaseCell() { - // initialize to none - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - - explicit BytesBaseCell(std::nullopt_t) { - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - - BytesBaseCell(const BytesBaseCell &other) : data_(other.data_) { // NOLINT(*) - if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::IncRefObjectHandle(data_.v_obj); - } - } - - BytesBaseCell(BytesBaseCell &&other) : data_(other.data_) { // NOLINT(*) - other.data_.type_index = TypeIndex::kTVMFFINone; - } - - BytesBaseCell &operator=(const BytesBaseCell &other) { - BytesBaseCell(other).swap(*this); // NOLINT(*) - return *this; - } - - BytesBaseCell &operator=(BytesBaseCell &&other) noexcept { - BytesBaseCell(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - - ~BytesBaseCell() { - if (data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::DecRefObjectHandle(data_.v_obj); - } - } - - /*! - * \brief Check if the cell is null - * \return true if the cell is null, false otherwise - */ - bool operator==(std::nullopt_t) const { return data_.type_index == TypeIndex::kTVMFFINone; } - - /*! - * \brief Check if the cell is not null - * \return true if the cell is not null, false otherwise - */ - bool operator!=(std::nullopt_t) const { return data_.type_index != TypeIndex::kTVMFFINone; } - - /*! - * \brief Swap this String with another string - * \param other The other string - */ - void swap(BytesBaseCell &other) { // NOLINT(*) - std::swap(data_, other.data_); - } - - const char *data() const noexcept { - if (data_.type_index < TypeIndex::kTVMFFIStaticObjectBegin) { - return data_.v_bytes; - } else { - // NOLINTNEXTLINE(clang-analyzer-security.ArrayBound) - return TVMFFIBytesGetByteArrayPtr(data_.v_obj)->data; - } - } - - size_t size() const noexcept { - if (data_.type_index < TypeIndex::kTVMFFIStaticObjectBegin) { - return data_.small_str_len; - } else { - // NOLINTNEXTLINE(clang-analyzer-security.ArrayBound) - return TVMFFIBytesGetByteArrayPtr(data_.v_obj)->size; - } - } - - template - void InitFromStd(std::string &&other, int32_t large_type_index) { - // needs to be reset to none first for exception safety - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&data_); - ObjectPtr ptr = make_object>(std::move(other)); - data_.v_obj = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(ptr)); - data_.type_index = large_type_index; - } - - /*! - * \brief Create a new empty space for a string - * \param size The size of the string - * \param small_type_index The type index for the small string - * \param large_type_index The type index for the large string - * \note always reserve one byte for \0 compactibility - * \return A pointer to the empty space - */ - template - char *InitSpaceForSize(size_t size, int32_t small_type_index, int32_t large_type_index) { - size_t kMaxSmallBytesLen = sizeof(int64_t) - 1; - // first zero the content, this is important for exception safety - data_.type_index = small_type_index; - data_.zero_padding = 0; - if (size <= kMaxSmallBytesLen) { - // set up the size accordingly - data_.small_str_len = static_cast(size); - return data_.v_bytes; - } else { - // allocate from heap - ObjectPtr ptr = make_inplace_array_object(size + 1); - char *dest_data = reinterpret_cast(ptr.get()) + sizeof(LargeObj); - ptr->data = dest_data; - ptr->size = size; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(&data_); - data_.v_obj = details::ObjectUnsafe::MoveObjectPtrToTVMFFIObjectPtr(std::move(ptr)); - // now reset the type index to str - data_.type_index = large_type_index; - return dest_data; - } - } - - void InitTypeIndex(int32_t type_index) { data_.type_index = type_index; } - - void MoveToAny(TVMFFIAny *result) { - *result = data_; - data_.type_index = TypeIndex::kTVMFFINone; - data_.zero_padding = 0; - data_.v_int64 = 0; - } - - TVMFFIAny CopyToTVMFFIAny() const { return data_; } - - static BytesBaseCell CopyFromAnyView(const TVMFFIAny *src) { - BytesBaseCell result(*src); - if (result.data_.type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { - details::ObjectUnsafe::IncRefObjectHandle(result.data_.v_obj); - } - return result; - } - - static BytesBaseCell MoveFromAny(TVMFFIAny *src) { - BytesBaseCell result(*src); - src->type_index = TypeIndex::kTVMFFINone; - src->zero_padding = 0; - src->v_int64 = 0; - return result; - } - -private: - explicit BytesBaseCell(TVMFFIAny data) : data_(data) {} - /*! \brief internal backing data */ - TVMFFIAny data_; -}; -} // namespace details - -/*! - * \brief Managed reference of byte array. - */ -class Bytes { -public: - /*! \brief default constructor */ - Bytes() { data_.InitTypeIndex(TypeIndex::kTVMFFISmallBytes); } - /*! - * \brief constructor from size - * - * \param data The data pointer. - * \param size The size of the char array. - */ - Bytes(const char *data, size_t size) { this->InitData(data, size); } - /*! - * \brief constructor from TVMFFIByteArray - * - * \param bytes a char array. - */ - Bytes(TVMFFIByteArray bytes) { // NOLINT(*) - this->InitData(bytes.data, bytes.size); - } - /*! - * \brief constructor from std::string - * - * \param other a char array. - */ - Bytes(const std::string &other) { // NOLINT(*) - this->InitData(other.data(), other.size()); - } - /*! - * \brief constructor from std::string - * - * \param other a char array. - */ - Bytes(std::string &&other) { // NOLINT(*) - data_.InitFromStd(std::move(other), TypeIndex::kTVMFFIBytes); - } - /*! - * \brief Swap this String with another string - * \param other The other string - */ - void swap(Bytes &other) { // NOLINT(*) - std::swap(data_, other.data_); - } - - template - Bytes &operator=(T &&other) { - // copy-and-swap idiom - Bytes(std::forward(other)).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief Return the length of the string - * - * \return size_t string length - */ - size_t size() const { return data_.size(); } - /*! - * \brief Return the data pointer - * - * \return const char* data pointer - */ - const char *data() const { return data_.data(); } - /*! - * \brief Convert String to an std::string object - * - * \return std::string - */ - operator std::string() const { // NOLINT(google-explicit-constructor) - return std::string{data(), size()}; - } - - /*! - * \brief Compare two char sequence - * - * \param lhs Pointers to the char array to compare - * \param rhs Pointers to the char array to compare - * \param lhs_count Length of the char array to compare - * \param rhs_count Length of the char array to compare - * \return int zero if both char sequences compare equal. negative if this - * appear before other, positive otherwise. - */ - static int memncmp(const char *lhs, const char *rhs, size_t lhs_count, size_t rhs_count) { - if (lhs == rhs && lhs_count == rhs_count) { - return 0; - } - - for (size_t i = 0; i < lhs_count && i < rhs_count; ++i) { - if (lhs[i] < rhs[i]) { - return -1; - } - if (lhs[i] > rhs[i]) { - return 1; - } - } - if (lhs_count < rhs_count) { - return -1; - } else if (lhs_count > rhs_count) { - return 1; - } else { - return 0; - } - } - /*! - * \brief Compare two char sequence for equality - * - * \param lhs Pointers to the char array to compare - * \param rhs Pointers to the char array to compare - * \param lhs_count Length of the char array to compare - * \param rhs_count Length of the char array to compare - * - * \return true if the two char sequences are equal, false otherwise. - */ - static bool memequal(const void *lhs, const void *rhs, size_t lhs_count, size_t rhs_count) { - return lhs_count == rhs_count && (lhs == rhs || std::memcmp(lhs, rhs, lhs_count) == 0); - } - -private: - template - friend struct TypeTraits; - template - friend class Optional; - // internal backing cell - details::BytesBaseCell data_; - // create a new String from TVMFFIAny, must keep private - explicit Bytes(details::BytesBaseCell data) : data_(std::move(data)) {} - char *InitSpaceForSize(size_t size) { - return data_.InitSpaceForSize(size, TypeIndex::kTVMFFISmallBytes, - TypeIndex::kTVMFFIBytes); - } - void InitData(const char *data, size_t size) { - char *dest_data = InitSpaceForSize(size); - std::memcpy(dest_data, data, size); - // mainly to be compat with string - dest_data[size] = '\0'; - } -}; - -/*! - * \brief String container class. - */ -class String { -public: - /*! - * \brief avoid misuse of nullptr - */ - String(std::nullptr_t) = delete; // NOLINT(*) - /*! - * \brief constructor - */ - String() { data_.InitTypeIndex(TypeIndex::kTVMFFISmallStr); } - // constructors from Any - /*! - * \brief Copy constructor - * \param other The other string - */ - String(const String &other) = default; // NOLINT(*) - /*! - * \brief Move constructor - * \param other The other string - */ - String(String &&other) = default; // NOLINT(*) - /*! - * \brief Copy assignment operator - * \param other The other string - */ - String &operator=(const String &other) = default; // NOLINT(*) - /*! - * \brief Move assignment operator - * \param other The other string - */ - String &operator=(String &&other) = default; // NOLINT(*) - - /*! - * \brief Swap this String with another string - * \param other The other string - */ - void swap(String &other) noexcept { // NOLINT(*) - std::swap(data_, other.data_); - } - - /*! - * \brief Copy assignment operator - * \param other The other string - */ - String &operator=(const std::string &other) { - String(other).swap(*this); // NOLINT(*) - return *this; - } - /*! - * \brief Move assignment operator - * \param other The other string - */ - String &operator=(std::string &&other) { - String(std::move(other)).swap(*this); // NOLINT(*) - return *this; - } - - /*! - * \brief Copy assignment operator - * \param other The other string - */ - String &operator=(const char *other) { - String(other).swap(*this); // NOLINT(*) - return *this; - } - - /*! - * \brief constructor from raw string - * - * \param data The data pointer. - * \param size The size of the char array. - */ - String(const char *data, size_t size) { this->InitData(data, size); } - - /*! - * \brief constructor from raw string - * - * \param other a char array. - * \note This constructor is marked as explicit to avoid implicit conversion - * of nullptr value here to string, which then was used in comparison - */ - String(const char *other) { // NOLINT(*) - this->InitData(other, std::char_traits::length(other)); - } - /*! - * \brief Construct a new string object - * \param other The std::string object to be copied - */ - String(const std::string &other) { // NOLINT(*) - this->InitData(other.data(), other.size()); - } - - /*! - * \brief Construct a new string object - * \param other The std::string object to be moved - */ - String(std::string &&other) { // NOLINT(*) - // exception safety, first set to none so if exception is thrown - // destructor works correctly - data_.InitFromStd(std::move(other), TypeIndex::kTVMFFIStr); - } - - /*! - * \brief constructor from TVMFFIByteArray - * - * \param other a TVMFFIByteArray. - */ - explicit String(TVMFFIByteArray other) { this->InitData(other.data, other.size); } - - /*! - * \brief Return the data pointer - * - * \return const char* data pointer - */ - const char *data() const noexcept { return data_.data(); } - - /*! - * \brief Returns a pointer to the char array in the string. - * - * \return const char* - */ - const char *c_str() const noexcept { return data(); } - - /*! - * \brief Return the length of the string - * - * \return size_t string length - */ - size_t size() const noexcept { return data_.size(); } - - /*! - * \brief Compares this String object to other - * - * \param other The String to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const String &other) const { - return Bytes::memncmp(data(), other.data(), size(), other.size()); - } - - /*! - * \brief Compares this String object to other - * - * \param other The string to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const std::string &other) const { - return Bytes::memncmp(data(), other.data(), size(), other.size()); - } - - /*! - * \brief Compares this to other - * - * \param other The character array to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const char *other) const { - const char *this_data = data(); - size_t this_size = size(); - for (size_t i = 0; i < this_size; ++i) { - // other is shorter than this - if (other[i] == '\0') { - return 1; - } - if (this_data[i] < other[i]) { - return -1; - } - if (this_data[i] > other[i]) { - return 1; - } - } - // other equals this - if (other[this_size] == '\0') { - return 0; - } - // other longer than this - return -1; - } - - /*! - * \brief Compares this to other - * - * \param other The TVMFFIByteArray to compare with. - * - * \return zero if both char sequences compare equal. negative if this appear - * before other, positive otherwise. - */ - int compare(const TVMFFIByteArray &other) const { - return Bytes::memncmp(data(), other.data, size(), other.size); - } - - /*! - * \brief Return the length of the string - * - * \return size_t string length - */ - size_t length() const { return size(); } - - /*! - * \brief Retun if the string is empty - * - * \return true if empty, false otherwise. - */ - bool empty() const { return size() == 0; } - - /*! - * \brief Read an element. - * \param pos The position at which to read the character. - * - * \return The char at position - */ - char at(size_t pos) const { - if (pos < size()) { - return data()[pos]; - } else { - throw std::out_of_range("tvm::String index out of bounds"); - } - } - - /*! - * \brief Convert String to an std::string object - * - * \return std::string - */ - operator std::string() const { // NOLINT(google-explicit-constructor) - return std::string{data(), size()}; - } - -private: - template - friend struct TypeTraits; - template - friend class Optional; - // internal backing cell - details::BytesBaseCell data_; - // create a new String from TVMFFIAny, must keep private - explicit String(details::BytesBaseCell data) : data_(std::move(data)) {} - /*! - * \brief Create a new empty space for a string - * \param size The size of the string - * \return A pointer to the empty space - */ - char *InitSpaceForSize(size_t size) { - return data_.InitSpaceForSize(size, TypeIndex::kTVMFFISmallStr, - TypeIndex::kTVMFFIStr); - } - void InitData(const char *data, size_t size) { - char *dest_data = InitSpaceForSize(size); - std::memcpy(dest_data, data, size); - dest_data[size] = '\0'; - } - /*! - * \brief Concatenate two char sequences - * - * \param lhs Pointers to the lhs char array - * \param lhs_size The size of the lhs char array - * \param rhs Pointers to the rhs char array - * \param rhs_size The size of the rhs char array - * - * \return The concatenated char sequence - */ - static String Concat(const char *lhs, size_t lhs_size, const char *rhs, size_t rhs_size) { - String ret; - // disable stringop-overflow and restrict warnings - // gcc may produce false positive when we enable dest_data returned from small string path - // Because compiler is not able to detect the condition that the path is only triggered via - // size < kMaxSmallStrLen and can report it as a overflow case. -#if (__GNUC__) && !(__clang__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstringop-overflow" -#pragma GCC diagnostic ignored "-Wrestrict" -#endif - char *dest_data = ret.InitSpaceForSize(lhs_size + rhs_size); - std::memcpy(dest_data, lhs, lhs_size); - std::memcpy(dest_data + lhs_size, rhs, rhs_size); - // NOLINTNEXTLINE(clang-analyzer-security.ArrayBound) - dest_data[lhs_size + rhs_size] = '\0'; -#if (__GNUC__) && !(__clang__) -#pragma GCC diagnostic pop -#endif - return ret; - } - // Overload + operator - friend String operator+(const String &lhs, const String &rhs); - friend String operator+(const String &lhs, const std::string &rhs); - friend String operator+(const std::string &lhs, const String &rhs); - friend String operator+(const String &lhs, const char *rhs); - friend String operator+(const char *lhs, const String &rhs); -}; - -/*! - * \brief Return an escaped version of the string - * \param value The input string - * \return The escaped string, quoted with double quotes - */ -inline String EscapeString(const String &value) { - std::ostringstream oss; - oss << '"'; - const char *data = value.data(); - const size_t size = value.size(); - for (size_t i = 0; i < size; ++i) { - switch (data[i]) { -/// \cond Doxygen_Suppress -#define TVM_FFI_ESCAPE_CHAR(pattern, val) \ - case pattern: \ - oss << (val); \ - break - TVM_FFI_ESCAPE_CHAR('\"', "\\\""); - TVM_FFI_ESCAPE_CHAR('\\', "\\\\"); - TVM_FFI_ESCAPE_CHAR('/', "\\/"); - TVM_FFI_ESCAPE_CHAR('\b', "\\b"); - TVM_FFI_ESCAPE_CHAR('\f', "\\f"); - TVM_FFI_ESCAPE_CHAR('\n', "\\n"); - TVM_FFI_ESCAPE_CHAR('\r', "\\r"); - TVM_FFI_ESCAPE_CHAR('\t', "\\t"); -#undef TVM_FFI_ESCAPE_CHAR - /// \endcond - default: { - uint8_t u8_val = static_cast(data[i]); - // this is a control character, print as \uXXXX - if (u8_val < 0x20 || u8_val == 0x7f) { - char buffer[8]; - int size = TVM_FFI_SNPRINTF(buffer, sizeof(buffer), "\\u%04x", - static_cast(data[i]) & 0xff); - oss.write(buffer, size); - } else { - oss << data[i]; - } - break; - } - } - } - oss << '"'; - return String(oss.str()); -} - -/*! \brief Convert TVMFFIByteArray to std::string_view */ -TVM_FFI_INLINE std::string_view ToStringView(TVMFFIByteArray str) { - return std::string_view(str.data, str.size); -} -/// \cond Doxygen_Suppress - -template <> -inline constexpr bool use_default_type_traits_v = false; - -// specialize to enable implicit conversion from TVMFFIByteArray* -template <> -struct TypeTraits : public TypeTraitsBase { - // bytes can be union type of small bytes and object, so keep it as any - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny; - - TVM_FFI_INLINE static void CopyToAnyView(const Bytes &src, TVMFFIAny *result) { - *result = src.data_.CopyToTVMFFIAny(); - } - - TVM_FFI_INLINE static void MoveToAny(Bytes src, TVMFFIAny *result) { - src.data_.MoveToAny(result); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - return src->type_index == TypeIndex::kTVMFFISmallBytes || src->type_index == TypeIndex::kTVMFFIBytes; - } - - TVM_FFI_INLINE static Bytes CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { - return Bytes(details::BytesBaseCell::CopyFromAnyView(src)); - } - - TVM_FFI_INLINE static Bytes MoveFromAnyAfterCheck(TVMFFIAny *src) { - return Bytes(details::BytesBaseCell::MoveFromAny(src)); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFIByteArrayPtr) { - return Bytes(*static_cast(src->v_ptr)); - } - if (src->type_index == TypeIndex::kTVMFFISmallBytes || src->type_index == TypeIndex::kTVMFFIBytes) { - return Bytes(details::BytesBaseCell::CopyFromAnyView(src)); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return "bytes"; } - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIBytes) + R"("})"; - } -}; - -template <> -inline constexpr bool use_default_type_traits_v = false; - -// specialize to enable implicit conversion from const char* -template <> -struct TypeTraits : public TypeTraitsBase { - // string can be union type of small string and object, so keep it as any - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny; - - TVM_FFI_INLINE static void CopyToAnyView(const String &src, TVMFFIAny *result) { - *result = src.data_.CopyToTVMFFIAny(); - } - - TVM_FFI_INLINE static void MoveToAny(String src, TVMFFIAny *result) { - src.data_.MoveToAny(result); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - return src->type_index == TypeIndex::kTVMFFISmallStr || src->type_index == TypeIndex::kTVMFFIStr; - } - - TVM_FFI_INLINE static String CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { - return String(details::BytesBaseCell::CopyFromAnyView(src)); - } - - TVM_FFI_INLINE static String MoveFromAnyAfterCheck(TVMFFIAny *src) { - return String(details::BytesBaseCell::MoveFromAny(src)); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFIRawStr) { - return String(src->v_c_str); - } - if (src->type_index == TypeIndex::kTVMFFISmallStr || src->type_index == TypeIndex::kTVMFFIStr) { - return String(details::BytesBaseCell::CopyFromAnyView(src)); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return "str"; } - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIStr) + R"("})"; - } -}; - -// const char*, requirement: not nullable, do not retain ownership -template -struct TypeTraits : public TypeTraitsBase { - // NOTE: only enable implicit conversion into AnyView - static constexpr bool storage_enabled = false; - - TVM_FFI_INLINE static void CopyToAnyView(const char src[N], TVMFFIAny *result) { - result->type_index = TypeIndex::kTVMFFIRawStr; - result->zero_padding = 0; - result->v_c_str = src; - } - - TVM_FFI_INLINE static void MoveToAny(const char src[N], TVMFFIAny *result) { - // when we need to move to any, convert to owned object first - TypeTraits::MoveToAny(String(src), result); - } -}; - -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr bool storage_enabled = false; - - TVM_FFI_INLINE static void CopyToAnyView(const char *src, TVMFFIAny *result) { - TVM_FFI_ICHECK_NOTNULL(src); - result->type_index = TypeIndex::kTVMFFIRawStr; - result->zero_padding = 0; - result->v_c_str = src; - } - - TVM_FFI_INLINE static void MoveToAny(const char *src, TVMFFIAny *result) { - // when we need to move to any, convert to owned object first - TypeTraits::MoveToAny(String(src), result); - } - // Do not allow const char* in a container, so we do not need CheckAnyStrict - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFIRawStr) { - return static_cast(src->v_c_str); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return "const char*"; } - TVM_FFI_INLINE static std::string TypeSchema() { return R"({"type":"const char*"})"; } -}; - -// TVMFFIByteArray, requirement: not nullable, do not retain ownership -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIByteArrayPtr; - static constexpr bool storage_enabled = false; - - TVM_FFI_INLINE static void CopyToAnyView(TVMFFIByteArray *src, TVMFFIAny *result) { - TVM_FFI_ICHECK_NOTNULL(src); - result->type_index = TypeIndex::kTVMFFIByteArrayPtr; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_ptr = src; - } - - TVM_FFI_INLINE static void MoveToAny(TVMFFIByteArray *src, TVMFFIAny *result) { - TypeTraits::MoveToAny(Bytes(*src), result); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFIByteArrayPtr) { - return static_cast(src->v_ptr); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIByteArrayPtr; } - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIByteArrayPtr) + R"("})"; - } -}; - -template <> -inline constexpr bool use_default_type_traits_v = false; - -template <> -struct TypeTraits - : public FallbackOnlyTraitsBase { - TVM_FFI_INLINE static void CopyToAnyView(const std::string &src, TVMFFIAny *result) { - result->type_index = TypeIndex::kTVMFFIRawStr; - result->zero_padding = 0; - result->v_c_str = src.c_str(); - } - - TVM_FFI_INLINE static void MoveToAny(std::string src, TVMFFIAny *result) { - // when we need to move to any, convert to owned object first - TypeTraits::MoveToAny(String(std::move(src)), result); - } - - TVM_FFI_INLINE static std::string TypeStr() { return "std::string"; } - TVM_FFI_INLINE static std::string TypeSchema() { return R"({"type":"std::string"})"; } - - TVM_FFI_INLINE static std::string ConvertFallbackValue(const char *src) { - return std::string(src); - } - - TVM_FFI_INLINE static std::string ConvertFallbackValue(TVMFFIByteArray *src) { - return std::string(src->data, src->size); - } - - // NOLINTNEXTLINE(performance-unnecessary-value-param) - TVM_FFI_INLINE static std::string ConvertFallbackValue(Bytes src) { - return src.operator std::string(); - } - - // NOLINTNEXTLINE(performance-unnecessary-value-param) - TVM_FFI_INLINE static std::string ConvertFallbackValue(String src) { - return src.operator std::string(); - } -}; - -inline String operator+(const String &lhs, const String &rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const String &lhs, const std::string &rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const std::string &lhs, const String &rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = rhs.size(); - return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const char *lhs, const String &rhs) { - size_t lhs_size = std::strlen(lhs); - size_t rhs_size = rhs.size(); - return String::Concat(lhs, lhs_size, rhs.data(), rhs_size); -} - -inline String operator+(const String &lhs, const char *rhs) { - size_t lhs_size = lhs.size(); - size_t rhs_size = std::strlen(rhs); - return String::Concat(lhs.data(), lhs_size, rhs, rhs_size); -} - -// Overload < operator -inline bool operator<(std::nullptr_t, const String &rhs) = delete; -inline bool operator<(const String &lhs, std::nullptr_t) = delete; - -inline bool operator<(const String &lhs, const std::string &rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const std::string &lhs, const String &rhs) { return rhs.compare(lhs) > 0; } - -inline bool operator<(const String &lhs, const String &rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const String &lhs, const char *rhs) { return lhs.compare(rhs) < 0; } - -inline bool operator<(const char *lhs, const String &rhs) { return rhs.compare(lhs) > 0; } - -// Overload > operator -inline bool operator>(std::nullptr_t, const String &rhs) = delete; -inline bool operator>(const String &lhs, std::nullptr_t) = delete; - -inline bool operator>(const String &lhs, const std::string &rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const std::string &lhs, const String &rhs) { return rhs.compare(lhs) < 0; } - -inline bool operator>(const String &lhs, const String &rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const String &lhs, const char *rhs) { return lhs.compare(rhs) > 0; } - -inline bool operator>(const char *lhs, const String &rhs) { return rhs.compare(lhs) < 0; } - -// Overload <= operator -inline bool operator<=(std::nullptr_t, const String &rhs) = delete; -inline bool operator<=(const String &lhs, std::nullptr_t) = delete; - -inline bool operator<=(const String &lhs, const std::string &rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const std::string &lhs, const String &rhs) { return rhs.compare(lhs) >= 0; } - -inline bool operator<=(const String &lhs, const String &rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const String &lhs, const char *rhs) { return lhs.compare(rhs) <= 0; } - -inline bool operator<=(const char *lhs, const String &rhs) { return rhs.compare(lhs) >= 0; } - -// Overload >= operator -inline bool operator>=(std::nullptr_t, const String &rhs) = delete; -inline bool operator>=(const String &lhs, std::nullptr_t) = delete; - -inline bool operator>=(const String &lhs, const std::string &rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const std::string &lhs, const String &rhs) { return rhs.compare(lhs) <= 0; } - -inline bool operator>=(const String &lhs, const String &rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const String &lhs, const char *rhs) { return lhs.compare(rhs) >= 0; } - -inline bool operator>=(const char *lhs, const String &rhs) { return rhs.compare(lhs) <= 0; } - -// delete Overload == operator for nullptr -inline bool operator==(const String &lhs, std::nullptr_t) = delete; -inline bool operator==(std::nullptr_t, const String &rhs) = delete; - -inline bool operator==(const String &lhs, const std::string &rhs) { - return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size()); -} - -inline bool operator==(const std::string &lhs, const String &rhs) { - return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size()); -} - -inline bool operator==(const String &lhs, const String &rhs) { - return Bytes::memequal(lhs.data(), rhs.data(), lhs.size(), rhs.size()); -} - -inline bool operator==(const String &lhs, const char *rhs) { return lhs.compare(rhs) == 0; } - -inline bool operator==(const char *lhs, const String &rhs) { return rhs.compare(lhs) == 0; } - -// Overload != operator -inline bool operator!=(const String &lhs, std::nullptr_t) = delete; -inline bool operator!=(std::nullptr_t, const String &rhs) = delete; - -inline bool operator!=(const String &lhs, const std::string &rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const std::string &lhs, const String &rhs) { return rhs.compare(lhs) != 0; } - -inline bool operator!=(const String &lhs, const String &rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const String &lhs, const char *rhs) { return lhs.compare(rhs) != 0; } - -inline bool operator!=(const char *lhs, const String &rhs) { return rhs.compare(lhs) != 0; } - -inline std::ostream &operator<<(std::ostream &out, const String &input) { - out.write(input.data(), static_cast(input.size())); - return out; -} -/// \endcond -} // namespace ffi -} // namespace tvm - -/// \cond Doxygen_Suppress -namespace std { - -template <> -struct hash<::tvm::ffi::Bytes> { - std::size_t operator()(const ::tvm::ffi::Bytes &bytes) const { - return std::hash()(std::string_view(bytes.data(), bytes.size())); - } -}; - -template <> -struct hash<::tvm::ffi::String> { - std::size_t operator()(const ::tvm::ffi::String &str) const { - return std::hash()(std::string_view(str.data(), str.size())); - } -}; -} // namespace std -/// \endcond -#endif // TVM_FFI_STRING_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/type_traits.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/type_traits.h deleted file mode 100644 index d9f3f58a7..000000000 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/tvm/ffi/type_traits.h +++ /dev/null @@ -1,828 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/ffi/object.h - * \brief A managed object in the TVM FFI. - */ -#ifndef TVM_FFI_TYPE_TRAITS_H_ -#define TVM_FFI_TYPE_TRAITS_H_ - -#include "base_details.h" -#include "c_api.h" -#include "error.h" -#include "object.h" - -#include -#include -#include - -namespace tvm { -namespace ffi { - -/*! - * \brief TypeTraits that specifies the conversion behavior from/to FFI Any. - * - * The function specifications of TypeTraits - * - * - CopyToAnyView: Convert a value T to AnyView - * - MoveToAny: Move a value to Any - * - CheckAnyStrict: Check if a Any stores a result of CopyToAnyView of current T. - * - CopyFromAnyViewAfterCheck: Copy a value T from Any view after we pass CheckAnyStrict. - * - MoveFromAnyAfterCheck: Move a value T from Any storage after we pass CheckAnyStrict. - * - TryCastFromAnyView: Convert a AnyView to a T, we may apply type conversion. - * - GetMismatchTypeInfo: Get the type key of a type when TryCastFromAnyView fails. - * - TypeStr: Get the type key of a type - * - * It is possible that CheckAnyStrict is false but TryCastFromAnyView still works. - * - * For example, when Any x stores int, TypeTraits::CheckAnyStrict(x) will be false, - * but TypeTraits::TryCastFromAnyView(x) will return a corresponding float value - * via type conversion. - * - * CheckAnyStrict is mainly used in recursive container such as Array to - * decide if a new Array needed to be created via recursive conversion, - * or we can use the current container as is when converting to Array. - * - * A container array: Array satisfies the following invariant: - * - `all(TypeTraits::CheckAnyStrict(x) for x in the array)`. - */ -template -struct TypeTraits { - /*! \brief Whether the type is enabled in FFI. */ - static constexpr bool convert_enabled = false; - /*! \brief Whether the type can appear as a storage type in Container */ - static constexpr bool storage_enabled = false; -}; - -/*! - * \brief TypeTraits that removes const and reference keywords. - * \tparam T the original type - */ -template -using TypeTraitsNoCR = TypeTraits>>; - -template -inline constexpr bool use_default_type_traits_v = true; - -struct TypeTraitsBase { - static constexpr bool convert_enabled = true; - static constexpr bool storage_enabled = true; - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIAny; - // get mismatched type when result mismatches the trait. - // this function is called after TryCastFromAnyView fails - // to get more detailed type information in runtime - // especially when the error involves nested container type - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny *source) { - return TypeIndexToTypeKey(source->type_index); - } -}; - -/*! - * \brief Trait that maps a type to its field static type index - * \tparam T the type - * \return the field static type index - */ -template -struct TypeToFieldStaticTypeIndex { - /*! \brief The field static type index of the type */ - static constexpr int32_t value = TypeIndex::kTVMFFIAny; -}; - -template -struct TypeToFieldStaticTypeIndex::convert_enabled>> { - static constexpr int32_t value = TypeTraits::field_static_type_index; -}; - -/*! - * \brief Trait that maps a type to its runtime type index - * \tparam T the type - * \return the runtime type index - */ -template -struct TypeToRuntimeTypeIndex { - /*! - * \brief Get the runtime type index of the type - * \return the runtime type index - */ - static int32_t v() { return TypeToFieldStaticTypeIndex::value; } -}; - -template -struct TypeToRuntimeTypeIndex>> { - static int32_t v() { return T::ContainerType::RuntimeTypeIndex(); } -}; - -// None -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFINone; - - TVM_FFI_INLINE static void CopyToAnyView(const std::nullptr_t &, TVMFFIAny *result) { - result->type_index = TypeIndex::kTVMFFINone; - result->zero_padding = 0; - // invariant: the pointer field also equals nullptr - // this will simplify same_as comparisons and hash - result->v_int64 = 0; - } - - TVM_FFI_INLINE static void MoveToAny(std::nullptr_t, TVMFFIAny *result) { - result->type_index = TypeIndex::kTVMFFINone; - result->zero_padding = 0; - // invariant: the pointer field also equals nullptr - // this will simplify same_as comparisons and hash - result->v_int64 = 0; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - return src->type_index == TypeIndex::kTVMFFINone; - } - - TVM_FFI_INLINE static std::nullptr_t CopyFromAnyViewAfterCheck(const TVMFFIAny *) { - return nullptr; - } - - TVM_FFI_INLINE static std::nullptr_t MoveFromAnyAfterCheck(TVMFFIAny *) { return nullptr; } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return nullptr; - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFINone; } - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(StaticTypeKey::kTVMFFINone) + R"("})"; - } -}; - -/** - * \brief A type that forbids implicit conversion from int to bool - * - * This type is used to prevent implicit conversion from int to bool. - */ -class StrictBool { -public: - /*! - * \brief Constructor - * \param value The value of the strict bool. - */ - StrictBool(bool value) : value_(value) {} // NOLINT(google-explicit-constructor) - /*! - *\brief Convert the strict bool to bool. - * \return The value of the strict bool. - */ - operator bool() const { return value_; } // NOLINT(google-explicit-constructor) - -private: - bool value_; -}; - -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIBool; - - TVM_FFI_INLINE static void CopyToAnyView(const StrictBool &src, TVMFFIAny *result) { - result->type_index = TypeIndex::kTVMFFIBool; - result->zero_padding = 0; - result->v_int64 = static_cast(src); - } - - TVM_FFI_INLINE static void MoveToAny(StrictBool src, TVMFFIAny *result) { - CopyToAnyView(src, result); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - return src->type_index == TypeIndex::kTVMFFIBool; - } - - TVM_FFI_INLINE static StrictBool CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { - return static_cast(src->v_int64); - } - - TVM_FFI_INLINE static StrictBool MoveFromAnyAfterCheck(TVMFFIAny *src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFIBool) { - return StrictBool(static_cast(src->v_int64)); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIBool; } - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIBool) + R"("})"; - } -}; - -// Bool type, allow implicit casting from int -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIBool; - - TVM_FFI_INLINE static void CopyToAnyView(const bool &src, TVMFFIAny *result) { - result->type_index = TypeIndex::kTVMFFIBool; - result->zero_padding = 0; - result->v_int64 = static_cast(src); - } - - TVM_FFI_INLINE static void MoveToAny(bool src, TVMFFIAny *result) { CopyToAnyView(src, result); } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - return src->type_index == TypeIndex::kTVMFFIBool; - } - - TVM_FFI_INLINE static bool CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { - return static_cast(src->v_int64); - } - - TVM_FFI_INLINE static bool MoveFromAnyAfterCheck(TVMFFIAny *src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { - return static_cast(src->v_int64); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIBool; } - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIBool) + R"("})"; - } -}; - -// Integer POD values -template -struct TypeTraits>> : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIInt; - - TVM_FFI_INLINE static void CopyToAnyView(const Int &src, TVMFFIAny *result) { - result->type_index = TypeIndex::kTVMFFIInt; - result->zero_padding = 0; - result->v_int64 = static_cast(src); - } - - TVM_FFI_INLINE static void MoveToAny(Int src, TVMFFIAny *result) { CopyToAnyView(src, result); } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny - return src->type_index == TypeIndex::kTVMFFIInt; - } - - TVM_FFI_INLINE static Int CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { - return static_cast(src->v_int64); - } - - TVM_FFI_INLINE static Int MoveFromAnyAfterCheck(TVMFFIAny *src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { - return Int(src->v_int64); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIInt; } - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIInt) + R"("})"; - } -}; - -/// \cond Doxygen_Suppress - -// trait to check if a type is an integeral enum -// note that we need this trait so we can confirm underlying_type_t is an integral type -// to avoid potential undefined behavior -template > -constexpr bool is_integeral_enum_v = false; - -template -constexpr bool is_integeral_enum_v = std::is_integral_v>; - -/// \endcond - -// Enum Integer POD values -template -struct TypeTraits>> : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIInt; - - TVM_FFI_INLINE static void CopyToAnyView(const IntEnum &src, TVMFFIAny *result) { - result->type_index = TypeIndex::kTVMFFIInt; - result->zero_padding = 0; - result->v_int64 = static_cast(src); - } - - TVM_FFI_INLINE static void MoveToAny(IntEnum src, TVMFFIAny *result) { - CopyToAnyView(src, result); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny - return src->type_index == TypeIndex::kTVMFFIInt; - } - - TVM_FFI_INLINE static IntEnum CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { - return static_cast(src->v_int64); - } - - TVM_FFI_INLINE static IntEnum MoveFromAnyAfterCheck(TVMFFIAny *src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { - return static_cast(src->v_int64); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIInt; } - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIInt) + R"("})"; - } -}; - -// Float POD values -template -struct TypeTraits>> - : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIFloat; - - TVM_FFI_INLINE static void CopyToAnyView(const Float &src, TVMFFIAny *result) { - result->type_index = TypeIndex::kTVMFFIFloat; - result->zero_padding = 0; - result->v_float64 = static_cast(src); - } - - TVM_FFI_INLINE static void MoveToAny(Float src, TVMFFIAny *result) { CopyToAnyView(src, result); } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny - return src->type_index == TypeIndex::kTVMFFIFloat; - } - - TVM_FFI_INLINE static Float CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { - return static_cast(src->v_float64); - } - - TVM_FFI_INLINE static Float MoveFromAnyAfterCheck(TVMFFIAny *src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFIFloat) { - return Float(src->v_float64); - } else if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { - return Float(src->v_int64); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIFloat; } - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIFloat) + R"("})"; - } -}; - -// void* -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIOpaquePtr; - - TVM_FFI_INLINE static void CopyToAnyView(void *src, TVMFFIAny *result) { - result->type_index = TypeIndex::kTVMFFIOpaquePtr; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_ptr = src; - } - - TVM_FFI_INLINE static void MoveToAny(void *src, TVMFFIAny *result) { CopyToAnyView(src, result); } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny - return src->type_index == TypeIndex::kTVMFFIOpaquePtr; - } - - TVM_FFI_INLINE static void *CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { return src->v_ptr; } - - TVM_FFI_INLINE static void *MoveFromAnyAfterCheck(TVMFFIAny *src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFIOpaquePtr) { - return static_cast(src->v_ptr); - } - if (src->type_index == TypeIndex::kTVMFFINone) { - return static_cast(nullptr); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIOpaquePtr; } - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIOpaquePtr) + R"("})"; - } -}; - -// Device -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDevice; - - TVM_FFI_INLINE static void CopyToAnyView(const DLDevice &src, TVMFFIAny *result) { - result->type_index = TypeIndex::kTVMFFIDevice; - result->zero_padding = 0; - result->v_device = src; - } - - TVM_FFI_INLINE static void MoveToAny(DLDevice src, TVMFFIAny *result) { - result->type_index = TypeIndex::kTVMFFIDevice; - result->zero_padding = 0; - result->v_device = src; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - return src->type_index == TypeIndex::kTVMFFIDevice; - } - - TVM_FFI_INLINE static DLDevice CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { - return src->v_device; - } - - TVM_FFI_INLINE static DLDevice MoveFromAnyAfterCheck(TVMFFIAny *src) { - // POD type, we can just copy the value - return CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFIDevice) { - return src->v_device; - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return StaticTypeKey::kTVMFFIDevice; } - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(StaticTypeKey::kTVMFFIDevice) + R"("})"; - } -}; - -// DLTensor*, requirement: not nullable, do not retain ownership -template <> -struct TypeTraits : public TypeTraitsBase { - static constexpr bool storage_enabled = false; - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIDLTensorPtr; - - TVM_FFI_INLINE static void CopyToAnyView(DLTensor *src, TVMFFIAny *result) { - TVM_FFI_ICHECK_NOTNULL(src); - result->type_index = TypeIndex::kTVMFFIDLTensorPtr; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_ptr = src; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - return src->type_index == TypeIndex::kTVMFFIDLTensorPtr; - } - - TVM_FFI_INLINE static DLTensor *CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { - return static_cast(src->v_ptr); - } - - TVM_FFI_INLINE static void MoveToAny(DLTensor *, TVMFFIAny *) { - TVM_FFI_THROW(RuntimeError) - << "DLTensor* cannot be held in Any as it does not retain ownership, use Tensor instead"; - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFIDLTensorPtr) { - return static_cast(src->v_ptr); - } else if (src->type_index == TypeIndex::kTVMFFITensor) { - // Conversion from Tensor pointer to DLTensor - // based on the assumption that Tensor always follows the TVMFFIObject header - static_assert(sizeof(TVMFFIObject) == 24); - return reinterpret_cast(reinterpret_cast(src->v_obj) + sizeof(TVMFFIObject)); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return "DLTensor*"; } - TVM_FFI_INLINE static std::string TypeSchema() { return R"({"type":"DLTensor*"})"; } -}; - -// Traits for ObjectRef, None to ObjectRef will always fail. -// use std::optional instead for nullable references. -template -struct ObjectRefTypeTraitsBase : public TypeTraitsBase { - static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIObject; - using ContainerType = typename TObjRef::ContainerType; - - TVM_FFI_INLINE static void CopyToAnyView(const TObjRef &src, TVMFFIAny *result) { - if constexpr (TObjRef::_type_is_nullable) { - if (!src.defined()) { - TypeTraits::CopyToAnyView(nullptr, result); - return; - } - } - TVMFFIObject *obj_ptr = details::ObjectUnsafe::TVMFFIObjectPtrFromObjectRef(src); - result->type_index = obj_ptr->type_index; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_obj = obj_ptr; - } - - TVM_FFI_INLINE static void MoveToAny(TObjRef src, TVMFFIAny *result) { - if constexpr (TObjRef::_type_is_nullable) { - if (!src.defined()) { - TypeTraits::CopyToAnyView(nullptr, result); - return; - } - } - TVMFFIObject *obj_ptr = details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(src)); - result->type_index = obj_ptr->type_index; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_obj = obj_ptr; - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - if constexpr (TObjRef::_type_is_nullable) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return true; - } - } - return (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin && details::IsObjectInstance(src->type_index)); - } - - TVM_FFI_INLINE static TObjRef CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { - if constexpr (TObjRef::_type_is_nullable) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); - } - } - return details::ObjectUnsafe::ObjectRefFromObjectPtr( - details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); - } - - TVM_FFI_INLINE static TObjRef MoveFromAnyAfterCheck(TVMFFIAny *src) { - if constexpr (TObjRef::_type_is_nullable) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); - } - } - // move out the object pointer - ObjectPtr obj_ptr = details::ObjectUnsafe::ObjectPtrFromOwned(src->v_obj); - // reset the src to nullptr - TypeTraits::MoveToAny(nullptr, src); - return details::ObjectUnsafe::ObjectRefFromObjectPtr(std::move(obj_ptr)); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if constexpr (TObjRef::_type_is_nullable) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return details::ObjectUnsafe::ObjectRefFromObjectPtr(nullptr); - } - } - if (src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin) { - if (details::IsObjectInstance(src->type_index)) { - return details::ObjectUnsafe::ObjectRefFromObjectPtr( - details::ObjectUnsafe::ObjectPtrFromUnowned(src->v_obj)); - } - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return ContainerType::_type_key; } - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(ContainerType::_type_key) + R"("})"; - } -}; - -template -struct TypeTraits && use_default_type_traits_v>> - : public ObjectRefTypeTraitsBase {}; - -/*! - * \brief Helper class that convert to T only via the FallbackTypes - * - * The conversion will go through the FallbackTypes in the order - * specified in the template parameter. - * \tparam T The type of the target value. - * \tparam FallbackTypes The type of the fallback value. - * \note TypeTraits must be derived from this class and define - * ConvertFallbackValue(FallbackType)->T for each FallbackType - */ -template -struct FallbackOnlyTraitsBase : public TypeTraitsBase { - // disable container for FallbackOnlyTraitsBase - /// \cond Doxygen_Suppress - static constexpr bool storage_enabled = false; - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - return TryFallbackTypes(src); - } - - template - TVM_FFI_INLINE static std::optional TryFallbackTypes(const TVMFFIAny *src) { - static_assert(!std::is_same_v, - "Using bool as FallbackType can cause bug because int will be detected as bool, " - "use tvm::ffi::StrictBool instead"); - if (auto opt_fallback = TypeTraits::TryCastFromAnyView(src)) { - return TypeTraits::ConvertFallbackValue(*std::move(opt_fallback)); - } - if constexpr (sizeof...(Rest) > 0) { - return TryFallbackTypes(src); - } - return std::nullopt; - } - /// \endcond -}; - -/*! - * \brief Helper class to define ObjectRef that can be auto-converted from a - * fallback type, the Traits must be derived from it - * and define a static methods named ConvertFallbackValue for each - * FallbackType - * - * The conversion will go through the FallbackTypes in the order - * specified in the template parameter. - * \tparam ObjectRefType The type of the ObjectRef. - * \tparam FallbackTypes The type of the fallback value. - */ -template -struct ObjectRefWithFallbackTraitsBase : public ObjectRefTypeTraitsBase { - /// \cond Doxygen_Suppress - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if (auto opt_obj = ObjectRefTypeTraitsBase::TryCastFromAnyView(src)) { - return opt_obj; - } - // apply fallback types in TryCastFromAnyView - return TryFallbackTypes(src); - } - - template - TVM_FFI_INLINE static std::optional TryFallbackTypes(const TVMFFIAny *src) { - static_assert(!std::is_same_v, - "Using bool as FallbackType can cause bug because int will be detected as bool, " - "use tvm::ffi::StrictBool instead"); - if (auto opt_fallback = TypeTraits::TryCastFromAnyView(src)) { - return TypeTraits::ConvertFallbackValue(*std::move(opt_fallback)); - } - if constexpr (sizeof...(Rest) > 0) { - return TryFallbackTypes(src); - } - return std::nullopt; - } - /// \endcond -}; - -// Traits for weak pointer of object -// NOTE: we require the weak pointer cast from - -template -struct TypeTraits>> - : public TypeTraitsBase { - TVM_FFI_INLINE static void CopyToAnyView(TObject *src, TVMFFIAny *result) { - TVMFFIObject *obj_ptr = details::ObjectUnsafe::GetHeader(src); - result->type_index = obj_ptr->type_index; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_obj = obj_ptr; - } - - TVM_FFI_INLINE static void MoveToAny(TObject *src, TVMFFIAny *result) { - TVMFFIObject *obj_ptr = details::ObjectUnsafe::GetHeader(src); - result->type_index = obj_ptr->type_index; - result->zero_padding = 0; - TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result); - result->v_obj = obj_ptr; - // needs to increase ref because original weak ptr do not own the code - details::ObjectUnsafe::IncRefObjectHandle(result->v_obj); - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - return src->type_index >= TypeIndex::kTVMFFIStaticObjectBegin && details::IsObjectInstance(src->type_index); - } - - TVM_FFI_INLINE static TObject *CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { - if constexpr (!std::is_const_v) { - static_assert(TObject::_type_mutable, "TObject must be mutable to enable cast from Any"); - } - return details::ObjectUnsafe::RawObjectPtrFromUnowned(src->v_obj); - } - - TVM_FFI_INLINE static std::optional TryCastFromAnyView(const TVMFFIAny *src) { - if constexpr (!std::is_const_v) { - static_assert(TObject::_type_mutable, "TObject must be mutable to enable cast from Any"); - } - if (CheckAnyStrict(src)) { - return CopyFromAnyViewAfterCheck(src); - } - return std::nullopt; - } - - TVM_FFI_INLINE static std::string TypeStr() { return TObject::_type_key; } - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":")" + std::string(TObject::_type_key) + R"("})"; - } -}; - -template -inline constexpr bool use_default_type_traits_v> = false; - -template -struct TypeTraits> : public TypeTraitsBase { - TVM_FFI_INLINE static void CopyToAnyView(const Optional &src, TVMFFIAny *result) { - if (src.has_value()) { - TypeTraits::CopyToAnyView(*src, result); - } else { - TypeTraits::CopyToAnyView(nullptr, result); - } - } - - TVM_FFI_INLINE static void MoveToAny(Optional src, TVMFFIAny *result) { - if (src.has_value()) { - TypeTraits::MoveToAny(*std::move(src), result); - } else { - TypeTraits::CopyToAnyView(nullptr, result); - } - } - - TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return true; - } - return TypeTraits::CheckAnyStrict(src); - } - - TVM_FFI_INLINE static Optional CopyFromAnyViewAfterCheck(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return Optional(std::nullopt); - } - return TypeTraits::CopyFromAnyViewAfterCheck(src); - } - - TVM_FFI_INLINE static Optional MoveFromAnyAfterCheck(TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return Optional(std::nullopt); - } - return TypeTraits::MoveFromAnyAfterCheck(src); - } - - TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny *src) { - if (src->type_index == TypeIndex::kTVMFFINone) { - return Optional(std::nullopt); - } - if (std::optional opt = TypeTraits::TryCastFromAnyView(src)) { - return Optional(*std::move(opt)); - } else { - // important to be explicit here - // because nullopt can convert to std::optional(nullopt) which indicate success - // return std::optional>(std::nullopt) to indicate failure - return std::optional>(std::nullopt); - } - } - - TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny *src) { - return TypeTraits::GetMismatchTypeInfo(src); - } - - TVM_FFI_INLINE static std::string TypeStr() { - return "Optional<" + TypeTraits::TypeStr() + ">"; - } - TVM_FFI_INLINE static std::string TypeSchema() { - return R"({"type":"Optional","args":[)" + details::TypeSchema::v() + "]}"; - } -}; -} // namespace ffi -} // namespace tvm -#endif // TVM_FFI_TYPE_TRAITS_H_ diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh index 18b52cf99..18d5da7c3 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.cuh @@ -17,8 +17,8 @@ #include "utils.h" -#include "dlpack/dlpack.h" -#include "tvm/ffi/extra/c_env_api.h" +#include +#include #include #include diff --git a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h index e2f0b8420..d6892d0dd 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h +++ b/src/infiniop/ops/gptq_marlin_gemm/sgl_kernel/utils.h @@ -45,9 +45,9 @@ #include "source_location.h" #endif -#include "dlpack/dlpack.h" #include #include +#include #include #include #include From c4b99a8e4cea30da5cfcff6514453775154bdd54 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Wed, 20 May 2026 06:25:44 +0000 Subject: [PATCH 07/10] issue/1083: add README.md --- README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/README.md b/README.md index bd0f7fe64..de8e73da8 100644 --- a/README.md +++ b/README.md @@ -176,6 +176,20 @@ python scripts/install.py [XMAKE_CONFIG_FLAGS] ``` + ##### 试验功能 -- 编译marlin相关算子 + + ```shell + + # 需要从github上克隆dlpack以及tvm_ffi仓库,克隆命令参考 + git clone git@github.com:dmlc/dlpack.git --recursive + git clone git@github.com:apache/tvm-ffi.git --recursive + + # 设置CPATH + export CPATH=/tvm-ffi/include:$CPATH #用来搜索tvm相关头文件 + export CPATH=/dlpack/include:$CPATH #用来搜索dlpack.h + + ``` + 2. 编译安装 默认安装路径为 `$HOME/.infini`。 From d8a23e0568470174003d1ffa64f2ab9b3e430ba9 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Wed, 20 May 2026 08:06:25 +0000 Subject: [PATCH 08/10] issue/1083: add TVM_ROOT --- README.md | 10 ++--- .../nvidia/gptq_marlin_gemm_nvidia.cu | 12 +++--- xmake/nvidia.lua | 39 +++++++++++-------- 3 files changed, 33 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index de8e73da8..9e3b14357 100644 --- a/README.md +++ b/README.md @@ -180,13 +180,11 @@ python scripts/install.py [XMAKE_CONFIG_FLAGS] ```shell - # 需要从github上克隆dlpack以及tvm_ffi仓库,克隆命令参考 - git clone git@github.com:dmlc/dlpack.git --recursive - git clone git@github.com:apache/tvm-ffi.git --recursive + # 需要从github上克隆tvm_ffi仓库,克隆命令参考 + git clone https://github.com/apache/tvm-ffi.git --recursive - # 设置CPATH - export CPATH=/tvm-ffi/include:$CPATH #用来搜索tvm相关头文件 - export CPATH=/dlpack/include:$CPATH #用来搜索dlpack.h + # 设置TVM_ROOT + export TVM_ROOT=/tvm-ffi #用来搜索tvm相关头文件 ``` diff --git a/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu b/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu index 59271f78b..99d9b315b 100644 --- a/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu +++ b/src/infiniop/ops/gptq_marlin_gemm/nvidia/gptq_marlin_gemm_nvidia.cu @@ -3,13 +3,12 @@ #include "../../../devices/nvidia/nvidia_handle.cuh" #include "../../../devices/nvidia/nvidia_kernel_common.cuh" #include "../gptq_marlin_gemm.h" -#include "../sgl_kernel/tensor.h" #include "gptq_marlin_gemm_nvidia.cuh" - -#include "../sgl_kernel/scalar_type.hpp" - +#if defined ENABLE_TVM_API #include "../marlin/kernel.h" #include "../marlin/marlin_template.h" +#include "../sgl_kernel/scalar_type.hpp" +#include "../sgl_kernel/tensor.h" namespace device::marlin { @@ -1044,7 +1043,7 @@ infiniStatus_t gptq_marlin_gemm_kernel(void *c, stream); return INFINI_STATUS_SUCCESS; } - +#endif int getCudaDeviceSMCount() { int dev; cudaGetDevice(&dev); @@ -1114,7 +1113,7 @@ Descriptor::calculate( bool use_fp32_reduce, bool is_zp_float, void *stream_) const { - +#if defined ENABLE_TVM_API cudaStream_t stream = (cudaStream_t)stream_; int64_t M = static_cast(_info.M); int64_t K = static_cast(_info.K); @@ -1133,6 +1132,7 @@ Descriptor::calculate( } else { return INFINI_STATUS_BAD_TENSOR_DTYPE; } +#endif return INFINI_STATUS_SUCCESS; } diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index 086e6a924..cf229041a 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -4,22 +4,13 @@ if CUDNN_ROOT ~= nil then end local CUTLASS_ROOT = os.getenv("CUTLASS_ROOT") or os.getenv("CUTLASS_HOME") or os.getenv("CUTLASS_PATH") +local TVM_ROOT = os.getenv("TVM_ROOT") or os.getenv("TVM_HOME") or os.getenv("TVM_PATH") local FLASH_ATTN_ROOT = get_config("flash-attn") local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini") -function parse_sgl_cuda_arch(arch) - - local num = arch:match("sm_(%d+)") - if not num then - return nil - end - - return tonumber(num) * 10 -end - target("infiniop-nvidia") set_kind("static") add_deps("infini-utils") @@ -111,15 +102,31 @@ target("infiniop-nvidia") end local arch_opt = get_config("cuda_arch") - if arch_opt then - local sgl_arch = parse_sgl_cuda_arch(arch_opt) - if sgl_arch then - add_defines("SGL_CUDA_ARCH=" .. sgl_arch) - print("SGL_CUDA_ARCH =", sgl_arch) + if TVM_ROOT ~= nil then + add_defines("ENABLE_TVM_API") + add_includedirs(TVM_ROOT, TVM_ROOT .. "/include", TVM_ROOT .. "/3rdparty/dlpack/include/") + function parse_sgl_cuda_arch(arch) + + local num = arch:match("sm_(%d+)") + if not num then + return nil + end + + return tonumber(num) * 10 + end + if arch_opt then + local sgl_arch = parse_sgl_cuda_arch(arch_opt) + if sgl_arch then + add_defines("SGL_CUDA_ARCH=" .. sgl_arch) + print("SGL_CUDA_ARCH =", sgl_arch) + else + print("Invalid cuda_arch:", arch_opt) + end else - print("Invalid cuda_arch:", arch_opt) + error("tvm complie marlin needs cuda_arch") end end + if arch_opt and type(arch_opt) == "string" then for _, arch in ipairs(arch_opt:split(",")) do arch = arch:trim() From 6d82fe47004f68f43fcd003eacd5490ab4c0e101 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Wed, 20 May 2026 08:08:38 +0000 Subject: [PATCH 09/10] issue/1083: modified nvidia.lua --- xmake/nvidia.lua | 1 - 1 file changed, 1 deletion(-) diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index cf229041a..6ccb2f358 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -10,7 +10,6 @@ local FLASH_ATTN_ROOT = get_config("flash-attn") local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini") - target("infiniop-nvidia") set_kind("static") add_deps("infini-utils") From 50cc3dab23042181d483ef19dae8f5c6146f5d70 Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Wed, 20 May 2026 08:14:44 +0000 Subject: [PATCH 10/10] issue/1083: add commit id tvm-ffi --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 9e3b14357..21dbc46e8 100644 --- a/README.md +++ b/README.md @@ -181,6 +181,7 @@ python scripts/install.py [XMAKE_CONFIG_FLAGS] ```shell # 需要从github上克隆tvm_ffi仓库,克隆命令参考 + ## tvm-ffi commit: 35c99d0ac4cb784862115d0089f60c603acec8f9 git clone https://github.com/apache/tvm-ffi.git --recursive # 设置TVM_ROOT