001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017 package org.apache.commons.math.stat.regression;
018
019 import org.apache.commons.math.MathRuntimeException;
020 import org.apache.commons.math.linear.RealMatrix;
021 import org.apache.commons.math.linear.Array2DRowRealMatrix;
022 import org.apache.commons.math.linear.RealVector;
023 import org.apache.commons.math.linear.ArrayRealVector;
024
025 /**
026 * Abstract base class for implementations of MultipleLinearRegression.
027 * @version $Revision: 811685 $ $Date: 2009-09-05 13:36:48 -0400 (Sat, 05 Sep 2009) $
028 * @since 2.0
029 */
030 public abstract class AbstractMultipleLinearRegression implements
031 MultipleLinearRegression {
032
033 /** X sample data. */
034 protected RealMatrix X;
035
036 /** Y sample data. */
037 protected RealVector Y;
038
039 /**
040 * Loads model x and y sample data from a flat array of data, overriding any previous sample.
041 * Assumes that rows are concatenated with y values first in each row.
042 *
043 * @param data input data array
044 * @param nobs number of observations (rows)
045 * @param nvars number of independent variables (columns, not counting y)
046 */
047 public void newSampleData(double[] data, int nobs, int nvars) {
048 double[] y = new double[nobs];
049 double[][] x = new double[nobs][nvars + 1];
050 int pointer = 0;
051 for (int i = 0; i < nobs; i++) {
052 y[i] = data[pointer++];
053 x[i][0] = 1.0d;
054 for (int j = 1; j < nvars + 1; j++) {
055 x[i][j] = data[pointer++];
056 }
057 }
058 this.X = new Array2DRowRealMatrix(x);
059 this.Y = new ArrayRealVector(y);
060 }
061
062 /**
063 * Loads new y sample data, overriding any previous sample
064 *
065 * @param y the [n,1] array representing the y sample
066 */
067 protected void newYSampleData(double[] y) {
068 this.Y = new ArrayRealVector(y);
069 }
070
071 /**
072 * Loads new x sample data, overriding any previous sample
073 *
074 * @param x the [n,k] array representing the x sample
075 */
076 protected void newXSampleData(double[][] x) {
077 this.X = new Array2DRowRealMatrix(x);
078 }
079
080 /**
081 * Validates sample data.
082 *
083 * @param x the [n,k] array representing the x sample
084 * @param y the [n,1] array representing the y sample
085 * @throws IllegalArgumentException if the x and y array data are not
086 * compatible for the regression
087 */
088 protected void validateSampleData(double[][] x, double[] y) {
089 if ((x == null) || (y == null) || (x.length != y.length)) {
090 throw MathRuntimeException.createIllegalArgumentException(
091 "dimension mismatch {0} != {1}",
092 (x == null) ? 0 : x.length,
093 (y == null) ? 0 : y.length);
094 } else if ((x.length > 0) && (x[0].length > x.length)) {
095 throw MathRuntimeException.createIllegalArgumentException(
096 "not enough data ({0} rows) for this many predictors ({1} predictors)",
097 x.length, x[0].length);
098 }
099 }
100
101 /**
102 * Validates sample data.
103 *
104 * @param x the [n,k] array representing the x sample
105 * @param covariance the [n,n] array representing the covariance matrix
106 * @throws IllegalArgumentException if the x sample data or covariance
107 * matrix are not compatible for the regression
108 */
109 protected void validateCovarianceData(double[][] x, double[][] covariance) {
110 if (x.length != covariance.length) {
111 throw MathRuntimeException.createIllegalArgumentException(
112 "dimension mismatch {0} != {1}", x.length, covariance.length);
113 }
114 if (covariance.length > 0 && covariance.length != covariance[0].length) {
115 throw MathRuntimeException.createIllegalArgumentException(
116 "a {0}x{1} matrix was provided instead of a square matrix",
117 covariance.length, covariance[0].length);
118 }
119 }
120
121 /**
122 * {@inheritDoc}
123 */
124 public double[] estimateRegressionParameters() {
125 RealVector b = calculateBeta();
126 return b.getData();
127 }
128
129 /**
130 * {@inheritDoc}
131 */
132 public double[] estimateResiduals() {
133 RealVector b = calculateBeta();
134 RealVector e = Y.subtract(X.operate(b));
135 return e.getData();
136 }
137
138 /**
139 * {@inheritDoc}
140 */
141 public double[][] estimateRegressionParametersVariance() {
142 return calculateBetaVariance().getData();
143 }
144
145 /**
146 * {@inheritDoc}
147 */
148 public double[] estimateRegressionParametersStandardErrors() {
149 double[][] betaVariance = estimateRegressionParametersVariance();
150 double sigma = calculateYVariance();
151 int length = betaVariance[0].length;
152 double[] result = new double[length];
153 for (int i = 0; i < length; i++) {
154 result[i] = Math.sqrt(sigma * betaVariance[i][i]);
155 }
156 return result;
157 }
158
159 /**
160 * {@inheritDoc}
161 */
162 public double estimateRegressandVariance() {
163 return calculateYVariance();
164 }
165
166 /**
167 * Calculates the beta of multiple linear regression in matrix notation.
168 *
169 * @return beta
170 */
171 protected abstract RealVector calculateBeta();
172
173 /**
174 * Calculates the beta variance of multiple linear regression in matrix
175 * notation.
176 *
177 * @return beta variance
178 */
179 protected abstract RealMatrix calculateBetaVariance();
180
181 /**
182 * Calculates the Y variance of multiple linear regression.
183 *
184 * @return Y variance
185 */
186 protected abstract double calculateYVariance();
187
188 /**
189 * Calculates the residuals of multiple linear regression in matrix
190 * notation.
191 *
192 * <pre>
193 * u = y - X * b
194 * </pre>
195 *
196 * @return The residuals [n,1] matrix
197 */
198 protected RealVector calculateResiduals() {
199 RealVector b = calculateBeta();
200 return Y.subtract(X.operate(b));
201 }
202
203 }