Новости и статьи об искусственном интеллекте и нейросетях. Мы собираем и обрабатываем самую актуальную информацию из мира AI. О проекте

Статьи

Разбор Triton: Умножение матриц по одному ядру

В статье разбирается реализация умножения матриц в Triton с акцентом на оптимизации вроде блочного разбиения и согласованности памяти. Рассматривается иерархия памяти GPU на примере A100 и влияние параллелизации на производительность. Эксперименты показывают, как профилирование помогает выявлять bottlenecks в ядрах.

15 октября 2025 г.
12 мин
7

Умножение матриц безусловно остается наиболее востребованной задачей для графических процессоров. Эта операция служит базовым элементом линейной алгебры и применяется в разнообразных областях, включая графику, симуляции физических процессов и научные вычисления, а также широко распространена в машинном обучении.

В этой публикации мы разберем концептуальную реализацию общего умножения матриц на матрицы (GEMM), одновременно познакомившись с ключевыми идеями оптимизации, такими как разбиение на блоки и согласованность доступа к памяти. В завершение мы реализуем GEMM с использованием Triton!

Простая реализация GEMM

Начнем с базового подхода: требуется умножить две матрицы X и Y с размерами (M,N) и (N,K) соответственно. Результирующая матрица Z=X@Y будет иметь размер (M,K).

Эта задача подразумевает вычисление скалярных произведений для всех комбинаций строк из X и столбцов из Y. Простая версия на NumPy может выглядеть следующим образом:

Хотя такой код прост в написании, чтении и понимании, он крайне неэффективен с точки зрения доступа к памяти и кэширования. Как отмечалось в первой части серии, ключевое правило оптимизации на GPU заключается в снижении объема передач данных.

Тем не менее, предлагаемая реализация загружает строку из X, затем последовательно считывает все K столбцов из Y, вычисляет скалярное произведение и повторяет цикл для каждой строки X. В итоге получается M(K+1) операций чтения.

Простое умножение матриц: фиолетовые и синие плитки обозначают векторы, участвующие в скалярных произведениях на каждом шаге, а зеленые ячейки — вычисленные значения результата.

Как видно из анимации, шаблон доступа к памяти оказывается расточительным, поскольку каждый столбец Y загружается M раз. В качестве сравнения: это подобно тому, как бегать в магазин (глобальная память) за каждым новым ингредиентом для блюда, вместо того чтобы заранее разложить все на кухонном столе (общая память). В идеале стоит минимизировать количество загрузок каждого фрагмента данных и повысить его повторное использование после загрузки. Это определяет два основных направления оптимизации:

  1. Как изменить шаблон доступа, чтобы избежать избыточных загрузок?
  2. Сколько данных загружать за раз и где их размещать в памяти GPU?

Блочное GEMM

Как уже говорилось, наивный метод GEMM приводит к множеству повторных загрузок, что создает лишнюю нагрузку. Желательно загружать каждый фрагмент данных лишь однажды и выполнять все необходимые операции до его удаления из памяти.

Изящное решение этой проблемы — блочное разбиение, при котором большие матрицы делятся на меньшие «блоки» или подматрицы. Рассмотрим матрицы X и Y с размерами (4,6) и (6,4) соответственно; произведение X@Y дает матрицу Z размером (4,4).

Чтобы вычислить первый элемент Z, то есть Z[0,0], необходимо рассчитать скалярное произведение первой строки X и первого столбца Y: Z[0,0] = dot(X[0, :], Y[:, 0]). Скалярное произведение можно разбить на мелкие части, например, по три элемента: Z[0,0] = dot(X[0,0:3], Y[0:3, 0]) + dot(X[0,3:6], Y[3:6, 0]).

Более того, этот метод можно распространить на две размерности и вычислять целую подматрицу (2,2) в Z за раз: Z[0:2, 0:2] = dot(X[0:2, 0:2], Y[0:2, 0:2]) + dot(X[0:2, 2:4], Y[2:4, 0:2]) + dot(X[0:2, 4:6], Y[4:6, 0:2]).

Вот визуальное изображение блочного умножения матриц:

Блочное умножение матриц. Вычисления разделены на несколько «блоков» из X и Y (выделены светло-синим и фиолетовым), каждый из которых содержит несколько подблоков (темно-синий и фиолетовый). В каждом подблоке рассчитываются скалярные произведения (зеленые ячейки в X и Y). Эти произведения накапливаются по подблокам внутри блока для получения выходных значений в Z (накопление показано цветами от оранжевого до зеленого).

