1   
  2   
  3   
  4   
  5   
  6   
  7   
  8   
  9   
 10   
 11   
 12   
 13   
 14   
 15   
 16   
 17   
 18  import os 
 19  import shutil 
 20  import sys 
 21  from threading import Lock 
 22  from tempfile import NamedTemporaryFile 
 23   
 24  from pyspark import accumulators 
 25  from pyspark.accumulators import Accumulator 
 26  from pyspark.broadcast import Broadcast 
 27  from pyspark.conf import SparkConf 
 28  from pyspark.files import SparkFiles 
 29  from pyspark.java_gateway import launch_gateway 
 30  from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer 
 31  from pyspark.storagelevel import StorageLevel 
 32  from pyspark.rdd import RDD 
 33   
 34  from py4j.java_collections import ListConverter 
 35   
 36   
 37 -class SparkContext(object): 
  38      """ 
 39      Main entry point for Spark functionality. A SparkContext represents the 
 40      connection to a Spark cluster, and can be used to create L{RDD}s and 
 41      broadcast variables on that cluster. 
 42      """ 
 43   
 44      _gateway = None 
 45      _jvm = None 
 46      _writeToFile = None 
 47      _next_accum_id = 0 
 48      _active_spark_context = None 
 49      _lock = Lock() 
 50      _python_includes = None  
 51   
 52   
 53 -    def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, 
 54          environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None): 
  55          """ 
 56          Create a new SparkContext. At least the master and app name should be set, 
 57          either through the named parameters here or through C{conf}. 
 58   
 59          @param master: Cluster URL to connect to 
 60                 (e.g. mesos://host:port, spark://host:port, local[4]). 
 61          @param appName: A name for your job, to display on the cluster web UI. 
 62          @param sparkHome: Location where Spark is installed on cluster nodes. 
 63          @param pyFiles: Collection of .zip or .py files to send to the cluster 
 64                 and add to PYTHONPATH.  These can be paths on the local file 
 65                 system or HDFS, HTTP, HTTPS, or FTP URLs. 
 66          @param environment: A dictionary of environment variables to set on 
 67                 worker nodes. 
 68          @param batchSize: The number of Python objects represented as a single 
 69                 Java object.  Set 1 to disable batching or -1 to use an 
 70                 unlimited batch size. 
 71          @param serializer: The serializer for RDDs. 
 72          @param conf: A L{SparkConf} object setting Spark properties. 
 73   
 74   
 75          >>> from pyspark.context import SparkContext 
 76          >>> sc = SparkContext('local', 'test') 
 77   
 78          >>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL 
 79          Traceback (most recent call last): 
 80              ... 
 81          ValueError:... 
 82          """ 
 83          SparkContext._ensure_initialized(self) 
 84   
 85          self.environment = environment or {} 
 86          self._conf = conf or SparkConf(_jvm=self._jvm) 
 87          self._batchSize = batchSize   
 88          self._unbatched_serializer = serializer 
 89          if batchSize == 1: 
 90              self.serializer = self._unbatched_serializer 
 91          else: 
 92              self.serializer = BatchedSerializer(self._unbatched_serializer, 
 93                                                  batchSize) 
 94   
 95           
 96          if master: 
 97              self._conf.setMaster(master) 
 98          if appName: 
 99              self._conf.setAppName(appName) 
100          if sparkHome: 
101              self._conf.setSparkHome(sparkHome) 
102          if environment: 
103              for key, value in environment.iteritems(): 
104                  self._conf.setExecutorEnv(key, value) 
105   
106           
107          if not self._conf.contains("spark.master"): 
108              raise Exception("A master URL must be set in your configuration") 
109          if not self._conf.contains("spark.app.name"): 
110              raise Exception("An application name must be set in your configuration") 
111   
112           
113           
114          self.master = self._conf.get("spark.master") 
115          self.appName = self._conf.get("spark.app.name") 
116          self.sparkHome = self._conf.get("spark.home", None) 
117          for (k, v) in self._conf.getAll(): 
118              if k.startswith("spark.executorEnv."): 
119                  varName = k[len("spark.executorEnv."):] 
120                  self.environment[varName] = v 
121   
122           
123          self._jsc = self._jvm.JavaSparkContext(self._conf._jconf) 
124   
125           
126           
127          self._accumulatorServer = accumulators._start_update_server() 
128          (host, port) = self._accumulatorServer.server_address 
129          self._javaAccumulator = self._jsc.accumulator( 
130                  self._jvm.java.util.ArrayList(), 
131                  self._jvm.PythonAccumulatorParam(host, port)) 
132   
133          self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') 
134   
135           
136           
137           
138           
139          self._pickled_broadcast_vars = set() 
140   
141          SparkFiles._sc = self 
142          root_dir = SparkFiles.getRootDirectory() 
143          sys.path.append(root_dir) 
144   
145           
146          self._python_includes = list() 
147          for path in (pyFiles or []): 
148              self.addPyFile(path) 
149   
150           
151          local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) 
152          self._temp_dir = \ 
153              self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() 
 154   
