Skip to content

Commit

Permalink
releasing version 0.0.67, updating mha, pallas flash attention and …
Browse files Browse the repository at this point in the history
…some other kernels, debuging lora
  • Loading branch information
erfanzar committed Jun 13, 2024
1 parent 84c91ee commit 784e874
Show file tree
Hide file tree
Showing 90 changed files with 1,928 additions and 1,369 deletions.
2 changes: 1 addition & 1 deletion .idea/FXUtils.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 4 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
<p align="center">
<img src="logo/light-logo.png" alt="Alt text"/>
</p>

# FJFormer

Embark on a journey of paralleled/unparalleled computational prowess with FJFormer - an arsenal of custom Jax Flax
Expand All @@ -15,16 +11,15 @@ checkpoint savers, partitioning tools, and other helpful functions.
The goal of FJFormer is to make your life easier when working with Flax and JAX. Whether you are training a new model,
fine-tuning an existing one, or just exploring the capabilities of these powerful frameworks, FJFormer offers

- FlashAttention on `TPU/GPU` 🧬
- BITComputations for 8,6,4 BIT Flax Models 🤏
- Smart Dataset Loading
- Pallas Kernels for GPU,TPU
- BITComputations for 8,6,4 BIT Flax Models
- Built-in functions and Loss functions
- GPU-Pallas triton like implementation of `Softmax`, `FlashAttention`, `RMSNorm`, `LayerNorm`
- Distributed and sharding Model Loaders and Checkpoint Savers
- Monitoring Utils for *TPU/GPU/CPU* memory `foot-print`
- Optimizers
- Special Optimizers with schedulers and Easy to Use
- Partitioning Utils
- LoRA with `XRapture` 🤠
- LoRA

