Skip to main content

Python 函数

Python 偏函数

偏函数(Partial Functions)是 Python 中通过 functools.partial 创建的函数,允许固定原函数的部分参数,生成一个新函数。偏函数简化重复调用函数时的参数传递,提高代码复用性和可读性。

基本概念

偏函数的核心是通过 functools.partial 模块固定函数的部分参数。基本语法为:

from functools import partial
new_function = partial(original_function, *args, **kwargs)
  • original_function:要修改的原函数。
  • args:固定的位置参数。
  • kwargs:固定的关键字参数。
  • new_function:返回的新函数,调用时只需传入剩余参数。

示例:固定参数的简单偏函数

from functools import partial

def multiply(x, y):
    return x * y

double = partial(multiply, 2)
print(double(5))  # 输出: 10

此代码固定 multiply 函数的 x=2,生成新函数 double,调用时只需提供 y

固定多个参数

偏函数可以固定多个参数,无论是位置参数还是关键字参数。

示例:固定多个位置参数

from functools import partial

def add(a, b, c):
    return a + b + c

add_to_ten = partial(add, 5, 3)
print(add_to_ten(2))  # 输出: 10

此代码固定 a=5b=3,生成新函数 add_to_ten,只需传入 c

示例:固定关键字参数

from functools import partial

def greet(name, greeting="Hello"):
    return f"{greeting}, {name}!"

say_hi = partial(greet, greeting="Hi")
print(say_hi("Alice"))  # 输出: Hi, Alice!

此代码固定 greeting="Hi",生成新函数 say_hi,只需传入 name

结合其他函数式工具

偏函数可以与 mapfilter 等函数式工具结合,提升代码简洁性。

示例:与 map 结合

from functools import partial

def power(base, exponent):
    return base ** exponent

square = partial(power, exponent=2)
numbers = [1, 2, 3, 4]
squares = list(map(square, numbers))
print(squares)  # 输出: [1, 4, 9, 16]

此代码固定 exponent=2,生成平方函数 square,并用 map 应用于列表。

动态修改偏函数参数

偏函数的参数在调用时可以被覆盖,增加灵活性。

示例:覆盖固定参数

from functools import partial

def divide(a, b):
    return a / b

half = partial(divide, b=2)
print(half(10))  # 输出: 5.0
print(half(10, b=5))  # 输出: 2.0

此代码固定 b=2,但调用时可通过指定 b 覆盖默认值。

注意事项

性能:偏函数的开销极小,但大量创建可能增加内存使用。

替代方案:简单场景可用默认参数或 lambda 表达式,但偏函数更适合复杂函数或需要复用的场景。

# 使用默认参数
def greet(name, greeting="Hello"):
    return f"{greeting}, {name}!"

# 使用偏函数
from functools import partial
say_hi = partial(greet, greeting="Hi")
print(say_hi("Bob"))  # 输出: Hi, Bob!

可读性:偏函数适合简单场景,复杂逻辑建议使用普通函数或 lambda 表达式。

from functools import partial

def complex_calc(x, y, z):
    return x * y + z

calc_with_y = partial(complex_calc, y=2)
print(calc_with_y(5, 3))  # 输出: 13