diff --git a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala index f94274082..5ca41e0dc 100644 --- a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala @@ -27,6 +27,7 @@ import org.graphframes.GraphFrame import org.graphframes.Logging import org.graphframes.WithAlgorithmChoice import org.graphframes.WithCheckpointInterval +import org.graphframes.WithMaxIter import java.io.IOException import java.math.BigDecimal @@ -45,7 +46,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) extends Arguments with Logging with WithAlgorithmChoice - with WithCheckpointInterval { + with WithCheckpointInterval + with WithMaxIter { private var broadcastThreshold: Int = 1000000 setAlgorithm(ALGO_GRAPHFRAMES) @@ -105,7 +107,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame) runInGraphX = algorithm == ALGO_GRAPHX, broadcastThreshold = broadcastThreshold, checkpointInterval = checkpointInterval, - intermediateStorageLevel = intermediateStorageLevel) + intermediateStorageLevel = intermediateStorageLevel, + maxIter = maxIter) } } @@ -205,9 +208,9 @@ object ConnectedComponents extends Logging { new ConnectedComponents(graph).run() } - private def runGraphX(graph: GraphFrame): DataFrame = { + private def runGraphX(graph: GraphFrame, maxIter: Int): DataFrame = { val components = - org.apache.spark.graphx.lib.ConnectedComponents.run(graph.cachedTopologyGraphX) + org.apache.spark.graphx.lib.ConnectedComponents.run(graph.cachedTopologyGraphX, maxIter) GraphXConversions.fromGraphX(graph, components, vertexNames = Seq(COMPONENT)).vertices } @@ -216,9 +219,10 @@ object ConnectedComponents extends Logging { runInGraphX: Boolean, broadcastThreshold: Int, checkpointInterval: Int, - intermediateStorageLevel: StorageLevel): DataFrame = { + intermediateStorageLevel: StorageLevel, + maxIter: Option[Int]): DataFrame = { if (runInGraphX) { - return runGraphX(graph) + return runGraphX(graph, maxIter.getOrElse(Int.MaxValue)) } val spark = graph.spark diff --git a/src/main/scala/org/graphframes/lib/LabelPropagation.scala b/src/main/scala/org/graphframes/lib/LabelPropagation.scala index 00e20de7d..ff9af4d48 100644 --- a/src/main/scala/org/graphframes/lib/LabelPropagation.scala +++ b/src/main/scala/org/graphframes/lib/LabelPropagation.scala @@ -20,6 +20,7 @@ package org.graphframes.lib import org.apache.spark.graphx.{lib => graphxlib} import org.apache.spark.sql.DataFrame import org.graphframes.GraphFrame +import org.graphframes.WithMaxIter /** * Run static Label Propagation for detecting communities in networks. @@ -35,18 +36,9 @@ import org.graphframes.GraphFrame * The resulting DataFrame contains all the original vertex information and one additional column: * - label (`LongType`): label of community affiliation */ -class LabelPropagation private[graphframes] (private val graph: GraphFrame) extends Arguments { - - private var maxIter: Option[Int] = None - - /** - * The max number of iterations of LPA to be performed. Because this is a static implementation, - * the algorithm will run for exactly this many iterations. - */ - def maxIter(value: Int): this.type = { - maxIter = Some(value) - this - } +class LabelPropagation private[graphframes] (private val graph: GraphFrame) + extends Arguments + with WithMaxIter { def run(): DataFrame = { LabelPropagation.run(graph, check(maxIter, "maxIter")) diff --git a/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala b/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala index 22057b2fc..1632485d9 100644 --- a/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala +++ b/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala @@ -19,6 +19,7 @@ package org.graphframes.lib import org.apache.spark.graphx.{lib => graphxlib} import org.graphframes.GraphFrame +import org.graphframes.WithMaxIter /** * Parallel Personalized PageRank algorithm implementation. @@ -52,10 +53,10 @@ import org.graphframes.GraphFrame * - weight (`DoubleType`): the normalized weight of this edge after running PageRank */ class ParallelPersonalizedPageRank private[graphframes] (private val graph: GraphFrame) - extends Arguments { + extends Arguments + with WithMaxIter { private var resetProb: Option[Double] = Some(0.15) - private var maxIter: Option[Int] = None private var srcIds: Array[Any] = Array() /** Source vertices for a Personalized Page Rank */ @@ -70,12 +71,6 @@ class ParallelPersonalizedPageRank private[graphframes] (private val graph: Grap this } - /** Number of iterations to run */ - def maxIter(value: Int): this.type = { - this.maxIter = Some(value) - this - } - def run(): GraphFrame = { require(maxIter != None, "Max number of iterations maxIter() must be provided") require(srcIds.nonEmpty, "Source vertices Ids sourceIds() must be provided") diff --git a/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala b/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala index 21492cd6d..61847faf4 100644 --- a/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala +++ b/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row import org.graphframes.GraphFrame import org.graphframes.GraphFramesUnreachableException +import org.graphframes.WithMaxIter /** * Implement SVD++ based on "Factorization Meets the Neighborhood: a Multifaceted Collaborative @@ -39,9 +40,10 @@ import org.graphframes.GraphFramesUnreachableException * Returns a DataFrame with vertex attributes containing the trained model. See the object * (static) members for the names of the output columns. */ -class SVDPlusPlus private[graphframes] (private val graph: GraphFrame) extends Arguments { +class SVDPlusPlus private[graphframes] (private val graph: GraphFrame) + extends Arguments + with WithMaxIter { private var _rank: Int = 10 - private var _maxIter: Int = 2 private var _minVal: Double = 0.0 private var _maxVal: Double = 5.0 private var _gamma1: Double = 0.007 @@ -56,11 +58,6 @@ class SVDPlusPlus private[graphframes] (private val graph: GraphFrame) extends A this } - def maxIter(value: Int): this.type = { - _maxIter = value - this - } - def minValue(value: Double): this.type = { _minVal = value this @@ -94,7 +91,7 @@ class SVDPlusPlus private[graphframes] (private val graph: GraphFrame) extends A def run(): DataFrame = { val conf = new graphxlib.SVDPlusPlus.Conf( rank = _rank, - maxIters = _maxIter, + maxIters = maxIter.getOrElse(2), minVal = _minVal, maxVal = _maxVal, gamma1 = _gamma1, diff --git a/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala b/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala index 41b2a2d8e..179e3d316 100644 --- a/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala +++ b/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala @@ -20,6 +20,7 @@ package org.graphframes.lib import org.apache.spark.graphx.{lib => graphxlib} import org.apache.spark.sql.DataFrame import org.graphframes.GraphFrame +import org.graphframes.WithMaxIter /** * Compute the strongly connected component (SCC) of each vertex and return a DataFrame with each @@ -29,14 +30,8 @@ import org.graphframes.GraphFrame * - component (`LongType`): unique ID for this component */ class StronglyConnectedComponents private[graphframes] (private val graph: GraphFrame) - extends Arguments { - - private var maxIter: Option[Int] = None - - def maxIter(value: Int): this.type = { - maxIter = Some(value) - this - } + extends Arguments + with WithMaxIter { def run(): DataFrame = { StronglyConnectedComponents.run(graph, check(maxIter, "maxIter")) diff --git a/src/main/scala/org/graphframes/mixins.scala b/src/main/scala/org/graphframes/mixins.scala index aa8b22afc..6a378903d 100644 --- a/src/main/scala/org/graphframes/mixins.scala +++ b/src/main/scala/org/graphframes/mixins.scala @@ -54,3 +54,15 @@ private[graphframes] trait WithCheckpointInterval extends Logging { */ def getCheckpointInterval: Int = checkpointInterval } + +private[graphframes] trait WithMaxIter { + protected var maxIter: Option[Int] = None + + /** + * The max number of iterations of algorithm to be performed. + */ + def maxIter(value: Int): this.type = { + maxIter = Some(value) + this + } +}