Skip to content

Commit 1570e20

Browse files
zsoljorenham
authored andcommitted
add helper to convert nodes to matchers (Instagram#1351)
* add helper to convert nodes to matchers * suppress type error
1 parent edfc8cf commit 1570e20

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

libcst/helpers/matchers.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
#
6+
7+
from dataclasses import fields, is_dataclass, MISSING
8+
9+
from libcst import matchers
10+
from libcst._nodes.base import CSTNode
11+
12+
13+
def node_to_matcher(
14+
node: CSTNode, *, match_syntactic_trivia: bool = False
15+
) -> matchers.BaseMatcherNode:
16+
"""Convert a concrete node to a matcher."""
17+
if not is_dataclass(node):
18+
raise ValueError(f"{node} is not a CSTNode")
19+
20+
attrs = {}
21+
for field in fields(node):
22+
name = field.name
23+
child = getattr(node, name)
24+
if not match_syntactic_trivia and field.name.startswith("whitespace"):
25+
# Not all nodes have whitespace fields, some have multiple, but they all
26+
# start with whitespace*
27+
child = matchers.DoNotCare()
28+
elif field.default is not MISSING and child == field.default:
29+
child = matchers.DoNotCare()
30+
# pyre-ignore[29]: Union[MISSING_TYPE, ...] is not a function.
31+
elif field.default_factory is not MISSING and child == field.default_factory():
32+
child = matchers.DoNotCare()
33+
elif isinstance(child, (list, tuple)):
34+
child = type(child)(
35+
node_to_matcher(item, match_syntactic_trivia=match_syntactic_trivia)
36+
for item in child
37+
)
38+
elif hasattr(matchers, type(child).__name__):
39+
child = node_to_matcher(
40+
child, match_syntactic_trivia=match_syntactic_trivia
41+
)
42+
attrs[name] = child
43+
44+
matcher = getattr(matchers, type(node).__name__)
45+
return matcher(**attrs)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
#
6+
7+
from libcst import parse_expression, parse_statement
8+
from libcst.helpers.matchers import node_to_matcher
9+
from libcst.matchers import matches
10+
from libcst.testing.utils import data_provider, UnitTest
11+
12+
13+
class MatchersTest(UnitTest):
14+
@data_provider(
15+
(
16+
('"some string"',),
17+
("call(some, **kwargs)",),
18+
("a[b.c]",),
19+
("[1 for _ in range(99) if False]",),
20+
)
21+
)
22+
def test_reflexive_expressions(self, code: str) -> None:
23+
node = parse_expression(code)
24+
matcher = node_to_matcher(node)
25+
self.assertTrue(matches(node, matcher))
26+
27+
@data_provider(
28+
(
29+
("def foo(a) -> None: pass",),
30+
("class F: ...",),
31+
("foo: bar",),
32+
)
33+
)
34+
def test_reflexive_statements(self, code: str) -> None:
35+
node = parse_statement(code)
36+
matcher = node_to_matcher(node)
37+
self.assertTrue(matches(node, matcher))
38+
39+
def test_whitespace(self) -> None:
40+
code_ws = parse_expression("(foo , bar )")
41+
code = parse_expression("(foo,bar)")
42+
self.assertTrue(
43+
matches(
44+
code,
45+
node_to_matcher(code_ws),
46+
)
47+
)
48+
self.assertFalse(
49+
matches(
50+
code,
51+
node_to_matcher(code_ws, match_syntactic_trivia=True),
52+
)
53+
)

0 commit comments

Comments
 (0)