feat: restrict AI tagging to parser-discovered test methods

This commit is contained in:
2026-03-10 20:42:26 +01:00
parent 32ddfa988b
commit a592ce1330
13 changed files with 357 additions and 99 deletions

View File

@@ -20,19 +20,20 @@ import org.egothor.methodatlas.ai.AiProvider;
import org.egothor.methodatlas.ai.AiSuggestionEngine;
import org.egothor.methodatlas.ai.AiSuggestionEngineImpl;
import org.egothor.methodatlas.ai.AiSuggestionException;
import org.egothor.methodatlas.ai.PromptBuilder;
import org.egothor.methodatlas.ai.SuggestionLookup;
import com.github.javaparser.ParserConfiguration;
import com.github.javaparser.ParserConfiguration.LanguageLevel;
import com.github.javaparser.StaticJavaParser;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.PackageDeclaration;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.expr.AnnotationExpr;
import com.github.javaparser.ast.expr.ArrayInitializerExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.MemberValuePair;
import com.github.javaparser.ast.nodeTypes.NodeWithName;
/**
* Command-line application for scanning Java test sources, extracting JUnit
@@ -298,33 +299,63 @@ public class MethodAtlasApp { // NOPMD
private static void processFile(Path path, OutputMode mode, AiOptions aiOptions, AiSuggestionEngine aiEngine) {
try {
CompilationUnit compilationUnit = StaticJavaParser.parse(path);
String packageName = compilationUnit.getPackageDeclaration().map(PackageDeclaration::getNameAsString)
.orElse("");
String packageName = compilationUnit.getPackageDeclaration().map(NodeWithName::getNameAsString).orElse("");
compilationUnit.findAll(ClassOrInterfaceDeclaration.class).forEach(clazz -> {
String className = clazz.getNameAsString();
String fqcn = packageName.isEmpty() ? className : packageName + "." + className;
SuggestionLookup suggestionLookup = resolveSuggestionLookup(clazz, fqcn, aiOptions, aiEngine);
clazz.findAll(MethodDeclaration.class).forEach(method -> {
if (!isJUnitTest(method)) {
return;
}
List<MethodDeclaration> testMethods = findJUnitTestMethods(clazz);
SuggestionLookup suggestionLookup = resolveSuggestionLookup(clazz, fqcn, testMethods, aiOptions,
aiEngine);
for (MethodDeclaration method : testMethods) {
int loc = countLOC(method);
List<String> tags = getTagValues(method);
AiMethodSuggestion suggestion = suggestionLookup.find(method.getNameAsString()).orElse(null);
emit(mode, aiOptions.enabled(), fqcn, method.getNameAsString(), loc, tags, suggestion);
});
}
});
} catch (Exception e) {
if (LOG.isLoggable(Level.WARNING)) {
LOG.log(Level.WARNING, "Failed to parse: {0} due to {1}", new Object[] { path, e.getMessage() });
LOG.log(Level.WARNING, "Failed to parse: " + path, e);
}
}
}
/**
* Returns all JUnit test methods declared within the specified class.
*
* <p>
* The method traverses the supplied {@link ClassOrInterfaceDeclaration} and
* collects all {@link MethodDeclaration} instances that satisfy the
* {@link #isJUnitTest(MethodDeclaration)} predicate.
* </p>
*
* <p>
* The detection logic currently recognizes methods annotated with supported
* JUnit Jupiter test annotations such as {@code @Test},
* {@code @ParameterizedTest}, and {@code @RepeatedTest}. Only methods matching
* these criteria are included in the returned list.
* </p>
*
* <p>
* The returned list preserves the discovery order produced by
* {@link com.github.javaparser.ast.Node#findAll(Class)}, which corresponds to
* the order of method declarations in the source file.
* </p>
*
* @param clazz parsed class declaration whose methods should be inspected
* @return list of JUnit test method declarations contained in the class;
* possibly empty but never {@code null}
*
* @see #isJUnitTest(MethodDeclaration)
*/
private static List<MethodDeclaration> findJUnitTestMethods(ClassOrInterfaceDeclaration clazz) {
return clazz.findAll(MethodDeclaration.class).stream().filter(MethodAtlasApp::isJUnitTest).toList();
}
/**
* Resolves method-level AI suggestions for a parsed class.
*
@@ -342,8 +373,8 @@ public class MethodAtlasApp { // NOPMD
* @return lookup of AI suggestions keyed by method name; never {@code null}
*/
private static SuggestionLookup resolveSuggestionLookup(ClassOrInterfaceDeclaration clazz, String fqcn,
AiOptions aiOptions, AiSuggestionEngine aiEngine) {
if (!aiOptions.enabled() || aiEngine == null) {
List<MethodDeclaration> testMethods, AiOptions aiOptions, AiSuggestionEngine aiEngine) {
if (!aiOptions.enabled() || aiEngine == null || testMethods.isEmpty()) {
return SuggestionLookup.from(null);
}
@@ -356,18 +387,55 @@ public class MethodAtlasApp { // NOPMD
return SuggestionLookup.from(null);
}
List<PromptBuilder.TargetMethod> targetMethods = toTargetMethods(testMethods);
try {
AiClassSuggestion aiClassSuggestion = aiEngine.suggestForClass(fqcn, classSource);
AiClassSuggestion aiClassSuggestion = aiEngine.suggestForClass(fqcn, classSource, targetMethods);
return SuggestionLookup.from(aiClassSuggestion);
} catch (AiSuggestionException e) {
if (LOG.isLoggable(Level.WARNING)) {
LOG.log(Level.WARNING, "AI suggestion failed for class {0}: {1}",
new Object[] { fqcn, e.getMessage() });
LOG.log(Level.WARNING, "AI suggestion failed for class " + fqcn, e);
}
return SuggestionLookup.from(null);
}
}
/**
* Converts parsed JUnit test method declarations into prompt target
* descriptors.
*
* <p>
* The returned {@link PromptBuilder.TargetMethod} objects provide a compact
* representation of the methods that should be analyzed by the AI
* classification prompt. Each descriptor contains the method name together with
* the optional begin and end line numbers derived from the parser source range.
* </p>
*
* <p>
* Line numbers are obtained from {@link MethodDeclaration#getRange()} when
* source position information is available. If the parser did not retain range
* metadata for a method, the corresponding line value is set to {@code null}.
* </p>
*
* <p>
* The resulting list preserves the order of the supplied method declarations.
* </p>
*
* @param testMethods list of parsed JUnit test method declarations
* @return list of prompt target descriptors representing the supplied methods;
* possibly empty but never {@code null}
*
* @see PromptBuilder.TargetMethod
* @see MethodDeclaration#getRange()
*/
private static List<PromptBuilder.TargetMethod> toTargetMethods(List<MethodDeclaration> testMethods) {
return testMethods.stream()
.map(method -> new PromptBuilder.TargetMethod(method.getNameAsString(),
method.getRange().map(range -> range.begin.line).orElse(null),
method.getRange().map(range -> range.end.line).orElse(null)))
.toList();
}
/**
* Creates the AI suggestion engine for the current run.
*

View File

@@ -1,5 +1,7 @@
package org.egothor.methodatlas.ai;
import java.util.List;
/**
* Provider-specific client abstraction used to communicate with external AI
* inference services.
@@ -76,10 +78,12 @@ public interface AiProviderClient {
* objects describing individual test methods.
* </p>
*
* @param fqcn fully qualified name of the analyzed class
* @param classSource complete source code of the class being analyzed
* @param taxonomyText security taxonomy definition guiding the AI
* classification
* @param fqcn fully qualified name of the analyzed class
* @param classSource complete source code of the class being analyzed
* @param taxonomyText security taxonomy definition guiding the AI
* classification
* @param targetMethods deterministically extracted JUnit test methods that must
* be classified
* @return normalized AI classification result
*
* @throws AiSuggestionException if the request fails due to provider errors,
@@ -88,6 +92,6 @@ public interface AiProviderClient {
* @see AiClassSuggestion
* @see AiMethodSuggestion
*/
AiClassSuggestion suggestForClass(String fqcn, String classSource, String taxonomyText)
throws AiSuggestionException;
AiClassSuggestion suggestForClass(String fqcn, String classSource, String taxonomyText,
List<PromptBuilder.TargetMethod> targetMethods) throws AiSuggestionException;
}