and A lot of these features are fully documented so i gusse FJFormer has something
to offer, and it's not just a Computation BackEnd for [EasyDel](https://github.com/erfanzar/EasyDel).
Expand Down
2 changes: 2 additions & 0 deletions docs/generated-bit_quantization-calibration.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# bit_quantization.calibration
::: src.fjformer.bit_quantization.calibration
2 changes: 2 additions & 0 deletions docs/generated-bit_quantization-config.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# bit_quantization.config
::: src.fjformer.bit_quantization.config
2 changes: 2 additions & 0 deletions docs/generated-bit_quantization-int_numerics.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# bit_quantization.int_numerics
::: src.fjformer.bit_quantization.int_numerics
2 changes: 2 additions & 0 deletions docs/generated-bit_quantization-no_numerics.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# bit_quantization.no_numerics
::: src.fjformer.bit_quantization.no_numerics
2 changes: 2 additions & 0 deletions docs/generated-bit_quantization-numerics.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# bit_quantization.numerics
::: src.fjformer.bit_quantization.numerics
2 changes: 2 additions & 0 deletions docs/generated-bit_quantization-q_dot_general.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# bit_quantization.q_dot_general
::: src.fjformer.bit_quantization.q_dot_general
2 changes: 2 additions & 0 deletions docs/generated-bit_quantization-q_flax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# bit_quantization.q_flax
::: src.fjformer.bit_quantization.q_flax
2 changes: 2 additions & 0 deletions docs/generated-bit_quantization-stochastic_rounding.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# bit_quantization.stochastic_rounding
::: src.fjformer.bit_quantization.stochastic_rounding
2 changes: 2 additions & 0 deletions docs/generated-pallas_operations-gpu-flash_attention-mha.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# pallas_operations.gpu.flash_attention.mha
::: src.fjformer.pallas_operations.gpu.flash_attention.mha
2 changes: 2 additions & 0 deletions docs/generated-pallas_operations-gpu-layer_norm-layer_norm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# pallas_operations.gpu.layer_norm.layer_norm
::: src.fjformer.pallas_operations.gpu.layer_norm.layer_norm
2 changes: 2 additions & 0 deletions docs/generated-pallas_operations-gpu-rms_norm-rms_norm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# pallas_operations.gpu.rms_norm.rms_norm
::: src.fjformer.pallas_operations.gpu.rms_norm.rms_norm
2 changes: 2 additions & 0 deletions docs/generated-pallas_operations-gpu-softmax-softmax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# pallas_operations.gpu.softmax.softmax
::: src.fjformer.pallas_operations.gpu.softmax.softmax
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# pallas_operations.pallas_attention.attention
::: src.fjformer.pallas_operations.pallas_attention.attention
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# pallas_operations.tpu.flash_attention.flash_attention
::: src.fjformer.pallas_operations.tpu.flash_attention.flash_attention
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# pallas_operations.tpu.paged_attention.paged_attention
::: src.fjformer.pallas_operations.tpu.paged_attention.paged_attention
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# pallas_operations.tpu.ring_attention.ring_attention
::: src.fjformer.pallas_operations.tpu.ring_attention.ring_attention
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# pallas_operations.tpu.splash_attention.splash_attention_kernel
::: src.fjformer.pallas_operations.tpu.splash_attention.splash_attention_kernel
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# pallas_operations.tpu.splash_attention.splash_attention_mask
::: src.fjformer.pallas_operations.tpu.splash_attention.splash_attention_mask
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# pallas_operations.tpu.splash_attention.splash_attention_mask_info
::: src.fjformer.pallas_operations.tpu.splash_attention.splash_attention_mask_info
65 changes: 31 additions & 34 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
nav:
- Bits:
- Bits: generated-bits-bits.md
- Calibration: generated-bits-calibration.md
- Config: generated-bits-config.md
- Int Numerics: generated-bits-int_numerics.md
- No Numerics: generated-bits-no_numerics.md
- Numerics: generated-bits-numerics.md
- Q Dot General: generated-bits-q_dot_general.md
- Q Flax: generated-bits-q_flax.md
- Qk: generated-bits-qk.md
- Stochastic Rounding: generated-bits-stochastic_rounding.md
- Bit Quantization:
- Calibration: generated-bit_quantization-calibration.md
- Config: generated-bit_quantization-config.md
- Int Numerics: generated-bit_quantization-int_numerics.md
- No Numerics: generated-bit_quantization-no_numerics.md
- Numerics: generated-bit_quantization-numerics.md
- Q Dot General: generated-bit_quantization-q_dot_general.md
- Q Flax: generated-bit_quantization-q_flax.md
- Stochastic Rounding: generated-bit_quantization-stochastic_rounding.md
- Checkpoint:
- Load: generated-checkpoint-_load.md
- Streamer: generated-checkpoint-streamer.md
Expand All @@ -30,29 +28,28 @@ nav:
- Pallas Operations:
- Efficient Attention:
- Efficient Attention: generated-pallas_operations-efficient_attention-efficient_attention.md
- Layer Norm:
- Gpu:
- Layer Norm: generated-pallas_operations-layer_norm-gpu-layer_norm.md
- Pallas Flash Attention:
- Attention: generated-pallas_operations-pallas_flash_attention-attention.md
- Ring Attention:
- Ring Attention: generated-pallas_operations-ring_attention-ring_attention.md
- Rms Norm:
- Gpu:
- Rms Norm: generated-pallas_operations-rms_norm-gpu-rms_norm.md
- Softmax:
- Gpu:
- Softmax: generated-pallas_operations-softmax-gpu-softmax.md
- Splash Attention:
- Tpu:
- Splash Attention Kernel: generated-pallas_operations-splash_attention-tpu-splash_attention_kernel.md
- Splash Attention Mask: generated-pallas_operations-splash_attention-tpu-splash_attention_mask.md
- Splash Attention Mask Info: generated-pallas_operations-splash_attention-tpu-splash_attention_mask_info.md
- Tpu Flash Attention:
- Gpu:
- Jax Flash Attn Gpu: generated-pallas_operations-tpu_flash_attention-gpu-jax_flash_attn_gpu.md
- Tpu:
- Jax Flash Attn Tpu: generated-pallas_operations-tpu_flash_attention-tpu-jax_flash_attn_tpu.md
- Gpu:
- Flash Attention:
- Mha: generated-pallas_operations-gpu-flash_attention-mha.md
- Layer Norm:
- Layer Norm: generated-pallas_operations-gpu-layer_norm-layer_norm.md
- Rms Norm:
- Rms Norm: generated-pallas_operations-gpu-rms_norm-rms_norm.md
- Softmax:
- Softmax: generated-pallas_operations-gpu-softmax-softmax.md
- Pallas Attention:
- Attention: generated-pallas_operations-pallas_attention-attention.md
- Tpu:
- Flash Attention:
- Flash Attention: generated-pallas_operations-tpu-flash_attention-flash_attention.md
- Paged Attention:
- Paged Attention: generated-pallas_operations-tpu-paged_attention-paged_attention.md
- Ring Attention:
- Ring Attention: generated-pallas_operations-tpu-ring_attention-ring_attention.md
- Splash Attention:
- Splash Attention Kernel: generated-pallas_operations-tpu-splash_attention-splash_attention_kernel.md
- Splash Attention Mask: generated-pallas_operations-tpu-splash_attention-splash_attention_mask.md
- Splash Attention Mask Info: generated-pallas_operations-tpu-splash_attention-splash_attention_mask_info.md
- Sharding:
- Sharding: generated-sharding-sharding.md
- T5x Partitioning: generated-sharding-t5x_partitioning.md
Expand Down
25 changes: 10 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
[project]
name = "FJFormer"
authors = [
{ name = "Erfan Zare Chavoshi", email = "Erfanzare810@gmail.com" }
]
requires-python = ">=3.8"
name = "fjformer"
authors = [{ name = "Erfan Zare Chavoshi", email = "Erfanzare810@gmail.com" }]
requires-python = ">=3.9"
readme = "README.md"
version = "0.0.66"
version = "0.0.67"

dependencies = [
"jax>=0.4.23",
"jaxlib>=0.4.23",
"jax>=0.4.29",
"jaxlib>=0.4.29",
"optax~=0.2.2",
"msgpack~=1.0.7",
"ipython~=8.17.2",
Expand All @@ -36,18 +34,15 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
description = "Embark on a journey of paralleled/unparalleled computational prowess with FJFormer - an arsenal of custom Jax Flax Functions and Utils that elevate your AI endeavors to new heights!"

license = { text = "Apache-2.0" }

keywords = [
"JAX", "Torch", "Deep Learning", "Machine Learning", "Flax", "XLA"
]
keywords = ["JAX", "Deep Learning", "Machine Learning", "Flax", "XLA"]

[build-system]
requires = ["setuptools>=46.4.0", "wheel>=0.34.2"]
build-backend = "setuptools.build_meta"
requires = ["flit_core >=3.2,<4"]
build-backend = "flit_core.buildapi"

[tool.setuptools.packages]

Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
jax>=0.4.23
jaxlib>=0.4.23
jax>=0.4.29
jaxlib>=0.4.29
optax~=0.2.2
msgpack~=1.0.7
ipython~=8.17.2
Expand Down
Loading

0 comments on commit 784e874

Please sign in to comment.