Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 112 additions & 56 deletions core/src/main/java/com/google/adk/runner/Runner.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.adk.agents.RunConfig;
import com.google.adk.artifacts.BaseArtifactService;
import com.google.adk.events.Event;
import com.google.adk.events.EventActions;
import com.google.adk.memory.BaseMemoryService;
import com.google.adk.plugins.BasePlugin;
import com.google.adk.plugins.PluginManager;
Expand All @@ -50,7 +51,9 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import javax.annotation.Nullable;

/** The main class for the GenAI Agents runner. */
Expand Down Expand Up @@ -176,56 +179,66 @@ private Single<Event> appendNewMessageToSession(
return this.sessionService.appendEvent(session, event);
}

/** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */
public Flowable<Event> runAsync(
String userId, String sessionId, Content newMessage, RunConfig runConfig) {
return runAsync(userId, sessionId, newMessage, runConfig, /* stateDelta= */ null);
}

/**
* Runs the agent in the standard mode.
* Runs the agent with an invocation-based mode.
*
* <p>TODO: make this the main implementation.
*
* @param userId The ID of the user for the session.
* @param sessionId The ID of the session to run the agent in.
* @param newMessage The new message from the user to process.
* @param runConfig Configuration for the agent run.
* @param stateDelta Optional map of state updates to merge into the session for this run.
* @return A Flowable stream of {@link Event} objects generated by the agent during execution.
*/
public Flowable<Event> runAsync(
String userId, String sessionId, Content newMessage, RunConfig runConfig) {
String userId,
String sessionId,
Content newMessage,
RunConfig runConfig,
@Nullable Map<String, Object> stateDelta) {
Maybe<Session> maybeSession =
this.sessionService.getSession(appName, userId, sessionId, Optional.empty());
return maybeSession
.switchIfEmpty(
Single.error(
new IllegalArgumentException(
String.format("Session not found: %s for user %s", sessionId, userId))))
.flatMapPublisher(session -> this.runAsync(session, newMessage, runConfig));
.flatMapPublisher(session -> this.runAsync(session, newMessage, runConfig, stateDelta));
}

/**
* Asynchronously runs the agent for a given user and session, processing a new message and using
* a default {@link RunConfig}.
*
* <p>This method initiates an agent execution within the specified session, appending the
* provided new message to the session's history. It utilizes a default {@code RunConfig} to
* control execution parameters. The method returns a stream of {@link Event} objects representing
* the agent's activity during the run.
*
* @param userId The ID of the user initiating the session.
* @param sessionId The ID of the session in which the agent will run.
* @param newMessage The new {@link Content} message to be processed by the agent.
* @return A {@link Flowable} emitting {@link Event} objects generated by the agent.
*/
/** See {@link #runAsync(String, String, Content, RunConfig, Map)}. */
public Flowable<Event> runAsync(String userId, String sessionId, Content newMessage) {
return runAsync(userId, sessionId, newMessage, RunConfig.builder().build());
}

/** See {@link #runAsync(Session, Content, RunConfig, Map)}. */
public Flowable<Event> runAsync(Session session, Content newMessage, RunConfig runConfig) {
return runAsync(session, newMessage, runConfig, /* stateDelta= */ null);
}

