В предыдущей части серии мы разбирали умножение матриц — операцию, которая встречается во многих областях информатики. Она часто применяется в нейронных сетях для расчета активаций линейных слоев. Однако сами активации трудно анализировать, поскольку их значения и статистика, такие как среднее, дисперсия или диапазон, сильно различаются между слоями. Именно поэтому используются функции активации, например логистическая функция, известная как сигмоида, которая отображает любое вещественное число в диапазон [0; 1].
Функция softmax, или нормализованная экспоненциальная функция, представляет собой многомерное обобщение сигмоиды. Она преобразует вектор исходных оценок, называемых логитами, в распределение вероятностей по M классам. Ее можно рассматривать как взвешенное среднее, которое работает как гладкая функция и легко дифференцируется. Softmax играет ключевую роль в механизмах внимания на основе скалярного произведения, моделировании языка и многоклассовой логистической регрессии.
В этой статье мы разберем:
- Реализацию эффективного ядра softmax в Triton.
- Обратный проход (
autograd). - Оптимизацию: модификаторы кэша и автотюнинг.
Если Triton еще не знаком, стоит заглянуть в предыдущие материалы серии!
Все иллюстрации и анимации созданы автором, если не указано иное.
Определение
Softmax определяется следующим образом:

Нормализация гарантирует, что сумма элементов вектора равна 1, что позволяет интерпретировать его как полноценное распределение вероятностей.
Эта формулировка очень уязвима к переполнению чисел. Максимальное значение, которое может хранить стандартный float16, составляет 65 504, что примерно равно exp(11). Значит, любой входной элемент больше примерно 11 приведет к тому, что exp(z_i) выйдет за пределы представимого диапазона и вызовет переполнение.
Чтобы избежать этой проблемы, обычно вычитают максимальное значение из вектора входа, делая новый максимум равным 0 перед возведением в степень и 1 после.

Простая реализация
Как видно, вычисление softmax требует двух операций редукции: поиска максимума и суммы. Наивный подход предполагает три отдельных прохода по вектору входа: сначала для максимума, потом для суммы и наконец для нормализованных выходов.
Повторяющаяся тема в этой серии по Triton — минимизация медленных обращений к глобальной памяти. Текущая реализация на Numpy требует трех полных чтений вектора входа, что крайне неэффективно.
Онлайн-softmax
К счастью, есть хитрый прием, называемый онлайн-softmax, который позволяет объединить шаги поиска максимума и суммы, сократив чтения памяти до 2.
Сначала определяем сумму экспонент рекурсивно. В следующих равенствах m_i обозначает максимум по x до i-го индекса.

Это равенство позволяет вычислять сумму экспонент итеративно, используя текущий максимум. Благодаря ему можно объединить первый и второй циклы в наивной реализации, рассчитывая максимум и сумму экспонент шаг за шагом.
Алгоритм принимает такой вид:

Это легко перенести на Numpy.
Теперь, когда принципы softmax ясны, перейдем к реализации в Triton. Начнем с простой версии на одном блоке и дойдем до онлайн-варианта на нескольких блоках. В итоге ядро должно работать как модуль PyTorch и поддерживать autograd.
К сожалению, с точки зрения PyTorch ядра Triton выглядят как черные ящики: их операции не отслеживаются autograd. Поэтому нужно самостоятельно реализовать обратный проход и указать, как вычислять градиенты. Вспомним правило цепочки и выведем градиент softmax.
Градиент
Поскольку выходы softmax всегда положительны, удобнее использовать логарифмическую производную для упрощения вычисления градиента. Здесь берем производную от лога выхода и применяем правило цепочки:

Далее перестраиваем члены и следуем этим шагам:

Предположим, есть градиент от вышестоящей функции, например от функции потерь L (скажем, кросс-энтропии). Выражение градиента будет таким:

