Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.storage.StorageLevel

object GraphFramesConf {
private val USE_LOCAL_CHECKPOINTS =
SQLConf
.buildConf("spark.graphframes.useLocalCheckpoints")
.doc(""" Tells the connected components algorithm to use local checkpoints (default: "false").
| If set to "true", iterative algorithm will use the checkpointing mechanism to the persistent storage.
| Local checkpoints are faster but can make the whole job less prone to errors.
| @note This option may become default "true" in the future.
|""".stripMargin)
.version("0.9.3")
.booleanConf
.createOptional

private val USE_LABELS_AS_COMPONENTS =
SQLConf
.buildConf("spark.graphframes.useLabelsAsComponents")
Expand Down Expand Up @@ -108,4 +120,6 @@ object GraphFramesConf {
case Some(use) => Some(use.toBoolean)
case _ => None
}

def getUseLocalCheckpoints: Option[Boolean] = get(USE_LOCAL_CHECKPOINTS).map(_.toBoolean)
}
42 changes: 26 additions & 16 deletions core/src/main/scala/org/graphframes/lib/ConnectedComponents.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.graphframes.WithAlgorithmChoice
import org.graphframes.WithBroadcastThreshold
import org.graphframes.WithCheckpointInterval
import org.graphframes.WithIntermediateStorageLevel
import org.graphframes.WithLocalCheckpoints
import org.graphframes.WithMaxIter
import org.graphframes.WithUseLabelsAsComponents

