2020 basestring = str
2121
2222from pyspark import SparkContext
23- from pyspark .sql import Column , DataFrame , SQLContext
23+ from pyspark .sql import Column , DataFrame , SparkSession
2424from pyspark .storagelevel import StorageLevel
2525
2626from graphframes .lib import Pregel
2727
2828
29- def _from_java_gf (jgf , sqlContext ):
29+ def _from_java_gf (jgf , spark ):
3030 """
3131 (internal) creates a python GraphFrame wrapper from a java GraphFrame.
3232
3333 :param jgf:
3434 """
35- pv = DataFrame (jgf .vertices (), sqlContext )
36- pe = DataFrame (jgf .edges (), sqlContext )
35+ pv = DataFrame (jgf .vertices (), spark )
36+ pe = DataFrame (jgf .edges (), spark )
3737 return GraphFrame (pv , pe )
3838
3939def _java_api (jsc ):
@@ -55,16 +55,16 @@ class GraphFrame(object):
5555
5656 >>> localVertices = [(1,"A"), (2,"B"), (3, "C")]
5757 >>> localEdges = [(1,2,"love"), (2,1,"hate"), (2,3,"follow")]
58- >>> v = sqlContext .createDataFrame(localVertices, ["id", "name"])
59- >>> e = sqlContext .createDataFrame(localEdges, ["src", "dst", "action"])
58+ >>> v = spark .createDataFrame(localVertices, ["id", "name"])
59+ >>> e = spark .createDataFrame(localEdges, ["src", "dst", "action"])
6060 >>> g = GraphFrame(v, e)
6161 """
6262
6363 def __init__ (self , v , e ):
6464 self ._vertices = v
6565 self ._edges = e
66- self ._sqlContext = v . sql_ctx
67- self ._sc = self ._sqlContext ._sc
66+ self ._spark = SparkSession . getActiveSession ()
67+ self ._sc = self ._spark ._sc
6868 self ._jvm_gf_api = _java_api (self ._sc )
6969
7070 self .ID = self ._jvm_gf_api .ID ()
@@ -142,7 +142,7 @@ def outDegrees(self):
142142 :return: DataFrame with new vertices column "outDegree"
143143 """
144144 jdf = self ._jvm_graph .outDegrees ()
145- return DataFrame (jdf , self ._sqlContext )
145+ return DataFrame (jdf , self ._spark )
146146
147147 @property
148148 def inDegrees (self ):
@@ -156,7 +156,7 @@ def inDegrees(self):
156156 :return: DataFrame with new vertices column "inDegree"
157157 """
158158 jdf = self ._jvm_graph .inDegrees ()
159- return DataFrame (jdf , self ._sqlContext )
159+ return DataFrame (jdf , self ._spark )
160160
161161 @property
162162 def degrees (self ):
@@ -170,7 +170,7 @@ def degrees(self):
170170 :return: DataFrame with new vertices column "degree"
171171 """
172172 jdf = self ._jvm_graph .degrees ()
173- return DataFrame (jdf , self ._sqlContext )
173+ return DataFrame (jdf , self ._spark )
174174
175175 @property
176176 def triplets (self ):
@@ -185,7 +185,7 @@ def triplets(self):
185185 :return: DataFrame with columns 'src', 'edge', and 'dst'
186186 """
187187 jdf = self ._jvm_graph .triplets ()
188- return DataFrame (jdf , self ._sqlContext )
188+ return DataFrame (jdf , self ._spark )
189189
190190 @property
191191 def pregel (self ):
@@ -206,7 +206,7 @@ def find(self, pattern):
206206 :return: DataFrame with one Row for each instance of the motif found
207207 """
208208 jdf = self ._jvm_graph .find (pattern )
209- return DataFrame (jdf , self ._sqlContext )
209+ return DataFrame (jdf , self ._spark )
210210
211211 def filterVertices (self , condition ):
212212 """
@@ -222,7 +222,7 @@ def filterVertices(self, condition):
222222 jdf = self ._jvm_graph .filterVertices (condition ._jc )
223223 else :
224224 raise TypeError ("condition should be string or Column" )
225- return _from_java_gf (jdf , self ._sqlContext )
225+ return _from_java_gf (jdf , self ._spark )
226226
227227 def filterEdges (self , condition ):
228228 """
@@ -237,7 +237,7 @@ def filterEdges(self, condition):
237237 jdf = self ._jvm_graph .filterEdges (condition ._jc )
238238 else :
239239 raise TypeError ("condition should be string or Column" )
240- return _from_java_gf (jdf , self ._sqlContext )
240+ return _from_java_gf (jdf , self ._spark )
241241
242242 def dropIsolatedVertices (self ):
243243 """
@@ -246,7 +246,7 @@ def dropIsolatedVertices(self):
246246 :return: GraphFrame with filtered vertices.
247247 """
248248 jdf = self ._jvm_graph .dropIsolatedVertices ()
249- return _from_java_gf (jdf , self ._sqlContext )
249+ return _from_java_gf (jdf , self ._spark )
250250
251251 def bfs (self , fromExpr , toExpr , edgeFilter = None , maxPathLength = 10 ):
252252 """
@@ -263,7 +263,7 @@ def bfs(self, fromExpr, toExpr, edgeFilter=None, maxPathLength=10):
263263 if edgeFilter is not None :
264264 builder .edgeFilter (edgeFilter )
265265 jdf = builder .run ()
266- return DataFrame (jdf , self ._sqlContext )
266+ return DataFrame (jdf , self ._spark )
267267
268268 def aggregateMessages (self , aggCol , sendToSrc = None , sendToDst = None ):
269269 """
@@ -305,7 +305,7 @@ def aggregateMessages(self, aggCol, sendToSrc=None, sendToDst=None):
305305 jdf = builder .agg (aggCol ._jc )
306306 else :
307307 jdf = builder .agg (aggCol )
308- return DataFrame (jdf , self ._sqlContext )
308+ return DataFrame (jdf , self ._spark )
309309
310310 # Standard algorithms
311311
@@ -329,7 +329,7 @@ def connectedComponents(self, algorithm = "graphframes", checkpointInterval = 2,
329329 .setCheckpointInterval (checkpointInterval ) \
330330 .setBroadcastThreshold (broadcastThreshold ) \
331331 .run ()
332- return DataFrame (jdf , self ._sqlContext )
332+ return DataFrame (jdf , self ._spark )
333333
334334 def labelPropagation (self , maxIter ):
335335 """
@@ -341,7 +341,7 @@ def labelPropagation(self, maxIter):
341341 :return: DataFrame with new vertices column "label"
342342 """
343343 jdf = self ._jvm_graph .labelPropagation ().maxIter (maxIter ).run ()
344- return DataFrame (jdf , self ._sqlContext )
344+ return DataFrame (jdf , self ._spark )
345345
346346 def pageRank (self , resetProbability = 0.15 , sourceId = None , maxIter = None ,
347347 tol = None ):
@@ -369,7 +369,7 @@ def pageRank(self, resetProbability = 0.15, sourceId = None, maxIter = None,
369369 assert tol is not None , "Exactly one of maxIter or tol should be set."
370370 builder = builder .tol (tol )
371371 jgf = builder .run ()
372- return _from_java_gf (jgf , self ._sqlContext )
372+ return _from_java_gf (jgf , self ._spark )
373373
374374 def parallelPersonalizedPageRank (self , resetProbability = 0.15 , sourceIds = None ,
375375 maxIter = None ):
@@ -392,7 +392,7 @@ def parallelPersonalizedPageRank(self, resetProbability = 0.15, sourceIds = None
392392 builder = builder .sourceIds (sourceIds )
393393 builder = builder .maxIter (maxIter )
394394 jgf = builder .run ()
395- return _from_java_gf (jgf , self ._sqlContext )
395+ return _from_java_gf (jgf , self ._spark )
396396
397397 def shortestPaths (self , landmarks ):
398398 """
@@ -404,7 +404,7 @@ def shortestPaths(self, landmarks):
404404 :return: DataFrame with new vertices column "distances"
405405 """
406406 jdf = self ._jvm_graph .shortestPaths ().landmarks (landmarks ).run ()
407- return DataFrame (jdf , self ._sqlContext )
407+ return DataFrame (jdf , self ._spark )
408408
409409 def stronglyConnectedComponents (self , maxIter ):
410410 """
@@ -416,7 +416,7 @@ def stronglyConnectedComponents(self, maxIter):
416416 :return: DataFrame with new vertex column "component"
417417 """
418418 jdf = self ._jvm_graph .stronglyConnectedComponents ().maxIter (maxIter ).run ()
419- return DataFrame (jdf , self ._sqlContext )
419+ return DataFrame (jdf , self ._spark )
420420
421421 def svdPlusPlus (self , rank = 10 , maxIter = 2 , minValue = 0.0 , maxValue = 5.0 ,
422422 gamma1 = 0.007 , gamma2 = 0.007 , gamma6 = 0.005 , gamma7 = 0.015 ):
@@ -433,7 +433,7 @@ def svdPlusPlus(self, rank = 10, maxIter = 2, minValue = 0.0, maxValue = 5.0,
433433 builder .gamma1 (gamma1 ).gamma2 (gamma2 ).gamma6 (gamma6 ).gamma7 (gamma7 )
434434 jdf = builder .run ()
435435 loss = builder .loss ()
436- v = DataFrame (jdf , self ._sqlContext )
436+ v = DataFrame (jdf , self ._spark )
437437 return (v , loss )
438438
439439 def triangleCount (self ):
@@ -445,15 +445,15 @@ def triangleCount(self):
445445 :return: DataFrame with new vertex column "count"
446446 """
447447 jdf = self ._jvm_graph .triangleCount ().run ()
448- return DataFrame (jdf , self ._sqlContext )
448+ return DataFrame (jdf , self ._spark )
449449
450450
451451def _test ():
452452 import doctest
453453 import graphframe
454454 globs = graphframe .__dict__ .copy ()
455455 globs ['sc' ] = SparkContext ('local[4]' , 'PythonTest' , batchSize = 2 )
456- globs ['sqlContext ' ] = SQLContext (globs ['sc' ])
456+ globs ['spark ' ] = SparkSession (globs ['sc' ]). builder . getOrCreate ( )
457457 (failure_count , test_count ) = doctest .testmod (
458458 globs = globs , optionflags = doctest .ELLIPSIS | doctest .NORMALIZE_WHITESPACE )
459459 globs ['sc' ].stop ()
0 commit comments