Skip to content

Commit 1e1dbd3

Browse files
author
William Tang
authored
Python: Replace SQLContext with SparkSession (#424)
* GraphFrame: Use SparkSession instead of SQLContext * ci: Remove spark 2.4.8 Spark 2.4.x does not support SparkSession.getActiveSession(). According to https://spark.apache.org/versioning-policy.html, spark 2.4.x is EOL and the last version was released 5 years ago. There's little meaning to support it.
1 parent 100fb01 commit 1e1dbd3

File tree

10 files changed

+78
-91
lines changed

10 files changed

+78
-91
lines changed

.github/workflows/python-ci.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@ jobs:
1616
- spark-version: 3.0.3
1717
scala-version: 2.12.12
1818
python-version: 3.8
19-
- spark-version: 2.4.8
20-
scala-version: 2.11.12
21-
python-version: 3.7
2219
runs-on: ubuntu-20.04
2320
env:
2421
# define Java options for both official sbt and sbt-extras

.github/workflows/scala-ci.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ jobs:
1414
scala-version: 2.12.12
1515
- spark-version: 3.0.3
1616
scala-version: 2.12.12
17-
- spark-version: 2.4.8
18-
scala-version: 2.11.12
1917
runs-on: ubuntu-20.04
2018
env:
2119
# define Java options for both official sbt and sbt-extras

build.sbt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ val defaultScalaVer = sparkBranch match {
1212
case "3.2" => "2.12.15"
1313
case "3.1" => "2.12.15"
1414
case "3.0" => "2.12.15"
15-
case "2.4" => "2.11.12"
1615
case _ => throw new IllegalArgumentException(s"Unsupported Spark version: $sparkVer.")
1716
}
1817
val scalaVer = sys.props.getOrElse("scala.version", defaultScalaVer)

dev/release.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def verify(prompt, interactive):
3939
@click.option("--publish-docs", type=bool, default=PUBLISH_DOCS_DEFAULT, show_default=True,
4040
help="Publish docs to github-pages.")
4141
@click.option("--spark-version", multiple=True, show_default=True,
42-
default=["2.4.8", "3.0.3", "3.1.3", "3.2.2", "3.3.0"])
42+
default=["3.0.3", "3.1.3", "3.2.2", "3.3.0"])
4343
def main(release_version, next_version, publish_to, no_prompt, git_remote, publish_docs,
4444
spark_version):
4545
interactive = not no_prompt

python/graphframes/examples/belief_propagation.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,12 @@
1717

1818
import math
1919

20-
from pyspark import SparkConf, SparkContext
21-
from pyspark.sql import SQLContext, functions as sqlfunctions, types
22-
20+
# Import subpackage examples here explicitly so that
21+
# this module can be run directly with spark-submit.
22+
import graphframes.examples
2323
from graphframes import GraphFrame
2424
from graphframes.lib import AggregateMessages as AM
25-
# Import subpackage examples here explicitly so that this module can be
26-
# run directly with spark-submit.
27-
import graphframes.examples
25+
from pyspark.sql import SparkSession, functions as sqlfunctions, types
2826

2927
__all__ = ['BeliefPropagation']
3028

@@ -151,13 +149,11 @@ def _sigmoid(x):
151149

152150
def main():
153151
"""Run the belief propagation algorithm for an example problem."""
154-
# setup context
155-
conf = SparkConf().setAppName("BeliefPropagation example")
156-
sc = SparkContext.getOrCreate(conf)
157-
sql = SQLContext.getOrCreate(sc)
152+
# setup spark session
153+
spark = SparkSession.builder.appName("BeliefPropagation example").getOrCreate()
158154

159155
# create graphical model g of size 3 x 3
160-
g = graphframes.examples.Graphs(sql).gridIsingModel(3)
156+
g = graphframes.examples.Graphs(spark).gridIsingModel(3)
161157
print("Original Ising model:")
162158
g.vertices.show()
163159
g.edges.show()
@@ -171,7 +167,8 @@ def main():
171167
print("Done with BP. Final beliefs after {} iterations:".format(numIter))
172168
beliefs.show()
173169

174-
sc.stop()
170+
spark.stop()
171+
175172

176173
if __name__ == '__main__':
177174
main()

python/graphframes/examples/graphs.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,17 @@
2727
class Graphs(object):
2828
"""Example GraphFrames for testing the API
2929
30-
:param sqlContext: SQLContext
30+
:param spark: SparkSession
3131
"""
3232