Приведенная анимация демонстрирует, как данные повторно используются в блочном GEMM. Для каждого блока 2×2 в X и Y вычисляется 4 скалярных произведения, что дает подматрицу (2,2) в Z. Поскольку каждый блок включает 3 подблока, требуется накопить 3 такие подматрицы для финального блока (2,2) в Z. Накопление отображено цветными ячейками в Z.

В аналогии с кухней это означает принести ингредиенты из магазина и разложить их на столе (то есть в малой общей памяти), многократно используя их перед следующим походом в магазин.

Важно, что повторное применение загруженных данных на нескольких этапах позволяет значительно сократить количество операций чтения. Для блоков (2,2) каждая строка X и столбец Y участвует в двух скалярных произведениях. Следовательно, мы выполняем вдвое больше операций с каждым блоком загруженных данных, что примерно уменьшает число загрузок вдвое! Этот принцип обобщается на большие блоки: использование (32,32) сократит загрузки примерно в 32 раза.

Теперь возникает вопрос: «какой размер блоков оптимален»? Чтобы ответить, вспомним, как организована память в современных GPU.

Иерархия памяти GPU

В графических процессорах Nvidia выделяют четыре основных вида памяти. Возьмем за пример A100:

  • Регистры: Самый быстрый и компактный тип памяти на GPU, расположенный непосредственно в каждом потоковом мультипроцессоре (SM). В A100 каждый SM предлагает 256 КБ пространства регистров (65 536 × 32-битных регистров), распределяемых между потоками. Каждый поток получает личные 32-битные регистры для хранения временных переменных и промежуточных результатов, избегая трафика памяти. Однако потребление регистров на поток влияет на заполняемость, поскольку чрезмерное использование ограничивает параллельное выполнение потоков.
  • Кэш L1/Общая память: В A100 каждый SM оснащен 192 КБ SRAM, которую можно настраивать гибко как аппаратный кэш L1 или программируемую общую память. Для критически важных ядер, таких как умножение матриц, мы явно применяем это пространство как общую память для размещения блоков данных рядом с вычислительными блоками, обходя кэш L1. Это обеспечивает точный контроль над повторным использованием данных.
  • Кэш L2: Этот кэш медленнее L1, но значительно больше — около 40 МБ, общих для всех SM в A100. Он функционирует как глобальный кэш для данных и инструкций, снижая обращения к высокозадержанной памяти HBM. Кэш L2 согласован между SM, то есть обновления в одном SM видны другим, что позволяет синхронизировать блоки потоков. Его пропускная способность достигает нескольких терабайт в секунду, служа буфером между быстрой SRAM на чипе и медленной HBM.
  • Память высокой пропускной способности (HBM): Это память устройства с емкостью 40 ГБ или 80 ГБ в зависимости от модели A100. Она обеспечивает чрезвычайно высокую пропускную способность (до 2 ТБ/с в варианте 80 ГБ), но с гораздо большей задержкой, чем у кэшей на чипе. В HBM хранятся большие тензоры, веса моделей и наборы данных во время выполнения. Поскольку доступ к HBM дорог, эффективные ядра стремятся минимизировать перемещение данных и максимально использовать повторное применение на чипе через регистры и общую память.

Как видно, иерархия памяти обычно обменивает емкость на задержку. Поэтому достижение максимальной производительности сводится к эффективной загрузке данных из HBM в общую память и их максимальному повторному использованию.

Иерархия памяти GPU: от самого быстрого/малого (сверху) к самому медленному/большому (снизу).

Выбор размера блока имеет решающее значение. Блоки должны быть достаточно крупными для создания значительного объема параллельных задач, но достаточно малыми, чтобы их данные помещались в общую память и регистры SM. BLOCK_SIZE в 64 — типичная отправная точка, поскольку это кратно размеру варпа (32 потока), гарантируя полную загрузку аппаратных ресурсов.

Параллельное блочное GEMM

Учитывая эти аспекты, логичным продолжением блочного GEMM становится параллелизация вычислений для пар блоков по нескольким блокам потоков, как показано в следующей анимации.

Параллельное блочное умножение матриц. Итерация по блокам заменяется параллельными операциями по нескольким блокам потоков.

Согласованность памяти

Перед тем как перейти к блочному GEMM в Triton, стоит учесть еще один нюанс: согласованность памяти, техника для оптимального использования пропускной способности глобальной памяти. Она достигается, когда последующие потоки в варпе обращаются к последовательным адресам памяти. Представьте библиотекаря, собирающего книги для клиента: если все книги стоят рядом на полке, их можно взять сразу; если же они разбросаны по разным полкам, процесс затянется.

