Время прочтения: 6 мин.

В процессе написания научной работы, я столкнулся с такой проблемой, как относительно невысокая скорость выполнения вычислений. Из-за этого приходится тратить больше времени или жертвовать точностью вычислений, но что делать, если не хочется идти на компромисс и чем-то жертвовать? Воспользоваться новым фреймворком JAX от google. В связке с различными ускорителями, например, GPU (Graphics Processing Unit, или графический процессор) или TPU (Tensor Processing Unit, или тензорный процессор), он покажет достойный результат.

Поговорим об этом по подробнее. Google JAX – фреймворк машинного обучения, разработанный командой исследователей компании Google, для высокопроизводительных вычислений и исследований. В его основе лежит два компонента:

  • библиотека Autograd от PyTorch
  • компилятор XLA (Accelerated Linear Algebra, или ускоренная линейная алгебра) разработанный компанией TensorFlow

Каковы возможности JAX?

Если обратиться к официальной документации JAX, то фреймворк позиционируется, как упрощенная библиотека NumPy, с гораздо большими возможностями для маневра. В качестве основных отличительных возможностей фреймворка, выделяются следующие функции.

GRAD (производит автоматическое дифференцирование).

Благодаря библиотеке Autograd, фреймворк JAX может использовать, как функции Python, так и NumPy. Приятным дополнением к библиотеке NumPy является возможность нахождения производных третьего порядка, а также расчет прямого и обратного распространения.

JIT (выполняет компиляцию в реальном времени).

Как уже было указано ранее, компилятор XLA, составляет основу фреймворка JAX. Именно он производит ускорение вычислений, при помощи комбинирования математических операций в ядра, которые исполняются на компиляторе с подключением какого-либо ускорителя, например, GPU или TPU. Благодаря такой начинке, мы можем строить сложные и многоуровневые программы не опасаясь, потери времени или снижения скорости выполнения вычислений.

VMAP (осуществляет автоматическую векторизацию).

Векторизация вычислений, позволяет значительно повысить скорость их выполнения. Фреймворк JAX дает такую возможность при помощи функции vmap, которая в сочетании с jit() позволит достичь максимального ускорения.

PMAP (SPMD (single program, multiple data — единая программа, множество данных) программирование является методом, используемым для достижения параллелизма).

JAX предоставляет возможность запрограммировать несколько графических процессоров (GPU) или ядер тензорного процессора (TPU) для работы одновременно. Для этого необходимо использовать функцию pmap. Отличным плюсом рассматриваемого фреймворка является его кроссплатформенность от различного рода ускорителей.

Довольно слов, посмотрим так ли он хорош, как о нем заявляют.

GRAD (производит автоматическую дифференциацию).

Рассмотрим кусок кода, в котором решим простейшую школьную задачу 10-11 класса по алгебре: найти производную в точке.

x = 1. # x – содержит значение точки, в которой и ведется расчет производной;
f = lambda x: x**3 + 35*(x**2) + 152 * x + 15 # определим исходную функцию;
dfdx = grad(f) # расчет производной первого порядка;
d2fdx = grad(dfdx) # расчет производной второго порядка;
d3fdx = grad(d2fdx) # расчет производной третьего порядка;
d4fdx = grad(d3fdx) # расчет производной четвертого порядка;
print("Значение функции в точке x = 1: " + str(f(x))) # вывод значений;
print("Значение производной первого порядка в точке x = 1: " + str(dfdx(x))) # вывод значений;
print("Значение производной второго порядка в точке x = 1: " + str(d2fdx(x))) # вывод значений;
print("Значение производной третьего порядка в точке x = 1: " + str(d3fdx(x))) # вывод значений;
print("Значение производной четвертого порядка в точке x = 1: " + str(d4fdx(x))) # вывод значений.

JIT (компиляция в реальном времени).

Для демонстрации возможностей jit будем использовать следующий кусок кода:

def test_function_for_jit(x, alpha=1.66, lyambda=1.04): # определим некую функцию для проверки возможностей работы jit-компиляции, запуск данной функции, производится без jit-компиляции;
    return lyambda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) 
test_function_for_jit_with_jit = jit(test_function_for_jit) # определим функцию, которая будет запускаться с jit-компиляцией;
data = random.normal(random.PRNGKey(0), (1000000,)) # создадим набор данных для тестирования функции;
print('non-jit version:') # покажем результат работы функции без jit-компиляции;
%timeit test_function_for_jit(data).block_until_ready()
print('jit version:') # покажем результат работы функции с jit-компиляцией;
%timeit test_function_for_jit_with_jit(data).block_until_ready().
Результаты работы, программы:
non-jit version:4.06 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
jit version:1.12 ms ± 4.51 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Рассмотрим несколько интересных фактов, связанных с фреймворком JAX NumPy.

Совпадение синтаксиса.

Для того чтобы увидеть совпадение синтаксиса, попытаемся построить один и тот же график функции при помощи библиотеки NumPy и фреймворка JAX.

Построение функции при помощи библиотеки NumPy:

x_from_np = np.linspace(0, 10, 1000) # зададим диапазон и количество точек, на которых будем строить функцию.
y_from_np = 2 * np.cos(x_from_np) # зададим функцию, для которой будем строить график.
plt.plot(x_from_np, y_from_np) # построим график.

Построение функции при помощи фреймворка JAX NumPy:

x_from_jaxnp = jnp.linspace(0, 10, 1000) # зададим диапазон и количество точек, на которых будем строить функцию.
y_from_jaxnp = 2 * jnp.cos(x_from_jaxnp) # зададим функцию, для которой будем строить график.
plt.plot(x_from_jaxnp, y_from_jaxnp) # построим график.

Массивы jax не изменяемы.

Рассмотрим пример работы с массивами из библиотеки NumPy.

vector_x_from_np = np.arange(5) # зададим массив заполненный числами от 1 до 5.
print(vector_x_from_np) # Покажем заданный массив.
vector_x_from_np[1] = 23 # Заменим второй элемент массива.
print(vector_x_from_np) # Покажем изменённый массив.

Посмотрим на вывод:

[0 1 2 3 4]
[0 23 2 3 4].

Попытаемся повторить аналогичные действия с JAX:

vector_x_from_jnp = jnp.arange(5) # зададим массив заполненный числами от 1 до 5.
print(vector_x_from_jnp) # Покажем заданный массив.
vector_x_from_jnp[1] = 23 # Заменим второй элемент массива.
print(vector_x_from_jnp) # Покажем изменённый массив.

Результатом выполнения данного куска кода, будет ошибка:

«TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method»

Однако изменить значения, записанные в массив, возможно следующим образом:

vector_y_from_jnp = vector_x_from_jnp.at[1].set(23) # метод, для замены значения в массиве по индексу.
print(vector_x_from_jnp) # Покажем заданный массив.
print(vector_y_from_jnp) # Покажем измененный массив.

Результат работы кода, становится такой же, как при работе с библиотекой NumPy:

[0 1 2 3 4]
[ 0 23 2 3 4].

В заключение хочется сказать, что представленная информация представляет лишь малую верхушку айсберга. Сегодня мы рассмотрели, что такое JAX, в следующем посте рассмотрим возможности фреймворка в машинном обучении.