Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 70 additions & 31 deletions zeroconf/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import sys
import threading
from types import TracebackType # noqa # used in type hints
from typing import Dict, List, Optional, Tuple, Type, Union, cast
from typing import Awaitable, Dict, List, Optional, Tuple, Type, Union, cast

from ._cache import DNSCache
from ._dns import DNSQuestion, DNSQuestionType
Expand All @@ -43,7 +43,7 @@
from ._services.info import ServiceInfo, instance_name_from_service_info
from ._services.registry import ServiceRegistry
from ._updates import RecordUpdate, RecordUpdateListener
from ._utils.asyncio import get_running_loop, shutdown_loop, wait_event_or_timeout
from ._utils.asyncio import await_awaitable, get_running_loop, shutdown_loop, wait_event_or_timeout
from ._utils.name import service_type_name
from ._utils.net import (
IPVersion,
Expand Down Expand Up @@ -74,6 +74,7 @@

_TC_DELAY_RANDOM_INTERVAL = (400, 500)
_CLOSE_TIMEOUT = 3
_REGISTER_BROADCASTS = 3


class AsyncEngine:
Expand Down Expand Up @@ -478,6 +479,27 @@ def register_service(
allow_name_change: bool = False,
cooperating_responders: bool = False,
) -> None:
"""Registers service information to the network with a default TTL.
Zeroconf will then respond to requests for information for that
service. The name of the service may be changed if needed to make
it unique on the network. Additionally multiple cooperating responders
can register the same service on the network for resilience
(if you want this behavior set `cooperating_responders` to `True`)."""
assert self.loop is not None
asyncio.run_coroutine_threadsafe(
await_awaitable(
self.async_register_service(info, ttl, allow_name_change, cooperating_responders)
),
self.loop,
).result(millis_to_seconds(_REGISTER_TIME * _REGISTER_BROADCASTS) + _LOADED_SYSTEM_TIMEOUT)

async def async_register_service(
self,
info: ServiceInfo,
ttl: Optional[int] = None,
allow_name_change: bool = False,
cooperating_responders: bool = False,
) -> Awaitable:
"""Registers service information to the network with a default TTL.
Zeroconf will then respond to requests for information for that
service. The name of the service may be changed if needed to make
Expand All @@ -489,47 +511,41 @@ def register_service(
# Setting TTLs via ServiceInfo is preferred
info.host_ttl = ttl
info.other_ttl = ttl
self.check_service(info, allow_name_change, cooperating_responders)

await self.async_wait_for_start()
await self.async_check_service(info, allow_name_change, cooperating_responders)
self.registry.add(info)
self._broadcast_service(info, _REGISTER_TIME, None)
return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None))

def update_service(self, info: ServiceInfo) -> None:
"""Registers service information to the network with a default TTL.
Zeroconf will then respond to requests for information for that
service."""
assert self.loop is not None
asyncio.run_coroutine_threadsafe(await_awaitable(self.async_update_service(info)), self.loop).result(
millis_to_seconds(_REGISTER_TIME * _REGISTER_BROADCASTS) + _LOADED_SYSTEM_TIMEOUT
)

async def async_update_service(self, info: ServiceInfo) -> Awaitable:
"""Registers service information to the network with a default TTL.
Zeroconf will then respond to requests for information for that
service."""
self.registry.update(info)
self._broadcast_service(info, _REGISTER_TIME, None)
return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None))

def _broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None:
async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None:
"""Send a broadcasts to announce a service at intervals."""
now = current_time_millis()
next_time = now
i = 0
while i < 3:
if now < next_time:
self.wait(next_time - now)
now = current_time_millis()
continue

self.send_service_broadcast(info, ttl)
i += 1
next_time += interval

def send_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> None:
"""Send a broadcast to announce a service."""
self.send(self.generate_service_broadcast(info, ttl))
for i in range(_REGISTER_BROADCASTS):
if i != 0:
await asyncio.sleep(millis_to_seconds(interval))
self.async_send(self.generate_service_broadcast(info, ttl))

def generate_service_broadcast(self, info: ServiceInfo, ttl: Optional[int]) -> DNSOutgoing:
"""Generate a broadcast to announce a service."""
out = DNSOutgoing(_FLAGS_QR_RESPONSE | _FLAGS_AA)
self._add_broadcast_answer(out, info, ttl)
return out

def send_service_query(self, info: ServiceInfo) -> None:
"""Send a query to lookup a service."""
self.send(self.generate_service_query(info))

def generate_service_query(self, info: ServiceInfo) -> DNSOutgoing: # pylint: disable=no-self-use
"""Generate a query to lookup a service."""
out = DNSOutgoing(_FLAGS_QR_QUERY | _FLAGS_AA)
Expand Down Expand Up @@ -559,9 +575,16 @@ def _add_broadcast_answer( # pylint: disable=no-self-use
out.add_answer_at_time(dns_address, 0)

def unregister_service(self, info: ServiceInfo) -> None:
"""Unregister a service."""
assert self.loop is not None
asyncio.run_coroutine_threadsafe(
await_awaitable(self.async_unregister_service(info)), self.loop
).result(millis_to_seconds(_UNREGISTER_TIME * _REGISTER_BROADCASTS) + _LOADED_SYSTEM_TIMEOUT)

async def async_unregister_service(self, info: ServiceInfo) -> Awaitable:
"""Unregister a service."""
self.registry.remove(info)
self._broadcast_service(info, _UNREGISTER_TIME, 0)
return asyncio.ensure_future(self._async_broadcast_service(info, _UNREGISTER_TIME, 0))

def generate_unregister_all_services(self) -> Optional[DNSOutgoing]:
"""Generate a DNSOutgoing goodbye for all services and remove them from the registry."""
Expand All @@ -574,6 +597,22 @@ def generate_unregister_all_services(self) -> Optional[DNSOutgoing]:
self.registry.remove(service_infos)
return out

async def async_unregister_all_services(self) -> None:
"""Unregister all registered services.

Unlike async_register_service and async_unregister_service, this
method does not return a future and is always expected to be
awaited since its only called at shutdown.
"""
# Send Goodbye packets https://datatracker.ietf.org/doc/html/rfc6762#section-10.1
out = self.generate_unregister_all_services()
if not out:
return
for i in range(_REGISTER_BROADCASTS):
if i != 0:
await asyncio.sleep(millis_to_seconds(_UNREGISTER_TIME))
self.async_send(out)

def unregister_all_services(self) -> None:
"""Unregister all registered services."""
# Send Goodbye packets https://datatracker.ietf.org/doc/html/rfc6762#section-10.1
Expand All @@ -592,7 +631,7 @@ def unregister_all_services(self) -> None:
i += 1
next_time += _UNREGISTER_TIME

def check_service(
async def async_check_service(
self, info: ServiceInfo, allow_name_change: bool, cooperating_responders: bool = False
) -> None:
"""Checks the network for a unique service name, modifying the
Expand All @@ -603,7 +642,7 @@ def check_service(
next_instance_number = 2
next_time = now = current_time_millis()
i = 0
while i < 3:
while i < _REGISTER_BROADCASTS:
# check for a name conflict
while self.cache.current_entry_with_name_and_alias(info.type, info.name):
if not allow_name_change:
Expand All @@ -617,11 +656,11 @@ def check_service(
i = 0

if now < next_time:
self.wait(next_time - now)
await self.async_wait(next_time - now)
now = current_time_millis()
continue

self.send_service_query(info)
self.async_send(self.generate_service_query(info))
i += 1
next_time += _CHECK_TIME

Expand Down
9 changes: 8 additions & 1 deletion zeroconf/_utils/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
import asyncio
import contextlib
import queue
from typing import Any, List, Optional, Set, cast
from typing import Any, Awaitable, List, Optional, Set, cast

# The combined timeouts should be lower than _CLOSE_TIMEOUT + _WAIT_FOR_LOOP_TASKS_TIMEOUT
_TASK_AWAIT_TIMEOUT = 1
_GET_ALL_TASKS_TIMEOUT = 3
_WAIT_FOR_LOOP_TASKS_TIMEOUT = 3 # Must be larger than _TASK_AWAIT_TIMEOUT
Expand Down Expand Up @@ -80,6 +81,12 @@ async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None:
await asyncio.wait(wait_tasks, timeout=_TASK_AWAIT_TIMEOUT)


async def await_awaitable(aw: Awaitable) -> None:
"""Wait on an awaitable and the task it returns."""
task = await aw
await task


def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None:
"""Wait for pending tasks and stop an event loop."""
pending_tasks = set(
Expand Down
54 changes: 9 additions & 45 deletions zeroconf/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,15 @@

from ._core import Zeroconf
from ._dns import DNSQuestionType
from ._exceptions import NonUniqueNameException
from ._services import ServiceListener
from ._services.browser import _ServiceBrowserBase
from ._services.info import ServiceInfo, instance_name_from_service_info
from ._services.info import ServiceInfo
from ._services.types import ZeroconfServiceTypes
from ._utils.net import IPVersion, InterfaceChoice, InterfacesType
from ._utils.time import millis_to_seconds
from .const import (
_BROWSER_TIME,
_CHECK_TIME,
_MDNS_PORT,
_REGISTER_TIME,
_SERVICE_TYPE_ENUMERATION_NAME,
_UNREGISTER_TIME,
)


Expand Down Expand Up @@ -172,16 +167,11 @@ def __init__(
)
self.async_browsers: Dict[ServiceListener, AsyncServiceBrowser] = {}

async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None:
"""Send a broadcasts to announce a service at intervals."""
for i in range(3):
if i != 0:
await asyncio.sleep(millis_to_seconds(interval))
self.zeroconf.async_send(self.zeroconf.generate_service_broadcast(info, ttl))

async def async_register_service(
self,
info: ServiceInfo,
ttl: Optional[int] = None,
allow_name_change: bool = False,
cooperating_responders: bool = False,
) -> Awaitable:
"""Registers service information to the network with a default TTL.
Expand All @@ -194,10 +184,9 @@ async def async_register_service(
The service will be broadcast in a task. This task is returned
and therefore can be awaited if necessary.
"""
await self.zeroconf.async_wait_for_start()
await self.async_check_service(info, cooperating_responders)
self.zeroconf.registry.add(info)
return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None))
return await self.zeroconf.async_register_service(
info, ttl, allow_name_change, cooperating_responders
)

async def async_unregister_all_services(self) -> None:
"""Unregister all registered services.
Expand All @@ -206,39 +195,15 @@ async def async_unregister_all_services(self) -> None:
method does not return a future and is always expected to be
awaited since its only called at shutdown.
"""
out = self.zeroconf.generate_unregister_all_services()
if not out:
return
for i in range(3):
if i != 0:
await asyncio.sleep(millis_to_seconds(_UNREGISTER_TIME))
self.zeroconf.async_send(out)

async def async_check_service(self, info: ServiceInfo, cooperating_responders: bool = False) -> None:
"""Checks the network for a unique service name."""
instance_name_from_service_info(info)
if cooperating_responders:
return
self._raise_on_name_conflict(info)
for i in range(3):
if i != 0:
await asyncio.sleep(millis_to_seconds(_CHECK_TIME))
self.zeroconf.async_send(self.zeroconf.generate_service_query(info))
self._raise_on_name_conflict(info)

def _raise_on_name_conflict(self, info: ServiceInfo) -> None:
"""Raise NonUniqueNameException if the ServiceInfo has a conflict."""
if self.zeroconf.cache.current_entry_with_name_and_alias(info.type, info.name):
raise NonUniqueNameException
await self.zeroconf.async_unregister_all_services()

async def async_unregister_service(self, info: ServiceInfo) -> Awaitable:
"""Unregister a service.

The service will be broadcast in a task. This task is returned
and therefore can be awaited if necessary.
"""
self.zeroconf.registry.remove(info)
return asyncio.ensure_future(self._async_broadcast_service(info, _UNREGISTER_TIME, 0))
return await self.zeroconf.async_unregister_service(info)

async def async_update_service(self, info: ServiceInfo) -> Awaitable:
"""Registers service information to the network with a default TTL.
Expand All @@ -248,8 +213,7 @@ async def async_update_service(self, info: ServiceInfo) -> Awaitable:
The service will be broadcast in a task. This task is returned
and therefore can be awaited if necessary.
"""
self.zeroconf.registry.update(info)
return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None))
return await self.zeroconf.async_update_service(info)

async def async_close(self) -> None:
"""Ends the background threads, and prevent this instance from
Expand Down