公式介绍
在【CS285】中提到了高斯策略对数概率公式的公式如下:
log
π
θ
(
a
t
∣
s
t
)
=
−
1
2
∥
f
(
s
t
)
−
a
t
∥
Σ
2
+
const
\log \pi_{\theta}(\mathbf{a}_t | \mathbf{s}_t) = -\frac{1}{2} \left\| f(\mathbf{s}_t) - \mathbf{a}_t \right\|_{\Sigma}^2 + \text{const}
logπθ(at∣st)=−21∥f(st)−at∥Σ2+const
符号说明
PDF:Probability Density Function,概率密度函数
推导说明(from DeepSeek-R1-web)
1. 多元高斯分布的概率密度函数
设策略
π
θ
(
a
t
∣
s
t
)
\pi_{\theta}(\mathbf{a}_t | \mathbf{s}_t)
πθ(at∣st)服从均值为
μ
=
f
(
s
t
)
\mu= f(\mathbf{s}_t)
μ=f(st)、协方差矩阵为
Σ
\Sigma
Σ的多元高斯分布,其PDF为:
π
θ
(
a
t
∣
s
t
)
=
1
(
2
π
)
d
∣
Σ
∣
exp
(
−
1
2
(
a
t
−
f
(
s
t
)
)
T
Σ
−
1
(
a
t
−
f
(
s
t
)
)
)
\pi_{\theta}(\mathbf{a}_t | \mathbf{s}_t) = \frac{1}{\sqrt{(2\pi)^{d} |\Sigma|}} \exp\left( -\frac{1}{2} (\mathbf{a}_t - f(\mathbf{s}_t))^T \Sigma^{-1} (\mathbf{a}_t - f(\mathbf{s}_t)) \right)
πθ(at∣st)=(2π)d∣Σ∣1exp(−21(at−f(st))TΣ−1(at−f(st)))
其中
d
d
d是动作
a
t
\mathbf{a}_t
at的维度。
2. 对PDF取对数
对上述公式取自然对数,得到对数概率:
log
π
θ
(
a
t
∣
s
t
)
=
−
1
2
(
a
t
−
f
(
s
t
)
)
T
Σ
−
1
(
a
t
−
f
(
s
t
)
)
−
d
2
log
(
2
π
)
−
1
2
log
∣
Σ
∣
\log \pi_{\theta}(\mathbf{a}_t | \mathbf{s}_t) = -\frac{1}{2} (\mathbf{a}_t - f(\mathbf{s}_t))^T \Sigma^{-1} (\mathbf{a}_t - f(\mathbf{s}_t)) - \frac{d}{2} \log(2\pi) - \frac{1}{2} \log|\Sigma|
logπθ(at∣st)=−21(at−f(st))TΣ−1(at−f(st))−2dlog(2π)−21log∣Σ∣
3. 简化与假设
在强化学习中,通常假设:
- 协方差矩阵 Σ \Sigma Σ 是固定的(例如,设为对角矩阵或常数矩阵),或者与参数 θ \theta θ 无关。
- 常数项对梯度更新无影响:在对策略梯度进行优化时,与 θ \theta θ 无关的项在求导后会消失,因此可以合并为常数。
基于上述假设,将对数概率中的常数项合并:
const
=
−
d
2
log
(
2
π
)
−
1
2
log
∣
Σ
∣
\text{const} = -\frac{d}{2} \log(2\pi) - \frac{1}{2} \log|\Sigma|
const=−2dlog(2π)−21log∣Σ∣
4. 引入马氏距离符号
定义马氏距离(Mahalanobis distance)为:
∥
a
t
−
f
(
s
t
)
∥
Σ
2
=
(
a
t
−
f
(
s
t
)
)
T
Σ
−
1
(
a
t
−
f
(
s
t
)
)
\left\| \mathbf{a}_t - f(\mathbf{s}_t) \right\|_{\Sigma}^2 = (\mathbf{a}_t - f(\mathbf{s}_t))^T \Sigma^{-1} (\mathbf{a}_t - f(\mathbf{s}_t))
∥at−f(st)∥Σ2=(at−f(st))TΣ−1(at−f(st))
代入对数概率公式,得到:
log
π
θ
(
a
t
∣
s
t
)
=
−
1
2
∥
f
(
s
t
)
−
a
t
∥
Σ
2
+
const
\log \pi_{\theta}(\mathbf{a}_t | \mathbf{s}_t) = -\frac{1}{2} \left\| f(\mathbf{s}_t) - \mathbf{a}_t \right\|_{\Sigma}^2 + \text{const}
logπθ(at∣st)=−21∥f(st)−at∥Σ2+const
最终公式
课程中给出的公式正是上述结果,其中:
- f ( s t ) f(\mathbf{s}_t) f(st) 是由参数 θ \theta θ 确定的均值函数(例如神经网络)。
- Σ \Sigma Σ 是固定的协方差矩阵,与 θ \theta θ 无关。
- const \text{const} const 包含所有与 θ \theta θ 无关的常数项。
补充说明
- 协方差矩阵的简化:如果 Σ \Sigma Σ 是对角矩阵或各向同性( Σ = σ 2 I \Sigma = \sigma^2 I Σ=σ2I),则计算马氏距离时只需对每个维度单独计算平方误差。
- 策略梯度的应用:在计算策略梯度 ∇ θ log π θ \nabla_\theta \log \pi_{\theta} ∇θlogπθ 时,常数项 const \text{const} const 的导数为零,因此可以安全忽略。
结论:通过假设协方差矩阵 Σ \Sigma Σ 固定且与参数 θ \theta θ 无关,课程中的公式从多元高斯分布的对数概率密度函数中合理推导得出。