From fa9508ca3784ee295bcf129349f01e4cd2fd9020 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 26 Sep 2025 08:27:07 -0700 Subject: [PATCH] refactor: Extracting MockApiAnswer from VertexAiSessionServiceTest PiperOrigin-RevId: 811828740 --- .../google/adk/sessions/MockApiAnswer.java | 216 ++++++++++++++++++ .../sessions/VertexAiSessionServiceTest.java | 203 +--------------- 2 files changed, 217 insertions(+), 202 deletions(-) create mode 100644 core/src/test/java/com/google/adk/sessions/MockApiAnswer.java diff --git a/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java new file mode 100644 index 000000000..5e8f3d992 --- /dev/null +++ b/core/src/test/java/com/google/adk/sessions/MockApiAnswer.java @@ -0,0 +1,216 @@ +package com.google.adk.sessions; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.JsonBaseModel; +import com.google.adk.events.Event; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import okhttp3.MediaType; +import okhttp3.ResponseBody; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +/** Mocks the http calls to Vertex AI API. */ +class MockApiAnswer implements Answer { + private static final ObjectMapper mapper = JsonBaseModel.getMapper(); + private static final Pattern LRO_REGEX = Pattern.compile("^operations/([^/]+)$"); + private static final Pattern SESSION_REGEX = + Pattern.compile("^reasoningEngines/([^/]+)/sessions/([^/]+)$"); + private static final Pattern SESSIONS_REGEX = + Pattern.compile("^reasoningEngines/([^/]+)/sessions$"); + private static final Pattern SESSIONS_FILTER_REGEX = + Pattern.compile("^reasoningEngines/([^/]+)/sessions\\?filter=user_id=([^/]+)$"); + private static final Pattern APPEND_EVENT_REGEX = + Pattern.compile("^reasoningEngines/([^/]+)/sessions/([^/]+):appendEvent$"); + private static final Pattern EVENTS_REGEX = + Pattern.compile("^reasoningEngines/([^/]+)/sessions/([^/]+)/events$"); + private static final MediaType JSON_MEDIA_TYPE = + MediaType.parse("application/json; charset=utf-8"); + + private final Map sessionMap; + private final Map eventMap; + + MockApiAnswer(Map sessionMap, Map eventMap) { + this.sessionMap = sessionMap; + this.eventMap = eventMap; + } + + @Override + public ApiResponse answer(InvocationOnMock invocation) throws Throwable { + String httpMethod = invocation.getArgument(0); + String path = invocation.getArgument(1); + if (httpMethod.equals("POST") && SESSIONS_REGEX.matcher(path).matches()) { + return handleCreateSession(path, invocation); + } else if (httpMethod.equals("GET") && SESSION_REGEX.matcher(path).matches()) { + return handleGetSession(path); + } else if (httpMethod.equals("GET") && SESSIONS_FILTER_REGEX.matcher(path).matches()) { + return handleGetSessions(path); + } else if (httpMethod.equals("POST") && APPEND_EVENT_REGEX.matcher(path).matches()) { + return handleAppendEvent(path, invocation); + } else if (httpMethod.equals("GET") && EVENTS_REGEX.matcher(path).matches()) { + return handleGetEvents(path); + } else if (httpMethod.equals("GET") && LRO_REGEX.matcher(path).matches()) { + return handleGetLro(path); + } else if (httpMethod.equals("DELETE")) { + return handleDeleteSession(path); + } + throw new RuntimeException( + String.format("Unsupported HTTP method: %s, path: %s", httpMethod, path)); + } + + private static ApiResponse responseWithBody(String body) { + return new ApiResponse() { + @Override + public ResponseBody getResponseBody() { + return ResponseBody.create(JSON_MEDIA_TYPE, body); + } + + @Override + public void close() {} + }; + } + + private ApiResponse handleCreateSession(String path, InvocationOnMock invocation) + throws Exception { + String newSessionId = "4"; + Map requestDict = + mapper.readValue( + (String) invocation.getArgument(2), new TypeReference>() {}); + Map newSessionData = new HashMap<>(); + newSessionData.put("name", path + "/" + newSessionId); + newSessionData.put("userId", requestDict.get("userId")); + newSessionData.put("sessionState", requestDict.get("sessionState")); + newSessionData.put("updateTime", "2024-12-12T12:12:12.123456Z"); + + sessionMap.put(newSessionId, mapper.writeValueAsString(newSessionData)); + + return responseWithBody( + String.format( + """ + { + "name": "%s/%s/operations/111", + "done": false + } + """, + path, newSessionId)); + } + + private ApiResponse handleGetSession(String path) throws Exception { + String sessionId = path.substring(path.lastIndexOf('/') + 1); + if (sessionId.contains("/")) { // Ensure it's a direct session ID + return null; + } + String sessionData = sessionMap.get(sessionId); + if (sessionData != null) { + return responseWithBody(sessionData); + } else { + throw new RuntimeException("Session not found: " + sessionId); + } + } + + private ApiResponse handleGetSessions(String path) throws Exception { + Matcher sessionsMatcher = SESSIONS_FILTER_REGEX.matcher(path); + if (!sessionsMatcher.matches()) { + return null; + } + String userId = sessionsMatcher.group(2); + List userSessionsJson = new ArrayList<>(); + for (String sessionJson : sessionMap.values()) { + Map session = + mapper.readValue(sessionJson, new TypeReference>() {}); + if (session.containsKey("userId") && session.get("userId").equals(userId)) { + userSessionsJson.add(sessionJson); + } + } + return responseWithBody( + String.format( + """ + { + "sessions": [%s] + } + """, + String.join(",", userSessionsJson))); + } + + private ApiResponse handleAppendEvent(String path, InvocationOnMock invocation) { + Matcher appendEventMatcher = APPEND_EVENT_REGEX.matcher(path); + if (!appendEventMatcher.matches()) { + return null; + } + String sessionId = appendEventMatcher.group(2); + String eventDataString = eventMap.get(sessionId); + String newEventDataString = (String) invocation.getArgument(2); + try { + ConcurrentMap newEventData = + mapper.readValue( + newEventDataString, new TypeReference>() {}); + + List> eventsData = new ArrayList<>(); + if (eventDataString != null) { + eventsData.addAll( + mapper.readValue( + eventDataString, new TypeReference>>() {})); + } + + newEventData.put( + "name", path.replaceFirst(":appendEvent$", "/events/" + Event.generateEventId())); + + eventsData.add(newEventData); + + eventMap.put(sessionId, mapper.writeValueAsString(eventsData)); + } catch (Exception e) { + throw new RuntimeException(e); + } + return responseWithBody(newEventDataString); + } + + private ApiResponse handleGetEvents(String path) throws Exception { + Matcher matcher = EVENTS_REGEX.matcher(path); + if (!matcher.matches()) { + return null; + } + String sessionId = matcher.group(2); + String eventData = eventMap.get(sessionId); + if (eventData != null) { + return responseWithBody( + String.format( + """ + { + "sessionEvents": %s + } + """, + eventData)); + } else { + // Return an empty list if no events are found for the session + return responseWithBody("{}"); + } + } + + private ApiResponse handleGetLro(String path) { + return responseWithBody( + String.format( + """ + { + "name": "%s", + "done": true + } + """, + path.replace("/operations/111", ""))); // Simulate LRO done + } + + private ApiResponse handleDeleteSession(String path) { + Matcher sessionMatcher = SESSION_REGEX.matcher(path); + if (!sessionMatcher.matches()) { + return null; + } + String sessionIdToDelete = sessionMatcher.group(2); + sessionMap.remove(sessionIdToDelete); + return responseWithBody(""); + } +} diff --git a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java index cc549d1d7..775b465ff 100644 --- a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java @@ -18,7 +18,6 @@ import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Single; import java.time.Instant; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -26,18 +25,12 @@ import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import okhttp3.MediaType; -import okhttp3.ResponseBody; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; /** Unit tests for {@link VertexAiSessionService}. */ @RunWith(JUnit4.class) @@ -151,204 +144,10 @@ private static Session getMockSession() throws Exception { /** Mock for HttpApiClient to mock the http calls to Vertex AI API. */ @Mock private HttpApiClient mockApiClient; - @Mock private ApiResponse mockApiResponse; private VertexAiSessionService vertexAiSessionService; private Map sessionMap = null; private Map eventMap = null; - private static final Pattern LRO_REGEX = Pattern.compile("^operations/([^/]+)$"); - private static final Pattern SESSION_REGEX = - Pattern.compile("^reasoningEngines/([^/]+)/sessions/([^/]+)$"); - private static final Pattern SESSIONS_REGEX = - Pattern.compile("^reasoningEngines/([^/]+)/sessions$"); - private static final Pattern SESSIONS_FILTER_REGEX = - Pattern.compile("^reasoningEngines/([^/]+)/sessions\\?filter=user_id=([^/]+)$"); - private static final Pattern APPEND_EVENT_REGEX = - Pattern.compile("^reasoningEngines/([^/]+)/sessions/([^/]+):appendEvent$"); - private static final Pattern EVENTS_REGEX = - Pattern.compile("^reasoningEngines/([^/]+)/sessions/([^/]+)/events$"); - - private static class MockApiAnswer implements Answer { - private final Map sessionMap; - private final Map eventMap; - private final ApiResponse mockApiResponse; - - private MockApiAnswer( - Map sessionMap, Map eventMap, ApiResponse mockApiResponse) { - this.sessionMap = sessionMap; - this.eventMap = eventMap; - this.mockApiResponse = mockApiResponse; - } - - @Override - public ApiResponse answer(InvocationOnMock invocation) throws Throwable { - String httpMethod = invocation.getArgument(0); - String path = invocation.getArgument(1); - if (httpMethod.equals("POST") && SESSIONS_REGEX.matcher(path).matches()) { - return handleCreateSession(path, invocation); - } else if (httpMethod.equals("GET") && SESSION_REGEX.matcher(path).matches()) { - return handleGetSession(path); - } else if (httpMethod.equals("GET") && SESSIONS_FILTER_REGEX.matcher(path).matches()) { - return handleGetSessions(path); - } else if (httpMethod.equals("POST") && APPEND_EVENT_REGEX.matcher(path).matches()) { - return handleAppendEvent(path, invocation); - } else if (httpMethod.equals("GET") && EVENTS_REGEX.matcher(path).matches()) { - return handleGetEvents(path); - } else if (httpMethod.equals("GET") && LRO_REGEX.matcher(path).matches()) { - return handleGetLro(path); - } else if (httpMethod.equals("DELETE")) { - return handleDeleteSession(path); - } - throw new RuntimeException( - String.format("Unsupported HTTP method: %s, path: %s", httpMethod, path)); - } - - private ApiResponse mockApiResponseWithBody(String body) { - when(mockApiResponse.getResponseBody()) - .thenReturn( - ResponseBody.create(MediaType.parse("application/json; charset=utf-8"), body)); - return mockApiResponse; - } - - private ApiResponse handleCreateSession(String path, InvocationOnMock invocation) - throws Exception { - String newSessionId = "4"; - Map requestDict = - mapper.readValue( - (String) invocation.getArgument(2), new TypeReference>() {}); - Map newSessionData = new HashMap<>(); - newSessionData.put("name", path + "/" + newSessionId); - newSessionData.put("userId", requestDict.get("userId")); - newSessionData.put("sessionState", requestDict.get("sessionState")); - newSessionData.put("updateTime", "2024-12-12T12:12:12.123456Z"); - - sessionMap.put(newSessionId, mapper.writeValueAsString(newSessionData)); - - return mockApiResponseWithBody( - String.format( - """ - { - "name": "%s/%s/operations/111", - "done": false - } - """, - path, newSessionId)); - } - - private ApiResponse handleGetSession(String path) throws Exception { - String sessionId = path.substring(path.lastIndexOf('/') + 1); - if (sessionId.contains("/")) { // Ensure it's a direct session ID - return null; - } - String sessionData = sessionMap.get(sessionId); - if (sessionData != null) { - return mockApiResponseWithBody(sessionData); - } else { - throw new RuntimeException("Session not found: " + sessionId); - } - } - - private ApiResponse handleGetSessions(String path) throws Exception { - Matcher sessionsMatcher = SESSIONS_FILTER_REGEX.matcher(path); - if (!sessionsMatcher.matches()) { - return null; - } - String userId = sessionsMatcher.group(2); - List userSessionsJson = new ArrayList<>(); - for (String sessionJson : sessionMap.values()) { - Map session = - mapper.readValue(sessionJson, new TypeReference>() {}); - if (session.containsKey("userId") && session.get("userId").equals(userId)) { - userSessionsJson.add(sessionJson); - } - } - return mockApiResponseWithBody( - String.format( - """ - { - "sessions": [%s] - } - """, - String.join(",", userSessionsJson))); - } - - private ApiResponse handleAppendEvent(String path, InvocationOnMock invocation) { - Matcher appendEventMatcher = APPEND_EVENT_REGEX.matcher(path); - if (!appendEventMatcher.matches()) { - return null; - } - String sessionId = appendEventMatcher.group(2); - String eventDataString = eventMap.get(sessionId); - String newEventDataString = (String) invocation.getArgument(2); - try { - ConcurrentMap newEventData = - mapper.readValue( - newEventDataString, new TypeReference>() {}); - - List> eventsData = new ArrayList<>(); - if (eventDataString != null) { - eventsData.addAll( - mapper.readValue( - eventDataString, new TypeReference>>() {})); - } - - newEventData.put( - "name", path.replaceFirst(":appendEvent$", "/events/" + Event.generateEventId())); - - eventsData.add(newEventData); - - eventMap.put(sessionId, mapper.writeValueAsString(eventsData)); - } catch (Exception e) { - throw new RuntimeException(e); - } - return mockApiResponseWithBody(newEventDataString); - } - - private ApiResponse handleGetEvents(String path) throws Exception { - Matcher matcher = EVENTS_REGEX.matcher(path); - if (!matcher.matches()) { - return null; - } - String sessionId = matcher.group(2); - String eventData = eventMap.get(sessionId); - if (eventData != null) { - return mockApiResponseWithBody( - String.format( - """ - { - "sessionEvents": %s - } - """, - eventData)); - } else { - // Return an empty list if no events are found for the session - return mockApiResponseWithBody("{}"); - } - } - - private ApiResponse handleGetLro(String path) { - return mockApiResponseWithBody( - String.format( - """ - { - "name": "%s", - "done": true - } - """, - path.replace("/operations/111", ""))); // Simulate LRO done - } - - private ApiResponse handleDeleteSession(String path) { - Matcher sessionMatcher = SESSION_REGEX.matcher(path); - if (!sessionMatcher.matches()) { - return null; - } - String sessionIdToDelete = sessionMatcher.group(2); - sessionMap.remove(sessionIdToDelete); - return mockApiResponseWithBody(""); - } - } - @Before public void setUp() throws Exception { sessionMap = @@ -363,7 +162,7 @@ public void setUp() throws Exception { vertexAiSessionService = new VertexAiSessionService("test-project", "test-location", mockApiClient); when(mockApiClient.request(anyString(), anyString(), anyString())) - .thenAnswer(new MockApiAnswer(sessionMap, eventMap, mockApiResponse)); + .thenAnswer(new MockApiAnswer(sessionMap, eventMap)); } @Test