PythonTips(Part 1)

前言

python cookbook中学到的很多python技巧,基本都是工作中做python开发遇到的实际问题,实用性很强,故记录于此。

数据结构

序列分解

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
>>> p = (4,5)
>>> a,b = p
>>> a
4
>>> b
5
>>> s = 'elssm'
>>> a,b,c,d,e = s
>>> a
'e'
>>> b
'l'
>>> c
's'
>>> d
's'
>>> e
'm'

解压可迭代对象赋值给多个变量

1
2
3
4
5
6
7
8
9
10
11
12
13
>>> record = ('elssm','test@qq.com','13888888888','15888888888')
>>> name,email,*phones = record
>>> name
'elssm'
>>> email
'test@qq.com'
>>> phones
['13888888888', '15888888888']
>>> *front,end = [1,2,3,4,5,6]
>>> front
[1, 2, 3, 4, 5]
>>> end
6

保留有限历史记录

1
2
3
4
5
6
7
8
9
10
11
>>> from collections import deque
>>> q = deque(maxlen=3)
>>> q.append(1)
>>> q.append(2)
>>> q.append(3)
>>> q.append(4)
>>> q
deque([2, 3, 4], maxlen=3)
>>> q.append(5)
>>> q
deque([3, 4, 5], maxlen=3)

deque增删操作

1
2
3
4
5
6
7
8
9
10
11
12
13
>>> q = deque()
>>> q.append(1)
>>> q.append(2)
>>> q.append(3)
>>> q.appendleft(4)
>>> q
deque([4, 1, 2, 3])
>>> q.pop()
3
>>> q.popleft()
4
>>> q
deque([1, 2])

查找最大或最小的N个元素

1
2
3
4
5
6
>>> import heapq
>>> nums = [3,2,5,8,4,7,6,9]
>>> heapq.nlargest(3,nums)
[9, 8, 7]
>>> heapq.nsmallest(3,nums)
[2, 3, 4]

处理更复杂的的数据结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
>>> portfolio = [
... {'name': 'IBM', 'shares': 100, 'price': 91.1},
... {'name': 'AAPL', 'shares': 50, 'price': 543.22},
... {'name': 'FB', 'shares': 200, 'price': 21.09},
... {'name': 'HPQ', 'shares': 35, 'price': 31.75},
... {'name': 'YHOO', 'shares': 45, 'price': 16.35},
... {'name': 'ACME', 'shares': 75, 'price': 115.65}
... ]
>>> cheap = heapq.nsmallest(3, portfolio, key=lambda s: s['price'])
>>> cheap
[{'name': 'YHOO', 'shares': 45, 'price': 16.35}, {'name': 'FB', 'shares': 200, 'price': 21.09}, {'name': 'HPQ', 'shares': 35, 'price': 31.75}]
>>> expensive = heapq.nlargest(3, portfolio, key=lambda s: s['price'])
>>> expensive
[{'name': 'AAPL', 'shares': 50, 'price': 543.22}, {'name': 'ACME', 'shares': 75, 'price': 115.65}, {'name': 'IBM', 'shares': 100, 'price': 91.1}]

字典运算

1
2
3
4
5
6
7
8
9
10
11
12
13
>>> prices = {
... 'ACME': 45.23,
... 'AAPL': 612.78,
... 'IBM': 205.55,
... 'HPQ': 37.20,
... 'FB': 10.75
... }
>>> min_price = min(zip(prices.values(),prices.keys()))
>>> min_price
(10.75, 'FB')
>>> max_price = max(zip(prices.values(),prices.keys()))
>>> max_price
(612.78, 'AAPL')

查找两字典的相同点

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
>>> a = {
... 'x' : 1,
... 'y' : 2,
... 'z' : 3
... }
>>>
>>> b = {
... 'w' : 10,
... 'x' : 11,
... 'y' : 2
... }
>>> a.keys() & b.keys()
{'y', 'x'}
>>> a.items() & b.items()
{('y', 2)}

删除序列相同元素并保持顺序

1
2
3
4
5
6
7
8
9
def dedupe(items):
seen = set()
for item in items:
if item not in seen:
yield item
seen.add(item)
>>> a = [1, 5, 2, 1, 9, 1, 5, 10]
>>> list(dedupe(a))
[1, 5, 2, 9, 10]
1
2
3
4
5
6
7
8
9
10
11
12
def dedupe(items, key=None):
seen = set()
for item in items:
val = item if key is None else key(item)
if val not in seen:
yield item
seen.add(val)
>>> a = [ {'x':1, 'y':2}, {'x':1, 'y':3}, {'x':1, 'y':2}, {'x':2, 'y':4}]
>>> list(dedupe(a, key=lambda d: (d['x'],d['y'])))
[{'x': 1, 'y': 2}, {'x': 1, 'y': 3}, {'x': 2, 'y': 4}]
>>> list(dedupe(a, key=lambda d: d['x']))
[{'x': 1, 'y': 2}, {'x': 2, 'y': 4}]

序列中出现次数最多的元素

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
>>> words = [
... 'look', 'into', 'my', 'eyes', 'look', 'into', 'my', 'eyes',
... 'the', 'eyes', 'the', 'eyes', 'the', 'eyes', 'not', 'around', 'the',
... 'eyes', "don't", 'look', 'around', 'the', 'eyes', 'look', 'into',
... 'my', 'eyes', "you're", 'under'
... ]
>>>
>>> from collections import Counter
>>> word_counts = Counter(words)
>>> top_three = word_counts.most_common(3)
>>> top_three
[('eyes', 8), ('the', 5), ('look', 4)]
>>> word_counts['not']
1
>>> word_counts['eyes']
8

通过某个关键字排序字典列表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
>>> rows = [
... {'fname': 'Brian', 'lname': 'Jones', 'uid': 1003},
... {'fname': 'David', 'lname': 'Beazley', 'uid': 1002},
... {'fname': 'John', 'lname': 'Cleese', 'uid': 1001},
... {'fname': 'Big', 'lname': 'Jones', 'uid': 1004}
... ]
>>> from operator import itemgetter
>>> rows_by_fname = sorted(rows,key=itemgetter('fname'))
>>> rows_by_fname
[{'fname': 'Big', 'lname': 'Jones', 'uid': 1004}, {'fname': 'Brian', 'lname': 'Jones', 'uid': 1003}, {'fname': 'David', 'lname': 'Beazley', 'uid': 1002}, {'fname': 'John', 'lname': 'Cleese', 'uid': 1001}]
>>> rows_by_uid = sorted(rows,key=itemgetter('uid'))
>>> rows_by_uid
[{'fname': 'John', 'lname': 'Cleese', 'uid': 1001}, {'fname': 'David', 'lname': 'Beazley', 'uid': 1002}, {'fname': 'Brian', 'lname': 'Jones', 'uid': 1003}, {'fname': 'Big', 'lname': 'Jones', 'uid': 1004}]
>>> min(rows,key=itemgetter('uid'))
{'fname': 'John', 'lname': 'Cleese', 'uid': 1001}
>>> max(rows,key=itemgetter('uid'))
{'fname': 'Big', 'lname': 'Jones', 'uid': 1004}

通过某个字段将记录分组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
>>> rows = [
... {'address': '5412 N CLARK', 'date': '07/01/2012'},
... {'address': '5148 N CLARK', 'date': '07/04/2012'},
... {'address': '5800 E 58TH', 'date': '07/02/2012'},
... {'address': '2122 N CLARK', 'date': '07/03/2012'},
... {'address': '5645 N RAVENSWOOD', 'date': '07/02/2012'},
... {'address': '1060 W ADDISON', 'date': '07/02/2012'},
... {'address': '4801 N BROADWAY', 'date': '07/01/2012'},
... {'address': '1039 W GRANVILLE', 'date': '07/04/2012'},
... ]
>>> from operator import itemgetter
>>> from itertools import groupby
>>> rows.sort(key=itemgetter('date'))
>>> for date,items in groupby(rows,key=itemgetter('date')):
... print(date)
... for i in items:
... print(' ',i)
...
07/01/2012
{'address': '5412 N CLARK', 'date': '07/01/2012'}
{'address': '4801 N BROADWAY', 'date': '07/01/2012'}
07/02/2012
{'address': '5800 E 58TH', 'date': '07/02/2012'}
{'address': '5645 N RAVENSWOOD', 'date': '07/02/2012'}
{'address': '1060 W ADDISON', 'date': '07/02/2012'}
07/03/2012
{'address': '2122 N CLARK', 'date': '07/03/2012'}
07/04/2012
{'address': '5148 N CLARK', 'date': '07/04/2012'}
{'address': '1039 W GRANVILLE', 'date': '07/04/2012'}

