diff --git a/Cargo.lock b/Cargo.lock index 107e1f4f29..2651ce5ce8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1153,6 +1153,7 @@ dependencies = [ "turso", "turso-dbhash", "turso_core", + "turso_ext", "turso_macros", "turso_parser", "turso_sdk_kit", diff --git a/core/connection.rs b/core/connection.rs index 976e2d3b75..b586aefa6e 100644 --- a/core/connection.rs +++ b/core/connection.rs @@ -2864,23 +2864,66 @@ impl Connection { self.syms.read().vtab_modules.keys().cloned().collect() } - /// Returns external (extension) functions: (name, is_aggregate, argc) - pub fn get_syms_functions(&self) -> Vec<(String, bool, i32)> { + /// Returns external (extension) functions: (name, is_aggregate, argc, deterministic) + pub fn get_syms_functions(&self) -> Vec<(String, bool, i32, bool)> { self.syms .read() .functions .values() .map(|f| { - let is_agg = matches!(f.func, function::ExtFunc::Aggregate { .. }); + let is_agg = f.func.is_aggregate(); let argc = match &f.func { function::ExtFunc::Aggregate { argc, .. } => *argc as i32, + function::ExtFunc::ContextScalar { argc, .. } => *argc, function::ExtFunc::Scalar(_) => -1, }; - (f.name.clone(), is_agg, argc) + ( + f.name.clone(), + is_agg, + argc, + function::Deterministic::is_deterministic(f.as_ref()), + ) }) .collect() } + #[allow(clippy::too_many_arguments)] + pub fn register_external_scalar_function( + &self, + name: String, + argc: i32, + deterministic: bool, + context: usize, + callback: crate::ContextScalarFunction, + context_destructor: Option, + value_destructor: Option, + ) { + assert!( + argc >= -1, + "managed scalar argument count must be -1 (variadic) or non-negative" + ); + let normalized_name = crate::util::normalize_ident(&name); + self.syms.write().functions.insert( + normalized_name.clone(), + Arc::new(function::ExternalFunc::new_context_scalar( + normalized_name, + argc, + deterministic, + context, + callback, + context_destructor, + value_destructor, + )), + ); + self.bump_prepare_context_generation(); + } + + pub fn unregister_external_function(&self, name: &str) { + let normalized_name = crate::util::normalize_ident(name); + self.syms.write().functions.remove(&normalized_name); + self.bump_prepare_context_generation(); + } + pub(crate) fn database_ptr(&self) -> usize { Arc::as_ptr(&self.db) as usize } @@ -3391,9 +3434,17 @@ impl SymbolTable { pub fn resolve_function( &self, name: &str, - _arg_count: usize, + arg_count: usize, ) -> Option> { - self.functions.get(name).cloned() + self.functions + .get(name) + .cloned() + .or_else(|| { + self.functions + .get(&crate::util::normalize_ident(name)) + .cloned() + }) + .filter(|func| func.func.matches_arg_count(arg_count)) } pub fn extend(&mut self, other: &SymbolTable) { diff --git a/core/function.rs b/core/function.rs index 9757bf6365..4e99d5e628 100644 --- a/core/function.rs +++ b/core/function.rs @@ -2,9 +2,100 @@ use crate::sync::Arc; use std::fmt; use std::fmt::{Debug, Display}; use strum::IntoEnumIterator; -use turso_ext::{FinalizeFunction, InitAggFunction, ScalarFunction, StepFunction}; +use turso_ext::{ + FinalizeFunction, InitAggFunction, ScalarFunction, StepFunction, Value as ExtValue, +}; + +use crate::{LimboError, Value}; + +pub type ContextScalarFunction = unsafe extern "C" fn( + context: usize, + argc: i32, + argv: *const ExtValue, + result: *mut ContextValue, +); +pub type ContextDestructor = unsafe extern "C" fn(context: usize); +pub type ContextValueDestructor = unsafe extern "C" fn(result: *mut ContextValue); + +#[repr(C)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum ContextValueType { + Null, + Integer, + Float, + Text, + Blob, + Error, +} + +#[repr(C)] +#[derive(Clone, Copy)] +pub struct ContextValueBytes { + pub ptr: *const u8, + pub len: usize, +} + +#[repr(C)] +#[derive(Clone, Copy)] +pub union ContextValueData { + pub int: i64, + pub float: f64, + pub bytes: ContextValueBytes, +} + +#[repr(C)] +#[derive(Clone, Copy)] +pub struct ContextValue { + pub value_type: ContextValueType, + pub value: ContextValueData, +} + +impl ContextValue { + pub fn null() -> Self { + Self { + value_type: ContextValueType::Null, + value: ContextValueData { int: 0 }, + } + } -use crate::LimboError; + pub fn into_value(self) -> Result { + // Text/blob/error payloads are callback-owned; copy them before the + // caller invokes the registered value destructor. + match self.value_type { + ContextValueType::Null => Ok(Value::Null), + ContextValueType::Integer => Ok(Value::from_i64(unsafe { self.value.int })), + ContextValueType::Float => Ok(Value::from_f64(unsafe { self.value.float })), + ContextValueType::Text => { + let bytes = unsafe { self.value.bytes }; + if bytes.ptr.is_null() { + return Ok(Value::Null); + } + let slice = unsafe { std::slice::from_raw_parts(bytes.ptr, bytes.len) }; + let text = std::str::from_utf8(slice) + .map_err(|err| LimboError::ExtensionError(err.to_string()))?; + Ok(Value::build_text(text.to_string())) + } + ContextValueType::Blob => { + let bytes = unsafe { self.value.bytes }; + if bytes.ptr.is_null() { + return Ok(Value::Blob(Vec::new())); + } + let slice = unsafe { std::slice::from_raw_parts(bytes.ptr, bytes.len) }; + Ok(Value::Blob(slice.to_vec())) + } + ContextValueType::Error => { + let bytes = unsafe { self.value.bytes }; + if bytes.ptr.is_null() { + return Err(LimboError::ExtensionError(String::new())); + } + let slice = unsafe { std::slice::from_raw_parts(bytes.ptr, bytes.len) }; + let message = std::str::from_utf8(slice) + .map_err(|err| LimboError::ExtensionError(err.to_string()))?; + Err(LimboError::ExtensionError(message.to_string())) + } + } + } +} pub trait Deterministic: std::fmt::Display { fn is_deterministic(&self) -> bool; @@ -17,14 +108,24 @@ pub struct ExternalFunc { impl Deterministic for ExternalFunc { fn is_deterministic(&self) -> bool { - // external functions can be whatever so let's just default to false - false + match self.func { + ExtFunc::ContextScalar { deterministic, .. } => deterministic, + _ => false, + } } } #[derive(Debug, Clone)] pub enum ExtFunc { Scalar(ScalarFunction), + ContextScalar { + context: usize, + argc: i32, + deterministic: bool, + callback: ContextScalarFunction, + context_destructor: Option, + value_destructor: Option, + }, Aggregate { argc: usize, init: InitAggFunction, @@ -40,6 +141,17 @@ impl ExtFunc { } Err(()) } + + pub fn matches_arg_count(&self, arg_count: usize) -> bool { + match self { + Self::ContextScalar { argc, .. } => *argc < 0 || *argc as usize == arg_count, + Self::Scalar(_) | Self::Aggregate { .. } => true, + } + } + + pub fn is_aggregate(&self) -> bool { + matches!(self, Self::Aggregate { .. }) + } } impl ExternalFunc { @@ -65,6 +177,41 @@ impl ExternalFunc { }, } } + + pub fn new_context_scalar( + name: String, + argc: i32, + deterministic: bool, + context: usize, + callback: ContextScalarFunction, + context_destructor: Option, + value_destructor: Option, + ) -> Self { + Self { + name, + func: ExtFunc::ContextScalar { + context, + argc, + deterministic, + callback, + context_destructor, + value_destructor, + }, + } + } +} + +impl Drop for ExternalFunc { + fn drop(&mut self) { + if let ExtFunc::ContextScalar { + context, + context_destructor: Some(context_destructor), + .. + } = self.func + { + unsafe { context_destructor(context) }; + } + } } impl Debug for ExternalFunc { diff --git a/core/lib.rs b/core/lib.rs index 95705dffc1..7215d4963b 100644 --- a/core/lib.rs +++ b/core/lib.rs @@ -127,6 +127,10 @@ use util::parse_schema_rows; pub use connection::{resolve_ext_path, Connection, Row, StepResult, SymbolTable}; pub(crate) use connection::{AtomicTransactionState, TransactionState}; pub use error::{io_error, CompletionError, LimboError}; +pub use function::{ + ContextDestructor, ContextScalarFunction, ContextValue, ContextValueBytes, ContextValueData, + ContextValueDestructor, ContextValueType, +}; #[cfg(all(feature = "fs", target_family = "unix", not(miri)))] pub use io::UnixIO; #[cfg(all(feature = "fs", target_os = "linux", feature = "io_uring", not(miri)))] diff --git a/core/storage/buffer_pool.rs b/core/storage/buffer_pool.rs index 3e3684a18b..a217015a21 100644 --- a/core/storage/buffer_pool.rs +++ b/core/storage/buffer_pool.rs @@ -474,11 +474,11 @@ mod arena { #[cfg(any(not(unix), miri))] mod arena { - pub fn alloc(len: usize) -> *mut u8 { + pub unsafe fn alloc(len: usize) -> *mut u8 { let layout = std::alloc::Layout::from_size_align(len, std::mem::size_of::()).unwrap(); unsafe { std::alloc::alloc_zeroed(layout) } } - pub fn dealloc(ptr: *mut u8, len: usize) { + pub unsafe fn dealloc(ptr: *mut u8, len: usize) { let layout = std::alloc::Layout::from_size_align(len, std::mem::size_of::()).unwrap(); unsafe { std::alloc::dealloc(ptr, layout) }; } diff --git a/core/translate/pragma.rs b/core/translate/pragma.rs index f9de6be8f9..c1272b252a 100644 --- a/core/translate/pragma.rs +++ b/core/translate/pragma.rs @@ -949,14 +949,18 @@ fn query_pragma( } // External (extension) functions - for (name, is_agg, argc) in connection.get_syms_functions() { + for (name, is_agg, argc, deterministic) in connection.get_syms_functions() { let func_type = if is_agg { "a" } else { "s" }; + let mut flags = 0; + if deterministic { + flags |= SQLITE_DETERMINISTIC; + } program.emit_string8(name, base_reg); program.emit_int(0, base_reg + 1); // builtin = 0 program.emit_string8(func_type.to_string(), base_reg + 2); program.emit_string8("utf8".to_string(), base_reg + 3); program.emit_int(argc as i64, base_reg + 4); - program.emit_int(0, base_reg + 5); // flags = 0 for extensions + program.emit_int(flags, base_reg + 5); program.emit_result_row(base_reg, 6); } diff --git a/core/vdbe/execute.rs b/core/vdbe/execute.rs index b6cf53b135..e4e1c8e0f1 100644 --- a/core/vdbe/execute.rs +++ b/core/vdbe/execute.rs @@ -8139,6 +8139,36 @@ pub fn op_function( } } } + ExtFunc::ContextScalar { + context, + callback, + value_destructor, + .. + } => { + let mut ext_values = Vec::with_capacity(arg_count); + if arg_count != 0 { + let register_slice = &state.registers[*start_reg..*start_reg + arg_count]; + for ov in register_slice.iter() { + ext_values.push(ov.get_value().to_ffi()); + } + } + + let argv_ptr = if ext_values.is_empty() { + std::ptr::null() + } else { + ext_values.as_ptr() + }; + let mut result = crate::function::ContextValue::null(); + unsafe { callback(context, arg_count as i32, argv_ptr, &mut result) }; + let value = result.into_value(); + if let Some(value_destructor) = value_destructor { + unsafe { value_destructor(&mut result) }; + } + for ext_value in ext_values { + unsafe { ext_value.__free_internal_type() }; + } + state.registers[*dest].set_value(value?); + } _ => unreachable!("aggregate called in scalar context"), }, crate::function::Func::Math(math_func) => match math_func.arity() { diff --git a/extensions/core/src/types.rs b/extensions/core/src/types.rs index 9f8e3c1f5a..358f25d9e7 100644 --- a/extensions/core/src/types.rs +++ b/extensions/core/src/types.rs @@ -206,7 +206,8 @@ impl TextValue { #[cfg(feature = "core_only")] fn free(self) { if !self.text.is_null() { - let _ = unsafe { Box::from_raw(self.text as *mut u8) }; + let ptr = std::ptr::slice_from_raw_parts_mut(self.text as *mut u8, self.len as usize); + let _ = unsafe { Box::from_raw(ptr) }; } } @@ -273,7 +274,8 @@ impl Blob { #[cfg(feature = "core_only")] fn free(self) { if !self.data.is_null() { - let _ = unsafe { Box::from_raw(self.data as *mut u8) }; + let ptr = std::ptr::slice_from_raw_parts_mut(self.data as *mut u8, self.size as usize); + let _ = unsafe { Box::from_raw(ptr) }; } } } @@ -520,7 +522,8 @@ impl Value { ValueType::Error => { let err_val = Box::from_raw(self.value.error as *mut ErrValue); if !err_val.message.is_null() { - let _ = Box::from_raw(err_val.message); + let message = Box::from_raw(err_val.message); + message.free(); } } _ => {} diff --git a/tests/Cargo.toml b/tests/Cargo.toml index 81ae826a5c..34f9ca8d2c 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -22,6 +22,7 @@ path = "fuzz/mod.rs" anyhow.workspace = true env_logger = { workspace = true } turso_core = { workspace = true, features = ["conn_raw_api"] } +turso_ext.workspace = true turso_sdk_kit = { path = "../sdk-kit" } turso = { workspace = true } tokio = { workspace = true, features = ["full"] } diff --git a/tests/integration/external_apis.rs b/tests/integration/external_apis.rs new file mode 100644 index 0000000000..e50840744b --- /dev/null +++ b/tests/integration/external_apis.rs @@ -0,0 +1,327 @@ +use crate::common::{limbo_exec_rows, ExecRows, TempDatabase}; +use rusqlite::types::Value as SqliteValue; +use serial_test::serial; +use std::sync::{ + atomic::{AtomicUsize, Ordering as AtomicOrdering}, + Arc, +}; +use turso_core::{ + ContextValue, ContextValueBytes, ContextValueData, ContextValueType, LimboError, StepResult, +}; +use turso_ext::{Value as ExtValue, ValueType as ExtValueType}; + +static SCALAR_VALUE_DROPS: AtomicUsize = AtomicUsize::new(0); + +#[derive(Default)] +struct CallbackCounters { + calls: AtomicUsize, + context_drops: AtomicUsize, +} + +struct ScalarContext { + multiplier: i64, + counters: Arc, +} + +fn integer_result(value: i64) -> ContextValue { + ContextValue { + value_type: ContextValueType::Integer, + value: ContextValueData { int: value }, + } +} + +fn float_result(value: f64) -> ContextValue { + ContextValue { + value_type: ContextValueType::Float, + value: ContextValueData { float: value }, + } +} + +fn bytes_result(value_type: ContextValueType, bytes: &'static [u8]) -> ContextValue { + ContextValue { + value_type, + value: ContextValueData { + bytes: ContextValueBytes { + ptr: bytes.as_ptr(), + len: bytes.len(), + }, + }, + } +} + +unsafe extern "C" fn managed_score( + context: usize, + argc: i32, + argv: *const ExtValue, + result: *mut ContextValue, +) { + let ctx = unsafe { &*(context as *const ScalarContext) }; + ctx.counters.calls.fetch_add(1, AtomicOrdering::SeqCst); + let args = if argc <= 0 || argv.is_null() { + &[] + } else { + unsafe { std::slice::from_raw_parts(argv, argc as usize) } + }; + + let int_value = args + .first() + .and_then(ExtValue::to_integer) + .unwrap_or_default(); + let float_value = args + .get(1) + .and_then(ExtValue::to_float) + .map(|value| value as i64) + .unwrap_or_default(); + let text_len = args + .get(2) + .and_then(ExtValue::to_text) + .map(str::len) + .unwrap_or_default() as i64; + let blob_len = args + .get(3) + .and_then(ExtValue::to_blob) + .map(|blob| blob.len()) + .unwrap_or_default() as i64; + let null_count = args + .iter() + .filter(|arg| arg.value_type() == ExtValueType::Null) + .count() as i64; + + unsafe { + *result = integer_result( + (int_value + float_value + text_len + blob_len + null_count) * ctx.multiplier, + ); + } +} + +unsafe extern "C" fn managed_result( + context: usize, + argc: i32, + argv: *const ExtValue, + result: *mut ContextValue, +) { + let ctx = unsafe { &*(context as *const ScalarContext) }; + ctx.counters.calls.fetch_add(1, AtomicOrdering::SeqCst); + let args = if argc <= 0 || argv.is_null() { + &[] + } else { + unsafe { std::slice::from_raw_parts(argv, argc as usize) } + }; + let mode = args.first().and_then(ExtValue::to_text).unwrap_or_default(); + + unsafe { + *result = match mode { + "null" => ContextValue::null(), + "text" => bytes_result(ContextValueType::Text, b"managed-text"), + "blob" => bytes_result(ContextValueType::Blob, b"\x01\x02\xFE"), + "float" => float_result(3.25), + "error" => bytes_result(ContextValueType::Error, b"managed failure"), + _ => bytes_result(ContextValueType::Error, b"unexpected mode"), + }; + } +} + +unsafe extern "C" fn managed_variadic_score( + context: usize, + argc: i32, + argv: *const ExtValue, + result: *mut ContextValue, +) { + let ctx = unsafe { &*(context as *const ScalarContext) }; + ctx.counters.calls.fetch_add(1, AtomicOrdering::SeqCst); + let args = if argc <= 0 || argv.is_null() { + &[] + } else { + unsafe { std::slice::from_raw_parts(argv, argc as usize) } + }; + + let null_count = args + .iter() + .filter(|arg| arg.value_type() == ExtValueType::Null) + .count() as i64; + let score = (argc as i64 * 100) + + args + .first() + .and_then(ExtValue::to_integer) + .unwrap_or_default() + + args + .get(1) + .and_then(ExtValue::to_float) + .map(|value| value as i64) + .unwrap_or_default() + + args + .get(2) + .and_then(ExtValue::to_text) + .map(str::len) + .unwrap_or_default() as i64 + + args + .get(3) + .and_then(ExtValue::to_blob) + .map(|blob| blob.len()) + .unwrap_or_default() as i64 + + null_count; + unsafe { + *result = integer_result(score * ctx.multiplier); + } +} + +unsafe extern "C" fn drop_scalar_context(context: usize) { + let context = unsafe { Box::from_raw(context as *mut ScalarContext) }; + context + .counters + .context_drops + .fetch_add(1, AtomicOrdering::SeqCst); +} + +unsafe extern "C" fn count_scalar_value_drop(_result: *mut ContextValue) { + SCALAR_VALUE_DROPS.fetch_add(1, AtomicOrdering::SeqCst); +} + +fn boxed_scalar_context(multiplier: i64, counters: Arc) -> usize { + Box::into_raw(Box::new(ScalarContext { + multiplier, + counters, + })) as usize +} + +#[turso_macros::test] +#[serial] +fn managed_scalar_callbacks_cover_fixed_args_metadata_and_invalidation( + tmp_db: TempDatabase, +) -> anyhow::Result<()> { + let counters = Arc::new(CallbackCounters::default()); + let conn = tmp_db.connect_limbo(); + + conn.register_external_scalar_function( + "managed_score".to_string(), + 5, + true, + boxed_scalar_context(1, counters.clone()), + managed_score, + Some(drop_scalar_context), + None, + ); + + let score: Vec<(i64,)> = conn.exec_rows("SELECT managed_score(2, 3.5, 'hi', x'010203', NULL)"); + assert_eq!(score, vec![(11,)]); + + let function_list: Vec<(String, i64, String, String, i64, i64)> = + conn.exec_rows("PRAGMA function_list"); + let managed_score_metadata = function_list + .iter() + .find(|(name, _, _, _, _, _)| name == "managed_score") + .expect("managed_score should be listed"); + assert_eq!(managed_score_metadata.2, "s"); + assert_eq!(managed_score_metadata.4, 5); + assert_ne!(managed_score_metadata.5 & 0x800, 0); + + let mut prepared = conn.prepare("SELECT managed_score(1, 2.0, 'a', x'00', NULL)")?; + conn.register_external_scalar_function( + "managed_score".to_string(), + 5, + true, + boxed_scalar_context(10, counters.clone()), + managed_score, + Some(drop_scalar_context), + None, + ); + match prepared.step()? { + StepResult::Row => {} + other => panic!("expected row from managed_score, got {other:?}"), + } + assert_eq!( + prepared + .row() + .expect("row should be available after StepResult::Row") + .get::(0)?, + 60 + ); + drop(prepared); + assert_eq!(counters.context_drops.load(AtomicOrdering::SeqCst), 1); + + assert!(conn.prepare("SELECT managed_score(1)").is_err()); + conn.unregister_external_function("managed_score"); + let err = conn + .prepare("SELECT managed_score(1, 2.0, 'a', x'00', NULL)") + .unwrap_err(); + assert!(err.to_string().contains("no such function")); + assert_eq!(counters.context_drops.load(AtomicOrdering::SeqCst), 2); + assert_eq!(counters.calls.load(AtomicOrdering::SeqCst), 2); + Ok(()) +} + +#[turso_macros::test] +#[serial] +fn managed_scalar_callbacks_convert_results_and_propagate_errors( + tmp_db: TempDatabase, +) -> anyhow::Result<()> { + SCALAR_VALUE_DROPS.store(0, AtomicOrdering::SeqCst); + let counters = Arc::new(CallbackCounters::default()); + let conn = tmp_db.connect_limbo(); + + conn.register_external_scalar_function( + "managed_result".to_string(), + 1, + false, + boxed_scalar_context(1, counters.clone()), + managed_result, + Some(drop_scalar_context), + Some(count_scalar_value_drop), + ); + + assert_eq!( + limbo_exec_rows(&conn, "SELECT managed_result('null')"), + vec![vec![SqliteValue::Null]] + ); + assert_eq!( + limbo_exec_rows(&conn, "SELECT managed_result('text')"), + vec![vec![SqliteValue::Text("managed-text".to_string())]] + ); + assert_eq!( + limbo_exec_rows(&conn, "SELECT managed_result('blob')"), + vec![vec![SqliteValue::Blob(vec![0x01, 0x02, 0xFE])]] + ); + let float_value: Vec<(f64,)> = conn.exec_rows("SELECT managed_result('float')"); + assert_eq!(float_value, vec![(3.25,)]); + + let err = conn.execute("SELECT managed_result('error')").unwrap_err(); + assert!(matches!(err, LimboError::ExtensionError(_))); + assert!(err.to_string().contains("managed failure")); + assert_eq!(SCALAR_VALUE_DROPS.load(AtomicOrdering::SeqCst), 5); + + conn.unregister_external_function("managed_result"); + assert_eq!(counters.context_drops.load(AtomicOrdering::SeqCst), 1); + assert_eq!(counters.calls.load(AtomicOrdering::SeqCst), 5); + Ok(()) +} + +#[turso_macros::test] +#[serial] +fn managed_scalar_variadic_callbacks_receive_callsite_arguments( + tmp_db: TempDatabase, +) -> anyhow::Result<()> { + let counters = Arc::new(CallbackCounters::default()); + let conn = tmp_db.connect_limbo(); + + conn.register_external_scalar_function( + "managed_variadic_score".to_string(), + -1, + true, + boxed_scalar_context(1, counters.clone()), + managed_variadic_score, + Some(drop_scalar_context), + None, + ); + + let score: Vec<(i64,)> = + conn.exec_rows("SELECT managed_variadic_score(1, 3.5, 'A', x'7E57', NULL)"); + assert_eq!(score, vec![(508,)]); + + let no_args: Vec<(i64,)> = conn.exec_rows("SELECT managed_variadic_score()"); + assert_eq!(no_args, vec![(0,)]); + + conn.unregister_external_function("managed_variadic_score"); + assert_eq!(counters.context_drops.load(AtomicOrdering::SeqCst), 1); + assert_eq!(counters.calls.load(AtomicOrdering::SeqCst), 2); + Ok(()) +} diff --git a/tests/integration/mod.rs b/tests/integration/mod.rs index 70fd9d094f..bbb9ab8525 100644 --- a/tests/integration/mod.rs +++ b/tests/integration/mod.rs @@ -4,6 +4,7 @@ mod common; mod conflict_resolution; mod custom_types; mod database; +mod external_apis; mod functions; mod fuzz_transaction; mod index_method;