Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Upcoming (TBD)
==============

Features
---------
* Add more output to the `status` command.


Documentation
---------
* Give example for ANSI prompt colors in `~/.myclirc`.
Expand Down
72 changes: 49 additions & 23 deletions mycli/packages/special/dbcommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from mycli import __version__
from mycli.packages.special import iocommands
from mycli.packages.special.main import ArgType, special_command
from mycli.packages.special.utils import format_uptime, get_ssl_version
from mycli.packages.special.utils import (
format_uptime,
get_local_timezone,
get_server_timezone,
get_ssl_cipher,
get_ssl_version,
)
from mycli.packages.sqlresult import SQLResult

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -69,7 +75,7 @@ def status(cur: Cursor, **_) -> list[SQLResult]:
try:
cur.execute(query)
except ProgrammingError:
# Fallback in case query fail, as it does with Mysql 4
# Fallback in case query fails, as it does with Mysql 4
query = "SHOW STATUS;"
logger.debug(query)
cur.execute(query)
Expand All @@ -78,15 +84,24 @@ def status(cur: Cursor, **_) -> list[SQLResult]:
query = "SHOW GLOBAL VARIABLES;"
logger.debug(query)
cur.execute(query)
variables = dict(cur.fetchall())
global_variables = dict(cur.fetchall())

# prepare in case keys are bytes, as with Python 3 and Mysql 4
if isinstance(list(variables)[0], bytes) and isinstance(list(status)[0], bytes):
variables = {k.decode("utf-8"): v.decode("utf-8") for k, v in variables.items()}
query = "SHOW SESSION VARIABLES;"
logger.debug(query)
cur.execute(query)
session_variables = dict(cur.fetchall())

# decode in case keys are bytes, as with Mysql 4
if global_variables and isinstance(list(global_variables)[0], bytes):
global_variables = {k.decode("utf-8"): v.decode("utf-8") for k, v in global_variables.items()}
if session_variables and isinstance(list(session_variables)[0], bytes):
session_variables = {k.decode("utf-8"): v.decode("utf-8") for k, v in session_variables.items()}
if status and isinstance(list(status)[0], bytes):
status = {k.decode("utf-8"): v.decode("utf-8") for k, v in status.items()}

# Create output buffers.
preamble = []
header = ['Setting', 'Value']
output = []
footer = []

Expand All @@ -111,7 +126,6 @@ def status(cur: Cursor, **_) -> list[SQLResult]:
else:
db = ""
user = ""

output.append(("Current database:", db))
output.append(("Current user:", user))

Expand All @@ -124,9 +138,16 @@ def status(cur: Cursor, **_) -> list[SQLResult]:
pager = "stdout"
output.append(("Current pager:", pager))

output.append(("Server version:", f'{variables["version"]} {variables["version_comment"]}'))
output.append(("Protocol version:", variables["protocol_version"]))
output.append(('SSL/TLS version:', get_ssl_version(cur)))
output.append(("Using delimiter:", iocommands.get_current_delimiter()))
output.append(("Using outfile:", iocommands.tee_file.name if iocommands.tee_file else ''))

output.append(("Server version:", f'{global_variables["version"]} {global_variables["version_comment"]}'))
output.append(("Protocol version:", global_variables["protocol_version"]))
if cipher := get_ssl_cipher(cur):
output.append(('SSL:', f'Cipher in use is {cipher}'))
else:
output.append(('SSL:', ''))
output.append(('SSL/TLS version:', get_ssl_version(cur) or ''))

if getattr(cur.connection, 'unix_socket', None):
host_info = cur.connection.host_info
Expand All @@ -135,23 +156,28 @@ def status(cur: Cursor, **_) -> list[SQLResult]:

output.append(("Connection:", host_info))

query = "SELECT @@character_set_server, @@character_set_database, @@character_set_client, @@character_set_connection LIMIT 1;"
logger.debug(query)
cur.execute(query)
if one := cur.fetchone():
charset = one
else:
charset = ("", "", "", "")
output.append(("Server characterset:", charset[0]))
output.append(("Db characterset:", charset[1]))
output.append(("Client characterset:", charset[2]))
output.append(("Conn. characterset:", charset[3]))
charset_spec = [
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we let the cli_helpers render format these instead of forcing a format?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand. It does get formatted into a table by cli_helpers.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the spacing like Db characterset:, adding the spaces in between. But if that is the only way to get the alignment you want then that works.

{'name': 'Server characterset:', 'variable': 'character_set_server'},
{'name': 'Db characterset:', 'variable': 'character_set_database'},
{'name': 'Client characterset:', 'variable': 'character_set_client'},
{'name': 'Conn. characterset:', 'variable': 'character_set_connection'},
{'name': 'Result characterset:', 'variable': 'character_set_results'},
]
for elt in charset_spec:
if elt['variable'] in session_variables:
value = session_variables[elt['variable']]
else:
value = ''
output.append((elt['name'], value))