过滤列表元素

1
2
3
4
5
6
7
8
9
10
values = ['1', '2', '-3', '-', '4', 'N/A', '5']
def is_int(val):
try:
x = int(val)
return True
except ValueError:
return False
ivals = list(filter(is_int, values))
print(ivals)
# Outputs ['1', '2', '-3', '4', '5']

从字典中提取子集

1
2
3
4
5
6
7
8
9
10
11
>>> prices = {
... 'ACME': 45.23,
... 'AAPL': 612.78,
... 'IBM': 205.55,
... 'HPQ': 37.20,
... 'FB': 10.75
... }
>>>
>>> p1 = dict((key, value) for key, value in prices.items() if value > 200)
>>> p1
{'AAPL': 612.78, 'IBM': 205.55}

合并多个字典或映射

假设有如下两个字典

1
2
a = {'x': 1, 'z': 3 }
b = {'y': 2, 'z': 4 }

假设必须在两个字典中执行查找操作(比如先从a中找,如果找不到再在b中找)

1
2
3
4
5
from collections import ChainMap
c = ChainMap(a,b)
print(c['x']) # Outputs 1 (from a)
print(c['y']) # Outputs 2 (from b)
print(c['z']) # Outputs 3 (from a)

字符串和文本

使用多个界定符分割字符串

1
2
3
4
>>> line = 'asdf fjdk; afed, fjek,asdf, foo'
>>> import re
>>> re.split(r'[;,\s]\s*',line)
['asdf', 'fjdk', 'afed', 'fjek', 'asdf', 'foo']

函数re.split()是非常实用的,因为它允许你为分隔符指定多个正则模式。 比如,在上面的例子中,分隔符可以是逗号,分号或者是空格,并且后面紧跟着任意个的空格。 只要这个模式被找到,那么匹配的分隔符两边的实体都会被当成是结果中的元素返回。 返回结果为一个字段列表,这个跟str.split()返回值类型是一样的。

用shell通配符匹配字符串

1
2
3
4
5
6
7
>>> from fnmatch import fnmatch,fnmatchcase
>>> fnmatch('foo.txt','*.txt')
True
>>> fnmatch('foo.txt','*.TXT')
True
>>> fnmatchcase('foo.txt','*.TXT')
False

字符串匹配和搜索

1
2
3
4
5
>>> import re
>>> datepat = re.compile(r'\d+/\d+/\d+')
>>> text = 'today is 2023/10/7.tomorrow is 2023/10/8.'
>>> datepat.findall(text)
['2023/10/7', '2023/10/8']

字符串搜索和替换

对于简单的搜索,直接使用str.replace()方法。

1
2
3
>>> text = 'this is a test.'
>>> text.replace('this','that')
'that is a test.'

对于复杂的模式,可以使用re模块中的sub()函数。示例如下:

1
2
3
4
>>> text = 'today is 2023/10/7.tomorrow is 2023/10/8.'
>>> import re
>>> re.sub(r'(\d+)/(\d+)/(\d+)', r'\1-\2-\3', text)
'today is 2023-10-7.tomorrow is 2023-10-8.'

最短匹配模式

1
2
3
4
5
6
7
>>> str_pat = re.compile(r'"(.*)"')
>>> text1 = 'Computer says "no."'
>>> str_pat.findall(text1)
['no.']
>>> text2 = 'Computer says "no." phone says "yes."'
>>> str_pat.findall(text2)
['no." phone says "yes.']

在这个例子中,模式r'\"(.*)\"'的意图是匹配被双引号包含的文本。 但是在正则表达式中操作符是贪婪的,因此匹配操作会查找最长的可能匹配。 于是在第二个例子中搜索text2的时候返回结果并不是正确的。
为了修正这个问题,可以在模式中的
操作符后面加上?修饰符

1
2
3
>>> str_pat = re.compile(r'"(.*?)"')
>>> str_pat.findall(text2)
['no.', 'yes.']

字符串对齐

对于基本的字符串对齐操作,可以使用字符串的ljust(),rjust()center()方法。示例如下:

1
2
3
4
5
6
7
8
>>> text = 'Hello World'
>>> text.ljust(20)
'Hello World '
>>> text.rjust(20)
' Hello World'
>>> text.center(20)
' Hello World '
>>>

所有这些方法都能接受一个可选的填充字符。示例如下:

1
2
3
4
5
>>> text.rjust(20,'=')
'=========Hello World'
>>> text.center(20,'*')
'****Hello World*****'
>>>

函数format()同样可以用来很容易的对齐字符串。 你要做的就是使用<,>或者^字符后面紧跟一个指定的宽度。比如:

1
2
3
4
5
6
7
>>> format(text, '>20')
' Hello World'
>>> format(text, '<20')
'Hello World '
>>> format(text, '^20')
' Hello World '
>>>

如果你想指定一个非空格的填充字符,将它写到对齐字符的前面即可:

1
2
3
4
5
>>> format(text, '=>20s')
'=========Hello World'
>>> format(text, '*^20s')
'****Hello World*****'
>>>

日期和时间

数字的四舍五入

对于简单的舍入运算,使用内置的round(value, ndigits)函数即可。示例如下:

1
2
3
4
5
6
7
8
9
>>> round(1.23, 1)
1.2
>>> round(1.27, 1)
1.3
>>> round(-1.27, 1)
-1.3
>>> round(1.25361,3)
1.254
>>>

注意:当一个值刚好在两个边界的中间的时候,round函数返回离它最近的偶数。 也就是说,对1.5或者2.5的舍入运算都会得到2。

传给 round() 函数的 ndigits 参数可以是负数,这种情况下, 舍入运算会作用在十位、百位、千位等上面。示例如下:

1
2
3
4
5
6
7
8
>>> a = 1627731
>>> round(a, -1)
1627730
>>> round(a, -2)
1627700
>>> round(a, -3)
1628000
>>>

不要将舍入和格式化输出搞混淆了。 如果你的目的只是简单的输出一定宽度的数,你不需要使用round()函数。 而仅仅只需要在格式化的时候指定精度即可。示例如下:

1
2
3
4
5
>>> x = 1.23456
>>> format(x, '0.2f')
'1.23'
>>> format(x, '0.3f')
'1.235'

执行精确的浮点数运算

浮点数的一个普遍问题是它们并不能精确的表示十进制数。 并且,即使是最简单的数学运算也会产生小的误差,比如:

1
2
3
4
5
6
7
>>> a = 4.2
>>> b = 2.1
>>> a + b
6.300000000000001
>>> (a + b) == 6.3
False
>>>

这些错误是由底层CPU和IEEE 754标准通过自己的浮点单位去执行算术时的特征。 由于Python的浮点数据类型使用底层表示存储数据,因此你没办法去避免这样的误差。如果想更加精确可以使用decimal模块:

1
2
3
4
5
6
7
8
9
>>> from decimal import Decimal
>>> a = Decimal('4.2')
>>> b = Decimal('2.1')
>>> a + b
Decimal('6.3')
>>> print(a + b)
6.3
>>> (a + b) == Decimal('6.3')
True

输出进制整数

为了将整数转换为二进制、八进制或十六进制的文本串, 可以分别使用bin(),oct()hex()函数。

1
2
3
4
5
6
7
8
>>> x = 1234
>>> bin(x)
'0b10011010010'
>>> oct(x)
'0o2322'
>>> hex(x)
'0x4d2'
>>>

