Fast and safe weak value dictionary

AUTHORS:

  • Simon King (2013-10)
  • Nils Bruin (2013-10)
  • Julian Rueth (2014-03-16): improved handling of unhashable objects

Python’s weakref module provides WeakValueDictionary. This behaves similar to a dictionary, but it does not prevent its values from garbage collection. Hence, it stores the values by weak references with callback functions: The callback function deletes a key-value pair from the dictionary, as soon as the value becomes subject to garbage collection.

However, a problem arises if hash and comparison of the key depend on the value that is being garbage collected:

sage: import weakref
sage: class Vals(object): pass
sage: class Keys:
....:     def __init__(self, val):
....:         self.val = weakref.ref(val)
....:     def __hash__(self):
....:         return hash(self.val())
....:     def __eq__(self, other):
....:         return self.val() == other.val()
....:     def __ne__(self, other):
....:         return self.val() != other.val()
sage: ValList = [Vals() for _ in range(10)]
sage: D = weakref.WeakValueDictionary()
sage: for v in ValList:
....:     D[Keys(v)] = v
sage: len(D)
10
sage: del ValList, v
sage: len(D) > 1
True

Hence, the defunct items have not been removed from the dictionary.

Therefore, Sage provides an alternative implementation sage.misc.weak_dict.WeakValueDictionary, using a callback that removes the defunct item not based on hash and equality check of the key (this is what fails in the example above), but based on comparison by identity. This is possible, since references with callback function are distinct even if they point to the same object. Hence, even if the same object O occurs as value for several keys, each reference to O corresponds to a unique key. We see no error messages, and the items get correctly removed:

sage: ValList = [Vals() for _ in range(10)]
sage: import sage.misc.weak_dict
sage: D = sage.misc.weak_dict.WeakValueDictionary()
sage: for v in ValList:
....:     D[Keys(v)] = v
sage: len(D)
10
sage: del ValList
sage: len(D)
1
sage: del v
sage: len(D)
0

Another problem arises when iterating over the items of a dictionary: If garbage collection occurs during iteration, then the content of the dictionary changes, and the iteration breaks for weakref.WeakValueDictionary:

sage: class Cycle:
....:     def __init__(self):
....:         self.selfref = self
sage: C = [Cycle() for n in range(10)]
sage: D = weakref.WeakValueDictionary(enumerate(C))
sage: import gc
sage: gc.disable()
sage: del C[:5]
sage: len(D)
10

With WeakValueDictionary, the behaviour is safer. Note that iteration over a WeakValueDictionary is non-deterministic, since the lifetime of values (and hence the presence of keys) in the dictionary may depend on when garbage collection occurs. The method implemented here will at least postpone dictionary mutations due to garbage collection callbacks. This means that as long as there is at least one iterator active on a dictionary, none of its keys will be deallocated (which could have side-effects). Which entries are returned is of course still dependent on when garbage collection occurs. Note that when a key gets returned as “present” in the dictionary, there is no guarantee one can actually retrieve its value: it may have been garbage collected in the mean time.

The variant CachedWeakValueDictionary additionally adds strong references to the most recently added values. This ensures that values will not be immediately deleted after adding them to the dictionary. This is mostly useful to implement cached functions.

Note that Sage’s weak value dictionary is actually an instance of dict, in contrast to weakref’s weak value dictionary:

sage: issubclass(weakref.WeakValueDictionary, dict)
False
sage: issubclass(sage.misc.weak_dict.WeakValueDictionary, dict)
True

See trac ticket #13394 for a discussion of some of the design considerations.

class sage.misc.weak_dict.CachedWeakValueDictionary

Bases: sage.misc.weak_dict.WeakValueDictionary

This class extends WeakValueDictionary with a strong cache to the most recently added values. It is meant to solve the case where significant performance losses can occur if a value is deleted too early, but where keeping a value alive too long does not hurt much. This is typically the case with cached functions.

EXAMPLES:

We illustrate the difference between WeakValueDictionary and CachedWeakValueDictionary. An item is removed from a WeakValueDictionary as soon as there are no references to it:

sage: from sage.misc.weak_dict import WeakValueDictionary
sage: D = WeakValueDictionary()
sage: class Test(object): pass
sage: tmp = Test()
sage: D[0] = tmp
sage: 0 in D
True
sage: del tmp
sage: 0 in D
False

So, if you have a cached function repeatedly creating the same temporary object and deleting it (in a helper function called from a loop for example), this caching will not help at all. With CachedWeakValueDictionary, the most recently added values are not deleted. After adding enough new values, the item is removed anyway:

sage: from sage.misc.weak_dict import CachedWeakValueDictionary
sage: D = CachedWeakValueDictionary(cache=4)
sage: class Test(object): pass
sage: tmp = Test()
sage: D[0] = tmp
sage: 0 in D
True
sage: del tmp
sage: 0 in D
True
sage: for i in range(5):
....:     D[1] = Test()
....:     print(0 in D)
True
True
True
False
False
class sage.misc.weak_dict.WeakValueDictEraser

