001 /*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 * http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017 package org.apache.commons.math.stat.regression;
018
019 import org.apache.commons.math.linear.Array2DRowRealMatrix;
020 import org.apache.commons.math.linear.LUDecompositionImpl;
021 import org.apache.commons.math.linear.QRDecomposition;
022 import org.apache.commons.math.linear.QRDecompositionImpl;
023 import org.apache.commons.math.linear.RealMatrix;
024 import org.apache.commons.math.linear.RealVector;
025
026 /**
027 * <p>Implements ordinary least squares (OLS) to estimate the parameters of a
028 * multiple linear regression model.</p>
029 *
030 * <p>OLS assumes the covariance matrix of the error to be diagonal and with
031 * equal variance.</p>
032 * <p>
033 * u ~ N(0, σ<sup>2</sup>I)
034 * </p>
035 *
036 * <p>The regression coefficients, b, satisfy the normal equations:
037 * <p>
038 * X<sup>T</sup> X b = X<sup>T</sup> y
039 * </p>
040 *
041 * <p>To solve the normal equations, this implementation uses QR decomposition
042 * of the X matrix. (See {@link QRDecompositionImpl} for details on the
043 * decomposition algorithm.)
044 * </p>
045 * <p>X<sup>T</sup>X b = X<sup>T</sup> y <br/>
046 * (QR)<sup>T</sup> (QR) b = (QR)<sup>T</sup>y <br/>
047 * R<sup>T</sup> (Q<sup>T</sup>Q) R b = R<sup>T</sup> Q<sup>T</sup> y <br/>
048 * R<sup>T</sup> R b = R<sup>T</sup> Q<sup>T</sup> y <br/>
049 * (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> R b = (R<sup>T</sup>)<sup>-1</sup> R<sup>T</sup> Q<sup>T</sup> y <br/>
050 * R b = Q<sup>T</sup> y
051 * </p>
052 * Given Q and R, the last equation is solved by back-subsitution.</p>
053 *
054 * @version $Revision: 825925 $ $Date: 2009-10-16 11:11:47 -0400 (Fri, 16 Oct 2009) $
055 * @since 2.0
056 */
057 public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
058
059 /** Cached QR decomposition of X matrix */
060 private QRDecomposition qr = null;
061
062 /**
063 * Loads model x and y sample data, overriding any previous sample.
064 *
065 * Computes and caches QR decomposition of the X matrix.
066 * @param y the [n,1] array representing the y sample
067 * @param x the [n,k] array representing the x sample
068 * @throws IllegalArgumentException if the x and y array data are not
069 * compatible for the regression
070 */
071 public void newSampleData(double[] y, double[][] x) {
072 validateSampleData(x, y);
073 newYSampleData(y);
074 newXSampleData(x);
075 }
076
077 /**
078 * {@inheritDoc}
079 *
080 * Computes and caches QR decomposition of the X matrix
081 */
082 @Override
083 public void newSampleData(double[] data, int nobs, int nvars) {
084 super.newSampleData(data, nobs, nvars);
085 qr = new QRDecompositionImpl(X);
086 }
087
088 /**
089 * <p>Compute the "hat" matrix.
090 * </p>
091 * <p>The hat matrix is defined in terms of the design matrix X
092 * by X(X<sup>T</sup>X)<sup>-1</sup>X<sup>T</sup>
093 * </p>
094 * <p>The implementation here uses the QR decomposition to compute the
095 * hat matrix as Q I<sub>p</sub>Q<sup>T</sup> where I<sub>p</sub> is the
096 * p-dimensional identity matrix augmented by 0's. This computational
097 * formula is from "The Hat Matrix in Regression and ANOVA",
098 * David C. Hoaglin and Roy E. Welsch,
099 * <i>The American Statistician</i>, Vol. 32, No. 1 (Feb., 1978), pp. 17-22.
100 *
101 * @return the hat matrix
102 */
103 public RealMatrix calculateHat() {
104 // Create augmented identity matrix
105 RealMatrix Q = qr.getQ();
106 final int p = qr.getR().getColumnDimension();
107 final int n = Q.getColumnDimension();
108 Array2DRowRealMatrix augI = new Array2DRowRealMatrix(n, n);
109 double[][] augIData = augI.getDataRef();
110 for (int i = 0; i < n; i++) {
111 for (int j =0; j < n; j++) {
112 if (i == j && i < p) {
113 augIData[i][j] = 1d;
114 } else {
115 augIData[i][j] = 0d;
116 }
117 }
118 }
119
120 // Compute and return Hat matrix
121 return Q.multiply(augI).multiply(Q.transpose());
122 }
123
124 /**
125 * Loads new x sample data, overriding any previous sample
126 *
127 * @param x the [n,k] array representing the x sample
128 */
129 @Override
130 protected void newXSampleData(double[][] x) {
131 this.X = new Array2DRowRealMatrix(x);
132 qr = new QRDecompositionImpl(X);
133 }
134
135 /**
136 * Calculates regression coefficients using OLS.
137 *
138 * @return beta
139 */
140 @Override
141 protected RealVector calculateBeta() {
142 return qr.getSolver().solve(Y);
143 }
144
145 /**
146 * <p>Calculates the variance on the beta by OLS.
147 * </p>
148 * <p>Var(b) = (X<sup>T</sup>X)<sup>-1</sup>
149 * </p>
150 * <p>Uses QR decomposition to reduce (X<sup>T</sup>X)<sup>-1</sup>
151 * to (R<sup>T</sup>R)<sup>-1</sup>, with only the top p rows of
152 * R included, where p = the length of the beta vector.</p>
153 *
154 * @return The beta variance
155 */
156 @Override
157 protected RealMatrix calculateBetaVariance() {
158 int p = X.getColumnDimension();
159 RealMatrix Raug = qr.getR().getSubMatrix(0, p - 1 , 0, p - 1);
160 RealMatrix Rinv = new LUDecompositionImpl(Raug).getSolver().getInverse();
161 return Rinv.multiply(Rinv.transpose());
162 }
163
164
165 /**
166 * <p>Calculates the variance on the Y by OLS.
167 * </p>
168 * <p> Var(y) = Tr(u<sup>T</sup>u)/(n - k)
169 * </p>
170 * @return The Y variance
171 */
172 @Override
173 protected double calculateYVariance() {
174 RealVector residuals = calculateResiduals();
175 return residuals.dotProduct(residuals) /
176 (X.getRowDimension() - X.getColumnDimension());
177 }
178
179 }