Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 19 additions & 21 deletions core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.BaseToolset;
import com.google.adk.tools.NamedToolPredicate;
import com.google.adk.tools.ToolPredicate;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.client.transport.ServerParameters;
import io.reactivex.rxjava3.core.Flowable;
Expand Down Expand Up @@ -59,47 +59,49 @@ public class McpAsyncToolset implements BaseToolset {

private final McpSessionManager mcpSessionManager;
private final ObjectMapper objectMapper;
private final ToolPredicate toolFilter;
private final Optional<Object> toolFilter;
private final AtomicReference<Mono<List<McpAsyncTool>>> mcpTools = new AtomicReference<>();

/** Builder for McpAsyncToolset */
public static class Builder {
private Object connectionParams = null;
private ObjectMapper objectMapper = null;
private ToolPredicate toolFilter = null;
private Optional<Object> toolFilter = null;

@CanIgnoreReturnValue
public Builder connectionParams(ServerParameters connectionParams) {
this.connectionParams = connectionParams;
return this;
}

@CanIgnoreReturnValue
public Builder connectionParams(SseServerParameters connectionParams) {
this.connectionParams = connectionParams;
return this;
}

@CanIgnoreReturnValue
public Builder objectMapper(ObjectMapper objectMapper) {
this.objectMapper = objectMapper;
return this;
}

public Builder toolFilter(ToolPredicate toolFilter) {
@CanIgnoreReturnValue
public Builder toolFilter(Optional<Object> toolFilter) {
this.toolFilter = toolFilter;
return this;
}

@CanIgnoreReturnValue
public Builder toolFilter(List<String> toolNames) {
this.toolFilter = new NamedToolPredicate(toolNames);
this.toolFilter = Optional.of(new NamedToolPredicate(toolNames));
return this;
}

public McpAsyncToolset build() {
if (objectMapper == null) {
objectMapper = JsonBaseModel.getMapper();
}
if (toolFilter == null) {
toolFilter = (tool, context) -> true;
}
if (connectionParams instanceof ServerParameters setSelectedParams) {
return new McpAsyncToolset(setSelectedParams, objectMapper, toolFilter);
} else if (connectionParams instanceof SseServerParameters sseServerParameters) {
Expand All @@ -116,11 +118,12 @@ public McpAsyncToolset build() {
*
* @param connectionParams The SSE connection parameters to the MCP server.
* @param objectMapper An ObjectMapper instance for parsing schemas.
* @param toolFilter null or an implement for {@link ToolPredicate}, {@link
* com.google.adk.tools.NamedToolPredicate}
* @param toolFilter An Optional containing either a ToolPredicate or a List of tool names.
*/
public McpAsyncToolset(
SseServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolFilter) {
SseServerParameters connectionParams,
ObjectMapper objectMapper,
Optional<Object> toolFilter) {
Objects.requireNonNull(connectionParams);
Objects.requireNonNull(objectMapper);
this.objectMapper = objectMapper;
Expand All @@ -133,11 +136,10 @@ public McpAsyncToolset(
*
* @param connectionParams The local server connection parameters to the MCP server.
* @param objectMapper An ObjectMapper instance for parsing schemas.
* @param toolFilter null or an implement for {@link ToolPredicate}, {@link
* com.google.adk.tools.NamedToolPredicate}
* @param toolFilter An Optional containing either a ToolPredicate or a List of tool names.
*/
public McpAsyncToolset(
ServerParameters connectionParams, ObjectMapper objectMapper, ToolPredicate toolFilter) {
ServerParameters connectionParams, ObjectMapper objectMapper, Optional<Object> toolFilter) {
Objects.requireNonNull(connectionParams);
Objects.requireNonNull(objectMapper);
this.objectMapper = objectMapper;
Expand All @@ -150,11 +152,10 @@ public McpAsyncToolset(
*
* @param mcpSessionManager The session manager for MCP connections.
* @param objectMapper An ObjectMapper instance for parsing schemas.
* @param toolFilter null or an implement for {@link ToolPredicate}, {@link
* com.google.adk.tools.NamedToolPredicate}
* @param toolFilter An Optional containing either a ToolPredicate or a List of tool names.
*/
public McpAsyncToolset(
McpSessionManager mcpSessionManager, ObjectMapper objectMapper, ToolPredicate toolFilter) {
McpSessionManager mcpSessionManager, ObjectMapper objectMapper, Optional<Object> toolFilter) {
Objects.requireNonNull(mcpSessionManager);
Objects.requireNonNull(objectMapper);
this.objectMapper = objectMapper;
Expand All @@ -171,10 +172,7 @@ public Flowable<BaseTool> getTools(ReadonlyContext readonlyContext) {
tools.stream()
.filter(
tool ->
isToolSelected(
tool,
Optional.ofNullable(toolFilter),
Optional.ofNullable(readonlyContext)))
isToolSelected(tool, toolFilter, Optional.ofNullable(readonlyContext)))
.toList())
.onErrorResumeNext(
err -> {
Expand Down