diff --git a/README.md b/README.md index 5d4b0d5..22aa1f9 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,28 @@ async def main(): asyncio.run(main()) ``` +### Gas Sponsorship Modes + +By default, write operations continue to use the existing gas station flow based on +`config.gas_station_api_key` / `config.gas_station_url`. No configuration changes are required. + +To use a local fee payer account instead of gas station signing, pass +`fee_payer_account` in `BaseSDKOptions`: + +```python +from aptos_sdk.account import Account +from decibel import BaseSDKOptions, DecibelWriteDex, TESTNET_CONFIG + +sender = Account.generate() +fee_payer = Account.generate() + +write = DecibelWriteDex( + TESTNET_CONFIG, + sender, + opts=BaseSDKOptions(fee_payer_account=fee_payer), +) +``` + ### WebSocket Streaming ```python diff --git a/src/decibel/_base.py b/src/decibel/_base.py index 9685fa2..9482d91 100644 --- a/src/decibel/_base.py +++ b/src/decibel/_base.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, cast import httpx +from aptos_sdk.account_address import AccountAddress from aptos_sdk.async_client import RestClient from aptos_sdk.authenticator import ( AccountAuthenticator, @@ -31,7 +32,6 @@ if TYPE_CHECKING: from aptos_sdk.account import Account - from aptos_sdk.account_address import AccountAddress from ._constants import DecibelConfig from ._gas_price_manager import GasPriceManager, GasPriceManagerSync @@ -49,12 +49,14 @@ DEFAULT_MAX_GAS_AMOUNT = 200_000 DEFAULT_GAS_ESTIMATE = 100 MAX_GAS_UNITS_LIMIT = 2_000_000 +FEE_PAYER_PLACEHOLDER_ADDRESS = AccountAddress.from_str("0x0") @dataclass class BaseSDKOptions: skip_simulate: bool = False no_fee_payer: bool = False + fee_payer_account: Account | None = None node_api_key: str | None = None gas_price_manager: GasPriceManager | None = None time_delta_ms: int = 0 @@ -64,6 +66,7 @@ class BaseSDKOptions: class BaseSDKOptionsSync: skip_simulate: bool = False no_fee_payer: bool = False + fee_payer_account: Account | None = None node_api_key: str | None = None gas_price_manager: GasPriceManagerSync | None = None time_delta_ms: int = 0 @@ -86,10 +89,14 @@ def __init__( opts = opts or BaseSDKOptions() self._skip_simulate = opts.skip_simulate self._no_fee_payer = opts.no_fee_payer + self._fee_payer_account = opts.fee_payer_account self._node_api_key = opts.node_api_key self._gas_price_manager = opts.gas_price_manager self._time_delta_ms = opts.time_delta_ms + if self._no_fee_payer and self._fee_payer_account is not None: + raise ValueError("no_fee_payer and fee_payer_account cannot be used together") + if config.chain_id is None: logger.warning( "Using default ABI for unknown chain_id, " @@ -156,7 +163,7 @@ async def build_tx( else: gas_unit_price = await self._fetch_gas_price_estimation() - return build_simple_transaction_sync( + transaction = build_simple_transaction_sync( sender=sender, data=data, chain_id=self._chain_id, @@ -167,6 +174,8 @@ async def build_tx( time_delta_ms=self._time_delta_ms, max_gas_amount=max_gas_amount or DEFAULT_MAX_GAS_AMOUNT, ) + self._apply_fee_payer_address_override(transaction) + return transaction async def submit_tx( self, @@ -175,12 +184,16 @@ async def submit_tx( *, txn_submit_timeout: float | None = None, ) -> PendingTransactionResponse: + self._validate_fee_payer_address(transaction) + if self._no_fee_payer: return await self._submit_direct(transaction, sender_authenticator, txn_submit_timeout) return await submit_fee_paid_transaction( self._config, transaction, sender_authenticator, + fee_payer_account=self._fee_payer_account, + node_api_key=self._node_api_key, txn_submit_timeout=txn_submit_timeout, ) @@ -273,6 +286,40 @@ def _sign_transaction( else: return raw_txn.sign(signer.private_key) + def _apply_fee_payer_address_override(self, transaction: SimpleTransaction) -> None: + if self._fee_payer_account is None: + return + + expected_fee_payer = self._fee_payer_account.address() + current_fee_payer = transaction.fee_payer_address + + if current_fee_payer == expected_fee_payer: + return + + if current_fee_payer == FEE_PAYER_PLACEHOLDER_ADDRESS: + transaction.fee_payer_address = expected_fee_payer + return + + if current_fee_payer is None: + raise ValueError( + "transaction.fee_payer_address must be set when fee_payer_account is used" + ) + + raise ValueError("transaction.fee_payer_address does not match fee_payer_account") + + def _validate_fee_payer_address(self, transaction: SimpleTransaction) -> None: + if self._fee_payer_account is None: + return + + expected_fee_payer = self._fee_payer_account.address() + if transaction.fee_payer_address is None: + raise ValueError( + "transaction.fee_payer_address must be set when fee_payer_account is used" + ) + + if transaction.fee_payer_address != expected_fee_payer: + raise ValueError("transaction.fee_payer_address does not match fee_payer_account") + async def _fetch_gas_price_estimation(self) -> int: url = f"{self._config.fullnode_url}/estimate_gas_price" headers = self._build_node_headers() @@ -458,11 +505,15 @@ def __init__( opts = opts or BaseSDKOptionsSync() self._skip_simulate = opts.skip_simulate self._no_fee_payer = opts.no_fee_payer + self._fee_payer_account = opts.fee_payer_account self._node_api_key = opts.node_api_key self._gas_price_manager = opts.gas_price_manager self._time_delta_ms = opts.time_delta_ms self._http_client = opts.http_client + if self._no_fee_payer and self._fee_payer_account is not None: + raise ValueError("no_fee_payer and fee_payer_account cannot be used together") + if config.chain_id is None: logger.warning( "Using default ABI for unknown chain_id, " @@ -525,7 +576,7 @@ def build_tx( else: gas_unit_price = self._fetch_gas_price_estimation() - return build_simple_transaction_sync( + transaction = build_simple_transaction_sync( sender=sender, data=data, chain_id=self._chain_id, @@ -536,6 +587,8 @@ def build_tx( time_delta_ms=self._time_delta_ms, max_gas_amount=max_gas_amount or DEFAULT_MAX_GAS_AMOUNT, ) + self._apply_fee_payer_address_override(transaction) + return transaction def submit_tx( self, @@ -544,12 +597,16 @@ def submit_tx( *, txn_submit_timeout: float | None = None, ) -> PendingTransactionResponse: + self._validate_fee_payer_address(transaction) + if self._no_fee_payer: return self._submit_direct(transaction, sender_authenticator, txn_submit_timeout) return submit_fee_paid_transaction_sync( self._config, transaction, sender_authenticator, + fee_payer_account=self._fee_payer_account, + node_api_key=self._node_api_key, txn_submit_timeout=txn_submit_timeout, ) @@ -640,6 +697,40 @@ def _sign_transaction( else: return raw_txn.sign(signer.private_key) + def _apply_fee_payer_address_override(self, transaction: SimpleTransaction) -> None: + if self._fee_payer_account is None: + return + + expected_fee_payer = self._fee_payer_account.address() + current_fee_payer = transaction.fee_payer_address + + if current_fee_payer == expected_fee_payer: + return + + if current_fee_payer == FEE_PAYER_PLACEHOLDER_ADDRESS: + transaction.fee_payer_address = expected_fee_payer + return + + if current_fee_payer is None: + raise ValueError( + "transaction.fee_payer_address must be set when fee_payer_account is used" + ) + + raise ValueError("transaction.fee_payer_address does not match fee_payer_account") + + def _validate_fee_payer_address(self, transaction: SimpleTransaction) -> None: + if self._fee_payer_account is None: + return + + expected_fee_payer = self._fee_payer_account.address() + if transaction.fee_payer_address is None: + raise ValueError( + "transaction.fee_payer_address must be set when fee_payer_account is used" + ) + + if transaction.fee_payer_address != expected_fee_payer: + raise ValueError("transaction.fee_payer_address does not match fee_payer_account") + def _fetch_gas_price_estimation(self) -> int: url = f"{self._config.fullnode_url}/estimate_gas_price" headers = self._build_node_headers() diff --git a/src/decibel/_fee_pay.py b/src/decibel/_fee_pay.py index 9169400..f6690a7 100644 --- a/src/decibel/_fee_pay.py +++ b/src/decibel/_fee_pay.py @@ -3,10 +3,13 @@ from typing import TYPE_CHECKING, Any, cast import httpx +from aptos_sdk.authenticator import Authenticator, FeePayerAuthenticator from aptos_sdk.bcs import Serializer +from aptos_sdk.transactions import FeePayerRawTransaction, SignedTransaction from pydantic import BaseModel if TYPE_CHECKING: + from aptos_sdk.account import Account from aptos_sdk.authenticator import AccountAuthenticator from ._constants import DecibelConfig @@ -33,9 +36,22 @@ async def submit_fee_paid_transaction( transaction: SimpleTransaction, sender_authenticator: AccountAuthenticator, *, + fee_payer_account: Account | None = None, + node_api_key: str | None = None, client: httpx.AsyncClient | None = None, txn_submit_timeout: float | None = None, ) -> PendingTransactionResponse: + if fee_payer_account is not None: + return await _submit_via_local_fee_payer( + config, + transaction, + sender_authenticator, + fee_payer_account=fee_payer_account, + node_api_key=node_api_key, + client=client, + txn_submit_timeout=txn_submit_timeout, + ) + if config.gas_station_api_key: return await _submit_via_gas_station_api( config, @@ -62,9 +78,22 @@ def submit_fee_paid_transaction_sync( transaction: SimpleTransaction, sender_authenticator: AccountAuthenticator, *, + fee_payer_account: Account | None = None, + node_api_key: str | None = None, client: httpx.Client | None = None, txn_submit_timeout: float | None = None, ) -> PendingTransactionResponse: + if fee_payer_account is not None: + return _submit_via_local_fee_payer_sync( + config, + transaction, + sender_authenticator, + fee_payer_account=fee_payer_account, + node_api_key=node_api_key, + client=client, + txn_submit_timeout=txn_submit_timeout, + ) + if config.gas_station_api_key: return _submit_via_gas_station_api_sync( config, @@ -302,6 +331,141 @@ def _submit_via_legacy_fee_payer_sync( ) +async def _submit_via_local_fee_payer( + config: DecibelConfig, + transaction: SimpleTransaction, + sender_authenticator: AccountAuthenticator, + *, + fee_payer_account: Account, + node_api_key: str | None = None, + client: httpx.AsyncClient | None = None, + txn_submit_timeout: float | None = None, +) -> PendingTransactionResponse: + url = f"{config.fullnode_url}/transactions" + headers = {"Content-Type": "application/x.aptos.signed_transaction+bcs"} + if node_api_key: + headers["x-api-key"] = node_api_key + bcs_bytes = _build_fee_payer_signed_transaction_bytes( + transaction, + sender_authenticator, + fee_payer_account, + ) + + if client is not None: + response = await client.post( + url, + content=bcs_bytes, + headers=headers, + timeout=txn_submit_timeout, + ) + else: + async with httpx.AsyncClient() as temp_client: + response = await temp_client.post( + url, + content=bcs_bytes, + headers=headers, + timeout=txn_submit_timeout, + ) + + if not response.is_success: + raise ValueError( + f"Local fee payer submission failed: {response.status_code} - {response.text}" + ) + + data = cast("dict[str, Any]", response.json()) + raw_txn = transaction.raw_transaction + return PendingTransactionResponse( + hash=str(data.get("hash", "")), + sender=str(raw_txn.sender), + sequence_number=str(raw_txn.sequence_number), + max_gas_amount=str(raw_txn.max_gas_amount), + gas_unit_price=str(raw_txn.gas_unit_price), + expiration_timestamp_secs=str(raw_txn.expiration_timestamps_secs), + ) + + +def _submit_via_local_fee_payer_sync( + config: DecibelConfig, + transaction: SimpleTransaction, + sender_authenticator: AccountAuthenticator, + *, + fee_payer_account: Account, + node_api_key: str | None = None, + client: httpx.Client | None = None, + txn_submit_timeout: float | None = None, +) -> PendingTransactionResponse: + url = f"{config.fullnode_url}/transactions" + headers = {"Content-Type": "application/x.aptos.signed_transaction+bcs"} + if node_api_key: + headers["x-api-key"] = node_api_key + bcs_bytes = _build_fee_payer_signed_transaction_bytes( + transaction, + sender_authenticator, + fee_payer_account, + ) + + if client is not None: + response = client.post( + url, + content=bcs_bytes, + headers=headers, + timeout=txn_submit_timeout, + ) + else: + with httpx.Client() as temp_client: + response = temp_client.post( + url, + content=bcs_bytes, + headers=headers, + timeout=txn_submit_timeout, + ) + + if not response.is_success: + raise ValueError( + f"Local fee payer submission failed: {response.status_code} - {response.text}" + ) + + data = cast("dict[str, Any]", response.json()) + raw_txn = transaction.raw_transaction + return PendingTransactionResponse( + hash=str(data.get("hash", "")), + sender=str(raw_txn.sender), + sequence_number=str(raw_txn.sequence_number), + max_gas_amount=str(raw_txn.max_gas_amount), + gas_unit_price=str(raw_txn.gas_unit_price), + expiration_timestamp_secs=str(raw_txn.expiration_timestamps_secs), + ) + + +def _build_fee_payer_signed_transaction_bytes( + transaction: SimpleTransaction, + sender_authenticator: AccountAuthenticator, + fee_payer_account: Account, +) -> bytes: + if transaction.fee_payer_address is None: + raise ValueError("transaction.fee_payer_address must be set for local fee payer submission") + + fee_payer_address = fee_payer_account.address() + if transaction.fee_payer_address != fee_payer_address: + raise ValueError("transaction.fee_payer_address does not match fee_payer_account") + + fee_payer_raw_txn = FeePayerRawTransaction( + raw_transaction=transaction.raw_transaction, + secondary_signers=[], + fee_payer=fee_payer_address, + ) + fee_payer_authenticator = fee_payer_raw_txn.sign(fee_payer_account.private_key) + + authenticator = Authenticator( + FeePayerAuthenticator( + sender=sender_authenticator, + secondary_signers=[], + fee_payer=(fee_payer_address, fee_payer_authenticator), + ) + ) + return SignedTransaction(transaction.raw_transaction, authenticator).bytes() + + def _get_default_gas_station_url(config: DecibelConfig) -> str: from ._constants import Network diff --git a/tests/test_sponsorship.py b/tests/test_sponsorship.py new file mode 100644 index 0000000..c545f24 --- /dev/null +++ b/tests/test_sponsorship.py @@ -0,0 +1,660 @@ +from __future__ import annotations + +from dataclasses import replace +from types import SimpleNamespace +from typing import Any + +import pytest +from aptos_sdk.account import Account +from aptos_sdk.account_address import AccountAddress + +import decibel._base as base_module +import decibel._fee_pay as fee_pay_module +from decibel._base import BaseSDK, BaseSDKOptions, BaseSDKOptionsSync, BaseSDKSync +from decibel._constants import TESTNET_CONFIG +from decibel._fee_pay import ( + PendingTransactionResponse, + submit_fee_paid_transaction, + submit_fee_paid_transaction_sync, +) +from decibel._transaction_builder import InputEntryFunctionData + + +class FakeResponse: + def __init__( + self, + *, + is_success: bool = True, + payload: dict[str, Any] | None = None, + status_code: int = 200, + text: str = "ok", + ) -> None: + self.is_success = is_success + self._payload = payload or {} + self.status_code = status_code + self.text = text + + def json(self) -> dict[str, Any]: + return self._payload + + +class RecordingAsyncClient: + def __init__(self, response: FakeResponse) -> None: + self._response = response + self.calls: list[dict[str, Any]] = [] + + async def post( + self, + url: str, + *, + content: bytes, + headers: dict[str, str], + timeout: float | None = None, + ) -> FakeResponse: + self.calls.append( + { + "url": url, + "content": content, + "headers": headers, + "timeout": timeout, + } + ) + return self._response + + +class RecordingSyncClient: + def __init__(self, response: FakeResponse) -> None: + self._response = response + self.calls: list[dict[str, Any]] = [] + + def post( + self, + url: str, + *, + content: bytes, + headers: dict[str, str], + timeout: float | None = None, + ) -> FakeResponse: + self.calls.append( + { + "url": url, + "content": content, + "headers": headers, + "timeout": timeout, + } + ) + return self._response + + +def _pending_response(hash_value: str = "0x1") -> PendingTransactionResponse: + return PendingTransactionResponse( + hash=hash_value, + sender="0x1", + sequence_number="1", + max_gas_amount="1", + gas_unit_price="1", + expiration_timestamp_secs="1", + ) + + +def test_base_sdk_rejects_conflicting_fee_payer_options() -> None: + with pytest.raises(ValueError, match="cannot be used together"): + BaseSDK( + TESTNET_CONFIG, + Account.generate(), + BaseSDKOptions( + no_fee_payer=True, + fee_payer_account=Account.generate(), + ), + ) + + +def test_base_sdk_sync_rejects_conflicting_fee_payer_options() -> None: + with pytest.raises(ValueError, match="cannot be used together"): + BaseSDKSync( + TESTNET_CONFIG, + Account.generate(), + BaseSDKOptionsSync( + no_fee_payer=True, + fee_payer_account=Account.generate(), + ), + ) + + +@pytest.mark.asyncio +async def test_submit_tx_uses_direct_path_when_no_fee_payer( + monkeypatch: pytest.MonkeyPatch, +) -> None: + sdk = BaseSDK( + TESTNET_CONFIG, + Account.generate(), + BaseSDKOptions(no_fee_payer=True), + ) + tx = SimpleNamespace() + sender_authenticator = SimpleNamespace() + called = {"direct": False} + + async def fake_submit_direct( + transaction: Any, + authenticator: Any, + txn_submit_timeout: float | None = None, + ) -> PendingTransactionResponse: + called["direct"] = True + assert transaction is tx + assert authenticator is sender_authenticator + assert txn_submit_timeout == 2.5 + return _pending_response("0xdirect") + + async def fake_fee_paid(*args: Any, **kwargs: Any) -> PendingTransactionResponse: + raise AssertionError("fee-paid path should not be used when no_fee_payer=True") + + monkeypatch.setattr(sdk, "_submit_direct", fake_submit_direct) + monkeypatch.setattr(base_module, "submit_fee_paid_transaction", fake_fee_paid) + + response = await sdk.submit_tx(tx, sender_authenticator, txn_submit_timeout=2.5) + assert response.hash == "0xdirect" + assert called["direct"] is True + + +@pytest.mark.asyncio +async def test_submit_tx_passes_fee_payer_account_to_fee_paid_submitter( + monkeypatch: pytest.MonkeyPatch, +) -> None: + expected_fee_payer_account = Account.generate() + sdk = BaseSDK( + TESTNET_CONFIG, + Account.generate(), + BaseSDKOptions(fee_payer_account=expected_fee_payer_account, node_api_key="node-key"), + ) + captured: dict[str, Any] = {} + + async def fake_fee_paid( + config: Any, + transaction: Any, + sender_authenticator: Any, + *, + fee_payer_account: Account | None = None, + node_api_key: str | None = None, + txn_submit_timeout: float | None = None, + ) -> PendingTransactionResponse: + captured["config"] = config + captured["transaction"] = transaction + captured["sender_authenticator"] = sender_authenticator + captured["fee_payer_account"] = fee_payer_account + captured["node_api_key"] = node_api_key + captured["txn_submit_timeout"] = txn_submit_timeout + return _pending_response("0xfee") + + monkeypatch.setattr(base_module, "submit_fee_paid_transaction", fake_fee_paid) + + tx = SimpleNamespace(fee_payer_address=expected_fee_payer_account.address()) + sender_authenticator = SimpleNamespace() + response = await sdk.submit_tx(tx, sender_authenticator, txn_submit_timeout=3.0) + + assert response.hash == "0xfee" + assert captured["config"] == TESTNET_CONFIG + assert captured["transaction"] is tx + assert captured["sender_authenticator"] is sender_authenticator + assert captured["fee_payer_account"] is expected_fee_payer_account + assert captured["node_api_key"] == "node-key" + assert captured["txn_submit_timeout"] == 3.0 + assert tx.fee_payer_address == expected_fee_payer_account.address() + + +def test_submit_tx_sync_passes_fee_payer_account_to_fee_paid_submitter( + monkeypatch: pytest.MonkeyPatch, +) -> None: + expected_fee_payer_account = Account.generate() + sdk = BaseSDKSync( + TESTNET_CONFIG, + Account.generate(), + BaseSDKOptionsSync( + fee_payer_account=expected_fee_payer_account, + node_api_key="node-key", + ), + ) + captured: dict[str, Any] = {} + + def fake_fee_paid( + config: Any, + transaction: Any, + sender_authenticator: Any, + *, + fee_payer_account: Account | None = None, + node_api_key: str | None = None, + txn_submit_timeout: float | None = None, + ) -> PendingTransactionResponse: + captured["config"] = config + captured["transaction"] = transaction + captured["sender_authenticator"] = sender_authenticator + captured["fee_payer_account"] = fee_payer_account + captured["node_api_key"] = node_api_key + captured["txn_submit_timeout"] = txn_submit_timeout + return _pending_response("0xsync-fee") + + monkeypatch.setattr(base_module, "submit_fee_paid_transaction_sync", fake_fee_paid) + + tx = SimpleNamespace(fee_payer_address=expected_fee_payer_account.address()) + sender_authenticator = SimpleNamespace() + response = sdk.submit_tx(tx, sender_authenticator, txn_submit_timeout=4.0) + + assert response.hash == "0xsync-fee" + assert captured["config"] == TESTNET_CONFIG + assert captured["transaction"] is tx + assert captured["sender_authenticator"] is sender_authenticator + assert captured["fee_payer_account"] is expected_fee_payer_account + assert captured["node_api_key"] == "node-key" + assert captured["txn_submit_timeout"] == 4.0 + assert tx.fee_payer_address == expected_fee_payer_account.address() + + +@pytest.mark.asyncio +async def test_send_tx_overrides_fee_payer_address_in_local_mode( + monkeypatch: pytest.MonkeyPatch, +) -> None: + sender = Account.generate() + fee_payer = Account.generate() + sdk = BaseSDK( + TESTNET_CONFIG, + sender, + BaseSDKOptions( + skip_simulate=False, + fee_payer_account=fee_payer, + ), + ) + built_transactions: list[SimpleNamespace] = [] + + async def fake_build_tx( + data: Any, + sender_addr: Any, + *, + max_gas_amount: int | None = None, + gas_unit_price: int | None = None, + ) -> SimpleNamespace: + _ = (data, sender_addr, max_gas_amount, gas_unit_price) + tx = SimpleNamespace(fee_payer_address=fee_payer.address()) + built_transactions.append(tx) + return tx + + async def fake_simulate(transaction: Any) -> dict[str, str]: + _ = transaction + return {"max_gas_amount": "100", "gas_unit_price": "2"} + + def fake_sign(signer: Any, transaction: Any) -> object: + _ = (signer, transaction) + return object() + + async def fake_submit( + transaction: Any, + sender_authenticator: Any, + *, + txn_submit_timeout: float | None = None, + ) -> PendingTransactionResponse: + _ = sender_authenticator + assert txn_submit_timeout == 1.25 + assert transaction.fee_payer_address == fee_payer.address() + return _pending_response("0xsend") + + async def fake_wait( + tx_hash: str, + txn_confirm_timeout: float | None = None, + poll_interval_secs: float = 1.0, + ) -> dict[str, Any]: + _ = (txn_confirm_timeout, poll_interval_secs) + return {"hash": tx_hash, "success": True} + + monkeypatch.setattr(sdk, "build_tx", fake_build_tx) + monkeypatch.setattr(sdk, "_simulate_transaction", fake_simulate) + monkeypatch.setattr(sdk, "_sign_transaction", fake_sign) + monkeypatch.setattr(sdk, "submit_tx", fake_submit) + monkeypatch.setattr(sdk, "_wait_for_transaction", fake_wait) + + result = await sdk._send_tx( + InputEntryFunctionData(function="0x1::module::function"), + txn_submit_timeout=1.25, + ) + assert result["hash"] == "0xsend" + assert len(built_transactions) == 2 + assert all(tx.fee_payer_address == fee_payer.address() for tx in built_transactions) + + +def test_send_tx_sync_overrides_fee_payer_address_in_local_mode( + monkeypatch: pytest.MonkeyPatch, +) -> None: + sender = Account.generate() + fee_payer = Account.generate() + sdk = BaseSDKSync( + TESTNET_CONFIG, + sender, + BaseSDKOptionsSync( + skip_simulate=False, + fee_payer_account=fee_payer, + ), + ) + built_transactions: list[SimpleNamespace] = [] + + def fake_build_tx( + data: Any, + sender_addr: Any, + *, + max_gas_amount: int | None = None, + gas_unit_price: int | None = None, + ) -> SimpleNamespace: + _ = (data, sender_addr, max_gas_amount, gas_unit_price) + tx = SimpleNamespace(fee_payer_address=fee_payer.address()) + built_transactions.append(tx) + return tx + + def fake_simulate(transaction: Any) -> dict[str, str]: + _ = transaction + return {"max_gas_amount": "100", "gas_unit_price": "2"} + + def fake_sign(signer: Any, transaction: Any) -> object: + _ = (signer, transaction) + return object() + + def fake_submit( + transaction: Any, + sender_authenticator: Any, + *, + txn_submit_timeout: float | None = None, + ) -> PendingTransactionResponse: + _ = sender_authenticator + assert txn_submit_timeout == 2.25 + assert transaction.fee_payer_address == fee_payer.address() + return _pending_response("0xsend-sync") + + def fake_wait( + tx_hash: str, + txn_confirm_timeout: float | None = None, + poll_interval_secs: float = 1.0, + ) -> dict[str, Any]: + _ = (txn_confirm_timeout, poll_interval_secs) + return {"hash": tx_hash, "success": True} + + monkeypatch.setattr(sdk, "build_tx", fake_build_tx) + monkeypatch.setattr(sdk, "_simulate_transaction", fake_simulate) + monkeypatch.setattr(sdk, "_sign_transaction", fake_sign) + monkeypatch.setattr(sdk, "submit_tx", fake_submit) + monkeypatch.setattr(sdk, "_wait_for_transaction", fake_wait) + + result = sdk._send_tx( + InputEntryFunctionData(function="0x1::module::function"), + txn_submit_timeout=2.25, + ) + assert result["hash"] == "0xsend-sync" + assert len(built_transactions) == 2 + assert all(tx.fee_payer_address == fee_payer.address() for tx in built_transactions) + + +@pytest.mark.asyncio +async def test_submit_tx_rejects_missing_fee_payer_address_in_local_mode() -> None: + sdk = BaseSDK( + TESTNET_CONFIG, + Account.generate(), + BaseSDKOptions(fee_payer_account=Account.generate()), + ) + + with pytest.raises(ValueError, match="must be set"): + await sdk.submit_tx( + SimpleNamespace(fee_payer_address=None), + SimpleNamespace(), + ) + + +@pytest.mark.asyncio +async def test_submit_tx_rejects_mismatched_fee_payer_address_in_local_mode() -> None: + sdk = BaseSDK( + TESTNET_CONFIG, + Account.generate(), + BaseSDKOptions(fee_payer_account=Account.generate()), + ) + + with pytest.raises(ValueError, match="does not match"): + await sdk.submit_tx( + SimpleNamespace(fee_payer_address=AccountAddress.from_str("0x1")), + SimpleNamespace(), + ) + + +def test_submit_tx_sync_rejects_missing_fee_payer_address_in_local_mode() -> None: + sdk = BaseSDKSync( + TESTNET_CONFIG, + Account.generate(), + BaseSDKOptionsSync(fee_payer_account=Account.generate()), + ) + + with pytest.raises(ValueError, match="must be set"): + sdk.submit_tx( + SimpleNamespace(fee_payer_address=None), + SimpleNamespace(), + ) + + +def test_submit_tx_sync_rejects_mismatched_fee_payer_address_in_local_mode() -> None: + sdk = BaseSDKSync( + TESTNET_CONFIG, + Account.generate(), + BaseSDKOptionsSync(fee_payer_account=Account.generate()), + ) + + with pytest.raises(ValueError, match="does not match"): + sdk.submit_tx( + SimpleNamespace(fee_payer_address=AccountAddress.from_str("0x1")), + SimpleNamespace(), + ) + + +@pytest.mark.asyncio +async def test_fee_pay_async_prefers_local_mode_when_account_provided( + monkeypatch: pytest.MonkeyPatch, +) -> None: + called = {"local": False} + + async def fake_local(*args: Any, **kwargs: Any) -> PendingTransactionResponse: + called["local"] = True + return _pending_response("0xlocal") + + async def fake_gas_station(*args: Any, **kwargs: Any) -> PendingTransactionResponse: + raise AssertionError("gas station path should not be called when fee_payer_account is set") + + monkeypatch.setattr(fee_pay_module, "_submit_via_local_fee_payer", fake_local) + monkeypatch.setattr(fee_pay_module, "_submit_via_gas_station_api", fake_gas_station) + monkeypatch.setattr(fee_pay_module, "_submit_via_legacy_fee_payer", fake_gas_station) + + config = replace(TESTNET_CONFIG, gas_station_api_key="api-key") + response = await submit_fee_paid_transaction( + config, + SimpleNamespace(), + SimpleNamespace(), + fee_payer_account=object(), + ) + assert response.hash == "0xlocal" + assert called["local"] is True + + +@pytest.mark.asyncio +async def test_fee_pay_async_routes_to_gas_station_api(monkeypatch: pytest.MonkeyPatch) -> None: + called = {"api": False} + + async def fake_api(*args: Any, **kwargs: Any) -> PendingTransactionResponse: + called["api"] = True + return _pending_response("0xapi") + + monkeypatch.setattr(fee_pay_module, "_submit_via_gas_station_api", fake_api) + config = replace(TESTNET_CONFIG, gas_station_api_key="api-key") + response = await submit_fee_paid_transaction(config, SimpleNamespace(), SimpleNamespace()) + assert response.hash == "0xapi" + assert called["api"] is True + + +@pytest.mark.asyncio +async def test_fee_pay_async_routes_to_legacy_url(monkeypatch: pytest.MonkeyPatch) -> None: + called = {"legacy": False} + + async def fake_legacy(*args: Any, **kwargs: Any) -> PendingTransactionResponse: + called["legacy"] = True + return _pending_response("0xlegacy") + + monkeypatch.setattr(fee_pay_module, "_submit_via_legacy_fee_payer", fake_legacy) + config = replace(TESTNET_CONFIG, gas_station_api_key=None, gas_station_url="https://fee-payer") + response = await submit_fee_paid_transaction(config, SimpleNamespace(), SimpleNamespace()) + assert response.hash == "0xlegacy" + assert called["legacy"] is True + + +@pytest.mark.asyncio +async def test_fee_pay_async_requires_gas_station_config_when_not_local() -> None: + config = replace(TESTNET_CONFIG, gas_station_api_key=None, gas_station_url=None) + with pytest.raises(ValueError, match="Either gas_station_api_key or gas_station_url"): + await submit_fee_paid_transaction(config, SimpleNamespace(), SimpleNamespace()) + + +def test_fee_pay_sync_prefers_local_mode_when_account_provided( + monkeypatch: pytest.MonkeyPatch, +) -> None: + called = {"local": False} + + def fake_local(*args: Any, **kwargs: Any) -> PendingTransactionResponse: + called["local"] = True + return _pending_response("0xsync-local") + + def fake_gas_station(*args: Any, **kwargs: Any) -> PendingTransactionResponse: + raise AssertionError("gas station path should not be called when fee_payer_account is set") + + monkeypatch.setattr(fee_pay_module, "_submit_via_local_fee_payer_sync", fake_local) + monkeypatch.setattr(fee_pay_module, "_submit_via_gas_station_api_sync", fake_gas_station) + monkeypatch.setattr(fee_pay_module, "_submit_via_legacy_fee_payer_sync", fake_gas_station) + + config = replace(TESTNET_CONFIG, gas_station_api_key="api-key") + response = submit_fee_paid_transaction_sync( + config, + SimpleNamespace(), + SimpleNamespace(), + fee_payer_account=object(), + ) + assert response.hash == "0xsync-local" + assert called["local"] is True + + +@pytest.mark.asyncio +async def test_local_fee_payer_async_submits_to_fullnode(monkeypatch: pytest.MonkeyPatch) -> None: + raw_txn = SimpleNamespace( + sender="0x111", + sequence_number=7, + max_gas_amount=200_000, + gas_unit_price=2, + expiration_timestamps_secs=999, + ) + transaction = SimpleNamespace( + raw_transaction=raw_txn, + fee_payer_address=AccountAddress.from_str("0x2"), + ) + client = RecordingAsyncClient(FakeResponse(payload={"hash": "0xabc"})) + fee_payer_account = Account.generate() + + monkeypatch.setattr( + fee_pay_module, + "_build_fee_payer_signed_transaction_bytes", + lambda *_args: b"signed-bytes", + ) + + response = await submit_fee_paid_transaction( + TESTNET_CONFIG, + transaction, + SimpleNamespace(), + fee_payer_account=fee_payer_account, + node_api_key="node-key", + client=client, + txn_submit_timeout=1.5, + ) + + assert response.hash == "0xabc" + assert response.sender == "0x111" + assert response.sequence_number == "7" + assert response.max_gas_amount == "200000" + assert response.gas_unit_price == "2" + assert response.expiration_timestamp_secs == "999" + assert len(client.calls) == 1 + call = client.calls[0] + assert call["url"] == f"{TESTNET_CONFIG.fullnode_url}/transactions" + assert call["content"] == b"signed-bytes" + assert call["headers"]["Content-Type"] == "application/x.aptos.signed_transaction+bcs" + assert call["headers"]["x-api-key"] == "node-key" + assert call["timeout"] == 1.5 + + +def test_local_fee_payer_sync_submits_to_fullnode(monkeypatch: pytest.MonkeyPatch) -> None: + raw_txn = SimpleNamespace( + sender="0x111", + sequence_number=7, + max_gas_amount=200_000, + gas_unit_price=2, + expiration_timestamps_secs=999, + ) + transaction = SimpleNamespace( + raw_transaction=raw_txn, + fee_payer_address=AccountAddress.from_str("0x2"), + ) + client = RecordingSyncClient(FakeResponse(payload={"hash": "0xsync-abc"})) + fee_payer_account = Account.generate() + + monkeypatch.setattr( + fee_pay_module, + "_build_fee_payer_signed_transaction_bytes", + lambda *_args: b"sync-signed-bytes", + ) + + response = submit_fee_paid_transaction_sync( + TESTNET_CONFIG, + transaction, + SimpleNamespace(), + fee_payer_account=fee_payer_account, + node_api_key="node-key", + client=client, + txn_submit_timeout=2.5, + ) + + assert response.hash == "0xsync-abc" + assert response.sender == "0x111" + assert response.sequence_number == "7" + assert response.max_gas_amount == "200000" + assert response.gas_unit_price == "2" + assert response.expiration_timestamp_secs == "999" + assert len(client.calls) == 1 + call = client.calls[0] + assert call["url"] == f"{TESTNET_CONFIG.fullnode_url}/transactions" + assert call["content"] == b"sync-signed-bytes" + assert call["headers"]["Content-Type"] == "application/x.aptos.signed_transaction+bcs" + assert call["headers"]["x-api-key"] == "node-key" + assert call["timeout"] == 2.5 + + +def test_build_fee_payer_signed_transaction_rejects_mismatched_fee_payer_address() -> None: + fee_payer_account = Account.generate() + transaction = SimpleNamespace( + fee_payer_address=AccountAddress.from_str("0x1"), + raw_transaction=SimpleNamespace(), + ) + + with pytest.raises(ValueError, match="does not match"): + fee_pay_module._build_fee_payer_signed_transaction_bytes( + transaction, + sender_authenticator=SimpleNamespace(), + fee_payer_account=fee_payer_account, + ) + + +def test_build_fee_payer_signed_transaction_rejects_missing_fee_payer_address() -> None: + fee_payer_account = Account.generate() + transaction = SimpleNamespace( + fee_payer_address=None, + raw_transaction=SimpleNamespace(), + ) + + with pytest.raises(ValueError, match="must be set"): + fee_pay_module._build_fee_payer_signed_transaction_bytes( + transaction, + sender_authenticator=SimpleNamespace(), + fee_payer_account=fee_payer_account, + )