Skip to content

Commit e0408ac

Browse files
committed
Port Perl unit tests to Python and fix relevant bugs
1 parent e0f731f commit e0408ac

6 files changed

Lines changed: 660 additions & 9 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.pyc

rivescript/python.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def load(self, name, code):
5353
def call(self, rs, name, user, fields):
5454
"""Invoke a previously loaded object."""
5555
# Call the dynamic method.
56+
if not name in self._objects:
57+
return '[ERR: Object Not Found]'
5658
func = self._objects[name]
5759
reply = ''
5860
try:

rivescript/rivescript.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# Common regular expressions.
1717
re_equals = re.compile('\s*=\s*')
1818
re_ws = re.compile('\s+')
19-
re_objend = re.compile('<\s*object')
19+
re_objend = re.compile('^\s*<\s*object')
2020
re_weight = re.compile('\{weight=(\d+)\}')
2121
re_inherit = re.compile('\{inherits=(\d+)\}')
2222
re_wilds = re.compile('[\s\*\#\_]+')
@@ -25,6 +25,10 @@
2525
# Version of RiveScript we support.
2626
rs_version = 2.0
2727

28+
# Exportable constants.
29+
RS_ERR_MATCH = "ERR: No Reply Matched"
30+
RS_ERR_REPLY = "ERR: No Reply Found"
31+
2832

2933
class RiveScript:
3034
"""A RiveScript interpreter for Python 2 and 3."""
@@ -80,7 +84,7 @@ def VERSION(self=None):
8084

8185
def _say(self, message):
8286
if self._debug:
83-
print("[RS]", message)
87+
print("[RS] {}".format(message))
8488
if self._log:
8589
# Log it to the file.
8690
fh = open(self._log, 'a')
@@ -455,7 +459,7 @@ def _parse(self, fname, code):
455459
# Only try to parse a language we support.
456460
ontrig = ''
457461
if lang is None:
458-
self._warn("Trying to parse unknown programming language", fname, fileno)
462+
self._warn("Trying to parse unknown programming language", fname, lineno)
459463
lang = 'python' # Assume it's Python.
460464

461465
# See if we have a defined handler for this language.
@@ -1203,7 +1207,7 @@ def reply(self, user, msg):
12031207

12041208
return reply
12051209

1206-
def _format_message(self, msg):
1210+
def _format_message(self, msg, botreply=False):
12071211
"""Format a user's message for safe processing."""
12081212

12091213
# Make sure the string is Unicode for Python 2.
@@ -1220,6 +1224,10 @@ def _format_message(self, msg):
12201224
# (to protect from obvious XSS attacks).
12211225
if self._utf8:
12221226
msg = re.sub(r'[\\<>]', '', msg)
1227+
1228+
# For the bot's reply, also strip common punctuation.
1229+
if botreply:
1230+
msg = re.sub(r'[.?,!;:@#$%^&*()]', '', msg)
12231231
else:
12241232
# For everything else, strip all non-alphanumerics.
12251233
msg = self._strip_nasties(msg)
@@ -1301,7 +1309,7 @@ def _getreply(self, user, msg, context='normal', step=0):
13011309
lastReply = self._users[user]["__history__"]["reply"][0]
13021310

13031311
# Format the bot's last reply the same way as the human's.
1304-
lastReply = self._format_message(lastReply)
1312+
lastReply = self._format_message(lastReply, botreply=True)
13051313

13061314
self._say("lastReply: " + lastReply)
13071315

@@ -1474,9 +1482,9 @@ def _getreply(self, user, msg, context='normal', step=0):
14741482

14751483
# Still no reply?
14761484
if not foundMatch:
1477-
reply = "ERR: No Reply Matched"
1485+
reply = RS_ERR_MATCH
14781486
elif len(reply) == 0:
1479-
reply = "ERR: No Reply Found"
1487+
reply = RS_ERR_FOUND
14801488

14811489
self._say("Reply: " + reply)
14821490

@@ -1554,7 +1562,7 @@ def _reply_regexp(self, user, regexp):
15541562
# Simple replacements.
15551563
regexp = re.sub(r'\*', r'(.+?)', regexp) # Convert * into (.+?)
15561564
regexp = re.sub(r'#', r'(\d+?)', regexp) # Convert # into (\d+?)
1557-
regexp = re.sub(r'_', r'([A-Za-z]+?)', regexp) # Convert _ into (\w+?)
1565+
regexp = re.sub(r'_', r'(\w+?)', regexp) # Convert _ into (\w+?)
15581566
regexp = re.sub(r'\{weight=\d+\}', '', regexp) # Remove {weight} tags
15591567
regexp = re.sub(r'<zerowidthstar>', r'(.*?)', regexp)
15601568

@@ -1577,6 +1585,9 @@ def _reply_regexp(self, user, regexp):
15771585

15781586
regexp = re.sub(r'\s*\[' + re.escape(match) + '\]\s*', '(?:' + pipes + ')', regexp)
15791587

1588+
# _ wildcards can't match numbers!
1589+
regexp = re.sub(r'\\w', r'[A-Za-z]', regexp)
1590+
15801591
# Filter in arrays.
15811592
arrays = re.findall(r'\@(.+?)\b', regexp)
15821593
for array in arrays:
@@ -1943,7 +1954,6 @@ def _find_trigger_by_inheritence(self, topic, trig, depth=0):
19431954
return match
19441955

19451956
# Don't know what else to do!
1946-
self._warn("User matched a trigger, " + trig + ", but I can't find out what topic it belongs to!")
19471957
return None
19481958

19491959
def _get_topic_tree(self, topic, depth=0):

tests/__init__.py

Whitespace-only changes.

tests/config.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/usr/bin/env python
2+
3+
"""Utility functions for the unit tests."""
4+
5+
import unittest
6+
from rivescript import RiveScript
7+
8+
class RiveScriptTestCase(unittest.TestCase):
9+
"""Base class for all RiveScript test cases, with helper functions."""
10+
11+
def setUp(self, **kwargs):
12+
self.rs = None # Local RiveScript bot object
13+
self.username = "localuser"
14+
15+
16+
def tearDown(self):
17+
pass
18+
19+
20+
def new(self, code, **kwargs):
21+
"""Make a bot and stream in the code."""
22+
self.rs = RiveScript(**kwargs)
23+
self.extend(code)
24+
25+
26+
def extend(self, code):
27+
"""Stream code into the bot."""
28+
lines = code.split("\n")
29+
self.rs.stream(lines)
30+
self.rs.sort_replies()
31+
32+
33+
def reply(self, message, expected):
34+
"""Test that the user's message gets the expected response."""
35+
reply = self.rs.reply(self.username, message)
36+
self.assertEqual(reply, expected)
37+
38+
39+
def uservar(self, var, expected):
40+
"""Test the value of a user variable."""
41+
value = self.rs.get_uservar(self.username, var)
42+
self.assertEqual(value, expected)

0 commit comments

Comments
 (0)