Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dev.aikido.agent_api.collectors;

import dev.aikido.agent_api.context.Context;
import dev.aikido.agent_api.storage.BypassedContextStore;
import dev.aikido.agent_api.storage.HostnamesStore;
import dev.aikido.agent_api.storage.PendingHostnamesStore;
import dev.aikido.agent_api.storage.ServiceConfigStore;
Expand Down Expand Up @@ -34,20 +35,24 @@ public static void report(String hostname, InetAddress[] inetAddresses) {
// store stats
StatisticsStore.registerCall("java.net.InetAddress.getAllByName", OperationKind.OUTGOING_HTTP_OP);

boolean bypassed = BypassedContextStore.isBypassed();

// Consume pending ports recorded by URLCollector for this hostname.
// Removing them here ensures each (hostname, port) pair is counted exactly once.
Set<Integer> ports = PendingHostnamesStore.getAndRemove(hostname);
if (!ports.isEmpty()) {
for (int port : ports) {
HostnamesStore.incrementHits(hostname, port);
if (!bypassed) {
// Bypassed IPs are trusted β€” don't report their outbound hostnames in heartbeats.
if (!ports.isEmpty()) {
for (int port : ports) {
HostnamesStore.incrementHits(hostname, port);
}
} else {
HostnamesStore.incrementHits(hostname, 0);
}
} else {
// We still need to report a hit to the hostname for outbound domain blocking
HostnamesStore.incrementHits(hostname, 0);
}

// Block if the hostname is in the blocked domains list
if (ServiceConfigStore.shouldBlockOutgoingRequest(hostname)) {
if (ServiceConfigStore.shouldBlockOutgoingRequest(hostname) && !bypassed) {
logger.debug("Blocking DNS lookup for domain: %s", hostname);
throw BlockedOutboundException.get();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import dev.aikido.agent_api.helpers.logging.LogManager;
import dev.aikido.agent_api.helpers.logging.Logger;
import dev.aikido.agent_api.storage.AttackQueue;
import dev.aikido.agent_api.storage.BypassedContextStore;
import dev.aikido.agent_api.storage.PendingHostnamesStore;
import dev.aikido.agent_api.storage.ServiceConfigStore;
import dev.aikido.agent_api.storage.ServiceConfiguration;
Expand Down Expand Up @@ -44,8 +45,10 @@ public static Res report(ContextObject newContext) {
// Flush pending hostnames on every context change to prevent the store from
// growing unboundedly when a thread is reused across multiple requests.
PendingHostnamesStore.clear();
BypassedContextStore.clear();

if (config.isIpBypassed(newContext.getRemoteAddress())) {
BypassedContextStore.setBypassed(true);
return null; // do not set context when the IP address is bypassed (zen = off)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package dev.aikido.agent_api.storage;

/**
* Thread-local flag recording whether the current request's remote IP is in the bypass list.
* Needed because bypassed requests intentionally do not set a context, but for Stored SSRF we still want to skip.
*/
public final class BypassedContextStore {
private BypassedContextStore() {}

private static final ThreadLocal<Boolean> store = ThreadLocal.withInitial(() -> false);
Comment thread
bitterpanda63 marked this conversation as resolved.

public static void setBypassed(boolean bypassed) {
store.set(bypassed);
}

public static boolean isBypassed() {
return store.get();
}

public static void clear() {
store.remove();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dev.aikido.agent_api.background.Endpoint;
import dev.aikido.agent_api.context.ContextObject;
import dev.aikido.agent_api.storage.BypassedContextStore;
import dev.aikido.agent_api.storage.ServiceConfiguration;

import java.util.List;
Expand All @@ -12,6 +13,11 @@
public final class SkipVulnerabilityScanDecider {
private SkipVulnerabilityScanDecider() {}
public static boolean shouldSkipVulnerabilityScan(ContextObject context, boolean defaultIfNoContext) {
// Check if ip is bypassed, important still for stored ssrf where it runs without a context.
if (BypassedContextStore.isBypassed()) {
return true;
}

if (context == null) {
return defaultIfNoContext;
}
Expand Down
39 changes: 39 additions & 0 deletions agent_api/src/test/java/collectors/DNSRecordCollectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dev.aikido.agent_api.context.Context;
import dev.aikido.agent_api.context.ContextObject;
import dev.aikido.agent_api.storage.AttackQueue;
import dev.aikido.agent_api.storage.BypassedContextStore;
import dev.aikido.agent_api.storage.Hostnames;
import dev.aikido.agent_api.storage.HostnamesStore;
import dev.aikido.agent_api.storage.PendingHostnamesStore;
Expand Down Expand Up @@ -37,6 +38,7 @@ void setup() throws UnknownHostException {
AttackQueue.clear();
HostnamesStore.clear();
PendingHostnamesStore.clear();
BypassedContextStore.clear();
}

@AfterEach
Expand All @@ -45,6 +47,7 @@ public void cleanup() {
PendingHostnamesStore.clear();
Context.set(null);
AttackQueue.clear();
BypassedContextStore.clear();
// Reset domain config
ServiceConfigStore.updateFromAPIResponse(new APIResponse(
true, null, 0L, null, null, null, false, List.of(), true, false, List.of()
Expand Down Expand Up @@ -134,6 +137,42 @@ public void testAllowedDomainNotBlocked() {
);
}

@Test
public void testBlockedDomainNotBlockedWhenIpBypassed() {
ServiceConfigStore.updateFromAPIResponse(new APIResponse(
true, null, 0L, null, null, null,
false, List.of(new Domain("blocked.example.com", "block")), true, true, List.of()
));
BypassedContextStore.setBypassed(true);
assertDoesNotThrow(() ->
DNSRecordCollector.report("blocked.example.com", new InetAddress[]{inetAddress1})
);
}

@Test
public void testHostnamesStoreNotUpdatedWhenBypassed() {
BypassedContextStore.setBypassed(true);
Context.set(new EmptySampleContextObject());

DNSRecordCollector.report("dev.aikido", new InetAddress[]{inetAddress1});

assertEquals(0, HostnamesStore.getHostnamesAsList().length);
}

@Test
public void testHostnamesStoreNotUpdatedWhenBypassedWithPendingPorts() {
PendingHostnamesStore.add("dev.aikido", 80);
PendingHostnamesStore.add("dev.aikido", 443);
BypassedContextStore.setBypassed(true);
Context.set(mock(ContextObject.class));

DNSRecordCollector.report("dev.aikido", new InetAddress[]{inetAddress1});

assertEquals(0, HostnamesStore.getHostnamesAsList().length);
// Pending entries are still consumed even when bypassed so the store doesn't grow unboundedly
assertTrue(PendingHostnamesStore.getPorts("dev.aikido").isEmpty());
}

@Test
public void testUnknownDomainBlockedWhenBlockNewOutgoingRequests() {
ServiceConfigStore.updateFromAPIResponse(new APIResponse(
Expand Down
29 changes: 29 additions & 0 deletions agent_api/src/test/java/collectors/WebRequestCollectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import dev.aikido.agent_api.context.Context;
import dev.aikido.agent_api.context.ContextObject;
import dev.aikido.agent_api.storage.AttackQueue;
import dev.aikido.agent_api.storage.BypassedContextStore;
import dev.aikido.agent_api.storage.ServiceConfigStore;
import dev.aikido.agent_api.storage.statistics.StatisticsStore;
import org.junit.jupiter.api.BeforeEach;
Expand All @@ -34,6 +35,7 @@ void setUp() {
ServiceConfigStore.updateFromAPIResponse(emptyAPIResponse);
ServiceConfigStore.updateFromAPIListsResponse(emptyAPIListsResponse);
AttackQueue.clear();
BypassedContextStore.clear();
}

@Test
Expand Down Expand Up @@ -260,6 +262,33 @@ void testReport_ipBlockedUsingLists_IPv4MappedBypass() {
assertNull(Context.get());
}

@Test
void testReport_bypassedIp_setsBypassedStore() {
List<String> bypassedIps = List.of("192.168.1.1");
ServiceConfigStore.updateFromAPIResponse(new APIResponse(
true, "", getUnixTimeMS(), List.of(), List.of(), bypassedIps, false, null, true, false, List.of()
));

assertFalse(BypassedContextStore.isBypassed());

WebRequestCollector.Res response = WebRequestCollector.report(contextObject);

assertNull(response);
assertNull(Context.get());
assertTrue(BypassedContextStore.isBypassed());
}

@Test
void testReport_nonBypassedIp_clearsBypassedStore() {
BypassedContextStore.setBypassed(true);
assertTrue(BypassedContextStore.isBypassed());

WebRequestCollector.Res response = WebRequestCollector.report(contextObject);

assertNull(response);
assertFalse(BypassedContextStore.isBypassed());
}

@Test
void testReport_ipNotAllowedUsingLists_Ip_Bypassed() {
ReportingApi.APIListsResponse blockedListsRes = new ReportingApi.APIListsResponse(
Expand Down
59 changes: 59 additions & 0 deletions agent_api/src/test/java/storage/BypassedContextStoreTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package storage;

import dev.aikido.agent_api.storage.BypassedContextStore;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.util.concurrent.atomic.AtomicBoolean;

import static org.junit.jupiter.api.Assertions.*;

public class BypassedContextStoreTest {

@BeforeEach
public void setUp() {
BypassedContextStore.clear();
}

@AfterEach
public void tearDown() {
BypassedContextStore.clear();
}

@Test
public void testDefaultIsFalse() {
assertFalse(BypassedContextStore.isBypassed());
}

@Test
public void testSetBypassed() {
BypassedContextStore.setBypassed(true);
assertTrue(BypassedContextStore.isBypassed());

BypassedContextStore.setBypassed(false);
assertFalse(BypassedContextStore.isBypassed());
}

@Test
public void testClear() {
BypassedContextStore.setBypassed(true);
assertTrue(BypassedContextStore.isBypassed());

BypassedContextStore.clear();
assertFalse(BypassedContextStore.isBypassed());
}

@Test
public void testThreadIsolation() throws InterruptedException {
BypassedContextStore.setBypassed(true);
AtomicBoolean observedInOtherThread = new AtomicBoolean(true);

Thread t = new Thread(() -> observedInOtherThread.set(BypassedContextStore.isBypassed()));
t.start();
t.join();

assertFalse(observedInOtherThread.get());
assertTrue(BypassedContextStore.isBypassed());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import dev.aikido.agent_api.background.Endpoint;
import dev.aikido.agent_api.context.ContextObject;
import dev.aikido.agent_api.storage.BypassedContextStore;
import dev.aikido.agent_api.vulnerabilities.SkipVulnerabilityScanDecider;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import utils.EmptyAPIResponses;
import utils.EmptySampleContextObject;
Expand All @@ -14,6 +17,16 @@
import static org.junit.jupiter.api.Assertions.*;

public class SkipVulnerabilityScanDeciderTest {
@BeforeEach
public void setUp() {
BypassedContextStore.clear();
}

@AfterEach
public void tearDown() {
BypassedContextStore.clear();
}

private List<Endpoint> createEndpoints(boolean protectionForcedOff1, boolean protectionForcedOff2) {
List<Endpoint> endpoints = new ArrayList<>();
endpoints.add(new Endpoint("POST", "/api/login", 3, 1000, Collections.emptyList(), false, protectionForcedOff1, true));
Expand Down Expand Up @@ -157,6 +170,33 @@ public void testShouldSkipVulnerabilityScan_NoConditionsMet() {
));
}

@Test
public void testShouldSkipVulnerabilityScan_BypassedIp_NullContext() {
BypassedContextStore.setBypassed(true);
// Even with defaultIfNoContext=false (the Stored SSRF path), a bypassed IP must skip.
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(null, false));
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(null, true));
}

@Test
public void testShouldSkipVulnerabilityScan_BypassedIp_WithContext() {
EmptyAPIResponses.setEmptyConfigWithEndpointList(createEndpoints(false, false));
ContextObject ctx = new EmptySampleContextObject("", "/api/login", "POST");
// Without bypass flag this context would return false (no forced protection off).
assertFalse(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(ctx));

BypassedContextStore.setBypassed(true);
ContextObject freshCtx = new EmptySampleContextObject("", "/api/login", "POST");
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(freshCtx));
}

@Test
public void testShouldSkipVulnerabilityScan_NotBypassed_NullContext() {
// Sanity check: default behavior unchanged when bypass flag is not set.
assertFalse(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(null, false));
assertTrue(SkipVulnerabilityScanDecider.shouldSkipVulnerabilityScan(null, true));
}

@Test
public void testUsesCache() {
ContextObject ctx = new EmptySampleContextObject("", "/api/login", "POST");
Expand Down
Loading