if getattr(cur.connection, 'unix_socket', None):
output.append(('UNIX socket:', variables['socket']))
output.append(('UNIX socket:', global_variables['socket']))
else:
output.append(('TCP port:', cur.connection.port))

output.append(('Server timezone:', get_server_timezone(global_variables)))
output.append(('Local timezone:', get_local_timezone()))

if "Uptime" in status:
output.append(("Uptime:", format_uptime(status["Uptime"])))

Expand All @@ -174,4 +200,4 @@ def status(cur: Cursor, **_) -> list[SQLResult]:

footer.append("--------------")

return [SQLResult(preamble="\n".join(preamble), rows=output, postamble="\n".join(footer))]
return [SQLResult(preamble="\n".join(preamble), header=header, rows=output, postamble="\n".join(footer))]
33 changes: 33 additions & 0 deletions mycli/packages/special/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import datetime
import logging
import os
from typing import Any

import click
import pymysql
Expand Down Expand Up @@ -110,3 +112,34 @@ def get_ssl_version(cur: Cursor) -> str | None:
pass

return ssl_version


def get_ssl_cipher(cur: Cursor) -> str | None:
query = 'SHOW STATUS LIKE "Ssl_cipher"'
logger.debug(query)

ssl_cipher = None

try:
cur.execute(query)
if one := cur.fetchone():
ssl_cipher = one[1] or None
except pymysql.err.OperationalError:
pass

return ssl_cipher


def get_server_timezone(variables: dict[str, Any]) -> str:
try:
if variables['time_zone'] == 'SYSTEM':
server_tz = variables['system_time_zone']
else:
server_tz = variables['time_zone']
return server_tz
except KeyError:
return ''


