diff --git a/changelog.md b/changelog.md index 3b093e74..76ffd2fc 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,11 @@ Upcoming (TBD) ============== +Features +--------- +* Add more output to the `status` command. + + Documentation --------- * Give example for ANSI prompt colors in `~/.myclirc`. diff --git a/mycli/packages/special/dbcommands.py b/mycli/packages/special/dbcommands.py index a2705053..e5043ee5 100644 --- a/mycli/packages/special/dbcommands.py +++ b/mycli/packages/special/dbcommands.py @@ -8,7 +8,13 @@ from mycli import __version__ from mycli.packages.special import iocommands from mycli.packages.special.main import ArgType, special_command -from mycli.packages.special.utils import format_uptime, get_ssl_version +from mycli.packages.special.utils import ( + format_uptime, + get_local_timezone, + get_server_timezone, + get_ssl_cipher, + get_ssl_version, +) from mycli.packages.sqlresult import SQLResult logger = logging.getLogger(__name__) @@ -69,7 +75,7 @@ def status(cur: Cursor, **_) -> list[SQLResult]: try: cur.execute(query) except ProgrammingError: - # Fallback in case query fail, as it does with Mysql 4 + # Fallback in case query fails, as it does with Mysql 4 query = "SHOW STATUS;" logger.debug(query) cur.execute(query) @@ -78,15 +84,24 @@ def status(cur: Cursor, **_) -> list[SQLResult]: query = "SHOW GLOBAL VARIABLES;" logger.debug(query) cur.execute(query) - variables = dict(cur.fetchall()) + global_variables = dict(cur.fetchall()) - # prepare in case keys are bytes, as with Python 3 and Mysql 4 - if isinstance(list(variables)[0], bytes) and isinstance(list(status)[0], bytes): - variables = {k.decode("utf-8"): v.decode("utf-8") for k, v in variables.items()} + query = "SHOW SESSION VARIABLES;" + logger.debug(query) + cur.execute(query) + session_variables = dict(cur.fetchall()) + + # decode in case keys are bytes, as with Mysql 4 + if global_variables and isinstance(list(global_variables)[0], bytes): + global_variables = {k.decode("utf-8"): v.decode("utf-8") for k, v in global_variables.items()} + if session_variables and isinstance(list(session_variables)[0], bytes): + session_variables = {k.decode("utf-8"): v.decode("utf-8") for k, v in session_variables.items()} + if status and isinstance(list(status)[0], bytes): status = {k.decode("utf-8"): v.decode("utf-8") for k, v in status.items()} # Create output buffers. preamble = [] + header = ['Setting', 'Value'] output = [] footer = [] @@ -111,7 +126,6 @@ def status(cur: Cursor, **_) -> list[SQLResult]: else: db = "" user = "" - output.append(("Current database:", db)) output.append(("Current user:", user)) @@ -124,9 +138,16 @@ def status(cur: Cursor, **_) -> list[SQLResult]: pager = "stdout" output.append(("Current pager:", pager)) - output.append(("Server version:", f'{variables["version"]} {variables["version_comment"]}')) - output.append(("Protocol version:", variables["protocol_version"])) - output.append(('SSL/TLS version:', get_ssl_version(cur))) + output.append(("Using delimiter:", iocommands.get_current_delimiter())) + output.append(("Using outfile:", iocommands.tee_file.name if iocommands.tee_file else '')) + + output.append(("Server version:", f'{global_variables["version"]} {global_variables["version_comment"]}')) + output.append(("Protocol version:", global_variables["protocol_version"])) + if cipher := get_ssl_cipher(cur): + output.append(('SSL:', f'Cipher in use is {cipher}')) + else: + output.append(('SSL:', '')) + output.append(('SSL/TLS version:', get_ssl_version(cur) or '')) if getattr(cur.connection, 'unix_socket', None): host_info = cur.connection.host_info @@ -135,23 +156,28 @@ def status(cur: Cursor, **_) -> list[SQLResult]: output.append(("Connection:", host_info)) - query = "SELECT @@character_set_server, @@character_set_database, @@character_set_client, @@character_set_connection LIMIT 1;" - logger.debug(query) - cur.execute(query) - if one := cur.fetchone(): - charset = one - else: - charset = ("", "", "", "") - output.append(("Server characterset:", charset[0])) - output.append(("Db characterset:", charset[1])) - output.append(("Client characterset:", charset[2])) - output.append(("Conn. characterset:", charset[3])) + charset_spec = [ + {'name': 'Server characterset:', 'variable': 'character_set_server'}, + {'name': 'Db characterset:', 'variable': 'character_set_database'}, + {'name': 'Client characterset:', 'variable': 'character_set_client'}, + {'name': 'Conn. characterset:', 'variable': 'character_set_connection'}, + {'name': 'Result characterset:', 'variable': 'character_set_results'}, + ] + for elt in charset_spec: + if elt['variable'] in session_variables: + value = session_variables[elt['variable']] + else: + value = '' + output.append((elt['name'], value)) if getattr(cur.connection, 'unix_socket', None): - output.append(('UNIX socket:', variables['socket'])) + output.append(('UNIX socket:', global_variables['socket'])) else: output.append(('TCP port:', cur.connection.port)) + output.append(('Server timezone:', get_server_timezone(global_variables))) + output.append(('Local timezone:', get_local_timezone())) + if "Uptime" in status: output.append(("Uptime:", format_uptime(status["Uptime"]))) @@ -174,4 +200,4 @@ def status(cur: Cursor, **_) -> list[SQLResult]: footer.append("--------------") - return [SQLResult(preamble="\n".join(preamble), rows=output, postamble="\n".join(footer))] + return [SQLResult(preamble="\n".join(preamble), header=header, rows=output, postamble="\n".join(footer))] diff --git a/mycli/packages/special/utils.py b/mycli/packages/special/utils.py index c395c2c9..fc014323 100644 --- a/mycli/packages/special/utils.py +++ b/mycli/packages/special/utils.py @@ -1,5 +1,7 @@ +import datetime import logging import os +from typing import Any import click import pymysql @@ -110,3 +112,34 @@ def get_ssl_version(cur: Cursor) -> str | None: pass return ssl_version + + +def get_ssl_cipher(cur: Cursor) -> str | None: + query = 'SHOW STATUS LIKE "Ssl_cipher"' + logger.debug(query) + + ssl_cipher = None + + try: + cur.execute(query) + if one := cur.fetchone(): + ssl_cipher = one[1] or None + except pymysql.err.OperationalError: + pass + + return ssl_cipher + + +def get_server_timezone(variables: dict[str, Any]) -> str: + try: + if variables['time_zone'] == 'SYSTEM': + server_tz = variables['system_time_zone'] + else: + server_tz = variables['time_zone'] + return server_tz + except KeyError: + return '' + + +def get_local_timezone() -> str: + return datetime.datetime.now().astimezone().tzname() or '' diff --git a/test/pytests/test_special_dbcommands.py b/test/pytests/test_special_dbcommands.py index e2e0d7f4..2859e654 100644 --- a/test/pytests/test_special_dbcommands.py +++ b/test/pytests/test_special_dbcommands.py @@ -182,6 +182,7 @@ def test_status_uses_global_queries_decodes_bytes_and_formats_stats(monkeypatch) monkeypatch.setattr(dbcommands.platform, 'python_implementation', lambda: 'CPython') monkeypatch.setattr(dbcommands.platform, 'python_version', lambda: '3.14.0') monkeypatch.setattr(dbcommands.iocommands, 'is_pager_enabled', lambda: True) + monkeypatch.setattr(dbcommands, 'get_ssl_cipher', lambda cur: 'TLS_AES_256_GCM_SHA384') monkeypatch.setattr(dbcommands, 'get_ssl_version', lambda cur: 'TLSv1.3') monkeypatch.setattr(dbcommands, 'format_uptime', lambda uptime: f'{uptime} seconds') monkeypatch.setenv('PAGER', 'less -SR') @@ -210,8 +211,14 @@ def test_status_uses_global_queries_decodes_bytes_and_formats_stats(monkeypatch) 'SELECT DATABASE(), USER();': { 'rows': [('test_db', 'test_user')], }, - 'SELECT @@character_set_server, @@character_set_database, @@character_set_client, @@character_set_connection LIMIT 1;': { - 'rows': [('utf8mb4', 'utf8mb4', 'utf8mb4', 'utf8mb4')], + 'SHOW SESSION VARIABLES;': { + 'rows': [ + (b'character_set_server', b'utf8mb4'), + (b'character_set_database', b'utf8mb4'), + (b'character_set_client', b'utf8mb4'), + (b'character_set_connection', b'utf8mb4'), + (b'character_set_results', b'utf8mb4'), + ], }, }, ) @@ -225,6 +232,7 @@ def test_status_uses_global_queries_decodes_bytes_and_formats_stats(monkeypatch) assert ('Current pager:', 'less -SR') in result.rows assert ('Server version:', '8.0.0 Community') in result.rows assert ('Protocol version:', '10') in result.rows + assert ('SSL:', 'Cipher in use is TLS_AES_256_GCM_SHA384') in result.rows assert ('SSL/TLS version:', 'TLSv1.3') in result.rows assert ('Connection:', 'tcp-host via TCP/IP') in result.rows assert ('TCP port:', 3307) in result.rows @@ -264,10 +272,10 @@ def test_status_falls_back_to_show_status_and_handles_empty_selects(monkeypatch) ('socket', '/tmp/mysql.sock'), ], }, - 'SELECT DATABASE(), USER();': { + 'SHOW SESSION VARIABLES;': { 'rows': [], }, - 'SELECT @@character_set_server, @@character_set_database, @@character_set_client, @@character_set_connection LIMIT 1;': { + 'SELECT DATABASE(), USER();': { 'rows': [], }, }, @@ -282,9 +290,9 @@ def test_status_falls_back_to_show_status_and_handles_empty_selects(monkeypatch) assert ('Connection:', 'Localhost via UNIX socket') in result.rows assert ('UNIX socket:', '/tmp/mysql.sock') in result.rows assert ('Server characterset:', '') in result.rows - assert ('Db characterset:', '') in result.rows + assert ('Db characterset:', '') in result.rows assert ('Client characterset:', '') in result.rows - assert ('Conn. characterset:', '') in result.rows + assert ('Conn. characterset:', '') in result.rows assert 'Connections:' not in result.postamble assert '--------------' in result.postamble @@ -307,8 +315,14 @@ def test_status_uses_system_default_pager_when_enabled_without_env(monkeypatch) 'SELECT DATABASE(), USER();': { 'rows': [('db', 'user')], }, - 'SELECT @@character_set_server, @@character_set_database, @@character_set_client, @@character_set_connection LIMIT 1;': { - 'rows': [('utf8', 'utf8', 'utf8', 'utf8')], + 'SHOW SESSION VARIABLES;': { + 'rows': [ + ('character_set_server', 'utf8'), + ('character_set_database', 'utf8'), + ('character_set_client', 'utf8'), + ('character_set_connection', 'utf8'), + ('character_set_results', 'utf8'), + ], }, }, ) diff --git a/test/pytests/test_special_utils.py b/test/pytests/test_special_utils.py index d21f1d25..efea02df 100644 --- a/test/pytests/test_special_utils.py +++ b/test/pytests/test_special_utils.py @@ -12,6 +12,9 @@ from mycli.packages.special.utils import ( CACHED_SSL_VERSION, format_uptime, + get_local_timezone, + get_server_timezone, + get_ssl_cipher, get_ssl_version, get_uptime, get_warning_count, @@ -185,3 +188,92 @@ def test_get_ssl_version_ignores_operational_error() -> None: cur.execute.side_effect = pymysql.err.OperationalError() assert get_ssl_version(cur) is None + + +def test_get_ssl_cipher_returns_value() -> None: + cur = MagicMock() + cur.fetchone.return_value = ('Ssl_cipher', 'TLS_AES_256_GCM_SHA384') + + ssl_cipher = get_ssl_cipher(cur) + + cur.execute.assert_called_once_with('SHOW STATUS LIKE "Ssl_cipher"') + assert ssl_cipher == 'TLS_AES_256_GCM_SHA384' + + +def test_get_ssl_cipher_returns_none_for_missing_row() -> None: + cur = MagicMock() + cur.fetchone.return_value = None + + assert get_ssl_cipher(cur) is None + + +def test_get_ssl_cipher_returns_none_for_empty_value() -> None: + cur = MagicMock() + cur.fetchone.return_value = ('Ssl_cipher', '') + + assert get_ssl_cipher(cur) is None + + +def test_get_ssl_cipher_ignores_operational_error() -> None: + cur = MagicMock() + cur.execute.side_effect = pymysql.err.OperationalError() + + assert get_ssl_cipher(cur) is None + + +def test_get_server_timezone_prefers_system_timezone_when_requested() -> None: + variables = { + 'time_zone': 'SYSTEM', + 'system_time_zone': 'UTC', + } + + assert get_server_timezone(variables) == 'UTC' + + +def test_get_server_timezone_returns_explicit_timezone() -> None: + variables = { + 'time_zone': '+02:00', + 'system_time_zone': 'UTC', + } + + assert get_server_timezone(variables) == '+02:00' + + +def test_get_server_timezone_returns_empty_string_when_keys_are_missing() -> None: + assert get_server_timezone({}) == '' + + +def test_get_local_timezone_returns_tzname(monkeypatch) -> None: + class FakeAwareDatetime: + def tzname(self) -> str: + return 'EDT' + + class FakeDatetime: + @staticmethod + def now() -> 'FakeDatetime': + return FakeDatetime() + + def astimezone(self) -> FakeAwareDatetime: + return FakeAwareDatetime() + + monkeypatch.setattr(mycli.packages.special.utils.datetime, 'datetime', FakeDatetime) + + assert get_local_timezone() == 'EDT' + + +def test_get_local_timezone_returns_empty_string_when_tzname_is_none(monkeypatch) -> None: + class FakeAwareDatetime: + def tzname(self) -> None: + return None + + class FakeDatetime: + @staticmethod + def now() -> 'FakeDatetime': + return FakeDatetime() + + def astimezone(self) -> FakeAwareDatetime: + return FakeAwareDatetime() + + monkeypatch.setattr(mycli.packages.special.utils.datetime, 'datetime', FakeDatetime) + + assert get_local_timezone() == ''