1   
 2   
 3   
 4   
 5   
 6   
 7   
 8   
 9   
10   
11   
12   
13   
14   
15   
16   
17   
18  from numpy import array, dot 
19  from math import sqrt 
20  from pyspark import SparkContext 
21  from pyspark.mllib._common import \ 
22      _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ 
23      _serialize_double_matrix, _deserialize_double_matrix, \ 
24      _serialize_double_vector, _deserialize_double_vector, \ 
25      _get_initial_weights, _serialize_rating, _regression_train_wrapper 
28      """A clustering model derived from the k-means method. 
29   
30      >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) 
31      >>> clusters = KMeans.train(sc.parallelize(data), 2, maxIterations=10, runs=30, initializationMode="random") 
32      >>> clusters.predict(array([0.0, 0.0])) == clusters.predict(array([1.0, 1.0])) 
33      True 
34      >>> clusters.predict(array([8.0, 9.0])) == clusters.predict(array([9.0, 8.0])) 
35      True 
36      >>> clusters = KMeans.train(sc.parallelize(data), 2) 
37      """ 
39          self.centers = centers_ 
 40   
42          """Find the cluster to which x belongs in this model.""" 
43          best = 0 
44          best_distance = 1e75 
45          for i in range(0, self.centers.shape[0]): 
46              diff = x - self.centers[i] 
47              distance = sqrt(dot(diff, diff)) 
48              if distance < best_distance: 
49                  best = i 
50                  best_distance = distance 
51          return best 
  52   
54      @classmethod 
55 -    def train(cls, data, k, maxIterations=100, runs=1, 
56              initializationMode="k-means||"): 
 57          """Train a k-means clustering model.""" 
58          sc = data.context 
59          dataBytes = _get_unmangled_double_vector_rdd(data) 
60          ans = sc._jvm.PythonMLLibAPI().trainKMeansModel(dataBytes._jrdd, 
61                  k, maxIterations, runs, initializationMode) 
62          if len(ans) != 1: 
63              raise RuntimeError("JVM call result had unexpected length") 
64          elif type(ans[0]) != bytearray: 
65              raise RuntimeError("JVM call result had first element of type " 
66                      + type(ans[0]) + " which is not bytearray") 
67          return KMeansModel(_deserialize_double_matrix(ans[0])) 
  68   
70      import doctest 
71      globs = globals().copy() 
72      globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) 
73      (failure_count, test_count) = doctest.testmod(globs=globs, 
74              optionflags=doctest.ELLIPSIS) 
75      globs['sc'].stop() 
76      if failure_count: 
77          exit(-1) 
 78   
79  if __name__ == "__main__": 
80      _test() 
81