Чтобы понять применение к нашему случаю, учтите, что матрицы хранятся линейно в памяти, то есть матрица (2,2) представлена как последовательность из 4 соседних элементов. Фреймворки вроде PyTorch используют строковую компоновку, при которой элементы матрицы соседни в памяти по строкам. Например, элементы матрицы (2,2) хранятся как [(0,0), (0,1), (1,0), (1,1)]; заметьте, что элементы одной строки соседни, а столбца имеют шаг в 1 (разделены одним элементом).

PyTorch хранит матрицы в строковой компоновке. Элементы строки соседни в памяти, а столбца — с шагом.

Это позволяет загружать строки с помощью согласованных загрузок, но столбцы не соответствуют этому требованию. Однако для скалярных произведений нужны столбцы Y. Чтобы добиться максимальной производительности, рекомендуется транспонировать Y, чтобы итерировать по строкам вместо столбцов.

Однако простая транспозиция Y не меняет ее компоновку в памяти. Как упоминалось, PyTorch хранит матрицы в плоском массиве. Каждой размерности соответствует атрибут stride, указывающий шаг для перехода к следующему элементу по этой оси. Для матрицы (10,10) шаги равны (10,1): от [0,0] к [1,0] — 10 слотов памяти (одна строка), а к [0,1] — соседний слот.

При транспозиции тензора PyTorch не меняет компоновку в памяти, а лишь пересчитывает шаги. Чтобы транспозиция повлияла на память, требуется вызвать Y.T.contiguous().

Это необходимые шаги для эффективной загрузки столбцов Y, но внутри ядра придется транспонировать загруженные блоки для правильного скалярного произведения: z_block = tl.dot(X_block, Y_block.T).

Представление Y, Y.T и Y.T.contiguous() в блочном виде и компоновке памяти. Транспозиция меняет поведение матрицы, но не ее память. Поэтому добавляется .contiguous() для согласованных чтений по строкам.

Реализация в Triton

Далее опишем ядро без согласованности памяти, чтобы упростить логику и арифметику указателей, а затем подведем итоги изменений для согласованных загрузок столбцов Y.

Начнем с обертки PyTorch вокруг ядра. Нужно извлечь M, N, K из входных матриц и вычислить их шаги, поскольку эти константы пригодятся в ядре. Затем зададим BLOCK_SIZE и объявим grid.

Теперь перейдем к коду самого ядра. Мы используем утилиту Triton make_block_ptr, упрощающую арифметику указателей. Создаем по одному блочному указателю на матрицу, передавая форму матрицы, ее шаги и размер блока. Кроме того, указываем смещение — координаты левого верхнего элемента текущего блока. Для X это (m_idx * BLOCK_SIZE, 0), где m_idx — индекс текущего блока по размерности M.

Далее определяем z_acc — нулевую матрицу, которая будет принимать частичные скалярные произведения при итерации по блокам. Теперь проходим по общей размерности N, загружая блоки (BLOCK_SIZE, BLOCK_SIZE) и накапливая их скалярные произведения в z_acc. Затем сдвигаем блочные указатели по общей размерности с помощью .advance.

Обратите внимание, что при загрузке данных применяются boundary_check и padding_option вместо mask и other, как в предыдущей публикации. Эти параметры специфичны для блочных указателей и определяют оси для проверки выхода за границы (здесь (0,1) для x и y) и обработку недопустимых значений. Мы устанавливаем их в ноль, чтобы игнорировать в скалярном произведении.

Теперь оценим производительность ядра с помощью следующей функции:

