diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000..dad4239
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,2 @@
+# this drop notebooks from GitHub language stats
+*.ipynb linguist-vendored
diff --git a/.gitignore b/.gitignore
index beba85d..58b7753 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,3 +11,7 @@ data/*
wandb
*egg-info
+
+poetry.lock
+
+.env
diff --git a/README.md b/README.md
index eee1095..94d4a24 100644
--- a/README.md
+++ b/README.md
@@ -68,6 +68,24 @@ Please sign-in the Hugging Face account.
huggingface-cli login
```
+## 4. Flash Attention
+Make sure that your environment can use the CUDA toolkit. See also [installation-and-features](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features) in flash-attention.
+
+To use flash-attention, you need to install following packages.
+```bash
+pip install packaging wheel
+pip uninstall -y ninja && pip install ninja --no-cache-dir
+pip install flash-attn --no-build-isolation
+```
+
+If flash-atten doesn't work, please install it from the source. ([Related issue](https://github.com/Dao-AILab/flash-attention/issues/821))
+```bash
+cd /path/to/download
+git clone https://github.com/Dao-AILab/flash-attention.git
+cd flash-attention
+python setup.py install
+```
+
# Training
For learning, use the yaml configuration file under the `projects` directory.
diff --git a/docs/README_CN.md b/docs/README_CN.md
index 0a2ac61..e90eb1c 100644
--- a/docs/README_CN.md
+++ b/docs/README_CN.md
@@ -67,6 +67,23 @@ pre-commit install
huggingface-cli login
```
+## 4. 使用 Flash Attention
+确保您的环境可以使用 CUDA Toolkit。另请参阅 flash-attention 中的安装和功能。
+为了使用 flash-attention,请安装以下包。
+```bash
+pip install packaging wheel
+pip uninstall -y ninja && pip install ninja --no-cache-dir
+pip install flash-attn --no-build-isolation
+```
+
+如果 flash-attention 无法正常工作,请从源代码安装。([相关issue](https://github.com/Dao-AILab/flash-attention/issues/821))
+```bash
+cd /path/to/download
+git clone https://github.com/Dao-AILab/flash-attention.git
+cd flash-attention
+python setup.py install
+```
+
# 学习方法
学习时,请使用 `projects` 目录下的 yaml 配置文件.
diff --git a/docs/README_JP.md b/docs/README_JP.md
index 4fc628f..b292bb5 100644
--- a/docs/README_JP.md
+++ b/docs/README_JP.md
@@ -67,6 +67,23 @@ Llama-2モデルを使用するには、アクセスの申請が必要です。
huggingface-cli login
```
+## 4. Flash Attentionの使用
+実行する環境でCUDA Toolkitが正しく使えることを確認してください。Flash Attentionの[installation-and-features](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features)も参照してください。
+flash-attentionを使うために、以下のパッケージをインストールしてください。
+```bash
+pip install packaging wheel
+pip uninstall -y ninja && pip install ninja --no-cache-dir
+pip install flash-attn --no-build-isolation
+```
+
+もしflash-attentionがうまく動かない場合は、Sourceからインストールしてください。([関連issue](https://github.com/Dao-AILab/flash-attention/issues/821))
+```bash
+cd /path/to/download
+git clone https://github.com/Dao-AILab/flash-attention.git
+cd flash-attention
+python setup.py install
+```
+
# 学習方法
学習を行う場合、`projects`ディレクトリ配下のyaml設定ファイルを使用します。
diff --git a/pyproject.toml b/pyproject.toml
index 49fb89e..9eae627 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,9 +9,9 @@ readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.10, <3.13"
pyyaml = "^6.0.1"
-accelerate = "^0.22.0"
+accelerate = "~0.27.2"
datasets = "^2.14.4"
-deepspeed = "^0.10.2"
+deepspeed = "~0.13.2"
einops = "^0.6.1"
evaluate = "^0.4.0"
peft = "^0.5.0"
@@ -19,10 +19,10 @@ protobuf = "^4.24.2"
scikit-learn = "^1.3.0"
scipy = "^1.11.2"
sentencepiece = "^0.1.99"
-torch = ">=2.0.1"
+torch = { url = "https://download.pytorch.org/whl/cu121/torch-2.2.0%2Bcu121-cp310-cp310-linux_x86_64.whl"}
fire = "^0.5.0"
pillow = "^10.0.0"
-transformers = "^4.33.0"
+transformers = "~4.38.1"
isort = "^5.12.0"
black = "^23.7.0"
wandb = "^0.15.9"
@@ -32,6 +32,8 @@ jupyterlab = "^4.0.5"
matplotlib = "^3.7.2"
japanize-matplotlib = "^1.1.3"
pre-commit = "^3.4.0"
+packaging = "^23.2"
+wheel = "^0.42.0"
[build-system]
requires = ["poetry-core"]
diff --git a/requirements.txt b/requirements.txt
index 8a47da6..d734cf1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,6 @@
+--find-links https://download.pytorch.org/whl/torch_stable.html
+torch==2.2.0+cu121
+
PyYAML
accelerate
datasets
@@ -9,7 +12,6 @@ protobuf
scikit-learn
scipy
sentencepiece
-torch>=2.0.1
fire
pillow
transformers