diff --git a/webext/app/credential_manager_shim.py b/webext/app/credential_manager_shim.py index 762613b..abdbf6b 100755 --- a/webext/app/credential_manager_shim.py +++ b/webext/app/credential_manager_shim.py @@ -16,6 +16,7 @@ from dbus_next import Variant from dbus_next.aio import MessageBus +from dbus_next.proxy_object import BaseProxyInterface from dbus_next.constants import MessageType from dbus_next.message import Message @@ -25,6 +26,7 @@ APP_ID = "@APP_ID@" DBUS_DOC_FILE = "@DBUS_DOC_FILE@" +INTERFACE: Optional[BaseProxyInterface] = None def getMessage(): @@ -397,20 +399,31 @@ async def get_passkey(interface, options, origin, top_origin): } logging.debug("Sending request to D-Bus API") - rsp = await interface.call_get_credential(["", req]) - if rsp["type"].value != "public-key": + request_event = create_portal_request_message_handler(interface.bus) + req = { + "handle_token": Variant("s", request_event.token), + "public_key": Variant("s", req_json), + } + if top_origin != origin: + req["top_origin"] = Variant("s", top_origin) + _rsp = await interface.call_get_credential("", origin, req) + result = await request_event.wait() + if result["type"].value != "public-key": raise Exception( - f"Invalid credential type received: expected 'public-key', received {rsp['type'].value}" + f"Invalid credential type received: expected 'public-key', received {result['type'].value}" ) response_json = json.loads( - rsp["public_key"].value["authentication_response_json"].value + result["public_key"].value["authentication_response_json"].value ) return response_json -async def run(cmd, options, origin, top_origin): - logging.debug("Executing command") +async def get_interface(): + global INTERFACE + if INTERFACE and INTERFACE.bus.connected: + return INTERFACE + bus = await MessageBus().connect() logging.debug("Connected to bus") import os @@ -438,11 +451,17 @@ async def run(cmd, options, origin, top_origin): "/org/freedesktop/portal/desktop", introspection, ) - interface = proxy_object.get_interface( "org.freedesktop.portal.experimental.Credential" ) + INTERFACE = interface logging.debug(f"Connected to interface at {interface.path}") + return INTERFACE + + +async def run(cmd, options, origin, top_origin): + logging.debug("Executing command") + interface = await get_interface() if cmd == "create": if "publicKey" in options: @@ -463,10 +482,33 @@ async def run(cmd, options, origin, top_origin): f"Could not get unknown credential type: {options.keys()[0]}" ) elif cmd == "getClientCapabilities": - rsp = await interface.call_get_client_capabilities() - response = {} - for name, val in rsp.items(): - response[name] = val.value + conditional_create = await interface.get_conditional_create() + conditional_get = await interface.get_conditional_get() + hybrid_transport = await interface.get_hybrid_transport() + passkey_platform_authenticator = ( + await interface.get_passkey_platform_authenticator() + ) + user_verifying_platform_authenticator = ( + await interface.get_user_verifying_platform_authenticator() + ) + related_origins = await interface.get_related_origins() + signal_all_accepted_credentials = ( + await interface.get_signal_all_accepted_credentials() + ) + signal_current_user_details = await interface.get_signal_current_user_details() + signal_unknown_credential = await interface.get_signal_unknown_credential() + + response = { + "conditional_create": conditional_create, + "conditional_get": conditional_get, + "hybrid_transport": hybrid_transport, + "passkey_platform_authenticator": passkey_platform_authenticator, + "user_verifying_platform_authenticator": user_verifying_platform_authenticator, + "related_origins": related_origins, + "signal_all_accepted_credentials": signal_all_accepted_credentials, + "signal_current_user_details": signal_current_user_details, + "signal_unknown_credential": signal_unknown_credential, + } return response else: raise Exception(f"unknown cmd: {cmd}") @@ -477,6 +519,7 @@ async def run(cmd, options, origin, top_origin): async def main(): logging.info("starting credential_manager_shim") + while not quit.is_set(): logging.debug("starting event loop message") receivedMessage = getMessage()