Упрощение левого члена в (9) происходит потому, что δ_ij равно 1 только для i-го элемента, и сумма по j сводится к одному слагаемому.
Реализация в Triton
Softmax на одном блоке
После вывода градиента можно написать ядра для прямого и обратного проходов softmax. Сначала разберем обертку на PyTorch, чтобы понять, как работает версия на одном блоке на высоком уровне. Для тензора входа размером 2D ядра прямого и обратного проходов будут обрабатывать все строки параллельно.
Для простоты зададим BLOCK_SIZE достаточно большим, чтобы захватить все столбцы за раз. Конкретно, возьмем следующую степень двойки, превышающую число столбцов, как требует Triton.
Затем определим сетку как число строк (она также может учитывать размер батча).
Обертка PyTorch для SoftmaxSingleBlock — это класс, наследующий от torch.autograd.Function, с методами forward и backward. Оба метода принимают аргумент ctx, который используем для сохранения выходов softmax во время прямого прохода и повторного использования в обратном.
Ядра довольно просты: загружаем входы строки с той же синтаксисом, что в статье о сложении векторов. Обратите внимание, что BLOCK_SIZE и num_warps вычисляются с помощью функции calculate_settings. Эта функция из библиотеки Unsloth, ее переиспользовали в других библиотеках ядер, таких как LigerKernel (на которой основаны ядра в этой статье). Она дает эвристики для настройки этих переменных:
def calculate_settings(n: int) -> tuple[int, int]: MAX_FUSED_SIZE = 65536 # maximum grid dimension on Nvidia GPUs BLOCK_SIZE = next_power_of_2(n) if BLOCK_SIZE > MAX_FUSED_SIZE: # we remove this assertion in this article raise RuntimeError( f"Cannot launch Triton kernel since n = {n} exceeds " f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}." ) num_warps = 4 if BLOCK_SIZE >= 32768: num_warps = 32 elif BLOCK_SIZE >= 8192: num_warps = 16 elif BLOCK_SIZE >= 2048: num_warps = 8 return BLOCK_SIZE, num_warpsЗатем реализуем обычный softmax для прямого прохода и уравнение (10) для обратного. Новизна по сравнению с предыдущими материалами — использование модификаторов кэша, которые указывают компилятору, как кэшировать и вытеснять данные. Пока сосредоточимся на трех модификаторах:
.ca(Кэш на всех уровнях): Указывает компилятору загрузить данные в кэш L1 и L2, предполагая, что они скоро понадобятся снова. Используйте, когда данные помещаются в L1 (~128–192 КБ на SM в A100) и будут многократно обращаться..cs(Потоковый): Обрабатывает данные как потоковые — они используются один раз и вытесняются из L1, освобождая место..wb(Запись с возвратом): Обычная кэшированная запись, данные остаются в иерархии кэша, если выходы могут пригодиться позже.
В этих ядрах применяем .ca для загрузок, поскольку выполняем несколько операций над данными. Для записи используем .cs в прямом проходе, так как выходы не переиспользуются сразу, и .wb в обратном, поскольку в контексте autograd (правило цепочки) градиенты потребляют downstream-ядра.
Softmax на нескольких блоках
Теперь рассмотрим онлайн-формулировку softmax. В этом разделе реализуем вариант на нескольких блоках. Здесь BLOCK_SIZE < n_cols, то есть загружаем плитку с BLOCK_SIZE элементами за раз, как в тайловом GEMM из прошлого руководства. Возникает вопрос: как выбрать размер блока?
Это отличный повод познакомиться с утилитой autotune в Triton. Она получает список конфигураций, проводит поиск по сетке, определяет и кэширует лучшую для конкретной формы входа. Процесс повторяется при новой форме входа.
Здесь проводим поиск по сетке для размера блока и числа варпов с помощью такой функции:
from itertools import product # --- Multi Block Tuning --- BLOCK_SIZES = [256, 512, 1024, 2048, 4096, 8192] NUM_WARPS = [2, 4, 8, 16] def get_autotune_config( block_sizes: list[int], num_warps: list[int] ) -> list[triton.Config]: return [ triton.Config(kwargs={"BLOCK_SIZE": bs}, num_warps=nw) for (bs, nw) in list(product(block_sizes, num_warps)) ]Теперь декорируем ядра на нескольких блоках с autotune и передаем список конфигураций, key="n_cols" показывает, что оптимальная конфигурация зависит от числа столбцов входа.
Реализация этих ядер концептуально близка к онлайн-softmax, рассмотренному ранее, основные отличия — итерация по плиткам (не по отдельным элементам, как в Numpy), что требует корректировок. Например, добавляем сумму по плитке в обновлении d, а обратное ядро теперь тоже требует двух итераций.
Примечание: обертка PyTorch идентична, кроме строки, где объявляются BLOCK_SIZE и num_warps (поскольку их выбирает autotune).
Тестирование и бенчмаркинг
Теперь можно запустить прямой и обратный проходы для обоих ядер и убедиться, что они совпадают с базовыми значениями PyTorch:
def validate_kernel(kernel_fn: callable) -> None: device = "cuda:0" if torch.cuda.is_available() else "cpu" torch.random.manual_seed(0) # Generate inputs x = torch.randn((256, 512), device=device) # triton input x.requires_grad = True xt = deepcopy(x) # torch input triton_output = kernel_fn(x) torch_output = torch.softmax(xt, dim=1) torch.testing.assert_close(triton_output, torch_output) # test fwd kernel # Setup fake labels y = torch.zeros_like(x) inds = (torch.arange(0, y.shape[0]), torch.randint(0, 3, (y.shape[0],))) y[inds] = 1 # Define loss and run backward pass loss_fn = torch.nn.CrossEntropyLoss() loss = loss_fn(torch_output, y) loss.backward() # Save gradient tensor for later torch_xgrad = xt.grad.detach().clone() triton_loss = loss_fn(triton_output, y) triton_loss.backward() torch.testing.assert_close(x.grad, torch_xgrad) # test grad outputs validate_kernel(softmax_sb) validate_kernel(softmax_mb)Наконец, сравним производительность с базой PyTorch с помощью такого фрагмента:
# --- Source: Triton softmax tutorial --- @triton.testing.perf_report( triton.testing.Benchmark( x_names=["N"], # argument names to use as an x-axis for the plot x_vals=[ 128 * i for i in range(2, 100) ], # different possible values for `x_name` line_arg="provider", # argument name whose value corresponds to a different line in the plot line_vals=[ "triton_single_block", "triton_multi_block", "torch", ], # possible values for `line_arg`` line_names=[ "Triton_single_block", "Triton_multi_block", "Torch", ], # label name for the lines styles=[("blue", "-"), ("green", "-"), ("red", "-")], ylabel="GB/s", # label name for the y-axis plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. args={"M": 4096}, # values for function arguments not in `x_names` and `y_name` ) ) def benchmark(M, N, provider): x = torch.randn(M, N, device=DEVICE, dtype=torch.float32) stream = getattr(torch, DEVICE.type).Stream() getattr(torch, DEVICE.type).set_stream(stream) if provider == "torch": ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) if provider == "triton_single_block": torch.cuda.synchronize() ms = triton.testing.do_bench(lambda: softmax_sb(x)) torch.cuda.synchronize() if provider == "triton_multi_block": torch.cuda.synchronize() ms = triton.testing.do_bench(lambda: softmax_mb(x)) torch.cuda.synchronize() gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms) benchmark.run(show_plots=True, print_data=True)Хорошие новости! Ядро на одном блоке стабильно превосходит базу PyTorch, а вариант на нескольких блоках отстает для входов с более чем 6 тысячами столбцов:

