diff --git a/Doc/library/collections.rst b/Doc/library/collections.rst index cb9300f072b9e7..0b750bfca61be5 100644 --- a/Doc/library/collections.rst +++ b/Doc/library/collections.rst @@ -730,7 +730,7 @@ stack manipulations such as ``dup``, ``drop``, ``swap``, ``over``, ``pick``, defaultdict(default_factory, iterable, /, **kwargs) Return a new dictionary-like object. :class:`defaultdict` is a subclass of the - built-in :class:`dict` class. It overrides one method and adds one writable + built-in :class:`dict` class. It defines two methods and adds one writable instance variable. The remaining functionality is the same as for the :class:`dict` class and is not documented here. @@ -740,33 +740,39 @@ stack manipulations such as ``dup``, ``drop``, ``swap``, ``over``, ``pick``, arguments. - :class:`defaultdict` objects support the following method in addition to the - standard :class:`dict` operations: + :class:`defaultdict` defines the following methods: - .. method:: __missing__(key, /) + .. method:: __getitem__(key, /) + + Return ``self[key]``. If the item doesn't exist, the :meth:`__missing__` + method is called to create it. - If the :attr:`default_factory` attribute is ``None``, this raises a - :exc:`KeyError` exception with the *key* as argument. + When :term:`free threading` is enabled, the defaultdict is locked while + the key is being looked up and the :meth:`__missing__` method is being + called, thus ensuring that only one default value is generated and + inserted for each missing key. - If :attr:`default_factory` is not ``None``, it is called without arguments - to provide a default value for the given *key*, this value is inserted in - the dictionary for the *key*, and returned. + .. method:: __missing__(key, /) - If calling :attr:`default_factory` raises an exception this exception is - propagated unchanged. + Equivalent to:: - This method is called by the :meth:`~object.__getitem__` method of the - :class:`dict` class when the requested key is not found; whatever it - returns or raises is then returned or raised by :meth:`~object.__getitem__`. + if self.default_factory is None: + raise KeyError(key) + self[key] = value = self.default_factory() + return value - Note that :meth:`__missing__` is *not* called for any operations besides - :meth:`~object.__getitem__`. This means that :meth:`~dict.get` will, like - normal dictionaries, return ``None`` as a default rather than using + Keep in mind that this method is *not* called for any operations besides + ``dd[key]``. This means that ``dd.get(key)`` will, like normal + dictionaries, return ``None`` as a default rather than using :attr:`default_factory`. + A direct call to this method (meaning a call that isn't coming from + :meth:`__getitem__`) can create a :term:`race condition`. To reset an + item to a default value the next time it's accessed, use the + :meth:`~dict.pop` method to safely remove the current value. - :class:`defaultdict` objects support the following instance variable: + :class:`defaultdict` objects support the following instance variable: .. attribute:: default_factory @@ -774,9 +780,14 @@ stack manipulations such as ``dup``, ``drop``, ``swap``, ``over``, ``pick``, it is initialized from the first argument to the constructor, if present, or to ``None``, if absent. + .. versionchanged:: 3.9 - Added merge (``|``) and update (``|=``) operators, specified in - :pep:`584`. + Added merge (``|``) and update (``|=``) operators, specified in + :pep:`584`. + + .. versionchanged:: 3.15 + Added the :meth:`__getitem__` method which is safe to use with + :term:`free threading` enabled. :class:`defaultdict` Examples diff --git a/Include/internal/pycore_critical_section.h b/Include/internal/pycore_critical_section.h index 2a2846b1296b90..415ef273223860 100644 --- a/Include/internal/pycore_critical_section.h +++ b/Include/internal/pycore_critical_section.h @@ -95,6 +95,9 @@ _PyCriticalSection2_BeginSlow(PyThreadState *tstate, PyCriticalSection2 *c, PyMu PyAPI_FUNC(void) _PyCriticalSection_SuspendAll(PyThreadState *tstate); +int +_PyCriticalSection_WarnIfNotHeld(PyObject *op, const char *message); + #ifdef Py_GIL_DISABLED static inline int diff --git a/Include/internal/pycore_dict.h b/Include/internal/pycore_dict.h index 6c6e3b77e69fab..56a29c586f67f5 100644 --- a/Include/internal/pycore_dict.h +++ b/Include/internal/pycore_dict.h @@ -123,6 +123,8 @@ PyAPI_FUNC(Py_ssize_t) _Py_dict_lookup(PyDictObject *mp, PyObject *key, Py_hash_ extern Py_ssize_t _Py_dict_lookup_threadsafe(PyDictObject *mp, PyObject *key, Py_hash_t hash, PyObject **value_addr); extern Py_ssize_t _Py_dict_lookup_threadsafe_stackref(PyDictObject *mp, PyObject *key, Py_hash_t hash, _PyStackRef *value_addr); +extern void _Py_dict_unhashable_type(PyObject *op, PyObject *key); + extern int _PyDict_GetMethodStackRef(PyDictObject *dict, PyObject *name, _PyStackRef *method); // Exported for external JIT support diff --git a/Lib/test/test_defaultdict.py b/Lib/test/test_defaultdict.py index a193eb10f16d17..4fb01486891f4d 100644 --- a/Lib/test/test_defaultdict.py +++ b/Lib/test/test_defaultdict.py @@ -2,7 +2,9 @@ import copy import pickle +import sys import unittest +import warnings from collections import defaultdict @@ -48,10 +50,16 @@ def test_basic(self): self.assertRaises(TypeError, defaultdict, 1) def test_missing(self): - d1 = defaultdict() - self.assertRaises(KeyError, d1.__missing__, 42) - d1.default_factory = list - self.assertEqual(d1.__missing__(42), []) + with warnings.catch_warnings(record=True, action='always') as w: + d1 = defaultdict() + self.assertRaises(KeyError, d1.__missing__, 42) + d1.default_factory = list + v1 = d1.__missing__(42) + self.assertEqual(v1, []) + v2 = d1.__missing__(42) + self.assertEqual(v2, []) + self.assertIsNot(v2, v1) + self.assertEqual(len(w), 0 if sys._is_gil_enabled() else 3) def test_repr(self): d1 = defaultdict() @@ -186,7 +194,7 @@ def test_union(self): with self.assertRaises(TypeError): i |= None - def test_factory_conflict_with_set_value(self): + def test_reentering_getitem_method(self): key = "conflict_test" count = 0 @@ -201,7 +209,7 @@ def default_factory(): test_dict = defaultdict(default_factory) self.assertEqual(count, 0) - self.assertEqual(test_dict[key], 2) + self.assertEqual(test_dict[key], 1) self.assertEqual(count, 2) def test_repr_recursive_factory(self): diff --git a/Misc/NEWS.d/next/Library/2026-04-21-12-06-41.gh-issue-148242.eCy0eS.rst b/Misc/NEWS.d/next/Library/2026-04-21-12-06-41.gh-issue-148242.eCy0eS.rst new file mode 100644 index 00000000000000..377d206a59a7be --- /dev/null +++ b/Misc/NEWS.d/next/Library/2026-04-21-12-06-41.gh-issue-148242.eCy0eS.rst @@ -0,0 +1,2 @@ +Restore the historical behavior of the :class:`~collections.defaultdict` class, +while keeping it safe to use with :term:`free threading`. diff --git a/Modules/_collectionsmodule.c b/Modules/_collectionsmodule.c index 4ff05727ebc8ce..dbc4dc7571d13f 100644 --- a/Modules/_collectionsmodule.c +++ b/Modules/_collectionsmodule.c @@ -1,5 +1,6 @@ #include "Python.h" #include "pycore_call.h" // _PyObject_CallNoArgs() +#include "pycore_critical_section.h" // _PyCriticalSection_WarnIfNotHeld() #include "pycore_dict.h" // _PyDict_GetItem_KnownHash() #include "pycore_long.h" // _PyLong_GetZero() #include "pycore_moduleobject.h" // _PyModule_GetState() @@ -2222,9 +2223,39 @@ typedef struct { static PyType_Spec defdict_spec; +PyDoc_STRVAR(defdict_getitem_doc, +"__getitem__($self, key, /)\n--\n\n\ +Return self[key]. If the item doesn't exist, self.__missing__(key) is called\n\ +to create it.\ +"); + +static PyObject * +defdict_subscript(PyObject *op, PyObject *key) +{ + Py_ssize_t ix; + Py_hash_t hash; + PyObject *value; + + hash = _PyObject_HashFast(key); + if (hash == -1) { + _Py_dict_unhashable_type(op, key); + return NULL; + } + Py_BEGIN_CRITICAL_SECTION(op); + ix = _Py_dict_lookup((PyDictObject *)op, key, hash, &value); + if (value != NULL) { + Py_INCREF(value); + } else if (ix != DKIX_ERROR) { + value = PyObject_CallMethodOneArg(op, &_Py_ID(__missing__), key); + } + Py_END_CRITICAL_SECTION(); + return value; +} + PyDoc_STRVAR(defdict_missing_doc, -"__missing__(key) # Called by __getitem__ for missing key; pseudo-code:\n\ - if self.default_factory is None: raise KeyError((key,))\n\ +"__missing__($self, key, /)\n--\n\n\ + # Called by __getitem__ for missing key. Equivalent to:\n\ + if self.default_factory is None: raise KeyError(key)\n\ self[key] = value = self.default_factory()\n\ return value\n\ "); @@ -2232,26 +2263,26 @@ PyDoc_STRVAR(defdict_missing_doc, static PyObject * defdict_missing(PyObject *op, PyObject *key) { + if (_PyCriticalSection_WarnIfNotHeld(op, + "the defaultdict.__missing__ method should not be called directly; " + "use dd.pop(key, None) to safely trigger a reset to a default value " + "the next time key is accessed") < 0) + return NULL; defdictobject *dd = defdictobject_CAST(op); PyObject *factory = dd->default_factory; PyObject *value; if (factory == NULL || factory == Py_None) { - /* XXX Call dict.__missing__(key) */ - PyObject *tup; - tup = PyTuple_Pack(1, key); - if (!tup) return NULL; - PyErr_SetObject(PyExc_KeyError, tup); - Py_DECREF(tup); + _PyErr_SetKeyError(key); return NULL; } value = _PyObject_CallNoArgs(factory); if (value == NULL) return value; - PyObject *result = NULL; - (void)PyDict_SetDefaultRef(op, key, value, &result); - // 'result' is NULL, or a strong reference to 'value' or 'op[key]' - Py_DECREF(value); - return result; + if (PyObject_SetItem(op, key, value) < 0) { + Py_DECREF(value); + return NULL; + } + return value; } static inline PyObject* @@ -2331,6 +2362,8 @@ defdict_reduce(PyObject *op, PyObject *Py_UNUSED(dummy)) } static PyMethodDef defdict_methods[] = { + {"__getitem__", defdict_subscript, METH_O|METH_COEXIST, + defdict_getitem_doc}, {"__missing__", defdict_missing, METH_O, defdict_missing_doc}, {"copy", defdict_copy, METH_NOARGS, @@ -2511,6 +2544,7 @@ static PyType_Slot defdict_slots[] = { {Py_tp_init, defdict_init}, {Py_tp_alloc, PyType_GenericAlloc}, {Py_tp_free, PyObject_GC_Del}, + {Py_mp_subscript, defdict_subscript}, {0, NULL}, }; diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 09db93b2d31820..1ae4be4b234ac9 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -2398,8 +2398,8 @@ PyDict_GetItem(PyObject *op, PyObject *key) "PyDict_GetItemRef() or PyDict_GetItemWithError()"); } -static void -dict_unhashable_type(PyObject *op, PyObject *key) +void +_Py_dict_unhashable_type(PyObject *op, PyObject *key) { PyObject *exc = PyErr_GetRaisedException(); assert(exc != NULL); @@ -2428,7 +2428,7 @@ _PyDict_LookupIndexAndValue(PyDictObject *mp, PyObject *key, PyObject **value) Py_hash_t hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type((PyObject*)mp, key); + _Py_dict_unhashable_type((PyObject*)mp, key); return -1; } @@ -2532,7 +2532,7 @@ PyDict_GetItemRef(PyObject *op, PyObject *key, PyObject **result) Py_hash_t hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(op, key); + _Py_dict_unhashable_type(op, key); *result = NULL; return -1; } @@ -2548,7 +2548,7 @@ _PyDict_GetItemRef_Unicode_LockHeld(PyDictObject *op, PyObject *key, PyObject ** Py_hash_t hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type((PyObject*)op, key); + _Py_dict_unhashable_type((PyObject*)op, key); *result = NULL; return -1; } @@ -2586,7 +2586,7 @@ PyDict_GetItemWithError(PyObject *op, PyObject *key) } hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(op, key); + _Py_dict_unhashable_type(op, key); return NULL; } @@ -2746,7 +2746,7 @@ setitem_take2_lock_held(PyDictObject *mp, PyObject *key, PyObject *value) { Py_hash_t hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type((PyObject*)mp, key); + _Py_dict_unhashable_type((PyObject*)mp, key); Py_DECREF(key); Py_DECREF(value); return -1; @@ -2924,7 +2924,7 @@ PyDict_DelItem(PyObject *op, PyObject *key) assert(key); Py_hash_t hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(op, key); + _Py_dict_unhashable_type(op, key); return -1; } @@ -3266,7 +3266,7 @@ pop_lock_held(PyObject *op, PyObject *key, PyObject **result) Py_hash_t hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(op, key); + _Py_dict_unhashable_type(op, key); if (result) { *result = NULL; } @@ -3679,7 +3679,7 @@ dict_subscript(PyObject *self, PyObject *key) hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(self, key); + _Py_dict_unhashable_type(self, key); return NULL; } ix = _Py_dict_lookup_threadsafe(mp, key, hash, &value); @@ -4650,7 +4650,7 @@ dict_get_impl(PyDictObject *self, PyObject *key, PyObject *default_value) hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type((PyObject*)self, key); + _Py_dict_unhashable_type((PyObject*)self, key); return NULL; } ix = _Py_dict_lookup_threadsafe(self, key, hash, &val); @@ -4687,7 +4687,7 @@ dict_setdefault_ref_lock_held(PyObject *d, PyObject *key, PyObject *default_valu hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(d, key); + _Py_dict_unhashable_type(d, key); if (result) { *result = NULL; } @@ -5128,7 +5128,7 @@ dict_contains(PyObject *op, PyObject *key) { Py_hash_t hash = _PyObject_HashFast(key); if (hash == -1) { - dict_unhashable_type(op, key); + _Py_dict_unhashable_type(op, key); return -1; } @@ -7234,7 +7234,7 @@ _PyDict_SetItem_LockHeld(PyDictObject *dict, PyObject *name, PyObject *value) if (value == NULL) { Py_hash_t hash = _PyObject_HashFast(name); if (hash == -1) { - dict_unhashable_type((PyObject*)dict, name); + _Py_dict_unhashable_type((PyObject*)dict, name); return -1; } return _PyDict_DelItem_KnownHash_LockHeld((PyObject *)dict, name, hash); diff --git a/Python/critical_section.c b/Python/critical_section.c index 98e23eda7cdd77..859e3537ed3319 100644 --- a/Python/critical_section.c +++ b/Python/critical_section.c @@ -201,3 +201,24 @@ PyCriticalSection2_End(PyCriticalSection2 *c) _PyCriticalSection2_End(_PyThreadState_GET(), c); #endif } + +int +_PyCriticalSection_WarnIfNotHeld(PyObject *op, const char *message) +{ +#ifdef Py_GIL_DISABLED + PyMutex *mutex = &_PyObject_CAST(op)->ob_mutex; + PyThreadState *tstate = _PyThreadState_GET(); + uintptr_t prev = tstate->critical_section; + if (prev & _Py_CRITICAL_SECTION_TWO_MUTEXES) { + PyCriticalSection2 *cs = (PyCriticalSection2 *)(prev & ~_Py_CRITICAL_SECTION_MASK); + if (cs == NULL || (cs->_cs_base._cs_mutex != mutex && cs->_cs_mutex2 != mutex)) + return PyErr_WarnEx(NULL, message, 2); + } + else { + PyCriticalSection *cs = (PyCriticalSection *)(prev & ~_Py_CRITICAL_SECTION_MASK); + if (cs == NULL || cs->_cs_mutex != mutex) + return PyErr_WarnEx(NULL, message, 2); + } +#endif + return 0; +}