Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import com.github.f4b6a3.ulid.Ulid;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;
import com.redis.om.spring.annotations.Bloom;
import com.redis.om.spring.annotations.Document;
import com.redis.om.spring.client.RedisModulesClient;
Expand All @@ -43,7 +42,6 @@
import org.springframework.context.annotation.*;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.context.event.EventListener;
import org.springframework.data.annotation.Reference;
import org.springframework.data.geo.Point;
import org.springframework.data.redis.connection.jedis.JedisConnectionFactory;
import org.springframework.data.redis.core.RedisHash;
Expand All @@ -60,9 +58,13 @@
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.OffsetDateTime;
import java.util.*;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static com.redis.om.spring.util.ObjectUtils.*;
import static com.redis.om.spring.util.ObjectUtils.getBeanDefinitionsFor;
import static com.redis.om.spring.util.ObjectUtils.getDeclaredFieldsTransitively;

@Configuration(proxyBeanMethods = false)
@EnableConfigurationProperties({RedisProperties.class, RedisOMSpringProperties.class})
Expand Down Expand Up @@ -94,6 +96,11 @@ public GsonBuilder gsonBuilder(List<GsonBuilderCustomizer> customizers) {
return builder;
}

@Bean(name = "referenceAwareGsonBuilder")
ReferenceAwareGsonBuilder referenceAwareGsonBuilder(GsonBuilder gsonBuilder, ApplicationContext ac) {
return new ReferenceAwareGsonBuilder(gsonBuilder, ac);
}

