Skip to content

Commit 95186e9

Browse files
authored
Merge branch 'main' into andystaples/add-replay-safe-logging
2 parents e0b9747 + c62329c commit 95186e9

13 files changed

Lines changed: 448 additions & 68 deletions

File tree

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@ ADDED
1111

1212
- Added `ReplaySafeLogger` and `OrchestrationContext.create_replay_safe_logger()`
1313
for suppressing duplicate log messages during orchestrator replay
14+
- Added `GrpcChannelOptions` and `GrpcRetryPolicyOptions` for configuring
15+
gRPC transport behavior, including message-size limits, keepalive settings,
16+
and channel-level retry policy service configuration.
17+
- Added optional `channel` and `channel_options` parameters to
18+
`TaskHubGrpcClient`, `AsyncTaskHubGrpcClient`, and `TaskHubGrpcWorker` to
19+
support pre-configured channel passthrough and low-level gRPC channel
20+
customization.
1421
- Added `get_orchestration_history()` and `list_instance_ids()` to the sync and async gRPC clients.
1522
- Added in-memory backend support for `StreamInstanceHistory` and `ListInstanceIds` so local orchestration tests can retrieve history and page terminal instance IDs by completion window.
1623

durabletask-azuremanaged/CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10+
- Added optional `interceptors`, `channel`, and `channel_options` parameters to
11+
`DurableTaskSchedulerClient`, `AsyncDurableTaskSchedulerClient`, and
12+
`DurableTaskSchedulerWorker` to allow combining custom gRPC interceptors with
13+
DTS defaults and to support pre-configured/customized gRPC channels.
14+
- Added `workerid` gRPC metadata on Durable Task Scheduler worker calls for
15+
improved worker identity and observability.
16+
- Improved sync access token refresh concurrency handling to avoid duplicate
17+
refresh operations under concurrent access.
18+
1019
## v1.4.0
1120

1221
- Updates base dependency to durabletask v1.4.0

durabletask-azuremanaged/durabletask/azuremanaged/client.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33

44
import logging
55

6-
from typing import Optional
6+
from typing import Optional, Sequence
77

8+
import grpc
9+
import grpc.aio
810
from azure.core.credentials import TokenCredential
911
from azure.core.credentials_async import AsyncTokenCredential
1012

@@ -13,6 +15,8 @@
1315
DTSDefaultClientInterceptorImpl,
1416
)
1517
from durabletask.client import AsyncTaskHubGrpcClient, TaskHubGrpcClient
18+
from durabletask.grpc_options import GrpcChannelOptions
19+
import durabletask.internal.shared as shared
1620
from durabletask.payload.store import PayloadStore
1721

1822