def bench(fn: callable, x: torch.Tensor, y: torch.Tensor, repeat: int): flops = [] med_latency = [] for _ in tqdm(range(repeat), desc=f"Benchmarking {fn.__name__}"): latency_ms = triton.testing.do_bench( lambda: fn(x, y), quantiles=[0.5], # get the median latency return_mode="all", ) n_flops = 2 * M * N * K # matmul roughly requires 2*M*N*K operations tflops = n_flops / (latency_ms / 1e3) / 1e12 med_latency.append(latency_ms) flops.append(tflops) flops = np.array(flops) med_latency = np.array(med_latency) print(f"Absolute Error: {torch.sum(torch.abs(X@Y - fn(x, y)))}") print(f"Median Latency: {med_latency.mean():.4f} ± {med_latency.std():.3f} ms") print(f"Throughput: {flops.mean():.4f} ± {flops.std():.3f} TeraFLOPS") M = 8192 N = 6144 K = 4096 X = torch.randn((M, N), device="cuda", dtype=torch.float32) Y = torch.randn((N, K), device="cuda", dtype=torch.float32) bench(block_matmul, X, Y, repeat=10)

Получаем следующие результаты (на GPU T4 в Colab):

Absolute Error: 0.0 # the kernel outputs the correct result! Median Latency: 130.7831 ± 1.794 ms Throughput: 3.1533 ± 0.043 TeraFLOPS

Теперь рассмотрим изменения для согласованных загрузок Y: в основном нужно поменять форму, шаги и смещения при определении блочного указателя для Y. Кроме того, обновляем указатель для движения по размерности столбцов (ранее строк). Полный код этой версии доступен на GitHub.

@triton.jit def coalesced_block_matmul_kernel( X_ptr, X_m_stride, X_n_stride, Y_ptr, Y_k_stride, Y_n_stride, Z_ptr, Z_m_stride, Z_k_stride, M, N, K, BLOCK_SIZE: tl.constexpr, ): ... y_block_ptr = tl.make_block_ptr( base=Y_ptr, # flip the shape, strides and offsets to match Y.T shape=(K, N), strides=(Y_k_stride, Y_n_stride), offsets=(k_idx * BLOCK_SIZE, 0), block_shape=(BLOCK_SIZE, BLOCK_SIZE), order=(0, 1), ) ... for _ in range(0, N, BLOCK_SIZE): ... # loads z_acc += tl.dot(x, y.T) # transpose Y back for dot product x_block_ptr = tl.advance(x_block_ptr, offsets=(0, BLOCK_SIZE)) # advance the block pointer along columns of Y.T (i.e rows of Y) y_block_ptr = tl.advance(y_block_ptr, offsets=(0, BLOCK_SIZE)) tl.store(pointer=z_block_ptr, value=z_acc, boundary_check=(0, 1)) def coalesced_block_matmul(X, Y): Y = Y.T.contiguous() # Y is now (K,N) M, N = X.shape K, _ = Y.shape Z = torch.empty((M, K), device="cuda") x_stride_m, x_stride_n = X.stride() y_stride_k, y_stride_n = Y.stride() z_stride_m, z_stride_k = Z.stride() ... # define BLOCK_SIZE and grid coalesced_block_matmul_kernel[grid]( X, x_stride_m, x_stride_n, Y, y_stride_n, y_stride_k, Z, z_stride_m, z_stride_k, M, N, K, BLOCK_SIZE, ) return Z

Вот результаты бенчмарка для ядра с согласованными загрузками Y:

Absolute Error: 0.0 # Again, the kernel is correct! Median Latency: 261.9420 ± 0.858 ms Throughput: 1.5741 ± 0.005 TeraFLOPS

Поразительно, но пропускная способность второго ядра составляет лишь половину от первого, несмотря на улучшение эффективности загрузок 🤔

Краткий анализ с помощью nsight (профайлера ядер Nvidia, подробнее в будущей публикации) показывает, что транспозиция внутри ядра вызывает «пробку». Конкретно, транспозиция провоцирует конфликты банков, из-за чего потоки простаивают большую часть времени. В частности, планировщик варпов не находит подходящих варпов для запуска в 87,6% случаев, ожидая разрешения конфликта. Кроме того, отчет указывает:

———————– ———– ————–
Metric Name Metric Unit Metric Value
———————– ———– ————–

DRAM Throughput % 8.20
Compute (SM) Throughput % 21.14

Это свидетельствует о том, что ядро ограничено задержкой (то есть ни памятью, ни вычислениями; см. предыдущую публикацию для деталей). В отличие от него, первое ядро ограничено вычислениями (то есть рост вычислений повысит производительность), поскольку пропускная способность вычислений высока по сравнению с DRAM.

———————– ———– ————–
Metric Name Metric Unit Metric Value
———————– ———– ————–

DRAM Throughput % 29.35
Compute (SM) Throughput % 74.39

Заключение

Этот эксперимент подчеркивает значимость профилирования и эмпирической проверки. Даже обоснованные оптимизации, такие как согласованность доступа к памяти, могут создать новые узкие места, если их не оценить тщательно. Первое ядро, будучи проще, было ограничено вычислениями и лучше соответствовало характеристикам аппаратного обеспечения.

В следующих публикациях серии мы реализуем ядро softmax, уделив особое внимание интеграции Triton с autograd PyTorch и профилированию ядер с помощью Nsight.

До скорой встречи! 👋