用Python手撸一个MDP求解器:从David Silver的PPT到可运行的代码(附完整源码)
从理论到实践用Python构建马尔可夫决策过程求解器马尔可夫决策过程MDP是强化学习的基础框架理解其数学原理和实现细节对掌握现代AI技术至关重要。本文将带您从David Silver的经典课程出发逐步实现一个完整的MDP求解器涵盖状态转移矩阵构建、Bellman方程迭代和策略优化等核心环节。1. MDP基础与Python建模马尔可夫决策过程由五元组(S,A,P,R,γ)构成其中S是有限状态集合A是有限动作集合P是状态转移概率矩阵R是即时奖励函数γ是折扣因子让我们用Python类来建模这个结构class MarkovDecisionProcess: def __init__(self, states, actions, transitions, rewards, gamma0.9): self.states states # 状态列表 self.actions actions # 动作列表 self.transitions transitions # 转移概率字典 self.rewards rewards # 奖励函数 self.gamma gamma # 折扣因子 def get_transition_prob(self, s, a, s_next): 获取状态转移概率P(s|s,a) return self.transitions.get((s, a, s_next), 0.0) def get_reward(self, s, a): 获取即时奖励R(s,a) return self.rewards.get((s, a), 0.0)注意实际实现时应添加输入验证确保概率分布合法和为1和状态/动作有效性2. 状态价值函数的计算状态价值函数V(s)表示从状态s开始遵循特定策略的长期回报期望。根据Bellman期望方程V(s) Σ π(a|s) * [R(s,a) γ * Σ P(s|s,a)V(s)]2.1 迭代法实现def value_iteration(mdp, epsilon1e-6, max_iter1000): V {s: 0 for s in mdp.states} for _ in range(max_iter): delta 0 new_V {} for s in mdp.states: max_value -float(inf) for a in mdp.actions: expected_value 0 for s_next in mdp.states: prob mdp.get_transition_prob(s, a, s_next) expected_value prob * (mdp.get_reward(s, a) mdp.gamma * V[s_next]) if expected_value max_value: max_value expected_value new_V[s] max_value delta max(delta, abs(new_V[s] - V[s])) V new_V if delta epsilon: break return V2.2 矩阵解法对于中等规模的问题我们可以使用线性代数方法直接求解import numpy as np def solve_value_function(mdp, policy): 通过矩阵运算求解固定策略下的价值函数 n len(mdp.states) P np.zeros((n, n)) R np.zeros(n) # 构建转移概率矩阵和奖励向量 for i, s in enumerate(mdp.states): for j, s_next in enumerate(mdp.states): for a in mdp.actions: P[i,j] policy[s][a] * mdp.get_transition_prob(s, a, s_next) for a in mdp.actions: R[i] policy[s][a] * mdp.get_reward(s, a) # 解线性方程组 V R γPV I np.eye(n) V np.linalg.solve(I - mdp.gamma * P, R) return {s: V[i] for i, s in enumerate(mdp.states)}3. 策略优化与Q学习动作价值函数Q(s,a)表示在状态s执行动作a后遵循最优策略的期望回报。我们可以通过策略迭代来优化策略def policy_iteration(mdp, initial_policyNone, max_iter100): if initial_policy is None: # 初始随机策略 policy {s: {a: 1/len(mdp.actions) for a in mdp.actions} for s in mdp.states} else: policy initial_policy for _ in range(max_iter): # 策略评估 V solve_value_function(mdp, policy) # 策略改进 policy_stable True for s in mdp.states: old_action max(policy[s].items(), keylambda x: x[1])[0] # 找出最优动作 q_values {} for a in mdp.actions: q_value 0 for s_next in mdp.states: prob mdp.get_transition_prob(s, a, s_next) q_value prob * (mdp.get_reward(s, a) mdp.gamma * V[s_next]) q_values[a] q_value best_action max(q_values.items(), keylambda x: x[1])[0] # 更新策略 if old_action ! best_action: policy_stable False for a in mdp.actions: policy[s][a] 1.0 if a best_action else 0.0 if policy_stable: break return policy4. 实战案例学生MDP问题让我们实现David Silver课程中的经典学生问题# 定义状态和动作 states [Class1, Class2, Class3, Pass, Pub, Facebook, Sleep] actions [Study, Facebook, Sleep, Quit, Pub] # 定义转移概率 transitions { # (当前状态, 动作, 下一状态): 概率 (Class1, Study, Class2): 0.5, (Class1, Facebook, Facebook): 0.5, (Class2, Study, Class3): 0.8, (Class2, Sleep, Sleep): 0.2, # ...其他转移概率 } # 定义奖励函数 rewards { (Class1, Study): -2, (Class1, Facebook): -1, (Class2, Study): -2, # ...其他奖励 } # 创建MDP实例 student_mdp MarkovDecisionProcess(states, actions, transitions, rewards) # 计算最优价值函数 optimal_V value_iteration(student_mdp) print(最优价值函数:, optimal_V) # 计算最优策略 optimal_policy policy_iteration(student_mdp) print(最优策略:) for s in states: print(f{s}: {max(optimal_policy[s].items(), keylambda x: x[1])[0]})5. 工程实践中的关键问题在实际实现MDP求解器时有几个常见陷阱需要注意稀疏矩阵处理对于大规模状态空间应使用稀疏矩阵存储转移概率from scipy.sparse import dok_matrix def build_sparse_transition_matrix(mdp, policy): n len(mdp.states) P dok_matrix((n, n)) for i, s in enumerate(mdp.states): for a in mdp.actions: for s_next in mdp.states: prob mdp.get_transition_prob(s, a, s_next) if prob 0: P[i, mdp.states.index(s_next)] policy[s][a] * prob return P.tocsr()收敛性检查迭代算法需要合理的停止条件def check_convergence(old_V, new_V, epsilon1e-6): return all(abs(old_V[s] - new_V[s]) epsilon for s in old_V)数值稳定性矩阵求逆可能不稳定可考虑使用迭代法或伪逆def safe_matrix_solve(A, b): try: return np.linalg.solve(A, b) except np.linalg.LinAlgError: return np.linalg.pinv(A) b并行计算优化对于大规模问题可以使用并行计算加速from multiprocessing import Pool def parallel_value_update(args): s, mdp, V args max_value -float(inf) for a in mdp.actions: expected_value 0 for s_next in mdp.states: prob mdp.get_transition_prob(s, a, s_next) expected_value prob * (mdp.get_reward(s, a) mdp.gamma * V[s_next]) if expected_value max_value: max_value expected_value return (s, max_value) def parallel_value_iteration(mdp, epsilon1e-6, max_iter1000): V {s: 0 for s in mdp.states} with Pool() as pool: for _ in range(max_iter): args [(s, mdp, V) for s in mdp.states] new_V dict(pool.map(parallel_value_update, args)) delta max(abs(new_V[s] - V[s]) for s in mdp.states) V new_V if delta epsilon: break return V实现一个健壮的MDP求解器需要考虑算法效率、数值稳定性和代码可维护性的平衡。通过将数学公式转化为模块化的Python代码我们不仅加深了对理论的理解也获得了可以应用于实际问题的工具。