155      @classmethod 
156 -    def _ensure_initialized(cls, instance=None): 
 157          with SparkContext._lock: 
158              if not SparkContext._gateway: 
159                  SparkContext._gateway = launch_gateway() 
160                  SparkContext._jvm = SparkContext._gateway.jvm 
161                  SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile 
162   
163              if instance: 
164                  if SparkContext._active_spark_context and SparkContext._active_spark_context != instance: 
165                      raise ValueError("Cannot run multiple SparkContexts at once") 
166                  else: 
167                      SparkContext._active_spark_context = instance 
 168   
169      @classmethod 
170 -    def setSystemProperty(cls, key, value): 
 171          """ 
172          Set a Java system property, such as spark.executor.memory. This must 
173          must be invoked before instantiating SparkContext. 
174          """ 
175          SparkContext._ensure_initialized() 
176          SparkContext._jvm.java.lang.System.setProperty(key, value) 
 177   
178      @property 
180          """ 
181          Default level of parallelism to use when not given by user (e.g. for 
182          reduce tasks) 
183          """ 
184          return self._jsc.sc().defaultParallelism() 
 185   
188   
190          """ 
191          Shut down the SparkContext. 
192          """ 
193          if self._jsc: 
194              self._jsc.stop() 
195              self._jsc = None 
196          if self._accumulatorServer: 
197              self._accumulatorServer.shutdown() 
198              self._accumulatorServer = None 
199          with SparkContext._lock: 
200              SparkContext._active_spark_context = None 
 201   
202 -    def parallelize(self, c, numSlices=None): 
 203          """ 
204          Distribute a local Python collection to form an RDD. 
205   
206          >>> sc.parallelize(range(5), 5).glom().collect() 
207          [[0], [1], [2], [3], [4]] 
208          """ 
209          numSlices = numSlices or self.defaultParallelism 
210           
211           
212           
213          tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) 
214           
215          if "__len__" not in dir(c): 
216              c = list(c)     
217          batchSize = min(len(c) // numSlices, self._batchSize) 
218          if batchSize > 1: 
219              serializer = BatchedSerializer(self._unbatched_serializer, 
220                                             batchSize) 
221          else: 
222              serializer = self._unbatched_serializer 
223          serializer.dump_stream(c, tempFile) 
224          tempFile.close() 
225          readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile 
226          jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) 
227          return RDD(jrdd, self, serializer) 
 228   
229 -    def textFile(self, name, minSplits=None): 
 230          """ 
231          Read a text file from HDFS, a local file system (available on all 
232          nodes), or any Hadoop-supported file system URI, and return it as an 
233          RDD of Strings. 
234          """ 
235          minSplits = minSplits or min(self.defaultParallelism, 2) 
236          return RDD(self._jsc.textFile(name, minSplits), self, 
237                     UTF8Deserializer()) 
 238   
239 -    def _checkpointFile(self, name, input_deserializer): 
 240          jrdd = self._jsc.checkpointFile(name) 
241          return RDD(jrdd, self, input_deserializer) 
 242   
243 -    def union(self, rdds): 
 244          """ 
245          Build the union of a list of RDDs. 
246   
247          This supports unions() of RDDs with different serialized formats, 
248          although this forces them to be reserialized using the default 
249          serializer: 
250   
251          >>> path = os.path.join(tempdir, "union-text.txt") 
252          >>> with open(path, "w") as testFile: 
253          ...    testFile.write("Hello") 
254          >>> textFile = sc.textFile(path) 
255          >>> textFile.collect() 
256          [u'Hello'] 
257          >>> parallelized = sc.parallelize(["World!"]) 
258          >>> sorted(sc.union([textFile, parallelized]).collect()) 
259          [u'Hello', 'World!'] 
260          """ 
261          first_jrdd_deserializer = rdds[0]._jrdd_deserializer 
262          if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): 
263              rdds = [x._reserialize() for x in rdds] 
264          first = rdds[0]._jrdd 
265          rest = [x._jrdd for x in rdds[1:]] 
266          rest = ListConverter().convert(rest, self._gateway._gateway_client) 
267          return RDD(self._jsc.union(first, rest), self, 
268                     rdds[0]._jrdd_deserializer) 
 269   
