Новости и статьи об искусственном интеллекте и нейросетях. Мы собираем и обрабатываем самую актуальную информацию из мира AI. О проекте

Новости

Warp 1.10: улучшения для JAX и производительности

Warp 1.10 усиливает интеграцию с JAX, добавляя автоматическое дифференцирование и поддержку нескольких устройств, а также улучшает модель тайлов и производительность. Значительные ускорения касаются вызовов функций из Python и операций BVH, sparse-матриц, FEM в графах CUDA. Убрали устаревший модуль warp.sim в пользу Newton, плюс улучшения языка и исправления ошибок.

24 ноября 2025 г.
12 мин
1

Warp v1.10.0

Warp v1.10 усиливает интеграцию с JAX, добавляя поддержку автоматического дифференцирования и совместимость с многопоточным jax.pmap(). Модель программирования тайлов обогатилась редукциями по осям, индексацией на уровне компонентов и удобными функциями для создания тайлов.

Производительность заметно выросла в ряде направлений: операции BVH теперь позволяют перестраивать структуру на месте для графов CUDA с настраиваемым размером листьев, вызовы встроенных функций из Python ускорены до 70 раз, а дополнительные операции со sparse-матрицами и FEM можно захватывать в графы CUDA.

Улучшения удобства включают отрицательную индексацию и слайсинг для массивов, атомарные побитовые операции, а также новые встроенные функции, такие как функции ошибок и приведение типов.

Важно: В этом выпуске убран модуль warp.sim (помеченный как устаревший с v1.8), который заменил движок физики Newton. Подробности миграции и другие изменения смотрите в разделе Объявления ниже.

Полный перечень обновлений доступен в changelog.

Новые возможности

Автоматическое дифференцирование с JAX (экспериментальное)

Warp теперь работает с экспериментальным автоматическим дифференцированием в JAX, что позволяет ядрам участвовать в процессах вычисления градиентов. Эта функция основана на предыдущих разработках и позволяет использовать jax.grad() для градиентов через ядра Warp, передавая enable_backward=True в jax_kernel().

Основные возможности:

  • Ядра с одним или несколькими выходами: Вычисление градиентов для ядер с одним или больше выходными массивами
  • Автоопределение статичных входов: Скалярные входы автоматически считаются статичными (недифференцируемыми) аргументами
  • Массивы векторов и матриц: Полная поддержка массивов составных типов, таких как wp.vec2 или wp.mat22
  • Выполнение на нескольких устройствах: Совместимость с jax.pmap() для распределенных прямых и обратных проходов по нескольким GPU
import jax
from warp.jax_experimental import jax_kernel
@wp.kernel
def my_kernel(a: wp.array(dtype=float), out: wp.array(dtype=float)):
    i = wp.tid()
    out[i] = a[i] ** 2.0
# Включаем автоматическое дифференцирование
jax_func = jax_kernel(my_kernel, num_outputs=1, enable_backward=True)
# Вычисляем градиенты через ядро
grad_fn = jax.grad(lambda a: jax.numpy.sum(jax_func(a)[0]))
gradient = grad_fn(input_array)
# gradient: [2*a[0], 2*a[1], ...]

Эта функция экспериментальная и имеет ограничения. Полные примеры, детали использования и ограничения описаны в документации по автоматическому дифференцированию JAX.

Поддержка нескольких устройств в JAX с jax.pmap()

Warp полностью совместим с jax.pmap() и jax.shard_map() для параллельного выполнения на нескольких устройствах. Раньше проблемы с выбором устройств мешали работе callable из Warp внутри этих примитивов JAX — JAX вызывал колбэки из разных потоков на разные устройства, но Warp всегда запускался на устройстве по умолчанию. Теперь исправление обеспечивает координацию устройств: извлекаются порядковые номера устройств из XLA FFI, добавлена синхронизация потоков для одновременных колбэков, что позволяет эффективно распределять вычисления.

@wp.kernel
def tile_reduce_axis(x: wp.array2d(dtype=float), y: wp.array(dtype=float)):
    a = wp.tile_load(x, shape=(4, 8), storage="shared")
    # Суммируем по оси 0, уменьшая форму с (4, 8) до (8,)
    b = wp.tile_sum(a, axis=0)
    wp.tile_store(y, b)
