Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -481,39 +481,34 @@ private List<BasicDBObject> parseGeneratedKeyValues(BasicDBObject generatedData,
List<BasicDBObject> generatedOperationKeyValuePairs = new ArrayList<>();
if (generatedData.containsKey(operationType)) {
Object generatedValue = generatedData.get(operationType);
if (generatedValue instanceof String) {
String generatedKey = generatedValue.toString();
generatedOperationKeyValuePairs.add(new BasicDBObject(generatedKey, value));
} else if (generatedValue instanceof JSONObject) {
JSONObject generatedObj = (JSONObject) generatedValue;
for (String k : generatedObj.keySet()) {
generatedOperationKeyValuePairs.add(new BasicDBObject(k, generatedObj.get(k)));
}
} else if (generatedValue instanceof JSONArray) {
JSONArray generatedArray = (JSONArray) generatedValue;
for (int i = 0; i < generatedArray.length(); i++) {
Object generatedValueAtIndex = generatedArray.get(i);
if(generatedValueAtIndex instanceof String) {
String generatedKey = generatedValueAtIndex.toString();
generatedOperationKeyValuePairs.add(new BasicDBObject(generatedKey, value));
continue;
} else if (generatedValueAtIndex instanceof JSONObject) {
JSONObject generatedObj = (JSONObject) generatedValueAtIndex;
for (String k : generatedObj.keySet()) {
generatedOperationKeyValuePairs.add(new BasicDBObject(k, generatedObj.get(k)));
}
continue;
}
}
} else {
loggerMaker.errorAndAddToDb("operation " + operationType + " returned unexpected type: " + generatedValue.getClass().getName());
}
parseGeneratedValueRecursively(generatedValue, generatedOperationKeyValuePairs, value, null);
} else {
loggerMaker.errorAndAddToDb("operation " + operationType + " not found in generated response");
}
return generatedOperationKeyValuePairs;
}

private void parseGeneratedValueRecursively(Object obj, List<BasicDBObject> result, Object value, String parentKey) {
if (obj instanceof JSONObject) {
JSONObject jsonObject = (JSONObject) obj;
for (String key : jsonObject.keySet()) {
Object nestedValue = jsonObject.get(key);
parseGeneratedValueRecursively(nestedValue, result, value, key);
}
} else if (obj instanceof JSONArray) {
JSONArray jsonArray = (JSONArray) obj;
for (int i = 0; i < jsonArray.length(); i++) {
Object arrayElement = jsonArray.get(i);
parseGeneratedValueRecursively(arrayElement, result, value, null);
}
} else if (obj instanceof String) {
String generatedValue = obj.toString();
result.add(new BasicDBObject(parentKey, generatedValue));
} else {
// For other types, add the key-value pair directly
result.add(new BasicDBObject(parentKey, obj.toString()));
}
}