@Bean(name = "redisModulesClient")
@Lazy
RedisModulesClient redisModulesClient( //
Expand All @@ -109,7 +116,7 @@ RedisModulesClient redisModulesClient( //
RedisModulesOperations<?> redisModulesOperations( //
RedisModulesClient rmc, //
StringRedisTemplate template, //
@Qualifier("omGsonBuilder") GsonBuilder gsonBuilder) {
ReferenceAwareGsonBuilder gsonBuilder) {
return new RedisModulesOperations<>(rmc, template, gsonBuilder);
}

Expand Down Expand Up @@ -154,7 +161,7 @@ public Criteria<Image, byte[]> imageEmbeddingModelCriteria(RedisOMSpringProperti
}

@Bean(name = "djlFaceDetectionTranslator")
public Translator<Image, DetectedObjects> faceDetectionTranslator(RedisOMSpringProperties properties) {
public Translator<Image, DetectedObjects> faceDetectionTranslator() {
double confThresh = 0.85f;
double nmsThresh = 0.45f;
double[] variance = {0.1f, 0.2f};
Expand Down Expand Up @@ -183,15 +190,15 @@ public ZooModel<Image, DetectedObjects> faceDetectionModel(
@Nullable @Qualifier("djlFaceDetectionModelCriteria") Criteria<Image, DetectedObjects> criteria,
RedisOMSpringProperties properties) {
try {
return properties.getDjl().isEnabled() ? ModelZoo.loadModel(criteria) : null;
return properties.getDjl().isEnabled() && (criteria != null) ? ModelZoo.loadModel(criteria) : null;
} catch (IOException | ModelNotFoundException | MalformedModelException ex) {
logger.warn("Error retrieving default DJL face detection model", ex);
return null;
}
}

@Bean(name = "djlFaceEmbeddingTranslator")
public Translator<Image, float[]> faceEmbeddingTranslator(RedisOMSpringProperties properties) {
public Translator<Image, float[]> faceEmbeddingTranslator() {
return new FaceFeatureTranslator();
}

Expand All @@ -214,7 +221,7 @@ public ZooModel<Image, float[]> faceEmbeddingModel(
@Nullable @Qualifier("djlFaceEmbeddingModelCriteria") Criteria<Image, float[]> criteria, //
RedisOMSpringProperties properties) {
try {
return properties.getDjl().isEnabled() ? ModelZoo.loadModel(criteria) : null;
return properties.getDjl().isEnabled() && (criteria != null) ? ModelZoo.loadModel(criteria) : null;
} catch (Exception e) {
logger.warn("Error retrieving default DJL face embeddings model", e);
return null;
Expand All @@ -225,7 +232,7 @@ public ZooModel<Image, float[]> faceEmbeddingModel(
public ZooModel<Image, byte[]> imageModel(
@Nullable @Qualifier("djlImageEmbeddingModelCriteria") Criteria<Image, byte[]> criteria, RedisOMSpringProperties properties)
throws MalformedModelException, ModelNotFoundException, IOException {
return properties.getDjl().isEnabled() ? ModelZoo.loadModel(criteria) : null;
return properties.getDjl().isEnabled() && (criteria != null) ? ModelZoo.loadModel(criteria) : null;
}

@Bean(name = "djlDefaultImagePipeline")
Expand Down Expand Up @@ -255,7 +262,7 @@ public HuggingFaceTokenizer sentenceTokenizer(RedisOMSpringProperties properties
try {
InetAddress.getByName("www.huggingface.co").isReachable(5000);
return HuggingFaceTokenizer.newInstance(properties.getDjl().getSentenceTokenizerModel(), options);
} catch (IOException ex) {
} catch (IOException ioe) {
logger.warn("Error retrieving default DJL sentence tokenizer");
return null;
}
Expand Down Expand Up @@ -319,41 +326,6 @@ EntityStream streamingQueryBuilder(RedisModulesOperations<?> redisModulesOperati
return new EntityStreamImpl(redisModulesOperations, gson);
}

@EventListener(ContextRefreshedEvent.class)
public void registerGsonDocumentReferenceDeserializers(ContextRefreshedEvent cre) {
logger.info("Registering GSon document reference deserializers......");

ApplicationContext ac = cre.getApplicationContext();
GsonBuilder builder = (GsonBuilder) ac.getBean("omGsonBuilder");

Set<BeanDefinition> beanDefs = new HashSet<>(getBeanDefinitionsFor(ac, Document.class));
for (BeanDefinition beanDef : beanDefs) {
try {
Class<?> cl = Class.forName(beanDef.getBeanClassName());
final List<java.lang.reflect.Field> allClassFields = getDeclaredFieldsTransitively(cl);
for (java.lang.reflect.Field field : allClassFields) {
if (field.isAnnotationPresent(Reference.class)) {
logger.info(String.format("Registering reference type adapter for %s", field.getType().getName()));
if (isCollection(field)) {
var maybeCollectionElementType = getCollectionElementType(field);
if (maybeCollectionElementType.isPresent()) {
TypeToken<?> typeToken = TypeToken.getParameterized(field.getType(), maybeCollectionElementType.get());
builder.registerTypeAdapter(typeToken.getType(), new ReferenceDeserializer(field));
} else {
builder.registerTypeAdapter(field.getType(), new ReferenceDeserializer(field));
}
} else {
builder.registerTypeAdapter(field.getType(), new ReferenceDeserializer(field));
}
}
}
} catch (ClassNotFoundException e) {
logger.debug(
String.format("Error processing references in %s: %s", beanDef.getBeanClassName(), e.getMessage()));
}
}
}

@EventListener(ContextRefreshedEvent.class)
public void ensureIndexesAreCreated(ContextRefreshedEvent cre) {
logger.info("Creating Indexes......");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
import com.redis.om.spring.ops.pds.*;
import com.redis.om.spring.ops.search.SearchOperations;
import com.redis.om.spring.ops.search.SearchOperationsImpl;
import com.redis.om.spring.serialization.gson.ReferenceAwareGsonBuilder;
import org.springframework.data.redis.core.StringRedisTemplate;

public class RedisModulesOperations<K> {

private final GsonBuilder gsonBuilder;
private final ReferenceAwareGsonBuilder gsonBuilder;
private final RedisModulesClient client;
private final StringRedisTemplate template;

public RedisModulesOperations(RedisModulesClient client, StringRedisTemplate template, GsonBuilder gsonBuilder) {
public RedisModulesOperations(RedisModulesClient client, StringRedisTemplate template, ReferenceAwareGsonBuilder gsonBuilder) {
this.client = client;
this.template = template;
this.gsonBuilder = gsonBuilder;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package com.redis.om.spring.ops.json;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.redis.om.spring.client.RedisModulesClient;
import com.redis.om.spring.serialization.gson.ReferenceAwareGsonBuilder;
import org.springframework.lang.Nullable;
import redis.clients.jedis.json.JsonSetParams;
import redis.clients.jedis.json.Path;
Expand All @@ -14,10 +13,10 @@

public class JSONOperationsImpl<K> implements JSONOperations<K> {

private final GsonBuilder builder;
private final ReferenceAwareGsonBuilder builder;
final RedisModulesClient client;

public JSONOperationsImpl(RedisModulesClient client, GsonBuilder builder) {
public JSONOperationsImpl(RedisModulesClient client, ReferenceAwareGsonBuilder builder) {
this.client = client;
this.builder = builder;
}
Expand All @@ -35,12 +34,14 @@ public String get(K key) {

@Override
public <T> T get(K key, Class<T> clazz) {
return builder.create().fromJson(client.clientForJSON().jsonGetAsPlainString(key.toString(), Path.ROOT_PATH), clazz);
builder.processEntity(clazz);
return builder.gson().fromJson(client.clientForJSON().jsonGetAsPlainString(key.toString(), Path.ROOT_PATH), clazz);
}

@Override
public <T> T get(K key, Class<T> clazz, Path path) {
return builder.create().fromJson(client.clientForJSON().jsonGetAsPlainString(key.toString(), path), clazz);
builder.processEntity(clazz);
return builder.gson().fromJson(client.clientForJSON().jsonGetAsPlainString(key.toString(), path), clazz);
}

@SafeVarargs
Expand All @@ -56,23 +57,23 @@ public final List<String> mget(K... keys) {

@SafeVarargs @Override
public final <T> List<T> mget(Class<T> clazz, K... keys) {
Gson gson = builder.create();
builder.processEntity(clazz);
return client.clientForJSON().jsonMGet(getKeysAsString(keys))
.stream()
.filter(Objects::nonNull)
.map(jsonArr -> jsonArr.get(0))
.map(Object::toString)
.map(str -> gson.fromJson(str, clazz))
.map(str -> builder.gson().fromJson(str, clazz))
.toList();
}

@SafeVarargs @Override
public final <T> List<T> mget(Path2 path, Class<T> clazz, K... keys) {
Gson gson = builder.create();
builder.processEntity(clazz);
return client.clientForJSON().jsonMGet(path, getKeysAsString(keys))
.stream()
.map(Object::toString)
.map(str -> gson.fromJson(str, clazz))
.map(str -> builder.gson().fromJson(str, clazz))
.toList();
}

Expand All @@ -83,8 +84,7 @@ public void set(K key, Object object, JsonSetParams flag) {

@Override
public void set(K key, Object object) {
Gson gson = builder.create();
client.clientForJSON().jsonSetWithPlainString(key.toString(), Path.ROOT_PATH, gson.toJson(object));
client.clientForJSON().jsonSetWithPlainString(key.toString(), Path.ROOT_PATH, builder.gson().toJson(object));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package com.redis.om.spring.serialization.gson;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;
import com.redis.om.spring.ops.json.JSONOperations;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Lazy;
import org.springframework.data.annotation.Reference;
import org.springframework.stereotype.Component;

import java.lang.reflect.Field;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;

import static com.redis.om.spring.util.ObjectUtils.*;

@Component
public class ReferenceAwareGsonBuilder {
private static final Log logger = LogFactory.getLog(ReferenceAwareGsonBuilder.class);
private final List<Type> processedClasses = new ArrayList<>();
private final GsonBuilder builder;
private Gson gson;
private JSONOperations<?> ops;
private final ApplicationContext ac;
private boolean rebuildGson = false;

public ReferenceAwareGsonBuilder(GsonBuilder builder, ApplicationContext ac) {
this.builder = builder;
this.gson = builder.create();
this.ac = ac;
}
public <T> void processEntity(Class<T> clazz) {
if (!processedClasses.contains(clazz)) {
ops = ac.getBean("redisJSONOperations", JSONOperations.class);
final List<java.lang.reflect.Field> allClassFields = getDeclaredFieldsTransitively(clazz);
for (java.lang.reflect.Field field : allClassFields) {
if (field.isAnnotationPresent(Reference.class)) {
logger.info(String.format("Registering reference type adapter for %s", field.getType().getName()));
processField(field);
}
}
processedClasses.add(clazz);
}
}

public Gson gson() {
if (rebuildGson) {
gson = builder.create();
rebuildGson = false;
}
return gson;
}
private void processField(Field field) {
TypeToken<?> typeToken;
if (isCollection(field)) {
var maybeCollectionElementType = getCollectionElementType(field);
if (maybeCollectionElementType.isPresent()) {
typeToken = TypeToken.getParameterized(field.getType(), maybeCollectionElementType.get());
} else {
typeToken = TypeToken.get(field.getType());
}
} else {
typeToken = TypeToken.get(field.getType());
}
builder.registerTypeAdapter(typeToken.getType(), new ReferenceDeserializer(field, ops));
rebuildGson = true;

processEntity(field.getType());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
import com.google.gson.internal.ConstructorConstructor;
import com.google.gson.internal.ObjectConstructor;
import com.google.gson.reflect.TypeToken;
import com.redis.om.spring.ApplicationContextProvider;
import com.redis.om.spring.ops.json.JSONOperations;
import com.redis.om.spring.util.ObjectUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.context.ApplicationContext;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
Expand All @@ -22,10 +20,13 @@ public class ReferenceDeserializer implements JsonDeserializer<Object> {
private final Class<?> type;
private final ObjectConstructor<?> objectConstructor;

public ReferenceDeserializer(Field field) {
private final JSONOperations<String> ops;

public ReferenceDeserializer(Field field, JSONOperations<?> ops) {
this.ops = (JSONOperations<String>) ops;
Map<Type, InstanceCreator<?>> instanceCreators = new HashMap<>();
ConstructorConstructor constructorConstructor = new ConstructorConstructor(instanceCreators, true,
Collections.<ReflectionAccessFilter>emptyList());
Collections.emptyList());
if (ObjectUtils.isCollection(field)) {
Optional<Class<?>> collectionType = ObjectUtils.getCollectionElementClass(field);
if (collectionType.isPresent()) {
Expand All @@ -44,8 +45,6 @@ public Object deserialize(JsonElement json, Type typeOfT, JsonDeserializationCon
throws JsonParseException {
Object reference = null;
JsonObject jsonObject;
ApplicationContext ac = ApplicationContextProvider.getApplicationContext();
JSONOperations<String> ops = (JSONOperations<String>) ac.getBean("redisJSONOperations");
if (json.isJsonPrimitive()) {
String referenceKey = ObjectUtils.unQuote(json.toString());
String referenceJSON = ops.get(referenceKey);
Expand Down