View File

@@ -1,5 +1,7 @@
package org.egothor.methodatlas.ai;
import java.util.List;
/**
* High-level AI orchestration contract for security classification of parsed
* test classes.
@@ -52,8 +54,10 @@ public interface AiSuggestionEngine {
* using full class context.
* </p>
*
* @param fqcn fully qualified class name of the parsed test class
* @param classSource complete source code of the class to analyze
* @param fqcn fully qualified class name of the parsed test class
* @param classSource complete source code of the class to analyze
* @param targetMethods deterministically extracted JUnit test methods that must
* be classified
* @return normalized AI classification result for the class and its methods
*
* @throws AiSuggestionException if analysis fails due to provider communication
@@ -63,5 +67,6 @@ public interface AiSuggestionEngine {
* @see AiClassSuggestion
* @see AiMethodSuggestion
*/
AiClassSuggestion suggestForClass(String fqcn, String classSource) throws AiSuggestionException;
AiClassSuggestion suggestForClass(String fqcn, String classSource, List<PromptBuilder.TargetMethod> targetMethods)
throws AiSuggestionException;
}

View File

@@ -2,6 +2,7 @@ package org.egothor.methodatlas.ai;
import java.io.IOException;
import java.nio.file.Files;
import java.util.List;
/**
* Default implementation of {@link AiSuggestionEngine} that coordinates
@@ -73,8 +74,10 @@ public final class AiSuggestionEngineImpl implements AiSuggestionEngine {
* taxonomy text loaded at engine initialization time.
* </p>
*
* @param fqcn fully qualified class name of the analyzed test class
* @param classSource complete source code of the class to analyze
* @param fqcn fully qualified class name of the analyzed test class
* @param classSource complete source code of the class to analyze
* @param targetMethods deterministically extracted JUnit test methods that must
* be classified
* @return normalized AI classification result for the class and its methods
*
* @throws AiSuggestionException if the provider fails to analyze the class or
@@ -84,8 +87,9 @@ public final class AiSuggestionEngineImpl implements AiSuggestionEngine {
* @see AiProviderClient#suggestForClass(String, String, String)
*/
@Override
public AiClassSuggestion suggestForClass(String fqcn, String classSource) throws AiSuggestionException {
return client.suggestForClass(fqcn, classSource, taxonomyText);
public AiClassSuggestion suggestForClass(String fqcn, String classSource,
List<PromptBuilder.TargetMethod> targetMethods) throws AiSuggestionException {
return client.suggestForClass(fqcn, classSource, taxonomyText, targetMethods);
}
/**

View File

@@ -118,9 +118,11 @@ public final class AnthropicClient implements AiProviderClient {
* model, which is then deserialized into an {@link AiClassSuggestion}.
* </p>
*
* @param fqcn fully qualified class name being analyzed
* @param classSource complete source code of the class
* @param taxonomyText taxonomy definition guiding classification
* @param fqcn fully qualified class name being analyzed
* @param classSource complete source code of the class
* @param taxonomyText taxonomy definition guiding classification
* @param targetMethods deterministically extracted JUnit test methods that must
* be classified
*
* @return normalized AI classification result
*
@@ -129,10 +131,10 @@ public final class AnthropicClient implements AiProviderClient {
* invalid content
*/
@Override
public AiClassSuggestion suggestForClass(String fqcn, String classSource, String taxonomyText)
throws AiSuggestionException {
public AiClassSuggestion suggestForClass(String fqcn, String classSource, String taxonomyText,
List<PromptBuilder.TargetMethod> targetMethods) throws AiSuggestionException {
try {
String prompt = PromptBuilder.build(fqcn, classSource, taxonomyText);
String prompt = PromptBuilder.build(fqcn, classSource, taxonomyText, targetMethods);
MessageRequest payload = new MessageRequest(options.modelName(), SYSTEM_PROMPT,
List.of(new ContentMessage("user", List.of(new ContentBlock("text", prompt)))), 0.0, 2_000);

View File

@@ -128,9 +128,11 @@ public final class OllamaClient implements AiProviderClient {
* {@link AiClassSuggestion}, and then normalized before being returned.
* </p>
*
* @param fqcn fully qualified class name being analyzed
* @param classSource complete source code of the class being analyzed
* @param taxonomyText taxonomy definition guiding classification
* @param fqcn fully qualified class name being analyzed
* @param classSource complete source code of the class being analyzed
* @param taxonomyText taxonomy definition guiding classification
* @param targetMethods deterministically extracted JUnit test methods that must
* be classified
* @return normalized AI classification result
*
* @throws AiSuggestionException if the request fails, if the provider returns
@@ -138,10 +140,10 @@ public final class OllamaClient implements AiProviderClient {
* fails
*/
@Override
public AiClassSuggestion suggestForClass(String fqcn, String classSource, String taxonomyText)
throws AiSuggestionException {
public AiClassSuggestion suggestForClass(String fqcn, String classSource, String taxonomyText,
List<PromptBuilder.TargetMethod> targetMethods) throws AiSuggestionException {
try {
String prompt = PromptBuilder.build(fqcn, classSource, taxonomyText);
String prompt = PromptBuilder.build(fqcn, classSource, taxonomyText, targetMethods);
ChatRequest payload = new ChatRequest(options.modelName(),
List.of(new Message("system", SYSTEM_PROMPT), new Message("user", prompt)), false,

View File

@@ -128,10 +128,11 @@ public final class OpenAiCompatibleClient implements AiProviderClient {
* {@link AiClassSuggestion}.
* </p>
*
* @param fqcn fully qualified class name being analyzed
* @param classSource complete source code of the class
* @param taxonomyText taxonomy definition guiding classification
*
* @param fqcn fully qualified class name being analyzed
* @param classSource complete source code of the class
* @param taxonomyText taxonomy definition guiding classification
* @param targetMethods deterministically extracted JUnit test methods that must
* be classified
* @return normalized classification result
*
* @throws AiSuggestionException if the provider request fails, the model
@@ -139,10 +140,10 @@ public final class OpenAiCompatibleClient implements AiProviderClient {
* fails
*/
@Override
public AiClassSuggestion suggestForClass(String fqcn, String classSource, String taxonomyText)
throws AiSuggestionException {
public AiClassSuggestion suggestForClass(String fqcn, String classSource, String taxonomyText,
List<PromptBuilder.TargetMethod> targetMethods) throws AiSuggestionException {
try {
String prompt = PromptBuilder.build(fqcn, classSource, taxonomyText);
String prompt = PromptBuilder.build(fqcn, classSource, taxonomyText, targetMethods);
ChatRequest payload = new ChatRequest(options.modelName(),
List.of(new Message("system", SYSTEM_PROMPT), new Message("user", prompt)), 0.0);

View File

@@ -1,5 +1,9 @@
package org.egothor.methodatlas.ai;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* Utility responsible for constructing the prompt supplied to AI providers for
* security classification of JUnit test classes.
@@ -18,6 +22,13 @@ package org.egothor.methodatlas.ai;
* </ul>
*
* <p>
* This revision keeps the full class source as semantic context but removes
* method discovery from the AI model. The caller supplies the exact list of
* JUnit test methods that must be classified, optionally with source line
* anchors.
* </p>
*
* <p>
* The resulting prompt is passed to the configured AI provider and instructs
* the model to produce a deterministic JSON classification result describing
* security relevance and taxonomy tags for individual test methods.
@@ -38,6 +49,23 @@ package org.egothor.methodatlas.ai;
* @see OptimizedSecurityTaxonomy
*/
public final class PromptBuilder {
/**
* Deterministically extracted test method descriptor supplied to the prompt.
*
* @param methodName name of the JUnit test method
* @param beginLine first source line of the method, or {@code null} if unknown
* @param endLine last source line of the method, or {@code null} if unknown
*/
public record TargetMethod(String methodName, Integer beginLine, Integer endLine) {
public TargetMethod {
Objects.requireNonNull(methodName, "methodName");
if (methodName.isBlank()) {
throw new IllegalArgumentException("methodName must not be blank");
}
}
}
/**
* Prevents instantiation of this utility class.
*/
@@ -55,6 +83,7 @@ public final class PromptBuilder {
* <ul>
* <li>task instructions describing the classification objective</li>
* <li>the security taxonomy definition controlling allowed tags</li>
* <li>the exact list of target test methods to classify</li>
* <li>strict output rules enforcing JSON-only responses</li>
* <li>a formal JSON schema describing the expected result structure</li>
* <li>the fully qualified class name of the analyzed test class</li>
@@ -73,22 +102,40 @@ public final class PromptBuilder {
* in chat-based inference APIs.
* </p>
*
* @param fqcn fully qualified class name of the test class being
* analyzed
* @param classSource complete source code of the test class
* @param taxonomyText taxonomy definition guiding classification
* @param fqcn fully qualified class name of the test class being
* analyzed
* @param classSource complete source code of the test class
* @param taxonomyText taxonomy definition guiding classification
* @param targetMethods exact list of deterministically discovered JUnit test
* methods to classify
* @return formatted prompt supplied to the AI provider
*
* @see AiSuggestionEngine#suggestForClass(String, String)
*/
public static String build(String fqcn, String classSource, String taxonomyText) {
public static String build(String fqcn, String classSource, String taxonomyText, List<TargetMethod> targetMethods) {
Objects.requireNonNull(fqcn, "fqcn");
Objects.requireNonNull(classSource, "classSource");
Objects.requireNonNull(taxonomyText, "taxonomyText");
Objects.requireNonNull(targetMethods, "targetMethods");
if (targetMethods.isEmpty()) {
throw new IllegalArgumentException("targetMethods must not be empty");
}
String targetMethodBlock = targetMethods.stream().map(PromptBuilder::formatTargetMethod)
.collect(Collectors.joining("\n"));
String expectedMethodNames = targetMethods.stream().map(TargetMethod::methodName)
.map(name -> "\"" + name + "\"").collect(Collectors.joining(", "));
return """
You are analyzing a single JUnit 5 test class and suggesting security tags.
TASK
- Analyze the WHOLE class for context.
- Return per-method suggestions for JUnit test methods only.
- Classify ONLY the methods explicitly listed in TARGET TEST METHODS.
- Do not invent methods that do not exist.
- Do not classify helper methods, lifecycle methods, nested classes, or any method not listed.
- Be conservative.
- If uncertain, classify the method as securityRelevant=false.
- Ignore pure functional / performance / UX tests unless they explicitly validate a security property.
@@ -96,17 +143,30 @@ public final class PromptBuilder {
CONTROLLED TAXONOMY
%s
TARGET TEST METHODS
The following methods were extracted deterministically by the parser and are the ONLY methods
you are allowed to classify. Use the full class source only as context for understanding them.
%s
OUTPUT RULES
- Return JSON only.
- No markdown.
- No prose outside JSON.
- Return exactly one result for each target method.
- methodName values in the output must exactly match one of:
[%s]
- Do not omit any listed method.
- Do not include any additional methods.
- Tags must come only from this closed set:
security, auth, access-control, crypto, input-validation, injection, data-protection, logging, error-handling, owasp
- If securityRelevant=true, tags MUST include "security".
- Add 1-3 tags total per method.
- displayName must be null when securityRelevant=false.
- If securityRelevant=false, displayName must be null.
- If securityRelevant=false, tags must be [].
- If securityRelevant=true, displayName must match:
SECURITY: <control/property> - <scenario>
- reason should be short and specific.
JSON SHAPE
{
@@ -131,6 +191,17 @@ public final class PromptBuilder {
SOURCE
%s
"""
.formatted(taxonomyText, fqcn, classSource);
.formatted(taxonomyText, targetMethodBlock, expectedMethodNames, fqcn, classSource);
}
private static String formatTargetMethod(TargetMethod targetMethod) {
StringBuilder builder = new StringBuilder("- ").append(targetMethod.methodName());
if (targetMethod.beginLine() != null || targetMethod.endLine() != null) {
builder.append(" [lines ").append(targetMethod.beginLine() == null ? "?" : targetMethod.beginLine())
.append('-').append(targetMethod.endLine() == null ? "?" : targetMethod.endLine()).append(']');
}
return builder.toString();
}
}

View File

@@ -4,6 +4,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mockConstruction;
@@ -39,15 +40,15 @@ class MethodAtlasAppAiTest {
try (MockedConstruction<AiSuggestionEngineImpl> mocked = mockConstruction(AiSuggestionEngineImpl.class,
(mock, context) -> {
when(mock.suggestForClass(eq("com.acme.tests.SampleOneTest"), anyString()))
when(mock.suggestForClass(eq("com.acme.tests.SampleOneTest"), anyString(), any()))
.thenReturn(sampleOneSuggestion());
when(mock.suggestForClass(eq("com.acme.other.AnotherTest"), anyString()))
when(mock.suggestForClass(eq("com.acme.other.AnotherTest"), anyString(), any()))
.thenReturn(anotherSuggestion());
when(mock.suggestForClass(eq("com.acme.security.AccessControlServiceTest"), anyString()))
when(mock.suggestForClass(eq("com.acme.security.AccessControlServiceTest"), anyString(), any()))
.thenReturn(accessControlSuggestion());
when(mock.suggestForClass(eq("com.acme.storage.PathTraversalValidationTest"), anyString()))
when(mock.suggestForClass(eq("com.acme.storage.PathTraversalValidationTest"), anyString(), any()))
.thenReturn(pathTraversalSuggestion());
when(mock.suggestForClass(eq("com.acme.audit.AuditLoggingTest"), anyString()))
when(mock.suggestForClass(eq("com.acme.audit.AuditLoggingTest"), anyString(), any()))
.thenReturn(auditLoggingSuggestion());
})) {
@@ -99,15 +100,15 @@ class MethodAtlasAppAiTest {
try (MockedConstruction<AiSuggestionEngineImpl> mocked = mockConstruction(AiSuggestionEngineImpl.class,
(mock, context) -> {
when(mock.suggestForClass(eq("com.acme.tests.SampleOneTest"), anyString()))
when(mock.suggestForClass(eq("com.acme.tests.SampleOneTest"), anyString(), any()))
.thenReturn(sampleOneSuggestion());
when(mock.suggestForClass(eq("com.acme.other.AnotherTest"), anyString()))
when(mock.suggestForClass(eq("com.acme.other.AnotherTest"), anyString(), any()))
.thenReturn(anotherSuggestion());
when(mock.suggestForClass(eq("com.acme.security.AccessControlServiceTest"), anyString()))
when(mock.suggestForClass(eq("com.acme.security.AccessControlServiceTest"), anyString(), any()))
.thenThrow(new AiSuggestionException("Simulated provider failure"));
when(mock.suggestForClass(eq("com.acme.storage.PathTraversalValidationTest"), anyString()))
when(mock.suggestForClass(eq("com.acme.storage.PathTraversalValidationTest"), anyString(), any()))
.thenReturn(pathTraversalSuggestion());
when(mock.suggestForClass(eq("com.acme.audit.AuditLoggingTest"), anyString()))
when(mock.suggestForClass(eq("com.acme.audit.AuditLoggingTest"), anyString(), any()))
.thenReturn(auditLoggingSuggestion());
})) {
@@ -167,7 +168,7 @@ class MethodAtlasAppAiTest {
assertEquals("", row.get(7));
assertEquals(1, mocked.constructed().size(), "Expected one AI engine instance");
verify(mocked.constructed().get(0), never()).suggestForClass(anyString(), anyString());
verify(mocked.constructed().get(0), never()).suggestForClass(anyString(), anyString(), any());
}
}

View File

@@ -34,22 +34,31 @@ class AiSuggestionEngineImplTest {
"SECURITY: authentication - reject unauthenticated request", List.of("security", "auth"),
"The test verifies that anonymous access is rejected.")));
List<PromptBuilder.TargetMethod> targetMethods = List
.of(new PromptBuilder.TargetMethod("shouldAllowOwnerToReadOwnStatement", null, null),
new PromptBuilder.TargetMethod("shouldAllowAdministratorToReadAnyStatement", null, null),
new PromptBuilder.TargetMethod("shouldDenyForeignUserFromReadingAnotherUsersStatement", null,
null),
new PromptBuilder.TargetMethod("shouldRejectUnauthenticatedRequest", null, null),
new PromptBuilder.TargetMethod("shouldRenderFriendlyAccountLabel", null, null));
AiOptions options = AiOptions.builder().enabled(true).provider(AiProvider.OPENAI).build();
try (MockedStatic<AiProviderFactory> factory = mockStatic(AiProviderFactory.class)) {
factory.when(() -> AiProviderFactory.create(options)).thenReturn(client);
when(client.suggestForClass(eq("com.acme.security.AccessControlServiceTest"),
eq("class AccessControlServiceTest {}"), eq(DefaultSecurityTaxonomy.text()))).thenReturn(expected);
eq("class AccessControlServiceTest {}"), eq(DefaultSecurityTaxonomy.text()), eq(targetMethods)))
.thenReturn(expected);
AiSuggestionEngineImpl engine = new AiSuggestionEngineImpl(options);
AiClassSuggestion actual = engine.suggestForClass("com.acme.security.AccessControlServiceTest",
"class AccessControlServiceTest {}");
"class AccessControlServiceTest {}", targetMethods);
assertSame(expected, actual);
factory.verify(() -> AiProviderFactory.create(options));
verify(client).suggestForClass("com.acme.security.AccessControlServiceTest",
"class AccessControlServiceTest {}", DefaultSecurityTaxonomy.text());
"class AccessControlServiceTest {}", DefaultSecurityTaxonomy.text(), targetMethods);
verifyNoMoreInteractions(client);
}
}
@@ -64,24 +73,30 @@ class AiSuggestionEngineImplTest {
List.of("security", "input-validation", "owasp"),
"The test rejects a classic path traversal payload.")));
List<PromptBuilder.TargetMethod> targetMethods = List.of(
new PromptBuilder.TargetMethod("shouldRejectRelativePathTraversalSequence", null, null),
new PromptBuilder.TargetMethod("shouldRejectNestedTraversalAfterNormalization", null, null),
new PromptBuilder.TargetMethod("shouldAllowSafePathInsideUploadRoot", null, null),
new PromptBuilder.TargetMethod("shouldBuildDownloadFileName", null, null));
AiOptions options = AiOptions.builder().enabled(true).provider(AiProvider.OLLAMA)
.taxonomyMode(AiOptions.TaxonomyMode.OPTIMIZED).build();
try (MockedStatic<AiProviderFactory> factory = mockStatic(AiProviderFactory.class)) {
factory.when(() -> AiProviderFactory.create(options)).thenReturn(client);
when(client.suggestForClass(eq("com.acme.storage.PathTraversalValidationTest"),
eq("class PathTraversalValidationTest {}"), eq(OptimizedSecurityTaxonomy.text())))
.thenReturn(expected);
eq("class PathTraversalValidationTest {}"), eq(OptimizedSecurityTaxonomy.text()),
eq(targetMethods))).thenReturn(expected);
AiSuggestionEngineImpl engine = new AiSuggestionEngineImpl(options);
AiClassSuggestion actual = engine.suggestForClass("com.acme.storage.PathTraversalValidationTest",
"class PathTraversalValidationTest {}");
"class PathTraversalValidationTest {}", targetMethods);
assertSame(expected, actual);
factory.verify(() -> AiProviderFactory.create(options));
verify(client).suggestForClass("com.acme.storage.PathTraversalValidationTest",
"class PathTraversalValidationTest {}", OptimizedSecurityTaxonomy.text());
"class PathTraversalValidationTest {}", OptimizedSecurityTaxonomy.text(), targetMethods);
verifyNoMoreInteractions(client);
}
}
@@ -104,23 +119,29 @@ class AiSuggestionEngineImplTest {
"SECURITY: logging - redact bearer token", List.of("security", "logging"),
"The test ensures credentials are not written to logs.")));
List<PromptBuilder.TargetMethod> targetMethods = List.of(
new PromptBuilder.TargetMethod("shouldWriteAuditEventForPrivilegeChange", null, null),
new PromptBuilder.TargetMethod("shouldNotLogRawBearerToken", null, null),
new PromptBuilder.TargetMethod("shouldNotLogPlaintextPasswordOnAuthenticationFailure", null, null),
new PromptBuilder.TargetMethod("shouldFormatHumanReadableSupportMessage", null, null));
AiOptions options = AiOptions.builder().enabled(true).provider(AiProvider.OPENROUTER).taxonomyFile(taxonomyFile)
.build();
try (MockedStatic<AiProviderFactory> factory = mockStatic(AiProviderFactory.class)) {
factory.when(() -> AiProviderFactory.create(options)).thenReturn(client);
when(client.suggestForClass(eq("com.acme.audit.AuditLoggingTest"), eq("class AuditLoggingTest {}"),
eq(taxonomyText))).thenReturn(expected);
eq(taxonomyText), eq(targetMethods))).thenReturn(expected);
AiSuggestionEngineImpl engine = new AiSuggestionEngineImpl(options);
AiClassSuggestion actual = engine.suggestForClass("com.acme.audit.AuditLoggingTest",
"class AuditLoggingTest {}");
"class AuditLoggingTest {}", targetMethods);
assertSame(expected, actual);
factory.verify(() -> AiProviderFactory.create(options));
verify(client).suggestForClass("com.acme.audit.AuditLoggingTest", "class AuditLoggingTest {}",
taxonomyText);
verify(client).suggestForClass("com.acme.audit.AuditLoggingTest", "class AuditLoggingTest {}", taxonomyText,
targetMethods);
verifyNoMoreInteractions(client);
}
}

View File

@@ -82,6 +82,11 @@ class OllamaClientTest {
}
""";
String taxonomyText = "security, input-validation, owasp";
List<PromptBuilder.TargetMethod> targetMethods = List.of(
new PromptBuilder.TargetMethod("shouldRejectRelativePathTraversalSequence", null, null),
new PromptBuilder.TargetMethod("shouldRejectNestedTraversalAfterNormalization", null, null),
new PromptBuilder.TargetMethod("shouldAllowSafePathInsideUploadRoot", null, null),
new PromptBuilder.TargetMethod("shouldBuildDownloadFileName", null, null));
String responseBody = """
{
@@ -111,7 +116,7 @@ class OllamaClientTest {
.modelName("qwen2.5-coder:7b").baseUrl("http://localhost:11434").build();
OllamaClient client = new OllamaClient(options);
AiClassSuggestion suggestion = client.suggestForClass(fqcn, classSource, taxonomyText);
AiClassSuggestion suggestion = client.suggestForClass(fqcn, classSource, taxonomyText, targetMethods);
assertEquals(fqcn, suggestion.className());
assertEquals(Boolean.TRUE, suggestion.classSecurityRelevant());
@@ -137,12 +142,12 @@ class OllamaClientTest {
assertNotNull(requestBody);
assertTrue(requestBody.contains("\"model\":\"qwen2.5-coder:7b\""));
assertTrue(requestBody.contains("\"stream\":false"));
assertTrue(requestBody.contains("You are a precise software security classification engine."));
assertTrue(requestBody.contains("You classify JUnit 5 tests and return strict JSON only."));
assertTrue(requestBody.contains("\"temperature\":0.0"));
assertTrue(requestBody.contains("FQCN: " + fqcn));
assertTrue(requestBody.contains("PathTraversalValidationTest"));
assertTrue(requestBody.contains("shouldRejectRelativePathTraversalSequence"));
assertTrue(requestBody.contains("shouldRejectNestedTraversalAfterNormalization"));
assertTrue(requestBody.contains("shouldAllowSafePathInsideUploadRoot"));
assertTrue(requestBody.contains("shouldBuildDownloadFileName"));
assertTrue(requestBody.contains(taxonomyText));
}
}
@@ -152,6 +157,12 @@ class OllamaClientTest {
ObjectMapper mapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
String fqcn = "com.acme.audit.AuditLoggingTest";
List<PromptBuilder.TargetMethod> targetMethods = List.of(
new PromptBuilder.TargetMethod("shouldWriteAuditEventForPrivilegeChange", null, null),
new PromptBuilder.TargetMethod("shouldNotLogRawBearerToken", null, null),
new PromptBuilder.TargetMethod("shouldNotLogPlaintextPasswordOnAuthenticationFailure", null, null),
new PromptBuilder.TargetMethod("shouldFormatHumanReadableSupportMessage", null, null));
String responseBody = """
{
"message": {
@@ -176,7 +187,8 @@ class OllamaClientTest {
OllamaClient client = new OllamaClient(options);
AiSuggestionException ex = org.junit.jupiter.api.Assertions.assertThrows(AiSuggestionException.class,
() -> client.suggestForClass(fqcn, "class AuditLoggingTest {}", "security, logging"));
() -> client.suggestForClass(fqcn, "class AuditLoggingTest {}", "security, logging",
targetMethods));
assertEquals("Ollama suggestion failed for " + fqcn, ex.getMessage());
assertInstanceOf(AiSuggestionException.class, ex.getCause());

View File

@@ -54,6 +54,13 @@ class OpenAiCompatibleClientTest {
}
""";
String taxonomyText = "security, auth, access-control";
List<PromptBuilder.TargetMethod> targetMethods = List
.of(new PromptBuilder.TargetMethod("shouldAllowOwnerToReadOwnStatement", null, null),
new PromptBuilder.TargetMethod("shouldAllowAdministratorToReadAnyStatement", null, null),
new PromptBuilder.TargetMethod("shouldDenyForeignUserFromReadingAnotherUsersStatement", null,
null),
new PromptBuilder.TargetMethod("shouldRejectUnauthenticatedRequest", null, null),
new PromptBuilder.TargetMethod("shouldRenderFriendlyAccountLabel", null, null));
String responseBody = """
{
@@ -74,7 +81,7 @@ class OpenAiCompatibleClientTest {
.baseUrl("https://api.openai.com").apiKey("sk-test-value").build();
OpenAiCompatibleClient client = new OpenAiCompatibleClient(options);
AiClassSuggestion suggestion = client.suggestForClass(fqcn, classSource, taxonomyText);
AiClassSuggestion suggestion = client.suggestForClass(fqcn, classSource, taxonomyText, targetMethods);
assertEquals(fqcn, suggestion.className());
assertEquals(Boolean.TRUE, suggestion.classSecurityRelevant());
@@ -100,11 +107,13 @@ class OpenAiCompatibleClientTest {
String requestBody = capturedBody.get();
assertNotNull(requestBody);
assertTrue(requestBody.contains("\"model\":\"gpt-4o-mini\""));
assertTrue(requestBody.contains("You are a precise software security classification engine."));
assertTrue(requestBody.contains("You classify JUnit 5 tests and return strict JSON only."));
assertTrue(requestBody.contains("FQCN: " + fqcn));
assertTrue(requestBody.contains("AccessControlServiceTest"));
assertTrue(requestBody.contains("shouldAllowOwnerToReadOwnStatement"));
assertTrue(requestBody.contains("shouldAllowAdministratorToReadAnyStatement"));
assertTrue(requestBody.contains("shouldDenyForeignUserFromReadingAnotherUsersStatement"));
assertTrue(requestBody.contains("shouldRejectUnauthenticatedRequest"));
assertTrue(requestBody.contains("shouldRenderFriendlyAccountLabel"));
assertTrue(requestBody.contains(taxonomyText));
assertTrue(requestBody.contains("\"temperature\":0.0"));
}
@@ -126,13 +135,19 @@ class OpenAiCompatibleClientTest {
}
""";
List<PromptBuilder.TargetMethod> targetMethods = List.of(
new PromptBuilder.TargetMethod("shouldWriteAuditEventForPrivilegeChange", null, null),
new PromptBuilder.TargetMethod("shouldNotLogRawBearerToken", null, null),
new PromptBuilder.TargetMethod("shouldNotLogPlaintextPasswordOnAuthenticationFailure", null, null),
new PromptBuilder.TargetMethod("shouldFormatHumanReadableSupportMessage", null, null));
try (MockedConstruction<HttpSupport> mocked = mockHttpSupport(mapper, responseBody, null)) {
AiOptions options = AiOptions.builder().enabled(true).provider(AiProvider.OPENROUTER)
.modelName("openai/gpt-4o-mini").baseUrl("https://openrouter.ai/api").apiKey("or-test-key").build();
OpenAiCompatibleClient client = new OpenAiCompatibleClient(options);
AiClassSuggestion suggestion = client.suggestForClass("com.acme.audit.AuditLoggingTest",
"class AuditLoggingTest {}", "security, logging");
"class AuditLoggingTest {}", "security, logging", targetMethods);
assertEquals("com.acme.audit.AuditLoggingTest", suggestion.className());
@@ -157,6 +172,12 @@ class OpenAiCompatibleClientTest {
}
""";
List<PromptBuilder.TargetMethod> targetMethods = List.of(
new PromptBuilder.TargetMethod("shouldWriteAuditEventForPrivilegeChange", null, null),
new PromptBuilder.TargetMethod("shouldNotLogRawBearerToken", null, null),
new PromptBuilder.TargetMethod("shouldNotLogPlaintextPasswordOnAuthenticationFailure", null, null),
new PromptBuilder.TargetMethod("shouldFormatHumanReadableSupportMessage", null, null));
try (MockedConstruction<HttpSupport> mocked = mockHttpSupport(mapper, responseBody, null)) {
AiOptions options = AiOptions.builder().enabled(true).provider(AiProvider.OPENAI).apiKey("sk-test-value")
.build();
@@ -164,7 +185,8 @@ class OpenAiCompatibleClientTest {
OpenAiCompatibleClient client = new OpenAiCompatibleClient(options);
AiSuggestionException ex = org.junit.jupiter.api.Assertions.assertThrows(AiSuggestionException.class,
() -> client.suggestForClass(fqcn, "class AuditLoggingTest {}", "security, logging"));
() -> client.suggestForClass(fqcn, "class AuditLoggingTest {}", "security, logging",
targetMethods));
assertEquals("OpenAI-compatible suggestion failed for " + fqcn, ex.getMessage());
assertInstanceOf(AiSuggestionException.class, ex.getCause());
@@ -189,6 +211,12 @@ class OpenAiCompatibleClientTest {
}
""";
List<PromptBuilder.TargetMethod> targetMethods = List.of(
new PromptBuilder.TargetMethod("shouldWriteAuditEventForPrivilegeChange", null, null),
new PromptBuilder.TargetMethod("shouldNotLogRawBearerToken", null, null),
new PromptBuilder.TargetMethod("shouldNotLogPlaintextPasswordOnAuthenticationFailure", null, null),
new PromptBuilder.TargetMethod("shouldFormatHumanReadableSupportMessage", null, null));
try (MockedConstruction<HttpSupport> mocked = mockHttpSupport(mapper, responseBody, null)) {
AiOptions options = AiOptions.builder().enabled(true).provider(AiProvider.OPENAI).apiKey("sk-test-value")
.build();
@@ -196,7 +224,8 @@ class OpenAiCompatibleClientTest {
OpenAiCompatibleClient client = new OpenAiCompatibleClient(options);
AiSuggestionException ex = org.junit.jupiter.api.Assertions.assertThrows(AiSuggestionException.class,
() -> client.suggestForClass(fqcn, "class AuditLoggingTest {}", "security, logging"));
() -> client.suggestForClass(fqcn, "class AuditLoggingTest {}", "security, logging",
targetMethods));
assertEquals("OpenAI-compatible suggestion failed for " + fqcn, ex.getMessage());
assertInstanceOf(AiSuggestionException.class, ex.getCause());

View File

@@ -1,8 +1,11 @@
package org.egothor.methodatlas.ai;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.util.List;
import org.junit.jupiter.api.Test;
class PromptBuilderTest {
@@ -33,22 +36,31 @@ class PromptBuilderTest {
- logging
""";
String prompt = PromptBuilder.build(fqcn, classSource, taxonomyText);
List<PromptBuilder.TargetMethod> targetMethods = List.of(
new PromptBuilder.TargetMethod("shouldRejectUnauthenticatedRequest", 8, 8),
new PromptBuilder.TargetMethod("shouldAllowOwnerToReadOwnStatement", 11, 11));
String prompt = PromptBuilder.build(fqcn, classSource, taxonomyText, targetMethods);
assertTrue(prompt.contains("FQCN: " + fqcn));
assertTrue(prompt.contains(classSource));
assertTrue(prompt.contains(taxonomyText));
assertTrue(prompt.contains("- shouldRejectUnauthenticatedRequest [lines 8-8]"));
assertTrue(prompt.contains("- shouldAllowOwnerToReadOwnStatement [lines 11-11]"));
}
@Test
void build_containsExpectedTaskInstructions() {
String prompt = PromptBuilder.build("com.acme.audit.AuditLoggingTest", "class AuditLoggingTest {}",
"security, logging");
"security, logging",
List.of(new PromptBuilder.TargetMethod("shouldWriteAuditEventForPrivilegeChange", null, null)));
assertTrue(prompt.contains("You are analyzing a single JUnit 5 test class and suggesting security tags."));
assertTrue(prompt.contains("- Analyze the WHOLE class for context."));
assertTrue(prompt.contains("- Return per-method suggestions for JUnit test methods only."));
assertTrue(prompt.contains("- Classify ONLY the methods explicitly listed in TARGET TEST METHODS."));
assertTrue(prompt.contains("- Do not invent methods that do not exist."));
assertTrue(prompt.contains(
"- Do not classify helper methods, lifecycle methods, nested classes, or any method not listed."));
assertTrue(prompt.contains("- Be conservative."));
assertTrue(prompt.contains("- If uncertain, classify the method as securityRelevant=false."));
}
@@ -56,7 +68,8 @@ class PromptBuilderTest {
@Test
void build_containsClosedTaxonomyRules() {
String prompt = PromptBuilder.build("com.acme.storage.PathTraversalValidationTest",
"class PathTraversalValidationTest {}", "security, input-validation, injection");
"class PathTraversalValidationTest {}", "security, input-validation, injection",
List.of(new PromptBuilder.TargetMethod("shouldRejectRelativePathTraversalSequence", null, null)));
assertTrue(prompt.contains("Tags must come only from this closed set:"));
assertTrue(prompt.contains(
@@ -68,9 +81,10 @@ class PromptBuilderTest {
@Test
void build_containsDisplayNameRules() {
String prompt = PromptBuilder.build("com.acme.security.AccessControlServiceTest",
"class AccessControlServiceTest {}", "security, auth, access-control");
"class AccessControlServiceTest {}", "security, auth, access-control",
List.of(new PromptBuilder.TargetMethod("shouldRejectUnauthenticatedRequest", null, null)));
assertTrue(prompt.contains("displayName must be null when securityRelevant=false."));
assertTrue(prompt.contains("If securityRelevant=false, displayName must be null."));
assertTrue(prompt.contains("If securityRelevant=true, displayName must match:"));
assertTrue(prompt.contains("SECURITY: <control/property> - <scenario>"));
}
@@ -78,7 +92,8 @@ class PromptBuilderTest {
@Test
void build_containsJsonShapeContract() {
String prompt = PromptBuilder.build("com.acme.audit.AuditLoggingTest", "class AuditLoggingTest {}",
"security, logging");
"security, logging",
List.of(new PromptBuilder.TargetMethod("shouldWriteAuditEventForPrivilegeChange", null, null)));
assertTrue(prompt.contains("JSON SHAPE"));
assertTrue(prompt.contains("\"className\": \"string\""));
@@ -105,10 +120,24 @@ class PromptBuilderTest {
""";
String prompt = PromptBuilder.build("com.acme.storage.PathTraversalValidationTest", classSource,
"security, input-validation, injection");
"security, input-validation, injection",
List.of(new PromptBuilder.TargetMethod("shouldRejectRelativePathTraversalSequence", 3, 5)));
assertTrue(prompt.contains("String userInput = \"../etc/passwd\";"));
assertTrue(prompt.contains("void shouldRejectRelativePathTraversalSequence()"));
assertTrue(prompt.contains("- shouldRejectRelativePathTraversalSequence [lines 3-5]"));
}
@Test
void build_includesExpectedMethodNamesConstraint() {
String prompt = PromptBuilder.build("com.acme.tests.SampleOneTest", "class SampleOneTest {}",
"security, crypto", List.of(new PromptBuilder.TargetMethod("alpha", 1, 1),
new PromptBuilder.TargetMethod("beta", 2, 2), new PromptBuilder.TargetMethod("gamma", 3, 3)));
assertTrue(prompt.contains("- methodName values in the output must exactly match one of:"));
assertTrue(prompt.contains("[\"alpha\", \"beta\", \"gamma\"]"));
assertTrue(prompt.contains("- Do not omit any listed method."));
assertTrue(prompt.contains("- Do not include any additional methods."));
}
@Test
@@ -116,10 +145,19 @@ class PromptBuilderTest {
String fqcn = "com.example.X";
String source = "class X {}";
String taxonomy = "security, logging";
List<PromptBuilder.TargetMethod> targetMethods = List.of(new PromptBuilder.TargetMethod("alpha", null, null));
String prompt1 = PromptBuilder.build(fqcn, source, taxonomy);
String prompt2 = PromptBuilder.build(fqcn, source, taxonomy);
String prompt1 = PromptBuilder.build(fqcn, source, taxonomy, targetMethods);
String prompt2 = PromptBuilder.build(fqcn, source, taxonomy, targetMethods);
assertEquals(prompt1, prompt2);
}
@Test
void build_rejectsEmptyTargetMethods() {
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class,
() -> PromptBuilder.build("com.example.X", "class X {}", "security", List.of()));
assertEquals("targetMethods must not be empty", ex.getMessage());
}
}