另外,如果你不想输出0b,0o或者0x的前缀的话,可以使用format()函数。示例如下:

1
2
3
4
5
6
7
>>> format(x, 'b')
'10011010010'
>>> format(x, 'o')
'2322'
>>> format(x, 'x')
'4d2'
>>>

无穷大和NaN

Python并没有特殊的语法来表示这些特殊的浮点值,但是可以使用float()来创建它们。示例如下:

1
2
3
4
5
6
7
8
9
10
>>> a = float('inf')
>>> b = float('-inf')
>>> c = float('nan')
>>> a
inf
>>> b
-inf
>>> c
nan
>>>

为了测试这些值的存在,使用math.isinf()math.isnan()函数。示例如下:

1
2
3
4
5
>>> math.isinf(a)
True
>>> math.isnan(c)
True
>>>

基本的日期与时间转换

为了执行不同时间单位的转换和计算,可以使用datetime模块,为了表示一个时间段,可以创建一个timedelta实例,示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
>>> from datetime import timedelta
>>> a = timedelta(days=2, hours=6)
>>> b = timedelta(hours=4.5)
>>> c = a + b
>>> c.days
2
>>> c.seconds
37800
>>> c.seconds / 3600
10.5
>>> c.total_seconds() / 3600
58.5
>>>

如果想表示指定的日期和时间,先创建一个datetime实例然后使用标准的数学运算来操作,实例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
>>> from datetime import datetime
>>> from datetime import timedelta
>>> a = datetime(2023,10,9)
>>> print(a+timedelta(days=10))
2023-10-19 00:00:00
>>> b = datetime(2023,11,11)
>>> d = b - a
>>> d.days
33
>>> now = datetime.today()
>>> print(now)
2023-10-09 08:58:22.071529
>>> print(now+timedelta(minutes=10))
2023-10-09 09:08:22.071529

计算当前月份的日期范围

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from datetime import datetime, date, timedelta
import calendar


def get_month_range(start_date=None):
if start_date is None:
start_date = date.today().replace(day=1)
_, days_in_month = calendar.monthrange(start_date.year, start_date.month)
end_date = start_date + timedelta(days=days_in_month)
return (start_date, end_date)


a_day = timedelta(days=1)
first_day, last_day = get_month_range()
while first_day < last_day:
print(first_day)
first_day += a_day

上述代码首先计算出当月的第一天的日期,通过date对象的replace()方法将days属性设置为1即可。然后,使用calendar.monthrange()函数来找出该月的总天数。monthrange()函数会返回包含星期和该月天数的元组。一旦该月的天数已知了,那么结束日期就可以通过在开始日期上面加上这个天数获得。 有个需要注意的是结束日期并不包含在这个日期范围内。 这个和Pythonslicerange操作行为保持一致,同样也不包含结尾。为了在日期范围上循环,要使用到标准的数学和比较操作。 比如,可以利用timedelta实例来递增日期,小于号<用来检查一个日期是否在结束日期之前。

字符串转换为日期

1
2
3
4
5
6
7
8
9
10
11
>>> from datetime import datetime
>>> text = '2023-10-9'
>>> y = datetime.strptime(text,"%Y-%m-%d")
>>> y
datetime.datetime(2023, 10, 9, 0, 0)
>>> z = datetime.now()
>>> z
datetime.datetime(2023, 10, 9, 9, 57, 13, 378621)
>>> diff = z-y
>>> diff
datetime.timedelta(seconds=35833, microseconds=378621)

datetime.strptime()方法支持很多的格式化代码, 比如%Y 代表4位数年份,%m代表两位数月份。 还有一点值得注意的是这些格式化占位符也可以反过来使用,将日期输出为指定的格式字符串形式。

1
2
3
4
datetime.datetime(2023, 10, 9, 9, 57, 13, 378621)
>>> nice_z = datetime.strftime(z,'%A %B %d %Y')
>>> nice_z
'Monday October 09 2023'

有一点需要注意的是,strptime()的性能较差,因为它是使用纯python实现,并且必须处理所有的系统本地设置,因此如果在代码中需要解析大量的日期并且已知日期字符串的确切格式,可以自定义实现日期解析函数,示例如下:

1
2
3
4
from datetime import datetime
def parse_ymd(s):
year_s, mon_s, day_s = s.split('-')
return datetime(int(year_s), int(mon_s), int(day_s))

迭代器和生成器

手动遍历迭代器

为了手动的遍历可迭代对象,使用next()函数并在代码中捕获StopIteration异常。

1
2
3
4
5
6
7
8
def manual_iter():
with open('/etc/passwd') as f:
try:
while True:
line = next(f)
print(line, end='')
except StopIteration:
pass
1
2
3
4
5
6
7
8
a = [1, 2, 3, 4, 5]
a = iter(a)
try:
while True:
b = next(a)
print(b, end=' ')
except StopIteration:
pass

使用生成器创建新的迭代模式

如果想实现一种新的迭代模式,使用一个生成器函数来定义它,如下是一个生产某个范围内浮点数的生成器:

1
2
3
4
5
6
7
8
9
10
11
def frange(start, stop, increment):
x = start
while x < stop:
yield x
x += increment


for n in frange(0, 4, 0.5):
print(n)

print(list(frange(0, 4, 0.5)))

一个生成器函数主要特征是它只会回应在迭代中使用到的next操作。 一旦生成器函数返回退出,迭代终止。我们在迭代中通常使用的for语句会自动处理这些细节

反向迭代

使用内置的reversed()函数,如下所示:

1
2
3
4
5
6
7
8
>>> a = [1,2,3,4]
>>> for x in reversed(a):
... print(x)
...
4
3
2
1

迭代器切片

如果想得到一个由迭代器生成的切片对象,但是标准切片操作并不能做到,可以使用itertools.islice()做切片操作,如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import itertools
def count(n):
while True:
yield n
n += 1
c = count(0)
#标准切片报错
c[10:20]
Traceback (most recent call last):
File "test.py", line 11, in <module>
TypeError: 'generator' object is not subscriptable
#使用itertools.islice()操作
for x in itertools.islice(c, 10, 20):
print(x)

迭代器和生成器不能使用标准的切片操作,因为它们的长度事先我们并不知道(并且也没有实现索引)。 函数islice()返回一个可以生成指定元素的迭代器,它通过遍历并丢弃直到切片开始索引位置的所有元素。 然后才开始一个个的返回元素,并直到切片结束索引位置。这里要着重强调的一点是islice()会消耗掉传入的迭代器中的数据。必须考虑到迭代器是不可逆的这个事实。如果你需要之后再次访问这个迭代器的话,那你就得先将它里面的数据放入一个列表中。

排列组合的迭代

如果想迭代遍历一个集合中元素的所有可能的排列组合,itertools提供了三个函数来解决这类问题,其中一个是itertools.permutations(),它接受一个集合并产生一个元组序列,每个元组由集合中所有元素的一个可能排列组成。 也就是说通过打乱集合中元素排列顺序生成一个元组,示例如下:

1
2
3
4
5
6
7
8
9
10
11
>>> from itertools import permutations
>>> items = ["a","b","c"]
>>> for p in permutations(items):
... print(p)
...
('a', 'b', 'c')
('a', 'c', 'b')
('b', 'a', 'c')
('b', 'c', 'a')
('c', 'a', 'b')
('c', 'b', 'a')

如果想得到指定长度的所有排列,可以传递一个可选的长度参数,示例如下:

1
2
3
4
5
6
7
8
9
10
11
>>> from itertools import permutations
>>> items = ["a","b","c"]
>>> for p in permutations(items,2):
... print(p)
...
('a', 'b')
('a', 'c')
('b', 'a')
('b', 'c')
('c', 'a')
('c', 'b')

使用itertools.combinations()可得到输入集合中元素的所有组合,示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
>>> from itertools import combinations
>>> items = ["a","b","c"]
>>> for c in combinations(items,3):
... print(c)
...
('a', 'b', 'c')
>>> for c in combinations(items,2):
... print(c)
...
('a', 'b')
('a', 'c')
('b', 'c')
>>> for c in combinations(items,1):
... print(c)
...
('a',)
('b',)
('c',)

