Skip to content

Commit

Permalink
Merge pull request #65 from chrigu/feature/http-proxy
Browse files Browse the repository at this point in the history
Feature/http proxy
  • Loading branch information
Fatal1ty authored Jun 25, 2024
2 parents be9a827 + aba9b96 commit 22e4f41
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 8 deletions.
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Features
* Ability to set priority for notifications
* Ability to set collapse-key for notifications
* Ability to use production or development APNs server
* Support for basic HTTP-Proxies


Installation
Expand Down
6 changes: 6 additions & 0 deletions aioapns/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def __init__(
use_sandbox: bool = False,
no_cert_validation: bool = False,
ssl_context: Optional[SSLContext] = None,
proxy_host: Optional[str] = None,
proxy_port: Optional[int] = None,
err_func: Optional[
Callable[
[NotificationRequest, NotificationResult], Awaitable[None]
Expand All @@ -42,6 +44,8 @@ def __init__(
use_sandbox=use_sandbox,
no_cert_validation=no_cert_validation,
ssl_context=ssl_context,
proxy_host=proxy_host,
proxy_port=proxy_port,
)
elif key and key_id and team_id and topic:
self.pool = APNsKeyConnectionPool(
Expand All @@ -53,6 +57,8 @@ def __init__(
max_connection_attempts=max_connection_attempts,
use_sandbox=use_sandbox,
ssl_context=ssl_context,
proxy_host=proxy_host,
proxy_port=proxy_port,
)
else:
raise ValueError(
Expand Down
150 changes: 142 additions & 8 deletions aioapns/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ def __init__(
max_connections: int = 10,
max_connection_attempts: int = 5,
use_sandbox: bool = False,
proxy_host: Optional[str] = None,
proxy_port: Optional[int] = None,
) -> None:
self.apns_topic = topic
self.max_connections = max_connections
Expand All @@ -342,6 +344,10 @@ def __init__(
self.connections: List[APNsBaseClientProtocol] = []
self._lock = asyncio.Lock()
self.max_connection_attempts = max_connection_attempts
self.ssl_context: Optional[ssl.SSLContext] = None

self.proxy_host = proxy_host
self.proxy_port = proxy_port

async def create_connection(self) -> APNsBaseClientProtocol:
raise NotImplementedError
Expand Down Expand Up @@ -428,6 +434,31 @@ async def send_notification(
logger.error("Failed to send after %d attempts.", attempts)
raise MaxAttemptsExceeded

async def _create_proxy_connection(
self, apns_protocol_factory
) -> APNsBaseClientProtocol:
assert self.proxy_host is not None, "proxy_host must be set"
assert self.proxy_port is not None, "proxy_port must be set"

_, protocol = await self.loop.create_connection(
protocol_factory=partial(
HttpProxyProtocol,
self.protocol_class.APNS_SERVER,
self.protocol_class.APNS_PORT,
self.loop,
self.ssl_context,
apns_protocol_factory,
),
host=self.proxy_host,
port=self.proxy_port,
)
await protocol.apns_connection_ready.wait()

assert (
protocol.apns_protocol is not None
), "protocol.apns_protocol could not be set"
return protocol.apns_protocol


class APNsCertConnectionPool(APNsBaseConnectionPool):
def __init__(
Expand All @@ -439,12 +470,16 @@ def __init__(
use_sandbox: bool = False,
no_cert_validation: bool = False,
ssl_context: Optional[ssl.SSLContext] = None,
proxy_host: Optional[str] = None,
proxy_port: Optional[int] = None,
) -> None:
super(APNsCertConnectionPool, self).__init__(
topic=topic,
max_connections=max_connections,
max_connection_attempts=max_connection_attempts,
use_sandbox=use_sandbox,
proxy_host=proxy_host,
proxy_port=proxy_port,
)

self.cert_file = cert_file
Expand All @@ -463,13 +498,23 @@ def __init__(
self.apns_topic = cert.get_subject().UID

async def create_connection(self) -> APNsBaseClientProtocol:
apns_protocol_factory = partial(
self.protocol_class,
self.apns_topic,
self.loop,
self.discard_connection,
)

if self.proxy_host and self.proxy_port:
return await self._create_proxy_connection(apns_protocol_factory)
else:
return await self._create_connection(apns_protocol_factory)

async def _create_connection(
self, apns_protocol_factory
) -> APNsBaseClientProtocol:
_, protocol = await self.loop.create_connection(
protocol_factory=partial(
self.protocol_class,
self.apns_topic,
self.loop,
self.discard_connection,
),
protocol_factory=apns_protocol_factory,
host=self.protocol_class.APNS_SERVER,
port=self.protocol_class.APNS_PORT,
ssl=self.ssl_context,
Expand All @@ -488,12 +533,16 @@ def __init__(
max_connection_attempts: int = 5,
use_sandbox: bool = False,
ssl_context: Optional[ssl.SSLContext] = None,
proxy_host: Optional[str] = None,
proxy_port: Optional[int] = None,
) -> None:
super(APNsKeyConnectionPool, self).__init__(
topic=topic,
max_connections=max_connections,
max_connection_attempts=max_connection_attempts,
use_sandbox=use_sandbox,
proxy_host=proxy_host,
proxy_port=proxy_port,
)

self.ssl_context = ssl_context or ssl.create_default_context()
Expand All @@ -508,16 +557,101 @@ async def create_connection(self) -> APNsBaseClientProtocol:
auth_provider = JWTAuthorizationHeaderProvider(
key=self.key, key_id=self.key_id, team_id=self.team_id
)
_, protocol = await self.loop.create_connection(
protocol_factory=partial(
apns_protocol_factory = (
partial(
self.protocol_class,
self.apns_topic,
self.loop,
self.discard_connection,
auth_provider,
),
)

if self.proxy_host and self.proxy_port:
return await self._create_proxy_connection(apns_protocol_factory)
else:
return await self._create_connection(apns_protocol_factory)

async def _create_connection(
self, apns_protocol_factory
) -> APNsBaseClientProtocol:
_, protocol = await self.loop.create_connection(
protocol_factory=apns_protocol_factory,
host=self.protocol_class.APNS_SERVER,
port=self.protocol_class.APNS_PORT,
ssl=self.ssl_context,
)
return protocol


class HttpProxyProtocol(asyncio.Protocol):
def __init__(
self,
apns_host: str,
apns_port: int,
loop: asyncio.AbstractEventLoop,
ssl_context: ssl.SSLContext,
protocol_factory,
):
self.apns_host = apns_host
self.apns_port = apns_port
self.buffer = bytearray()
self.loop = loop
self.ssl_context = ssl_context
self.apns_protocol_factory = protocol_factory
self.apns_protocol: Optional[APNsBaseClientProtocol] = None
self.transport = None
self.apns_connection_ready = (
asyncio.Event()
) # Event to signal APNs readiness

def connection_made(self, transport):
logger.debug(
"Proxy connection made.",
)
self.transport = transport
connect_request = (
f"CONNECT {self.apns_host}:{self.apns_port} "
f"HTTP/1.1\r\nHost: "
f"{self.apns_host}\r\nConnection: close\r\n\r\n"
)
self.transport.write(connect_request.encode("utf-8"))

def data_received(self, data):
# Data is usually received in bytes,
# so you might want to decode or process it
logger.debug("Raw data received: %s", data)
self.buffer.extend(data)
# some proxies send "HTTP/1.1 200 Connection established"
# others "HTTP/1.1 200 Connected"
if b"HTTP/1.1 200 Connect" in data:
logger.debug(
"Proxy tunnel established.",
)
asyncio.create_task(self.create_apns_connection())
else:
logger.debug(
"Data received (before APNs connection establishment): %s",
data.decode(),
)

async def create_apns_connection(self):
# Use the existing transport to create a new APNs connection
logger.debug(
"Initiating APNs connection.",
)
sock = self.transport.get_extra_info("socket")
_, self.apns_protocol = await self.loop.create_connection(
self.apns_protocol_factory,
server_hostname=self.apns_host,
ssl=self.ssl_context,
sock=sock,
)
# Signal that APNs connection is ready
self.apns_connection_ready.set()

def connection_lost(self, exc):
logger.debug(
"Proxy connection lost.",
)
self.transport.close()

0 comments on commit 22e4f41

Please sign in to comment.