Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,86 @@ import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.GetStructField
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

import scala.annotation.nowarn
import scala.collection.mutable

object SparkShims {

/**
* Extracts all column references from a Column expression, returning a map from top-level
* prefix to the set of nested field names accessed under that prefix.
*
* For nested column references like "src.id" or "edge.weight", this returns Map("src" ->
* Set("id"), "edge" -> Set("weight")). For top-level references like "src" (the whole struct),
* it returns Map("src" -> Set()).
*
* This handles both unresolved expressions (UnresolvedAttribute, UnresolvedExtractValue) and
* resolved expressions (AttributeReference, GetStructField).
*
* Note: Deeply nested struct access (e.g., "dst.location.city") is not fully parsed. In such
* cases, the prefix is recorded with an empty field set, which causes callers to conservatively
* assume the entire struct is needed. This is the safe/correct fallback behavior.
*
* @param spark
* the SparkSession (unused in Spark 3, included for API compatibility with Spark 4)
* @param expr
* the Column expression to analyze
* @return
* a Map from column prefix to the set of nested field names accessed
*/
@nowarn
def extractColumnReferences(spark: SparkSession, expr: Column): Map[String, Set[String]] = {
val refs = mutable.Map.empty[String, mutable.Set[String]]

def addRef(prefix: String, field: Option[String]): Unit = {
val fields = refs.getOrElseUpdate(prefix, mutable.Set.empty[String])
field.foreach(fields += _)
}

expr.expr.foreach {
// Unresolved: col("src.id") -> UnresolvedAttribute(Seq("src", "id"))
case UnresolvedAttribute(nameParts) if nameParts.nonEmpty =>
addRef(nameParts.head, nameParts.lift(1))

// Unresolved: col("src")("id") -> UnresolvedExtractValue
case UnresolvedExtractValue(child, extraction) =>
child match {
case UnresolvedAttribute(nameParts) if nameParts.nonEmpty =>
extraction match {
case Literal(fieldName: String, _) => addRef(nameParts.head, Some(fieldName))
case Literal(fieldName, _) if fieldName != null =>
// Handle UTF8String (Spark's internal string representation)
addRef(nameParts.head, Some(fieldName.toString))
case _ => addRef(nameParts.head, None) // Unknown field access
}
case _ => // Nested extraction we can't easily parse - conservative fallback
}

// Resolved: AttributeReference for top-level columns
case attr: AttributeReference =>
addRef(attr.name, None)

// Resolved: GetStructField for nested field access like struct.field
// Note: Only handles single-level nesting; deeper nesting falls through to default case
case GetStructField(child, _, Some(fieldName)) =>
child match {
case attr: AttributeReference => addRef(attr.name, Some(fieldName))
case _ => // Deeply nested struct access - conservative fallback (join will be used)
}

case _ => // ignore other expression types
}

refs.map { case (k, v) => k -> v.toSet }.toMap
}

/**
* Apply the given SQL expression (such as `id = 3`) to the field in a column, rather than to
* the column itself.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,90 @@ import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.GetStructField
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.classic.ClassicConversions.*
import org.apache.spark.sql.classic.DataFrame as ClassicDataFrame
import org.apache.spark.sql.classic.Dataset
import org.apache.spark.sql.classic.ExpressionUtils
import org.apache.spark.sql.classic.SparkSession as ClassicSparkSession

import scala.collection.mutable

object SparkShims {

/**
* Extracts all column references from a Column expression, returning a map from top-level
* prefix to the set of nested field names accessed under that prefix.
*
* For nested column references like "src.id" or "edge.weight", this returns Map("src" ->
* Set("id"), "edge" -> Set("weight")). For top-level references like "src" (the whole struct),
* it returns Map("src" -> Set()).
*
* This handles both unresolved expressions (UnresolvedAttribute, UnresolvedExtractValue) and
* resolved expressions (AttributeReference, GetStructField).
*
* Note: Deeply nested struct access (e.g., "dst.location.city") is not fully parsed. In such
* cases, the prefix is recorded with an empty field set, which causes callers to conservatively
* assume the entire struct is needed. This is the safe/correct fallback behavior.
*
* @param spark
* the SparkSession (needed for expression conversion in Spark 4)
* @param expr
* the Column expression to analyze
* @return
* a Map from column prefix to the set of nested field names accessed
*/
def extractColumnReferences(spark: SparkSession, expr: Column): Map[String, Set[String]] = {
val refs = mutable.Map.empty[String, mutable.Set[String]]

def addRef(prefix: String, field: Option[String]): Unit = {
val fields = refs.getOrElseUpdate(prefix, mutable.Set.empty[String])
field.foreach(fields += _)
}

val converted = spark.asInstanceOf[ClassicSparkSession].converter(expr.node)
converted.foreach {
// Unresolved: col("src.id") -> UnresolvedAttribute(Seq("src", "id"))
case UnresolvedAttribute(nameParts) if nameParts.nonEmpty =>
addRef(nameParts.head, nameParts.lift(1))

// Unresolved: col("src")("id") -> UnresolvedExtractValue
case UnresolvedExtractValue(child, extraction) =>
child match {
case UnresolvedAttribute(nameParts) if nameParts.nonEmpty =>
extraction match {
case Literal(fieldName: String, _) => addRef(nameParts.head, Some(fieldName))
case Literal(fieldName, _) if fieldName != null =>
// Handle UTF8String (Spark's internal string representation)
addRef(nameParts.head, Some(fieldName.toString))
case _ => addRef(nameParts.head, None) // Unknown field access
}
case _ => // Nested extraction we can't easily parse - conservative fallback
}

// Resolved: AttributeReference for top-level columns
case attr: AttributeReference =>
addRef(attr.name, None)

// Resolved: GetStructField for nested field access like struct.field
// Note: Only handles single-level nesting; deeper nesting falls through to default case
case GetStructField(child, _, Some(fieldName)) =>
child match {
case attr: AttributeReference => addRef(attr.name, Some(fieldName))
case _ => // Deeply nested struct access - conservative fallback (join will be used)
}

case _ => // ignore other expression types
}

refs.map { case (k, v) => k -> v.toSet }.toMap
}

/**
* Apply the given SQL expression (such as `id = 3`) to the field in a column, rather than to
* the column itself.
Expand Down
56 changes: 50 additions & 6 deletions core/src/main/scala/org/graphframes/lib/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.explode
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.struct
import org.apache.spark.sql.graphframes.SparkShims
import org.graphframes.GraphFrame
import org.graphframes.GraphFrame.*
import org.graphframes.Logging
Expand Down Expand Up @@ -397,9 +398,34 @@ class Pregel(val graph: GraphFrame)
((initialAttributes :+ initialActiveVertexExpression.alias(
Pregel.ACTIVE_FLAG_COL)) ++ initVertexCols): _*)

// Automatic optimization: detect if destination vertex state is needed by analyzing
// the MESSAGE expressions only (not the target ID expressions, since dst.id is always
// available from the edge). If no message expression references dst.* columns,
// we can skip the second join entirely.
// However, if skipMessagesFromNonActiveVertices is enabled, we need dst._pregel_is_active.
// Additionally, if the only dst field referenced is "id", we can still skip since
// dst.id is available from the edge's dst column.
val messageExpressions = sendMsgs.toList.map { case (_, msgExpr) => msgExpr }
val allDstRefs = messageExpressions.flatMap { expr =>
SparkShims.extractColumnReferences(graph.spark, expr).get(DST)
}
val dstPrefixReferenced = allDstRefs.nonEmpty
val dstFieldsReferenced = allDstRefs.flatten.toSet
// We need the dst join if:
// 1. skipMessagesFromNonActiveVertices is enabled (needs dst._pregel_is_active), OR
// 2. dst is referenced AND fields other than just "id" are accessed
// (empty set means whole struct access like col("dst"), which also needs the join)
val needsDstState = skipMessagesFromNonActiveVertices ||
(dstPrefixReferenced && (dstFieldsReferenced.isEmpty || dstFieldsReferenced != Set(ID)))
if (!needsDstState) {
logDebug(
"Optimization: skipping second join (dst state not required by message expressions)")
}

val edges = graph.edges
.select(col(SRC).alias("edge_src"), col(DST).alias("edge_dst"), struct(col("*")).as(EDGE))
.repartition(col("edge_src"), col("edge_dst"))
.repartition(
(if (needsDstState) Seq(col("edge_src"), col("edge_dst")) else Seq(col("edge_src"))): _*)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do need repartition by both in both branches. It was a mistake to add it: it makes shuffle required on both branches (join by src and join by dst). Let's simplify and optimize by always do partitioning by the src only.

.persist(intermediateStorageLevel)

var iteration = 1
Expand Down Expand Up @@ -431,13 +457,31 @@ class Pregel(val graph: GraphFrame)
val currRoundPersistent = scala.collection.mutable.Queue[DataFrame]()
currRoundPersistent.enqueue(currentVertices.persist(intermediateStorageLevel))

var tripletsDF = currentVertices
// Build triplets: start with src vertex state joined with edges
var srcWithEdges = currentVertices
.select(struct(srcCols: _*).as(SRC))
.join(edges, Pregel.src(ID) === col("edge_src"))
.join(
currentVertices.select(struct(dstCols: _*).as(DST)),
col("edge_dst") === Pregel.dst(ID))
.drop(col("edge_src"), col("edge_dst"))

// Optimization: persist srcWithEdges when skipping dst join to avoid recomputation
if (!needsDstState) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why do we need persist here and I'm against it. We already have persisted edges. Adding persisted triplets will blow the memory without any benefits.

srcWithEdges = srcWithEdges.persist(intermediateStorageLevel)
currRoundPersistent.enqueue(srcWithEdges)
}

// Only perform the second join (adding dst vertex state) if needed
var tripletsDF = if (needsDstState) {
srcWithEdges
.join(
currentVertices.select(struct(dstCols: _*).as(DST)),
col("edge_dst") === Pregel.dst(ID))
.drop(col("edge_src"), col("edge_dst"))
} else {
// Skip second join - dst state not needed by any message expression.
// Create a minimal dst struct with just the id from edge_dst for sendMsgToDst to work.
srcWithEdges
.withColumn(DST, struct(col("edge_dst").as(ID)))
.drop(col("edge_src"), col("edge_dst"))
}

if (skipMessagesFromNonActiveVertices) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do not need the second join, this condition can be moved up and be called before the first join that will be a big benefit for some algorithms. Can we do it?

tripletsDF = tripletsDF.filter(
Expand Down
Loading