对于combinations()来讲,元素顺序已经不重要了,也就是说,('a','b')('b','a')是一样的。

在计算组合的时候,一旦元素被选取就会从后选中剔除掉,而函数itertools.combinations_with_replacement()允许同一个元素被选择多次,示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
>>> from itertools import combinations_with_replacement
>>> items = ["a","b","c"]
>>> for c in combinations_with_replacement(items,3):
... print(c)
...
('a', 'a', 'a')
('a', 'a', 'b')
('a', 'a', 'c')
('a', 'b', 'b')
('a', 'b', 'c')
('a', 'c', 'c')
('b', 'b', 'b')
('b', 'b', 'c')
('b', 'c', 'c')
('c', 'c', 'c')

序列上索引值迭代

一般内置的enumerate()函数可以很好的处理该问题,示例如下:

1
2
3
4
5
6
7
>>> a = ['a','b','c']
>>> for idx,val in enumerate(a):
... print(idx,val)
...
0 a
1 b
2 c

为了按照传统行号输出,可以传递一个开始参数,示例如下:

1
2
3
4
5
6
7
>>> a = ['a','b','c']
>>> for idx,val in enumerate(a,1):
... print(idx,val)
...
1 a
2 b
3 c

这种情况在你遍历文件时想在错误消息中使用行号定位时非常有用,示例如下:

1
2
3
4
5
6
7
8
9
def parse_data(filename):
with open(filename, 'rt') as f:
for lineno, line in enumerate(f, 1):
fields = line.split()
try:
count = int(fields[1])
...
except ValueError as e:
print('Line {}: Parse error: {}'.format(lineno, e))

还有一点可能并不很重要,但是也值得注意, 有时候当你在一个已经解压后的元组序列上使用enumerate()函数时很容易调入陷阱。 你得像下面正确的方式这样写:

1
2
3
4
5
6
7
8
data = [ (1, 2), (3, 4), (5, 6), (7, 8) ]

# Correct!
for n, (x, y) in enumerate(data):
...
# Error!
for n, x, y in enumerate(data):
...

同时迭代多个序列

为了同时迭代多个序列,可以使用zip()函数,示例如下:

1
2
3
4
5
6
7
8
9
10
>>> a = [1,3,5,7,9]
>>> b =[2,4,6,8,10]
>>> for x,y in zip(a,b):
... print(x,y)
...
1 2
3 4
5 6
7 8
9 10

zip(a, b)会生成一个可返回元组(x, y)的迭代器,其中x来自ay来自b。 一旦其中某个序列到底结尾,迭代宣告结束。 因此迭代长度跟参数中最短序列长度一致。

1
2
3
4
5
6
7
8
>>> a = [1,3,5]
>>> b = ['e','l','s','s','m']
>>> for i in zip(a,b):
... print(i)
...
(1, 'e')
(3, 'l')
(5, 's')

如果不希望和最短序列保持一致,可以使用itertools.zip_longest()函数来代替,示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
>>> a = [1,3,5]
>>> b = ['e','l','s','s','m']
>>> from itertools import zip_longest
>>> for i in zip_longest(a,b):
... print(i)
...
(1, 'e')
(3, 'l')
(5, 's')
(None, 's')
(None, 'm')
>>> for i in zip_longest(a,b,fillvalue=0):
... print(i)
...
(1, 'e')
(3, 'l')
(5, 's')
(0, 's')
(0, 'm')

使用zip()函数可以将数据打包并生成一个字典,示例如下:

1
2
3
4
5
>>> key=['name','age','sex']
>>> value = ['elssm','24','男']
>>> s = dict(zip(key,value))
>>> s
{'name': 'elssm', 'age': '24', 'sex': '男'}

最后强调一点就是,zip()会创建一个迭代器来作为结果返回。 如果你需要将结对的值存储在列表中,要使用list()函数。示例如下:

1
2
3
4
5
6
>>> a = [1,3,5,7,9]
>>> b =[2,4,6,8,10]
>>> zip(a,b)
<zip object at 0x00000254B2799640>
>>> list(zip(a,b))
[(1, 2), (3, 4), (5, 6), (7, 8), (9, 10)]

不同集合上元素的迭代

itertools.chain()方法接受一个可迭代对象列表作为输入,并返回一个迭代器,示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
>>> from itertools import chain
>>> a = [1,2,3,4]
>>> b = ['x','y','z']
>>> for x in chain(a,b):
... print(x)
...
1
2
3
4
x
y
z

使用chain的一个常见场景是当你相对不同的集合中所有元素执行某些操作时,示例如下:

1
2
3
4
5
6
active_items = set()
inactive_items = set()

# Iterate over all items
for item in chain(active_items, inactive_items):
# Process item

这种解决方案要比使用两个单独的循环处理更加优雅。

展开嵌套的序列

如果想将一个多层嵌套的序列展开成一个单层列表,可以写一个包含yield from语句的递归生成器来处理,示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from collections.abc import Iterable


def flatten(item, ignore_types=(str, bytes)):
for x in item:
if isinstance(x, Iterable) and not isinstance(x, ignore_types):
yield from flatten(x)
else:
yield x


items = [1, 2, [3, 4, [5, 6], 7], 8]
for i in flatten(items):
print(i)

在上述代码中,isinstance(x,Iterable)检查某个元素是否是可迭代的,如果是,yield from就会返回所有子例程的值,最终返回结果就是一个没有嵌套的简单序列,额外的参数ignore_types和检测语句isinstance(x,ignore_types)用来将字符串和字节排除在可迭代对象外,防止将它们再展开成单个字符,示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from collections.abc import Iterable


def flatten(item, ignore_types=(str, bytes)):
for x in item:
if isinstance(x, Iterable) and not isinstance(x, ignore_types):
yield from flatten(x)
else:
yield x


items = ['Dave', 'Paula', ['Thomas', 'Lewis']]
for i in flatten(items):
print(i)

语句yield from在你想在生成器中调用其他生成器作为子例程时非常有用,如果你不想在代码中使用,那么需要写额外的for循环,示例如下:

1
2
3
4
5
6
7
def flatten(items, ignore_types=(str, bytes)):
for x in items:
if isinstance(x, Iterable) and not isinstance(x, ignore_types):
for i in flatten(x):
yield i
else:
yield x

顺序迭代合并后的排序迭代对象

如果有一系列的排序序列,想将它们合并后得到一个排序序列并在上面迭代遍历。可以使用heapq.merge()函数,示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
>>> import heapq
>>> a = [1,3,5,7,9]
>>> b = [2,4,6,8,10]
>>> for c in heapq.merge(a,b):
... print(c)
...
1
2
3
4
5
6
7
8
9
10

数据编码与处理

将字典转换为XML

xml.etree.ElementTree库通常用来做解析工作,其实也可以创建XML文档。如下函数所示:

1
2
3
4
5
6
7
8
from xml.etree.ElementTree import Element,tostring
def dict_to_xml(tag, d):
elem = Element(tag)
for key, val in d.items():
child = Element(key)
child.text = str(val)
elem.append(child)
return elem

函数使用示例如下:

1
2
3
4
5
>>> s = { 'name': 'GOOG', 'shares': 100, 'price':490.1 }
>>> e = dict_to_xml('stock', s)
>>> e
<Element 'stock' at 0x1004b64c8>
>>>

转换结果是一个Element实例。对于I/O操作,使用xml.etree.ElementTree中的tostring()函数很容易就能将它转换成一个字节字符串。示例如下:

1
2
3
4
>>> from xml.etree.ElementTree import tostring
>>> tostring(e)
b'<stock><price>490.1</price><shares>100</shares><name>GOOG</name></stock>'
>>>

如果想给某个元素添加属性值,可以使用set()方法:

1
2
3
4
5
>>> e.set('_id','1234')
>>> tostring(e)
b'<stock _id="1234"><price>490.1</price><shares>100</shares><name>GOOG</name>
</stock>'
>>>

解析和修改XML

