CoCalc
Sharedjax.ipynbOpen in CoCalc
Author: Harald Schilly

# JAX on CoCalc

Kernel: Python 3 (Ubuntu Linux)

JAX is Autograd and XLA, brought together for high-performance machine learning research

jax also has a numpy compatible interface

from jax import grad, jit
import jax.numpy as np

def tanh(x):  # Define a function
y = np.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)  # Obtain its gradient function

0.4199743
/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:164: UserWarning: No GPU found, falling back to CPU. warnings.warn('No GPU found, falling back to CPU.')
%timeit grad_tanh(0.1)

4.58 ms ± 551 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
def slow_f(x):
# Element-wise ops see a large benefit from fusion
a = x * x
b = 2.0 + x
return a * b

fast_f = jit(slow_f)

x = np.ones((5000, 5000))

%timeit -n10 -r3 fast_f(x)

146 ms ± 2.95 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
%timeit -n10 -r3 slow_f(x)

466 ms ± 5.45 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
del x


comparing with numba

import numba
numba_f = numba.jit(slow_f)

import numpy
y = numpy.ones((5000, 5000))

%timeit -n10 -r3 numba_f(y)

612 ms ± 22.5 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)