x = wp.array(np.arange(32).reshape(4, 8), dtype=float)
# x = [[ 0. 1. 2. 3. 4. 5. 6. 7.]
# [ 8. 9. 10. 11. 12. 13. 14. 15.]
# [16. 17. 18. 19. 20. 21. 22. 23.]
# [24. 25. 26. 27. 28. 29. 30. 31.]]
y = wp.zeros(8, dtype=float)
wp.launch_tiled(tile_reduce_axis, dim=(1,), inputs=[x], outputs=[y], block_dim=32)
# y = [48. 52. 56. 60. 64. 68. 72. 76.] (суммы по столбцам)

Индексация на уровне компонентов

Тайлы составных типов (векторы, матрицы, кватернионы) теперь поддерживают индексацию и присваивание на уровне отдельных компонентов. Можно обращаться к элементам напрямую с помощью расширенного синтаксиса индексации:

  • Компоненты векторов: tile[i][1] извлекает вторую компоненту вектора в позиции i
  • Элементы матриц: tile[i][1, 1] дает доступ к элементу в строке 1, столбце 1 матрицы в позиции i

Такой подход упрощает работу со структурированными данными в тайлах.

Создание тайлов с постоянным значением

Новая функция wp.tile_full() позволяет легко создавать тайлы, заполненные постоянным значением, аналогично np.full() в NumPy:

# Создаем тайл 8x8, заполненный 3.14
tile = wp.tile_full(shape=(8, 8), value=3.14, dtype=float)

Новый пример

Пример example_tile_mcgp.py показывает применение тайлов в методах Монте-Карло: реализует алгоритм walk-on-spheres для решения уравнения Лапласа в объемных областях.

Улучшения производительности

Вызовы встроенных функций из Python

Вызовы встроенных функций Warp из области Python (например, wp.normalize(), wp.transform_identity(), арифметика матриц вроде mat * mat) теперь работают гораздо быстрее благодаря оптимизациям в разрешении перегрузок. Раньше каждый вызов проходил по всем вариантам, пытался связать аргументы и упаковывал параметры в C-типы, пока не находил совпадение. Сейчас Warp кэширует разрешенную перегрузку и стратегию упаковки параметров по типам аргументов с помощью @functools.lru_cache, избавляясь от лишних затрат на повторные вызовы.

В микробенчмарках повторное умножение wp.mat44 в области Python ускорено до 70 раз (~570 мкс → ~8 мкс), а операции вроде wp.transform_identity() — в 3-4 раза (~100 мкс → ~30 мкс). Ускорение зависит от сложности операции: больше выигрыш для тех, где разрешение перегрузок дорогое.

Изменение совместимости: В рамках оптимизации убрана поддержка передачи списков, кортежей и других не-Warp-массивов в встроенные функции. Вызовы вроде wp.normalize([1.0, 2.0, 3.0]) теперь нужно писать как wp.normalize(wp.vec3(1.0, 2.0, 3.0)). Это упрощает путь вызова и убирает дорогую логику распаковки последовательностей, несовместимую с кэшированием.

Настраиваемый размер листа BVH

wp.Bvh и wp.Mesh теперь имеют параметры leaf_size и bvh_leaf_size, чтобы пользователи могли регулировать количество примитивов в каждом листе для оптимизации. Оптимальный размер листа зависит от типа запросов:

  • Запросы пересечений (трассировка лучей, пересечения AABB): Маленькие размеры листьев (например, 1) обычно лучше, чтобы минимизировать лишние проверки примитивов
  • Запросы ближайшей точки: Большие размеры (4-8) повышают скорость, проверяя больше примитивов за раз и снижая затраты на обход
  • Смешанные нагрузки: Средние значения (4) дают баланс

Изменение поведения: По умолчанию leaf_size для wp.Bvh теперь 1 вместо жестко заданных 4, что оптимизировано под пересечения, которые чаще. Для wp.Mesh дефолт bvh_leaf_size остался 4 как компромисс между пересечениями и ближайшими точками. Пользователи, фокусирующиеся на ближайших точках, выиграют от явной установки больших размеров.