Expand All @@ -54,7 +55,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
with WithBroadcastThreshold
with WithIntermediateStorageLevel
with WithUseLabelsAsComponents
with WithMaxIter {
with WithMaxIter
with WithLocalCheckpoints {

setAlgorithm(GraphFramesConf.getConnectedComponentsAlgorithm.getOrElse(ALGO_GRAPHFRAMES))
setCheckpointInterval(
Expand All @@ -65,6 +67,7 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
GraphFramesConf.getConnectedComponentsStorageLevel.getOrElse(intermediateStorageLevel))
setUseLabelsAsComponents(
GraphFramesConf.getUseLabelsAsComponents.getOrElse(useLabelsAsComponents))
setUseLocalCheckpoints(GraphFramesConf.getUseLocalCheckpoints.getOrElse(useLocalCheckpoints))

/**
* Runs the algorithm.
Expand All @@ -77,7 +80,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
checkpointInterval = checkpointInterval,
intermediateStorageLevel = intermediateStorageLevel,
useLabelsAsComponents = useLabelsAsComponents,
maxIter = maxIter)
maxIter = maxIter,
useLocalCheckpoints = useLocalCheckpoints)
}
}

Expand Down Expand Up @@ -190,7 +194,8 @@ object ConnectedComponents extends Logging {
checkpointInterval: Int,
intermediateStorageLevel: StorageLevel,
useLabelsAsComponents: Boolean,
maxIter: Option[Int]): DataFrame = {
maxIter: Option[Int],
useLocalCheckpoints: Boolean): DataFrame = {
if (runInGraphX) {
return runGraphX(graph, maxIter.getOrElse(Int.MaxValue))
}
Expand All @@ -208,7 +213,8 @@ object ConnectedComponents extends Logging {
logInfo(s"$logPrefix Start connected components with run ID $runId.")

val shouldCheckpoint = checkpointInterval > 0
val checkpointDir: Option[String] = if (shouldCheckpoint) {
val checkpointDir: Option[String] = if (useLocalCheckpoints) { None }
else if (shouldCheckpoint) {
val dir = sc.getCheckpointDir
.map { d =>
new Path(d, s"$CHECKPOINT_NAME_PREFIX-$runId").toString
Expand Down Expand Up @@ -297,19 +303,23 @@ object ConnectedComponents extends Logging {

// checkpointing
if (shouldCheckpoint && (iteration % checkpointInterval == 0)) {
// TODO: remove this after DataFrame.checkpoint is implemented
val out = s"${checkpointDir.get}/$iteration"
ee.write.parquet(out)
// may hit S3 eventually consistent issue
ee = spark.read.parquet(out)

// remove previous checkpoint
if (iteration > checkpointInterval) {
val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}")
path.getFileSystem(sc.hadoopConfiguration).delete(path, true)
}
if (useLocalCheckpoints) {
ee = ee.localCheckpoint(eager = true)
} else {
// TODO: remove this after DataFrame.checkpoint is implemented
val out = s"${checkpointDir.get}/$iteration"
ee.write.parquet(out)
// may hit S3 eventually consistent issue
ee = spark.read.parquet(out)

// remove previous checkpoint
if (iteration > checkpointInterval) {
val path = new Path(s"${checkpointDir.get}/${iteration - checkpointInterval}")
path.getFileSystem(sc.hadoopConfiguration).delete(path, true)
}

System.gc() // hint Spark to clean shuffle directories
System.gc() // hint Spark to clean shuffle directories
}
}

ee.persist(intermediateStorageLevel)
Expand Down
14 changes: 11 additions & 3 deletions core/src/main/scala/org/graphframes/lib/LabelPropagation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.functions._
import org.graphframes.GraphFrame
import org.graphframes.WithAlgorithmChoice
import org.graphframes.WithCheckpointInterval
import org.graphframes.WithLocalCheckpoints
import org.graphframes.WithMaxIter

/**
Expand All @@ -44,14 +45,19 @@ class LabelPropagation private[graphframes] (private val graph: GraphFrame)
extends Arguments
with WithAlgorithmChoice
with WithCheckpointInterval
with WithMaxIter {
with WithMaxIter
with WithLocalCheckpoints {

def run(): DataFrame = {
val maxIterChecked = check(maxIter, "maxIter")
algorithm match {
case "graphx" => LabelPropagation.runInGraphX(graph, maxIterChecked)
case "graphframes" =>
LabelPropagation.runInGraphFrames(graph, maxIterChecked, checkpointInterval)
LabelPropagation.runInGraphFrames(
graph,
maxIterChecked,
checkpointInterval,
useLocalCheckpoints = useLocalCheckpoints)
}
}
}
Expand All @@ -74,7 +80,8 @@ private object LabelPropagation {
graph: GraphFrame,
maxIter: Int,
checkpointInterval: Int,
isDirected: Boolean = true): DataFrame = {
isDirected: Boolean = true,
useLocalCheckpoints: Boolean): DataFrame = {
// Overall:
// - Initial labels - IDs
// - Active vertex col (halt voting) - did the label changed?
Expand All @@ -88,6 +95,7 @@ private object LabelPropagation {
.setCheckpointInterval(checkpointInterval)
.setSkipMessagesFromNonActiveVertices(false)
.setUpdateActiveVertexExpression(col(LABEL_ID) =!= keyWithMaxValue(Pregel.msg))
.setUseLocalCheckpoints(useLocalCheckpoints)

if (isDirected) {
pregel = pregel.sendMsgToDst(col(LABEL_ID))
Expand Down
15 changes: 10 additions & 5 deletions core/src/main/scala/org/graphframes/lib/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.functions.struct
import org.graphframes.GraphFrame
import org.graphframes.GraphFrame._
import org.graphframes.Logging
import org.graphframes.WithLocalCheckpoints

import java.io.IOException
import scala.util.control.Breaks.break
Expand Down Expand Up @@ -80,7 +81,7 @@ import scala.util.control.Breaks.breakable
* <a href="https://doi.org/10.1145/1807167.1807184"> Malewicz et al., Pregel: a system for
* large-scale graph processing. </a>
*/
class Pregel(val graph: GraphFrame) extends Logging {
class Pregel(val graph: GraphFrame) extends Logging with WithLocalCheckpoints {

private val withVertexColumnList = collection.mutable.ListBuffer.empty[(String, Column, Column)]

Expand Down Expand Up @@ -342,7 +343,7 @@ class Pregel(val graph: GraphFrame) extends Logging {

val shouldCheckpoint = checkpointInterval > 0

if (shouldCheckpoint && graph.spark.sparkContext.getCheckpointDir.isEmpty) {
if (shouldCheckpoint && graph.spark.sparkContext.getCheckpointDir.isEmpty && !useLocalCheckpoints) {
// Spark Connect workaround
graph.spark.conf.getOption("spark.checkpoint.dir") match {
case Some(d) => graph.spark.sparkContext.setCheckpointDir(d)
Expand Down Expand Up @@ -394,9 +395,13 @@ class Pregel(val graph: GraphFrame) extends Logging {
updateActiveVertexExpression.alias(Pregel.ACTIVE_FLAG_COL)) ++ updateVertexCols): _*)

if (shouldCheckpoint && iteration % checkpointInterval == 0) {
// do checkpoint, use lazy checkpoint because later we will materialize this DF.
newVertexUpdateColDF = newVertexUpdateColDF.checkpoint(eager = false)
// TODO: remove last checkpoint file.
if (useLocalCheckpoints) {
newVertexUpdateColDF = newVertexUpdateColDF.localCheckpoint(eager = false)
} else {
// do checkpoint, use lazy checkpoint because later we will materialize this DF.
newVertexUpdateColDF = newVertexUpdateColDF.checkpoint(eager = false)
// TODO: remove last checkpoint file.
}
}
newVertexUpdateColDF.cache()
newVertexUpdateColDF.count() // materialize it
Expand Down
16 changes: 13 additions & 3 deletions core/src/main/scala/org/graphframes/lib/ShortestPaths.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.graphframes.GraphFramesUnreachableException
import org.graphframes.Logging
import org.graphframes.WithAlgorithmChoice
import org.graphframes.WithCheckpointInterval
import org.graphframes.WithLocalCheckpoints

import java.util
import scala.jdk.CollectionConverters._
Expand All @@ -54,7 +55,8 @@ import scala.jdk.CollectionConverters._
class ShortestPaths private[graphframes] (private val graph: GraphFrame)
extends Arguments
with WithAlgorithmChoice
with WithCheckpointInterval {
with WithCheckpointInterval
with WithLocalCheckpoints {
import org.graphframes.lib.ShortestPaths._

private var lmarks: Option[Seq[Any]] = None
Expand All @@ -79,7 +81,12 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame)
val lmarksChecked = check(lmarks, "landmarks")
algorithm match {
case ALGO_GRAPHX => runInGraphX(graph, lmarksChecked)
case ALGO_GRAPHFRAMES => runInGraphFrames(graph, lmarksChecked, checkpointInterval)
case ALGO_GRAPHFRAMES =>
runInGraphFrames(
graph,
lmarksChecked,
checkpointInterval,
useLocalCheckpoints = useLocalCheckpoints)
case _ => throw new GraphFramesUnreachableException()
}
}
Expand Down Expand Up @@ -109,7 +116,8 @@ private object ShortestPaths extends Logging {
graph: GraphFrame,
landmarks: Seq[Any],
checkpointInterval: Int,
isDirected: Boolean = true): DataFrame = {
isDirected: Boolean = true,
useLocalCheckpoints: Boolean): DataFrame = {
logWarn("The GraphFrames based implementation is slow and considered experimental!")
val vertexType = graph.vertices.schema(GraphFrame.ID).dataType

Expand Down Expand Up @@ -202,6 +210,8 @@ private object ShortestPaths extends Logging {
.setUpdateActiveVertexExpression(updateActiveVierticesExpr)
.setStopIfAllNonActiveVertices(true)
.setSkipMessagesFromNonActiveVertices(true)
.setCheckpointInterval(checkpointInterval)
.setUseLocalCheckpoints(useLocalCheckpoints)

// Experimental feature
if (isDirected) {
Expand Down
34 changes: 34 additions & 0 deletions core/src/main/scala/org/graphframes/mixins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,37 @@ private[graphframes] trait WithUseLabelsAsComponents {
*/
def getUseLabelsAsComponents: Boolean = useLabelsAsComponents
}

/**
* Provides support for local checkpoints in Spark computations.
*
* Local checkpoints offer a faster alternative to regular checkpoints as they don't require
* configuration of checkpointDir in persistent storage (like HDFS or S3). While being more
* performant, local checkpoints are less reliable since they don't survive node failures and the
* data is not persisted across multiple nodes.
*/
private[graphframes] trait WithLocalCheckpoints {
protected var useLocalCheckpoints: Boolean = false

/**
* Sets whether to use local checkpoints instead of regular checkpoints (default: false). Local
* checkpoints are faster but less reliable as they don't survive node failures.
*
* @param value
* true to use local checkpoints, false for regular checkpoints
* @return
* this instance
*/
def setUseLocalCheckpoints(value: Boolean): this.type = {
useLocalCheckpoints = value
this
}

/**
* Gets whether local checkpoints are being used instead of regular checkpoints.
*
* @return
* true if local checkpoints are enabled, false otherwise
*/
def getUseLocalCheckpoints: Boolean = useLocalCheckpoints
}
Loading