Google JAX is a machine learning framework for transforming numerical functions, to be used in Python.[2][3][4] It is described as bringing together a modified version of autograd[5] (automatic obtaining of the gradient function through differentiation of a function) and TensorFlow's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch.[6][7] The primary functions of JAX are

grad: automatic differentiation

jit: compilation

vmap: auto-vectorization

pmap: SPMD programming