From 0886b9c7c248098b1422a166e182ae499806b9ab Mon Sep 17 00:00:00 2001
From: semyonsinchenko
Date: Sat, 19 Apr 2025 16:53:09 +0200
Subject: [PATCH] Add a new mixin WithMaxIter and propogate it to GraphX CC
---
.../graphframes/lib/ConnectedComponents.scala | 16 ++++++++++------
.../org/graphframes/lib/LabelPropagation.scala | 16 ++++------------
.../lib/ParallelPersonalizedPageRank.scala | 11 +++--------
.../scala/org/graphframes/lib/SVDPlusPlus.scala | 13 +++++--------
.../lib/StronglyConnectedComponents.scala | 11 +++--------
src/main/scala/org/graphframes/mixins.scala | 12 ++++++++++++
6 files changed, 37 insertions(+), 42 deletions(-)
diff --git a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala
index f94274082..5ca41e0dc 100644
--- a/src/main/scala/org/graphframes/lib/ConnectedComponents.scala
+++ b/src/main/scala/org/graphframes/lib/ConnectedComponents.scala
@@ -27,6 +27,7 @@ import org.graphframes.GraphFrame
import org.graphframes.Logging
import org.graphframes.WithAlgorithmChoice
import org.graphframes.WithCheckpointInterval
+import org.graphframes.WithMaxIter
import java.io.IOException
import java.math.BigDecimal
@@ -45,7 +46,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
extends Arguments
with Logging
with WithAlgorithmChoice
- with WithCheckpointInterval {
+ with WithCheckpointInterval
+ with WithMaxIter {
private var broadcastThreshold: Int = 1000000
setAlgorithm(ALGO_GRAPHFRAMES)
@@ -105,7 +107,8 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
runInGraphX = algorithm == ALGO_GRAPHX,
broadcastThreshold = broadcastThreshold,
checkpointInterval = checkpointInterval,
- intermediateStorageLevel = intermediateStorageLevel)
+ intermediateStorageLevel = intermediateStorageLevel,
+ maxIter = maxIter)
}
}
@@ -205,9 +208,9 @@ object ConnectedComponents extends Logging {
new ConnectedComponents(graph).run()
}
- private def runGraphX(graph: GraphFrame): DataFrame = {
+ private def runGraphX(graph: GraphFrame, maxIter: Int): DataFrame = {
val components =
- org.apache.spark.graphx.lib.ConnectedComponents.run(graph.cachedTopologyGraphX)
+ org.apache.spark.graphx.lib.ConnectedComponents.run(graph.cachedTopologyGraphX, maxIter)
GraphXConversions.fromGraphX(graph, components, vertexNames = Seq(COMPONENT)).vertices
}
@@ -216,9 +219,10 @@ object ConnectedComponents extends Logging {
runInGraphX: Boolean,
broadcastThreshold: Int,
checkpointInterval: Int,
- intermediateStorageLevel: StorageLevel): DataFrame = {
+ intermediateStorageLevel: StorageLevel,
+ maxIter: Option[Int]): DataFrame = {
if (runInGraphX) {
- return runGraphX(graph)
+ return runGraphX(graph, maxIter.getOrElse(Int.MaxValue))
}
val spark = graph.spark
diff --git a/src/main/scala/org/graphframes/lib/LabelPropagation.scala b/src/main/scala/org/graphframes/lib/LabelPropagation.scala
index 00e20de7d..ff9af4d48 100644
--- a/src/main/scala/org/graphframes/lib/LabelPropagation.scala
+++ b/src/main/scala/org/graphframes/lib/LabelPropagation.scala
@@ -20,6 +20,7 @@ package org.graphframes.lib
import org.apache.spark.graphx.{lib => graphxlib}
import org.apache.spark.sql.DataFrame
import org.graphframes.GraphFrame
+import org.graphframes.WithMaxIter
/**
* Run static Label Propagation for detecting communities in networks.
@@ -35,18 +36,9 @@ import org.graphframes.GraphFrame
* The resulting DataFrame contains all the original vertex information and one additional column:
* - label (`LongType`): label of community affiliation
*/
-class LabelPropagation private[graphframes] (private val graph: GraphFrame) extends Arguments {
-
- private var maxIter: Option[Int] = None
-
- /**
- * The max number of iterations of LPA to be performed. Because this is a static implementation,
- * the algorithm will run for exactly this many iterations.
- */
- def maxIter(value: Int): this.type = {
- maxIter = Some(value)
- this
- }
+class LabelPropagation private[graphframes] (private val graph: GraphFrame)
+ extends Arguments
+ with WithMaxIter {
def run(): DataFrame = {
LabelPropagation.run(graph, check(maxIter, "maxIter"))
diff --git a/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala b/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala
index 22057b2fc..1632485d9 100644
--- a/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala
+++ b/src/main/scala/org/graphframes/lib/ParallelPersonalizedPageRank.scala
@@ -19,6 +19,7 @@ package org.graphframes.lib
import org.apache.spark.graphx.{lib => graphxlib}
import org.graphframes.GraphFrame
+import org.graphframes.WithMaxIter
/**
* Parallel Personalized PageRank algorithm implementation.
@@ -52,10 +53,10 @@ import org.graphframes.GraphFrame
* - weight (`DoubleType`): the normalized weight of this edge after running PageRank
*/
class ParallelPersonalizedPageRank private[graphframes] (private val graph: GraphFrame)
- extends Arguments {
+ extends Arguments
+ with WithMaxIter {
private var resetProb: Option[Double] = Some(0.15)
- private var maxIter: Option[Int] = None
private var srcIds: Array[Any] = Array()
/** Source vertices for a Personalized Page Rank */
@@ -70,12 +71,6 @@ class ParallelPersonalizedPageRank private[graphframes] (private val graph: Grap
this
}
- /** Number of iterations to run */
- def maxIter(value: Int): this.type = {
- this.maxIter = Some(value)
- this
- }
-
def run(): GraphFrame = {
require(maxIter != None, "Max number of iterations maxIter() must be provided")
require(srcIds.nonEmpty, "Source vertices Ids sourceIds() must be provided")
diff --git a/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala b/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala
index 21492cd6d..61847faf4 100644
--- a/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala
+++ b/src/main/scala/org/graphframes/lib/SVDPlusPlus.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.graphframes.GraphFrame
import org.graphframes.GraphFramesUnreachableException
+import org.graphframes.WithMaxIter
/**
* Implement SVD++ based on "Factorization Meets the Neighborhood: a Multifaceted Collaborative
@@ -39,9 +40,10 @@ import org.graphframes.GraphFramesUnreachableException
* Returns a DataFrame with vertex attributes containing the trained model. See the object
* (static) members for the names of the output columns.
*/
-class SVDPlusPlus private[graphframes] (private val graph: GraphFrame) extends Arguments {
+class SVDPlusPlus private[graphframes] (private val graph: GraphFrame)
+ extends Arguments
+ with WithMaxIter {
private var _rank: Int = 10
- private var _maxIter: Int = 2
private var _minVal: Double = 0.0
private var _maxVal: Double = 5.0
private var _gamma1: Double = 0.007
@@ -56,11 +58,6 @@ class SVDPlusPlus private[graphframes] (private val graph: GraphFrame) extends A
this
}
- def maxIter(value: Int): this.type = {
- _maxIter = value
- this
- }
-
def minValue(value: Double): this.type = {
_minVal = value
this
@@ -94,7 +91,7 @@ class SVDPlusPlus private[graphframes] (private val graph: GraphFrame) extends A
def run(): DataFrame = {
val conf = new graphxlib.SVDPlusPlus.Conf(
rank = _rank,
- maxIters = _maxIter,
+ maxIters = maxIter.getOrElse(2),
minVal = _minVal,
maxVal = _maxVal,
gamma1 = _gamma1,
diff --git a/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala b/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala
index 41b2a2d8e..179e3d316 100644
--- a/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala
+++ b/src/main/scala/org/graphframes/lib/StronglyConnectedComponents.scala
@@ -20,6 +20,7 @@ package org.graphframes.lib
import org.apache.spark.graphx.{lib => graphxlib}
import org.apache.spark.sql.DataFrame
import org.graphframes.GraphFrame
+import org.graphframes.WithMaxIter
/**
* Compute the strongly connected component (SCC) of each vertex and return a DataFrame with each
@@ -29,14 +30,8 @@ import org.graphframes.GraphFrame
* - component (`LongType`): unique ID for this component
*/
class StronglyConnectedComponents private[graphframes] (private val graph: GraphFrame)
- extends Arguments {
-
- private var maxIter: Option[Int] = None
-
- def maxIter(value: Int): this.type = {
- maxIter = Some(value)
- this
- }
+ extends Arguments
+ with WithMaxIter {
def run(): DataFrame = {
StronglyConnectedComponents.run(graph, check(maxIter, "maxIter"))
diff --git a/src/main/scala/org/graphframes/mixins.scala b/src/main/scala/org/graphframes/mixins.scala
index aa8b22afc..6a378903d 100644
--- a/src/main/scala/org/graphframes/mixins.scala
+++ b/src/main/scala/org/graphframes/mixins.scala
@@ -54,3 +54,15 @@ private[graphframes] trait WithCheckpointInterval extends Logging {
*/
def getCheckpointInterval: Int = checkpointInterval
}
+
+private[graphframes] trait WithMaxIter {
+ protected var maxIter: Option[Int] = None
+
+ /**
+ * The max number of iterations of algorithm to be performed.
+ */
+ def maxIter(value: Int): this.type = {
+ maxIter = Some(value)
+ this
+ }
+}