From 391e0493317c2d875400e751c5043eec3d4ef031 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Sun, 19 Oct 2025 10:51:05 -0700 Subject: [PATCH] refactor!: Use RxJava for VertexAiClient In addition to the refactor, use the built in RxJava sleep functionality instead of Thread.sleep(). Also, adding some randomness to the LRO checking on createSession to test out that the LRO logic works. PiperOrigin-RevId: 821356859 --- .../google/adk/sessions/VertexAiClient.java | 196 ++++++++------- .../adk/sessions/VertexAiSessionService.java | 227 ++++++++++-------- 2 files changed, 235 insertions(+), 188 deletions(-) diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiClient.java b/core/src/main/java/com/google/adk/sessions/VertexAiClient.java index 21239db3a..d35bbccae 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiClient.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiClient.java @@ -9,12 +9,16 @@ import com.google.common.base.Splitter; import com.google.common.collect.Iterables; import com.google.genai.types.HttpOptions; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; import java.io.IOException; import java.io.UncheckedIOException; import java.util.List; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.TimeoutException; import javax.annotation.Nullable; import okhttp3.ResponseBody; import org.slf4j.Logger; @@ -46,8 +50,7 @@ final class VertexAiClient { new HttpApiClient(Optional.of(project), Optional.of(location), credentials, httpOptions); } - @Nullable - JsonNode createSession( + Maybe createSession( String reasoningEngineId, String userId, ConcurrentMap state) { ConcurrentHashMap sessionJsonMap = new ConcurrentHashMap<>(); sessionJsonMap.put("userId", userId); @@ -55,95 +58,116 @@ JsonNode createSession( sessionJsonMap.put("sessionState", state); } - String sessId; - String operationId; - try { - String sessionJson = objectMapper.writeValueAsString(sessionJsonMap); - try (ApiResponse apiResponse = - apiClient.request( - "POST", "reasoningEngines/" + reasoningEngineId + "/sessions", sessionJson)) { - logger.debug("Create Session response {}", apiResponse.getResponseBody()); - if (apiResponse == null || apiResponse.getResponseBody() == null) { - return null; - } - - JsonNode jsonResponse = getJsonResponse(apiResponse); - if (jsonResponse == null) { - return null; - } - String sessionName = jsonResponse.get("name").asText(); - List parts = Splitter.on('/').splitToList(sessionName); - sessId = parts.get(parts.size() - 3); - operationId = Iterables.getLast(parts); - } - } catch (IOException e) { - throw new UncheckedIOException(e); - } + return Single.fromCallable(() -> objectMapper.writeValueAsString(sessionJsonMap)) + .flatMap( + sessionJson -> + performApiRequest( + "POST", "reasoningEngines/" + reasoningEngineId + "/sessions", sessionJson)) + .flatMapMaybe( + apiResponse -> { + logger.debug("Create Session response {}", apiResponse.getResponseBody()); + return getJsonResponse(apiResponse); + }) + .flatMap( + jsonResponse -> { + String sessionName = jsonResponse.get("name").asText(); + List parts = Splitter.on('/').splitToList(sessionName); + String sessId = parts.get(parts.size() - 3); + String operationId = Iterables.getLast(parts); + + return pollOperation(operationId, 0).andThen(getSession(reasoningEngineId, sessId)); + }); + } - for (int i = 0; i < MAX_RETRY_ATTEMPTS; i++) { - try (ApiResponse lroResponse = apiClient.request("GET", "operations/" + operationId, "")) { - JsonNode lroJsonResponse = getJsonResponse(lroResponse); - if (lroJsonResponse != null && lroJsonResponse.get("done") != null) { - break; - } - } - try { - SECONDS.sleep(1); - } catch (InterruptedException e) { - logger.warn("Error during sleep", e); - Thread.currentThread().interrupt(); - } + /** + * Polls the status of a long-running operation. + * + * @param operationId The ID of the operation to poll. + * @param attempt The current retry attempt number (starting from 0). + * @return A Completable that completes when the operation is done, or errors with + * TimeoutException if max retries are exceeded. + */ + private Completable pollOperation(String operationId, int attempt) { + if (attempt >= MAX_RETRY_ATTEMPTS) { + return Completable.error( + new TimeoutException("Operation " + operationId + " did not complete in time.")); } - return getSession(reasoningEngineId, sessId); + return performApiRequest("GET", "operations/" + operationId, "") + .flatMapMaybe(VertexAiClient::getJsonResponse) + .flatMapCompletable( + lroJsonResponse -> { + if (lroJsonResponse != null && lroJsonResponse.get("done") != null) { + return Completable.complete(); // Operation is done + } else { + // Not done, retry after a delay + return Completable.timer(1, SECONDS) + .andThen(pollOperation(operationId, attempt + 1)); + } + }); } - JsonNode listSessions(String reasoningEngineId, String userId) { - try (ApiResponse apiResponse = - apiClient.request( + Maybe listSessions(String reasoningEngineId, String userId) { + return performApiRequest( "GET", "reasoningEngines/" + reasoningEngineId + "/sessions?filter=user_id=" + userId, - "")) { - return getJsonResponse(apiResponse); - } + "") + .flatMapMaybe(VertexAiClient::getJsonResponse); } - JsonNode listEvents(String reasoningEngineId, String sessionId) { - try (ApiResponse apiResponse = - apiClient.request( + Maybe listEvents(String reasoningEngineId, String sessionId) { + return performApiRequest( "GET", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + "/events", - "")) { - logger.debug("List events response {}", apiResponse); - return getJsonResponse(apiResponse); - } + "") + .doOnSuccess(apiResponse -> logger.debug("List events response {}", apiResponse)) + .flatMapMaybe(VertexAiClient::getJsonResponse); } - JsonNode getSession(String reasoningEngineId, String sessionId) { - try (ApiResponse apiResponse = - apiClient.request( - "GET", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "")) { - return getJsonResponse(apiResponse); - } + Maybe getSession(String reasoningEngineId, String sessionId) { + return performApiRequest( + "GET", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "") + .flatMapMaybe(apiResponse -> getJsonResponse(apiResponse)); } - void deleteSession(String reasoningEngineId, String sessionId) { - try (ApiResponse response = - apiClient.request( - "DELETE", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "")) {} + Completable deleteSession(String reasoningEngineId, String sessionId) { + return performApiRequest( + "DELETE", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "") + .doOnSuccess(ApiResponse::close) + .ignoreElement(); } - void appendEvent(String reasoningEngineId, String sessionId, String eventJson) { - try (ApiResponse response = - apiClient.request( + Completable appendEvent(String reasoningEngineId, String sessionId, String eventJson) { + return performApiRequest( "POST", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + ":appendEvent", - eventJson)) { - if (response.getResponseBody().string().contains("com.google.genai.errors.ClientException")) { - logger.warn("Failed to append event: {}", eventJson); - } - } catch (IOException e) { - throw new UncheckedIOException(e); - } + eventJson) + .flatMapCompletable( + response -> { + try (response) { + ResponseBody responseBody = response.getResponseBody(); + if (responseBody != null) { + String responseString = responseBody.string(); + if (responseString.contains("com.google.genai.errors.ClientException")) { + logger.warn("Failed to append event: {}", eventJson); + } + } + return Completable.complete(); + } catch (IOException e) { + return Completable.error(new UncheckedIOException(e)); + } + }); + } + + /** + * Performs an API request and returns a Single emitting the ApiResponse. + * + *

Note: The caller is responsible for closing the returned {@link ApiResponse}. + */ + private Single performApiRequest(String method, String path, String body) { + return Single.fromCallable( + () -> { + return apiClient.request(method, path, body); + }); } /** @@ -152,19 +176,23 @@ void appendEvent(String reasoningEngineId, String sessionId, String eventJson) { * @throws UncheckedIOException if parsing fails. */ @Nullable - private static JsonNode getJsonResponse(ApiResponse apiResponse) { - if (apiResponse == null || apiResponse.getResponseBody() == null) { - return null; - } + private static Maybe getJsonResponse(ApiResponse apiResponse) { try { - ResponseBody responseBody = apiResponse.getResponseBody(); - String responseString = responseBody.string(); - if (responseString.isEmpty()) { - return null; + if (apiResponse == null || apiResponse.getResponseBody() == null) { + return Maybe.empty(); + } + try { + ResponseBody responseBody = apiResponse.getResponseBody(); + String responseString = responseBody.string(); // Read body here + if (responseString.isEmpty()) { + return Maybe.empty(); + } + return Maybe.just(objectMapper.readTree(responseString)); + } catch (IOException e) { + return Maybe.error(new UncheckedIOException(e)); } - return objectMapper.readTree(responseString); - } catch (IOException e) { - throw new UncheckedIOException(e); + } finally { + apiResponse.close(); } } } diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java index 2e0934be5..0321ef281 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java @@ -78,11 +78,20 @@ public Single createSession( @Nullable String sessionId) { String reasoningEngineId = parseReasoningEngineId(appName); - JsonNode getSessionResponseMap = client.createSession(reasoningEngineId, userId, state); + return client + .createSession(reasoningEngineId, userId, state) + .map( + getSessionResponseMap -> + parseSession(getSessionResponseMap, appName, userId, sessionId)) + .toSingle(); + } + + private static Session parseSession( + JsonNode getSessionResponseMap, String appName, String userId, String fallbackSessionId) { String sessId = Optional.ofNullable(getSessionResponseMap.get("name")) .map(name -> Iterables.getLast(Splitter.on('/').splitToList(name.asText()))) - .orElse(sessionId); + .orElse(fallbackSessionId); Instant updateTimestamp = Instant.parse(getSessionResponseMap.get("updateTime").asText()); ConcurrentMap sessionState = null; if (getSessionResponseMap != null && getSessionResponseMap.has("sessionState")) { @@ -93,25 +102,28 @@ public Single createSession( sessionStateNode, new TypeReference>() {}); } } - return Single.just( - Session.builder(sessId) - .appName(appName) - .userId(userId) - .lastUpdateTime(updateTimestamp) - .state(sessionState == null ? new ConcurrentHashMap<>() : sessionState) - .build()); + return Session.builder(sessId) + .appName(appName) + .userId(userId) + .lastUpdateTime(updateTimestamp) + .state(sessionState == null ? new ConcurrentHashMap<>() : sessionState) + .build(); } @Override public Single listSessions(String appName, String userId) { String reasoningEngineId = parseReasoningEngineId(appName); - JsonNode listSessionsResponseMap = client.listSessions(reasoningEngineId, userId); + return client + .listSessions(reasoningEngineId, userId) + .map( + listSessionsResponseMap -> + parseListSessionsResponse(listSessionsResponseMap, appName, userId)) + .defaultIfEmpty(ListSessionsResponse.builder().build()); + } - // Handles empty response case - if (listSessionsResponseMap == null) { - return Single.just(ListSessionsResponse.builder().build()); - } + private ListSessionsResponse parseListSessionsResponse( + JsonNode listSessionsResponseMap, String appName, String userId) { List> apiSessions = objectMapper.convertValue( listSessionsResponseMap.get("sessions"), @@ -131,125 +143,132 @@ public Single listSessions(String appName, String userId) .build(); sessions.add(session); } - return Single.just(ListSessionsResponse.builder().sessions(sessions).build()); + return ListSessionsResponse.builder().sessions(sessions).build(); } @Override public Single listEvents(String appName, String userId, String sessionId) { String reasoningEngineId = parseReasoningEngineId(appName); - JsonNode listEventsResponse = client.listEvents(reasoningEngineId, sessionId); - - if (listEventsResponse == null) { - return Single.just(ListEventsResponse.builder().build()); - } + return client + .listEvents(reasoningEngineId, sessionId) + .map(this::parseListEventsResponse) + .defaultIfEmpty(ListEventsResponse.builder().build()); + } + private ListEventsResponse parseListEventsResponse(JsonNode listEventsResponse) { JsonNode sessionEventsNode = listEventsResponse.get("sessionEvents"); if (sessionEventsNode == null || sessionEventsNode.isEmpty()) { - return Single.just(ListEventsResponse.builder().events(new ArrayList<>()).build()); + return ListEventsResponse.builder().events(new ArrayList<>()).build(); } - return Single.just( - ListEventsResponse.builder() - .events( - objectMapper - .convertValue( - sessionEventsNode, - new TypeReference>>() {}) - .stream() - .map(SessionJsonConverter::fromApiEvent) - .collect(toCollection(ArrayList::new))) - .build()); + return ListEventsResponse.builder() + .events( + objectMapper + .convertValue( + sessionEventsNode, new TypeReference>>() {}) + .stream() + .map(SessionJsonConverter::fromApiEvent) + .collect(toCollection(ArrayList::new))) + .build(); } @Override public Maybe getSession( String appName, String userId, String sessionId, Optional config) { String reasoningEngineId = parseReasoningEngineId(appName); - JsonNode getSessionResponseMap = client.getSession(reasoningEngineId, sessionId); + return client + .getSession(reasoningEngineId, sessionId) + .flatMap( + getSessionResponseMap -> { + String sessId = + Optional.ofNullable(getSessionResponseMap.get("name")) + .map(name -> Iterables.getLast(Splitter.on('/').splitToList(name.asText()))) + .orElse(sessionId); + Instant updateTimestamp = + Optional.ofNullable(getSessionResponseMap.get("updateTime")) + .map(updateTime -> Instant.parse(updateTime.asText())) + .orElse(null); - if (getSessionResponseMap == null) { - return Maybe.empty(); - } - - String sessId = - Optional.ofNullable(getSessionResponseMap.get("name")) - .map(name -> Iterables.getLast(Splitter.on('/').splitToList(name.asText()))) - .orElse(sessionId); - Instant updateTimestamp = - Optional.ofNullable(getSessionResponseMap.get("updateTime")) - .map(updateTime -> Instant.parse(updateTime.asText())) - .orElse(null); + ConcurrentMap sessionState = new ConcurrentHashMap<>(); + if (getSessionResponseMap != null && getSessionResponseMap.has("sessionState")) { + sessionState.putAll( + objectMapper.convertValue( + getSessionResponseMap.get("sessionState"), + new TypeReference>() {})); + } - ConcurrentMap sessionState = new ConcurrentHashMap<>(); - if (getSessionResponseMap != null && getSessionResponseMap.has("sessionState")) { - sessionState.putAll( - objectMapper.convertValue( - getSessionResponseMap.get("sessionState"), - new TypeReference>() {})); - } + return listEvents(appName, userId, sessionId) + .map( + response -> { + Session.Builder sessionBuilder = + Session.builder(sessId) + .appName(appName) + .userId(userId) + .lastUpdateTime(updateTimestamp) + .state(sessionState); + List events = response.events(); + if (events.isEmpty()) { + return sessionBuilder.build(); + } + events = filterEvents(events, updateTimestamp, config); + return sessionBuilder.events(events).build(); + }) + .toMaybe(); + }); + } - return listEvents(appName, userId, sessionId) - .map( - response -> { - Session.Builder sessionBuilder = - Session.builder(sessId) - .appName(appName) - .userId(userId) - .lastUpdateTime(updateTimestamp) - .state(sessionState); - List events = response.events(); - if (events.isEmpty()) { - return sessionBuilder.build(); - } - events = - events.stream() - .filter( - event -> - updateTimestamp == null - || Instant.ofEpochMilli(event.timestamp()) - .isBefore(updateTimestamp)) - .sorted(Comparator.comparing(Event::timestamp)) - .collect(toCollection(ArrayList::new)); + private static List filterEvents( + List originalEvents, + @Nullable Instant updateTimestamp, + Optional config) { + List events = + originalEvents.stream() + .filter( + event -> + updateTimestamp == null + || Instant.ofEpochMilli(event.timestamp()).isBefore(updateTimestamp)) + .sorted(Comparator.comparing(Event::timestamp)) + .collect(toCollection(ArrayList::new)); - if (config.isPresent()) { - if (config.get().numRecentEvents().isPresent()) { - int numRecentEvents = config.get().numRecentEvents().get(); - if (events.size() > numRecentEvents) { - events = events.subList(events.size() - numRecentEvents, events.size()); - } - } else if (config.get().afterTimestamp().isPresent()) { - Instant afterTimestamp = config.get().afterTimestamp().get(); - int i = events.size() - 1; - while (i >= 0) { - if (Instant.ofEpochMilli(events.get(i).timestamp()).isBefore(afterTimestamp)) { - break; - } - i -= 1; - } - if (i >= 0) { - events = events.subList(i, events.size()); - } - } - } - return sessionBuilder.events(events).build(); - }) - .toMaybe(); + if (config.isPresent()) { + if (config.get().numRecentEvents().isPresent()) { + int numRecentEvents = config.get().numRecentEvents().get(); + if (events.size() > numRecentEvents) { + events = events.subList(events.size() - numRecentEvents, events.size()); + } + } else if (config.get().afterTimestamp().isPresent()) { + Instant afterTimestamp = config.get().afterTimestamp().get(); + int i = events.size() - 1; + while (i >= 0) { + if (Instant.ofEpochMilli(events.get(i).timestamp()).isBefore(afterTimestamp)) { + break; + } + i -= 1; + } + if (i >= 0) { + events = events.subList(i, events.size()); + } + } + } + return events; } @Override public Completable deleteSession(String appName, String userId, String sessionId) { String reasoningEngineId = parseReasoningEngineId(appName); - client.deleteSession(reasoningEngineId, sessionId); - return Completable.complete(); + return client.deleteSession(reasoningEngineId, sessionId); } @Override public Single appendEvent(Session session, Event event) { - BaseSessionService.super.appendEvent(session, event); - String reasoningEngineId = parseReasoningEngineId(session.appName()); - client.appendEvent( - reasoningEngineId, session.id(), SessionJsonConverter.convertEventToJson(event)); - return Single.just(event); + return BaseSessionService.super + .appendEvent(session, event) + .flatMap( + e -> + client + .appendEvent( + reasoningEngineId, session.id(), SessionJsonConverter.convertEventToJson(e)) + .toSingleDefault(e)); } /**