diff --git a/prometheus_client/metrics.py b/prometheus_client/metrics.py index 4c53b26b..ea7740a6 100644 --- a/prometheus_client/metrics.py +++ b/prometheus_client/metrics.py @@ -207,6 +207,10 @@ def remove(self, *labelvalues: Any) -> None: warnings.warn( "Removal of labels has not been implemented in multi-process mode yet.", UserWarning) + if 'PROMETHEUS_REDIS_URL' in os.environ: + warnings.warn( + "Removal of labels has not been implemented in redis mode yet.", + UserWarning) if not self._labelnames: raise ValueError('No label names were set when constructing %s' % self) @@ -226,6 +230,10 @@ def remove_by_labels(self, labels: dict[str, str]) -> None: "Removal of labels has not been implemented in multi-process mode yet.", UserWarning ) + if 'PROMETHEUS_REDIS_URL' in os.environ: + warnings.warn( + "Removal of labels has not been implemented in redis mode yet.", + UserWarning) if not self._labelnames: raise ValueError('No label names were set when constructing %s' % self) @@ -258,6 +266,10 @@ def clear(self) -> None: warnings.warn( "Clearing labels has not been implemented in multi-process mode yet", UserWarning) + if 'PROMETHEUS_REDIS_URL' in os.environ: + warnings.warn( + "Clearing of labels has not been implemented in redis mode yet.", + UserWarning) with self._lock: self._metrics = {} diff --git a/prometheus_client/redis.py b/prometheus_client/redis.py new file mode 100644 index 00000000..d952164b --- /dev/null +++ b/prometheus_client/redis.py @@ -0,0 +1,90 @@ +import os +from datetime import timedelta +from threading import Event, Thread +from typing import Any +from urllib.parse import urlsplit + +from redis import Redis + +# For testing, a pool of otherwise anonymous FakeRedis instances are made +# available by ID +_fake_redis_pool: dict[int, Redis] = {} + + +def redis_client() -> Redis: + """ + Create a redis client for PROMETHEUS_REDIS_URL. + + Configure the redis database via a URL in PROMETHEUS_REDIS_URL of the form + redis://localhost:6379/0 + """ + parsed_url = urlsplit(os.environ["PROMETHEUS_REDIS_URL"]) + assert parsed_url.path.startswith("/") + assert parsed_url.path[1:].isdigit() + port = parsed_url.port or 6379 + db = int(parsed_url.path[1:]) + + if parsed_url.scheme == "fakeredis": + from fakeredis import FakeRedis + + if db not in _fake_redis_pool: + _fake_redis_pool[db] = FakeRedis() + return _fake_redis_pool[db] + + assert parsed_url.scheme == "redis" + assert parsed_url.hostname + return Redis(host=parsed_url.hostname, port=port, db=db) + + +# For each process identifier, a list of keys that should be kept from expiring +_live_metrics: dict[str, set[str]] = {} + + +def _key_expiry() -> timedelta: + """Return the configured expiry for multiprocess keys.""" + return timedelta(seconds=int(os.environ.get("PROMETHEUS_REDIS_REFRESH_TTL", 20))) + + +class KeepMetricsAliveThread(Thread): + """A daemon thread that keeps metrics from expiring as long as we live.""" + + stop: Event + identifier: str + + def __init__(self, identifier: str, *args: Any, **kwargs: Any) -> None: + self.stop = Event() + self.identifier = identifier + super().__init__(*args, **kwargs) + + def run(self) -> None: + delay = int(os.environ.get("PROMETHEUS_REDIS_REFRESH_FREQUENCY", 10)) + expiry = _key_expiry() + client = redis_client() + while not self.stop.wait(delay): + for key in _live_metrics[self.identifier]: + client.expire(key, expiry) + + +_daemon_threads: dict[str, KeepMetricsAliveThread] = {} + + +def _keep_key_from_expiring(identifier: str, key: str) -> None: + """Stop key for process identifier from expiring as long as we are alive.""" + _live_metrics.setdefault(identifier, set()).add(key) + if identifier not in _daemon_threads: + thread = KeepMetricsAliveThread(identifier=identifier, daemon=True) + thread.start() + _daemon_threads[identifier] = thread + + +def mark_process_dead(identifier: str | int) -> None: + """Immediately expire all live* metrics for process identifier.""" + thread = _daemon_threads.pop(str(identifier), None) + if thread is not None: + thread.stop.set() + thread.join() + + keys = _live_metrics.pop(str(identifier), None) + if not keys: + return + redis_client().delete(*keys) diff --git a/prometheus_client/redis_collector.py b/prometheus_client/redis_collector.py new file mode 100644 index 00000000..38844cdc --- /dev/null +++ b/prometheus_client/redis_collector.py @@ -0,0 +1,144 @@ +import json +from collections.abc import Iterable +from typing import cast + +from .metrics_core import Metric +from .redis import redis_client +from .registry import Collector, CollectorRegistry +from .samples import Sample +from .values import MULTIPROCESS_MODE_T + + +class RedisCollector(Collector): + """Collector for redis mode.""" + + def __init__(self, registry: CollectorRegistry | None) -> None: + self._client = redis_client() + if registry: + registry.register(self) + + def _iter_values(self) -> Iterable[tuple[bytes, str]]: + cursor = 0 + while True: + cursor, keys = self._client.scan(cursor=cursor, match="value:*") + values = self._client.mget(keys) + yield from zip(keys, values) + if cursor == 0: + break + + def collect(self) -> Iterable[Metric]: + metrics: dict[str, Metric] = {} + histograms: set[str] = set() + multiprocess: dict[str, MULTIPROCESS_MODE_T] = {} + + for key, value_s in self._iter_values(): + # FIXME: Catch ValueError here, just in case? + prefix_b, typ_b, multiprocess_mode_b, mmap_key = key.split(b":", 3) + assert prefix_b == b"value" + value = float(value_s) + + metric_name, name, labels, help_text = json.loads(mmap_key) + + metric = metrics.get(metric_name) + if metric is None: + typ = typ_b.decode() + metric = Metric(metric_name, help_text, typ) + metrics[metric_name] = metric + + if typ in ("histogram", "gaugehistogram"): + histograms.add(metric_name) + + multiprocess_mode = cast( + MULTIPROCESS_MODE_T, multiprocess_mode_b.decode() + ) + if typ in ("gauge", "gaugehistogram") and multiprocess_mode: + multiprocess[metric_name] = multiprocess_mode + + metric.add_sample(name, labels, value) + + for name, multiprocess_mode in multiprocess.items(): + self._accumulate_multiprocess(metrics[name], multiprocess_mode) + + for name in histograms: + self._fix_histogram(metrics[name]) + + return metrics.values() + + def _accumulate_multiprocess( + self, metric: Metric, multiprocess_mode: MULTIPROCESS_MODE_T + ) -> None: + """Merge metrics from multiple processes using multiprocess_mode.""" + # We deal with live/dead with Redis expiry + if multiprocess_mode.startswith("live"): + multiprocess_mode = cast( + MULTIPROCESS_MODE_T, multiprocess_mode[len("live") :] + ) + if multiprocess_mode == "all": + return + + by_label: dict[tuple[tuple[str, ...], str], Sample] = {} + + for sample in metric.samples: + labels = sample.labels + if "pid" in sample.labels: + labels = labels.copy() + labels.pop("pid") + key = (tuple(labels.values()), sample.name) + value = sample.value + if key in by_label: + current_value = by_label[key].value + if multiprocess_mode == "min" and value > current_value: + continue + if multiprocess_mode == "max" and value < current_value: + continue + if multiprocess_mode == "sum": + value += current_value + if multiprocess_mode == "mostrecent": + raise NotImplementedError( + "The 'mostrecent' modes are not supported in RedisCollector" + ) + by_label[key] = Sample(sample.name, labels, value) + + metric.samples = list(by_label.values()) + + def _fix_histogram(self, metric: Metric) -> None: + """ + Fix-up histogram samples. + + Sort the buckets as expected by a client, and accumulate the values. + The Histogram class is optimized to only increment the bucket that a + value first appears in, not larger ones that would also contain it. + """ + by_label: dict[tuple[tuple[str, ...], str], list[Sample]] = {} + + # Organize into lists of samples by label + for sample in metric.samples: + if "le" in sample.labels: + labels_without_le = sample.labels.copy() + labels_without_le.pop("le") + key = (tuple(labels_without_le.values()), sample.name) + else: + key = (tuple(sample.labels.values()), sample.name) + by_label.setdefault(key, []).append(sample) + + metric.samples = [] + + for (labels, name), samples in sorted(by_label.items()): + if name.endswith("_bucket"): + # Sort buckets within each label + samples.sort(key=lambda sample: float(sample.labels["le"])) + + # Accumulate values into larger buckets + value = 0.0 + for sample in samples: + value += sample.value + metric.samples.append(Sample(sample.name, sample.labels, value)) + + labels_without_le = sample.labels.copy() + labels_without_le.pop("le") + metric.samples.append( + Sample(f"{metric.name}_count", labels_without_le, value) + ) + + else: + metric.samples.extend(samples) diff --git a/prometheus_client/values.py b/prometheus_client/values.py index 6ff85e3b..e6e1f99d 100644 --- a/prometheus_client/values.py +++ b/prometheus_client/values.py @@ -1,11 +1,64 @@ import os -from threading import Lock import warnings +from collections.abc import Callable, Sequence +from datetime import timedelta +from threading import Lock +from typing import Any, Protocol, Literal from .mmap_dict import mmap_key, MmapedDict +from .redis import redis_client, _keep_key_from_expiring, _key_expiry +from .samples import Exemplar + +MULTIPROCESS_MODE_T = Literal[ + "all", + "liveall", + "min", + "livemin", + "max", + "livemax", + "sum", + "livesum", + "mostrecent", + "livemostrecent", + "", +] + + +class Value(Protocol): + """Prometheus Client Metric implementation.""" + + _multiprocess: bool + + def __init__( + self, + typ: str, + metric_name: str, + name: str, + labelnames: Sequence[str], + labelvalues: Sequence[str], + help_text: str, + **kwargs: Any, + ) -> None: + """Initialize a metric.""" + + def inc(self, amount: float) -> None: + """Increment the metric by amount.""" + def set(self, value: float, timestamp: float | None = None) -> None: + """Set the metric to value.""" -class MutexValue: + def get(self) -> float: + """Get the current metric value.""" + + def set_exemplar(self, exemplar: Exemplar) -> None: + """Set an exemplar value.""" + exemplar # For vulture + + def get_exemplar(self) -> Exemplar | None: + """Get any set exemplar value.""" + + +class MutexValue(Value): """A float protected by a mutex.""" _multiprocess = False @@ -52,7 +105,7 @@ def MultiProcessValue(process_identifier=os.getpid): # This avoids the need to also have mutexes in __MmapDict. lock = Lock() - class MmapedValue: + class MmapedValue(Value): """A float protected by a mutex backed by a per-process mmaped file.""" _multiprocess = True @@ -125,12 +178,128 @@ def get_exemplar(self): return MmapedValue -def get_value_class(): +def RedisValue(process_identifier: Callable[[], str | int] = os.getpid) -> type[Value]: + + class RedisValueImpl(Value): + """ + A value implementation that stores data in a redis/valkey database. + + Key scheme: + * value:typ:MMAP_KEY + + When a live multiprocess_mode is used, we set the key to expire after + PROMETHEUS_REDIS_REFRESH_TTL seconds. We launch a daemon thread that + extends the expiry of all our process' keys every + PROMETHEUS_REDIS_REFRESH_FREQUENCY. + """ + + _multiprocess: bool = True + + _typ: str + _metric_name: str + _name: str + _labelnames: list[str] + _labelvalues: list[str] + _help_text: str + _multiprocess_mode: MULTIPROCESS_MODE_T + _expiry: timedelta | None + + _key: str + + def __init__( + self, + typ: str, + metric_name: str, + name: str, + labelnames: Sequence[str], + labelvalues: Sequence[str], + help_text: str, + multiprocess_mode: MULTIPROCESS_MODE_T = "", + **kwargs: Any, + ) -> None: + self._typ = typ + self._metric_name = metric_name + self._name = name + self._labelnames = list(labelnames) + self._labelvalues = list(labelvalues) + self._help_text = help_text + self._multiprocess_mode = multiprocess_mode + self._expiry = None + if multiprocess_mode: + if multiprocess_mode in ("mostrecent", "livemostrecent"): + raise NotImplementedError( + "The 'mostrecent' modes are not supported in RedisValue" + ) + assert typ in ("gauge", "gaugehistogram") + self._labelnames.append("pid") + self._labelvalues.append("") + if multiprocess_mode.startswith("live"): + self._expiry = _key_expiry() + self._update_key(True) + redis_client().set(self._key, 0.0, ex=self._expiry, nx=True) + + def _update_key(self, update: bool = False) -> None: + if self._multiprocess_mode: + assert self._labelnames[-1] == "pid" + new_id = str(process_identifier()) + if new_id != self._labelvalues[-1]: + self._labelvalues[-1] = new_id + update = True + + if update: + key = mmap_key( + self._metric_name, + self._name, + self._labelnames, + self._labelvalues, + self._help_text, + ) + self._key = f"value:{self._typ}:{self._multiprocess_mode}:{key}" + + if self._expiry and update: + _keep_key_from_expiring(self._labelvalues[-1], self._key) + + def inc(self, amount: float) -> None: + self._update_key() + client = redis_client() + client.incrbyfloat(self._key, amount) + if self._expiry: + client.expire(self._key, self._expiry) + + def set(self, value: float, timestamp: float | None = None) -> None: + self._update_key() + # TODO: Implement timestamps + redis_client().set(self._key, value, ex=self._expiry) + + def get(self) -> float: + self._update_key() + value = redis_client().get(self._key) + if value is None: + return 0.0 + return float(value) + + def set_exemplar(self, exemplar: Exemplar) -> None: + # TODO: Implement exemplars for redis. + return + + def get_exemplar(self) -> Exemplar | None: + # TODO: Implement exemplars for redis. + return None + + return RedisValueImpl + + +def get_value_class() -> type[Value]: # Should we enable multi-process mode? # This needs to be chosen before the first metric is constructed, # and as that may be in some arbitrary library the user/admin has # no control over we use an environment variable. - if 'prometheus_multiproc_dir' in os.environ or 'PROMETHEUS_MULTIPROC_DIR' in os.environ: + if "PROMETHEUS_REDIS_URL" in os.environ: + return RedisValue() + elif ( + "prometheus_multiproc_dir" in os.environ + or "PROMETHEUS_MULTIPROC_DIR" in os.environ + ): return MultiProcessValue() else: return MutexValue diff --git a/pyproject.toml b/pyproject.toml index 336cfb4f..915bf863 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,9 @@ aiohttp = [ django = [ "django", ] +redis = [ + "redis", +] [project.urls] Homepage = "https://github.com/prometheus/client_python" diff --git a/tests/test_redis.py b/tests/test_redis.py new file mode 100644 index 00000000..54648ebd --- /dev/null +++ b/tests/test_redis.py @@ -0,0 +1,623 @@ +import os +import unittest +import warnings +from collections.abc import Sequence +from datetime import timedelta +from time import time +from typing import Any + +import pytest + +from prometheus_client import values +from prometheus_client.core import ( + CollectorRegistry, + Counter, + Gauge, + Histogram, + Sample, + Summary, +) +from prometheus_client.redis import ( + mark_process_dead, + redis_client, + _daemon_threads, + _live_metrics, +) +from prometheus_client.redis_collector import RedisCollector +from prometheus_client.values import ( + MULTIPROCESS_MODE_T, + MutexValue, + RedisValue, + Value, + get_value_class, +) + +pytest.importorskip("redis") + + +class RedisTestCase(unittest.TestCase): + def setUp(self) -> None: + os.environ["PROMETHEUS_REDIS_URL"] = "fakeredis://localhost/42" + values.ValueClass = RedisValue(lambda: 123) + + def tearDown(self) -> None: + for identifier in list(_daemon_threads): + mark_process_dead(identifier) + redis_client().flushdb() + del os.environ["PROMETHEUS_REDIS_URL"] + values.ValueClass = MutexValue + + +class ValueTestCase(RedisTestCase): + def create_value( + self, + metric_name: str = "test", + name: str | None = None, + type_: str = "counter", + labelnames: list[str] | None = None, + labelvalues: list[str] | None = None, + multiprocess_mode: MULTIPROCESS_MODE_T = "", + ) -> Value: + return values.ValueClass( + type_, + metric_name, + name or metric_name + "_total", + labelnames or [], + labelvalues or [], + "Help Text", + multiprocess_mode=multiprocess_mode, + ) + + def test_initializes_value(self) -> None: + value = self.create_value() + self.assertEqual(value.get(), 0.0) + + def test_sets_and_gets_value(self) -> None: + value = self.create_value() + value.set(5) + self.assertEqual(value.get(), 5.0) + + def test_inc_value(self) -> None: + value = self.create_value() + value.inc(3) + value.inc(5) + self.assertEqual(value.get(), 8.0) + + def test_differentiated_by_name(self) -> None: + v1 = self.create_value("value1") + v2 = self.create_value("value2") + v1.set(1) + v2.set(2) + self.assertEqual(v1.get(), 1.0) + self.assertEqual(v2.get(), 2.0) + + def test_differentiated_by_labels(self) -> None: + v1 = self.create_value("value3", labelnames=["a"], labelvalues=["1"]) + v2 = self.create_value("value3", labelnames=["a"], labelvalues=["2"]) + v1.set(1) + v2.set(2) + self.assertEqual(v1.get(), 1.0) + self.assertEqual(v2.get(), 2.0) + + def test_multiprocess_mode_mostrecent(self) -> None: + with self.assertRaises(NotImplementedError): + self.create_value(type_="gauge", multiprocess_mode="mostrecent") + + def test_multiprocess_mode_counter(self) -> None: + with self.assertRaises(AssertionError): + self.create_value(type_="counter", multiprocess_mode="liveall") + + def test_multiprocess_mode(self) -> None: + value = self.create_value(type_="gauge", multiprocess_mode="all") + self.assertEqual(value._labelnames, ["pid"]) + self.assertEqual(value._labelvalues[-1], "123") + self.assertIsNone(value._expiry) + self.assertEqual(redis_client().expiretime(value._key), -1) + + def test_multiprocess_mode_live(self) -> None: + value = self.create_value(type_="gauge", multiprocess_mode="liveall") + unixtime = time() + self.assertEqual(value._labelnames, ["pid"]) + self.assertEqual(value._labelvalues[-1], "123") + self.assertEqual(value._expiry, timedelta(seconds=20)) + expiretime = redis_client().expiretime(value._key) + self.assertGreater(expiretime, unixtime) + self.assertLessEqual(expiretime, unixtime + 20) + + self.assertIn("123", _daemon_threads) + self.assertTrue(_daemon_threads["123"].is_alive()) + self.assertIn("123", _live_metrics) + self.assertIn(value._key, _live_metrics["123"]) + + def test_live_inc_updates_expiry(self) -> None: + value = self.create_value(type_="gauge", multiprocess_mode="liveall") + unixtime = time() + redis_client().persist(value._key) + self.assertEqual(redis_client().expiretime(value._key), -1) + + value.inc(1) + self.assertGreater(redis_client().expiretime(value._key), unixtime) + + def test_live_set_updates_expiry(self) -> None: + value = self.create_value(type_="gauge", multiprocess_mode="liveall") + unixtime = time() + redis_client().persist(value._key) + self.assertEqual(redis_client().expiretime(value._key), -1) + + value.set(1) + self.assertGreater(redis_client().expiretime(value._key), unixtime) + + def test_multiprocess_pid_change(self) -> None: + pid = 1 + values.ValueClass = RedisValue(lambda: pid) + + value = self.create_value(type_="gauge", multiprocess_mode="all") + self.assertEqual(value._labelnames[-1], "pid") + self.assertEqual(value._labelvalues[-1], "1") + value.inc(1) + self.assertEqual(value.get(), 1.0) + + pid = 2 + value.inc(1) + self.assertEqual(value._labelvalues[-1], "2") + self.assertEqual(value.get(), 1.0) + + +class TestRedis(RedisTestCase): + def setUp(self) -> None: + super().setUp() + self.registry = CollectorRegistry(support_collectors_without_names=True) + self.collector = RedisCollector(self.registry) + + def test_counter_adds(self) -> None: + c1 = Counter("c", "help", registry=None) + c2 = Counter("c", "help", registry=None) + self.assertEqual(0, self.registry.get_sample_value("c_total")) + c1.inc(1) + c2.inc(2) + self.assertEqual(3, self.registry.get_sample_value("c_total")) + + def test_summary_adds(self) -> None: + s1 = Summary("s", "help", registry=None) + values.ValueClass = RedisValue(lambda: 456) + s2 = Summary("s", "help", registry=None) + self.assertEqual(0, self.registry.get_sample_value("s_count")) + self.assertEqual(0, self.registry.get_sample_value("s_sum")) + s1.observe(1) + s2.observe(2) + self.assertEqual(2, self.registry.get_sample_value("s_count")) + self.assertEqual(3, self.registry.get_sample_value("s_sum")) + + def test_histogram_adds(self) -> None: + h1 = Histogram("h", "help", registry=None) + values.ValueClass = RedisValue(lambda: 456) + h2 = Histogram("h", "help", registry=None) + self.assertEqual(0, self.registry.get_sample_value("h_count")) + self.assertEqual(0, self.registry.get_sample_value("h_sum")) + self.assertEqual(0, self.registry.get_sample_value("h_bucket", {"le": "5.0"})) + h1.observe(1) + h2.observe(2) + self.assertEqual(2, self.registry.get_sample_value("h_count")) + self.assertEqual(3, self.registry.get_sample_value("h_sum")) + self.assertEqual(2, self.registry.get_sample_value("h_bucket", {"le": "5.0"})) + + def test_gauge_all(self) -> None: + g1 = Gauge("g", "help", registry=None) + values.ValueClass = RedisValue(lambda: 456) + g2 = Gauge("g", "help", registry=None) + self.assertEqual(0, self.registry.get_sample_value("g", {"pid": "123"})) + self.assertEqual(0, self.registry.get_sample_value("g", {"pid": "456"})) + g1.set(1) + g2.set(2) + mark_process_dead(123) + self.assertEqual(1, self.registry.get_sample_value("g", {"pid": "123"})) + self.assertEqual(2, self.registry.get_sample_value("g", {"pid": "456"})) + + def test_gauge_liveall(self) -> None: + g1 = Gauge("g", "help", registry=None, multiprocess_mode="liveall") + values.ValueClass = RedisValue(lambda: 456) + g2 = Gauge("g", "help", registry=None, multiprocess_mode="liveall") + self.assertEqual(0, self.registry.get_sample_value("g", {"pid": "123"})) + self.assertEqual(0, self.registry.get_sample_value("g", {"pid": "456"})) + g1.set(1) + g2.set(2) + self.assertEqual(1, self.registry.get_sample_value("g", {"pid": "123"})) + self.assertEqual(2, self.registry.get_sample_value("g", {"pid": "456"})) + mark_process_dead(123) + self.assertEqual(None, self.registry.get_sample_value("g", {"pid": "123"})) + self.assertEqual(2, self.registry.get_sample_value("g", {"pid": "456"})) + + def test_gauge_min(self) -> None: + g1 = Gauge("g", "help", registry=None, multiprocess_mode="min") + values.ValueClass = RedisValue(lambda: 456) + g2 = Gauge("g", "help", registry=None, multiprocess_mode="min") + self.assertEqual(0, self.registry.get_sample_value("g")) + g1.set(1) + g2.set(2) + self.assertEqual(1, self.registry.get_sample_value("g")) + + def test_gauge_livemin(self): + g1 = Gauge("g", "help", registry=None, multiprocess_mode="livemin") + values.ValueClass = RedisValue(lambda: 456) + g2 = Gauge("g", "help", registry=None, multiprocess_mode="livemin") + self.assertEqual(0, self.registry.get_sample_value("g")) + g1.set(1) + g2.set(2) + self.assertEqual(1, self.registry.get_sample_value("g")) + mark_process_dead(123) + self.assertEqual(2, self.registry.get_sample_value("g")) + + def test_gauge_max(self) -> None: + g1 = Gauge("g", "help", registry=None, multiprocess_mode="max") + values.ValueClass = RedisValue(lambda: 456) + g2 = Gauge("g", "help", registry=None, multiprocess_mode="max") + self.assertEqual(0, self.registry.get_sample_value("g")) + g1.set(1) + g2.set(2) + self.assertEqual(2, self.registry.get_sample_value("g")) + + def test_gauge_livemax(self) -> None: + g1 = Gauge("g", "help", registry=None, multiprocess_mode="livemax") + values.ValueClass = RedisValue(lambda: 456) + g2 = Gauge("g", "help", registry=None, multiprocess_mode="livemax") + self.assertEqual(0, self.registry.get_sample_value("g")) + g1.set(2) + g2.set(1) + self.assertEqual(2, self.registry.get_sample_value("g")) + mark_process_dead(123) + self.assertEqual(1, self.registry.get_sample_value("g")) + + def test_gauge_sum(self) -> None: + g1 = Gauge("g", "help", registry=None, multiprocess_mode="sum") + values.ValueClass = RedisValue(lambda: 456) + g2 = Gauge("g", "help", registry=None, multiprocess_mode="sum") + self.assertEqual(0, self.registry.get_sample_value("g")) + g1.set(1) + g2.set(2) + self.assertEqual(3, self.registry.get_sample_value("g")) + mark_process_dead(123) + self.assertEqual(3, self.registry.get_sample_value("g")) + + def test_gauge_livesum(self) -> None: + g1 = Gauge("g", "help", registry=None, multiprocess_mode="livesum") + values.ValueClass = RedisValue(lambda: 456) + g2 = Gauge("g", "help", registry=None, multiprocess_mode="livesum") + self.assertEqual(0, self.registry.get_sample_value("g")) + g1.set(1) + g2.set(2) + self.assertEqual(3, self.registry.get_sample_value("g")) + mark_process_dead(123) + self.assertEqual(2, self.registry.get_sample_value("g")) + + def xxx_test_gauge_mostrecent(self) -> None: + g1 = Gauge("g", "help", registry=None, multiprocess_mode="mostrecent") + values.ValueClass = RedisValue(lambda: 456) + g2 = Gauge("g", "help", registry=None, multiprocess_mode="mostrecent") + g2.set(2) + g1.set(1) + self.assertEqual(1, self.registry.get_sample_value("g")) + mark_process_dead(123) + self.assertEqual(1, self.registry.get_sample_value("g")) + + def xxx_test_gauge_livemostrecent(self) -> None: + g1 = Gauge("g", "help", registry=None, multiprocess_mode="livemostrecent") + values.ValueClass = RedisValue(lambda: 456) + g2 = Gauge("g", "help", registry=None, multiprocess_mode="livemostrecent") + g2.set(2) + g1.set(1) + self.assertEqual(1, self.registry.get_sample_value("g")) + mark_process_dead(123) + self.assertEqual(2, self.registry.get_sample_value("g")) + + def test_namespace_subsystem(self) -> None: + c1 = Counter("c", "help", registry=None, namespace="ns", subsystem="ss") + c1.inc(1) + self.assertEqual(1, self.registry.get_sample_value("ns_ss_c_total")) + + def test_counter_across_forks(self) -> None: + pid = 0 + values.ValueClass = RedisValue(lambda: pid) + c1 = Counter("c", "help", registry=None) + self.assertEqual(0, self.registry.get_sample_value("c_total")) + c1.inc(1) + c1.inc(1) + pid = 1 + c1.inc(1) + self.assertEqual(3, self.registry.get_sample_value("c_total")) + # Unlike MultiProcessValue, we don't store any local state + self.assertEqual(3, c1._value.get()) + + def test_collect(self) -> None: + pid = 0 + values.ValueClass = RedisValue(lambda: pid) + labels = {i: i for i in "abcd"} + + def add_label(key: str, value: str) -> dict[str, str]: + l = labels.copy() + l[key] = value + return l + + c = Counter("c", "help", labelnames=labels.keys(), registry=None) + g = Gauge("g", "help", labelnames=labels.keys(), registry=None) + h = Histogram("h", "help", labelnames=labels.keys(), registry=None) + + c.labels(**labels).inc(1) + g.labels(**labels).set(1) + h.labels(**labels).observe(1) + + pid = 1 + + c.labels(**labels).inc(1) + g.labels(**labels).set(1) + h.labels(**labels).observe(5) + + metrics = {m.name: m for m in self.collector.collect()} + + self.assertEqual(metrics["c"].samples, [Sample("c_total", labels, 2.0)]) + metrics["g"].samples.sort(key=lambda x: x[1]["pid"]) + self.assertEqual( + metrics["g"].samples, + [ + Sample("g", add_label("pid", "0"), 1.0), + Sample("g", add_label("pid", "1"), 1.0), + ], + ) + + expected_histogram = [ + Sample("h_bucket", add_label("le", "0.005"), 0.0), + Sample("h_bucket", add_label("le", "0.01"), 0.0), + Sample("h_bucket", add_label("le", "0.025"), 0.0), + Sample("h_bucket", add_label("le", "0.05"), 0.0), + Sample("h_bucket", add_label("le", "0.075"), 0.0), + Sample("h_bucket", add_label("le", "0.1"), 0.0), + Sample("h_bucket", add_label("le", "0.25"), 0.0), + Sample("h_bucket", add_label("le", "0.5"), 0.0), + Sample("h_bucket", add_label("le", "0.75"), 0.0), + Sample("h_bucket", add_label("le", "1.0"), 1.0), + Sample("h_bucket", add_label("le", "2.5"), 1.0), + Sample("h_bucket", add_label("le", "5.0"), 2.0), + Sample("h_bucket", add_label("le", "7.5"), 2.0), + Sample("h_bucket", add_label("le", "10.0"), 2.0), + Sample("h_bucket", add_label("le", "+Inf"), 2.0), + Sample("h_count", labels, 2.0), + Sample("h_sum", labels, 6.0), + ] + + self.assertEqual(metrics["h"].samples, expected_histogram) + + def test_collect_histogram_ordering(self) -> None: + pid = 0 + values.ValueClass = RedisValue(lambda: pid) + + h = Histogram("h", "help", labelnames=["view"], registry=None) + + h.labels(view="view1").observe(1) + + pid = 1 + + h.labels(view="view1").observe(5) + h.labels(view="view2").observe(1) + + metrics = {m.name: m for m in self.collector.collect()} + + expected_histogram = [ + Sample("h_bucket", {"view": "view1", "le": "0.005"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "0.01"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "0.025"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "0.05"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "0.075"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "0.1"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "0.25"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "0.5"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "0.75"}, 0.0), + Sample("h_bucket", {"view": "view1", "le": "1.0"}, 1.0), + Sample("h_bucket", {"view": "view1", "le": "2.5"}, 1.0), + Sample("h_bucket", {"view": "view1", "le": "5.0"}, 2.0), + Sample("h_bucket", {"view": "view1", "le": "7.5"}, 2.0), + Sample("h_bucket", {"view": "view1", "le": "10.0"}, 2.0), + Sample("h_bucket", {"view": "view1", "le": "+Inf"}, 2.0), + Sample("h_count", {"view": "view1"}, 2.0), + Sample("h_sum", {"view": "view1"}, 6.0), + Sample("h_bucket", {"view": "view2", "le": "0.005"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "0.01"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "0.025"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "0.05"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "0.075"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "0.1"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "0.25"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "0.5"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "0.75"}, 0.0), + Sample("h_bucket", {"view": "view2", "le": "1.0"}, 1.0), + Sample("h_bucket", {"view": "view2", "le": "2.5"}, 1.0), + Sample("h_bucket", {"view": "view2", "le": "5.0"}, 1.0), + Sample("h_bucket", {"view": "view2", "le": "7.5"}, 1.0), + Sample("h_bucket", {"view": "view2", "le": "10.0"}, 1.0), + Sample("h_bucket", {"view": "view2", "le": "+Inf"}, 1.0), + Sample("h_count", {"view": "view2"}, 1.0), + Sample("h_sum", {"view": "view2"}, 1.0), + ] + + self.assertEqual(metrics["h"].samples, expected_histogram) + + def test_restrict(self) -> None: + pid = 0 + values.ValueClass = RedisValue(lambda: pid) + labels = {i: i for i in "abcd"} + + c = Counter("c", "help", labelnames=labels.keys(), registry=None) + g = Gauge("g", "help", labelnames=labels.keys(), registry=None) + + c.labels(**labels).inc(1) + g.labels(**labels).set(1) + + pid = 1 + + c.labels(**labels).inc(1) + g.labels(**labels).set(1) + + metrics = { + m.name: m for m in self.registry.restricted_registry(["c_total"]).collect() + } + + self.assertEqual(metrics.keys(), {"c"}) + + self.assertEqual(metrics["c"].samples, [Sample("c_total", labels, 2.0)]) + + def test_collect_preserves_help(self) -> None: + pid = 0 + values.ValueClass = RedisValue(lambda: pid) + labels = {i: i for i in "abcd"} + + c = Counter("c", "c help", labelnames=labels.keys(), registry=None) + g = Gauge("g", "g help", labelnames=labels.keys(), registry=None) + h = Histogram("h", "h help", labelnames=labels.keys(), registry=None) + + c.labels(**labels).inc(1) + g.labels(**labels).set(1) + h.labels(**labels).observe(1) + + pid = 1 + + c.labels(**labels).inc(1) + g.labels(**labels).set(1) + h.labels(**labels).observe(5) + + metrics = {m.name: m for m in self.collector.collect()} + + self.assertEqual(metrics["c"].documentation, "c help") + self.assertEqual(metrics["g"].documentation, "g help") + self.assertEqual(metrics["h"].documentation, "h help") + + def test_remove_clear_warning(self) -> None: + with warnings.catch_warnings(record=True) as w: + values.ValueClass = get_value_class() + registry = CollectorRegistry() + collector = RedisCollector(registry) + counter = Counter("c", "help", labelnames=["label"], registry=None) + counter.labels("label").inc() + counter.remove("label") + counter.clear() + assert issubclass(w[0].category, UserWarning) + assert "Removal of labels has not been implemented" in str(w[0].message) + assert issubclass(w[-1].category, UserWarning) + assert "Clearing of labels has not been implemented" in str(w[-1].message) + + def test_child_name_is_built_once_with_namespace_subsystem_unit(self) -> None: + """ + Repro for #1035: + In multiprocess mode, child metrics must NOT rebuild the full name + (namespace/subsystem/unit) a second time. The exported family name should + be built once, and Counter samples should use "_total". + """ + from prometheus_client import Counter + + class CustomCounter(Counter): + def __init__( + self, + name: str, + documentation: str, + labelnames: Sequence[str] = (), + namespace: str = "mydefaultnamespace", + subsystem: str = "mydefaultsubsystem", + unit: str = "", + registry: CollectorRegistry | None = None, + _labelvalues: Sequence[str] | None = None, + ): + # Intentionally provide non-empty defaults to trigger the bug path. + super().__init__( + name=name, + documentation=documentation, + labelnames=labelnames, + namespace=namespace, + subsystem=subsystem, + unit=unit, + registry=registry, + _labelvalues=_labelvalues, + ) + + # Create a Counter with explicit namespace/subsystem/unit + c = CustomCounter( + name="m", + documentation="help", + labelnames=("status", "method"), + namespace="ns", + subsystem="ss", + unit="seconds", # avoid '_total_total' confusion + registry=None, # not registered in local registry in multiprocess mode + ) + + # Create two labeled children + c.labels(status="200", method="GET").inc() + c.labels(status="404", method="POST").inc() + + # Collect from the multiprocess collector initialized in setUp() + metrics = {m.name: m for m in self.collector.collect()} + + # Family name should be built once (no '_total' in family name) + expected_family = "ns_ss_m_seconds" + self.assertIn(expected_family, metrics, f"missing family {expected_family}") + + # Counter samples must use '_total' + mf = metrics[expected_family] + sample_names = {s.name for s in mf.samples} + self.assertTrue( + all(name == expected_family + "_total" for name in sample_names), + f"unexpected sample names: {sample_names}", + ) + + # Ensure no double-built prefix sneaks in (the original bug) + bad_prefix = "mydefaultnamespace_mydefaultsubsystem_" + all_names = {mf.name, *sample_names} + self.assertTrue( + all(not n.startswith(bad_prefix) for n in all_names), + f"found double-built name(s): {[n for n in all_names if n.startswith(bad_prefix)]}", + ) + + def test_child_preserves_parent_context_for_subclasses(self) -> None: + """ + Ensure child metrics preserve parent's namespace/subsystem/unit information + so that subclasses can correctly use these parameters in their logic. + """ + + class ContextAwareCounter(Counter): + def __init__( + self, + name: str, + documentation: str, + labelnames: Sequence[str] = (), + namespace: str = "", + subsystem: str = "", + unit: str = "", + **kwargs: Any, + ): + self.context = { + "namespace": namespace, + "subsystem": subsystem, + "unit": unit, + } + super().__init__( + name, + documentation, + labelnames=labelnames, + namespace=namespace, + subsystem=subsystem, + unit=unit, + **kwargs, + ) + + parent = ContextAwareCounter( + "m", + "help", + labelnames=["status"], + namespace="prod", + subsystem="api", + unit="seconds", + registry=None, + ) + + child = parent.labels(status="200") + + # Verify that child retains parent's context + self.assertEqual(child.context["namespace"], "prod") + self.assertEqual(child.context["subsystem"], "api") + self.assertEqual(child.context["unit"], "seconds") diff --git a/tox.ini b/tox.ini index 992bd0a7..0a6dff73 100644 --- a/tox.ini +++ b/tox.ini @@ -8,6 +8,8 @@ deps = pytest pytest-benchmark attrs + fakeredis + redis {py3.9,pypy3.9}: twisted {py3.9,pypy3.9}: aiohttp {py3.9,pypy3.9}: django