private static boolean removeAuthIfNotChanged(RawApi originalRawApi, RawApi testRawApi, String authMechanismHeaderKey, List<CustomAuthType> customAuthTypes) {
boolean removed = false;
Expand Down
45 changes: 45 additions & 0 deletions libs/utils/src/main/java/com/akto/data_actor/ClientActor.java
Original file line number Diff line number Diff line change
Expand Up @@ -3989,6 +3989,51 @@ public String getLLMPromptResponse(JSONObject promptPayload) {
return null;
}

@Override
public String getLLMResponseV2(JSONObject promptPayload) {
try {
JSONObject requestJson = new JSONObject();
requestJson.put("llmPayload", promptPayload);

OriginalHttpRequest request = new OriginalHttpRequest(
url + "/getLLMResponseV2",
"",
"POST",
requestJson.toString(),
buildHeaders(),
""
);

loggerMaker.debug("Sending request to LLM server: {}", requestJson);

OriginalHttpResponse response = ApiExecutor.sendRequest(request, true, null, false, null);

if (response == null) {
loggerMaker.errorAndAddToDb("Response object is null from LLM server", LoggerMaker.LogDb.TESTING);
return null;
}

String responsePayload = response.getBody();

if (response.getStatusCode() != 200) {
loggerMaker.errorAndAddToDb("Non-2xx response in getLLMResponse: " + response.getStatusCode(), LoggerMaker.LogDb.TESTING);
return null;
}

if (responsePayload == null || responsePayload.trim().isEmpty()) {
loggerMaker.errorAndAddToDb("Empty or null response body from LLM server", LoggerMaker.LogDb.TESTING);
return null;
}

loggerMaker.debug("Received response from LLM server: {}", responsePayload);
return responsePayload;

} catch (Exception e) {
loggerMaker.errorAndAddToDb(e, "Exception in getLLMResponse." , LoggerMaker.LogDb.TESTING);
}
return null;
}

public List<SlackWebhook> fetchSlackWebhooks() {
Map<String, List<String>> headers = buildHeaders();
OriginalHttpRequest request = new OriginalHttpRequest(url + "/getSlackWebhooks", "", "POST", "", headers, "");
Expand Down
2 changes: 2 additions & 0 deletions libs/utils/src/main/java/com/akto/data_actor/DataActor.java
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,8 @@ public abstract class DataActor {

public abstract String getLLMPromptResponse(JSONObject promptPayload);

public abstract String getLLMResponseV2(JSONObject promptPayload);

public abstract List<SlackWebhook> fetchSlackWebhooks();

public abstract void insertMCPAuditDataLog(McpAuditInfo auditInfo);
Expand Down
6 changes: 6 additions & 0 deletions libs/utils/src/main/java/com/akto/data_actor/DbActor.java
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,12 @@ public String getLLMPromptResponse(JSONObject promptPayload) {
return null;
}

@Override
public String getLLMResponseV2(JSONObject promptPayload) {
// no implementation needed.
return null;
}

@Override
public void updateModuleInfo(ModuleInfo moduleInfo) {
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package com.akto.gpt.handlers.gpt_prompts;

import java.util.concurrent.TimeUnit;

import javax.validation.ValidationException;

import com.akto.data_actor.DataActorFactory;
import com.akto.log.LoggerMaker;
import com.akto.log.LoggerMaker.LogDb;
import com.akto.util.http_util.CoreHTTPClient;
import com.mongodb.BasicDBObject;

import okhttp3.OkHttpClient;
import org.json.JSONObject;
import org.json.JSONArray;

public abstract class AzureOpenAIPromptHandler {

static final OkHttpClient client = CoreHTTPClient.client.newBuilder()
.connectTimeout(60, TimeUnit.SECONDS)
.readTimeout(60, TimeUnit.SECONDS)
.writeTimeout(60, TimeUnit.SECONDS)
.build();

static final LoggerMaker logger = new LoggerMaker(AzureOpenAIPromptHandler.class, LogDb.DASHBOARD);

public BasicDBObject handle(BasicDBObject queryData) {
try {
validate(queryData);
String prompt = getPrompt(queryData);
String rawResponse = call(prompt);
BasicDBObject resp = processResponse(rawResponse);
return resp;
} catch (ValidationException exception) {
logger.error("Validation error: " + exception.getMessage());
BasicDBObject resp = new BasicDBObject();
resp.put("error", "Invalid input parameters.");
return resp;
} catch (Exception e) {
logger.error("Error while handling request: " + e);
BasicDBObject resp = new BasicDBObject();
resp.put("error", "Internal server error: " + e.getMessage());
return resp;
}
}

protected abstract void validate(BasicDBObject queryData) throws ValidationException;

protected abstract String getPrompt(BasicDBObject queryData);

protected String call(String prompt) throws Exception {
JSONObject payload = new JSONObject();

// Set model parameters
payload.put("temperature", PromptHandler.temperature);
payload.put("top_p", 0.9);
payload.put("max_tokens", PromptHandler.max_tokens);
payload.put("frequency_penalty", 0);
payload.put("presence_penalty", 0.6);

// Create messages array
JSONArray messages = new JSONArray();
JSONObject systemMessage = new JSONObject();
systemMessage.put("role", "system");
systemMessage.put("content", prompt);
messages.put(systemMessage);

payload.put("messages", messages);
synchronized (PromptHandler.llmLock) {
return DataActorFactory.fetchInstance().getLLMPromptResponse(payload);
}
}

static String cleanJSON(String rawResponse) {
if (rawResponse == null || rawResponse.isEmpty()) {
return "NOT_FOUND";
}

// Truncate at the last closing brace to remove any trailing notes
int lastBrace = rawResponse.lastIndexOf('}');
if (lastBrace != -1) {
rawResponse = rawResponse.substring(0, lastBrace + 1);
}

// Start at the first opening brace to remove any forward notes
int firstBrace = rawResponse.indexOf('{');
if (firstBrace != -1) {
rawResponse = rawResponse.substring(firstBrace);
}
return rawResponse.trim();
}

static String processOutput(String rawResponse) {
try {
return cleanJSON(rawResponse);
} catch (Exception e) {
logger.error("Failed to clean LLM response: " + rawResponse, e);
return "NOT_FOUND";
}
}

protected abstract BasicDBObject processResponse(String rawResponse);

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ public abstract class PromptHandler {

private static final LoggerMaker logger = new LoggerMaker(PromptHandler.class, LogDb.TESTING);
private static final String OLLAMA_MODEL = "llama3:8b";
private static final Double temperature = 0.1;
private static final int max_tokens = 4000;
private static final Object llmLock = new Object();
static final Double temperature = 0.1;
static final int max_tokens = 10000;
static final Object llmLock = new Object();
private static final int CHUNK_SIZE = 10000;
private static final String CONTEXT_DELIMITER = "****";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

import com.mongodb.BasicDBObject;

public class TestExecutorModifier extends PromptHandler {

public class TestExecutorModifier extends AzureOpenAIPromptHandler {

private static final LoggerMaker logger = new LoggerMaker(TestExecutorModifier.class, LogDb.TESTING);
static final int MAX_QUERY_LENGTH = 100000;
Expand Down Expand Up @@ -81,6 +80,10 @@ protected String getPrompt(BasicDBObject queryData) {
.append("- Example: { \"delete_body_param\": \"param1\" }\n")
.append("- Example: { \"modify_header\": {\"header1\": \"value1\"} }\n")
.append("- Example: { \"modify_url\": \"https://example.com/product?id=5 OR 1=1\" }\n")
.append("- Example: { \"add_header\": {\"key\": \"value1\"} }\n")
.append("- Example: { \"add_body_param\": {\"key\": \"value1\"} }\n")
.append("- Give preference to the operation given in the prompt, but not strictly as operation can be of modify nature but key would not be there to modify in the api request shared to you, be intelligent about the operations to be performed.\n")
.append("- Check the key location for choosing the operation to be performed, if key belongs to headers, then choose the operation from the list of operations given for headers, if key belongs to body, then choose the operation from the list of operations given for body, if key belongs to url, then choose the operation from the list of operations given for url, if key belongs to query params, then choose the operation from the list of operations given for query params.\n")
.append("- Example: { \"modify_body_param\": {\"key\": \"value1\"} }\n")
.append("- Return ONLY the JSON or " + _NOT_FOUND + " — nothing else.");
return promptBuilder.toString();
Expand Down
Loading