/**
* Runs the agent in the standard mode using a provided Session object.
* Runs the agent asynchronously using a provided Session object.
*
* @param session The session to run the agent in.
* @param newMessage The new message from the user to process.
* @param runConfig Configuration for the agent run.
* @param stateDelta Optional map of state updates to merge into the session for this run.
* @return A Flowable stream of {@link Event} objects generated by the agent during execution.
*/
public Flowable<Event> runAsync(Session session, Content newMessage, RunConfig runConfig) {
public Flowable<Event> runAsync(
Session session,
Content newMessage,
RunConfig runConfig,
@Nullable Map<String, Object> stateDelta) {
Span span = Telemetry.getTracer().spanBuilder("invocation").startSpan();
try (Scope scope = span.makeCurrent()) {
try (Scope unusedScope = span.makeCurrent()) {
BaseAgent rootAgent = this.agent;
InvocationContext context =
newInvocationContext(
Expand All @@ -234,6 +247,12 @@ public Flowable<Event> runAsync(Session session, Content newMessage, RunConfig r
/* liveRequestQueue= */ Optional.empty(),
runConfig);

// Emit state delta event if provided, using the same invocation ID
Single<Session> sessionSingle =
(stateDelta != null && !stateDelta.isEmpty())
? emitStateDeltaEvent(session, stateDelta, context.invocationId())
: Single.just(session);

Maybe<Event> beforeRunEvent =
this.pluginManager
.runBeforeRunCallback(context)
Expand All @@ -247,42 +266,49 @@ public Flowable<Event> runAsync(Session session, Content newMessage, RunConfig r
.build());

Flowable<Event> agentEvents =
Flowable.defer(
() ->
this.pluginManager
.runOnUserMessageCallback(context, newMessage)
.switchIfEmpty(Single.just(newMessage))
.flatMap(
content ->
(content != null)
? appendNewMessageToSession(
session,
content,
context,
runConfig.saveInputBlobsAsArtifacts())
: Single.just(null))
.flatMapPublisher(
event -> {
InvocationContext contextWithNewMessage =
newInvocationContext(
session, event.content(), Optional.empty(), runConfig);
contextWithNewMessage.agent(this.findAgentToRun(session, rootAgent));
return contextWithNewMessage
.agent()
.runAsync(contextWithNewMessage)
.flatMap(
agentEvent ->
this.sessionService
.appendEvent(session, agentEvent)
.flatMap(
registeredEvent ->
contextWithNewMessage
.pluginManager()
.runOnEventCallback(
contextWithNewMessage, registeredEvent)
.defaultIfEmpty(registeredEvent))
.toFlowable());
}));
sessionSingle.flatMapPublisher(
updatedSession ->
Flowable.defer(
() ->
this.pluginManager
.runOnUserMessageCallback(context, newMessage)
.switchIfEmpty(Single.just(newMessage))
.flatMap(
content ->
(content != null)
? appendNewMessageToSession(
updatedSession,
content,
context,
runConfig.saveInputBlobsAsArtifacts())
: Single.just(null))
.flatMapPublisher(
event -> {
InvocationContext contextWithNewMessage =
newInvocationContext(
updatedSession,
event.content(),
Optional.empty(),
runConfig);
contextWithNewMessage.agent(
this.findAgentToRun(updatedSession, rootAgent));
return contextWithNewMessage
.agent()
.runAsync(contextWithNewMessage)
.flatMap(
agentEvent ->
this.sessionService
.appendEvent(updatedSession, agentEvent)
.flatMap(
registeredEvent ->
contextWithNewMessage
.pluginManager()
.runOnEventCallback(
contextWithNewMessage,
registeredEvent)
.defaultIfEmpty(registeredEvent))
.toFlowable());
})));

return beforeRunEvent
.toFlowable()
Expand All @@ -302,6 +328,36 @@ public Flowable<Event> runAsync(Session session, Content newMessage, RunConfig r
}
}

/**
* Emits a state update event and returns the updated session.
*
* @param session The session to update.
* @param stateDelta The state delta to apply.
* @param invocationId The invocation ID to use for the state delta event.
* @return Single emitting the updated session after applying the state delta.
*/
private Single<Session> emitStateDeltaEvent(
Session session, Map<String, Object> stateDelta, String invocationId) {
ConcurrentHashMap<String, Object> deltaMap = new ConcurrentHashMap<>(stateDelta);

Event stateEvent =
Event.builder()
.id(Event.generateEventId())
.invocationId(invocationId)
.author("user")
.actions(EventActions.builder().stateDelta(deltaMap).build())
.timestamp(System.currentTimeMillis())
.build();

return this.sessionService
.appendEvent(session, stateEvent)
.flatMap(
event ->
this.sessionService
.getSession(session.appName(), session.userId(), session.id(), Optional.empty())
.switchIfEmpty(Single.error(new IllegalStateException("Session not found"))));
}

/**
* Creates an {@link InvocationContext} for a live (streaming) run.
*
Expand Down
Loading