Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor] Improve performance based on profiling results #1623

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions .github/workflows/metrics.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ on:
- main
- develop
workflow_dispatch:

jobs:
metrics:
runs-on: ubuntu-latest # container: docker://ghcr.io/iterative/cml:0-dvc2-base1
Expand All @@ -19,24 +20,32 @@ jobs:
uses: actions/checkout@v3
with:
ref: ${{ github.event.pull_request.head.sha }}

- name: Install Python 3.12
uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Setup NodeJS (for CML)
uses: actions/setup-node@v3 # For CML
with:
node-version: '16'

- name: Setup CML
uses: iterative/setup-cml@v1

- name: Install Poetry
uses: snok/install-poetry@v1

- name: Install Dependencies
run: poetry install --no-interaction --no-root --with=pytest,metrics --without=dev,docs,linters

- name: Install Project
run: poetry install --no-interaction --with=pytest,metrics --without=dev,docs,linters

- name: Train model
run: poetry run pytest tests/test_model_performance.py -n 1 --durations=0

- name: Download metrics from main
uses: dawidd6/action-download-artifact@v2
with:
Expand All @@ -45,28 +54,40 @@ jobs:
name: metrics
path: tests/metrics-main/
if_no_artifact_found: warn

- name: Open Benchmark Report
run: echo "## Model Benchmark" >> report.md

- name: Write Benchmark Report
run: poetry run python tests/metrics/compareMetrics.py >> report.md

- name: Publish Report with CML
env:
REPO_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
echo "<details>\n<summary>Model training plots</summary>\n" >> report.md
echo "<details><summary>Model training plots</summary>" >> report.md
echo "" >> report.md
echo "## Model Training" >> report.md
echo "" >> report.md
echo "### PeytonManning" >> report.md
cml asset publish tests/metrics/PeytonManning.svg --md >> report.md
echo "" >> report.md
echo "### YosemiteTemps" >> report.md
cml asset publish tests/metrics/YosemiteTemps.svg --md >> report.md
echo "" >> report.md
echo "### AirPassengers" >> report.md
cml asset publish tests/metrics/AirPassengers.svg --md >> report.md
echo "" >> report.md
echo "### EnergyPriceDaily" >> report.md
cml asset publish tests/metrics/EnergyPriceDaily.svg --md >> report.md
echo "\n</details>" >> report.md
echo "" >> report.md
echo "</details>" >> report.md
echo "" >> report.md
cml comment update --target=pr report.md # Post reports as comments in GitHub PRs
cml check create --title=ModelReport report.md # update status of check in PR