При рассмотрении больших входов можно отметить несколько моментов:
- Ядро на нескольких блоках стабилизируется около 900 ГБ/с пропускной способности, обгоняя базу PyTorch для входов с более чем 30 тысячами столбцов.
- Интересно, что вариант на нескольких блоках лидирует для входов свыше 60 тысяч столбцов.
- Несмотря на превышение максимального размера блока в варианте на одном блоке, ядро все равно работает гладко по какой-то причине. Triton автоматически управляет размером блока на заднем плане. Когда
n_colsбольше аппаратного лимита, Triton разбивает вход и итерирует по нему. Однако это медленнее, чем подход на нескольких блоках.
Чтобы улучшить, можно объединить оба подхода в одно ядро, которое явно выбирает оптимальное на основе размера входа. Так получим высокую скорость одного блока для малых входов и большую пропускную способность нескольких блоков для входов свыше 60 тысяч столбцов.

На этом завершается третья часть серии по Triton, спасибо за внимание!
В следующей статье применим онлайн-формулировку softmax в контексте Flash Attention.
До встречи! 👋
Ресурсы:
- Реализация softmax в LigerKernel
- Вывод градиента softmax от Thomas Kurbiel
- Оптимизация ядер GPU: Softmax — Часть 2 от Hugo Rosenkranz-Costa (ядра на Cuda и Triton с акцентом на профилирование и аппаратную оптимизацию)
- От онлайн-softmax к FlashAttention от Zihao Ye