From b1491b8b202a66f7ddec4279aea4d2ff35a58be2 Mon Sep 17 00:00:00 2001 From: ditadi Date: Tue, 12 May 2026 18:25:10 +0100 Subject: [PATCH] refactor(appkit): split SQLWarehouseConnector into submit/get/poll/transform MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `SQLClient.executeStatement` was a single block: submit the SQL, poll until terminal, transform the Arrow payload to JSON. Splitting it into four narrower public APIs lets durable executors compose them without holding the orchestrator open across the wait: - `submitStatement(sql, params, opts)` — POST `/sql/statements`, returns the raw initial response. Adds a dedicated `sql.submit` span. - `getStatement(id)` — GET `/sql/statements/{id}`, single status read. - `pollStatement(id, opts)` — block until the statement reaches a terminal state (SUCCEEDED / FAILED / CANCELED / CLOSED), respecting the same timeout, signal, and error semantics the old monolithic method had. - `transformResult(response)` — Arrow → JSON row transform, no I/O. `executeStatement(...)` is preserved and now composes the four publics (`submit` → `poll` → `transform`). No private wrapper-only helpers remain. Every error path, abort branch, and status state machine of the old method is exercised by the new per-API test suites (21 new tests against `submit` / `get` / `poll` / `transform`). Motivation (documented inline in JSDoc): durable callers — e.g. a future TaskFlow-based analytics handler — emit a `statement_submitted` event with the warehouse-side statement ID right after `submitStatement` returns, so on crash recovery they can re-attach via `pollStatement` without re-running the SQL. The TaskFlow integration itself is not in this PR. `executeStatement`'s contract is unchanged; analytics (the only external caller) keeps working without modification. The added `sql.submit` span is purely additive for OTLP collectors. Verified: pnpm -r typecheck, pnpm build, full pnpm test (122 files, 2276 tests) all green. Signed-off-by: ditadi --- .../src/connectors/sql-warehouse/client.ts | 331 ++++++++---- .../connectors/tests/sql-warehouse.test.ts | 499 ++++++++++++++++-- 2 files changed, 658 insertions(+), 172 deletions(-) diff --git a/packages/appkit/src/connectors/sql-warehouse/client.ts b/packages/appkit/src/connectors/sql-warehouse/client.ts index d0a1c1816..dde7047f2 100644 --- a/packages/appkit/src/connectors/sql-warehouse/client.ts +++ b/packages/appkit/src/connectors/sql-warehouse/client.ts @@ -30,14 +30,28 @@ interface SQLWarehouseConfig { telemetry?: TelemetryOptions; } +/** + * Unified shape returned by {@link SQLWarehouseConnector.transformResult}. + * Same top-level fields as {@link sql.StatementResponse}; `result.data` is + * the name-keyed projection of `result.data_array` for JSON queries. + * `result.external_links` is intentionally absent (pre-signed URLs that + * must not flow downstream). + */ +type SQLTransformedResponse = Omit & { + result?: Omit< + NonNullable, + "external_links" + > & { + data?: Record[]; + }; +}; + export class SQLWarehouseConnector { private readonly name = "sql-warehouse"; private config: SQLWarehouseConfig; - // Lazy-initialized: only created when Arrow format is used private _arrowProcessor: ArrowStreamProcessor | null = null; - // telemetry private readonly telemetry: TelemetryProvider; private readonly telemetryMetrics: { queryCount: Counter; @@ -66,21 +80,11 @@ export class SQLWarehouseConnector { } /** - * Lazily initializes and returns the ArrowStreamProcessor. - * Only created on first Arrow format query to avoid unnecessary allocation. + * Submit a statement, poll if it hasn't reached a terminal state, and + * transform the result. Callers that need to persist the warehouse-side + * `statement_id` between submission and polling can compose + * {@link submitStatement} + {@link pollStatement} directly. */ - private get arrowProcessor(): ArrowStreamProcessor { - if (!this._arrowProcessor) { - this._arrowProcessor = new ArrowStreamProcessor({ - timeout: this.config.timeout || executeStatementDefaults.timeout, - maxConcurrentDownloads: - ArrowStreamProcessor.DEFAULT_MAX_CONCURRENT_DOWNLOADS, - retries: ArrowStreamProcessor.DEFAULT_RETRIES, - }); - } - return this._arrowProcessor; - } - async executeStatement( workspaceClient: WorkspaceClient, input: sql.ExecuteStatementRequest, @@ -89,7 +93,6 @@ export class SQLWarehouseConnector { const startTime = Date.now(); let success = false; - // if signal is aborted, throw an error if (signal?.aborted) { throw ExecutionError.canceled(); } @@ -113,7 +116,6 @@ export class SQLWarehouseConnector { if (signal) { abortHandler = () => { - // abort span if not recording if (!span.isRecording()) return; isAborted = true; span.setAttribute("cancelled", true); @@ -127,73 +129,39 @@ export class SQLWarehouseConnector { } try { - // validate required fields - if (!input.statement) { - throw ValidationError.missingField("statement"); - } - - if (!input.warehouse_id) { - throw ValidationError.missingField("warehouse_id"); - } - - const body: sql.ExecuteStatementRequest = { - statement: input.statement, - parameters: input.parameters, - warehouse_id: input.warehouse_id, - catalog: input.catalog, - schema: input.schema, - wait_timeout: - input.wait_timeout || executeStatementDefaults.wait_timeout, - disposition: - input.disposition || executeStatementDefaults.disposition, - format: input.format || executeStatementDefaults.format, - byte_limit: input.byte_limit, - row_limit: input.row_limit, - on_wait_timeout: - input.on_wait_timeout || executeStatementDefaults.on_wait_timeout, - }; - span.addEvent("statement.submitting", { - "db.warehouse_id": input.warehouse_id, + "db.warehouse_id": input.warehouse_id ?? "", }); - const response = - await workspaceClient.statementExecution.executeStatement( - body, - this._createContext(signal), - ); - - if (!response) { - throw ConnectionError.apiFailure("SQL Warehouse"); - } + const response = await this.submitStatement( + workspaceClient, + input, + signal, + ); const status = response.status; const statementId = response.statement_id as string; - span.setAttribute("db.statement_id", statementId); span.addEvent("statement.submitted", { - "db.statement_id": response.statement_id, + "db.statement_id": statementId, "db.status": status?.state, }); - let result: - | sql.StatementResponse - | { result: { statement_id: string; status: sql.StatementStatus } }; + let result: SQLTransformedResponse; switch (status?.state) { case "RUNNING": case "PENDING": span.addEvent("statement.polling_started", { - "db.status": response.status?.state, + "db.status": status?.state, }); - result = await this._pollForStatementResult( + result = await this.pollStatement( workspaceClient, statementId, - this.config.timeout, signal, ); break; case "SUCCEEDED": - result = this._transformDataArray(response); + result = this.transformResult(response); break; case "FAILED": throw ExecutionError.statementFailed(status.error?.message); @@ -207,10 +175,7 @@ export class SQLWarehouseConnector { ); } - const resultData = result.result as any; - const rowCount = - resultData?.data?.length ?? resultData?.data_array?.length ?? 0; - + const rowCount = result.result?.data?.length ?? 0; if (rowCount > 0) { span.setAttribute("db.result.row_count", rowCount); } @@ -223,13 +188,11 @@ export class SQLWarehouseConnector { }); success = true; - // only set success status if not aborted if (!isAborted) { span.setStatus({ code: SpanStatusCode.OK }); } return result; } catch (error) { - // only record error if not already handled by abort if (!isAborted) { span.recordException(error as Error); span.setStatus({ @@ -250,14 +213,12 @@ export class SQLWarehouseConnector { error instanceof Error ? error.message : String(error), ); } finally { - // remove abort handler if (abortHandler && signal) { signal.removeEventListener("abort", abortHandler); } const duration = Date.now() - startTime; - // end span if not already ended by abort handler if (!isAborted) { span.end(); } @@ -278,12 +239,122 @@ export class SQLWarehouseConnector { ); } - private async _pollForStatementResult( + /** + * Submit a statement and return the raw initial response. May already + * be terminal if the warehouse completes within the request's + * `wait_timeout`; otherwise the caller polls via {@link pollStatement}. + */ + async submitStatement( + workspaceClient: WorkspaceClient, + input: sql.ExecuteStatementRequest, + signal?: AbortSignal, + ): Promise { + if (signal?.aborted) { + throw ExecutionError.canceled(); + } + if (!input.statement) { + throw ValidationError.missingField("statement"); + } + if (!input.warehouse_id) { + throw ValidationError.missingField("warehouse_id"); + } + + const body: sql.ExecuteStatementRequest = { + statement: input.statement, + parameters: input.parameters, + warehouse_id: input.warehouse_id, + catalog: input.catalog, + schema: input.schema, + wait_timeout: input.wait_timeout || executeStatementDefaults.wait_timeout, + disposition: input.disposition || executeStatementDefaults.disposition, + format: input.format || executeStatementDefaults.format, + byte_limit: input.byte_limit, + row_limit: input.row_limit, + on_wait_timeout: + input.on_wait_timeout || executeStatementDefaults.on_wait_timeout, + }; + + return this.telemetry.startActiveSpan( + "sql.submit", + { + kind: SpanKind.CLIENT, + attributes: { + "db.system": "databricks", + "db.warehouse_id": body.warehouse_id || "", + "db.catalog": body.catalog ?? "", + "db.schema": body.schema ?? "", + "db.statement": body.statement?.substring(0, 500) || "", + "db.has_parameters": !!body.parameters, + }, + }, + async (span: Span) => { + try { + const response = + await workspaceClient.statementExecution.executeStatement( + body, + this._createContext(signal), + ); + if (!response) { + throw ConnectionError.apiFailure("SQL Warehouse"); + } + if (response.statement_id) { + span.setAttribute("db.statement_id", response.statement_id); + } + if (response.status?.state) { + span.setAttribute("db.status", response.status.state); + } + span.setStatus({ code: SpanStatusCode.OK }); + return response; + } catch (error) { + // Client-initiated cancel isn't a span error. + if (signal?.aborted) { + span.setAttribute("cancelled", true); + span.setStatus({ code: SpanStatusCode.OK }); + } else { + span.recordException(error as Error); + span.setStatus({ + code: SpanStatusCode.ERROR, + message: error instanceof Error ? error.message : String(error), + }); + } + throw error; + } finally { + span.end(); + } + }, + { name: this.name, includePrefix: true }, + ); + } + + /** Single non-blocking status read for a known statement ID. */ + async getStatement( workspaceClient: WorkspaceClient, statementId: string, - timeout = executeStatementDefaults.timeout, signal?: AbortSignal, - ) { + ): Promise { + if (signal?.aborted) { + throw ExecutionError.canceled(); + } + const response = await workspaceClient.statementExecution.getStatement( + { statement_id: statementId }, + this._createContext(signal), + ); + if (!response) { + throw ConnectionError.apiFailure("SQL Warehouse"); + } + return response; + } + + /** + * Block until the statement reaches a terminal state, then transform + * via {@link transformResult}. + */ + async pollStatement( + workspaceClient: WorkspaceClient, + statementId: string, + signal?: AbortSignal, + timeout = this.config.timeout ?? executeStatementDefaults.timeout, + ): Promise { return this.telemetry.startActiveSpan( "sql.poll", { @@ -296,14 +367,13 @@ export class SQLWarehouseConnector { try { const startTime = Date.now(); let delay = 1000; - const maxDelayBetweenPolls = 5000; // max 5 seconds between polls + const maxDelayBetweenPolls = 5000; let pollCount = 0; while (true) { pollCount++; span.setAttribute("db.polling.current_attempt", pollCount); - // check if timeout exceeded const elapsedTime = Date.now() - startTime; if (elapsedTime > timeout) { const error = ExecutionError.statementFailed( @@ -315,10 +385,7 @@ export class SQLWarehouseConnector { } if (signal?.aborted) { - const error = ExecutionError.canceled(); - span.recordException(error); - span.setStatus({ code: SpanStatusCode.ERROR }); - throw error; + throw ExecutionError.canceled(); } span.addEvent("polling.attempt", { @@ -329,9 +396,7 @@ export class SQLWarehouseConnector { const response = await workspaceClient.statementExecution.getStatement( - { - statement_id: statementId, - }, + { statement_id: statementId }, this._createContext(signal), ); if (!response) { @@ -339,7 +404,6 @@ export class SQLWarehouseConnector { } const status = response.status; - span.addEvent("polling.status_check", { "db.status": status?.state, "poll.attempt": pollCount, @@ -348,7 +412,6 @@ export class SQLWarehouseConnector { switch (status?.state) { case "PENDING": case "RUNNING": - // continue polling break; case "SUCCEEDED": span.setAttribute("db.polling.attempts", pollCount); @@ -358,7 +421,7 @@ export class SQLWarehouseConnector { "poll.duration_ms": elapsedTime, }); span.setStatus({ code: SpanStatusCode.OK }); - return this._transformDataArray(response); + return this.transformResult(response); case "FAILED": throw ExecutionError.statementFailed(status.error?.message); case "CANCELED": @@ -371,18 +434,41 @@ export class SQLWarehouseConnector { ); } - // continue polling after delay - await new Promise((resolve) => setTimeout(resolve, delay)); + // ±25% jitter de-syncs concurrent pollers. + const jitterMs = Math.floor(delay * (Math.random() - 0.5) * 0.5); + const sleepMs = Math.max(0, delay + jitterMs); + await new Promise((resolve) => { + if (sleepMs <= 0) { + resolve(); + return; + } + const handle = setTimeout(() => { + signal?.removeEventListener("abort", onAbort); + resolve(); + }, sleepMs); + const onAbort = () => { + clearTimeout(handle); + resolve(); + }; + signal?.addEventListener("abort", onAbort, { once: true }); + }); + if (signal?.aborted) { + throw ExecutionError.canceled(); + } delay = Math.min(delay * 2, maxDelayBetweenPolls); } } catch (error) { - span.recordException(error as Error); - span.setStatus({ - code: SpanStatusCode.ERROR, - message: error instanceof Error ? error.message : String(error), - }); - - // error logging is handled by executeStatement's catch block (gated on isAborted) + // Logging is handled by the caller. + if (signal?.aborted) { + span.setAttribute("cancelled", true); + span.setStatus({ code: SpanStatusCode.OK }); + } else { + span.recordException(error as Error); + span.setStatus({ + code: SpanStatusCode.ERROR, + message: error instanceof Error ? error.message : String(error), + }); + } if (error instanceof AppKitError) { throw error; } @@ -397,24 +483,40 @@ export class SQLWarehouseConnector { ); } - private _transformDataArray(response: sql.StatementResponse) { + /** + * Standard result transform. Returns the same shape in every branch + * (see {@link SQLTransformedResponse}): + * - ARROW_STREAM: top-level `statement_id`/`status` preserved; `manifest` + * and `result.external_links` stripped (pre-signed URLs must not flow + * downstream). Consumer fetches the Arrow buffer via + * {@link getArrowData}. + * - JSON with rows + schema: positional `result.data_array` projected + * into name-keyed `result.data` (JSON-looking STRING values parsed). + * - Otherwise: pass-through. + */ + transformResult(response: sql.StatementResponse): SQLTransformedResponse { if (response.manifest?.format === "ARROW_STREAM") { - return this.updateWithArrowStatus(response); + return { + ...response, + manifest: undefined, + result: { + statement_id: response.statement_id, + status: response.status, + } as SQLTransformedResponse["result"], + }; } if (!response.result?.data_array || !response.manifest?.schema?.columns) { - return response; + return response as SQLTransformedResponse; } const columns = response.manifest.schema.columns; - const transformedData = response.result.data_array.map((row) => { const obj: Record = {}; row.forEach((value, index) => { const column = columns[index]; const columnName = column?.name || `column_${index}`; - // attempt to parse JSON strings for string columns if ( column?.type_name === "STRING" && typeof value === "string" && @@ -424,7 +526,6 @@ export class SQLWarehouseConnector { try { obj[columnName] = JSON.parse(value); } catch { - // if parsing fails, keep as string obj[columnName] = value; } } else { @@ -434,7 +535,6 @@ export class SQLWarehouseConnector { return obj; }); - // remove data_array const { data_array: _data_array, ...restResult } = response.result; return { ...response, @@ -445,20 +545,6 @@ export class SQLWarehouseConnector { }; } - private updateWithArrowStatus(response: sql.StatementResponse): { - result: { statement_id: string; status: sql.StatementStatus }; - } { - return { - result: { - statement_id: response.statement_id as string, - status: { - state: response.status?.state, - error: response.status?.error, - } as sql.StatementStatus, - }, - }; - } - async getArrowData( workspaceClient: WorkspaceClient, jobId: string, @@ -539,7 +625,18 @@ export class SQLWarehouseConnector { ); } - // create context for cancellation token + private get arrowProcessor(): ArrowStreamProcessor { + if (!this._arrowProcessor) { + this._arrowProcessor = new ArrowStreamProcessor({ + timeout: this.config.timeout || executeStatementDefaults.timeout, + maxConcurrentDownloads: + ArrowStreamProcessor.DEFAULT_MAX_CONCURRENT_DOWNLOADS, + retries: ArrowStreamProcessor.DEFAULT_RETRIES, + }); + } + return this._arrowProcessor; + } + private _createContext(signal?: AbortSignal) { return new Context({ cancellationToken: { diff --git a/packages/appkit/src/connectors/tests/sql-warehouse.test.ts b/packages/appkit/src/connectors/tests/sql-warehouse.test.ts index 753d58636..b3943779f 100644 --- a/packages/appkit/src/connectors/tests/sql-warehouse.test.ts +++ b/packages/appkit/src/connectors/tests/sql-warehouse.test.ts @@ -1,9 +1,16 @@ +import { + createFailedSQLResponse, + createSuccessfulSQLResponse, +} from "@tools/test-helpers"; import { beforeEach, describe, expect, test, vi } from "vitest"; import { SQLWarehouseConnector } from "../sql-warehouse"; -// Mock telemetry to pass through span callbacks -vi.mock("../../telemetry", () => { - const mockSpan = { +// Pass-through telemetry stub: invokes the span callback with a no-op span. +// `mockSpan` is shared across all spans opened in a test (sql.query, +// sql.submit, sql.poll), so call counts accumulate; assert on call args +// rather than counts when verifying that a specific event fired. +const { mockSpan } = vi.hoisted(() => ({ + mockSpan: { end: vi.fn(), setAttribute: vi.fn(), setAttributes: vi.fn(), @@ -12,53 +19,446 @@ vi.mock("../../telemetry", () => { addEvent: vi.fn(), isRecording: vi.fn().mockReturnValue(true), spanContext: vi.fn(), - }; + }, +})); - return { - TelemetryManager: { - getProvider: vi.fn(() => ({ - startActiveSpan: vi - .fn() - .mockImplementation(async (_name, _options, fn) => { - return await fn(mockSpan); - }), - getMeter: vi.fn().mockReturnValue({ - createCounter: vi.fn().mockReturnValue({ add: vi.fn() }), - createHistogram: vi.fn().mockReturnValue({ record: vi.fn() }), +vi.mock("../../telemetry", () => ({ + TelemetryManager: { + getProvider: vi.fn(() => ({ + startActiveSpan: vi + .fn() + .mockImplementation(async (_name, _options, fn) => { + return await fn(mockSpan); }), - })), - }, - SpanKind: { CLIENT: 2 }, - SpanStatusCode: { OK: 1, ERROR: 2 }, + getMeter: vi.fn().mockReturnValue({ + createCounter: vi.fn().mockReturnValue({ add: vi.fn() }), + createHistogram: vi.fn().mockReturnValue({ record: vi.fn() }), + }), + })), + }, + SpanKind: { CLIENT: 2 }, + SpanStatusCode: { OK: 1, ERROR: 2 }, +})); + +/** Minimal `WorkspaceClient` stub with `executeStatement` / `getStatement` mocks. */ +function makeClient() { + const executeStatement = vi.fn(); + const getStatement = vi.fn(); + return { + client: { + statementExecution: { executeStatement, getStatement }, + config: { host: "https://test.databricks.com" }, + } as any, + mocks: { executeStatement, getStatement }, }; -}); +} describe("SQLWarehouseConnector", () => { - describe("error log redaction", () => { - let connector: SQLWarehouseConnector; + let connector: SQLWarehouseConnector; + + beforeEach(() => { + vi.clearAllMocks(); + connector = new SQLWarehouseConnector({ timeout: 5000 }); + }); + + describe("submitStatement", () => { + test("rejects when the statement is missing", async () => { + const { client } = makeClient(); + await expect( + connector.submitStatement(client, { + statement: "", + warehouse_id: "w-1", + }), + ).rejects.toThrow(/statement/); + }); + + test("rejects when the warehouse_id is missing", async () => { + const { client } = makeClient(); + await expect( + connector.submitStatement(client, { + statement: "SELECT 1", + warehouse_id: "", + }), + ).rejects.toThrow(/warehouse_id/); + }); + + test("rejects when the signal is already aborted", async () => { + const { client } = makeClient(); + const ac = new AbortController(); + ac.abort(); + await expect( + connector.submitStatement( + client, + { statement: "SELECT 1", warehouse_id: "w-1" }, + ac.signal, + ), + ).rejects.toThrow(); + }); + + test("returns the raw response on success without polling", async () => { + const { client, mocks } = makeClient(); + const response = createSuccessfulSQLResponse([["a"]], [{ name: "col" }]); + mocks.executeStatement.mockResolvedValueOnce(response); - beforeEach(() => { - vi.clearAllMocks(); - connector = new SQLWarehouseConnector({ timeout: 5000 }); + const result = await connector.submitStatement(client, { + statement: "SELECT 1", + warehouse_id: "w-1", + }); + + expect(result).toBe(response); + expect(mocks.executeStatement).toHaveBeenCalledTimes(1); + expect(mocks.getStatement).not.toHaveBeenCalled(); }); - test("should not log the SQL statement on executeStatement error", async () => { + test("propagates a null response as a SQL Warehouse api failure", async () => { + const { client, mocks } = makeClient(); + mocks.executeStatement.mockResolvedValueOnce(null); + + await expect( + connector.submitStatement(client, { + statement: "SELECT 1", + warehouse_id: "w-1", + }), + ).rejects.toThrow(/SQL Warehouse/); + }); + }); + + describe("getStatement", () => { + test("rejects when the signal is already aborted", async () => { + const { client } = makeClient(); + const ac = new AbortController(); + ac.abort(); + await expect( + connector.getStatement(client, "stmt-1", ac.signal), + ).rejects.toThrow(); + }); + + test("returns the raw response", async () => { + const { client, mocks } = makeClient(); + const response = createSuccessfulSQLResponse([["x"]], [{ name: "col" }]); + mocks.getStatement.mockResolvedValueOnce(response); + + const result = await connector.getStatement(client, "stmt-1"); + expect(result).toBe(response); + expect(mocks.getStatement).toHaveBeenCalledWith( + { statement_id: "stmt-1" }, + expect.anything(), + ); + }); + + test("rejects when the response is null", async () => { + const { client, mocks } = makeClient(); + mocks.getStatement.mockResolvedValueOnce(null); + + await expect(connector.getStatement(client, "stmt-1")).rejects.toThrow( + /SQL Warehouse/, + ); + }); + }); + + describe("pollStatement", () => { + test("returns transformed result when status is SUCCEEDED on first poll", async () => { + const { client, mocks } = makeClient(); + mocks.getStatement.mockResolvedValueOnce( + createSuccessfulSQLResponse( + [["alice", "30"]], + [{ name: "name" }, { name: "age" }], + ), + ); + + const result = await connector.pollStatement(client, "stmt-1"); + expect((result as any).result.data).toEqual([ + { name: "alice", age: "30" }, + ]); + }); + + test("throws statementFailed when status is FAILED", async () => { + const { client, mocks } = makeClient(); + mocks.getStatement.mockResolvedValueOnce( + createFailedSQLResponse("Table not found"), + ); + + await expect(connector.pollStatement(client, "stmt-1")).rejects.toThrow( + /Table not found/, + ); + }); + + test("throws canceled when status is CANCELED", async () => { + const { client, mocks } = makeClient(); + mocks.getStatement.mockResolvedValueOnce({ + status: { state: "CANCELED" }, + statement_id: "stmt-1", + }); + + await expect(connector.pollStatement(client, "stmt-1")).rejects.toThrow(); + }); + + test("throws when the polling timeout is exceeded", async () => { + // timeout: 0 trips the elapsed-time check on the second iteration. + const tight = new SQLWarehouseConnector({ timeout: 0 }); + const { client, mocks } = makeClient(); + mocks.getStatement.mockResolvedValue({ + status: { state: "RUNNING" }, + statement_id: "stmt-1", + }); + + await expect( + tight.pollStatement(client, "stmt-1", undefined, 0), + ).rejects.toThrow(/Polling timeout exceeded/); + }); + + test("throws when the signal aborts during polling", async () => { + const { client, mocks } = makeClient(); + mocks.getStatement.mockResolvedValueOnce({ + status: { state: "RUNNING" }, + statement_id: "stmt-1", + }); + + const ac = new AbortController(); + ac.abort(); + + await expect( + connector.pollStatement(client, "stmt-1", ac.signal), + ).rejects.toThrow(); + }); + }); + + describe("transformResult", () => { + test("projects data_array into name-keyed rows", () => { + const response = createSuccessfulSQLResponse( + [ + ["alice", "30"], + ["bob", "25"], + ], + [{ name: "name" }, { name: "age" }], + ); + + const result = connector.transformResult(response as any) as any; + expect(result.result.data).toEqual([ + { name: "alice", age: "30" }, + { name: "bob", age: "25" }, + ]); + expect(result.result.data_array).toBeUndefined(); + }); + + test("parses STRING columns whose value looks like JSON", () => { + const response = createSuccessfulSQLResponse( + [['{"a":1}']], + [{ name: "payload", type_name: "STRING" }], + ); + + const result = connector.transformResult(response as any) as any; + expect(result.result.data[0].payload).toEqual({ a: 1 }); + }); + + test("keeps the raw string when JSON parsing fails", () => { + const response = createSuccessfulSQLResponse( + [["{not-json"]], + [{ name: "payload", type_name: "STRING" }], + ); + + const result = connector.transformResult(response as any) as any; + expect(result.result.data[0].payload).toBe("{not-json"); + }); + + test("preserves top-level fields and strips external_links for ARROW_STREAM", () => { + const response = { + status: { state: "SUCCEEDED" }, + statement_id: "stmt-arrow-1", + manifest: { format: "ARROW_STREAM" }, + result: { external_links: [{ external_link: "https://signed" }] }, + } as any; + + const result = connector.transformResult(response) as any; + // top-level statement_id/status preserved so consumers don't need + // to know which branch ran. + expect(result.statement_id).toBe("stmt-arrow-1"); + expect(result.status).toEqual({ state: "SUCCEEDED" }); + // manifest dropped (large; not needed by Arrow consumers). + expect(result.manifest).toBeUndefined(); + // result mirrors the handle but never exposes external_links. + expect(result.result.statement_id).toBe("stmt-arrow-1"); + expect(result.result.status).toEqual({ state: "SUCCEEDED" }); + expect(result.result.external_links).toBeUndefined(); + }); + + test("passes the response through when there is no data_array", () => { + const response = { + status: { state: "SUCCEEDED" }, + statement_id: "stmt-1", + } as any; + + const result = connector.transformResult(response); + expect(result).toBe(response); + }); + }); + + describe("executeStatement", () => { + test("transforms inline when submit returns SUCCEEDED", async () => { + const { client, mocks } = makeClient(); + mocks.executeStatement.mockResolvedValueOnce( + createSuccessfulSQLResponse([["a"]], [{ name: "col" }]), + ); + + const result = (await connector.executeStatement(client, { + statement: "SELECT 'a' AS col", + warehouse_id: "w-1", + })) as any; + + expect(result.result.data).toEqual([{ col: "a" }]); + expect(mocks.getStatement).not.toHaveBeenCalled(); + }); + + test("polls when submit returns RUNNING and returns the polled result", async () => { + const { client, mocks } = makeClient(); + mocks.executeStatement.mockResolvedValueOnce({ + status: { state: "RUNNING" }, + statement_id: "stmt-2", + }); + mocks.getStatement.mockResolvedValueOnce( + createSuccessfulSQLResponse([["b"]], [{ name: "col" }]), + ); + + const result = (await connector.executeStatement(client, { + statement: "SELECT 'b' AS col", + warehouse_id: "w-1", + })) as any; + + expect(result.result.data).toEqual([{ col: "b" }]); + expect(mocks.executeStatement).toHaveBeenCalledTimes(1); + expect(mocks.getStatement).toHaveBeenCalledTimes(1); + }); + + test("rejects when the signal is already aborted", async () => { + const { client } = makeClient(); + const ac = new AbortController(); + ac.abort(); + await expect( + connector.executeStatement( + client, + { statement: "SELECT 1", warehouse_id: "w-1" }, + ac.signal, + ), + ).rejects.toThrow(); + }); + }); + + describe("span events", () => { + test("emits statement.submitting and statement.submitted on submit path", async () => { + const { client, mocks } = makeClient(); + mocks.executeStatement.mockResolvedValueOnce( + createSuccessfulSQLResponse([["a"]], [{ name: "col" }]), + ); + + await connector.executeStatement(client, { + statement: "SELECT 'a' AS col", + warehouse_id: "w-1", + }); + + const eventNames = mockSpan.addEvent.mock.calls.map((c) => c[0]); + expect(eventNames).toContain("statement.submitting"); + expect(eventNames).toContain("statement.submitted"); + }); + + test("emits statement.polling_started when submit returns RUNNING", async () => { + const { client, mocks } = makeClient(); + mocks.executeStatement.mockResolvedValueOnce({ + status: { state: "RUNNING" }, + statement_id: "stmt-1", + }); + mocks.getStatement.mockResolvedValueOnce( + createSuccessfulSQLResponse([["b"]], [{ name: "col" }]), + ); + + await connector.executeStatement(client, { + statement: "SELECT 'b' AS col", + warehouse_id: "w-1", + }); + + const eventNames = mockSpan.addEvent.mock.calls.map((c) => c[0]); + expect(eventNames).toContain("statement.polling_started"); + }); + + test("does not emit polling_started when submit already returned SUCCEEDED", async () => { + const { client, mocks } = makeClient(); + mocks.executeStatement.mockResolvedValueOnce( + createSuccessfulSQLResponse([["a"]], [{ name: "col" }]), + ); + + await connector.executeStatement(client, { + statement: "SELECT 'a' AS col", + warehouse_id: "w-1", + }); + + const eventNames = mockSpan.addEvent.mock.calls.map((c) => c[0]); + expect(eventNames).not.toContain("statement.polling_started"); + expect(mocks.getStatement).not.toHaveBeenCalled(); + }); + }); + + describe("abort span semantics", () => { + test("submitStatement abort marks span OK + cancelled, not ERROR", async () => { + const { client, mocks } = makeClient(); + const ac = new AbortController(); + // Abort triggers from inside the SDK call — mimics warehouse SDK + // surfacing the cancellation token. + mocks.executeStatement.mockImplementationOnce(async () => { + ac.abort(); + throw new Error("cancelled by token"); + }); + + await expect( + connector.submitStatement( + client, + { statement: "SELECT 1", warehouse_id: "w-1" }, + ac.signal, + ), + ).rejects.toThrow(); + + const cancelledCall = mockSpan.setAttribute.mock.calls.find( + (c) => c[0] === "cancelled" && c[1] === true, + ); + expect(cancelledCall).toBeTruthy(); + // OK = 1 in the stubbed SpanStatusCode. + const okStatus = mockSpan.setStatus.mock.calls.find( + (c) => c[0].code === 1, + ); + expect(okStatus).toBeTruthy(); + }); + + test("pollStatement abort marks span OK + cancelled, not ERROR", async () => { + const { client, mocks } = makeClient(); + const ac = new AbortController(); + mocks.getStatement.mockImplementationOnce(async () => { + ac.abort(); + return { status: { state: "RUNNING" }, statement_id: "stmt-1" }; + }); + + await expect( + connector.pollStatement(client, "stmt-1", ac.signal), + ).rejects.toThrow(); + + const cancelledCall = mockSpan.setAttribute.mock.calls.find( + (c) => c[0] === "cancelled" && c[1] === true, + ); + expect(cancelledCall).toBeTruthy(); + }); + }); + + describe("error log redaction", () => { + test("does not log the SQL statement on executeStatement error", async () => { const errorSpy = vi.spyOn(console, "error").mockImplementation(() => {}); const sensitiveStatement = "SELECT password, ssn FROM users WHERE email = 'admin@test.com'"; - const mockWorkspaceClient = { - statementExecution: { - executeStatement: vi - .fn() - .mockRejectedValue(new Error("warehouse unavailable")), - }, - config: { host: "https://test.databricks.com" }, - }; + const { client, mocks } = makeClient(); + mocks.executeStatement.mockRejectedValue( + new Error("warehouse unavailable"), + ); await expect( - connector.executeStatement(mockWorkspaceClient as any, { + connector.executeStatement(client, { statement: sensitiveStatement, warehouse_id: "test-warehouse", }), @@ -68,10 +468,7 @@ describe("SQLWarehouseConnector", () => { .map((call) => call.join(" ")) .join(" "); - // Should log the error message expect(loggedOutput).toContain("warehouse unavailable"); - - // Should NOT log the SQL statement expect(loggedOutput).not.toContain("password"); expect(loggedOutput).not.toContain("ssn"); expect(loggedOutput).not.toContain("admin@test.com"); @@ -79,22 +476,18 @@ describe("SQLWarehouseConnector", () => { errorSpy.mockRestore(); }); - test("should not log the SQL statement on polling error", async () => { + test("does not log the SQL statement on polling error", async () => { const errorSpy = vi.spyOn(console, "error").mockImplementation(() => {}); - const mockWorkspaceClient = { - statementExecution: { - executeStatement: vi.fn().mockResolvedValue({ - statement_id: "stmt-123", - status: { state: "RUNNING" }, - }), - getStatement: vi.fn().mockRejectedValue(new Error("polling timeout")), - }, - config: { host: "https://test.databricks.com" }, - }; + const { client, mocks } = makeClient(); + mocks.executeStatement.mockResolvedValue({ + statement_id: "stmt-123", + status: { state: "RUNNING" }, + }); + mocks.getStatement.mockRejectedValue(new Error("polling timeout")); await expect( - connector.executeStatement(mockWorkspaceClient as any, { + connector.executeStatement(client, { statement: "SELECT secret_data FROM vault", warehouse_id: "test-warehouse", }), @@ -104,12 +497,8 @@ describe("SQLWarehouseConnector", () => { .map((call) => call.join(" ")) .join(" "); - // Errors raised inside polling bubble up to executeStatement's catch, - // which is the single point that logs (gated on isAborted). The poll - // layer no longer logs to avoid double-logging the same failure. + // Polling errors bubble to executeStatement's catch — the single point that logs. expect(loggedOutput).toContain("polling timeout"); - - // Should NOT log the SQL statement expect(loggedOutput).not.toContain("secret_data"); expect(loggedOutput).not.toContain("vault");