Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ async def test_async_with_sync_passed_in_closed_in_async() -> None:
@pytest.mark.asyncio
async def test_sync_within_event_loop_executor() -> None:
"""Test sync version still works from an executor within an event loop."""

def sync_code():
zc = Zeroconf(interfaces=['127.0.0.1'])
assert zc.get_service_info("_neverused._tcp.local.", "xneverused._neverused._tcp.local.", 10) is None
Expand Down
25 changes: 25 additions & 0 deletions tests/test_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,31 @@ def test_dns_service_record_hashablity():
assert len(record_set) == 4


def test_dns_nsec_record_hashablity():
"""Test DNSNsec are hashable."""
nsec1 = r.DNSNsec(
'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'irrelevant', [1, 2, 3]
)
nsec2 = r.DNSNsec(
'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'irrelevant', [1, 2]
)

record_set = set([nsec1, nsec2])
assert len(record_set) == 2

record_set.add(nsec1)
assert len(record_set) == 2

nsec2_dupe = r.DNSNsec(
'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'irrelevant', [1, 2]
)
assert nsec2 == nsec2_dupe
assert nsec2.__hash__() == nsec2_dupe.__hash__()

record_set.add(nsec2_dupe)
assert len(record_set) == 2


def test_rrset_does_not_consider_ttl():
"""Test DNSRRSet does not consider the ttl in the hash."""

Expand Down
15 changes: 15 additions & 0 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,21 @@ def test_qu_packet_parser():
assert ",QU," in str(parsed.questions[0])


def test_parse_packet_with_nsec_record():
"""Test we can parse a packet with an NSEC record."""
nsec_packet = (
b"\x00\x00\x84\x00\x00\x00\x00\x01\x00\x00\x00\x03\x08_meshcop\x04_udp\x05local\x00\x00\x0c\x00"
b"\x01\x00\x00\x11\x94\x00\x0f\x0cMyHome54 (2)\xc0\x0c\xc0+\x00\x10\x80\x01\x00\x00\x11\x94\x00"
b")\x0bnn=MyHome54\x13xp=695034D148CC4784\x08tv=0.0.0\xc0+\x00!\x80\x01\x00\x00\x00x\x00\x15\x00"
b"\x00\x00\x00\xc0'\x0cMaster-Bed-2\xc0\x1a\xc0+\x00/\x80\x01\x00\x00\x11\x94\x00\t\xc0+\x00\x05"
b"\x00\x00\x80\x00@"
)
parsed = DNSIncoming(nsec_packet)
nsec_record = parsed.answers[3]
assert "nsec," in str(nsec_record)
assert nsec_record.rdtypes == [16, 33]


def test_records_same_packet_share_fate():
"""Test records in the same packet all have the same created time."""
out = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA)
Expand Down
1 change: 1 addition & 0 deletions zeroconf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
DNSAddress,
DNSEntry,
DNSHinfo,
DNSNsec,
DNSPointer,
DNSQuestion,
DNSRecord,
Expand Down
46 changes: 40 additions & 6 deletions zeroconf/_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import enum
import socket
from typing import Any, Dict, Iterable, Optional, TYPE_CHECKING, Tuple, Union, cast
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Union, cast

from ._exceptions import AbstractMethodException
from ._utils.net import _is_v6_address
Expand Down Expand Up @@ -116,11 +116,7 @@ class DNSQuestion(DNSEntry):

def answered_by(self, rec: 'DNSRecord') -> bool:
"""Returns true if the question is answered by the record"""
return (
self.class_ == rec.class_
and (self.type == rec.type or self.type == _TYPE_ANY)
and self.name == rec.name
)
return self.class_ == rec.class_ and self.type in (rec.type, _TYPE_ANY) and self.name == rec.name

def __hash__(self) -> int:
return hash((self.name, self.class_, self.type))
Expand Down Expand Up @@ -446,6 +442,44 @@ def __repr__(self) -> str:
return self.to_string("%s:%s" % (self.server, self.port))


class DNSNsec(DNSRecord):

"""A DNS NSEC record"""

__slots__ = ('next', 'rdtypes')

def __init__(
self,
name: str,
type_: int,
class_: int,
ttl: int,
next: str,
rdtypes: List[int],
created: Optional[float] = None,
) -> None:
super().__init__(name, type_, class_, ttl, created)
self.next = next
self.rdtypes = rdtypes

def __eq__(self, other: Any) -> bool:
"""Tests equality on cpu and os"""
return (
isinstance(other, DNSNsec)
and self.next == other.next
and self.rdtypes == other.rdtypes
and DNSEntry.__eq__(self, other)
)

def __hash__(self) -> int:
"""Hash to compare like DNSNSec."""
return hash((*self._entry_tuple(), self.next, *self.rdtypes))

def __repr__(self) -> str:
"""String representation"""
return self.to_string(self.next + "," + "|".join([self.get_type(type_) for type_ in self.rdtypes]))


class DNSRRSet:
"""A set of dns records independent of the ttl."""

Expand Down
28 changes: 27 additions & 1 deletion zeroconf/_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Tuple, Union, cast


from ._dns import DNSAddress, DNSHinfo, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText
from ._dns import DNSAddress, DNSHinfo, DNSNsec, DNSPointer, DNSQuestion, DNSRecord, DNSService, DNSText
from ._exceptions import IncomingDecodeError, NamePartTooLongException
from ._logger import QuietLogger, log
from ._utils.struct import int2byte
Expand All @@ -43,6 +43,7 @@
_TYPE_AAAA,
_TYPE_CNAME,
_TYPE_HINFO,
_TYPE_NSEC,
_TYPE_PTR,
_TYPE_SRV,
_TYPE_TXT,
Expand Down Expand Up @@ -201,6 +202,18 @@ def read_others(self) -> None:
rec = DNSAddress(
domain, type_, class_, ttl, self.read_string(16), created=self.now, scope_id=self.scope_id
)
elif type_ == _TYPE_NSEC:
name_start = self.offset
name = self.read_name()
rec = DNSNsec(
domain,
type_,
class_,
ttl,
name,
self.read_bitmap(name_start + length),
self.now,
)
else:
# Try to ignore types we don't know about
# Skip the payload for the resource record so the next
Expand All @@ -210,6 +223,19 @@ def read_others(self) -> None:
if rec is not None:
self.answers.append(rec)

def read_bitmap(self, end: int) -> List[int]:
"""Reads an NSEC bitmap from the packet."""
rdtypes = []
while self.offset < end:
window = self.data[self.offset]
bitmap_length = self.data[self.offset + 1]
for i, byte in enumerate(self.data[self.offset + 2 : self.offset + 2 + bitmap_length]):
for bit in range(0, 8):
if byte & (0x80 >> bit):
rdtypes.append(bit + window * 256 + i * 8)
self.offset += 2 + bitmap_length
return rdtypes

def read_name(self) -> str:
"""Reads a domain name from the packet"""
result = ''
Expand Down
2 changes: 2 additions & 0 deletions zeroconf/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
_TYPE_TXT = 16
_TYPE_AAAA = 28
_TYPE_SRV = 33
_TYPE_NSEC = 47
_TYPE_ANY = 255

# Mapping constants to names
Expand Down Expand Up @@ -136,6 +137,7 @@
_TYPE_AAAA: "quada",
_TYPE_SRV: "srv",
_TYPE_ANY: "any",
_TYPE_NSEC: "nsec",
}

_HAS_A_TO_Z = re.compile(r'[A-Za-z]')
Expand Down