Skip to content

Commit 5060edb

Browse files
committed
undirected edge
1 parent d801d34 commit 5060edb

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed

core/src/main/scala/org/graphframes/GraphFrame.scala

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,11 @@ class GraphFrame private (
373373
* @group motif
374374
*/
375375
def find(pattern: String): DataFrame = {
376-
val VarLengthPattern = """\((\w+)\)-\[(\w*)\*(\d*)\.\.(\d*)\]->\((\w+)\)""".r
376+
val VarLengthPattern = """\((\w+)\)-\[(\w*)\*(\d*)\.\.(\d*)\]-(>?)\((\w+)\)""".r
377+
val UndirectedPattern = """\((\w+)\)-\[(\w*)\]-\((\w+)\)""".r
378+
377379
pattern match {
378-
case VarLengthPattern(src, name, min, max, dst) =>
380+
case VarLengthPattern(src, name, min, max, direction, dst) =>
379381
if (min.isEmpty || max.isEmpty) {
380382
throw new InvalidParseException(
381383
s"Unbounded length patten ${pattern} is not supported! " +
@@ -384,9 +386,22 @@ class GraphFrame private (
384386
val strToSeq: Seq[String] = (min.toInt to max.toInt).reverse.map { hop =>
385387
s"($src)-[$name*$hop]->($dst)"
386388
}
387-
strToSeq
389+
val strToSeqReverse: Seq[String] = if (direction.isEmpty) {
390+
(min.toInt to max.toInt).reverse.map(hop => s"($dst)-[$name*$hop]->($src)")
391+
} else {
392+
Seq.empty[String]
393+
}
394+
395+
(strToSeq ++ strToSeqReverse)
388396
.map(findAugmentedPatterns)
389397
.reduce((a, b) => a.unionByName(b, allowMissingColumns = true))
398+
399+
case UndirectedPattern(src, name, dst) =>
400+
val out: DataFrame = findAugmentedPatterns(s"($src)-[$name]->($dst)")
401+
val in: DataFrame = findAugmentedPatterns(s"($dst)-[$name]->($src)")
402+
403+
out.unionByName(in)
404+
390405
case _ =>
391406
findAugmentedPatterns(pattern)
392407
}

core/src/test/scala/org/graphframes/PatternMatchSuite.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,51 @@ class PatternMatchSuite extends SparkFunSuite with GraphFrameTestSparkContext {
658658
assert(varEdge.except(unionEdge).isEmpty && unionEdge.except(varEdge).isEmpty)
659659
}
660660

661+
test("undirected edge") {
662+
val res = g
663+
.find("(u)-[]-(v)")
664+
.where("u.id == 0")
665+
.select("u.id", "v.id")
666+
.collect()
667+
.toSet
668+
669+
val expected = Set(Row(0L, 1L), Row(0L, 2L))
670+
671+
compareResultToExpected(res, expected)
672+
}
673+
674+
test("undirected with edge name") {
675+
val res = g
676+
.find("(u)-[e]-(v)")
677+
.where("u.id == 0")
678+
.select("e.src", "e.dst", "e.relationship")
679+
.collect()
680+
.toSet
681+
682+
val expected = Set(Row(0L, 1L, "friend"), Row(1L, 0L, "follow"), Row(2L, 0L, "unknown"))
683+
684+
compareResultToExpected(res, expected)
685+
}
686+
687+
test("undirected var-length pattern") {
688+
val res = g
689+
.find("(u)-[e*1..3]-(v)")
690+
.where("u.id == 2")
691+
692+
val df1 = g
693+
.find("(u)-[e*1..3]->(v)")
694+
.where("u.id == 2")
695+
696+
val df2 = g
697+
.find("(v)-[e*1..3]->(u)")
698+
.where("u.id == 2")
699+
700+
val expected = df1.unionByName(df2, allowMissingColumns = true)
701+
702+
assert(res.schema === expected.schema)
703+
assert(res.except(expected).isEmpty && expected.except(res).isEmpty)
704+
}
705+
661706
test("stateful predicates via UDFs") {
662707
val chain4 = g
663708
.find("(a)-[ab]->(b); (b)-[bc]->(c); (c)-[cd]->(d)")

0 commit comments

Comments
 (0)