Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/scala-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ jobs:
~/.ivy2/cache
key: sbt-ivy-cache-spark-${{ matrix.spark-version}}-scala-${{ matrix.scala-version }}-java-${{ matrix.java-version }}
- name: Check scalafmt
run: build/sbt root/scalafmtCheckAll
run: build/sbt scalafmtCheckAll
- name: Check scalastyle
run: build/sbt scalafixAll
- name: Build and Test
run: build/sbt -v ++${{ matrix.scala-version }} -Dspark.version=${{ matrix.spark-version }} coverage test coverageReport
- uses: codecov/codecov-action@v3
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,11 @@ repos:
language: system
types: [scala]
pass_filenames: false

- id: scalafix
name: scalafix
entry: build/sbt scalafixAll
language: system
types: [scala]
pass_filenames: false

8 changes: 8 additions & 0 deletions .scalafix.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
rules = [
RemoveUnused
DisableSyntax
ProcedureSyntax
RedundantSyntax
OrganizeImports
ExplicitResultTypes
]
24 changes: 21 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ ThisBuild / scalaVersion := scalaVer
ThisBuild / organization := "org.graphframes"
ThisBuild / crossScalaVersions := Seq("2.12.18", "2.13.8")

// Scalafix configuration
ThisBuild / semanticdbEnabled := true
ThisBuild / semanticdbVersion := "4.8.10" // The maximal version that supports both 2.13.8 and 2.12.18

lazy val commonSetting = Seq(
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-graphx" % sparkVer % "provided" cross CrossVersion.for3Use2_13,
Expand Down Expand Up @@ -55,7 +59,22 @@ lazy val commonSetting = Seq(
"--add-opens=java.base/java.nio=ALL-UNNAMED",
"--add-opens=java.base/java.lang.invoke=ALL-UNNAMED",
"--add-opens=java.base/java.util=ALL-UNNAMED"),
credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials"))
credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials"),

// Scalafix
scalacOptions ++= Seq(
"-Xlint", // to enforce code quality checks
if (scalaVersion.value.startsWith("2.12")) {
// fail on warning
"-Xfatal-warnings"
} else {
"-Werror" // the same but in 2.13
},
// scalastyle related things
if (scalaVersion.value.startsWith("2.12"))
"-Ywarn-unused-import"
else
"-Wunused:imports"))

lazy val root = (project in file("."))
.settings(
Expand Down Expand Up @@ -108,5 +127,4 @@ lazy val connect = (project in file("graphframes-connect"))
case x =>
val oldStrategy = (assembly / assemblyMergeStrategy).value
oldStrategy(x)
}
)
})
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package org.apache.spark.sql.graphframes

import org.graphframes.connect.proto.GraphFramesAPI

import com.google.protobuf
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.plugin.RelationPlugin

import com.google.protobuf
import org.graphframes.connect.proto.GraphFramesAPI

