-
Notifications
You must be signed in to change notification settings - Fork 262
feat(pregel): automatically skip second join when dst columns not needed #795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
49a1c17
10cbd48
2a8dac4
d33a821
bf525c3
cee85ee
ce37595
9cb47f1
b97a715
7f32626
be54ce9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ import org.apache.spark.sql.functions.col | |
| import org.apache.spark.sql.functions.explode | ||
| import org.apache.spark.sql.functions.lit | ||
| import org.apache.spark.sql.functions.struct | ||
| import org.apache.spark.sql.graphframes.SparkShims | ||
| import org.graphframes.GraphFrame | ||
| import org.graphframes.GraphFrame.* | ||
| import org.graphframes.Logging | ||
|
|
@@ -397,9 +398,34 @@ class Pregel(val graph: GraphFrame) | |
| ((initialAttributes :+ initialActiveVertexExpression.alias( | ||
| Pregel.ACTIVE_FLAG_COL)) ++ initVertexCols): _*) | ||
|
|
||
| // Automatic optimization: detect if destination vertex state is needed by analyzing | ||
| // the MESSAGE expressions only (not the target ID expressions, since dst.id is always | ||
| // available from the edge). If no message expression references dst.* columns, | ||
| // we can skip the second join entirely. | ||
| // However, if skipMessagesFromNonActiveVertices is enabled, we need dst._pregel_is_active. | ||
| // Additionally, if the only dst field referenced is "id", we can still skip since | ||
| // dst.id is available from the edge's dst column. | ||
| val messageExpressions = sendMsgs.toList.map { case (_, msgExpr) => msgExpr } | ||
| val allDstRefs = messageExpressions.flatMap { expr => | ||
| SparkShims.extractColumnReferences(graph.spark, expr).get(DST) | ||
| } | ||
| val dstPrefixReferenced = allDstRefs.nonEmpty | ||
| val dstFieldsReferenced = allDstRefs.flatten.toSet | ||
| // We need the dst join if: | ||
| // 1. skipMessagesFromNonActiveVertices is enabled (needs dst._pregel_is_active), OR | ||
| // 2. dst is referenced AND fields other than just "id" are accessed | ||
| // (empty set means whole struct access like col("dst"), which also needs the join) | ||
| val needsDstState = skipMessagesFromNonActiveVertices || | ||
| (dstPrefixReferenced && (dstFieldsReferenced.isEmpty || dstFieldsReferenced != Set(ID))) | ||
| if (!needsDstState) { | ||
| logDebug( | ||
| "Optimization: skipping second join (dst state not required by message expressions)") | ||
| } | ||
|
|
||
| val edges = graph.edges | ||
| .select(col(SRC).alias("edge_src"), col(DST).alias("edge_dst"), struct(col("*")).as(EDGE)) | ||
| .repartition(col("edge_src"), col("edge_dst")) | ||
| .repartition( | ||
| (if (needsDstState) Seq(col("edge_src"), col("edge_dst")) else Seq(col("edge_src"))): _*) | ||
| .persist(intermediateStorageLevel) | ||
|
|
||
| var iteration = 1 | ||
|
|
@@ -431,13 +457,31 @@ class Pregel(val graph: GraphFrame) | |
| val currRoundPersistent = scala.collection.mutable.Queue[DataFrame]() | ||
| currRoundPersistent.enqueue(currentVertices.persist(intermediateStorageLevel)) | ||
|
|
||
| var tripletsDF = currentVertices | ||
| // Build triplets: start with src vertex state joined with edges | ||
| var srcWithEdges = currentVertices | ||
| .select(struct(srcCols: _*).as(SRC)) | ||
| .join(edges, Pregel.src(ID) === col("edge_src")) | ||
| .join( | ||
| currentVertices.select(struct(dstCols: _*).as(DST)), | ||
| col("edge_dst") === Pregel.dst(ID)) | ||
| .drop(col("edge_src"), col("edge_dst")) | ||
|
|
||
| // Optimization: persist srcWithEdges when skipping dst join to avoid recomputation | ||
| if (!needsDstState) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand why do we need persist here and I'm against it. We already have persisted edges. Adding persisted triplets will blow the memory without any benefits. |
||
| srcWithEdges = srcWithEdges.persist(intermediateStorageLevel) | ||
| currRoundPersistent.enqueue(srcWithEdges) | ||
| } | ||
|
|
||
| // Only perform the second join (adding dst vertex state) if needed | ||
| var tripletsDF = if (needsDstState) { | ||
| srcWithEdges | ||
| .join( | ||
| currentVertices.select(struct(dstCols: _*).as(DST)), | ||
| col("edge_dst") === Pregel.dst(ID)) | ||
| .drop(col("edge_src"), col("edge_dst")) | ||
| } else { | ||
| // Skip second join - dst state not needed by any message expression. | ||
| // Create a minimal dst struct with just the id from edge_dst for sendMsgToDst to work. | ||
| srcWithEdges | ||
| .withColumn(DST, struct(col("edge_dst").as(ID))) | ||
| .drop(col("edge_src"), col("edge_dst")) | ||
| } | ||
|
|
||
| if (skipMessagesFromNonActiveVertices) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we do not need the second join, this condition can be moved up and be called before the first join that will be a big benefit for some algorithms. Can we do it? |
||
| tripletsDF = tripletsDF.filter( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do need repartition by both in both branches. It was a mistake to add it: it makes shuffle required on both branches (join by src and join by dst). Let's simplify and optimize by always do partitioning by the
srconly.