Online Softmax

Original softmax

对于向量 $x\in\mathbb R^{N}$,函数 $y=\text{softmax}(x)$ 的公式定义为
$$
y_i=\frac{\exp(x_i)}{\sum_{j=1}^N \exp(x_j)}
$$
但在工程上,一般还要对softmax进行一步保证数值安全的操作:
$$
y_i=\frac{\exp(x_i-\max_{k=1}^N x_k)}{\sum_{j=1}^N \exp(x_j-\max_{k=1}^N x_k)}
$$
在代码实现上,求出 $y$ 需要三轮循环(四次访存),具体地:

def safe_softmax(x: torch.Tensor) -> torch.Tensor:
    N = x.shape[0]
    m = torch.zeros_like(x)
    d = torch.zeros_like(x)
    y = torch.zeros_like(x)
    m[0] = x[0]
    for k in range(1, N):
        m[k] = torch.max(m[k - 1], x[k])
    d[0] = torch.exp(x[0] - m[-1])
    for j in range(1, N):
        d[j] = d[j - 1] + torch.exp(x[j] - m[-1])
    for i in range(0, N):
        y[i] = torch.exp(x[i] - m[-1]) / d[-1]
    return y

这里我们通过定义 $d_i=d_{i-1}+\exp(x_i – \max(x))$ 求出 $\text{softmax}$ 的分母。

综上,一个朴素的 $\text{softmax}$ 实现需要三重循环:

  1. 求 $\max (x)$
  2. 求 $\sum_{j=1}^N \exp(x_j-\max_{k=1}^N x_k)$
  3. 求 $y$

Online softmax

定义 $d^\prime\in\mathbb R^N$,
$$
d^\prime_i = \sum_{j=1}^i \exp(x_j-\max_{k=1}^i x_k)
$$
不难注意到 $d^\prime_N = d_N$,因此只要能找到 $d^\prime_i$ 的递推公式就同样可以求出 $\text{softmax}$ 的分母。
$$
\begin{aligned}
d^\prime_i &= \sum_{j=1}^i \exp(x_j-m_i)\\
&= \bigg(\sum_{j=1}^{i-1}\exp(x_j – m_i)\bigg)+ \exp(x_i-m_i)\\
&= \bigg(\sum_{j=1}^{i-1}\exp(x_j – m_{i-1})\times \exp(m_{i-1}-m_i)\bigg) + \exp(x_i-m_i)\\
&= d^\prime_{i-1}\times \exp(m_{i-1}-m_i) + \exp(x_i – m_i)
\end{aligned}
$$
这个递推的好处在于:我们可以将求 $\max(x)$ 和 $d_N$ 的操作融合进同一轮循环,于是就只求 $\text{softmax}$ 就只需要两轮循环了(三次访存)。

def online_softmax(x: torch.Tensor) -> torch.Tensor:
    N = x.shape[0]
    m = torch.zeros_like(x)
    d = torch.zeros_like(x)
    y = torch.zeros_like(x)
    m[0] = x[0]
    d[0] = 1
    for k in range(1, N):
        m[k] = torch.max(m[k - 1], x[k])
        d[k] = d[k - 1] * torch.exp(m[k - 1] - m[k]) + torch.exp(x[k] - m[k])
    for i in range(0, N):
        y[i] = torch.exp(x[i] - m[-1]) / d[-1]
    return y

Parallel online normalizer calculation

虽然online softmax减少了访存次数,但是 $d^\prime_j$ 依赖于 $d^\prime_{j-1}$,无法直接并行化,因此我们还需要将该算法变形为可并行的形式。

定义二元运算 $\oplus$
$$
\begin{bmatrix}
m_x\\
d_x
\end{bmatrix}\oplus
\begin{bmatrix}
m_y\\
d_y
\end{bmatrix}=
\begin{bmatrix}
m_{xy}\\
d_x\exp(m_x-m_{xy}) + d_y\exp(m_y-m_{xy})
\end{bmatrix}
$$
这里 $m_{xy}$ 表示 $\max(m_x,m_y)$。

则有
$$
d_{N-1} =\begin{bmatrix}
x_0\\
1
\end{bmatrix}\oplus\begin{bmatrix}
x_1\\
1
\end{bmatrix}\oplus\begin{bmatrix}
x_2\\
1
\end{bmatrix}\oplus\cdots\oplus\begin{bmatrix}
x_{N-1}\\
1
\end{bmatrix}
$$
只要证明这里的 $\oplus$ 满足结合律,则该运算可并行,即证明
$$
\bigg(\begin{bmatrix}
m_x\\
d_x
\end{bmatrix}\oplus
\begin{bmatrix}
m_y\\
d_y
\end{bmatrix}\bigg)\oplus
\begin{bmatrix}
m_z\\
d_z
\end{bmatrix}=\begin{bmatrix}
m_x\\
d_x
\end{bmatrix}\oplus
\bigg(\begin{bmatrix}
m_y\\
d_y
\end{bmatrix}\oplus
\begin{bmatrix}
m_z\\
d_z
\end{bmatrix}\bigg)
$$
不难验证等式左右两侧的运算结果都是
$$
d_x\exp(m_x-m_{xyz}) + d_y\exp(m_y-m_{xyz}) + d_z\exp(m_z-m_{xyz})
$$
由这个结果实际上我们也容易归纳出
$$
d_{N-1} = 1\cdot\exp(m_0-m_{0,1,\cdots, N-1}) + \cdots + 1\cdot\exp(m_{N-1}-m_{0,1,\cdots, N-1})
$$