def get_local_timezone() -> str:
return datetime.datetime.now().astimezone().tzname() or ''
30 changes: 22 additions & 8 deletions test/pytests/test_special_dbcommands.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def test_status_uses_global_queries_decodes_bytes_and_formats_stats(monkeypatch)
monkeypatch.setattr(dbcommands.platform, 'python_implementation', lambda: 'CPython')
monkeypatch.setattr(dbcommands.platform, 'python_version', lambda: '3.14.0')
monkeypatch.setattr(dbcommands.iocommands, 'is_pager_enabled', lambda: True)
monkeypatch.setattr(dbcommands, 'get_ssl_cipher', lambda cur: 'TLS_AES_256_GCM_SHA384')
monkeypatch.setattr(dbcommands, 'get_ssl_version', lambda cur: 'TLSv1.3')
monkeypatch.setattr(dbcommands, 'format_uptime', lambda uptime: f'{uptime} seconds')
monkeypatch.setenv('PAGER', 'less -SR')
Expand Down Expand Up @@ -210,8 +211,14 @@ def test_status_uses_global_queries_decodes_bytes_and_formats_stats(monkeypatch)
'SELECT DATABASE(), USER();': {
'rows': [('test_db', 'test_user')],
},
'SELECT @@character_set_server, @@character_set_database, @@character_set_client, @@character_set_connection LIMIT 1;': {
'rows': [('utf8mb4', 'utf8mb4', 'utf8mb4', 'utf8mb4')],
'SHOW SESSION VARIABLES;': {
'rows': [
(b'character_set_server', b'utf8mb4'),
(b'character_set_database', b'utf8mb4'),
(b'character_set_client', b'utf8mb4'),
(b'character_set_connection', b'utf8mb4'),
(b'character_set_results', b'utf8mb4'),
],
},
},
)
Expand All @@ -225,6 +232,7 @@ def test_status_uses_global_queries_decodes_bytes_and_formats_stats(monkeypatch)
assert ('Current pager:', 'less -SR') in result.rows
assert ('Server version:', '8.0.0 Community') in result.rows
assert ('Protocol version:', '10') in result.rows
assert ('SSL:', 'Cipher in use is TLS_AES_256_GCM_SHA384') in result.rows
assert ('SSL/TLS version:', 'TLSv1.3') in result.rows
assert ('Connection:', 'tcp-host via TCP/IP') in result.rows
assert ('TCP port:', 3307) in result.rows
Expand Down Expand Up @@ -264,10 +272,10 @@ def test_status_falls_back_to_show_status_and_handles_empty_selects(monkeypatch)
('socket', '/tmp/mysql.sock'),
],
},
'SELECT DATABASE(), USER();': {
'SHOW SESSION VARIABLES;': {
'rows': [],
},
'SELECT @@character_set_server, @@character_set_database, @@character_set_client, @@character_set_connection LIMIT 1;': {
'SELECT DATABASE(), USER();': {
'rows': [],
},
},
Expand All @@ -282,9 +290,9 @@ def test_status_falls_back_to_show_status_and_handles_empty_selects(monkeypatch)
assert ('Connection:', 'Localhost via UNIX socket') in result.rows
assert ('UNIX socket:', '/tmp/mysql.sock') in result.rows
assert ('Server characterset:', '') in result.rows
assert ('Db characterset:', '') in result.rows
assert ('Db characterset:', '') in result.rows
assert ('Client characterset:', '') in result.rows
assert ('Conn. characterset:', '') in result.rows
assert ('Conn. characterset:', '') in result.rows
assert 'Connections:' not in result.postamble
assert '--------------' in result.postamble

Expand All @@ -307,8 +315,14 @@ def test_status_uses_system_default_pager_when_enabled_without_env(monkeypatch)
'SELECT DATABASE(), USER();': {
'rows': [('db', 'user')],
},
'SELECT @@character_set_server, @@character_set_database, @@character_set_client, @@character_set_connection LIMIT 1;': {
'rows': [('utf8', 'utf8', 'utf8', 'utf8')],
'SHOW SESSION VARIABLES;': {
'rows': [
('character_set_server', 'utf8'),
('character_set_database', 'utf8'),
('character_set_client', 'utf8'),
('character_set_connection', 'utf8'),
('character_set_results', 'utf8'),
],
},
},
)
Expand Down
92 changes: 92 additions & 0 deletions test/pytests/test_special_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from mycli.packages.special.utils import (
CACHED_SSL_VERSION,
format_uptime,
get_local_timezone,
get_server_timezone,
get_ssl_cipher,
get_ssl_version,
get_uptime,
get_warning_count,
Expand Down Expand Up @@ -185,3 +188,92 @@ def test_get_ssl_version_ignores_operational_error() -> None:
cur.execute.side_effect = pymysql.err.OperationalError()

assert get_ssl_version(cur) is None


def test_get_ssl_cipher_returns_value() -> None:
cur = MagicMock()
cur.fetchone.return_value = ('Ssl_cipher', 'TLS_AES_256_GCM_SHA384')

ssl_cipher = get_ssl_cipher(cur)

cur.execute.assert_called_once_with('SHOW STATUS LIKE "Ssl_cipher"')
assert ssl_cipher == 'TLS_AES_256_GCM_SHA384'


def test_get_ssl_cipher_returns_none_for_missing_row() -> None:
cur = MagicMock()
cur.fetchone.return_value = None

assert get_ssl_cipher(cur) is None


def test_get_ssl_cipher_returns_none_for_empty_value() -> None:
cur = MagicMock()
cur.fetchone.return_value = ('Ssl_cipher', '')

assert get_ssl_cipher(cur) is None


def test_get_ssl_cipher_ignores_operational_error() -> None:
cur = MagicMock()
cur.execute.side_effect = pymysql.err.OperationalError()

assert get_ssl_cipher(cur) is None


def test_get_server_timezone_prefers_system_timezone_when_requested() -> None:
variables = {
'time_zone': 'SYSTEM',
'system_time_zone': 'UTC',
}

assert get_server_timezone(variables) == 'UTC'


def test_get_server_timezone_returns_explicit_timezone() -> None:
variables = {
'time_zone': '+02:00',
'system_time_zone': 'UTC',
}

assert get_server_timezone(variables) == '+02:00'


def test_get_server_timezone_returns_empty_string_when_keys_are_missing() -> None:
assert get_server_timezone({}) == ''


def test_get_local_timezone_returns_tzname(monkeypatch) -> None:
class FakeAwareDatetime:
def tzname(self) -> str:
return 'EDT'

class FakeDatetime:
@staticmethod
def now() -> 'FakeDatetime':
return FakeDatetime()

def astimezone(self) -> FakeAwareDatetime:
return FakeAwareDatetime()

monkeypatch.setattr(mycli.packages.special.utils.datetime, 'datetime', FakeDatetime)

assert get_local_timezone() == 'EDT'


def test_get_local_timezone_returns_empty_string_when_tzname_is_none(monkeypatch) -> None:
class FakeAwareDatetime:
def tzname(self) -> None:
return None

class FakeDatetime:
@staticmethod
def now() -> 'FakeDatetime':
return FakeDatetime()

def astimezone(self) -> FakeAwareDatetime:
return FakeAwareDatetime()

monkeypatch.setattr(mycli.packages.special.utils.datetime, 'datetime', FakeDatetime)

assert get_local_timezone() == ''