Skip to content

Commit

Permalink
Fixing flash attention droping attention mask support and causal supp…
Browse files Browse the repository at this point in the history
…ort and adding bias support instead
  • Loading branch information
erfanzar committed May 14, 2024
1 parent 09c5beb commit 2ab0ba9
Show file tree
Hide file tree
Showing 13 changed files with 74 additions and 172 deletions.
6 changes: 2 additions & 4 deletions .idea/FXUtils.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions .idea/deployment.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ requires-python = ">=3.8"

readme = "README.md"

version = "0.0.56"
version = "0.0.57"

dependencies = [
"jax>=0.4.20",
Expand Down
2 changes: 1 addition & 1 deletion src/fjformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
from . import optimizers as optimizers
from . import linen as linen

__version__ = "0.0.56"
__version__ = "0.0.57"

__all__ = (
# Loss and extra function
Expand Down
2 changes: 1 addition & 1 deletion src/fjformer/pallas_operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

from .efficient_attention import efficient_attention as efficient_attention
from .flash_attention import (
from .tpu_flash_attention import (
flash_attention as tpu_flash_attention,
mha as gpu_flash_attention,
BlockSizes
Expand Down
Loading

0 comments on commit 2ab0ba9

Please sign in to comment.