diff --git a/CHANGELOG.md b/CHANGELOG.md index bf60ed32094e7..14f2a64180c4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add support for Warm Indices Write Block on Flood Watermark breach ([#18375](https://github.com/opensearch-project/OpenSearch/pull/18375)) - Ability to run Code Coverage with Gradle and produce the jacoco reports locally ([#18509](https://github.com/opensearch-project/OpenSearch/issues/18509)) - Introduce SecureHttpTransportParameters experimental API (to complement SecureTransportParameters counterpart) ([#18572](https://github.com/opensearch-project/OpenSearch/issues/18572)) +- Create equivalents of JSM's AccessController in the java agent ([#18346](https://github.com/opensearch-project/OpenSearch/issues/18346)) ### Changed - Update Subject interface to use CheckedRunnable ([#18570](https://github.com/opensearch-project/OpenSearch/issues/18570)) diff --git a/libs/agent-sm/agent-policy/src/main/java/org/opensearch/secure_sm/AccessController.java b/libs/agent-sm/agent-policy/src/main/java/org/opensearch/secure_sm/AccessController.java new file mode 100644 index 0000000000000..915c3630e135e --- /dev/null +++ b/libs/agent-sm/agent-policy/src/main/java/org/opensearch/secure_sm/AccessController.java @@ -0,0 +1,129 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.secure_sm; + +import java.util.concurrent.Callable; +import java.util.function.Supplier; + +/** + * A utility class that provides methods to perform actions in a privileged context. + * + * This class is a replacement for Java's {@code java.security.AccessController} functionality which is marked for + * removal. All new code should use this class instead of the JDK's {@code AccessController}. + * + * Running code in a privileged context will ensure that the code has the necessary permissions + * without traversing through the entire call stack. See {@code org.opensearch.javaagent.StackCallerProtectionDomainChainExtractor} + * + * Example usages: + *
+ * {@code
+ * AccessController.doPrivileged(() -> {
+ *     // code that requires privileges
+ * });
+ * }
+ * 
+ * + * Example usage with a return value and checked exception: + * + *
+ * {@code
+ * T something = AccessController.doPrivilegedChecked(() -> {
+ *     // code that requires privileges and may throw a checked exception
+ *     return something;
+ *     // or
+ *     throw new Exception();
+ * });
+ * }
+ * 
+ */ +public final class AccessController { + /** + * Don't allow instantiation an {@code AccessController} + */ + private AccessController() {} + + /** + * Performs the specified action in a privileged block. + * + *

If the action's {@code run} method throws an (unchecked) + * exception, it will propagate through this method. + * + * @param action the action to be performed + */ + public static void doPrivileged(Runnable action) { + action.run(); + } + + /** + * Performs the specified action. + * + *

If the action's {@code run} method throws an unchecked + * exception, it will propagate through this method. + * + * @param the type of the value returned by the + * PrivilegedExceptionAction's {@code run} method + * + * @param action the action to be performed + * + * @return the value returned by the action's {@code run} method + */ + public static T doPrivileged(Supplier action) { + return action.get(); + } + + /** + * Performs the specified action. + * + *

If the action's {@code run} method throws an unchecked + * exception, it will propagate through this method. + * + * @param the type of the value returned by the + * PrivilegedExceptionAction's {@code run} method + * + * @param action the action to be performed + * + * @return the value returned by the action's {@code run} method + * + * @throws Exception if the specified action's + * {@code call} method threw a checked exception + */ + public static T doPrivilegedChecked(Callable action) throws Exception { + return action.call(); + } + + /** + * Performs the specified action in a privileged block. + * + *

If the action's {@code run} method throws an (unchecked) + * exception, it will propagate through this method. + * + * @param action the action to be performed + * + * @throws T if the specified action's + * {@code call} method threw a checked exception + */ + public static void doPrivilegedChecked(CheckedRunnable action) throws T { + action.run(); + } + + /** + * A functional interface that represents a runnable action that can throw a checked exception. + * + * @param the type of the exception that can be thrown + */ + public interface CheckedRunnable { + + /** + * Executes the action. + * + * @throws E + */ + void run() throws E; + } +} diff --git a/libs/agent-sm/agent-policy/src/main/java/org/opensearch/secure_sm/package-info.java b/libs/agent-sm/agent-policy/src/main/java/org/opensearch/secure_sm/package-info.java new file mode 100644 index 0000000000000..c315e7c6c7244 --- /dev/null +++ b/libs/agent-sm/agent-policy/src/main/java/org/opensearch/secure_sm/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * Classes for running code in a privileged context + */ +package org.opensearch.secure_sm; diff --git a/libs/agent-sm/agent/build.gradle b/libs/agent-sm/agent/build.gradle index 4a700d65730fb..c495067d45ebb 100644 --- a/libs/agent-sm/agent/build.gradle +++ b/libs/agent-sm/agent/build.gradle @@ -14,6 +14,7 @@ dependencies { implementation "net.bytebuddy:byte-buddy:${versions.bytebuddy}" compileOnly "com.google.code.findbugs:jsr305:3.0.2" + testImplementation project(":libs:agent-sm:agent-policy") testImplementation "junit:junit:${versions.junit}" testImplementation "org.hamcrest:hamcrest:${versions.hamcrest}" } diff --git a/libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/StackCallerProtectionDomainChainExtractor.java b/libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/StackCallerProtectionDomainChainExtractor.java index f4a1382254b0f..8c348a29ab69e 100644 --- a/libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/StackCallerProtectionDomainChainExtractor.java +++ b/libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/StackCallerProtectionDomainChainExtractor.java @@ -11,6 +11,7 @@ import java.lang.StackWalker.StackFrame; import java.security.ProtectionDomain; import java.util.Collection; +import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -24,6 +25,18 @@ public final class StackCallerProtectionDomainChainExtractor implements Function */ public static final StackCallerProtectionDomainChainExtractor INSTANCE = new StackCallerProtectionDomainChainExtractor(); + private static final StackWalker STACK_WALKER = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE); + /** + * Classes that are used to check if the stack frame is from AccessController. Temporarily supports both the + * AccessController from the JDK (marked for removal) and its replacement in the Java Agent. + */ + private static final Set ACCESS_CONTROLLER_CLASSES = Set.of( + "java.security.AccessController", + "org.opensearch.secure_sm.AccessController" + ); + + private static final Set DO_PRIVILEGED_METHODS = Set.of("doPrivileged", "doPrivilegedChecked"); + /** * Constructor */ @@ -36,7 +49,7 @@ private StackCallerProtectionDomainChainExtractor() {} @Override public Collection apply(Stream frames) { return frames.takeWhile( - frame -> !(frame.getClassName().equals("java.security.AccessController") && frame.getMethodName().equals("doPrivileged")) + frame -> !(ACCESS_CONTROLLER_CLASSES.contains(frame.getClassName()) && DO_PRIVILEGED_METHODS.contains(frame.getMethodName())) ) .map(StackFrame::getDeclaringClass) .map(Class::getProtectionDomain) diff --git a/libs/agent-sm/agent/src/test/java/org/opensearch/javaagent/StackCallerProtectionDomainExtractorTests.java b/libs/agent-sm/agent/src/test/java/org/opensearch/javaagent/StackCallerProtectionDomainExtractorTests.java index 4f26a97d0ff12..eab0711a2288f 100644 --- a/libs/agent-sm/agent/src/test/java/org/opensearch/javaagent/StackCallerProtectionDomainExtractorTests.java +++ b/libs/agent-sm/agent/src/test/java/org/opensearch/javaagent/StackCallerProtectionDomainExtractorTests.java @@ -19,12 +19,14 @@ import java.security.ProtectionDomain; import java.util.List; import java.util.Set; +import java.util.function.Supplier; import java.util.stream.Collectors; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.hasItem; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; public class StackCallerProtectionDomainExtractorTests { @@ -115,4 +117,144 @@ public Void run() { } }); } + + @Test + public void testStackTruncationWithOpenSearchAccessController() { + org.opensearch.secure_sm.AccessController.doPrivileged(() -> { + StackCallerProtectionDomainChainExtractor extractor = StackCallerProtectionDomainChainExtractor.INSTANCE; + Set protectionDomains = (Set) extractor.apply(captureStackFrames().stream()); + assertEquals(1, protectionDomains.size()); + List simpleNames = protectionDomains.stream().map(pd -> { + try { + return pd.getCodeSource().getLocation().toURI(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + }) + .map(URI::getPath) + .map(Paths::get) + .map(Path::getFileName) + .map(Path::toString) + // strip trailing “-VERSION.jar” if present + .map(name -> name.replaceFirst("-\\d[\\d\\.]*\\.jar$", "")) + // otherwise strip “.jar” + .map(name -> name.replaceFirst("\\.jar$", "")) + .toList(); + assertThat( + simpleNames, + containsInAnyOrder( + "test" // from the build/classes/java/test directory + ) + ); + }); + } + + @Test + public void testStackTruncationWithOpenSearchAccessControllerUsingSupplier() { + org.opensearch.secure_sm.AccessController.doPrivileged((Supplier) () -> { + StackCallerProtectionDomainChainExtractor extractor = StackCallerProtectionDomainChainExtractor.INSTANCE; + Set protectionDomains = (Set) extractor.apply(captureStackFrames().stream()); + assertEquals(1, protectionDomains.size()); + List simpleNames = protectionDomains.stream().map(pd -> { + try { + return pd.getCodeSource().getLocation().toURI(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + }) + .map(URI::getPath) + .map(Paths::get) + .map(Path::getFileName) + .map(Path::toString) + // strip trailing “-VERSION.jar” if present + .map(name -> name.replaceFirst("-\\d[\\d\\.]*\\.jar$", "")) + // otherwise strip “.jar” + .map(name -> name.replaceFirst("\\.jar$", "")) + .toList(); + assertThat( + simpleNames, + containsInAnyOrder( + "test" // from the build/classes/java/test directory + ) + ); + return null; + }); + } + + @Test + public void testStackTruncationWithOpenSearchAccessControllerUsingCallable() throws Exception { + org.opensearch.secure_sm.AccessController.doPrivilegedChecked(() -> { + StackCallerProtectionDomainChainExtractor extractor = StackCallerProtectionDomainChainExtractor.INSTANCE; + Set protectionDomains = (Set) extractor.apply(captureStackFrames().stream()); + assertEquals(1, protectionDomains.size()); + List simpleNames = protectionDomains.stream().map(pd -> { + try { + return pd.getCodeSource().getLocation().toURI(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + }) + .map(URI::getPath) + .map(Paths::get) + .map(Path::getFileName) + .map(Path::toString) + // strip trailing “-VERSION.jar” if present + .map(name -> name.replaceFirst("-\\d[\\d\\.]*\\.jar$", "")) + // otherwise strip “.jar” + .map(name -> name.replaceFirst("\\.jar$", "")) + .toList(); + assertThat( + simpleNames, + containsInAnyOrder( + "test" // from the build/classes/java/test directory + ) + ); + return null; + }); + } + + @Test + public void testAccessControllerUsingCallableThrowsException() { + assertThrows(IllegalArgumentException.class, () -> { + org.opensearch.secure_sm.AccessController.doPrivilegedChecked(() -> { throw new IllegalArgumentException("Test exception"); }); + }); + } + + @Test + public void testStackTruncationWithOpenSearchAccessControllerUsingCheckedRunnable() throws IllegalArgumentException { + org.opensearch.secure_sm.AccessController.doPrivilegedChecked(() -> { + StackCallerProtectionDomainChainExtractor extractor = StackCallerProtectionDomainChainExtractor.INSTANCE; + Set protectionDomains = (Set) extractor.apply(captureStackFrames().stream()); + assertEquals(1, protectionDomains.size()); + List simpleNames = protectionDomains.stream().map(pd -> { + try { + return pd.getCodeSource().getLocation().toURI(); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + }) + .map(URI::getPath) + .map(Paths::get) + .map(Path::getFileName) + .map(Path::toString) + // strip trailing “-VERSION.jar” if present + .map(name -> name.replaceFirst("-\\d[\\d\\.]*\\.jar$", "")) + // otherwise strip “.jar” + .map(name -> name.replaceFirst("\\.jar$", "")) + .toList(); + assertThat( + simpleNames, + containsInAnyOrder( + "test" // from the build/classes/java/test directory + ) + ); + }); + } + + @Test + public void testAccessControllerUsingCheckedRunnableThrowsException() { + assertThrows(IllegalArgumentException.class, () -> { + org.opensearch.secure_sm.AccessController.doPrivilegedChecked(() -> { throw new IllegalArgumentException("Test exception"); }); + }); + } } diff --git a/modules/ingest-geoip/src/main/java/org/opensearch/ingest/geoip/GeoIpProcessor.java b/modules/ingest-geoip/src/main/java/org/opensearch/ingest/geoip/GeoIpProcessor.java index b27c0f9fe0b31..62eb49d480df4 100644 --- a/modules/ingest-geoip/src/main/java/org/opensearch/ingest/geoip/GeoIpProcessor.java +++ b/modules/ingest-geoip/src/main/java/org/opensearch/ingest/geoip/GeoIpProcessor.java @@ -51,11 +51,10 @@ import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.Processor; import org.opensearch.ingest.geoip.IngestGeoIpModulePlugin.GeoIpCache; +import org.opensearch.secure_sm.AccessController; import java.io.IOException; import java.net.InetAddress; -import java.security.AccessController; -import java.security.PrivilegedAction; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -221,17 +220,15 @@ Set getProperties() { @SuppressWarnings("removal") private Map retrieveCityGeoData(InetAddress ipAddress) { SpecialPermission.check(); - CityResponse response = AccessController.doPrivileged( - (PrivilegedAction) () -> cache.putIfAbsent(ipAddress, CityResponse.class, ip -> { - try { - return lazyLoader.get().city(ip); - } catch (AddressNotFoundException e) { - throw new AddressNotFoundRuntimeException(e); - } catch (Exception e) { - throw new RuntimeException(e); - } - }) - ); + CityResponse response = AccessController.doPrivileged(() -> cache.putIfAbsent(ipAddress, CityResponse.class, ip -> { + try { + return lazyLoader.get().city(ip); + } catch (AddressNotFoundException e) { + throw new AddressNotFoundRuntimeException(e); + } catch (Exception e) { + throw new RuntimeException(e); + } + })); Country country = response.getCountry(); City city = response.getCity(); @@ -309,17 +306,15 @@ private Map retrieveCityGeoData(InetAddress ipAddress) { @SuppressWarnings("removal") private Map retrieveCountryGeoData(InetAddress ipAddress) { SpecialPermission.check(); - CountryResponse response = AccessController.doPrivileged( - (PrivilegedAction) () -> cache.putIfAbsent(ipAddress, CountryResponse.class, ip -> { - try { - return lazyLoader.get().country(ip); - } catch (AddressNotFoundException e) { - throw new AddressNotFoundRuntimeException(e); - } catch (Exception e) { - throw new RuntimeException(e); - } - }) - ); + CountryResponse response = AccessController.doPrivileged(() -> cache.putIfAbsent(ipAddress, CountryResponse.class, ip -> { + try { + return lazyLoader.get().country(ip); + } catch (AddressNotFoundException e) { + throw new AddressNotFoundRuntimeException(e); + } catch (Exception e) { + throw new RuntimeException(e); + } + })); Country country = response.getCountry(); Continent continent = response.getContinent(); @@ -356,17 +351,15 @@ private Map retrieveCountryGeoData(InetAddress ipAddress) { @SuppressWarnings("removal") private Map retrieveAsnGeoData(InetAddress ipAddress) { SpecialPermission.check(); - AsnResponse response = AccessController.doPrivileged( - (PrivilegedAction) () -> cache.putIfAbsent(ipAddress, AsnResponse.class, ip -> { - try { - return lazyLoader.get().asn(ip); - } catch (AddressNotFoundException e) { - throw new AddressNotFoundRuntimeException(e); - } catch (Exception e) { - throw new RuntimeException(e); - } - }) - ); + AsnResponse response = AccessController.doPrivileged(() -> cache.putIfAbsent(ipAddress, AsnResponse.class, ip -> { + try { + return lazyLoader.get().asn(ip); + } catch (AddressNotFoundException e) { + throw new AddressNotFoundRuntimeException(e); + } catch (Exception e) { + throw new RuntimeException(e); + } + })); Long asn = response.getAutonomousSystemNumber(); String organization_name = response.getAutonomousSystemOrganization();