001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017 package org.apache.commons.math.analysis.interpolation;
018
019 import java.io.Serializable;
020 import java.util.Arrays;
021
022 import org.apache.commons.math.MathException;
023 import org.apache.commons.math.analysis.polynomials.PolynomialSplineFunction;
024
025 /**
026 * Implements the <a href="http://en.wikipedia.org/wiki/Local_regression">
027 * Local Regression Algorithm</a> (also Loess, Lowess) for interpolation of
028 * real univariate functions.
029 * <p/>
030 * For reference, see
031 * <a href="http://www.math.tau.ac.il/~yekutiel/MA seminar/Cleveland 1979.pdf">
032 * William S. Cleveland - Robust Locally Weighted Regression and Smoothing
033 * Scatterplots</a>
034 * <p/>
035 * This class implements both the loess method and serves as an interpolation
036 * adapter to it, allowing to build a spline on the obtained loess fit.
037 *
038 * @version $Revision: 925812 $ $Date: 2010-03-21 11:49:31 -0400 (Sun, 21 Mar 2010) $
039 * @since 2.0
040 */
041 public class LoessInterpolator
042 implements UnivariateRealInterpolator, Serializable {
043
044 /** Default value of the bandwidth parameter. */
045 public static final double DEFAULT_BANDWIDTH = 0.3;
046
047 /** Default value of the number of robustness iterations. */
048 public static final int DEFAULT_ROBUSTNESS_ITERS = 2;
049
050 /**
051 * Default value for accuracy.
052 * @since 2.1
053 */
054 public static final double DEFAULT_ACCURACY = 1e-12;
055
056 /** serializable version identifier. */
057 private static final long serialVersionUID = 5204927143605193821L;
058
059 /**
060 * The bandwidth parameter: when computing the loess fit at
061 * a particular point, this fraction of source points closest
062 * to the current point is taken into account for computing
063 * a least-squares regression.
064 * <p/>
065 * A sensible value is usually 0.25 to 0.5.
066 */
067 private final double bandwidth;
068
069 /**
070 * The number of robustness iterations parameter: this many
071 * robustness iterations are done.
072 * <p/>
073 * A sensible value is usually 0 (just the initial fit without any
074 * robustness iterations) to 4.
075 */
076 private final int robustnessIters;
077
078 /**
079 * If the median residual at a certain robustness iteration
080 * is less than this amount, no more iterations are done.
081 */
082 private final double accuracy;
083
084 /**
085 * Constructs a new {@link LoessInterpolator}
086 * with a bandwidth of {@link #DEFAULT_BANDWIDTH},
087 * {@link #DEFAULT_ROBUSTNESS_ITERS} robustness iterations
088 * and an accuracy of {#link #DEFAULT_ACCURACY}.
089 * See {@link #LoessInterpolator(double, int, double)} for an explanation of
090 * the parameters.
091 */
092 public LoessInterpolator() {
093 this.bandwidth = DEFAULT_BANDWIDTH;
094 this.robustnessIters = DEFAULT_ROBUSTNESS_ITERS;
095 this.accuracy = DEFAULT_ACCURACY;
096 }
097
098 /**
099 * Constructs a new {@link LoessInterpolator}
100 * with given bandwidth and number of robustness iterations.
101 * <p>
102 * Calling this constructor is equivalent to calling {link {@link
103 * #LoessInterpolator(double, int, double) LoessInterpolator(bandwidth,
104 * robustnessIters, LoessInterpolator.DEFAULT_ACCURACY)}
105 * </p>
106 *
107 * @param bandwidth when computing the loess fit at
108 * a particular point, this fraction of source points closest
109 * to the current point is taken into account for computing
110 * a least-squares regression.</br>
111 * A sensible value is usually 0.25 to 0.5, the default value is
112 * {@link #DEFAULT_BANDWIDTH}.
113 * @param robustnessIters This many robustness iterations are done.</br>
114 * A sensible value is usually 0 (just the initial fit without any
115 * robustness iterations) to 4, the default value is
116 * {@link #DEFAULT_ROBUSTNESS_ITERS}.
117 * @throws MathException if bandwidth does not lie in the interval [0,1]
118 * or if robustnessIters is negative.
119 * @see #LoessInterpolator(double, int, double)
120 */
121 public LoessInterpolator(double bandwidth, int robustnessIters) throws MathException {
122 this(bandwidth, robustnessIters, DEFAULT_ACCURACY);
123 }
124
125 /**
126 * Constructs a new {@link LoessInterpolator}
127 * with given bandwidth, number of robustness iterations and accuracy.
128 *
129 * @param bandwidth when computing the loess fit at
130 * a particular point, this fraction of source points closest
131 * to the current point is taken into account for computing
132 * a least-squares regression.</br>
133 * A sensible value is usually 0.25 to 0.5, the default value is
134 * {@link #DEFAULT_BANDWIDTH}.
135 * @param robustnessIters This many robustness iterations are done.</br>
136 * A sensible value is usually 0 (just the initial fit without any
137 * robustness iterations) to 4, the default value is
138 * {@link #DEFAULT_ROBUSTNESS_ITERS}.
139 * @param accuracy If the median residual at a certain robustness iteration
140 * is less than this amount, no more iterations are done.
141 * @throws MathException if bandwidth does not lie in the interval [0,1]
142 * or if robustnessIters is negative.
143 * @see #LoessInterpolator(double, int)
144 * @since 2.1
145 */
146 public LoessInterpolator(double bandwidth, int robustnessIters, double accuracy) throws MathException {
147 if (bandwidth < 0 || bandwidth > 1) {
148 throw new MathException("bandwidth must be in the interval [0,1], but got {0}",
149 bandwidth);
150 }
151 this.bandwidth = bandwidth;
152 if (robustnessIters < 0) {
153 throw new MathException("the number of robustness iterations must " +
154 "be non-negative, but got {0}",
155 robustnessIters);
156 }
157 this.robustnessIters = robustnessIters;
158 this.accuracy = accuracy;
159 }
160
161 /**
162 * Compute an interpolating function by performing a loess fit
163 * on the data at the original abscissae and then building a cubic spline
164 * with a
165 * {@link org.apache.commons.math.analysis.interpolation.SplineInterpolator}
166 * on the resulting fit.
167 *
168 * @param xval the arguments for the interpolation points
169 * @param yval the values for the interpolation points
170 * @return A cubic spline built upon a loess fit to the data at the original abscissae
171 * @throws MathException if some of the following conditions are false:
172 * <ul>
173 * <li> Arguments and values are of the same size that is greater than zero</li>
174 * <li> The arguments are in a strictly increasing order</li>
175 * <li> All arguments and values are finite real numbers</li>
176 * </ul>
177 */
178 public final PolynomialSplineFunction interpolate(
179 final double[] xval, final double[] yval) throws MathException {
180 return new SplineInterpolator().interpolate(xval, smooth(xval, yval));
181 }
182
183 /**
184 * Compute a weighted loess fit on the data at the original abscissae.
185 *
186 * @param xval the arguments for the interpolation points
187 * @param yval the values for the interpolation points
188 * @param weights point weights: coefficients by which the robustness weight of a point is multiplied
189 * @return values of the loess fit at corresponding original abscissae
190 * @throws MathException if some of the following conditions are false:
191 * <ul>
192 * <li> Arguments and values are of the same size that is greater than zero</li>
193 * <li> The arguments are in a strictly increasing order</li>
194 * <li> All arguments and values are finite real numbers</li>
195 * </ul>
196 * @since 2.1
197 */
198 public final double[] smooth(final double[] xval, final double[] yval, final double[] weights)
199 throws MathException {
200 if (xval.length != yval.length) {
201 throw new MathException(
202 "Loess expects the abscissa and ordinate arrays " +
203 "to be of the same size, " +
204 "but got {0} abscissae and {1} ordinatae",
205 xval.length, yval.length);
206 }
207
208 final int n = xval.length;
209
210 if (n == 0) {
211 throw new MathException("Loess expects at least 1 point");
212 }
213
214 checkAllFiniteReal(xval, "all abscissae must be finite real numbers, but {0}-th is {1}");
215 checkAllFiniteReal(yval, "all ordinatae must be finite real numbers, but {0}-th is {1}");
216 checkAllFiniteReal(weights, "all weights must be finite real numbers, but {0}-th is {1}");
217
218 checkStrictlyIncreasing(xval);
219
220 if (n == 1) {
221 return new double[]{yval[0]};
222 }
223
224 if (n == 2) {
225 return new double[]{yval[0], yval[1]};
226 }
227
228 int bandwidthInPoints = (int) (bandwidth * n);
229
230 if (bandwidthInPoints < 2) {
231 throw new MathException(
232 "the bandwidth must be large enough to " +
233 "accomodate at least 2 points. There are {0} " +
234 " data points, and bandwidth must be at least {1} " +
235 " but it is only {2}",
236 n, 2.0 / n, bandwidth);
237 }
238
239 final double[] res = new double[n];
240
241 final double[] residuals = new double[n];
242 final double[] sortedResiduals = new double[n];
243
244 final double[] robustnessWeights = new double[n];
245
246 // Do an initial fit and 'robustnessIters' robustness iterations.
247 // This is equivalent to doing 'robustnessIters+1' robustness iterations
248 // starting with all robustness weights set to 1.
249 Arrays.fill(robustnessWeights, 1);
250
251 for (int iter = 0; iter <= robustnessIters; ++iter) {
252 final int[] bandwidthInterval = {0, bandwidthInPoints - 1};
253 // At each x, compute a local weighted linear regression
254 for (int i = 0; i < n; ++i) {
255 final double x = xval[i];
256
257 // Find out the interval of source points on which
258 // a regression is to be made.
259 if (i > 0) {
260 updateBandwidthInterval(xval, weights, i, bandwidthInterval);
261 }
262
263 final int ileft = bandwidthInterval[0];
264 final int iright = bandwidthInterval[1];
265
266 // Compute the point of the bandwidth interval that is
267 // farthest from x
268 final int edge;
269 if (xval[i] - xval[ileft] > xval[iright] - xval[i]) {
270 edge = ileft;
271 } else {
272 edge = iright;
273 }
274
275 // Compute a least-squares linear fit weighted by
276 // the product of robustness weights and the tricube
277 // weight function.
278 // See http://en.wikipedia.org/wiki/Linear_regression
279 // (section "Univariate linear case")
280 // and http://en.wikipedia.org/wiki/Weighted_least_squares
281 // (section "Weighted least squares")
282 double sumWeights = 0;
283 double sumX = 0;
284 double sumXSquared = 0;
285 double sumY = 0;
286 double sumXY = 0;
287 double denom = Math.abs(1.0 / (xval[edge] - x));
288 for (int k = ileft; k <= iright; ++k) {
289 final double xk = xval[k];
290 final double yk = yval[k];
291 final double dist = (k < i) ? x - xk : xk - x;
292 final double w = tricube(dist * denom) * robustnessWeights[k] * weights[k];
293 final double xkw = xk * w;
294 sumWeights += w;
295 sumX += xkw;
296 sumXSquared += xk * xkw;
297 sumY += yk * w;
298 sumXY += yk * xkw;
299 }
300
301 final double meanX = sumX / sumWeights;
302 final double meanY = sumY / sumWeights;
303 final double meanXY = sumXY / sumWeights;
304 final double meanXSquared = sumXSquared / sumWeights;
305
306 final double beta;
307 if (Math.sqrt(Math.abs(meanXSquared - meanX * meanX)) < accuracy) {
308 beta = 0;
309 } else {
310 beta = (meanXY - meanX * meanY) / (meanXSquared - meanX * meanX);
311 }
312
313 final double alpha = meanY - beta * meanX;
314
315 res[i] = beta * x + alpha;
316 residuals[i] = Math.abs(yval[i] - res[i]);
317 }
318
319 // No need to recompute the robustness weights at the last
320 // iteration, they won't be needed anymore
321 if (iter == robustnessIters) {
322 break;
323 }
324
325 // Recompute the robustness weights.
326
327 // Find the median residual.
328 // An arraycopy and a sort are completely tractable here,
329 // because the preceding loop is a lot more expensive
330 System.arraycopy(residuals, 0, sortedResiduals, 0, n);
331 Arrays.sort(sortedResiduals);
332 final double medianResidual = sortedResiduals[n / 2];
333
334 if (Math.abs(medianResidual) < accuracy) {
335 break;
336 }
337
338 for (int i = 0; i < n; ++i) {
339 final double arg = residuals[i] / (6 * medianResidual);
340 if (arg >= 1) {
341 robustnessWeights[i] = 0;
342 } else {
343 final double w = 1 - arg * arg;
344 robustnessWeights[i] = w * w;
345 }
346 }
347 }
348
349 return res;
350 }
351
352 /**
353 * Compute a loess fit on the data at the original abscissae.
354 *
355 * @param xval the arguments for the interpolation points
356 * @param yval the values for the interpolation points
357 * @return values of the loess fit at corresponding original abscissae
358 * @throws MathException if some of the following conditions are false:
359 * <ul>
360 * <li> Arguments and values are of the same size that is greater than zero</li>
361 * <li> The arguments are in a strictly increasing order</li>
362 * <li> All arguments and values are finite real numbers</li>
363 * </ul>
364 */
365 public final double[] smooth(final double[] xval, final double[] yval)
366 throws MathException {
367 if (xval.length != yval.length) {
368 throw new MathException(
369 "Loess expects the abscissa and ordinate arrays " +
370 "to be of the same size, " +
371 "but got {0} abscissae and {1} ordinatae",
372 xval.length, yval.length);
373 }
374
375 final double[] unitWeights = new double[xval.length];
376 Arrays.fill(unitWeights, 1.0);
377
378 return smooth(xval, yval, unitWeights);
379 }
380
381 /**
382 * Given an index interval into xval that embraces a certain number of
383 * points closest to xval[i-1], update the interval so that it embraces
384 * the same number of points closest to xval[i], ignoring zero weights.
385 *
386 * @param xval arguments array
387 * @param weights weights array
388 * @param i the index around which the new interval should be computed
389 * @param bandwidthInterval a two-element array {left, right} such that: <p/>
390 * <tt>(left==0 or xval[i] - xval[left-1] > xval[right] - xval[i])</tt>
391 * <p/> and also <p/>
392 * <tt>(right==xval.length-1 or xval[right+1] - xval[i] > xval[i] - xval[left])</tt>.
393 * The array will be updated.
394 */
395 private static void updateBandwidthInterval(final double[] xval, final double[] weights,
396 final int i,
397 final int[] bandwidthInterval) {
398 final int left = bandwidthInterval[0];
399 final int right = bandwidthInterval[1];
400
401 // The right edge should be adjusted if the next point to the right
402 // is closer to xval[i] than the leftmost point of the current interval
403 int nextRight = nextNonzero(weights, right);
404 if (nextRight < xval.length && xval[nextRight] - xval[i] < xval[i] - xval[left]) {
405 int nextLeft = nextNonzero(weights, bandwidthInterval[0]);
406 bandwidthInterval[0] = nextLeft;
407 bandwidthInterval[1] = nextRight;
408 }
409 }
410
411 /**
412 * Returns the smallest index j such that j > i && (j==weights.length || weights[j] != 0)
413 * @param weights weights array
414 * @param i the index from which to start search; must be < weights.length
415 * @return the smallest index j such that j > i && (j==weights.length || weights[j] != 0)
416 */
417 private static int nextNonzero(final double[] weights, final int i) {
418 int j = i + 1;
419 while(j < weights.length && weights[j] == 0) {
420 j++;
421 }
422 return j;
423 }
424
425 /**
426 * Compute the
427 * <a href="http://en.wikipedia.org/wiki/Local_regression#Weight_function">tricube</a>
428 * weight function
429 *
430 * @param x the argument
431 * @return (1-|x|^3)^3
432 */
433 private static double tricube(final double x) {
434 final double tmp = 1 - x * x * x;
435 return tmp * tmp * tmp;
436 }
437
438 /**
439 * Check that all elements of an array are finite real numbers.
440 *
441 * @param values the values array
442 * @param pattern pattern of the error message
443 * @throws MathException if one of the values is not a finite real number
444 */
445 private static void checkAllFiniteReal(final double[] values, final String pattern)
446 throws MathException {
447 for (int i = 0; i < values.length; i++) {
448 final double x = values[i];
449 if (Double.isInfinite(x) || Double.isNaN(x)) {
450 throw new MathException(pattern, i, x);
451 }
452 }
453 }
454
455 /**
456 * Check that elements of the abscissae array are in a strictly
457 * increasing order.
458 *
459 * @param xval the abscissae array
460 * @throws MathException if the abscissae array
461 * is not in a strictly increasing order
462 */
463 private static void checkStrictlyIncreasing(final double[] xval)
464 throws MathException {
465 for (int i = 0; i < xval.length; ++i) {
466 if (i >= 1 && xval[i - 1] >= xval[i]) {
467 throw new MathException(
468 "the abscissae array must be sorted in a strictly " +
469 "increasing order, but the {0}-th element is {1} " +
470 "whereas {2}-th is {3}",
471 i - 1, xval[i - 1], i, xval[i]);
472 }
473 }
474 }
475 }