Операции со sparse-матрицами в графах CUDA

Операции со sparse-матрицами в warp.sparse теперь захватываются в графы CUDA без выделений памяти. Операции вроде bsr_axpy(), bsr_assign() и bsr_set_transpose() сохраняют топологию матрицы при masked=True, а bsr_mm() получила параметр max_new_nnz для указания верхней границы новых ненулевых блоков — это позволяет гибко захватывать графы, когда шаблоны разреженности варьируются в известных пределах.

Операции FEM в графах CUDA

Построение геометрии и пространств функций в warp.fem теперь захватывается в графы CUDA с указанием верхних границ размеров разделов: max_cell_count и max_side_count для ExplicitGeometryPartition, max_node_count для make_space_partition(). Кроме того, построение полей и ограничений по умолчанию теперь без синхронизации.

Улучшения языка

Улучшения индексации и слайсинга массивов

Массивы Warp теперь поддерживают отрицательную индексацию и более удобный слайсинг, что делает манипуляции с массивами интуитивными и похожими на NumPy.

Отрицательная индексация: Доступ к элементам с конца массива через отрицательные индексы:

@wp.kernel
def use_negative_indexing(arr: wp.array(dtype=float)):
    last = arr[-1] # Последний элемент
    second_last = arr[-2] # Предпоследний элемент

Расширенный слайсинг массивов: Массивы поддерживают гибкие операции слайсинга в ядрах, включая доступ с шагом. Это работает как с обычными массивами, так и с тайлами:

@wp.kernel
def tile_load_strided(input: wp.array2d(dtype=float), output: wp.array2d(dtype=float)):
    # Загружаем каждый второй элемент из области 16x16 в тайл 8x8
    tile = wp.tile_load(input[::2, ::2], shape=(8, 8))
    wp.tile_store(output, tile)
input = wp.array(np.arange(256).reshape(16, 16), dtype=float)
output = wp.zeros((8, 8), dtype=float)
wp.launch_tiled(tile_load_strided, dim=(1,), inputs=[input, output], block_dim=32)
# output содержит каждый второй элемент из input:
# [[ 0. 2. 4. 6. 8. 10. 12. 14.]
# [ 32. 34. 36. 38. 40. 42. 44. 46.]
# [ 64. 66. 68. 70. 72. 74. 76. 78.]
# ...
# [224. 226. 228. 230. 232. 234. 236. 238.]]

Новые встроенные функции

  • Функции ошибок: Добавлены wp.erf(), wp.erfc(), wp.erfinv() и wp.erfcinv() для вычислений функций ошибок
  • Приведение типов: Добавлена wp.cast() для переинтерпретации значений как других типов с сохранением битов (например, float как int)
  • Атомарные побитовые операции: Добавлены wp.atomic_and(), wp.atomic_or() и wp.atomic_xor() для потокобезопасных побитовых операций над целыми числами
  • Утилиты для sparse-матриц: Добавлены wp.sparse.bsr_row_index() и wp.sparse.bsr_block_index() как функции уровня ядра для быстрого определения строки блока без поиска по сжатому массиву смещений

Исправления ошибок

Выполнение на CPU AArch64 с тайлами

Устранены сегфолты при запуске ядер на основе тайлов на CPU AArch64, что затрагивает платформы вроде NVIDIA Jetson (Thor, Orin), DGX Spark, Grace Hopper и Grace Blackwell. Исправление использует выделение на стеке вместо статической памяти, чтобы обойти ограничения JIT-компилятора LLVM.

Это изменение включено по умолчанию на всех CPU-архитектурах и отключается через wp.config.enable_tiles_in_stack_memory = False. Если проблемы решаются отключением, сообщите о них на GitHub Issues.

Примечание: Это в основном влияет на выполнение операций с тайлами на CPU, что реже в Warp, но полезно для отладки или случаев, когда перенос на GPU дороже вычислений.

Проверка версии нативной библиотеки

