Often, we face a problem that is almost solved by an existing class. For example, suppose I want to use Python to keep track of my grocery shopping. I can use a dict
to log the items in my pantry:
pantry = {
"rice (lbs)" : 2,
"harissa (jars)" : 1,
"onions" : 5,
"lemons" : 3
}
Now suppose I go shopping, and I come back with:
shopping_trip = {
"rice (lbs)" : 1,
"onions" : 2,
"spinach (lbs)" : 1
}
What I'd like to do is add these dict
s together in the obvious way, obtaining the dict
{
"rice (lbs)" : 3,
"harissa (jars)" : 1,
"onions" : 7,
"lemons" : 3,
"spinach (lbs)" : 1
}
Unfortunately, the native implementation of dict
s doesn't support this kind of operation. For our first example, we will implement a new class that inherits from dict
, and which supports basic arithmetic. In particular, once we're done, the following will achieve the expected result:
pantry += shopping_trip
To write a class classA
that inherits from classB
, just declare class classA(classB)
. For example:
class ArithmeticDict(dict):
pass
Just by including the inheritance, this very boring class already does everything that a dict
can do. In fact, it IS a dict
-- that is, it is an instance of the dict
class.
x = ArithmeticDict({'a' : 1, 'b' : 2})
x, type(x), isinstance(x, dict)
({'a': 1, 'b': 2}, __main__.ArithmeticDict, True)
We can do normal dict
methods:
x.update({'c' : 3})
x, x['a']
({'a': 1, 'b': 2, 'c': 3}, 1)
Pause for a moment: why were we able to do:
a = ArithmeticDict({'a' : 1, 'b' : 2})
and get the expected result?
# behind the scenes
# which __init__() method is this?
a = ArithmeticDict.__init__({'a' : 1, 'b' : 2})
b = dict.__init__({'a' : 1, 'b' : 2})
Of course, this doesn't give us anything new yet. The important part is that we are now able to define new methods that will be available only for the ArithmeticDict
class.
class ArithmeticDict(dict):
"""
A dictionary class that supports entrywise addition
"""
def __add__(self, to_add):
"""
Add two ArithmeticDicts entrywise.
"""
new = {}
keys1 = set(self.keys())
keys2 = set(to_add.keys())
all_keys = keys1.union(keys2)
for key in all_keys:
new.update({key : self.get(key,0) + to_add.get(key,0)})
return ArithmeticDict(new)
x = ArithmeticDict({'a' : 1, 'b' : 2})
y = ArithmeticDict({'a' : 1, 'b' : 3, 'c' : 7})
x+y
{'a': 2, 'b': 5, 'c': 7}
I'm now able to update my pantry:
pantry = {
"rice (lbs)" : 2,
"harissa (jars)" : 1,
"onions" : 5,
"lemons" : 3
}
shopping_trip = {
"rice (lbs)" : 1,
"onions" : 2,
"spinach (lbs)" : 1
}
pantry = ArithmeticDict(pantry)
pantry
{'rice (lbs)': 2, 'harissa (jars)': 1, 'onions': 5, 'lemons': 3}
shopping_trip = ArithmeticDict(shopping_trip)
shopping_trip
{'rice (lbs)': 1, 'onions': 2, 'spinach (lbs)': 1}
pantry += shopping_trip
# OR pantry = pantry + shopping_trip
pantry
{'spinach (lbs)': 1, 'lemons': 3, 'onions': 7, 'rice (lbs)': 3, 'harissa (jars)': 1}