使用xml.etree.ElementTree模块可以完成该操作,假设有如下名为pred.xml的文档内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
<?xml version="1.0"?>
<stop>
<id>14791</id>
<nm>Clark &amp; Balmoral</nm>
<sri>
<rt>22</rt>
<d>North Bound</d>
<dd>North Bound</dd>
</sri>
<cr>22</cr>
<pre>
<pt>5 MIN</pt>
<fd>Howard</fd>
<v>1378</v>
<rn>22</rn>
</pre>
<pre>
<pt>15 MIN</pt>
<fd>Howard</fd>
<v>1867</v>
<rn>22</rn>
</pre>
</stop>

下面是一个利用ElementTree来读取这个文档并对它做一些修改的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
>>> from xml.etree.ElementTree import parse,Element
>>> doc = parse('pred.xml')
>>> root = doc.getroot()
root
<Element 'stop' at 0x00000204D519FB30>
>>> root.remove(root.find('sri'))
>>> root.remove(root.find('cr'))
>>> root.getchildren().index(root.find('nm'))
1
>>> e = Element('spam')
>>> e.text = 'this is a test'
>>> root.insert(2,e)
>>> doc.write("newpred.xml",xml_declaration=True)

新生成的newpred.xml文档内容如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
<?xml version='1.0' encoding='us-ascii'?>
<stop>
<id>14791</id>
<nm>Clark &amp; Balmoral</nm>
<spam>This is a test</spam>
<pre>
<pt>5 MIN</pt>
<fd>Howard</fd>
<v>1378</v>
<rn>22</rn>
</pre>
<pre>
<pt>15 MIN</pt>
<fd>Howard</fd>
<v>1867</v>
<rn>22</rn>
</pre>
</stop>

修改一个XML文档结构是很容易的,但是你必须牢记的是所有的修改都是针对父节点元素, 将它作为一个列表来处理。例如,如果你删除某个元素,通过调用父节点的remove()方法从它的直接父节点中删除。 如果你插入或增加新的元素,你同样使用父节点元素的insert()append()方法。 还能对元素使用索引和切片操作,比如element[i]element[i:j]

类与对象

改变对象的字符串显示

要改变一个实例的字符串表示,可以重新定义它的__str__()__repr__()方法,例如:

1
2
3
4
5
6
7
8
9
10
class Pair:
def __init__(self, x, y):
self.x = x
self.y = y

def __repr__(self):
return 'Pair({0.x!r}, {0.y!r})'.format(self)

def __str__(self):
return '({0.x!s}, {0.y!s})'.format(self)

__repr__()方法返回一个实例的代码表示形式,通常用来重新构造这个实例。 内置的repr()函数返回这个字符串,跟我们使用交互式解释器显示的值是一样的。__str__()方法将实例转换为一个字符串,使用str()print()函数会输出这个字符串。

1
2
3
4
5
6
>>> p = Pair(3, 4)
>>> p
Pair(3, 4) # __repr__() output
>>> print(p)
(3, 4) # __str__() output
>>>

上面的format()方法的使用看上去很有趣,格式化代码{0.x}对应的是第1个参数的x属性。 因此,在下面的函数中,0实际上指的就是self本身:

1
2
def __repr__(self):
return 'Pair({0.x!r}, {0.y!r})'.format(self)

作为这种实现的替代,也可以使用%操作符,示例如下:

1
2
def __repr__(self):
return 'Pair(%r, %r)' % (self.x, self.y)

自定义字符串的格式化

为了自定义字符串的格式化,我们需要在类上面定义__format__()方法,实例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
_formats = {
'ymd' : '{d.year}-{d.month}-{d.day}',
'mdy' : '{d.month}/{d.day}/{d.year}',
'dmy' : '{d.day}/{d.month}/{d.year}'
}

class Date:
def __init__(self, year, month, day):
self.year = year
self.month = month
self.day = day

def __format__(self, code):
if code == '':
code = 'ymd'
fmt = _formats[code]
return fmt.format(d=self)

现在Date类的示例可以支持格式化操作,如下所示:

1
2
3
4
5
6
7
8
9
10
>>> d = Date(2023, 10, 11)
>>> format(d)
'2023-10-11'
>>> format(d, 'mdy')
'11/10/2023'
>>> 'The date is {:ymd}'.format(d)
'The date is 2023-10-11'
>>> 'The date is {:mdy}'.format(d)
'The date is 10/11/2023'
>>>

__format__()方法给Python的字符串格式化功能提供了一个钩子。 这里需要着重强调的是格式化代码的解析工作完全由类自己决定。因此,格式化代码可以是任何值。 例如,参考下面来自datetime模块中的代码:

1
2
3
4
5
6
7
>>> from datetime import date
>>> d = date(2023,10,11)
>>> format(d)
'2023-10-11'
>>> format(d,'%A,%B,%d,%Y')
'Wednesday,October,11,2023'
>>>

让对象支持上下文管理协议

为了让一个对象兼容with语句,你需要实现__enter__()__exit__()方法。例如,考虑如下的一个类,它能为我们创建一个网络连接:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from socket import socket, AF_INET, SOCK_STREAM

class LazyConnection:
def __init__(self, address, family=AF_INET, type=SOCK_STREAM):
self.address = address
self.family = family
self.type = type
self.sock = None

def __enter__(self):
if self.sock is not None:
raise RuntimeError('Already connected')
self.sock = socket(self.family, self.type)
self.sock.connect(self.address)
return self.sock

def __exit__(self, exc_ty, exc_val, tb):
self.sock.close()
self.sock = None

这个类的关键特点在于它表示了一个网络连接,但是初始化的时候并不会做任何事情。 连接的建立和关闭是使用with语句自动完成的,例如:

1
2
3
4
5
6
7
8
9
10
11
from functools import partial

conn = LazyConnection(('www.python.org', 80))
# Connection closed
with conn as s:
# conn.__enter__() executes: connection open
s.send(b'GET /index.html HTTP/1.0\r\n')
s.send(b'Host: www.python.org\r\n')
s.send(b'\r\n')
resp = b''.join(iter(partial(s.recv, 8192), b''))
# conn.__exit__() executes: connection closed

编写上下文管理器的主要原理是你的代码会放到with语句块中执行。 当出现with语句的时候,对象的__enter__()方法被触发, 它返回的值会被赋值给as声明的变量。然后,with语句块里面的代码开始执行。 最后,__exit__()方法被触发进行清理工作。

在类中封装属性名

Python程序员不去依赖语言特性去封装数据,而是通过遵循一定的属性和方法命名规约来达到这个效果。 第一个约定是任何以单下划线_开头的名字都应该是内部实现。示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
class A:
def __init__(self):
self._internal = 0 # 内部属性
self.public = 1 # A 公共属性

def public_method(self):
'''
A public method
'''
pass

def _internal_method(self):
pass

还可能会遇到在类定义中使用两个下划线(__)开头的命名。比如:

1
2
3
4
5
6
7
8
9
10
class B:
def __init__(self):
self.__private = 0

def __private_method(self):
pass

def public_method(self):
pass
self.__private_method()

使用双下划线开始会导致访问名称变成其他形式。 比如,在前面的类B中,私有属性会被分别重命名为_B__private_B__private_method。 这时候你可能会问这样重命名的目的是什么,答案就是继承——这种属性通过继承是无法被覆盖的。比如:

1
2
3
4
5
6
7
8
class C(B):
def __init__(self):
super().__init__()
self.__private = 1 # Does not override B.__private

# Does not override B.__private_method()
def __private_method(self):
pass

这里,私有名称__private__private_method被重命名为_C__private_C__private_method,这个跟父类B中的名称是完全不同的。
大多数而言,你应该让你的非公共名称以单下划线开头。但是,如果你清楚你的代码会涉及到子类, 并且有些内部属性应该在子类中隐藏起来,那么才考虑使用双下划线方案。

创建可管理的属性

自定义某个属性的一种简单方法是将它定义为一个property。 例如,下面的代码定义了一个property,增加对一个属性简单的类型检查:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Person:
def __init__(self, first_name):
self._first_name = first_name

# Getter function
@property
def first_name(self):
return self._first_name

