forked from peterdsharpe/AeroSandbox
-
Notifications
You must be signed in to change notification settings - Fork 0
/
interpolate.py
302 lines (243 loc) · 10.1 KB
/
interpolate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import numpy as _onp
import casadi as _cas
from aerosandbox.numpy.determine_type import is_casadi_type
from aerosandbox.numpy.array import array, zeros_like
from aerosandbox.numpy.conditionals import where
from aerosandbox.numpy.logicals import all, any, logical_or
from typing import Tuple
from scipy import interpolate as _interpolate
def interp(x, xp, fp, left=None, right=None, period=None):
"""
One-dimensional linear interpolation, analogous to numpy.interp().
Returns the one-dimensional piecewise linear interpolant to a function with given discrete data points (xp, fp),
evaluated at x.
See syntax here: https://numpy.org/doc/stable/reference/generated/numpy.interp.html
Specific notes: xp is assumed to be sorted.
"""
if not is_casadi_type([x, xp, fp], recursive=True):
return _onp.interp(
x=x,
xp=xp,
fp=fp,
left=left,
right=right,
period=period
)
else:
### If xp or x are CasADi types, this is unsupported :(
if is_casadi_type([x, xp], recursive=True):
raise NotImplementedError(
"Unfortunately, CasADi doesn't yet support a dispatch for x or xp as CasADi types."
)
### Handle period argument
if period is not None:
if any(
logical_or(
xp < 0,
xp > period
)
):
raise NotImplementedError(
"Haven't yet implemented handling for if xp is outside the period.") # Not easy to implement because casadi doesn't have a sort feature.
x = _cas.mod(x, period)
### Make sure x isn't an int
if isinstance(x, int):
x = float(x)
### Make sure that x is an iterable
try:
x[0]
except TypeError:
x = array([x], dtype=float)
### Make sure xp is an iterable
xp = array(xp, dtype=float)
### Do the interpolation
f = _cas.interp1d(
xp,
fp,
x
)
### Handle left/right
if left is not None:
f = where(
x < xp[0],
left,
f
)
if right is not None:
f = where(
x > xp[-1],
right,
f
)
### Return
return f
def is_data_structured(
x_data_coordinates: Tuple[_onp.ndarray],
y_data_structured: _onp.ndarray
) -> bool:
"""
Determines if the shapes of a given dataset are consistent with "structured" (i.e. gridded) data.
For this to evaluate True, the inputs should be:
x_data_coordinates: A tuple or list of 1D ndarrays that represent coordinates along each axis of a N-dimensional hypercube.
y_data_structured: The values of some scalar defined on that N-dimensional hypercube, expressed as an
N-dimesional array. In other words, y_data_structured is evaluated at `np.meshgrid(*x_data_coordinates,
indexing="ij")`.
Returns: Boolean of whether the above description is true.
"""
try:
for coordinates in x_data_coordinates:
if len(coordinates.shape) != 1:
return False
implied_y_data_shape = tuple(len(coordinates) for coordinates in x_data_coordinates)
if not y_data_structured.shape == implied_y_data_shape:
return False
except TypeError: # if x_data_coordinates is not iterable, for instance
return False
except AttributeError: # if y_data_structured has no shape, for instance
return False
return True
def interpn(
points: Tuple[_onp.ndarray],
values: _onp.ndarray,
xi: _onp.ndarray,
method: str = "linear",
bounds_error=True,
fill_value=_onp.NaN
) -> _onp.ndarray:
"""
Performs multidimensional interpolation on regular grids. Analogue to scipy.interpolate.interpn().
See syntax here: https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.interpn.html
Args:
points: The points defining the regular grid in n dimensions. Tuple of coordinates of each axis.
values: The data on the regular grid in n dimensions.
xi: The coordinates to sample the gridded data at.
method: The method of interpolation to perform. one of:
* "bspline" (Note: differentiable and suitable for optimization - made of piecewise-cubics. For other
applications, other interpolators may be faster. Not monotonicity-preserving - may overshoot.)
* "linear" (Note: differentiable, but not suitable for use in optimization w/o subgradient treatment due
to C1-discontinuity)
* "nearest" (Note: NOT differentiable, don't use in optimization. Fast.)
bounds_error: If True, when interpolated values are requested outside of the domain of the input data,
a ValueError is raised. If False, then fill_value is used.
fill_value: If provided, the value to use for points outside of the interpolation domain. If None,
values outside the domain are extrapolated.
Returns: Interpolated values at input coordinates.
"""
### Check input types for points and values
if is_casadi_type([points, values], recursive=True):
raise TypeError("The underlying dataset (points, values) must consist of NumPy arrays.")
### Check dimensions of points
for points_axis in points:
points_axis = array(points_axis)
if not len(points_axis.shape) == 1:
raise ValueError("`points` must consist of a tuple of 1D ndarrays defining the coordinates of each axis.")
### Check dimensions of values
implied_values_shape = tuple(len(points_axis) for points_axis in points)
if not values.shape == implied_values_shape:
raise ValueError(f"""
The shape of `values` should be {implied_values_shape}.
""")
if ( ### NumPy implementation
not is_casadi_type([points, values, xi], recursive=True)
) and (
(method == "linear") or (method == "nearest")
):
xi = _onp.array(xi).reshape((-1, len(implied_values_shape)))
return _interpolate.interpn(
points=points,
values=values,
xi=xi,
method=method,
bounds_error=bounds_error,
fill_value=fill_value
)
elif ( ### CasADi implementation
(method == "linear") or (method == "bspline")
):
### Add handling to patch a specific bug in CasADi that occurs when `values` is all zeros.
### For more information, see: https://github.com/casadi/casadi/issues/2837
if method == "bspline" and all(values == 0):
return zeros_like(xi)
### If xi is an int or float, promote it to an array
if isinstance(xi, int) or isinstance(xi, float):
xi = array([xi])
### If xi is a NumPy array and 1D, convert it to 2D for this.
if not is_casadi_type(xi, recursive=False) and len(xi.shape) != 2:
xi = _onp.reshape(xi, (-1, 1))
### Check that xi is now 2D
if not len(xi.shape) == 2:
raise ValueError("`xi` must have the shape (n_points, n_dimensions)!")
### Transpose xi so that xi.shape is [n_points, n_dimensions].
n_dimensions = len(points)
if not len(points) in xi.shape:
raise ValueError("`xi` must have the shape (n_points, n_dimensions)!")
if not xi.shape[1] == n_dimensions:
xi = xi.T
assert xi.shape[1] == n_dimensions
### Calculate the minimum and maximum values along each axis.
axis_values_min = [
_onp.min(axis_values)
for axis_values in points
]
axis_values_max = [
_onp.max(axis_values)
for axis_values in points
]
### If fill_value is None, project the xi back onto the nearest point in the domain.
if fill_value is None:
for axis in range(n_dimensions):
xi[:, axis] = where(
xi[:, axis] > axis_values_max[axis],
axis_values_max[axis],
xi[:, axis]
)
xi[:, axis] = where(
xi[:, axis] < axis_values_min[axis],
axis_values_min[axis],
xi[:, axis]
)
### Check bounds_error
if bounds_error:
if isinstance(xi, _cas.MX):
raise ValueError("Can't have the `bounds_error` flag as True if `xi` is of cas.MX type.")
for axis in range(n_dimensions):
if any(
logical_or(
xi[:, axis] > axis_values_max[axis],
xi[:, axis] < axis_values_min[axis]
)
):
raise ValueError(
f"One of the requested xi is out of bounds in dimension {axis}"
)
### Do the interpolation
values_flattened = _onp.ravel(values, order='F')
interpolator = _cas.interpolant(
'Interpolator',
method,
points,
values_flattened
)
fi = interpolator(xi.T).T
### If fill_value is a scalar, replace all out-of-bounds xi with that value.
if fill_value is not None:
for axis in range(n_dimensions):
fi = where(
xi[:, axis] > axis_values_max[axis],
fill_value,
fi
)
fi = where(
xi[:, axis] < axis_values_min[axis],
fill_value,
fi
)
### If DM output (i.e. a numeric value), convert that back to an array
if isinstance(fi, _cas.DM):
if fi.shape == (1, 1):
return float(fi)
else:
return _onp.array(fi, dtype=float).reshape(-1)
return fi
else:
raise ValueError("Bad value of `method`!")