- name: Upload metrics if on main
if: github.ref == 'refs/heads/main'
uses: actions/upload-artifact@v3
with:
name: metrics
Expand Down
26 changes: 20 additions & 6 deletions neuralprophet/time_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,23 @@ def __init__(

# Construct index map
self.sample2index_map, self.length = self.create_sample2index_map(self.df, self.df_tensors)
self.time_offset = torch.tensor(datetime(1900, 1, 1).timestamp())
self.df_tensors["ds_seasonality"] = (self.df_tensors["ds"] - self.time_offset).float() / (3600 * 24.0)

self.precomputed_seasonality_terms = self.precompute_seasonality_terms()

def precompute_seasonality_terms(self):
precomputed_terms = OrderedDict()
if self.config_seasonality is None:
return precomputed_terms

for name, period in self.config_seasonality.periods.items():
if period.resolution > 0:
factor = 2.0 * np.pi / period.period
arrange_tensor = torch.arange(1, period.resolution + 1, dtype=torch.float32)
factor_arrange = factor * arrange_tensor
precomputed_terms[name] = factor_arrange
return precomputed_terms

def __getitem__(self, index):
"""Overrides parent class method to get an item at index.
Expand Down Expand Up @@ -333,19 +350,16 @@ def get_sample_lagged_regressors(df_tensors, origin_index, config_lagged_regress


def get_sample_seasonalities(df_tensors, origin_index, n_forecasts, max_lags, n_lags, config_seasonality):

seasonalities = OrderedDict({})
if max_lags == 0:
dates = df_tensors["ds"][origin_index].unsqueeze(0)
dates = df_tensors["ds_seasonality"][origin_index].unsqueeze(0)
else:
dates = df_tensors["ds"][origin_index - n_lags + 1 : origin_index + n_forecasts + 1]

t = (dates - torch.tensor(datetime(1900, 1, 1).timestamp())).float() / (3600 * 24.0)
dates = df_tensors["ds_seasonality"][origin_index - n_lags + 1 : origin_index + n_forecasts + 1]

for name, period in config_seasonality.periods.items():
if period.resolution > 0:
if config_seasonality.computation == "fourier":
factor = 2.0 * np.pi * t[:, None] / period.period
factor = 2.0 * np.pi * dates[:, None] / period.period
sin_terms = torch.sin(factor * torch.arange(1, period.resolution + 1))
cos_terms = torch.cos(factor * torch.arange(1, period.resolution + 1))
features = torch.cat((sin_terms, cos_terms), dim=1)
Expand Down
146 changes: 71 additions & 75 deletions neuralprophet/time_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,28 +268,34 @@ def __init__(
self.ar_layers = ar_layers
self.max_lags = max_lags
if self.n_lags > 0:
self.ar_net = nn.ModuleList()
ar_net_layers = []
d_inputs = self.n_lags
for d_hidden_i in self.ar_layers:
self.ar_net.append(nn.Linear(d_inputs, d_hidden_i, bias=True))
ar_net_layers.append(nn.Linear(d_inputs, d_hidden_i, bias=True))
ar_net_layers.append(nn.ReLU())
d_inputs = d_hidden_i
# final layer has input size d_inputs and output size equal to no. of forecasts * no. of quantiles
self.ar_net.append(nn.Linear(d_inputs, self.n_forecasts * len(self.quantiles), bias=False))
ar_net_layers.append(nn.Linear(d_inputs, self.n_forecasts * len(self.quantiles), bias=False))
self.ar_net = nn.Sequential(*ar_net_layers)
for lay in self.ar_net:
nn.init.kaiming_normal_(lay.weight, mode="fan_in")
if isinstance(lay, nn.Linear):
nn.init.kaiming_normal_(lay.weight, mode="fan_in")

# Lagged regressors
self.lagged_reg_layers = lagged_reg_layers
self.config_lagged_regressors = config_lagged_regressors
if self.config_lagged_regressors is not None:
self.covar_net = nn.ModuleList()
covar_net_layers = []
d_inputs = sum([covar.n_lags for _, covar in self.config_lagged_regressors.items()])
for d_hidden_i in self.lagged_reg_layers:
self.covar_net.append(nn.Linear(d_inputs, d_hidden_i, bias=True))
covar_net_layers.append(nn.Linear(d_inputs, d_hidden_i, bias=True))
covar_net_layers.append(nn.ReLU())
d_inputs = d_hidden_i
self.covar_net.append(nn.Linear(d_inputs, self.n_forecasts * len(self.quantiles), bias=False))
covar_net_layers.append(nn.Linear(d_inputs, self.n_forecasts * len(self.quantiles), bias=False))
self.covar_net = nn.Sequential(*covar_net_layers)
for lay in self.covar_net:
nn.init.kaiming_normal_(lay.weight, mode="fan_in")
if isinstance(lay, nn.Linear):
nn.init.kaiming_normal_(lay.weight, mode="fan_in")

# Regressors
self.config_regressors = config_regressors
Expand All @@ -310,7 +316,9 @@ def __init__(
def ar_weights(self) -> torch.Tensor:
"""sets property auto-regression weights for regularization. Update if AR is modelled differently"""
# TODO: this is wrong for deep networks, use utils_torch.interprete_model
return self.ar_net[0].weight
for layer in self.ar_net:
if isinstance(layer, nn.Linear):
return layer.weight

def get_covar_weights(self, covar_input=None) -> torch.Tensor:
"""
Expand Down Expand Up @@ -393,49 +401,50 @@ def _compute_quantile_forecasts_from_diffs(self, diffs: torch.Tensor, predict_mo
dim (batch, n_forecasts, no_quantiles)
final forecasts
"""
if len(self.quantiles) > 1:
# generate the actual quantile forecasts from predicted differences
if any(quantile > 0.5 for quantile in self.quantiles):
quantiles_divider_index = next(i for i, quantile in enumerate(self.quantiles) if quantile > 0.5)
else:
quantiles_divider_index = len(self.quantiles)

n_upper_quantiles = diffs.shape[-1] - quantiles_divider_index
n_lower_quantiles = quantiles_divider_index - 1

out = torch.zeros_like(diffs)
out[:, :, 0] = diffs[:, :, 0] # set the median where 0 is the median quantile index

if n_upper_quantiles > 0: # check if upper quantiles exist
upper_quantile_diffs = diffs[:, :, quantiles_divider_index:]
if predict_mode: # check for quantile crossing and correct them in predict mode
upper_quantile_diffs[:, :, 0] = torch.max(
torch.tensor(0, device=self.device), upper_quantile_diffs[:, :, 0]
)
for i in range(n_upper_quantiles - 1):
next_diff = upper_quantile_diffs[:, :, i + 1]
diff = upper_quantile_diffs[:, :, i]
upper_quantile_diffs[:, :, i + 1] = torch.max(next_diff, diff)
out[:, :, quantiles_divider_index:] = (
upper_quantile_diffs + diffs[:, :, 0].unsqueeze(dim=2).repeat(1, 1, n_upper_quantiles).detach()
) # set the upper quantiles

if n_lower_quantiles > 0: # check if lower quantiles exist
lower_quantile_diffs = diffs[:, :, 1:quantiles_divider_index]
if predict_mode: # check for quantile crossing and correct them in predict mode
lower_quantile_diffs[:, :, -1] = torch.max(
torch.tensor(0, device=self.device), lower_quantile_diffs[:, :, -1]
)
for i in range(n_lower_quantiles - 1, 0, -1):
next_diff = lower_quantile_diffs[:, :, i - 1]
diff = lower_quantile_diffs[:, :, i]
lower_quantile_diffs[:, :, i - 1] = torch.max(next_diff, diff)
lower_quantile_diffs = -lower_quantile_diffs
out[:, :, 1:quantiles_divider_index] = (
lower_quantile_diffs + diffs[:, :, 0].unsqueeze(dim=2).repeat(1, 1, n_lower_quantiles).detach()
) # set the lower quantiles

if len(self.quantiles) <= 1:
return diffs
# generate the actual quantile forecasts from predicted differences
if any(quantile > 0.5 for quantile in self.quantiles):
quantiles_divider_index = next(i for i, quantile in enumerate(self.quantiles) if quantile > 0.5)
else:
out = diffs
quantiles_divider_index = len(self.quantiles)

n_upper_quantiles = diffs.shape[-1] - quantiles_divider_index
n_lower_quantiles = quantiles_divider_index - 1

out = torch.zeros_like(diffs)
out[:, :, 0] = diffs[:, :, 0] # set the median where 0 is the median quantile index

if n_upper_quantiles > 0: # check if upper quantiles exist
upper_quantile_diffs = diffs[:, :, quantiles_divider_index:]
if predict_mode: # check for quantile crossing and correct them in predict mode
upper_quantile_diffs[:, :, 0] = torch.max(
torch.tensor(0, device=self.device), upper_quantile_diffs[:, :, 0]
)
for i in range(n_upper_quantiles - 1):
next_diff = upper_quantile_diffs[:, :, i + 1]
diff = upper_quantile_diffs[:, :, i]
upper_quantile_diffs[:, :, i + 1] = torch.max(next_diff, diff)
out[:, :, quantiles_divider_index:] = (
upper_quantile_diffs + diffs[:, :, 0].unsqueeze(dim=2).repeat(1, 1, n_upper_quantiles).detach()
) # set the upper quantiles

if n_lower_quantiles > 0: # check if lower quantiles exist
lower_quantile_diffs = diffs[:, :, 1:quantiles_divider_index]
if predict_mode: # check for quantile crossing and correct them in predict mode
lower_quantile_diffs[:, :, -1] = torch.max(
torch.tensor(0, device=self.device), lower_quantile_diffs[:, :, -1]
)
for i in range(n_lower_quantiles - 1, 0, -1):
next_diff = lower_quantile_diffs[:, :, i - 1]
diff = lower_quantile_diffs[:, :, i]
lower_quantile_diffs[:, :, i - 1] = torch.max(next_diff, diff)
lower_quantile_diffs = -lower_quantile_diffs
out[:, :, 1:quantiles_divider_index] = (
lower_quantile_diffs + diffs[:, :, 0].unsqueeze(dim=2).repeat(1, 1, n_lower_quantiles).detach()
) # set the lower quantiles

return out

def scalar_features_effects(self, features: torch.Tensor, params: nn.Parameter, indices=None) -> torch.Tensor:
Expand Down Expand Up @@ -474,14 +483,9 @@ def auto_regression(self, lags: Union[torch.Tensor, float]) -> torch.Tensor:
torch.Tensor
Forecast component of dims: (batch, n_forecasts)
"""
x = lags
for i in range(len(self.ar_layers) + 1):
if i > 0:
x = nn.functional.relu(x)
x = self.ar_net[i](x)

x = self.ar_net(lags)
# segment the last dimension to match the quantiles
x = x.reshape(x.shape[0], self.n_forecasts, len(self.quantiles))
x = x.view(x.shape[0], self.n_forecasts, len(self.quantiles))
return x

def forward_covar_net(self, covariates):
Expand All @@ -501,13 +505,9 @@ def forward_covar_net(self, covariates):
x = torch.cat([covar for _, covar in covariates.items()], axis=1)
else:
x = covariates
for i in range(len(self.lagged_reg_layers) + 1):
if i > 0:
x = nn.functional.relu(x)
x = self.covar_net[i](x)

x = self.covar_net(x)
# segment the last dimension to match the quantiles
x = x.reshape(x.shape[0], self.n_forecasts, len(self.quantiles))
x = x.view(x.shape[0], self.n_forecasts, len(self.quantiles))
return x

def forward(self, inputs: Dict, meta: Dict = None, compute_components_flag: bool = False) -> torch.Tensor:
Expand Down Expand Up @@ -880,8 +880,7 @@ def _get_time_based_sample_weight(self, t):
end_w = self.config_train.newer_samples_weight
start_t = self.config_train.newer_samples_start
time = (t.detach() - start_t) / (1.0 - start_t)
time = torch.maximum(torch.zeros_like(time), time)
time = torch.minimum(torch.ones_like(time), time) # time = 0 to 1
time = torch.clamp(time, 0.0, 1.0) # time = 0 to 1
time = np.pi * (time - 1.0) # time = -pi to 0
time = 0.5 * torch.cos(time) + 0.5 # time = 0 to 1
# scales end to be end weight times bigger than start weight
Expand Down Expand Up @@ -1019,24 +1018,21 @@ class DeepNet(nn.Module):
def __init__(self, d_inputs, d_outputs, lagged_reg_layers=[]):
# Perform initialization of the pytorch superclass
super(DeepNet, self).__init__()
self.layers = nn.ModuleList()
layers = []
for d_hidden_i in lagged_reg_layers:
self.layers.append(nn.Linear(d_inputs, d_hidden_i, bias=True))
layers.append(nn.Linear(d_inputs, d_hidden_i, bias=True))
layers.append(nn.ReLU())
d_inputs = d_hidden_i
self.layers.append(nn.Linear(d_inputs, d_outputs, bias=True))
layers.append(nn.Linear(d_inputs, d_outputs, bias=True))
self.layers = nn.Sequential(*layers)
for lay in self.layers:
nn.init.kaiming_normal_(lay.weight, mode="fan_in")

def forward(self, x):
"""
This method defines the network layering and activation functions
"""
activation = nn.functional.relu
for i in range(len(self.layers)):
if i > 0:
x = activation(x)
x = self.layers[i](x)
return x
return self.layers(x)

@property
def ar_weights(self):
Expand Down
Loading