JAX Quick Start Link to heading
We will cover a quick start for doing stuff with jax on Windows.
Step 1: Prepare Linux environment Link to heading
1.1 Get WSL for Linux Link to heading
We follow: https://learn.microsoft.com/en-us/windows/wsl/install .
We will get the default Ubuntu distribution. In PowerShell:
wsl --install
Set up your username and password for your Linux distribution according to the guide.
Now a quick update and upgrade to get everything up to the latest:
sudo apt update && sudo apt upgrade
1.2 Get VSCode for WSL Link to heading
We follow: https://learn.microsoft.com/en-us/windows/wsl/tutorials/wsl-vscode .
Use installer to install if you havn’t got vscode on machine.
Navigate to your project folder and open your Linux shell. This can be done from Windows explorer with ‘Shift + RMB’.
With Linux shell open, run:
code .
We now have the project opened within VSCode running on stock WSL.
1.3 Get Git Link to heading
We follow: https://learn.microsoft.com/en-us/windows/wsl/tutorials/wsl-git .
Check if Git is istalled. In your Linux shell, run:
git --version
Install Git if you don’t have it:
sudo apt-get install git
Run Git configs:
git config --global user.name "Your Name" git config --global user.email "youremail@domain.com"
Step 2: Python development set up and Git Link to heading
2.1 Virtual environment setup Link to heading
We set up a venv for better package control and compatibility.
Say we want a project called
JAXTest
on py3.12, we should first navigate to the project directory. It looks likeC:\MyProjects\JAXTest
.We can now get this py distribution with:
sudo add-apt-repository ppa:deadsnakes/ppa sudo apt update sudo apt install python3.12-full
Now we can create and activate a virtual environment with:
python3.12 -m venv .venv source .venv/bin/activate
We now see
\.venv
folder within our project. We should have Git ignore it. Simply create a.gitignore
file in root folder.
2.2 Sync with GitHub Link to heading
We now set GitHub up for remote repository.
Create a GitHub repository online. Follow: https://docs.github.com/en/repositories/creating-and-managing-repositories/quickstart-for-repositories .
Don’t forget to choose a license for your project.
We want to initialize Git to this project directory. In Linux shell, run:
git init
We actually want to rename this branch to
main
:git branch -m main
We can now link local repo with remote:
git remote add origin'https://github.com/Your_Username/JAXTest.git'
We now make a sample
README.md
file with some messages written within.Now we add both
.gitignore
,README.md
, andLICENSE
to staging for commit:git add . git commit -m "Initial commit."
Push/pull to sync up changes.
git push
Step 3: JAX related setup Link to heading
3.1: JAX install Link to heading
We have two options: GPU or CPU version of JAX. Use at your own discretion. We follow: https://jax.readthedocs.io/en/latest/installation.html . Since we are working within our new virtual environment, all packages will be installed to this venv specifically.
Let us start with getting pip up to date. In Linux shell, run:
pip install --upgrade pip
We install the GPU version of JAX for our NVIDIA GPU:
pip install --upgrade "jax[cuda12]"
3.2 Reading time Link to heading
- Read up and continue to reference on the common gotchas of JAX in https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html .
- Additionally, with good commit practices in https://gist.github.com/luismts/495d982e8c5b1a0ced4a57cf3d93cf60 .
Step 4: Set up project folder Link to heading
4.1 README and ignore Link to heading
- Add
.gitignore
andREADME.md
as we mentioned before to the project root folder.
4.2 Set up folder structure Link to heading
- Add
\src\JAXTest
as your source folder for the project. - Add
\Test
as your folder for all future unit tests and such.
4.3 Update requirements Link to heading
It’s good practice to document the package and its version for your current project environment. In Linux shell, with your venv activated:
pip freeze > requirements.txt
We now have a list of all the packages installed within our venv, but we only want the ones that is used:
pip freeze -q -r requirements.txt | sed '/freeze/,$ d' > requirements-froze.txt
We can replace the original
requirements.txt
withrequirements-froze.txt
for easier documentation.mv requirements-froze.txt requirements.txt
But of course, manually adding your installed packages to
requirements.txt
could be a cleaner execution.
Step 5: Something in JAX Link to heading
5.1 A little something Link to heading
We can write the following:
import jax import jax.numpy as jnp result = jnp.arange(3) print(result)
That was a little something in
JAX
that worked.Now let us check if GPU is the default device:
print(jax.default_backend()) print(jax.devices())
They should give:
gpu [CudaDevice(id=0)]
5.2 Configs Link to heading
We can tweak the settings of
JAX
for performance:# Import import os import jax # Float64 support option - set to false jax.config.update("jax_enable_x64", False) # Mem allocation - manual allocation set to false os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # Use persistent cache jax.config.update("jax_compilation_cache_dir", "./.jaxcache")
We should now modify
.gitignore
to includejaxcache
.