diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index 57d80b48e..4acfeb121 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -25,6 +25,7 @@ import com.google.adk.agents.RunConfig; import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; +import com.google.adk.events.EventActions; import com.google.adk.memory.BaseMemoryService; import com.google.adk.plugins.BasePlugin; import com.google.adk.plugins.PluginManager; @@ -50,7 +51,9 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import javax.annotation.Nullable; /** The main class for the GenAI Agents runner. */ @@ -176,17 +179,30 @@ private Single appendNewMessageToSession( return this.sessionService.appendEvent(session, event); } + /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ + public Flowable runAsync( + String userId, String sessionId, Content newMessage, RunConfig runConfig) { + return runAsync(userId, sessionId, newMessage, runConfig, /* stateDelta= */ null); + } + /** - * Runs the agent in the standard mode. + * Runs the agent with an invocation-based mode. + * + *

TODO: make this the main implementation. * * @param userId The ID of the user for the session. * @param sessionId The ID of the session to run the agent in. * @param newMessage The new message from the user to process. * @param runConfig Configuration for the agent run. + * @param stateDelta Optional map of state updates to merge into the session for this run. * @return A Flowable stream of {@link Event} objects generated by the agent during execution. */ public Flowable runAsync( - String userId, String sessionId, Content newMessage, RunConfig runConfig) { + String userId, + String sessionId, + Content newMessage, + RunConfig runConfig, + @Nullable Map stateDelta) { Maybe maybeSession = this.sessionService.getSession(appName, userId, sessionId, Optional.empty()); return maybeSession @@ -194,38 +210,35 @@ public Flowable runAsync( Single.error( new IllegalArgumentException( String.format("Session not found: %s for user %s", sessionId, userId)))) - .flatMapPublisher(session -> this.runAsync(session, newMessage, runConfig)); + .flatMapPublisher(session -> this.runAsync(session, newMessage, runConfig, stateDelta)); } - /** - * Asynchronously runs the agent for a given user and session, processing a new message and using - * a default {@link RunConfig}. - * - *

This method initiates an agent execution within the specified session, appending the - * provided new message to the session's history. It utilizes a default {@code RunConfig} to - * control execution parameters. The method returns a stream of {@link Event} objects representing - * the agent's activity during the run. - * - * @param userId The ID of the user initiating the session. - * @param sessionId The ID of the session in which the agent will run. - * @param newMessage The new {@link Content} message to be processed by the agent. - * @return A {@link Flowable} emitting {@link Event} objects generated by the agent. - */ + /** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */ public Flowable runAsync(String userId, String sessionId, Content newMessage) { return runAsync(userId, sessionId, newMessage, RunConfig.builder().build()); } + /** See {@link #runAsync(Session, Content, RunConfig, Map)}. */ + public Flowable runAsync(Session session, Content newMessage, RunConfig runConfig) { + return runAsync(session, newMessage, runConfig, /* stateDelta= */ null); + } + /** - * Runs the agent in the standard mode using a provided Session object. + * Runs the agent asynchronously using a provided Session object. * * @param session The session to run the agent in. * @param newMessage The new message from the user to process. * @param runConfig Configuration for the agent run. + * @param stateDelta Optional map of state updates to merge into the session for this run. * @return A Flowable stream of {@link Event} objects generated by the agent during execution. */ - public Flowable runAsync(Session session, Content newMessage, RunConfig runConfig) { + public Flowable runAsync( + Session session, + Content newMessage, + RunConfig runConfig, + @Nullable Map stateDelta) { Span span = Telemetry.getTracer().spanBuilder("invocation").startSpan(); - try (Scope scope = span.makeCurrent()) { + try (Scope unusedScope = span.makeCurrent()) { BaseAgent rootAgent = this.agent; InvocationContext context = newInvocationContext( @@ -234,6 +247,12 @@ public Flowable runAsync(Session session, Content newMessage, RunConfig r /* liveRequestQueue= */ Optional.empty(), runConfig); + // Emit state delta event if provided, using the same invocation ID + Single sessionSingle = + (stateDelta != null && !stateDelta.isEmpty()) + ? emitStateDeltaEvent(session, stateDelta, context.invocationId()) + : Single.just(session); + Maybe beforeRunEvent = this.pluginManager .runBeforeRunCallback(context) @@ -247,42 +266,49 @@ public Flowable runAsync(Session session, Content newMessage, RunConfig r .build()); Flowable agentEvents = - Flowable.defer( - () -> - this.pluginManager - .runOnUserMessageCallback(context, newMessage) - .switchIfEmpty(Single.just(newMessage)) - .flatMap( - content -> - (content != null) - ? appendNewMessageToSession( - session, - content, - context, - runConfig.saveInputBlobsAsArtifacts()) - : Single.just(null)) - .flatMapPublisher( - event -> { - InvocationContext contextWithNewMessage = - newInvocationContext( - session, event.content(), Optional.empty(), runConfig); - contextWithNewMessage.agent(this.findAgentToRun(session, rootAgent)); - return contextWithNewMessage - .agent() - .runAsync(contextWithNewMessage) - .flatMap( - agentEvent -> - this.sessionService - .appendEvent(session, agentEvent) - .flatMap( - registeredEvent -> - contextWithNewMessage - .pluginManager() - .runOnEventCallback( - contextWithNewMessage, registeredEvent) - .defaultIfEmpty(registeredEvent)) - .toFlowable()); - })); + sessionSingle.flatMapPublisher( + updatedSession -> + Flowable.defer( + () -> + this.pluginManager + .runOnUserMessageCallback(context, newMessage) + .switchIfEmpty(Single.just(newMessage)) + .flatMap( + content -> + (content != null) + ? appendNewMessageToSession( + updatedSession, + content, + context, + runConfig.saveInputBlobsAsArtifacts()) + : Single.just(null)) + .flatMapPublisher( + event -> { + InvocationContext contextWithNewMessage = + newInvocationContext( + updatedSession, + event.content(), + Optional.empty(), + runConfig); + contextWithNewMessage.agent( + this.findAgentToRun(updatedSession, rootAgent)); + return contextWithNewMessage + .agent() + .runAsync(contextWithNewMessage) + .flatMap( + agentEvent -> + this.sessionService + .appendEvent(updatedSession, agentEvent) + .flatMap( + registeredEvent -> + contextWithNewMessage + .pluginManager() + .runOnEventCallback( + contextWithNewMessage, + registeredEvent) + .defaultIfEmpty(registeredEvent)) + .toFlowable()); + }))); return beforeRunEvent .toFlowable() @@ -302,6 +328,36 @@ public Flowable runAsync(Session session, Content newMessage, RunConfig r } } + /** + * Emits a state update event and returns the updated session. + * + * @param session The session to update. + * @param stateDelta The state delta to apply. + * @param invocationId The invocation ID to use for the state delta event. + * @return Single emitting the updated session after applying the state delta. + */ + private Single emitStateDeltaEvent( + Session session, Map stateDelta, String invocationId) { + ConcurrentHashMap deltaMap = new ConcurrentHashMap<>(stateDelta); + + Event stateEvent = + Event.builder() + .id(Event.generateEventId()) + .invocationId(invocationId) + .author("user") + .actions(EventActions.builder().stateDelta(deltaMap).build()) + .timestamp(System.currentTimeMillis()) + .build(); + + return this.sessionService + .appendEvent(session, stateEvent) + .flatMap( + event -> + this.sessionService + .getSession(session.appName(), session.userId(), session.id(), Optional.empty()) + .switchIfEmpty(Single.error(new IllegalStateException("Session not found")))); + } + /** * Creates an {@link InvocationContext} for a live (streaming) run. * diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index b82452d46..2cb5d3d9e 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -29,6 +29,7 @@ import static org.mockito.Mockito.when; import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.RunConfig; import com.google.adk.events.Event; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; @@ -46,6 +47,8 @@ import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; import java.util.List; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -370,6 +373,143 @@ public void onEventCallback_success() { verify(plugin).onEventCallback(any(), any()); } + @Test + public void runAsync_withStateDelta_mergesStateIntoSession() { + ImmutableMap stateDelta = ImmutableMap.of("key1", "value1", "key2", 42); + + var events = + runner + .runAsync( + "user", + session.id(), + createContent("test message"), + RunConfig.builder().build(), + stateDelta) + .toList() + .blockingGet(); + + // Verify agent runs successfully + assertThat(simplifyEvents(events)).containsExactly("test agent: from llm"); + + // Verify state was merged into session + Session finalSession = + runner + .sessionService() + .getSession("test", "user", session.id(), Optional.empty()) + .blockingGet(); + assertThat(finalSession.state()).containsAtLeastEntriesIn(stateDelta); + } + + @Test + public void runAsync_withEmptyStateDelta_doesNotModifySession() { + ImmutableMap emptyStateDelta = ImmutableMap.of(); + + var events = + runner + .runAsync( + "user", + session.id(), + createContent("test message"), + RunConfig.builder().build(), + emptyStateDelta) + .toList() + .blockingGet(); + + assertThat(simplifyEvents(events)).containsExactly("test agent: from llm"); + + // Verify no state events were emitted for empty delta + Session finalSession = + runner + .sessionService() + .getSession("test", "user", session.id(), Optional.empty()) + .blockingGet(); + assertThat(finalSession.state()).isEmpty(); + } + + @Test + public void runAsync_withNullStateDelta_doesNotModifySession() { + var events = + runner + .runAsync( + "user", + session.id(), + createContent("test message"), + RunConfig.builder().build(), + null) + .toList() + .blockingGet(); + + assertThat(simplifyEvents(events)).containsExactly("test agent: from llm"); + + Session finalSession = + runner + .sessionService() + .getSession("test", "user", session.id(), Optional.empty()) + .blockingGet(); + assertThat(finalSession.state()).isEmpty(); + } + + @Test + public void runAsync_withStateDelta_appendsStateEventToHistory() { + var unused = + runner + .runAsync( + "user", + session.id(), + createContent("test message"), + RunConfig.builder().build(), + ImmutableMap.of("testKey", "testValue")) + .toList() + .blockingGet(); + + Session finalSession = + runner + .sessionService() + .getSession("test", "user", session.id(), Optional.empty()) + .blockingGet(); + + assertThat( + finalSession.events().stream() + .anyMatch( + e -> + e.author().equals("user") + && e.actions() != null + && e.actions().stateDelta() != null + && !e.actions().stateDelta().isEmpty())) + .isTrue(); + } + + @Test + public void runAsync_withStateDelta_mergesWithExistingState() { + // Create a new session with initial state + ConcurrentHashMap initialState = new ConcurrentHashMap<>(); + initialState.put("existing_key", "existing_value"); + Session sessionWithState = + runner.sessionService().createSession("test", "user", initialState, null).blockingGet(); + + // Add new state via stateDelta + ImmutableMap newDelta = ImmutableMap.of("new_key", "new_value"); + var unused = + runner + .runAsync( + "user", + sessionWithState.id(), + createContent("test message"), + RunConfig.builder().build(), + newDelta) + .toList() + .blockingGet(); + + // Verify both old and new states are present (merged, not replaced) + Session finalSession = + runner + .sessionService() + .getSession("test", "user", sessionWithState.id(), Optional.empty()) + .blockingGet(); + assertThat(finalSession.state()).containsEntry("existing_key", "existing_value"); + assertThat(finalSession.state()).containsEntry("new_key", "new_value"); + } + private Content createContent(String text) { return Content.builder().parts(Part.builder().text(text).build()).build(); } diff --git a/dev/src/main/java/com/google/adk/web/controller/ExecutionController.java b/dev/src/main/java/com/google/adk/web/controller/ExecutionController.java index 54c23327c..6d5a2764c 100644 --- a/dev/src/main/java/com/google/adk/web/controller/ExecutionController.java +++ b/dev/src/main/java/com/google/adk/web/controller/ExecutionController.java @@ -80,7 +80,8 @@ public List agentRun(@RequestBody AgentRunRequest request) { RunConfig runConfig = RunConfig.builder().setStreamingMode(StreamingMode.NONE).build(); Flowable eventStream = - runner.runAsync(request.userId, request.sessionId, request.newMessage, runConfig); + runner.runAsync( + request.userId, request.sessionId, request.newMessage, runConfig, request.stateDelta); List events = Lists.newArrayList(eventStream.blockingIterable()); log.info("Agent run for session {} generated {} events.", request.sessionId, events.size()); @@ -151,7 +152,12 @@ public SseEmitter agentRunSse(@RequestBody AgentRunRequest request) { .build(); Flowable eventFlowable = - runner.runAsync(request.userId, request.sessionId, request.newMessage, runConfig); + runner.runAsync( + request.userId, + request.sessionId, + request.newMessage, + runConfig, + request.stateDelta); Disposable disposable = eventFlowable diff --git a/dev/src/main/java/com/google/adk/web/dto/AgentRunRequest.java b/dev/src/main/java/com/google/adk/web/dto/AgentRunRequest.java index bc550b687..652de5aac 100644 --- a/dev/src/main/java/com/google/adk/web/dto/AgentRunRequest.java +++ b/dev/src/main/java/com/google/adk/web/dto/AgentRunRequest.java @@ -18,6 +18,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.genai.types.Content; +import java.util.Map; +import javax.annotation.Nullable; /** * Data Transfer Object (DTO) for POST /run and POST /run-sse requests. Contains information needed @@ -39,6 +41,15 @@ public class AgentRunRequest { @JsonProperty("streaming") public boolean streaming = false; + /** + * Optional state delta to merge into the session state before running the agent. This allows + * updating session state dynamically per request, useful for injecting configuration (e.g., + * replay mode settings) without modifying the stored session. + */ + @JsonProperty("stateDelta") + @Nullable + public Map stateDelta; + public AgentRunRequest() {} public String getAppName() { @@ -60,4 +71,9 @@ public Content getNewMessage() { public boolean getStreaming() { return streaming; } + + @Nullable + public Map getStateDelta() { + return stateDelta; + } } diff --git a/dev/src/test/java/com/google/adk/web/dto/AgentRunRequestTest.java b/dev/src/test/java/com/google/adk/web/dto/AgentRunRequestTest.java new file mode 100644 index 000000000..70f4e2282 --- /dev/null +++ b/dev/src/test/java/com/google/adk/web/dto/AgentRunRequestTest.java @@ -0,0 +1,139 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.web.dto; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; +import com.google.genai.types.Content; +import com.google.genai.types.Part; +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public final class AgentRunRequestTest { + + private ObjectMapper objectMapper; + + private static final String BASIC_MESSAGE_JSON = + """ + "appName": "testApp", + "userId": "user123", + "sessionId": "session456", + "newMessage": { + "parts": [ + {"text": "hello"} + ] + } + """; + + private AgentRunRequest newRequest() { + AgentRunRequest request = new AgentRunRequest(); + request.appName = "testApp"; + request.userId = "user123"; + request.sessionId = "session456"; + request.newMessage = Content.builder().parts(Part.builder().text("hello").build()).build(); + return request; + } + + @BeforeEach + public void setUp() { + objectMapper = new ObjectMapper(); + objectMapper.registerModule(new Jdk8Module()); + } + + @Test + public void deserialize_expectedUsage() throws Exception { + String json = "{ %s }".formatted(BASIC_MESSAGE_JSON); + + AgentRunRequest request = objectMapper.readValue(json, AgentRunRequest.class); + + assertThat(request.appName).isEqualTo("testApp"); + assertThat(request.userId).isEqualTo("user123"); + assertThat(request.sessionId).isEqualTo("session456"); + assertThat(request.stateDelta).isNull(); + } + + @Test + public void deserialize_withDeltaState() throws Exception { + String json = + """ + { + %s, + "stateDelta": { + "key1": "value1", + "key2": 42, + "stringVal": "text", + "intVal": 123, + "boolVal": true, + "doubleVal": 45.67, + "nestedObj": {"inner": "value"} + } + } + """ + .formatted(BASIC_MESSAGE_JSON); + + AgentRunRequest request = objectMapper.readValue(json, AgentRunRequest.class); + + assertThat(request.stateDelta).isNotNull(); + assertThat(request.stateDelta).hasSize(7); + assertThat(request.stateDelta).containsEntry("key1", "value1"); + assertThat(request.stateDelta).containsEntry("key2", 42); + assertThat(request.stateDelta).containsEntry("stringVal", "text"); + assertThat(request.stateDelta).containsEntry("intVal", 123); + assertThat(request.stateDelta).containsEntry("boolVal", true); + assertThat(request.stateDelta).containsEntry("doubleVal", 45.67); + assertThat(request.stateDelta).containsKey("nestedObj"); + } + + @Test + public void serialize_withStateDelta_success() throws Exception { + AgentRunRequest request = newRequest(); + + Map stateDelta = new HashMap<>(); + stateDelta.put("key1", "value1"); + stateDelta.put("key2", 42); + stateDelta.put("key3", true); + request.stateDelta = stateDelta; + + String json = objectMapper.writeValueAsString(request); + + JsonNode deltaState = objectMapper.readTree(json).get("stateDelta"); + + assertThat(deltaState.get("key1").asText()).isEqualTo("value1"); + assertThat(deltaState.get("key2").asInt()).isEqualTo(42); + assertThat(deltaState.get("key3").asBoolean()).isTrue(); + } + + @Test + public void serialize_expectedUsage() throws Exception { + AgentRunRequest request = newRequest(); + + String json = objectMapper.writeValueAsString(request); + + JsonNode node = objectMapper.readTree(json); + assertThat(node.get("appName").asText()).isEqualTo("testApp"); + assertThat(node.get("userId").asText()).isEqualTo("user123"); + assertThat(node.get("sessionId").asText()).isEqualTo("session456"); + if (node.has("stateDelta")) { + assertThat(node.get("stateDelta").isNull()).isTrue(); + } + } +} diff --git a/maven_plugin/examples/custom_tools/pom.xml b/maven_plugin/examples/custom_tools/pom.xml index ee96758d4..b663d75f5 100644 --- a/maven_plugin/examples/custom_tools/pom.xml +++ b/maven_plugin/examples/custom_tools/pom.xml @@ -1,5 +1,7 @@ - + 4.0.0 com.example @@ -14,7 +16,7 @@ 11 11 UTF-8 - 0.3.0 + 0.3.1-SNAPSHOT @@ -56,4 +58,4 @@ - + \ No newline at end of file diff --git a/maven_plugin/examples/simple-agent/pom.xml b/maven_plugin/examples/simple-agent/pom.xml index a6fb62be1..a4cc5aeb9 100644 --- a/maven_plugin/examples/simple-agent/pom.xml +++ b/maven_plugin/examples/simple-agent/pom.xml @@ -1,5 +1,7 @@ - + 4.0.0 com.example @@ -14,7 +16,7 @@ 11 11 UTF-8 - 0.3.0 + 0.3.1-SNAPSHOT @@ -60,4 +62,4 @@ - + \ No newline at end of file