Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package org.apache.spark.sql.graphframes

import org.apache.spark.internal.config.ConfigEntry
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.storage.StorageLevel

object GraphFramesConf {
private val CONNECTED_COMPONENTS_ALGORITHM =
SQLConf
.buildConf("spark.graphframes.connectedComponents.algorithm")
.doc(""" Sets the connected components algorithm to use (default: "graphframes"). Supported algorithms
| - "graphframes": Uses alternating large star and small star iterations proposed in
| [[http://dx.doi.org/10.1145/2670979.2670997 Connected Components in MapReduce and Beyond]]
| with skewed join optimization.
| - "graphx": Converts the graph to a GraphX graph and then uses the connected components
| implementation in GraphX.
| @see org.graphframes.lib.ConnectedComponents.supportedAlgorithms""".stripMargin)
.version("0.9.0")
.stringConf
.createOptional

private val CONNECTED_COMPONENTS_BROADCAST_THRESHOLD =
SQLConf
.buildConf("spark.graphframes.connectedComponents.broadcastthreshold")
.doc(""" Sets broadcast threshold in propagating component assignments (default: 1000000). If a node
| degree is greater than this threshold at some iteration, its component assignment will be
| collected and then broadcasted back to propagate the assignment to its neighbors. Otherwise,
| the assignment propagation is done by a normal Spark join. This parameter is only used when
| the algorithm is set to "graphframes".""".stripMargin)
.version("0.9.0")
.intConf
.createOptional

private val CONNECTED_COMPONENTS_CHECKPOINT_INTERVAL =
SQLConf
.buildConf("spark.graphframes.connectedComponents.checkpointinterval")
.doc(""" Sets checkpoint interval in terms of number of iterations (default: 2). Checkpointing
| regularly helps recover from failures, clean shuffle files, shorten the lineage of the
| computation graph, and reduce the complexity of plan optimization. As of Spark 2.0, the
| complexity of plan optimization would grow exponentially without checkpointing. Hence,
| disabling or setting longer-than-default checkpoint intervals are not recommended. Checkpoint
| data is saved under `org.apache.spark.SparkContext.getCheckpointDir` with prefix
| "connected-components". If the checkpoint directory is not set, this throws a
| `java.io.IOException`. Set a nonpositive value to disable checkpointing. This parameter is
| only used when the algorithm is set to "graphframes". Its default value might change in the
| future.
| @see `org.apache.spark.SparkContext.setCheckpointDir` in Spark API doc""".stripMargin)
.version("0.9.0")
.intConf
.createOptional

private val CONNECTED_COMPONENTS_INTERMEDIATE_STORAGE_LEVEL =
SQLConf
.buildConf("spark.graphframes.connectedComponents.intermediatestoragelevel")
.doc("Sets storage level for intermediate datasets that require multiple passes (default: ``MEMORY_AND_DISK``).")
.version("0.9.0")
.stringConf
.createOptional

private def get(entry: ConfigEntry[_]): Option[String] = {
try {
Option(SparkSession.getActiveSession.get.conf.get(entry.key))
} catch {
case _: NoSuchElementException => None
}
}

def getConnectedComponentsAlgorithm: Option[String] = {
get(CONNECTED_COMPONENTS_ALGORITHM) match {
case Some(threshold) => Some(threshold.toLowerCase)
case _ => None
}
}

def getConnectedComponentsBroadcastThreshold: Option[Int] = {
get(CONNECTED_COMPONENTS_BROADCAST_THRESHOLD) match {
case Some(threshold) => Some(threshold.toInt)
case _ => None
}
}

def getConnectedComponentsCheckpointInterval: Option[Int] = {
get(CONNECTED_COMPONENTS_CHECKPOINT_INTERVAL) match {
case Some(interval) => Some(interval.toInt)
case _ => None
}
}

def getConnectedComponentsStorageLevel: Option[StorageLevel] = {
get(CONNECTED_COMPONENTS_INTERMEDIATE_STORAGE_LEVEL) match {
case Some(level) => Some(StorageLevel.fromString(level.toUpperCase))
case _ => None
}
}
}
60 changes: 12 additions & 48 deletions src/main/scala/org/graphframes/lib/ConnectedComponents.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.graphframes.GraphFramesConf
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.storage.StorageLevel
import org.graphframes.GraphFrame
import org.graphframes.Logging
import org.graphframes.WithAlgorithmChoice
import org.graphframes.WithBroadcastThreshold
import org.graphframes.WithCheckpointInterval
import org.graphframes.WithIntermediateStorageLevel
import org.graphframes.WithMaxIter