Warp теперь проверяет версии на runtime, чтобы выявлять несоответствия между Python-пакетом и нативными библиотеками (например, warp.dll, warp.so). Это помогает диагностировать проблемы от нескольких установок Warp на системе, когда загружаются неверные библиотеки. При несоответствии выдается предупреждение, но выполнение продолжается. Если видите предупреждения, убедитесь, что Warp загружается из нужного места и нет конфликтующих версий.

Объявления

Удаление модуля warp.sim

Модуль warp.sim удален в этом выпуске. Он был помечен устаревшим в Warp v1.8 (июль 2025) и заменен движком физики Newton — независимым пакетом под Linux Foundation с переработанным API для робототехники и обучения роботов.

Миграция: Пользователи warp.sim должны перейти на Newton. Руководство по переходу от warp.sim к Newton — в миграционном гайде Newton. Исходное объявление и обсуждение — в GitHub Discussion #735.

Вопросы по Newton направляйте в раздел Discussions Newton. Существующие issues по warp.sim в репозитории Warp закроют.

JAX FFI теперь по умолчанию

Реализация jax_kernel() по умолчанию теперь на базе Foreign Function Interface (FFI) JAX, которая требуется для JAX 0.8 и новее. Большинству пользователей не нужно менять код: версия на FFI доступна с Warp 1.7 и дает лучшую производительность через захват графов CUDA. Предыдущая кастомная реализация все еще есть как wp.jax_experimental.custom_call.jax_kernel() для старых версий JAX, но она устаревшая и не работает с JAX 0.8+.

Реорганизация внутреннего кода: папка _src

В рамках уточнения публичного API Warp внутренний код переместился в подпачку warp._src. Это помогает разделить публичные API, на которые опираются пользователи, от внутренних деталей, которые могут меняться без уведомления.

Что это значит для пользователей:

  • Без немедленных изменений: Все импорты работают как раньше. Модули вроде warp.context, warp.types и warp.fem доступны по старым путям через совместимые обертки.
  • В стек-трейсах: В ошибках и стек-трейсах могут появляться пути warp._src (например, warp._src.context вместо warp.context).
  • Будущие планы: В следующих релизах определим и зафиксируем публичный API. После этого публичные модули переэкспортируют все символы, обертки уберут. Код, импортирующий из внутренних модулей, придется обновить на публичные API или явно из warp._src.* (с учетом приватности).

Эта реорганизация — первый шаг в многоэтапном процессе стабилизации публичного API. Если возникли проблемы, сообщите на GitHub Issues.

Предстоящие удаления

В v1.11 (планируется на январь 2026) уберут:

  • Построение матриц из векторов-строк: Возможность создавать матрицы, передавая векторы-строки в конструктор (например, wp.mat22(wp.vec2(1, 2), wp.vec2(3, 4))). Используйте wp.matrix_from_rows() или wp.matrix_from_cols(). Депрекация объявлена в v1.9 с удалением в v1.10, но отложена на один цикл. С v1.9 предупреждения в ядрах, а в v1.10 — и в Python.
  • Параметр graph_compatible в jax_callable(): Булев graph_compatible заменен на graph_mode, принимающий значения enum GraphMode (GraphMode.JAX, GraphMode.WARP или GraphMode.NONE).

Поддержка платформ

  • Python 3.14: Warp теперь работает с Python 3.14, расширяя совместимость за пределы 3.13.
  • macOS на Intel (x86_64): Поддержка Intel Mac удалена. Apple Silicon (ARM64) продолжают поддерживаться с CPU-выполнением. Пользователи Intel Mac могут использовать Warp 1.9.x или старше.
  • Python 3.8: Планируется убрать поддержку Python 3.8 (EOL с 2024-10-07) в следующем минорном релизе (#1019).

Благодарности

Спасибо внешним контрибьюторам:

  • За добавление поддержки атомарных побитовых операций
  • За исправление проблем с интеропом JAX на нескольких устройствах
  • За добавление поддержки автоматического дифференцирования JAX
  • За улучшение аннотаций типов для декораторов struct() и overload()
  • За добавление поддержки компиляции основной библиотеки в нескольких процессах
  • За исправления для обработки предстоящих удалений в LLVM