class GraphFramesConnect extends RelationPlugin {
override def transform(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@
// Same about Column helper object.
package org.apache.spark.sql.graphframes

import scala.jdk.CollectionConverters._
import org.graphframes.{GraphFrame, GraphFramesUnreachableException}
import org.graphframes.connect.proto.{ColumnOrExpression, GraphFramesAPI, StringOrLongID}
import com.google.protobuf.ByteString
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.lit
import org.graphframes.GraphFrame
import org.graphframes.GraphFramesUnreachableException
import org.graphframes.connect.proto.ColumnOrExpression
import org.graphframes.connect.proto.ColumnOrExpression.ColOrExprCase
import org.graphframes.connect.proto.GraphFramesAPI
import org.graphframes.connect.proto.GraphFramesAPI.MethodCase
import org.graphframes.connect.proto.StringOrLongID
import org.graphframes.connect.proto.StringOrLongID.IdCase
import org.apache.spark.sql.{Column, DataFrame, Dataset}
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.functions.{col, expr, lit}
import com.google.protobuf.ByteString

import scala.jdk.CollectionConverters._

object GraphFramesConnectUtils {
private[graphframes] def parseColumnOrExpression(
Expand Down
3 changes: 3 additions & 0 deletions project/plugins.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.4")
// Protobuf things needed for the Spark Connect
addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.7")
libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.10.10"

// Scalafix
addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.14.2")
36 changes: 24 additions & 12 deletions src/main/scala/org/graphframes/GraphFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,27 @@

package org.graphframes

import java.util.Random

import scala.reflect.runtime.universe.TypeTag

import org.graphframes.lib._
import org.graphframes.pattern._

import org.apache.spark.graphx.{Edge, Graph}
import org.apache.spark.graphx.Edge
import org.apache.spark.graphx.Graph
import org.apache.spark.ml.clustering.PowerIterationClustering
import org.apache.spark.sql._
import org.apache.spark.sql.functions.{array, broadcast, col, count, explode, expr, lit, max, monotonically_increasing_id, struct, udf}
import org.apache.spark.sql.functions.array
import org.apache.spark.sql.functions.broadcast
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.count
import org.apache.spark.sql.functions.explode
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.monotonically_increasing_id
import org.apache.spark.sql.functions.struct
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.graphframes.lib._
import org.graphframes.pattern._

import java.util.Random
import scala.reflect.runtime.universe.TypeTag

/**
* A representation of a graph using `DataFrame`s.
Expand Down Expand Up @@ -189,20 +197,26 @@ class GraphFrame private (
if (hasIntegralIdType) {
val vv = vertices.select(col(ID).cast(LongType), nestAsCol(vertices, ATTR)).rdd.map {
case Row(id: Long, attr: Row) => (id, attr)
case _ => throw new GraphFramesUnreachableException()
}
val ee = edges
.select(col(SRC).cast(LongType), col(DST).cast(LongType), nestAsCol(edges, ATTR))
.rdd
.map { case Row(srcId: Long, dstId: Long, attr: Row) => Edge(srcId, dstId, attr) }
.map {
case Row(srcId: Long, dstId: Long, attr: Row) => Edge(srcId, dstId, attr)
case _ => throw new GraphFramesUnreachableException()
}
Graph(vv, ee)
} else {
// Compute Long vertex IDs
val vv = indexedVertices.select(LONG_ID, ATTR).rdd.map {
case Row(long_id: Long, attr: Row) => (long_id, attr)
case _ => throw new GraphFramesUnreachableException()
}
val ee = indexedEdges.select(LONG_SRC, LONG_DST, ATTR).rdd.map {
case Row(long_src: Long, long_dst: Long, attr: Row) =>
Edge(long_src, long_dst, attr)
case _ => throw new GraphFramesUnreachableException()
}
Graph(vv, ee)
}
Expand Down Expand Up @@ -686,8 +700,6 @@ object GraphFrame extends Serializable with Logging {
joinCol: String,
hubs: Set[T],
logPrefix: String): DataFrame = {
val spark = a.sparkSession
import spark.implicits._
if (hubs.isEmpty) {
// No skew. Do regular join.
a.join(b, joinCol)
Expand Down
5 changes: 2 additions & 3 deletions src/main/scala/org/graphframes/GraphFramePythonAPI.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package org.graphframes

import org.apache.spark.sql.DataFrame

import org.graphframes.lib.AggregateMessages
import org.graphframes.examples.Graphs
import org.graphframes.lib.AggregateMessages

private[graphframes] class GraphFramePythonAPI {

def createGraph(v: DataFrame, e: DataFrame) = GraphFrame(v, e)
def createGraph(v: DataFrame, e: DataFrame): GraphFrame = GraphFrame(v, e)

val ID: String = GraphFrame.ID
val SRC: String = GraphFrame.SRC
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/org/graphframes/Logging.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.graphframes

import org.slf4j.{Logger, LoggerFactory}
import org.slf4j.Logger
import org.slf4j.LoggerFactory

// This needs to be accessible to org.apache.spark.graphx.lib.backport
private[org] trait Logging {
Expand Down
16 changes: 11 additions & 5 deletions src/main/scala/org/graphframes/examples/BeliefPropagation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,17 @@

package org.graphframes.examples

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.graphx.{Graph, VertexRDD, Edge => GXEdge}
import org.apache.spark.sql.{Column, Row, SparkSession}
import org.apache.spark.sql.functions.{col, lit, sum, udf, when}

import org.apache.spark.graphx.Graph
import org.apache.spark.graphx.VertexRDD
import org.apache.spark.graphx.{Edge => GXEdge}
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.sum
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.functions.when
import org.graphframes.GraphFrame
import org.graphframes.examples.Graphs.gridIsingModel
import org.graphframes.lib.AggregateMessages
Expand Down
13 changes: 7 additions & 6 deletions src/main/scala/org/graphframes/examples/Graphs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@

package org.graphframes.examples

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.SparkContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, lit, randn, udf}

import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.randn
import org.apache.spark.sql.functions.udf
import org.graphframes.GraphFrame
import org.graphframes.GraphFrame._

import scala.reflect.runtime.universe.TypeTag

class Graphs private[graphframes] () {
// Note: this cannot be values: we are creating and destroying spark contexts during the tests,
// and turning these into vals means we would hold onto a potentially destroyed spark context.
Expand Down Expand Up @@ -103,7 +104,7 @@ class Graphs private[graphframes] () {
v1 <- n until (2 * n)
v2 <- n until (2 * n)
} yield (v1.toLong, v2.toLong, s"$v1-$v2")
val edges = edges1 ++ edges2 :+ (0L, n.toLong, s"0-$n")
val edges = edges1 ++ edges2 ++ Seq((0L, n.toLong, s"0-$n"))
val vertices = (0 until (2 * n)).map { v => (v.toLong, s"$v", v) }
val e = spark.createDataFrame(edges).toDF("src", "dst", "e_attr1")
val v = spark.createDataFrame(vertices).toDF("id", "v_attr1", "v_attr2")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@ package org.graphframes.examples

import java.net.URL
import java.nio.file._
import java.util.Properties

import scala.sys.process._

import org.graphframes.GraphFrame

object LDBCUtils {
private val LDBC_URL_PREFIX = "https://datasets.ldbcouncil.org/graphalytics/"
private val bufferSize = 8192 // 8Kb
Expand Down Expand Up @@ -37,7 +33,7 @@ object LDBCUtils {

private def checkZSTD(): Unit = {
try {
s"zstd --version".!
"zstd --version".!
} catch {
case e: Exception =>
throw new RuntimeException(
Expand Down
14 changes: 8 additions & 6 deletions src/main/scala/org/graphframes/lib/AggregateMessages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.graphframes.lib

import org.apache.spark.sql.functions.{col, expr}
import org.apache.spark.sql.{Column, DataFrame}

import org.graphframes.{GraphFrame, Logging}
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
import org.graphframes.GraphFrame
import org.graphframes.Logging

/**
* This is a primitive for implementing graph algorithms. This method aggregates messages from the
Expand Down Expand Up @@ -101,8 +103,8 @@ class AggregateMessages private[graphframes] (private val g: GraphFrame)
def agg(aggCol: Column): DataFrame = {
require(
msgToSrc.nonEmpty || msgToDst.nonEmpty,
s"To run GraphFrame.aggregateMessages," +
s" messages must be sent to src, dst, or both. Set using sendToSrc(), sendToDst().")
"To run GraphFrame.aggregateMessages," +
" messages must be sent to src, dst, or both. Set using sendToSrc(), sendToDst().")
val triplets = g.triplets
val sentMsgsToSrc = msgToSrc.map { msg =>
val msgsToSrc =
Expand Down
13 changes: 8 additions & 5 deletions src/main/scala/org/graphframes/lib/BFS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

package org.graphframes.lib

import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.functions.{col, expr}
import org.apache.spark.sql.{Column, DataFrame, Row}

import org.graphframes.{GraphFrame, Logging}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
import org.graphframes.GraphFrame
import org.graphframes.GraphFrame.nestAsCol
import org.graphframes.Logging

/**
* Breadth-first search (BFS)
Expand Down Expand Up @@ -193,7 +196,7 @@ private object BFS extends Logging with Serializable {
// TODO: Avoid crossing paths; i.e., touch each vertex at most once.
val previousVertexChecks = Range(1, iter + 1)
.map(i => paths(s"v$i.id") =!= paths(nextVertex + ".id"))
.foldLeft(paths(s"from.id") =!= paths(nextVertex + ".id"))((c1, c2) => c1 && c2)
.foldLeft(paths("from.id") =!= paths(nextVertex + ".id"))((c1, c2) => c1 && c2)
paths = paths.filter(previousVertexChecks)
}
// Check if done by applying toExpr to column nextVertex
Expand Down
17 changes: 8 additions & 9 deletions src/main/scala/org/graphframes/lib/ConnectedComponents.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,20 @@

package org.graphframes.lib

import java.io.IOException
import java.math.BigDecimal
import java.util.UUID

import org.apache.hadoop.fs.Path

import org.graphframes.{GraphFrame, Logging}
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.storage.StorageLevel
import org.graphframes.GraphFrame
import org.graphframes.Logging
import org.graphframes.WithAlgorithmChoice

import java.io.IOException
import java.math.BigDecimal
import java.util.UUID

/**
* Connected Components algorithm.
*
Expand All @@ -44,8 +45,6 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
with Logging
with WithAlgorithmChoice {

import org.graphframes.lib.ConnectedComponents._

private var broadcastThreshold: Int = 1000000
setAlgorithm(ALGO_GRAPHFRAMES)

Expand Down
Loading