1   
  2   
  3   
  4   
  5   
  6   
  7   
  8   
  9   
 10   
 11   
 12   
 13   
 14   
 15   
 16   
 17   
 18  from numpy import array, dot 
 19  from pyspark import SparkContext 
 20  from pyspark.mllib._common import \ 
 21      _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ 
 22      _serialize_double_matrix, _deserialize_double_matrix, \ 
 23      _serialize_double_vector, _deserialize_double_vector, \ 
 24      _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ 
 25      _linear_predictor_typecheck 
 28      """Something that has a vector of coefficients and an intercept.""" 
 30          self._coeff = coeff 
 31          self._intercept = intercept 
   32   
 34      """A linear regression model. 
 35   
 36      >>> lrmb = LinearRegressionModelBase(array([1.0, 2.0]), 0.1) 
 37      >>> abs(lrmb.predict(array([-1.03, 7.777])) - 14.624) < 1e-6 
 38      True 
 39      """ 
 41          """Predict the value of the dependent variable given a vector x""" 
 42          """containing values for the independent variables.""" 
 43          _linear_predictor_typecheck(x, self._coeff) 
 44          return dot(self._coeff, x) + self._intercept 
   45   
 47      """A linear regression model derived from a least-squares fit. 
 48   
 49      >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) 
 50      >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) 
 51      """ 
  52   
 54      @classmethod 
 55 -    def train(cls, data, iterations=100, step=1.0, 
 56                miniBatchFraction=1.0, initialWeights=None): 
  57          """Train a linear regression model on the given data.""" 
 58          sc = data.context 
 59          return _regression_train_wrapper(sc, lambda d, i: 
 60                  sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( 
 61                          d._jrdd, iterations, step, miniBatchFraction, i), 
 62                  LinearRegressionModel, data, initialWeights) 
   63   
 65      """A linear regression model derived from a least-squares fit with an 
 66      l_1 penalty term. 
 67   
 68      >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) 
 69      >>> lrm = LassoWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) 
 70      """ 
  71   
 73      @classmethod 
 74 -    def train(cls, data, iterations=100, step=1.0, regParam=1.0, 
 75                miniBatchFraction=1.0, initialWeights=None): 
  76          """Train a Lasso regression model on the given data.""" 
 77          sc = data.context 
 78          return _regression_train_wrapper(sc, lambda d, i: 
 79                  sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD(d._jrdd, 
 80                          iterations, step, regParam, miniBatchFraction, i), 
 81                  LassoModel, data, initialWeights) 
   82   
 84      """A linear regression model derived from a least-squares fit with an 
 85      l_2 penalty term. 
 86   
 87      >>> data = array([0.0, 0.0, 1.0, 1.0, 3.0, 2.0, 2.0, 3.0]).reshape(4,2) 
 88      >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), initialWeights=array([1.0])) 
 89      """ 
  90   
 92      @classmethod 
 93 -    def train(cls, data, iterations=100, step=1.0, regParam=1.0, 
 94                miniBatchFraction=1.0, initialWeights=None): 
  95          """Train a ridge regression model on the given data.""" 
 96          sc = data.context 
 97          return _regression_train_wrapper(sc, lambda d, i: 
 98                  sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD(d._jrdd, 
 99                          iterations, step, regParam, miniBatchFraction, i), 
100                  RidgeRegressionModel, data, initialWeights) 
  101   
103      import doctest 
104      globs = globals().copy() 
105      globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) 
106      (failure_count, test_count) = doctest.testmod(globs=globs, 
107              optionflags=doctest.ELLIPSIS) 
108      globs['sc'].stop() 
109      if failure_count: 
110          exit(-1) 
 111   
112  if __name__ == "__main__": 
113      _test() 
114