33-
def __init__(self, sqlContext):
34-
self._sql = sqlContext
35-
self._sc = sqlContext._sc
33+
def __init__(self, spark):
34+
self._spark = spark
35+
self._sc = spark._sc
3636

3737
def friends(self):
3838
"""A GraphFrame of friends in a (fake) social network."""
39-
sqlContext = self._sql
4039
# Vertex DataFrame
41-
v = sqlContext.createDataFrame([
40+
v = self._spark.createDataFrame([
4241
("a", "Alice", 34),
4342
("b", "Bob", 36),
4443
("c", "Charlie", 30),
@@ -47,7 +46,7 @@ def friends(self):
4746
("f", "Fanny", 36)
4847
], ["id", "name", "age"])
4948
# Edge DataFrame
50-
e = sqlContext.createDataFrame([
49+
e = self._spark.createDataFrame([
5150
("a", "b", "friend"),
5251
("b", "c", "follow"),
5352
("c", "b", "follow"),
@@ -92,7 +91,7 @@ def gridIsingModel(self, n, vStd=1.0, eStd=1.0):
9291
.format(n))
9392

9493
# create coodinates grid
95-
coordinates = self._sql.createDataFrame(
94+
coordinates = self._spark.createDataFrame(
9695
itertools.product(range(n), range(n)),
9796
schema=('i', 'j'))
9897

python/graphframes/graphframe.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,20 @@
2020
basestring = str
2121

2222
from pyspark import SparkContext
23-
from pyspark.sql import Column, DataFrame, SQLContext
23+
from pyspark.sql import Column, DataFrame, SparkSession
2424
from pyspark.storagelevel import StorageLevel
2525

2626
from graphframes.lib import Pregel
2727

2828

29-
def _from_java_gf(jgf, sqlContext):
29+
def _from_java_gf(jgf, spark):
3030
"""
3131
(internal) creates a python GraphFrame wrapper from a java GraphFrame.
3232
3333
:param jgf:
3434
"""
35-
pv = DataFrame(jgf.vertices(), sqlContext)
36-
pe = DataFrame(jgf.edges(), sqlContext)
35+
pv = DataFrame(jgf.vertices(), spark)
36+
pe = DataFrame(jgf.edges(), spark)
3737
return GraphFrame(pv, pe)
3838

3939
def _java_api(jsc):
@@ -55,16 +55,16 @@ class GraphFrame(object):
5555
5656
>>> localVertices = [(1,"A"), (2,"B"), (3, "C")]
5757
>>> localEdges = [(1,2,"love"), (2,1,"hate"), (2,3,"follow")]
58-
>>> v = sqlContext.createDataFrame(localVertices, ["id", "name"])
59-
>>> e = sqlContext.createDataFrame(localEdges, ["src", "dst", "action"])
58+
>>> v = spark.createDataFrame(localVertices, ["id", "name"])
59+
>>> e = spark.createDataFrame(localEdges, ["src", "dst", "action"])
6060
>>> g = GraphFrame(v, e)
6161
"""
6262

6363
def __init__(self, v, e):
6464
self._vertices = v
6565
self._edges = e
66-
self._sqlContext = v.sql_ctx
67-
self._sc = self._sqlContext._sc
66+
self._spark = SparkSession.getActiveSession()
67+
self._sc = self._spark._sc
6868
self._jvm_gf_api = _java_api(self._sc)
6969

7070
self.ID = self._jvm_gf_api.ID()
@@ -142,7 +142,7 @@ def outDegrees(self):
142142
:return: DataFrame with new vertices column "outDegree"
143143
"""
144144
jdf = self._jvm_graph.outDegrees()
145-
return DataFrame(jdf, self._sqlContext)
145+
return DataFrame(jdf, self._spark)
146146

147147
@property
148148
def inDegrees(self):
@@ -156,7 +156,7 @@ def inDegrees(self):
156156
:return: DataFrame with new vertices column "inDegree"
157157
"""
158158
jdf = self._jvm_graph.inDegrees()
159-
return DataFrame(jdf, self._sqlContext)
159+
return DataFrame(jdf, self._spark)
160160

161161
@property
162162
def degrees(self):
@@ -170,7 +170,7 @@ def degrees(self):
170170
:return: DataFrame with new vertices column "degree"
171171
"""
172172
jdf = self._jvm_graph.degrees()
173-
return DataFrame(jdf, self._sqlContext)
173+
return DataFrame(jdf, self._spark)
174174

175175
@property
176176
def triplets(self):
@@ -185,7 +185,7 @@ def triplets(self):
185185
:return: DataFrame with columns 'src', 'edge', and 'dst'
186186
"""
187187
jdf = self._jvm_graph.triplets()
188-
return DataFrame(jdf, self._sqlContext)
188+
return DataFrame(jdf, self._spark)
189189

190190
@property
191191
def pregel(self):
@@ -206,7 +206,7 @@ def find(self, pattern):
206206
:return: DataFrame with one Row for each instance of the motif found
207207
"""
208208
jdf = self._jvm_graph.find(pattern)
209-
return DataFrame(jdf, self._sqlContext)
209+
return DataFrame(jdf, self._spark)
210210

211211
def filterVertices(self, condition):
212212
"""
@@ -222,7 +222,7 @@ def filterVertices(self, condition):
222222
jdf = self._jvm_graph.filterVertices(condition._jc)
223223
else:
224224
raise TypeError("condition should be string or Column")
225-
return _from_java_gf(jdf, self._sqlContext)
225+
return _from_java_gf(jdf, self._spark)
226226

227227
def filterEdges(self, condition):
228228
"""
@@ -237,7 +237,7 @@ def filterEdges(self, condition):
237237
jdf = self._jvm_graph.filterEdges(condition._jc)
238238
else:
239239
raise TypeError("condition should be string or Column")
240-
return _from_java_gf(jdf, self._sqlContext)
240+
return _from_java_gf(jdf, self._spark)
241241

242242
def dropIsolatedVertices(self):
243243
"""
@@ -246,7 +246,7 @@ def dropIsolatedVertices(self):
246246
:return: GraphFrame with filtered vertices.
247247
"""
248248
jdf = self._jvm_graph.dropIsolatedVertices()
249-
return _from_java_gf(jdf, self._sqlContext)
249+
return _from_java_gf(jdf, self._spark)
250250

251251
def bfs(self, fromExpr, toExpr, edgeFilter=None, maxPathLength=10):
252252
"""
@@ -263,7 +263,7 @@ def bfs(self, fromExpr, toExpr, edgeFilter=None, maxPathLength=10):
263263
if edgeFilter is not None:
264264
builder.edgeFilter(edgeFilter)
265265
jdf = builder.run()
266-
return DataFrame(jdf, self._sqlContext)
266+
return DataFrame(jdf, self._spark)
267267

268268
def aggregateMessages(self, aggCol, sendToSrc=None, sendToDst=None):
269269
"""
@@ -305,7 +305,7 @@ def aggregateMessages(self, aggCol, sendToSrc=None, sendToDst=None):
305305
jdf = builder.agg(aggCol._jc)
306306
else:
307307
jdf = builder.agg(aggCol)
308-
return DataFrame(jdf, self._sqlContext)
308+
return DataFrame(jdf, self._spark)
309309

310310
# Standard algorithms
311311

@@ -329,7 +329,7 @@ def connectedComponents(self, algorithm = "graphframes", checkpointInterval = 2,
329329
.setCheckpointInterval(checkpointInterval) \
330330
.setBroadcastThreshold(broadcastThreshold) \
331331
.run()
332-
return DataFrame(jdf, self._sqlContext)
332+
return DataFrame(jdf, self._spark)
333333

334334
def labelPropagation(self, maxIter):
335335
"""
@@ -341,7 +341,7 @@ def labelPropagation(self, maxIter):
341341
:return: DataFrame with new vertices column "label"
342342
"""
343343
jdf = self._jvm_graph.labelPropagation().maxIter(maxIter).run()
344-
return DataFrame(jdf, self._sqlContext)
344+
return DataFrame(jdf, self._spark)
345345

346346
def pageRank(self, resetProbability = 0.15, sourceId = None, maxIter = None,
347347
tol = None):
@@ -369,7 +369,7 @@ def pageRank(self, resetProbability = 0.15, sourceId = None, maxIter = None,
369369
assert tol is not None, "Exactly one of maxIter or tol should be set."
370370
builder = builder.tol(tol)
371371
jgf = builder.run()
372-
return _from_java_gf(jgf, self._sqlContext)
372+
return _from_java_gf(jgf, self._spark)
373373

374374
def parallelPersonalizedPageRank(self, resetProbability = 0.15, sourceIds = None,
375375
maxIter = None):
@@ -392,7 +392,7 @@ def parallelPersonalizedPageRank(self, resetProbability = 0.15, sourceIds = None
392392
builder = builder.sourceIds(sourceIds)
393393
builder = builder.maxIter(maxIter)
394394
jgf = builder.run()
395-
return _from_java_gf(jgf, self._sqlContext)
395+
return _from_java_gf(jgf, self._spark)
396396

397397
def shortestPaths(self, landmarks):
398398
"""
@@ -404,7 +404,7 @@ def shortestPaths(self, landmarks):
404404
:return: DataFrame with new vertices column "distances"
405405
"""
406406
jdf = self._jvm_graph.shortestPaths().landmarks(landmarks).run()
407-
return DataFrame(jdf, self._sqlContext)
407+
return DataFrame(jdf, self._spark)
408408

409409
def stronglyConnectedComponents(self, maxIter):
410410
"""
@@ -416,7 +416,7 @@ def stronglyConnectedComponents(self, maxIter):
416416
:return: DataFrame with new vertex column "component"
417417
"""
418418
jdf = self._jvm_graph.stronglyConnectedComponents().maxIter(maxIter).run()
419-
return DataFrame(jdf, self._sqlContext)
419+
return DataFrame(jdf, self._spark)
420420

421421
def svdPlusPlus(self, rank = 10, maxIter = 2, minValue = 0.0, maxValue = 5.0,
422422
gamma1 = 0.007, gamma2 = 0.007, gamma6 = 0.005, gamma7 = 0.015):
@@ -433,7 +433,7 @@ def svdPlusPlus(self, rank = 10, maxIter = 2, minValue = 0.0, maxValue = 5.0,
433433
builder.gamma1(gamma1).gamma2(gamma2).gamma6(gamma6).gamma7(gamma7)
434434
jdf = builder.run()
435435
loss = builder.loss()
436-
v = DataFrame(jdf, self._sqlContext)
436+
v = DataFrame(jdf, self._spark)
437437
return (v, loss)
438438

439439
def triangleCount(self):
@@ -445,15 +445,15 @@ def triangleCount(self):
445445
:return: DataFrame with new vertex column "count"
446446
"""
447447
jdf = self._jvm_graph.triangleCount().run()
448-
return DataFrame(jdf, self._sqlContext)
448+
return DataFrame(jdf, self._spark)
449449

450450

451451
def _test():
452452
import doctest
453453
import graphframe
454454
globs = graphframe.__dict__.copy()
455455
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
456-
globs['sqlContext'] = SQLContext(globs['sc'])
456+
globs['spark'] = SparkSession(globs['sc']).builder.getOrCreate()
457457
(failure_count, test_count) = doctest.testmod(
458458
globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
459459
globs['sc'].stop()

python/graphframes/lib/aggregate_messages.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717

1818
from pyspark import SparkContext
19-
from pyspark.sql import DataFrame, functions as sqlfunctions
19+
from pyspark.sql import DataFrame, functions as sqlfunctions, SparkSession
2020

2121

2222
def _java_api(jsc):
@@ -77,7 +77,7 @@ def getCachedDataFrame(df):
7777
WARNING: This is NOT the same as `DataFrame.cache()`.
7878
The original DataFrame will NOT be cached.
7979
"""
80-
sqlContext = df.sql_ctx
81-
jvm_gf_api = _java_api(sqlContext._sc)
80+
spark = SparkSession.getActiveSession()
81+
jvm_gf_api = _java_api(spark._sc)
8282
jdf = jvm_gf_api.aggregateMessages().getCachedDataFrame(df._jdf)
83-
return DataFrame(jdf, sqlContext)
83+
return DataFrame(jdf, spark)

python/graphframes/lib/pregel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
if sys.version > '3':
2020
basestring = str
2121

22-
from pyspark.sql import DataFrame
22+
from pyspark.sql import DataFrame, SparkSession
2323
from pyspark.sql.functions import col
24-
from pyspark.ml.wrapper import JavaWrapper, _jvm
24+
from pyspark.ml.wrapper import JavaWrapper
2525

2626

2727
class Pregel(JavaWrapper):
@@ -169,7 +169,7 @@ def run(self):
169169
170170
:return: the result vertex DataFrame from the final iteration including both original and additional columns.
171171
"""
172-
return DataFrame(self._java_obj.run(), self.graph.vertices.sql_ctx)
172+
return DataFrame(self._java_obj.run(), SparkSession.getActiveSession())
173173

174174
@staticmethod
175175
def msg():

0 commit comments

Comments
 (0)