Skip to content

Commit

Permalink
Fix requirements, clean up small test issues (#117)
Browse files Browse the repository at this point in the history
* Fix requirements, clean up small test issues

* Fix CI runner config

* Fix false requirement boundary

* Update python version

* add jax back

* Remove HF test

* remove jit
  • Loading branch information
SamTov authored May 7, 2024
1 parent 1955413 commit ac59d15
Show file tree
Hide file tree
Showing 12 changed files with 47 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/black.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ jobs:
steps:
- uses: actions/checkout@v2
with:
python-version: '3.10'
python-version: '3.11'
- name: Black Check
uses: psf/[email protected]
8 changes: 4 additions & 4 deletions .github/workflows/doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ jobs:
- name: Setup Python environment
uses: actions/setup-python@v2
with:
python-version: '3.10'
python-version: '3.11'
- name: Install dependencies
run: |
sudo apt install pandoc
pip3 install -r requirements.txt
pip install -r dev-requirements.txt
pip install -r requirements.txt
pip install .
pip3 install h5py --upgrade --no-dependencies
pip3 install cached-property
- name: Build documentation
run: |
cd docs
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/flake8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [ "3.10" ]
python-version: [ "3.11" ]

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/isort.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: '3.10'
python-version: '3.11'
- name: Install isort
run: |
pip install isort==5.10.1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nbtest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: "3.10"
python-version: "3.11"
- name: Install dev requirements
run: |
pip3 install nbmake
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
python-version: ["3.11"]

steps:
- uses: actions/checkout@v2
Expand All @@ -22,8 +22,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install -r dev-requirements.txt
pip install -r requirements.txt
- name: Install package
run: |
pip install .
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def setup_class(cls):
Create a model and data for the tests.
The resnet config has a 1 dimensional input and a 2 dimensional output.
"""

resnet_config = ResNetConfig(
num_channels=2,
embedding_size=64,
Expand Down Expand Up @@ -88,3 +87,11 @@ def test_infinite_failure(self):
"""
with pytest.raises(NotImplementedError):
self.model.compute_ntk(self.x, infinite=True)


if __name__ == "__main__":
test_class = TestFlaxHFModule()
test_class.setup_class()

# test_class.test_infinite_failure()
test_class.test_ntk_shape()
2 changes: 1 addition & 1 deletion CI/unit_tests/utils/test_matrix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_unscaled_eigenvalues(self):

values, vectors = compute_eigensystem(matrix, normalize=False)

assert_array_equal(np.real(values), [1, 1])
assert_array_equal(np.real(values), [1.0, 1.0])

def test_scaled_eigenvalues(self):
"""
Expand Down
10 changes: 10 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
isort>=5.13.2
black>=24.4.0
sphinx>=7.3.7
sphinx_copybutton>=0.5.2
sphinx_rtd_theme>=2.0.0
nbsphinx>=0.9.3
pytest>=8.1.1
numpydoc>=1.7.0
flake8>=7.0.0
pre_commit>=3.7.0
43 changes: 16 additions & 27 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,28 +1,17 @@
numpy
matplotlib
sphinx
flake8
black
ipython
numpydoc
optax
sphinx_copybutton
sphinx_rtd_theme
nbsphinx
tensorflow_probability
scipy
scikit-learn
# Temp fix of version of jax and jaxlib until the next release
jax<=0.4.25
jaxlib<=0.4.25
plotly
flax
tqdm
pandas
numpy>=1.26.4
matplotlib>=3.8.4
optax>=0.2.2
tensorflow_probability>=0.24.0
scipy>=1.13.0
scikit-learn>=1.4.2
plotly>=5.21.0
flax>=0.8.2
tqdm>=4.66.2
pandas>=2.2.2
neural-tangents>=0.6.5
tensorflow-datasets
isort
tensorflow
pyyaml
jupyter
transformers
tensorflow-datasets>=4.9.4
tensorflow>=2.16.1
jupyter>=1.0.0
transformers>=4.40.0
jax>=0.4.26
jaxlib>=0.4.26
1 change: 0 additions & 1 deletion znnl/models/jax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
-------
"""

from functools import partial
from typing import Any, Callable, Optional, Sequence, Union

import jax
Expand Down
2 changes: 1 addition & 1 deletion znnl/training_strategies/simple_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def train_model(
state = self.model.model_state

loading_bar = trange(
1, epochs + 1, ncols=100, unit="batch", disable=self.disable_loading_bar
0, epochs, ncols=100, unit="batch", disable=self.disable_loading_bar
)

train_losses = []
Expand Down

0 comments on commit ac59d15

Please sign in to comment.