这个思想和线段树中的monoid感觉非常类似,都是利用结合律进行异步合并,此外通过单位元进行边界处理。

#include "cuda_runtime.h"
#include "device_launch_parameters.h"

#include <cub/cub.cuh>
#include <bits/stdc++.h>
const int N = 1e8;
float host_a[N];
float cpu_out[N];
float gpu_out[N];

std::mt19937 rng(998244353);
int rand_int(int l, int r) {  // [l, r]
    return std::uniform_int_distribution<int>(l, r)(rng);
}

class Timer {
public:
    Timer() {
        reset();
    }

    void reset() {
        start_time = std::chrono::steady_clock::now();
    }

    double elapsedSeconds() const {  // 返回距离上次reset经过的时间(单位:秒)
        auto now = std::chrono::steady_clock::now();
        std::chrono::duration<double> diff = now - start_time;
        return diff.count();
    }
private:
    std::chrono::steady_clock::time_point start_time;
};

void safe_softmax_cpu(const float* x, float* y, int N, int batch_size) {
    for (int b = 0; b < batch_size; b++) {
        auto max_x = *std::max_element(x, x + N);
        auto dn = exp(x[0] - max_x);
        for (int i = 1; i < N; i++) {
            dn += exp(x[i] - max_x);
        }
        for (int i = 0; i < N; i++) {
            y[i] = exp(x[i] - max_x) / dn;
        }
        x += N;
        y += N;
    }
}

struct __align__(8) S {
    float m;
    float d;
};

__device__ __forceinline__ S merge_md(S x, S y) {
    S res;
    res.m = max(x.m, y.m);
    res.d = x.d * exp(x.m - res.m) + y.d * exp(y.m - res.m);
    return res;
}

const int THREADBLOCK_SIZE = 512;
__global__ void online_softmax_kernel(const float* __restrict x, float* __restrict y, int N) {
    int tid = threadIdx.x;
    int gid = blockIdx.x;

    x += gid * N;
    y += gid * N;
    typedef cub::BlockReduce<S, THREADBLOCK_SIZE> BlockReduce;

    __shared__ typename BlockReduce::TempStorage temp_storage;
    __shared__ S md_total;

    S md_partial{ FLT_MIN, 0.0F };
    for (int elem_id = tid; elem_id < N; elem_id += THREADBLOCK_SIZE) {
        S new_elem{ x[elem_id],1.0F };
        md_partial = merge_md(md_partial, new_elem);
    }

    S md = BlockReduce(temp_storage).Reduce(md_partial, merge_md);
    if (tid == 0) {
        md_total = md;
    }
    __syncthreads();

    float d_total_inverse = __fdividef(1.0F, md_total.d);
    for (int elem_id = tid; elem_id < N; elem_id += THREADBLOCK_SIZE) {
        y[elem_id] = __expf(x[elem_id] - md_total.m) * d_total_inverse;
    }
}

int main() {
    std::cout << std::fixed << std::setprecision(12);
    int V = 5000, batch_size = 20000;
    for (int i = 0; i < V * batch_size; i++) {
        host_a[i] = rand_int(1, 10);
    }
    float* device_input, * device_output;
    cudaMalloc(&device_input, N * sizeof(float));
    cudaMalloc(&device_output, N * sizeof(float));
    cudaMemcpy(device_input, host_a, N * sizeof(float), cudaMemcpyHostToDevice);
    Timer timer;
    safe_softmax_cpu(host_a, cpu_out, V, batch_size);
    std::cout << "cpu: " << timer.elapsedSeconds() << "\n";

    timer.reset();
    online_softmax_kernel<<<batch_size, THREADBLOCK_SIZE>>>(device_input, device_output, N);
    std::cout << "gpu: " << timer.elapsedSeconds() << "\n";
    cudaMemcpy(gpu_out, device_output, N * sizeof(float), cudaMemcpyDeviceToHost);

    for (int i = 0; i < N; i++) {
        assert(abs(cpu_out[i] - gpu_out[i]) < 1e-7);
    }
    return 0;
}

以上代码的本地测试结果:

cpu: 0.714705000000
gpu: 0.000924800000
暂无评论

发送评论 编辑评论


				
|´・ω・)ノ
ヾ(≧∇≦*)ゝ
(☆ω☆)
(╯‵□′)╯︵┴─┴
 ̄﹃ ̄
(/ω\)
∠( ᐛ 」∠)_
(๑•̀ㅁ•́ฅ)
→_→
୧(๑•̀⌄•́๑)૭
٩(ˊᗜˋ*)و
(ノ°ο°)ノ
(´இ皿இ`)
⌇●﹏●⌇
(ฅ´ω`ฅ)
(╯°A°)╯︵○○○
φ( ̄∇ ̄o)
ヾ(´・ ・`。)ノ"
( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
(ó﹏ò。)
Σ(っ °Д °;)っ
( ,,´・ω・)ノ"(´っω・`。)
╮(╯▽╰)╭
o(*////▽////*)q
>﹏<
( ๑´•ω•) "(ㆆᴗㆆ)
😂
😀
😅
😊
🙂
🙃
😌
😍
😘
😜
😝
😏
😒
🙄
😳
😡
😔
😫
😱
😭
💩
👻
🙌
🖕
👍
👫
👬
👭
🌚
🌝
🙈
💊
😶
🙏
🍦
🍉
😣
Source: github.com/k4yt3x/flowerhd
颜文字
Emoji
小恐龙
花!
上一篇