Original softmax
对于向量 $x\in\mathbb R^{N}$,函数 $y=\text{softmax}(x)$ 的公式定义为
$$
y_i=\frac{\exp(x_i)}{\sum_{j=1}^N \exp(x_j)}
$$
但在工程上,一般还要对softmax进行一步保证数值安全的操作:
$$
y_i=\frac{\exp(x_i-\max_{k=1}^N x_k)}{\sum_{j=1}^N \exp(x_j-\max_{k=1}^N x_k)}
$$
在代码实现上,求出 $y$ 需要三轮循环(四次访存),具体地:
def safe_softmax(x: torch.Tensor) -> torch.Tensor:
N = x.shape[0]
m = torch.zeros_like(x)
d = torch.zeros_like(x)
y = torch.zeros_like(x)
m[0] = x[0]
for k in range(1, N):
m[k] = torch.max(m[k - 1], x[k])
d[0] = torch.exp(x[0] - m[-1])
for j in range(1, N):
d[j] = d[j - 1] + torch.exp(x[j] - m[-1])
for i in range(0, N):
y[i] = torch.exp(x[i] - m[-1]) / d[-1]
return y
这里我们通过定义 $d_i=d_{i-1}+\exp(x_i – \max(x))$ 求出 $\text{softmax}$ 的分母。
综上,一个朴素的 $\text{softmax}$ 实现需要三重循环:
- 求 $\max (x)$
- 求 $\sum_{j=1}^N \exp(x_j-\max_{k=1}^N x_k)$
- 求 $y$
Online softmax
定义 $d^\prime\in\mathbb R^N$,
$$
d^\prime_i = \sum_{j=1}^i \exp(x_j-\max_{k=1}^i x_k)
$$
不难注意到 $d^\prime_N = d_N$,因此只要能找到 $d^\prime_i$ 的递推公式就同样可以求出 $\text{softmax}$ 的分母。
$$
\begin{aligned}
d^\prime_i &= \sum_{j=1}^i \exp(x_j-m_i)\\
&= \bigg(\sum_{j=1}^{i-1}\exp(x_j – m_i)\bigg)+ \exp(x_i-m_i)\\
&= \bigg(\sum_{j=1}^{i-1}\exp(x_j – m_{i-1})\times \exp(m_{i-1}-m_i)\bigg) + \exp(x_i-m_i)\\
&= d^\prime_{i-1}\times \exp(m_{i-1}-m_i) + \exp(x_i – m_i)
\end{aligned}
$$
这个递推的好处在于:我们可以将求 $\max(x)$ 和 $d_N$ 的操作融合进同一轮循环,于是就只求 $\text{softmax}$ 就只需要两轮循环了(三次访存)。
def online_softmax(x: torch.Tensor) -> torch.Tensor:
N = x.shape[0]
m = torch.zeros_like(x)
d = torch.zeros_like(x)
y = torch.zeros_like(x)
m[0] = x[0]
d[0] = 1
for k in range(1, N):
m[k] = torch.max(m[k - 1], x[k])
d[k] = d[k - 1] * torch.exp(m[k - 1] - m[k]) + torch.exp(x[k] - m[k])
for i in range(0, N):
y[i] = torch.exp(x[i] - m[-1]) / d[-1]
return y
Parallel online normalizer calculation
虽然online softmax减少了访存次数,但是 $d^\prime_j$ 依赖于 $d^\prime_{j-1}$,无法直接并行化,因此我们还需要将该算法变形为可并行的形式。
定义二元运算 $\oplus$
$$
\begin{bmatrix}
m_x\\
d_x
\end{bmatrix}\oplus
\begin{bmatrix}
m_y\\
d_y
\end{bmatrix}=
\begin{bmatrix}
m_{xy}\\
d_x\exp(m_x-m_{xy}) + d_y\exp(m_y-m_{xy})
\end{bmatrix}
$$
这里 $m_{xy}$ 表示 $\max(m_x,m_y)$。
则有
$$
d_{N-1} =\begin{bmatrix}
x_0\\
1
\end{bmatrix}\oplus\begin{bmatrix}
x_1\\
1
\end{bmatrix}\oplus\begin{bmatrix}
x_2\\
1
\end{bmatrix}\oplus\cdots\oplus\begin{bmatrix}
x_{N-1}\\
1
\end{bmatrix}
$$
只要证明这里的 $\oplus$ 满足结合律,则该运算可并行,即证明
$$
\bigg(\begin{bmatrix}
m_x\\
d_x
\end{bmatrix}\oplus
\begin{bmatrix}
m_y\\
d_y
\end{bmatrix}\bigg)\oplus
\begin{bmatrix}
m_z\\
d_z
\end{bmatrix}=\begin{bmatrix}
m_x\\
d_x
\end{bmatrix}\oplus
\bigg(\begin{bmatrix}
m_y\\
d_y
\end{bmatrix}\oplus
\begin{bmatrix}
m_z\\
d_z
\end{bmatrix}\bigg)
$$
不难验证等式左右两侧的运算结果都是
$$
d_x\exp(m_x-m_{xyz}) + d_y\exp(m_y-m_{xyz}) + d_z\exp(m_z-m_{xyz})
$$
由这个结果实际上我们也容易归纳出
$$
d_{N-1} = 1\cdot\exp(m_0-m_{0,1,\cdots, N-1}) + \cdots + 1\cdot\exp(m_{N-1}-m_{0,1,\cdots, N-1})
$$
这个思想和线段树中的monoid感觉非常类似,都是利用结合律进行异步合并,此外通过单位元进行边界处理。
#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include <cub/cub.cuh>
#include <bits/stdc++.h>
const int N = 1e8;
float host_a[N];
float cpu_out[N];
float gpu_out[N];
std::mt19937 rng(998244353);
int rand_int(int l, int r) { // [l, r]
return std::uniform_int_distribution<int>(l, r)(rng);
}
class Timer {
public:
Timer() {
reset();
}
void reset() {
start_time = std::chrono::steady_clock::now();
}
double elapsedSeconds() const { // 返回距离上次reset经过的时间(单位:秒)
auto now = std::chrono::steady_clock::now();
std::chrono::duration<double> diff = now - start_time;
return diff.count();
}
private:
std::chrono::steady_clock::time_point start_time;
};
void safe_softmax_cpu(const float* x, float* y, int N, int batch_size) {
for (int b = 0; b < batch_size; b++) {
auto max_x = *std::max_element(x, x + N);
auto dn = exp(x[0] - max_x);
for (int i = 1; i < N; i++) {
dn += exp(x[i] - max_x);
}
for (int i = 0; i < N; i++) {
y[i] = exp(x[i] - max_x) / dn;
}
x += N;
y += N;
}
}
struct __align__(8) S {
float m;
float d;
};
__device__ __forceinline__ S merge_md(S x, S y) {
S res;
res.m = max(x.m, y.m);
res.d = x.d * exp(x.m - res.m) + y.d * exp(y.m - res.m);
return res;
}
const int THREADBLOCK_SIZE = 512;
__global__ void online_softmax_kernel(const float* __restrict x, float* __restrict y, int N) {
int tid = threadIdx.x;
int gid = blockIdx.x;
x += gid * N;
y += gid * N;
typedef cub::BlockReduce<S, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ S md_total;
S md_partial{ FLT_MIN, 0.0F };
for (int elem_id = tid; elem_id < N; elem_id += THREADBLOCK_SIZE) {
S new_elem{ x[elem_id],1.0F };
md_partial = merge_md(md_partial, new_elem);
}
S md = BlockReduce(temp_storage).Reduce(md_partial, merge_md);
if (tid == 0) {
md_total = md;
}
__syncthreads();
float d_total_inverse = __fdividef(1.0F, md_total.d);
for (int elem_id = tid; elem_id < N; elem_id += THREADBLOCK_SIZE) {
y[elem_id] = __expf(x[elem_id] - md_total.m) * d_total_inverse;
}
}
int main() {
std::cout << std::fixed << std::setprecision(12);
int V = 5000, batch_size = 20000;
for (int i = 0; i < V * batch_size; i++) {
host_a[i] = rand_int(1, 10);
}
float* device_input, * device_output;
cudaMalloc(&device_input, N * sizeof(float));
cudaMalloc(&device_output, N * sizeof(float));
cudaMemcpy(device_input, host_a, N * sizeof(float), cudaMemcpyHostToDevice);
Timer timer;
safe_softmax_cpu(host_a, cpu_out, V, batch_size);
std::cout << "cpu: " << timer.elapsedSeconds() << "\n";
timer.reset();
online_softmax_kernel<<<batch_size, THREADBLOCK_SIZE>>>(device_input, device_output, N);
std::cout << "gpu: " << timer.elapsedSeconds() << "\n";
cudaMemcpy(gpu_out, device_output, N * sizeof(float), cudaMemcpyDeviceToHost);
for (int i = 0; i < N; i++) {
assert(abs(cpu_out[i] - gpu_out[i]) < 1e-7);
}
return 0;
}
以上代码的本地测试结果:
cpu: 0.714705000000
gpu: 0.000924800000