import java.io.IOException
Expand All @@ -47,56 +50,17 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
with Logging
with WithAlgorithmChoice
with WithCheckpointInterval
with WithBroadcastThreshold
with WithIntermediateStorageLevel
with WithMaxIter {

private var broadcastThreshold: Int = 1000000
setAlgorithm(ALGO_GRAPHFRAMES)

/**
* Sets broadcast threshold in propagating component assignments (default: 1000000). If a node
* degree is greater than this threshold at some iteration, its component assignment will be
* collected and then broadcasted back to propagate the assignment to its neighbors. Otherwise,
* the assignment propagation is done by a normal Spark join. This parameter is only used when
* the algorithm is set to "graphframes".
*/
def setBroadcastThreshold(value: Int): this.type = {
require(value >= 0, s"Broadcast threshold must be non-negative but got $value.")
broadcastThreshold = value
this
}

// python-friendly setter
private[graphframes] def setBroadcastThreshold(value: java.lang.Integer): this.type = {
setBroadcastThreshold(value.toInt)
}

/**
* Gets broadcast threshold in propagating component assignment.
* @see
* [[org.graphframes.lib.ConnectedComponents.setBroadcastThreshold]]
*/
def getBroadcastThreshold: Int = broadcastThreshold

// python-friendly setter
private[graphframes] def setCheckpointInterval(value: java.lang.Integer): this.type = {
setCheckpointInterval(value.toInt)
}

private var intermediateStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK

/**
* Sets storage level for intermediate datasets that require multiple passes (default:
* ``MEMORY_AND_DISK``).
*/
def setIntermediateStorageLevel(value: StorageLevel): this.type = {
intermediateStorageLevel = value
this
}

/**
* Gets storage level for intermediate datasets that require multiple passes.
*/
def getIntermediateStorageLevel: StorageLevel = intermediateStorageLevel
setAlgorithm(GraphFramesConf.getConnectedComponentsAlgorithm.getOrElse(ALGO_GRAPHFRAMES))
setCheckpointInterval(
GraphFramesConf.getConnectedComponentsCheckpointInterval.getOrElse(checkpointInterval))
setBroadcastThreshold(
GraphFramesConf.getConnectedComponentsBroadcastThreshold.getOrElse(broadcastThreshold))
setIntermediateStorageLevel(
GraphFramesConf.getConnectedComponentsStorageLevel.getOrElse(intermediateStorageLevel))

/**
* Runs the algorithm.
Expand Down
56 changes: 56 additions & 0 deletions src/main/scala/org/graphframes/mixins.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package org.graphframes

import org.apache.spark.storage.StorageLevel

private[graphframes] trait WithAlgorithmChoice {
protected val ALGO_GRAPHX = "graphx"
protected val ALGO_GRAPHFRAMES = "graphframes"
Expand Down Expand Up @@ -49,12 +51,66 @@ private[graphframes] trait WithCheckpointInterval extends Logging {
this
}

// python-friendly setter
private[graphframes] def setCheckpointInterval(value: java.lang.Integer): this.type = {
setCheckpointInterval(value.toInt)
}

/**
* Gets checkpoint interval.
*/
def getCheckpointInterval: Int = checkpointInterval
}

private[graphframes] trait WithBroadcastThreshold extends Logging {
protected var broadcastThreshold: Int = 1000000

/**
* Sets broadcast threshold in propagating component assignments (default: 1000000). If a node
* degree is greater than this threshold at some iteration, its component assignment will be
* collected and then broadcasted back to propagate the assignment to its neighbors. Otherwise,
* the assignment propagation is done by a normal Spark join. This parameter is only used when
* the algorithm is set to "graphframes".
*/
def setBroadcastThreshold(value: Int): this.type = {
require(value >= 0, s"Broadcast threshold must be non-negative but got $value.")
broadcastThreshold = value
this
}

// python-friendly setter
private[graphframes] def setBroadcastThreshold(value: java.lang.Integer): this.type = {
setBroadcastThreshold(value.toInt)
}

/**
* Gets broadcast threshold in propagating component assignment.
* @see
* [[org.graphframes.lib.ConnectedComponents.setBroadcastThreshold]]
*/
def getBroadcastThreshold: Int = broadcastThreshold
}

private[graphframes] trait WithIntermediateStorageLevel extends Logging {

protected var intermediateStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK

/**
* Sets storage level for intermediate datasets that require multiple passes (default:
* ``MEMORY_AND_DISK``).
*/
def setIntermediateStorageLevel(value: StorageLevel): this.type = {
intermediateStorageLevel = value
this
}

/**
* Gets storage level for intermediate datasets that require multiple passes.
*/
def getIntermediateStorageLevel: StorageLevel = intermediateStorageLevel

}

private[graphframes] trait WithMaxIter {
protected var maxIter: Option[Int] = None

Expand Down
17 changes: 17 additions & 0 deletions src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,23 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
assert(spark.sparkContext.getPersistentRDDs.size === priorCachedDFsSize)
}

test("set configuration from spark conf") {
spark.conf.set("spark.graphframes.connectedComponents.algorithm", "GRAPHX")
assert(Graphs.friends.connectedComponents.getAlgorithm == "graphx")

spark.conf.set("spark.graphframes.connectedComponents.broadcastthreshold", "1000")
assert(Graphs.friends.connectedComponents.getBroadcastThreshold == 1000)

spark.conf.set("spark.graphframes.connectedComponents.checkpointinterval", "5")
assert(Graphs.friends.connectedComponents.getCheckpointInterval == 5)

spark.conf.set(
"spark.graphframes.connectedComponents.intermediatestoragelevel",
"memory_only")
assert(
Graphs.friends.connectedComponents.getIntermediateStorageLevel == StorageLevel.MEMORY_ONLY)
}

private def assertComponents[T: ClassTag: TypeTag](
actual: DataFrame,
expected: Set[Set[T]]): Unit = {
Expand Down