# Setter function
@first_name.setter
def first_name(self, value):
if not isinstance(value, str):
raise TypeError('Expected a string')
self._first_name = value

# Deleter function (optional)
@first_name.deleter
def first_name(self):
raise AttributeError("Can't delete attribute")

上述代码中有三个相关联的方法,这三个方法的名字都必须一样。 第一个方法是一个getter函数,它使得first_name成为一个属性。 其他两个方法给first_name属性添加了setterdeleter函数。 需要强调的是只有在first_name属性被创建后, 后面的两个装饰器@first_name.setter@first_name.deleter才能被定义。property的一个关键特征是它看上去跟普通的attribute没什么两样, 但是访问它的时候会自动触发gettersetterdeleter方法。示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
>>> a = Person('Guido')
>>> a.first_name # Calls the getter
'Guido'
>>> a.first_name = 42 # Calls the setter
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "prop.py", line 14, in first_name
raise TypeError('Expected a string')
TypeError: Expected a string
>>> del a.first_name
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: can`t delete attribute
>>>

调用父类方法

为了调用父类(超类)的一个方法,可以使用super()函数,比如:

1
2
3
4
5
6
7
8
class A:
def spam(self):
print('A.spam')

class B(A):
def spam(self):
print('B.spam')
super().spam()

super()函数的一个常见用法是在__init__()方法中确保父类被正确的初始化了:

1
2
3
4
5
6
7
8
class A:
def __init__(self):
self.x = 0

class B(A):
def __init__(self):
super().__init__()
self.y = 1

super()的另外一个常见用法出现在覆盖Python特殊方法的代码中,比如:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Proxy:
def __init__(self, obj):
self._obj = obj

# Delegate attribute lookup to internal obj
def __getattr__(self, name):
return getattr(self._obj, name)

# Delegate attribute assignment
def __setattr__(self, name, value):
if name.startswith('_'):
super().__setattr__(name, value) # Call original __setattr__
else:
setattr(self._obj, name, value)

在上面代码中,__setattr__()的实现包含一个名字检查。如果某个属性名以下划线(_)开头,就通过super()调用原始的__setattr__(), 否则的话就委派给内部的代理对象self._obj去处理。因为就算没有显式的指明某个类的父类,super()仍然可以有效的工作。

子类中扩展property

如下代码定义了一个property:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Person:
def __init__(self, name):
self.name = name

# Getter function
@property
def name(self):
return self._name

# Setter function
@name.setter
def name(self, value):
if not isinstance(value, str):
raise TypeError('Expected a string')
self._name = value

# Deleter function
@name.deleter
def name(self):
raise AttributeError("Can't delete attribute")

下面是一个示例类,它继承自Person并扩展了name属性的功能:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class SubPerson(Person):
@property
def name(self):
print('Getting name')
return super().name

@name.setter
def name(self, value):
print('Setting name to', value)
super(SubPerson, SubPerson).name.__set__(self, value)

@name.deleter
def name(self):
print('Deleting name')
super(SubPerson, SubPerson).name.__delete__(self)

如果仅仅只想扩展property的某一个方法,那么可以像下面这样写:

1
2
3
4
5
class SubPerson(Person):
@Person.name.getter
def name(self):
print('Getting name')
return super().name

创建新的类或实例属性

如果想创建一个全新的实例属性,可以通过一个描述器类的形式来定义它的功能,示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class Integer:
def __init__(self, name):
self.name = name

def __get__(self, instance, cls):
if instance is None:
return self
else:
return instance.__dict__[self.name]

def __set__(self, instance, value):
if not isinstance(value, int):
raise TypeError('Expected an int')
instance.__dict__[self.name] = value

def __delete__(self, instance):
del instance.__dict__[self.name]

一个描述器就是一个实现了三个核心的属性访问操作(get,set,delete)的类,分别为(__get__(),__set__(),__delete__())这三个特殊的方法,这些方法接受一个实例作为输入,之后相应的操作实例底层的字典。为了使用一个描述器,需将这个描述器的实例作为类属性放到一个类的定义中,例如:

1
2
3
4
5
6
7
class Point:
x = Integer('x')
y = Integer('y')

def __init__(self, x, y):
self.x = x
self.y = y

这样做所有对描述器属性(如x或y)的访问会被(__get__(),__set__(),__delete__())方法捕获到,示例如下:

1
2
3
4
5
6
7
8
9
10
11
>>> p = Point(2, 3)
>>> p.x # Calls Point.x.__get__(p,Point)
2
>>> p.y = 5 # Calls Point.y.__set__(p, 5)
>>> p.x = 2.3 # Calls Point.x.__set__(p, 2.3)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "descrip.py", line 12, in __set__
raise TypeError('Expected an int')
TypeError: Expected an int
>>>

作为输入,描述器的每一个方法会接受一个操作实例。为了实现请求操作,会相应的操作实例底层的字典(__dict__属性)。 描述器的self.name属性存储了在实例字典中被实际使用到的key

需要注意的是:描述器只能在类级别被定义,而不能为每个实例单独定义。因此,下面的代码是无法工作的:

1
2
3
4
5
6
class Point:
def __init__(self, x, y):
self.x = Integer('x') # 错误,必须是类变量
self.y = Integer('y')
self.x = x
self.y = y

使用延迟计算属性

如果想将一个只读属性定义成一个property,并且只在访问的时候才会计算结果。 但是一旦被访问后,你希望结果值被缓存起来,不用每次都去计算。可以使用一个描述器类,如下所示:

1
2
3
4
5
6
7
8
9
10
11
class lazyproperty:
def __init__(self, func):
self.func = func

def __get__(self, instance, cls):
if instance is None:
return self
else:
value = self.func(instance)
setattr(instance, self.func.__name__, value)
return value

使用上述描述器类示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import math

class Circle:
def __init__(self, radius):
self.radius = radius

@lazyproperty
def area(self):
print('Computing area')
return math.pi * self.radius ** 2

@lazyproperty
def perimeter(self):
print('Computing perimeter')
return 2 * math.pi * self.radius

交互环境演示如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
>>> c = Circle(4.0)
>>> c.radius
4.0
>>> c.area
Computing area
50.26548245743669
>>> c.area
50.26548245743669
>>> c.perimeter
Computing perimeter
25.132741228718345
>>> c.perimeter
25.132741228718345
>>>

可以发现Computing areaComputing perimeter只出现了一次。

这种方案有一个小缺陷就是计算出的值被创建后是可以被修改的。例如:

1
2
3
4
5
6
7
>>> c.area
Computing area
50.26548245743669
>>> c.area = 25
>>> c.area
25
>>>

如果想要修改这个问题可以使用如下方式:

1
2
3
4
5
6
7
8
9
10
11
def lazyproperty(func):
name = '_lazy_' + func.__name__
@property
def lazy(self):
if hasattr(self, name):
return getattr(self, name)
else:
value = func(self)
setattr(self, name, value)
return value
return lazy

这种方式不允许修改:

1
2
3
4
5
6
7
8
9
10
11
>>> c = Circle(4.0)
>>> c.area
Computing area
50.26548245743669
>>> c.area
50.26548245743669
>>> c.area = 25
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: can't set attribute
>>>

简化数据结构的初始化

如果不想写太多__init__()函数,可以在一个基类中写一个公用的__init__()函数:

1
2
3
4
5
6
7
8
9
10
import math

class Structure1:
_fields = []

def __init__(self, *args):
if len(args) != len(self._fields):
raise TypeError('Expected {} arguments'.format(len(self._fields)))
for name, value in zip(self._fields, args):
setattr(self, name, value)

然后让写的类继承上述基类:

1
2
3
4
5
6
7
8
9
10
11
class Stock(Structure1):
_fields = ['name', 'shares', 'price']

class Point(Structure1):
_fields = ['x', 'y']

class Circle(Structure1):
_fields = ['radius']

def area(self):
return math.pi * self.radius ** 2

使用示例如下:

1
2
3
4
5
6
7
8
9
>>> s = Stock('ACME', 50, 91.1)
>>> p = Point(2, 3)
>>> c = Circle(4.5)
>>> s2 = Stock('ACME', 50)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "structure.py", line 6, in __init__
raise TypeError('Expected {} arguments'.format(len(self._fields)))
TypeError: Expected 3 arguments

除此之外,还可以将不在_fields中的名称加入到属性中去,示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Structure3:
_fields = []

def __init__(self, *args, **kwargs):
if len(args) != len(self._fields):
raise TypeError('Expected {} arguments'.format(len(self._fields)))

for name, value in zip(self._fields, args):
setattr(self, name, value)

# 设置额外属性值
extra_args = kwargs.keys() - self._fields
for name in extra_args:
setattr(self, name, kwargs.pop(name))

if kwargs:
raise TypeError('Duplicate values for {}'.format(','.join(kwargs)))

if __name__ == '__main__':
class Stock(Structure3):
_fields = ['name', 'shares', 'price']

s1 = Stock('ACME', 50, 91.1)
s2 = Stock('ACME', 50, 91.1, date='8/2/2012')

定义接口或者抽象基类

使用abc模块可以很轻松的定义抽象基类,示例如下:

1
2
3
4
5
6
7
8
9
10
from abc import ABCMeta, abstractmethod

class IStream(metaclass=ABCMeta):
@abstractmethod
def read(self, maxbytes=-1):
pass

@abstractmethod
def write(self, data):
pass

抽象类的一个特点是它不能直接被实例化,如下方式是错误的:

1
a = IStream()

抽象类的目的就是让别的类继承它并实现特定的抽象方法:

1
2
3
4
5
6
class SocketStream(IStream):
def read(self, maxbytes=-1):
pass

def write(self, data):
pass

抽象基类的一个主要用途是在代码中检查某些类是否为特定类型,实现了特定接口:

1
2
3
4
def serialize(obj, stream):
if not isinstance(stream, IStream):
raise TypeError('Expected an IStream')
pass

实现自定义容器

collections定义了很多抽象基类,当你想自定义容器类的时候它们会非常有用。 比如你想让你的类支持迭代,那就让你的类继承collections.Iterable即可:

1
2
3
import collections
class A(collections.Iterable):
pass

不过需要实现collections.Iterable所有的抽象方法,否则会报错:

1
2
3
4
5
>>> a = A()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: Can't instantiate abstract class A with abstract methods __iter__
>>>

示例如下:

1
2
3
4
5
from collections.abc import Iterable
class A(Iterable):
def __iter__(self):
pass
a = A()

collections中很多抽象类会为一些常见容器操作提供默认的实现,这样一来你只需要实现那些你最感兴趣的方法即可。假设你的类继承自collections.MutableSequence,如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Items(collections.MutableSequence):
def __init__(self, initial=None):
self._items = list(initial) if initial is not None else []

def __getitem__(self, index):
print('Getting:', index)
return self._items[index]

def __setitem__(self, index, value):
print('Setting:', index, value)
self._items[index] = value

def __delitem__(self, index):
print('Deleting:', index)
del self._items[index]

def insert(self, index, value):
print('Inserting:', index, value)
self._items.insert(index, value)

def __len__(self):
print('Len')
return len(self._items)

如果你创建Items的实例,你会发现它支持几乎所有的核心列表方法(如append()、remove()、count()等)。 使用示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
>>> a = Items([1, 2, 3])
>>> len(a)
Len
3
>>> a.append(4)
Len
Inserting: 3 4
>>> a.append(2)
Len
Inserting: 4 2
>>> a.count(2)
Getting: 0
Getting: 1
Getting: 2
Getting: 3
Getting: 4
Getting: 5
2
>>> a.remove(3)
Getting: 0
Getting: 1
Getting: 2
Deleting: 2
>>>

属性的代理访问

简单来说,代理是一种编程模式,它将某个操作转移给另外一个对象来实现。 最简单的形式可能是像下面这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class A:
def spam(self, x):
pass

def foo(self):
pass


class B1:
"""简单的代理"""

def __init__(self):
self._a = A()

def spam(self, x):
return self._a.spam(x)

def foo(self):
return self._a.foo()

def bar(self):
pass

如果有大量的方法需要代理, 那么使用__getattr__()方法会更好:

1
2
3
4
5
6
7
8
9
10
11
12
class B2:
"""使用__getattr__的代理,代理方法比较多时候"""

def __init__(self):
self._a = A()

def bar(self):
pass

def __getattr__(self, name):
"""这个方法在访问的attribute不存在的时候被调用"""
return getattr(self._a, name)

示例如下:

1
2
3
b = B2()
b.bar() # 调用B.bar() (B2上存在)
b.spam(42) # 调用B.__getattr__('spam')代理到A.spam

另外一个代理例子是实现代理模式,示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class Proxy:
def __init__(self, obj):
self._obj = obj

def __getattr__(self, name):
print('getattr:', name)
return getattr(self._obj, name)

def __setattr__(self, name, value):
if name.startswith('_'):
super().__setattr__(name, value)
else:
print('setattr:', name, value)
setattr(self._obj, name, value)

def __delattr__(self, name):
if name.startswith('_'):
super().__delattr__(name)
else:
print('delattr:', name)
delattr(self._obj, name)

使用这个代理类时,你只需要用它来包装下其他类即可:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Spam:
def __init__(self, x):
self.x = x

def bar(self, y):
print('Spam.bar:', self.x, y)

# Create an instance
s = Spam(2)
# Create a proxy around it
p = Proxy(s)
# Access the proxy
print(p.x) # Outputs 2
p.bar(3) # Outputs "Spam.bar: 2 3"
p.x = 37 # Changes s.x to 37

通过自定义属性访问方法,你可以用不同方式自定义代理类行为。

代理类有时候可以作为继承的替代方案。例如,一个简单的继承如下:

1
2
3
4
5
6
7
8
9
10
11
12
class A:
def spam(self, x):
print('A.spam', x)
def foo(self):
print('A.foo')

class B(A):
def spam(self, x):
print('B.spam')
super().spam(x)
def bar(self):
print('B.bar')

使用代理的话,就是下面这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class A:
def spam(self, x):
print('A.spam', x)
def foo(self):
print('A.foo')

class B:
def __init__(self):
self._a = A()
def spam(self, x):
print('B.spam', x)
self._a.spam(x)
def bar(self):
print('B.bar')
def __getattr__(self, name):
return getattr(self._a, name)

当实现代理模式时,还有些细节需要注意。首先,__getattr__()实际是一个后备方法,只有在属性不存在时才会调用。因此,如果代理类实例本身有这个属性的话,那么不会触发这个方法的。另外,__setattr__()__delattr__()需要额外的魔法来区分代理实例和被代理实例_obj的属性。一个通常的约定是只代理那些不以下划线_开头的属性。

还有一点需要注意的是,__getattr__()对于大部分以双下划线(__)开始和结尾的属性并不适用。 例如考虑如下的类:

1
2
3
4
5
6
7
8
class ListLike:
"""__getattr__对于双下划线开始和结尾的方法是不能用的,需要一个个去重定义"""

def __init__(self):
self._items = []

def __getattr__(self, name):
return getattr(self._items, name)

如果是创建一个ListLike对象,会发现它支持普通的列表方法,如append()insert(), 但是却不支持len()、元素查找等。例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
>>> a = ListLike()
>>> a.append(2)
>>> a.insert(0, 1)
>>> a.sort()
>>> len(a)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: object of type 'ListLike' has no len()
>>> a[0]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: 'ListLike' object does not support indexing
>>>

为了让它支持这些方法,必须手动实现这些方法代理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class ListLike:
"""__getattr__对于双下划线开始和结尾的方法是不能用的,需要一个个去重定义"""

def __init__(self):
self._items = []

def __getattr__(self, name):
return getattr(self._items, name)

def __len__(self):
return len(self._items)

def __getitem__(self, index):
return self._items[index]

def __setitem__(self, index, value):
self._items[index] = value

def __delitem__(self, index):
del self._items[index]

创建不调用init方法的实例

可以通过__new__()方法创建一个未初始化的实例。例如考虑如下这个类:

1
2
3
4
5
class Date:
def __init__(self, year, month, day):
self.year = year
self.month = month
self.day = day

下面为不调用__init__()方法来创建这个Date实例:

1
2
3
4
5
6
7
8
>>> d = Date.__new__(Date)
>>> d
<__main__.Date object at 0x1006716d0>
>>> d.year
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'Date' object has no attribute 'year'
>>>

结果可以看到,这个Date实例的属性year还不存在,所以你需要手动初始化:

1
2
3
4
5
6
7
8
9
>>> data = {'year':2023, 'month':10, 'day':11}
>>> for key, value in data.items():
... setattr(d, key, value)
...
>>> d.year
2023
>>> d.month
10
>>>

实现状态对象或者状态机

在很多程序中,有些对象会根据状态的不同来执行不同的操作。比如考虑如下的一个连接对象:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Connection:
"""普通方案,好多个判断语句,效率低下~~"""

def __init__(self):
self.state = 'CLOSED'

def read(self):
if self.state != 'OPEN':
raise RuntimeError('Not open')
print('reading')

def write(self, data):
if self.state != 'OPEN':
raise RuntimeError('Not open')
print('writing')

def open(self):
if self.state == 'OPEN':
raise RuntimeError('Already open')
self.state = 'OPEN'

def close(self):
if self.state == 'CLOSED':
raise RuntimeError('Already closed')
self.state = 'CLOSED'

这样写有很多缺点,首先是代码太复杂了,好多的条件判断。其次是执行效率变低, 因为一些常见的操作比如read()、write()每次执行前都需要执行检查。一个更好的办法是为每个状态定义一个对象:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class Connection1:
"""新方案——对每个状态定义一个类"""

def __init__(self):
self.new_state(ClosedConnectionState)

def new_state(self, newstate):
self._state = newstate
# Delegate to the state class

def read(self):
return self._state.read(self)

def write(self, data):
return self._state.write(self, data)

def open(self):
return self._state.open(self)

def close(self):
return self._state.close(self)


# Connection state base class
class ConnectionState:
@staticmethod
def read(conn):
raise NotImplementedError()

@staticmethod
def write(conn, data):
raise NotImplementedError()

@staticmethod
def open(conn):
raise NotImplementedError()

@staticmethod
def close(conn):
raise NotImplementedError()


class ClosedConnectionState(ConnectionState):
@staticmethod
def read(conn):
raise RuntimeError('Not open')

@staticmethod
def write(conn, data):
raise RuntimeError('Not open')

@staticmethod
def open(conn):
conn.new_state(OpenConnectionState)

@staticmethod
def close(conn):
raise RuntimeError('Already closed')


class OpenConnectionState(ConnectionState):
@staticmethod
def read(conn):
print('reading')

@staticmethod
def write(conn, data):
print('writing')

@staticmethod
def open(conn):
raise RuntimeError('Already open')

@staticmethod
def close(conn):
conn.new_state(ClosedConnectionState)

使用示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
>>> c = Connection()
>>> c._state
<class '__main__.ClosedConnectionState'>
>>> c.read()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "example.py", line 10, in read
return self._state.read(self)
File "example.py", line 43, in read
raise RuntimeError('Not open')
RuntimeError: Not open
>>> c.open()
>>> c._state
<class '__main__.OpenConnectionState'>
>>> c.read()
reading
>>> c.write('hello')
writing
>>> c.close()
>>> c._state
<class '__main__.ClosedConnectionState'>
>>>

通过字符串调用对象方法

有一个字符串形式的方法名称,如果想通过它调用某个对象的对应方法。最简单的情况,可以使用getattr()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import math

class Point:
def __init__(self, x, y):
self.x = x
self.y = y

def __repr__(self):
return 'Point({!r:},{!r:})'.format(self.x, self.y)

def distance(self, x, y):
return math.hypot(self.x - x, self.y - y)


p = Point(2, 3)
d = getattr(p, 'distance')(0, 0) # Calls p.distance(0, 0)

另外一种方法是使用operator.methodcaller(),示例如下:

1
2
import operator
operator.methodcaller('distance', 0, 0)(p)

当你需要通过相同的参数多次调用某个方法时,使用operator.methodcaller就很方便了。 比如你需要排序一系列的点,就可以这样做:

1
2
3
4
5
6
7
8
9
points = [
Point(1, 2),
Point(3, 0),
Point(10, -3),
Point(-5, -7),
Point(-1, 8),
Point(3, 2)
]
points.sort(key=operator.methodcaller('distance', 0, 0))

调用一个方法实际上是两步独立操作,第一步是查找属性,第二步是函数调用。 因此,为了调用某个方法,你可以首先通过getattr()来查找到这个属性,然后再去以函数方式调用它即可。
operator.methodcaller()创建一个可调用对象,并同时提供所有必要参数, 然后调用的时候只需要将实例对象传递给它即可,示例如下:

1
2
3
4
5
>>> p = Point(3, 4)
>>> d = operator.methodcaller('distance', 0, 0)
>>> d(p)
5.0
>>>

让类支持比较操作

Python类对每个比较操作都需要实现一个特殊方法来支持。例如为了支持>=操作符,你需要定义一个__ge__()方法。装饰器functools.total_ordering就是用来简化这个处理的。使用它来装饰一个类,你只需定义一个__eq__()方法, 外加其他方法(__lt__, __le__, __gt__, or __ge__)中的一个即可。 然后装饰器会自动为你填充其它比较方法。示例如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from functools import total_ordering

class Room:
def __init__(self, name, length, width):
self.name = name
self.length = length
self.width = width
self.square_feet = self.length * self.width

@total_ordering
class House:
def __init__(self, name, style):
self.name = name
self.style = style
self.rooms = list()

@property
def living_space_footage(self):
return sum(r.square_feet for r in self.rooms)

def add_room(self, room):
self.rooms.append(room)

def __str__(self):
return '{}: {} square foot {}'.format(self.name,
self.living_space_footage,
self.style)

def __eq__(self, other):
return self.living_space_footage == other.living_space_footage

def __lt__(self, other):
return self.living_space_footage < other.living_space_footage

这里我们只是给House类定义了两个方法:__eq__()__lt__(),它就能支持所有的比较操作:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
h1 = House('h1', 'Cape')
h1.add_room(Room('Master Bedroom', 14, 21))
h1.add_room(Room('Living Room', 18, 20))
h1.add_room(Room('Kitchen', 12, 16))
h1.add_room(Room('Office', 12, 12))
h2 = House('h2', 'Ranch')
h2.add_room(Room('Master Bedroom', 14, 21))
h2.add_room(Room('Living Room', 18, 20))
h2.add_room(Room('Kitchen', 12, 16))
h3 = House('h3', 'Split')
h3.add_room(Room('Master Bedroom', 14, 21))
h3.add_room(Room('Living Room', 18, 20))
h3.add_room(Room('Office', 12, 16))
h3.add_room(Room('Kitchen', 15, 17))
houses = [h1, h2, h3]
print('Is h1 bigger than h2?', h1 > h2) # prints True
print('Is h2 smaller than h3?', h2 < h3) # prints True
print('Is h2 greater than or equal to h1?', h2 >= h1) # Prints False
print('Which one is biggest?', max(houses)) # Prints 'h3: 1101-square-foot Split'
print('Which is smallest?', min(houses)) # Prints 'h2: 846-square-foot Ranch'

total_ordering装饰器也没那么神秘。它就是定义了一个从每个比较支持方法到所有需要定义的其他方法的一个映射而已。比如你定义了__le__()方法,那么它就被用来构建所有其他的需要定义的那些特殊方法。 实际上就是在类里面像下面这样定义了一些特殊方法:

1
2
3
4
5
6
7
8
9
10
class House:
def __eq__(self, other):
pass
def __lt__(self, other):
pass
# Methods created by @total_ordering
__le__ = lambda self, other: self < other or self == other
__gt__ = lambda self, other: not (self < other or self == other)
__ge__ = lambda self, other: not (self < other)
__ne__ = lambda self, other: not self == other