Sharedjax.ipynbOpen in CoCalc
Author: Harald Schilly
Views : 62

JAX on CoCalc

Kernel: Python 3 (Ubuntu Linux)

JAX is Autograd and XLA, brought together for high-performance machine learning research

https://github.com/google/jax

grad for gradient, and jit for just in time compilation

jax also has a numpy compatible interface

In [1]:
from jax import grad, jit import jax.numpy as np
In [2]:
def tanh(x): # Define a function y = np.exp(-2.0 * x) return (1.0 - y) / (1.0 + y)
In [3]:
grad_tanh = grad(tanh) # Obtain its gradient function print(grad_tanh(1.0))
0.4199743
/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:164: UserWarning: No GPU found, falling back to CPU. warnings.warn('No GPU found, falling back to CPU.')
In [4]:
%timeit grad_tanh(0.1)
4.58 ms ± 551 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [5]:
def slow_f(x): # Element-wise ops see a large benefit from fusion a = x * x b = 2.0 + x return a * b
In [6]:
fast_f = jit(slow_f)
In [7]:
x = np.ones((5000, 5000))
In [8]:
%timeit -n10 -r3 fast_f(x)
146 ms ± 2.95 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
In [9]:
%timeit -n10 -r3 slow_f(x)
466 ms ± 5.45 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
In [10]:
del x

comparing with numba

In [11]:
import numba numba_f = numba.jit(slow_f)
In [12]:
import numpy y = numpy.ones((5000, 5000))
In [13]:
%timeit -n10 -r3 numba_f(y)
612 ms ± 22.5 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
In [ ]:
In [ ]: