JAX Quickstart Guide
A quickstart guide for JAX
· 4 min read
JAX Quick Start #
We will cover a quick start for doing stuff with jax on Windows.
Step 1: Prepare Linux environment #
1.1 Get WSL for Linux #
We follow: https://learn.microsoft.com/en-us/windows/wsl/install.
We will get the default Ubuntu distribution. In PowerShell:
> wsl --installSet up your username and password for your Linux distribution according to the guide.
Now a quick update and upgrade to get everything up to the latest:
> sudo apt update && sudo apt upgrade
1.2 Get VSCode for WSL #
We follow: https://learn.microsoft.com/en-us/windows/wsl/tutorials/wsl-vscode.
Use installer to install if you havn’t got vscode on machine.
Navigate to your project folder and open your Linux shell. This can be done from Windows explorer with ‘Shift + RMB’.
With Linux shell open, run:
> code .We now have the project opened within VSCode running on stock WSL.
1.3 Get Git #
We follow: https://learn.microsoft.com/en-us/windows/wsl/tutorials/wsl-git.
Check if Git is istalled. In your Linux shell, run:
> git --versionInstall Git if you don’t have it:
> sudo apt-get install gitRun Git configs:
> git config --global user.name "Your Name" > git config --global user.email "youremail@domain.com"
Step 2: Python development set up and Git #
2.1 Virtual environment setup #
We set up a venv for better package control and compatibility.
Say we want a project called
JAXTeston py3.12, we should first navigate to the project directory. It looks likeC:\MyProjects\JAXTest.We can now get this py distribution with:
> sudo add-apt-repository ppa:deadsnakes/ppa > sudo apt update > sudo apt install python3.12-fullNow we can create and activate a virtual environment with:
> python3.12 -m venv .venv > source .venv/bin/activateWe now see
\.venvfolder within our project. We should have Git ignore it. Simply create a.gitignorefile in root folder.
2.2 Sync with GitHub #
We now set GitHub up for remote repository.
Create a GitHub repository online. Follow: https://docs.github.com/en/repositories/creating-and-managing-repositories/quickstart-for-repositories.
Don’t forget to choose a license for your project.
We want to initialize Git to this project directory. In Linux shell, run:
> git initWe actually want to rename this branch to
main:> git branch -m mainWe can now link local repo with remote:
> git remote add origin'https://github.com/Your_Username/JAXTest.git'We now make a sample
README.mdfile with some messages written within.Now we add both
.gitignore,README.md, andLICENSEto staging for commit:> git add . > git commit -m "Initial commit."Push/pull to sync up changes.
> git push
Step 3: JAX related setup #
3.1: JAX install #
We have two options: GPU or CPU version of JAX. Use at your own discretion. We follow: https://jax.readthedocs.io/en/latest/installation.html. Since we are working within our new virtual environment, all packages will be installed to this venv specifically.
Let us start with getting pip up to date. In Linux shell, run:
> pip install --upgrade pipWe install the GPU version of JAX for our NVIDIA GPU:
> pip install --upgrade "jax[cuda12]"
3.2 Reading time #
- Read up and continue to reference on the common gotchas of JAX in https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html.
- Additionally, with good commit practices in https://gist.github.com/luismts/495d982e8c5b1a0ced4a57cf3d93cf60.
Step 4: Set up project folder #
4.1 README and ignore #
- Add
.gitignoreandREADME.mdas we mentioned before to the project root folder.
4.2 Set up folder structure #
- Add
\src\JAXTestas your source folder for the project. - Add
\Testas your folder for all future unit tests and such.
4.3 Update requirements #
It’s good practice to document the package and its version for your current project environment. In Linux shell, with your venv activated:
> pip freeze > requirements.txtWe now have a list of all the packages installed within our venv, but we only want the ones that is used:
> pip freeze -q -r requirements.txt | sed '/freeze/,$ d' > requirements-froze.txtWe can replace the original
requirements.txtwithrequirements-froze.txtfor easier documentation.> mv requirements-froze.txt requirements.txtBut of course, manually adding your installed packages to
requirements.txtcould be a cleaner execution.
Step 5: Something in JAX #
5.1 A little something #
We can write the following:
import jax import jax.numpy as jnp result = jnp.arange(3) print(result)That was a little something in
JAXthat worked.Now let us check if GPU is the default device:
print(jax.default_backend()) print(jax.devices())They should give:
gpu [CudaDevice(id=0)]
5.2 Configs #
We can tweak the settings of
JAXfor performance:# Import import os import jax # Float64 support option - set to false jax.config.update("jax_enable_x64", False) # Mem allocation - manual allocation set to false os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # Use persistent cache jax.config.update("jax_compilation_cache_dir", "./.jaxcache")We should now modify
.gitignoreto includejaxcache.