@@ -22,7 +26,10 @@ def __init__(self, *,
2226
host_address: str,
2327
taskhub: str,
2428
token_credential: Optional[TokenCredential],
29+
channel: Optional[grpc.Channel] = None,
2530
secure_channel: bool = True,
31+
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
32+
channel_options: Optional[GrpcChannelOptions] = None,
2633
default_version: Optional[str] = None,
2734
payload_store: Optional[PayloadStore] = None,
2835
log_handler: Optional[logging.Handler] = None,
@@ -31,17 +38,22 @@ def __init__(self, *,
3138
if not taskhub:
3239
raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub")
3340

34-
interceptors = [DTSDefaultClientInterceptorImpl(token_credential, taskhub)]
41+
resolved_interceptors: list[shared.ClientInterceptor] = (
42+
list(interceptors) if interceptors is not None else []
43+
)
44+
resolved_interceptors.append(DTSDefaultClientInterceptorImpl(token_credential, taskhub))
3545

3646
# We pass in None for the metadata so we don't construct an additional interceptor in the parent class
3747
# Since the parent class doesn't use anything metadata for anything else, we can set it as None
3848
super().__init__(
3949
host_address=host_address,
50+
channel=channel,
4051
secure_channel=secure_channel,
4152
metadata=None,
4253
log_handler=log_handler,
4354
log_formatter=log_formatter,
44-
interceptors=interceptors,
55+
interceptors=resolved_interceptors,
56+
channel_options=channel_options,
4557
default_version=default_version,
4658
payload_store=payload_store)
4759

@@ -88,7 +100,10 @@ def __init__(self, *,
88100
host_address: str,
89101
taskhub: str,
90102
token_credential: Optional[AsyncTokenCredential],
103+
channel: Optional[grpc.aio.Channel] = None,
91104
secure_channel: bool = True,
105+
interceptors: Optional[Sequence[shared.AsyncClientInterceptor]] = None,
106+
channel_options: Optional[GrpcChannelOptions] = None,
92107
default_version: Optional[str] = None,
93108
payload_store: Optional[PayloadStore] = None,
94109
log_handler: Optional[logging.Handler] = None,
@@ -97,16 +112,21 @@ def __init__(self, *,
97112
if not taskhub:
98113
raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub")
99114

100-
interceptors = [DTSAsyncDefaultClientInterceptorImpl(token_credential, taskhub)]
115+
resolved_interceptors: list[shared.AsyncClientInterceptor] = (
116+
list(interceptors) if interceptors is not None else []
117+
)
118+
resolved_interceptors.append(DTSAsyncDefaultClientInterceptorImpl(token_credential, taskhub))
101119

102120
# We pass in None for the metadata so we don't construct an additional interceptor in the parent class
103121
# Since the parent class doesn't use anything metadata for anything else, we can set it as None
104122
super().__init__(
105123
host_address=host_address,
124+
channel=channel,
106125
secure_channel=secure_channel,
107126
metadata=None,
108127
log_handler=log_handler,
109128
log_formatter=log_formatter,
110-
interceptors=interceptors,
129+
interceptors=resolved_interceptors,
130+
channel_options=channel_options,
111131
default_version=default_version,
112132
payload_store=payload_store)

durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33
from datetime import datetime, timedelta, timezone
4+
from threading import Lock
45
from typing import Optional
56

67
from azure.core.credentials import AccessToken, TokenCredential
@@ -20,6 +21,7 @@ def __init__(self, token_credential: Optional[TokenCredential], refresh_interval
2021
self._logger = shared.get_logger("token_manager")
2122

2223
self._credential = token_credential
24+
self._refresh_lock = Lock()
2325

2426
if self._credential is not None:
2527
self._token = self._credential.get_token(self._scope)
@@ -30,7 +32,9 @@ def __init__(self, token_credential: Optional[TokenCredential], refresh_interval
3032

3133
def get_access_token(self) -> Optional[AccessToken]:
3234
if self._token is None or self.is_token_expired():
33-
self.refresh_token()
35+
with self._refresh_lock:
36+
if self._token is None or self.is_token_expired():
37+
self.refresh_token()
3438
return self._token
3539

3640
# Checks if the token is expired, or if it will expire in the next "refresh_interval_seconds" seconds.

durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl):
2525
StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an
2626
interceptor to add additional headers to all calls as needed."""
2727

28-
def __init__(self, token_credential: Optional[TokenCredential], taskhub_name: str):
28+
def __init__(
29+
self,
30+
token_credential: Optional[TokenCredential],
31+
taskhub_name: str,
32+
worker_id: Optional[str] = None):
2933
try:
3034
# Get the version of the azuremanaged package
3135
sdk_version = version('durabletask-azuremanaged')
@@ -35,7 +39,9 @@ def __init__(self, token_credential: Optional[TokenCredential], taskhub_name: st
3539
user_agent = f"durabletask-python/{sdk_version}"
3640
self._metadata = [
3741
("taskhub", taskhub_name),
38-
("x-user-agent", user_agent)] # 'user-agent' is a reserved header in grpc, so we use 'x-user-agent' instead
42+
("x-user-agent", user_agent)] # 'user-agent' is a reserved header; use 'x-user-agent'
43+
if worker_id is not None:
44+
self._metadata.append(("workerid", worker_id))
3945
super().__init__(self._metadata)
4046

4147
self._token_manager = None
@@ -44,7 +50,17 @@ def __init__(self, token_credential: Optional[TokenCredential], taskhub_name: st
4450
self._token_manager = AccessTokenManager(token_credential=self._token_credential)
4551
access_token = self._token_manager.get_access_token()
4652
if access_token is not None:
47-
self._metadata.append(("authorization", f"Bearer {access_token.token}"))
53+
self._upsert_authorization_header(access_token.token)
54+
55+
def _upsert_authorization_header(self, token: str) -> None:
56+
found = False
57+
for i, (key, _) in enumerate(self._metadata):
58+
if key.lower() == "authorization":
59+
self._metadata[i] = ("authorization", f"Bearer {token}")
60+
found = True
61+
break
62+
if not found:
63+
self._metadata.append(("authorization", f"Bearer {token}"))
4864

4965
def _intercept_call(
5066
self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:
@@ -56,15 +72,7 @@ def _intercept_call(
5672
if self._token_manager is not None:
5773
access_token = self._token_manager.get_access_token()
5874
if access_token is not None:
59-
# Update the existing authorization header
60-
found = False
61-
for i, (key, _) in enumerate(self._metadata):
62-
if key.lower() == "authorization":
63-
self._metadata[i] = ("authorization", f"Bearer {access_token.token}")
64-
found = True
65-
break
66-
if not found:
67-
self._metadata.append(("authorization", f"Bearer {access_token.token}"))
75+
self._upsert_authorization_header(access_token.token)
6876

6977
return super()._intercept_call(client_call_details)
7078

@@ -96,6 +104,16 @@ def __init__(self, token_credential: Optional[AsyncTokenCredential], taskhub_nam
96104
self._token_credential = token_credential
97105
self._token_manager = AsyncAccessTokenManager(token_credential=self._token_credential)
98106

107+
def _upsert_authorization_header(self, token: str) -> None:
108+
found = False
109+
for i, (key, _) in enumerate(self._metadata):
110+
if key.lower() == "authorization":
111+
self._metadata[i] = ("authorization", f"Bearer {token}")
112+
found = True
113+
break
114+
if not found:
115+
self._metadata.append(("authorization", f"Bearer {token}"))
116+
99117
async def _intercept_call(
100118
self, client_call_details: _AsyncClientCallDetails) -> grpc.aio.ClientCallDetails:
101119
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
@@ -106,16 +124,6 @@ async def _intercept_call(
106124
if self._token_manager is not None:
107125
access_token = await self._token_manager.get_access_token()
108126
if access_token is not None:
109-
# Update the existing authorization header, or append one if this
110-
# is the first successful token acquisition (token is lazily
111-
# fetched on the first call since async constructors aren't possible).
112-
found = False
113-
for i, (key, _) in enumerate(self._metadata):
114-
if key.lower() == "authorization":
115-
self._metadata[i] = ("authorization", f"Bearer {access_token.token}")
116-
found = True
117-
break
118-
if not found:
119-
self._metadata.append(("authorization", f"Bearer {access_token.token}"))
127+
self._upsert_authorization_header(access_token.token)
120128

121129
return await super()._intercept_call(client_call_details)

durabletask-azuremanaged/durabletask/azuremanaged/worker.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,19 @@
22
# Licensed under the MIT License.
33

44
import logging
5+
import os
6+
import socket
7+
import uuid
58

6-
from typing import Optional
9+
from typing import Optional, Sequence
710

11+
import grpc
812
from azure.core.credentials import TokenCredential
913

1014
from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import \
1115
DTSDefaultClientInterceptorImpl
16+
from durabletask.grpc_options import GrpcChannelOptions
17+
import durabletask.internal.shared as shared
1218
from durabletask.payload.store import PayloadStore
1319
from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker
1420

@@ -64,7 +70,10 @@ def __init__(self, *,
6470
host_address: str,
6571
taskhub: str,
6672
token_credential: Optional[TokenCredential],
73+
channel: Optional[grpc.Channel] = None,
6774
secure_channel: bool = True,
75+
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
76+
channel_options: Optional[GrpcChannelOptions] = None,
6877
concurrency_options: Optional[ConcurrencyOptions] = None,
6978
payload_store: Optional[PayloadStore] = None,
7079
log_handler: Optional[logging.Handler] = None,
@@ -73,17 +82,25 @@ def __init__(self, *,
7382
if not taskhub:
7483
raise ValueError("The taskhub value cannot be empty.")
7584

76-
interceptors = [DTSDefaultClientInterceptorImpl(token_credential, taskhub)]
85+
worker_id = f"{socket.gethostname()}:{os.getpid()}:{uuid.uuid4()}"
86+
resolved_interceptors: list[shared.ClientInterceptor] = (
87+
list(interceptors) if interceptors is not None else []
88+
)
89+
resolved_interceptors.append(
90+
DTSDefaultClientInterceptorImpl(token_credential, taskhub, worker_id=worker_id)
91+
)
7792

7893
# We pass in None for the metadata so we don't construct an additional interceptor in the parent class
7994
# Since the parent class doesn't use anything metadata for anything else, we can set it as None
8095
super().__init__(
8196
host_address=host_address,
97+
channel=channel,
8298
secure_channel=secure_channel,
8399
metadata=None,
84100
log_handler=log_handler,
85101
log_formatter=log_formatter,
86-
interceptors=interceptors,
102+
interceptors=resolved_interceptors,
103+
channel_options=channel_options,
87104
concurrency_options=concurrency_options,
88105
# DTS natively supports long timers so chunking is unnecessary
89106
maximum_timer_interval=None,

durabletask/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
"""Durable Task SDK for Python"""
55

6+
from durabletask.grpc_options import GrpcChannelOptions, GrpcRetryPolicyOptions
67
from durabletask.payload.store import LargePayloadStorageOptions, PayloadStore
78
from durabletask.worker import (
89
ActivityWorkItemFilter,
@@ -17,6 +18,8 @@
1718
"ActivityWorkItemFilter",
1819
"ConcurrencyOptions",
1920
"EntityWorkItemFilter",
21+
"GrpcChannelOptions",
22+
"GrpcRetryPolicyOptions",
2023
"LargePayloadStorageOptions",
2124
"OrchestrationWorkItemFilter",
2225
"PayloadStore",

0 commit comments

Comments
 (0)