Bases: object

Erases items from a sage.misc.weak_dict.WeakValueDictionary when a weak reference becomes invalid.

This is of internal use only. Instances of this class will be passed as a callback function when creating a weak reference.

EXAMPLES:

sage: from sage.misc.weak_dict import WeakValueDictionary
sage: v = frozenset([1])
sage: D = WeakValueDictionary({1 : v})
sage: len(D)
1
sage: del v
sage: len(D)
0

AUTHOR:

  • Nils Bruin (2013-11)
class sage.misc.weak_dict.WeakValueDictionary

Bases: dict

IMPLEMENTATION:

The WeakValueDictionary inherits from dict. In its implementation, it stores weakrefs to the actual values under the keys. All access routines are wrapped to transparently place and remove these weakrefs.

NOTE:

In contrast to weakref.WeakValueDictionary in Python’s weakref module, the callback does not need to assume that the dictionary key is a valid Python object when it is called. There is no need to compute the hash or compare the dictionary keys. This is why the example below would not work with weakref.WeakValueDictionary, but does work with sage.misc.weak_dict.WeakValueDictionary.

EXAMPLES:

sage: import weakref
sage: class Vals(object): pass
sage: class Keys:
....:     def __init__(self, val):
....:         self.val = weakref.ref(val)
....:     def __hash__(self):
....:         return hash(self.val())
....:     def __eq__(self, other):
....:         return self.val() == other.val()
....:     def __ne__(self, other):
....:         return self.val() != other.val()
sage: ValList = [Vals() for _ in range(10)]
sage: import sage.misc.weak_dict
sage: D = sage.misc.weak_dict.WeakValueDictionary()
sage: for v in ValList:
....:     D[Keys(v)] = v
sage: len(D)
10
sage: del ValList
sage: len(D)
1
sage: del v
sage: len(D)
0
get(k, d=None)

Return the stored value for a key, or a default value for unknown keys.

The default value defaults to None.

EXAMPLES:

sage: import sage.misc.weak_dict
sage: L = [GF(p) for p in prime_range(10^3)]
sage: D = sage.misc.weak_dict.WeakValueDictionary(enumerate(L))
sage: 100 in D
True
sage: 200 in D
False
sage: D.get(100, "not found")
Finite Field of size 547
sage: D.get(200, "not found")
'not found'
sage: D.get(200) is None
True
items()

The key-value pairs of this dictionary.

EXAMPLES:

sage: import sage.misc.weak_dict
sage: class Vals:
....:     def __init__(self, n):
....:         self.n = n
....:     def __repr__(self):
....:         return "<%s>" % self.n
....:     def __lt__(self, other):
....:         return self.n < other.n
....:     def __eq__(self, other):
....:         return self.n == other.n
....:     def __ne__(self, other):
....:         return self.val() != other.val()
sage: class Keys(object):
....:     def __init__(self, n):
....:         self.n = n
....:     def __hash__(self):
....:         if self.n % 2:
....:             return int(5)
....:         return int(3)
....:     def __repr__(self):
....:         return "[%s]" % self.n
....:     def __lt__(self, other):
....:         return self.n < other.n
....:     def __eq__(self, other):
....:         return self.n == other.n
....:     def __ne__(self, other):
....:         return self.val() != other.val()
sage: L = [(Keys(n), Vals(n)) for n in range(10)]
sage: D = sage.misc.weak_dict.WeakValueDictionary(L)

We remove one dictionary item directly. Another item is removed by means of garbage collection. By consequence, there remain eight items in the dictionary:

sage: del D[Keys(2)]
sage: del L[5]
sage: sorted(D.items())
[([0], <0>),
 ([1], <1>),
 ([3], <3>),
 ([4], <4>),
 ([6], <6>),
 ([7], <7>),
 ([8], <8>),
 ([9], <9>)]
iteritems()

Iterate over the items of this dictionary.

Warning

Iteration is unsafe, if the length of the dictionary changes during the iteration! This can also happen by garbage collection.

EXAMPLES:

sage: import sage.misc.weak_dict
sage: class Vals:
....:     def __init__(self, n):
....:         self.n = n
....:     def __repr__(self):
....:         return "<%s>" % self.n
....:     def __lt__(self, other):
....:         return self.n < other.n
....:     def __eq__(self, other):
....:         return self.n == other.n
....:     def __ne__(self, other):
....:         return self.val() != other.val()
sage: class Keys(object):
....:     def __init__(self, n):
....:         self.n = n
....:     def __hash__(self):
....:         if self.n % 2:
....:             return int(5)
....:         return int(3)
....:     def __repr__(self):
....:         return "[%s]" % self.n
....:     def __lt__(self, other):
....:         return self.n < other.n
....:     def __eq__(self, other):
....:         return self.n == other.n
....:     def __ne__(self, other):
....:         return self.val() != other.val()
sage: L = [(Keys(n), Vals(n)) for n in range(10)]
sage: D = sage.misc.weak_dict.WeakValueDictionary(L)