270 -    def broadcast(self, value): 
 271          """ 
272          Broadcast a read-only variable to the cluster, returning a 
273          L{Broadcast<pyspark.broadcast.Broadcast>} 
274          object for reading it in distributed functions. The variable will be 
275          sent to each cluster only once. 
276          """ 
277          pickleSer = PickleSerializer() 
278          pickled = pickleSer.dumps(value) 
279          jbroadcast = self._jsc.broadcast(bytearray(pickled)) 
280          return Broadcast(jbroadcast.id(), value, jbroadcast, 
281                           self._pickled_broadcast_vars) 
 282   
283 -    def accumulator(self, value, accum_param=None): 
 284          """ 
285          Create an L{Accumulator} with the given initial value, using a given 
286          L{AccumulatorParam} helper object to define how to add values of the 
287          data type if provided. Default AccumulatorParams are used for integers 
288          and floating-point numbers if you do not provide one. For other types, 
289          a custom AccumulatorParam can be used. 
290          """ 
291          if accum_param is None: 
292              if isinstance(value, int): 
293                  accum_param = accumulators.INT_ACCUMULATOR_PARAM 
294              elif isinstance(value, float): 
295                  accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM 
296              elif isinstance(value, complex): 
297                  accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM 
298              else: 
299                  raise Exception("No default accumulator param for type %s" % type(value)) 
300          SparkContext._next_accum_id += 1 
301          return Accumulator(SparkContext._next_accum_id - 1, value, accum_param) 
 302   
303 -    def addFile(self, path): 
 304          """ 
305          Add a file to be downloaded with this Spark job on every node. 
306          The C{path} passed can be either a local file, a file in HDFS 
307          (or other Hadoop-supported filesystems), or an HTTP, HTTPS or 
308          FTP URI. 
309   
310          To access the file in Spark jobs, use 
311          L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its 
312          download location. 
313   
314          >>> from pyspark import SparkFiles 
315          >>> path = os.path.join(tempdir, "test.txt") 
316          >>> with open(path, "w") as testFile: 
317          ...    testFile.write("100") 
318          >>> sc.addFile(path) 
319          >>> def func(iterator): 
320          ...    with open(SparkFiles.get("test.txt")) as testFile: 
321          ...        fileVal = int(testFile.readline()) 
322          ...        return [x * 100 for x in iterator] 
323          >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() 
324          [100, 200, 300, 400] 
325          """ 
326          self._jsc.sc().addFile(path) 
 327   
328 -    def clearFiles(self): 
 329          """ 
330          Clear the job's list of files added by L{addFile} or L{addPyFile} so 
331          that they do not get downloaded to any new nodes. 
332          """ 
333           
334          self._jsc.sc().clearFiles() 
 335   
336 -    def addPyFile(self, path): 
 337          """ 
338          Add a .py or .zip dependency for all tasks to be executed on this 
339          SparkContext in the future.  The C{path} passed can be either a local 
340          file, a file in HDFS (or other Hadoop-supported filesystems), or an 
341          HTTP, HTTPS or FTP URI. 
342          """ 
343          self.addFile(path) 
344          (dirname, filename) = os.path.split(path)  
345   
346          if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): 
347              self._python_includes.append(filename) 
348              sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename))  
 349   
350 -    def setCheckpointDir(self, dirName): 
 351          """ 
352          Set the directory under which RDDs are going to be checkpointed. The 
353          directory must be a HDFS path if running on a cluster. 
354          """ 
355          self._jsc.sc().setCheckpointDir(dirName) 
 356   
357 -    def _getJavaStorageLevel(self, storageLevel): 
 358          """ 
359          Returns a Java StorageLevel based on a pyspark.StorageLevel. 
360          """ 
361          if not isinstance(storageLevel, StorageLevel): 
362              raise Exception("storageLevel must be of type pyspark.StorageLevel") 
363   
364          newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel 
365          return newStorageLevel(storageLevel.useDisk, storageLevel.useMemory, 
366              storageLevel.deserialized, storageLevel.replication) 
  367   
369      import atexit 
370      import doctest 
371      import tempfile 
372      globs = globals().copy() 
373      globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) 
374      globs['tempdir'] = tempfile.mkdtemp() 
375      atexit.register(lambda: shutil.rmtree(globs['tempdir'])) 
376      (failure_count, test_count) = doctest.testmod(globs=globs) 
377      globs['sc'].stop() 
378      if failure_count: 
379          exit(-1) 
 380   
381   
382  if __name__ == "__main__": 
383      _test() 
384