Skip to content

Commit

Permalink
feat(function_data): Add function_data classes from InfrastructureSys…
Browse files Browse the repository at this point in the history
…tems (#33)
  • Loading branch information
jerrypotts authored Jun 28, 2024
1 parent 9cc94f9 commit d17754e
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 0 deletions.
172 changes: 172 additions & 0 deletions src/infrasys/function_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""Defines models for cost functions"""

from infrasys import Component
from typing_extensions import Annotated
from pydantic import Field, model_validator
from pydantic.functional_validators import AfterValidator
from typing import NamedTuple, List
import numpy as np


class XYCoords(NamedTuple):
"""Named tuple used to define (x,y) coordinates."""

x: float
y: float


class LinearFunctionData(Component):
"""Data representation for linear cost function.
Class to represent the underlying data of linear functions. Principally used for
the representation of cost functions `f(x) = proportional_term*x + constant_term`.
"""

name: Annotated[str, Field(frozen=True)] = ""
proportional_term: Annotated[
float, Field(description="the proportional term in the represented function")
]
constant_term: Annotated[
float, Field(description="the constant term in the represented function")
]


class QuadraticFunctionData(Component):
"""Data representation for quadratic cost functions.
Class to represent the underlying data of quadratic functions. Principally used for the
representation of cost functions
`f(x) = quadratic_term*x^2 + proportional_term*x + constant_term`.
"""

name: Annotated[str, Field(frozen=True)] = ""
quadratic_term: Annotated[
float, Field(description="the quadratic term in the represented function")
]
proportional_term: Annotated[
float, Field(description="the proportional term in the represented function")
]
constant_term: Annotated[
float, Field(description="the constant term in the represented function")
]


def validate_piecewise_linear_x(points: List[XYCoords]) -> List[XYCoords]:
"""Validates the x data for PiecewiseLinearData class
Function used to validate given x data for the PiecewiseLinearData class.
X data is checked to ensure there is at least two values of x,
which is the minimum required to generate a cost curve, and is
given in ascending order (e.g. [1, 2, 3], not [1, 3, 2]).
Parameters
----------
points : List[XYCoords]
List of named tuples of (x,y) coordinates for cost function
Returns
----------
points : List[XYCoords]
List of (x,y) data for cost function after successful validation.
"""

x_coords = [p.x for p in points]

if len(x_coords) < 2:
raise ValueError("Must specify at least two x-coordinates")
if not (
x_coords == sorted(x_coords)
or (np.isnan(x_coords[0]) and x_coords[1:] == sorted(x_coords[1:]))
):
raise ValueError(f"Piecewise x-coordinates must be ascending, got {x_coords}")

return points


def validate_piecewise_step_x(x_coords: List[float]) -> List[float]:
"""Validates the x data for PiecewiseStepData class
Function used to validate given x data for the PiecewiseStepData class.
X data is checked to ensure there is at least two values of x,
which is the minimum required to generate a cost curve, and is
given in ascending order (e.g. [1, 2, 3], not [1, 3, 2]).
Parameters
----------
x_coords : List[float]
List of x data for cost function.
Returns
----------
x_coords : List[float]
List of x data for cost function after successful validation.
"""

if len(x_coords) < 2:
raise ValueError("Must specify at least two x-coordinates")
if not (
x_coords == sorted(x_coords)
or (np.isnan(x_coords[0]) and x_coords[1:] == sorted(x_coords[1:]))
):
raise ValueError(f"Piecewise x-coordinates must be ascending, got {x_coords}")

return x_coords


class PiecewiseLinearData(Component):
"""Data representation for piecewise linear cost function.
Class to represent piecewise linear data as a series of points: two points define one
segment, three points define two segments, etc. The curve starts at the first point given,
not the origin. Principally used for the representation of cost functions where the points
store quantities (x, y), such as (MW, USD/h).
"""

name: Annotated[str, Field(frozen=True)] = ""
points: Annotated[
List[XYCoords],
AfterValidator(validate_piecewise_linear_x),
Field(description="list of (x,y) points that define the function"),
]


class PiecewiseStepData(Component):
"""Data representation for piecewise step cost function.
Class to represent a step function as a series of endpoint x-coordinates and segment
y-coordinates: two x-coordinates and one y-coordinate defines a single segment, three
x-coordinates and two y-coordinates define two segments, etc. This can be useful to
represent the derivative of a `PiecewiseLinearData`, where the y-coordinates of this
step function represent the slopes of that piecewise linear function.
Principally used for the representation of cost functions where the points store
quantities (x, dy/dx), such as (MW, USD/MWh).
"""

name: Annotated[str, Field(frozen=True)] = ""
x_coords: Annotated[
List[float],
Field(description="the x-coordinates of the endpoints of the segments"),
]
y_coords: Annotated[
List[float],
Field(
description="the y-coordinates of the segments: `y_coords[1]` is the y-value \
between `x_coords[0]` and `x_coords[1]`, etc. Must have one fewer elements than `x_coords`."
),
]

@model_validator(mode="after")
def validate_piecewise_xy(self):
"""Method to validate the x and y data for PiecewiseStepData class
Model validator used to validate given data for the PiecewiseStepData class.
Calls `validate_piecewise_step_x` to check if `x_coords` is valid, then checks if
the length of `y_coords` is exactly one less than `x_coords`, which is necessary
to define the cost functions correctly.
"""
validate_piecewise_step_x(self.x_coords)

if len(self.y_coords) != len(self.x_coords) - 1:
raise ValueError("Must specify one fewer y-coordinates than x-coordinates")

return self
50 changes: 50 additions & 0 deletions tests/test_function_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from infrasys.function_data import PiecewiseStepData, PiecewiseLinearData, XYCoords
import pytest


def test_xycoords():
test_xy = XYCoords(x=1.0, y=2.0)

# Checking associated types
assert isinstance(test_xy, XYCoords)

assert isinstance(test_xy.x, float)

assert isinstance(test_xy.y, float)


def test_piecewise_linear():
# Check validation for minimum x values
test_coords = [XYCoords(1.0, 2.0)]

with pytest.raises(ValueError):
PiecewiseLinearData(points=test_coords)

# Check validation for ascending x values
test_coords = [XYCoords(1.0, 2.0), XYCoords(4.0, 3.0), XYCoords(3.0, 4.0)]

with pytest.raises(ValueError):
PiecewiseLinearData(points=test_coords)


def test_piecewise_step():
# Check minimum x values
test_x = [2]
test_y = [1]

with pytest.raises(ValueError):
PiecewiseStepData(x_coords=test_x, y_coords=test_y)

# Check ascending x values
test_x = [1, 4, 3]
test_y = [2, 4]

with pytest.raises(ValueError):
PiecewiseStepData(x_coords=test_x, y_coords=test_y)

# Check length of x and y lists
test_x = [1, 2, 3]
test_y = [2, 4, 3]

with pytest.raises(ValueError):
PiecewiseStepData(x_coords=test_x, y_coords=test_y)

0 comments on commit d17754e

Please sign in to comment.