We remove one dictionary item directly. Another item is removed by means of garbage collection. By consequence, there remain eight items in the dictionary:

sage: del D[Keys(2)]
sage: del L[5]
sage: for k,v in sorted(D.iteritems()):
....:     print("{} {}".format(k, v))
[0] <0>
[1] <1>
[3] <3>
[4] <4>
[6] <6>
[7] <7>
[8] <8>
[9] <9>
itervalues()

Iterate over the values of this dictionary.

Warning

Iteration is unsafe, if the length of the dictionary changes during the iteration! This can also happen by garbage collection.

EXAMPLES:

sage: import sage.misc.weak_dict
sage: class Vals:
....:     def __init__(self, n):
....:         self.n = n
....:     def __repr__(self):
....:         return "<%s>" % self.n
....:     def __lt__(self, other):
....:         return self.n < other.n
....:     def __eq__(self, other):
....:         return self.n == other.n
....:     def __ne__(self, other):
....:         return self.val() != other.val()
sage: L = [Vals(n) for n in range(10)]
sage: D = sage.misc.weak_dict.WeakValueDictionary(enumerate(L))

We delete one item from D and we delete one item from the list L. The latter implies that the corresponding item from D gets deleted as well. Hence, there remain eight values:

sage: del D[2]
sage: del L[5]
sage: for v in sorted(D.itervalues()):
....:     print(v)
<0>
<1>
<3>
<4>
<6>
<7>
<8>
<9>
keys()

The list of keys.

EXAMPLES:

sage: import sage.misc.weak_dict
sage: class Vals(object): pass
sage: L = [Vals() for _ in range(10)]
sage: D = sage.misc.weak_dict.WeakValueDictionary(enumerate(L))
sage: del L[4]

One item got deleted from the list L and hence the corresponding item in the dictionary got deleted as well. Therefore, the corresponding key 4 is missing in the list of keys:

sage: sorted(D.keys())
[0, 1, 2, 3, 5, 6, 7, 8, 9]
pop(k)

Return the value for a given key, and delete it from the dictionary.

EXAMPLES:

sage: import sage.misc.weak_dict
sage: L = [GF(p) for p in prime_range(10^3)]
sage: D = sage.misc.weak_dict.WeakValueDictionary(enumerate(L))
sage: 20 in D
True
sage: D.pop(20)
Finite Field of size 73
sage: 20 in D
False
sage: D.pop(20)
Traceback (most recent call last):
...
KeyError: 20
popitem()

Return and delete some item from the dictionary.

EXAMPLES:

sage: import sage.misc.weak_dict
sage: D = sage.misc.weak_dict.WeakValueDictionary()
sage: D[1] = ZZ

The dictionary only contains a single item, hence, it is clear which one will be returned:

sage: D.popitem()
(1, Integer Ring)

Now, the dictionary is empty, and hence the next attempt to pop an item will fail with a KeyError:

sage: D.popitem()
Traceback (most recent call last):
...
KeyError: 'popitem(): weak value dictionary is empty'
setdefault(k, default=None)

Return the stored value for a given key; return and store a default value if no previous value is stored.

EXAMPLES:

sage: import sage.misc.weak_dict
sage: L = [(p,GF(p)) for p in prime_range(10)]
sage: D = sage.misc.weak_dict.WeakValueDictionary(L)
sage: len(D)
4

The value for an existing key is returned and not overridden:

sage: D.setdefault(5, ZZ)
Finite Field of size 5
sage: D[5]
Finite Field of size 5

For a non-existing key, the default value is stored and returned:

sage: 4 in D
False
sage: D.setdefault(4, ZZ)
Integer Ring
sage: 4 in D
True
sage: D[4]
Integer Ring
sage: len(D)
5
values()

Return the list of values.

EXAMPLES:

sage: import sage.misc.weak_dict
sage: class Vals:
....:     def __init__(self, n):
....:         self.n = n
....:     def __repr__(self):
....:         return "<%s>" % self.n
....:     def __lt__(self, other):
....:         return self.n < other.n
....:     def __eq__(self, other):
....:         return self.n == other.n
....:     def __ne__(self, other):
....:         return self.val() != other.val()
sage: L = [Vals(n) for n in range(10)]
sage: D = sage.misc.weak_dict.WeakValueDictionary(enumerate(L))

We delete one item from D and we delete one item from the list L. The latter implies that the corresponding item from D gets deleted as well. Hence, there remain eight values:

sage: del D[2]
sage: del L[5]
sage: sorted(D.values())
[<0>, <1>, <3>, <4>, <6>, <7>, <8>, <9>]