Skip to content

Commit 517d9b6

Browse files
authored
new: add the first structure of the project (#1)
1 parent 0e04fd6 commit 517d9b6

27 files changed

+1419
-0
lines changed

.devcontainer/Dockerfile

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04
2+
3+
# Avoid prompts from apt
4+
ENV DEBIAN_FRONTEND=noninteractive
5+
6+
# Install Python and other dependencies
7+
RUN apt-get update && apt-get install -y \
8+
python3.10 \
9+
python3.10-dev \
10+
python3-pip \
11+
python3-venv \
12+
git \
13+
curl \
14+
wget \
15+
build-essential \
16+
&& rm -rf /var/lib/apt/lists/*
17+
18+
# Create symbolic links for python
19+
RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \
20+
ln -sf /usr/bin/python3.10 /usr/bin/python3
21+
22+
# Create a non-root user
23+
ARG USERNAME=vscode
24+
ARG USER_UID=1000
25+
ARG USER_GID=$USER_UID
26+
27+
RUN groupadd --gid $USER_GID $USERNAME \
28+
&& useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \
29+
&& apt-get update \
30+
&& apt-get install -y sudo \
31+
&& echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \
32+
&& chmod 0440 /etc/sudoers.d/$USERNAME
33+
34+
# Set up Python environment
35+
RUN python -m pip install --upgrade pip setuptools wheel
36+
37+
# Install JAX with CUDA support
38+
RUN pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
39+
40+
# Install matplotlib for benchmarking script
41+
RUN pip install matplotlib
42+
43+
# Set working directory
44+
WORKDIR /workspace
45+
46+
# Switch to non-root user
47+
USER $USERNAME
48+
49+
# Note: The project dependencies will be installed by the postCreateCommand in devcontainer.json
50+
# which runs `pip install -e '.[dev]'` after the container is created

.devcontainer/README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Development Container for JAX Layers
2+
3+
This directory contains configuration files for setting up a development container with JAX and CUDA support, which is especially useful for Windows users where JAX doesn't natively support CUDA.
4+
5+
## Prerequisites
6+
7+
To use this development container, you need:
8+
9+
1. [Docker Desktop](https://www.docker.com/products/docker-desktop/) installed and configured with WSL 2 backend
10+
2. [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html) installed
11+
3. [Visual Studio Code](https://code.visualstudio.com/) with the [Remote - Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) extension
12+
13+
## GPU Support
14+
15+
The container is configured to use all available GPUs. Make sure your NVIDIA drivers are up-to-date and that Docker has access to your GPUs.
16+
17+
## Usage
18+
19+
1. Open the project in Visual Studio Code
20+
2. Click on the green icon in the bottom-left corner of VS Code
21+
3. Select "Reopen in Container" from the menu
22+
4. Wait for the container to build and start (this may take a while the first time)
23+
24+
Once the container is running, you'll have a fully configured development environment with:
25+
26+
- Python 3.10
27+
- CUDA 12.2 with cuDNN 9
28+
- JAX with CUDA support
29+
- All dependencies from pyproject.toml
30+
31+
## Dependency Management
32+
33+
The container installs dependencies directly from your project's `pyproject.toml` file using the `pip install -e '.[dev]'` command, ensuring consistency between your development environment and the container.
34+
35+
## Customization
36+
37+
You can customize the container by modifying:
38+
39+
- `devcontainer.json`: VS Code settings, extensions, and container configuration
40+
- `Dockerfile`: Base image, dependencies, and environment setup
41+
42+
## Troubleshooting
43+
44+
If you encounter issues with GPU access:
45+
46+
1. Verify that Docker Desktop is configured to use WSL 2
47+
2. Check that NVIDIA Container Toolkit is properly installed
48+
3. Ensure your NVIDIA drivers are up-to-date
49+
4. Run `nvidia-smi` in WSL to verify GPU access
50+
5. Check Docker logs for any error messages related to GPU access

.devcontainer/devcontainer.json

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
{
2+
"name": "JAX Layers Development",
3+
"build": {
4+
"dockerfile": "Dockerfile",
5+
"context": ".."
6+
},
7+
"runArgs": [
8+
"--gpus=all"
9+
],
10+
"customizations": {
11+
"vscode": {
12+
"extensions": [
13+
"ms-python.python",
14+
"ms-python.vscode-pylance",
15+
"charliermarsh.ruff",
16+
"matangover.mypy"
17+
],
18+
"settings": {
19+
"python.defaultInterpreterPath": "/usr/local/bin/python",
20+
"python.linting.enabled": true,
21+
"editor.formatOnSave": true,
22+
"editor.codeActionsOnSave": {
23+
"source.organizeImports": "true",
24+
"source.fixAll": "true"
25+
},
26+
"python.formatting.provider": "none",
27+
"[python]": {
28+
"editor.defaultFormatter": "charliermarsh.ruff"
29+
}
30+
}
31+
}
32+
},
33+
"forwardPorts": [],
34+
"postCreateCommand": "pip install -e '.[dev]'",
35+
"remoteUser": "vscode"
36+
}

.devcontainer/verify_jax_cuda.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Script to verify that JAX with CUDA is working correctly.
4+
Run this script after the container is built to confirm GPU access.
5+
"""
6+
7+
import re
8+
import subprocess
9+
import time
10+
11+
import jax
12+
import jax.numpy as jnp
13+
14+
15+
def get_cuda_version():
16+
"""Get the CUDA version from nvcc."""
17+
try:
18+
result = subprocess.run(["nvcc", "--version"], capture_output=True, text=True)
19+
version_match = re.search(r"release (\d+\.\d+)", result.stdout)
20+
if version_match:
21+
return version_match.group(1)
22+
return "Unknown"
23+
except Exception:
24+
return "Unknown"
25+
26+
27+
def main():
28+
print("\n" + "=" * 50)
29+
print("JAX CUDA Verification Script")
30+
print("=" * 50)
31+
32+
# Check JAX version
33+
print(f"JAX version: {jax.__version__}")
34+
35+
# Check CUDA version
36+
cuda_version = get_cuda_version()
37+
print(f"CUDA version: {cuda_version}")
38+
39+
# Check available devices
40+
print("\nAvailable devices:")
41+
for i, device in enumerate(jax.devices()):
42+
print(f" Device {i}: {device}")
43+
44+
# Check if GPU is available
45+
gpu_available = any(d.platform == "gpu" for d in jax.devices())
46+
print(f"\nGPU available: {gpu_available}")
47+
48+
if not gpu_available:
49+
print("\n⚠️ No GPU devices found! JAX is not using CUDA.")
50+
print("Please check your installation and GPU configuration.")
51+
return
52+
53+
# Run a simple benchmark
54+
print("\nRunning simple matrix multiplication benchmark...")
55+
56+
# Create large matrices
57+
n = 5000
58+
print(f"Creating {n}x{n} matrices...")
59+
60+
# CPU benchmark
61+
with jax.devices("cpu")[0]:
62+
x_cpu = jnp.ones((n, n))
63+
y_cpu = jnp.ones((n, n))
64+
65+
# Warm-up
66+
_ = jnp.dot(x_cpu, y_cpu)
67+
jax.block_until_ready(_)
68+
69+
# Benchmark
70+
start = time.time()
71+
result_cpu = jnp.dot(x_cpu, y_cpu)
72+
jax.block_until_ready(result_cpu)
73+
cpu_time = time.time() - start
74+
75+
# GPU benchmark
76+
with jax.devices("gpu")[0]:
77+
x_gpu = jnp.ones((n, n))
78+
y_gpu = jnp.ones((n, n))
79+
80+
# Warm-up
81+
_ = jnp.dot(x_gpu, y_gpu)
82+
jax.block_until_ready(_)
83+
84+
# Benchmark
85+
start = time.time()
86+
result_gpu = jnp.dot(x_gpu, y_gpu)
87+
jax.block_until_ready(result_gpu)
88+
gpu_time = time.time() - start
89+
90+
# Print results
91+
print(f"\nCPU time: {cpu_time:.4f} seconds")
92+
print(f"GPU time: {gpu_time:.4f} seconds")
93+
print(f"Speedup: {cpu_time / gpu_time:.2f}x")
94+
95+
if cpu_time > gpu_time:
96+
print("\n✅ GPU is faster than CPU! JAX with CUDA is working correctly.")
97+
else:
98+
print("\n⚠️ GPU is not faster than CPU. Something might be wrong with the CUDA setup.")
99+
100+
print("\n" + "=" * 50)
101+
102+
103+
if __name__ == "__main__":
104+
main()

.github/workflows/docs.yml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
name: Documentation
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
docs:
11+
runs-on: ubuntu-latest
12+
13+
steps:
14+
- uses: actions/checkout@v4
15+
16+
- name: Set up Python
17+
uses: actions/setup-python@v4
18+
with:
19+
python-version: "3.10"
20+
21+
- name: Install dependencies
22+
run: |
23+
python -m pip install --upgrade pip
24+
pip install -e ".[dev]"
25+
26+
- name: Build documentation
27+
run: |
28+
cd docs
29+
make html
30+
31+
- name: Deploy to GitHub Pages
32+
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
33+
uses: peaceiris/actions-gh-pages@v3
34+
with:
35+
github_token: ${{ secrets.GITHUB_TOKEN }}
36+
publish_dir: ./docs/_build/html

.github/workflows/mypy.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
name: Type Checking
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
mypy:
11+
runs-on: ubuntu-latest
12+
13+
steps:
14+
- uses: actions/checkout@v4
15+
16+
- name: Set up Python
17+
uses: actions/setup-python@v4
18+
with:
19+
python-version: "3.10"
20+
21+
- name: Install dependencies
22+
run: |
23+
python -m pip install --upgrade pip
24+
pip install -e ".[dev]"
25+
pip install mypy
26+
27+
- name: Run mypy
28+
run: |
29+
mypy jax_layers tests

.github/workflows/ruff.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
name: Ruff
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
ruff:
11+
runs-on: ubuntu-latest
12+
13+
steps:
14+
- uses: actions/checkout@v4
15+
16+
- name: Set up Python
17+
uses: actions/setup-python@v4
18+
with:
19+
python-version: "3.10"
20+
21+
- name: Install dependencies
22+
run: |
23+
python -m pip install --upgrade pip
24+
pip install ruff
25+
26+
- name: Run Ruff
27+
run: |
28+
ruff check .
29+
ruff format --check .

.github/workflows/tests.yml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
name: Tests
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
branches: [ main ]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
strategy:
13+
matrix:
14+
python-version: ["3.10", "3.11", "3.12"]
15+
16+
steps:
17+
- uses: actions/checkout@v4
18+
19+
- name: Set up Python ${{ matrix.python-version }}
20+
uses: actions/setup-python@v4
21+
with:
22+
python-version: ${{ matrix.python-version }}
23+
24+
- name: Install dependencies
25+
run: |
26+
python -m pip install --upgrade pip
27+
pip install -e ".[dev]"
28+
29+
- name: Run tests
30+
run: |
31+
python tests/run_tests.py

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.12

docs/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
_build/

0 commit comments

Comments
 (0)