diff --git a/redis-om-spring/pom.xml b/redis-om-spring/pom.xml index b8c8d4d76..2d152202e 100644 --- a/redis-om-spring/pom.xml +++ b/redis-om-spring/pom.xml @@ -59,8 +59,8 @@ 17 17 17 - 3.1.0 - 3.1.0 + 3.1.1 + 3.1.1 4.3.2 1.0 1.0.1 diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/RediSearchIndexer.java b/redis-om-spring/src/main/java/com/redis/om/spring/RediSearchIndexer.java index c5e461047..99940209d 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/RediSearchIndexer.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/RediSearchIndexer.java @@ -185,6 +185,10 @@ public boolean indexExistsFor(Class entityClass) { return indexedEntityClasses.contains(entityClass); } + public Schema getSchemaFor(Class entityClass) { + return entityClassToSchema.get(entityClass); + } + private List findIndexFields(java.lang.reflect.Field field, String prefix, boolean isDocument) { List fields = new ArrayList<>(); @@ -200,13 +204,7 @@ private List findIndexFields(java.lang.reflect.Field field, String prefix // @Reference @Indexed fields: Create schema field for the reference entity @Id field // logger.debug("🪲Found @Reference field " + field.getName() + " in " + field.getDeclaringClass().getSimpleName()); - var maybeReferenceIdField = getIdFieldForEntityClass(fieldType); - if (maybeReferenceIdField.isPresent()) { - var idFieldToIndex = maybeReferenceIdField.get(); - createIndexedFieldForReferenceIdField(field, idFieldToIndex, isDocument).ifPresent(fields::add); - } else { - logger.warn("Couldn't find ID field for reference" + field.getName() + " in " + field.getDeclaringClass().getSimpleName()); - } + createIndexedFieldForReferenceIdField(field, isDocument).ifPresent(fields::add); } else if (indexed.schemaFieldType() == SchemaFieldType.AUTODETECT) { // // Any Character class, Enums or Boolean -> Tag Search Field @@ -337,13 +335,11 @@ private Field indexAsTagFieldFor(java.lang.reflect.Field field, boolean isDocume String fieldPostfix = (isDocument && typeInfo.isCollectionLike() && !field.isAnnotationPresent(JsonAdapter.class)) ? "[*]" : ""; - FieldName fieldName = FieldName.of(fieldPrefix + field.getName() + fieldPostfix); + String name = fieldPrefix + field.getName() + fieldPostfix; + String alias = ObjectUtils.isEmpty(ti.alias()) ? QueryUtils.searchIndexFieldAliasFor(field, prefix) : ti.alias(); - if (!ObjectUtils.isEmpty(ti.alias())) { - fieldName = fieldName.as(ti.alias()); - } else { - fieldName = fieldName.as(QueryUtils.searchIndexFieldAliasFor(field, prefix)); - } + FieldName fieldName = FieldName.of(name); + fieldName = fieldName.as(alias); return new TagField(fieldName, ti.separator(), false); } @@ -384,13 +380,10 @@ private Field indexAsVectorFieldFor(java.lang.reflect.Field field, boolean isDoc } } - VectorField vectorField = new VectorField(fieldName, indexed.algorithm(), attributes); + String alias = ObjectUtils.isEmpty(indexed.alias()) ? QueryUtils.searchIndexFieldAliasFor(field, prefix) : indexed.alias(); - if (!ObjectUtils.isEmpty(indexed.alias())) { - vectorField.as(indexed.alias()); - } else { - vectorField.as(QueryUtils.searchIndexFieldAliasFor(field, prefix)); - } + VectorField vectorField = new VectorField(fieldName, indexed.algorithm(), attributes); + vectorField.as(alias); return vectorField; } @@ -431,13 +424,10 @@ private Field indexAsVectorFieldFor(java.lang.reflect.Field field, boolean isDoc } } - VectorField vectorField = new VectorField(fieldName, vi.algorithm(), attributes); + String alias = ObjectUtils.isEmpty(vi.alias()) ? QueryUtils.searchIndexFieldAliasFor(field, prefix) : vi.alias(); - if (!ObjectUtils.isEmpty(vi.alias())) { - vectorField.as(vi.alias()); - } else { - vectorField.as(QueryUtils.searchIndexFieldAliasFor(field, prefix)); - } + VectorField vectorField = new VectorField(fieldName, vi.algorithm(), attributes); + vectorField.as(alias); return vectorField; } @@ -450,22 +440,20 @@ private Field indexAsTagFieldFor(java.lang.reflect.Field field, boolean isDocume String fieldPostfix = (isDocument && typeInfo.isCollectionLike() && !field.isAnnotationPresent(JsonAdapter.class)) ? index : ""; - FieldName fieldName = FieldName.of(fieldPrefix + field.getName() + fieldPostfix); - - fieldName = fieldName.as(QueryUtils.searchIndexFieldAliasFor(field, prefix)); + String name = fieldPrefix + field.getName() + fieldPostfix; + String alias = QueryUtils.searchIndexFieldAliasFor(field, prefix); + FieldName fieldName = FieldName.of(name); + fieldName = fieldName.as(alias); return new TagField(fieldName, separator.isBlank() ? null : separator, sortable); } private Field indexAsTextFieldFor(java.lang.reflect.Field field, boolean isDocument, String prefix, TextIndexed ti) { String fieldPrefix = getFieldPrefix(prefix, isDocument); - FieldName fieldName = FieldName.of(fieldPrefix + field.getName()); - - if (!ObjectUtils.isEmpty(ti.alias())) { - fieldName = fieldName.as(ti.alias()); - } else { - fieldName = fieldName.as(QueryUtils.searchIndexFieldAliasFor(field, prefix)); - } + String name = fieldPrefix + field.getName(); + String alias = ObjectUtils.isEmpty(ti.alias()) ? QueryUtils.searchIndexFieldAliasFor(field, prefix) : ti.alias(); + FieldName fieldName = FieldName.of(name); + fieldName = fieldName.as(alias); String phonetic = ObjectUtils.isEmpty(ti.phonetic()) ? null : ti.phonetic(); @@ -474,13 +462,11 @@ private Field indexAsTextFieldFor(java.lang.reflect.Field field, boolean isDocum private Field indexAsTextFieldFor(java.lang.reflect.Field field, boolean isDocument, String prefix, Searchable ti) { String fieldPrefix = getFieldPrefix(prefix, isDocument); - FieldName fieldName = FieldName.of(fieldPrefix + field.getName()); + String name = fieldPrefix + field.getName(); + String alias = ObjectUtils.isEmpty(ti.alias()) ? QueryUtils.searchIndexFieldAliasFor(field, prefix) : ti.alias(); + FieldName fieldName = FieldName.of(name); + fieldName = fieldName.as(alias); - if (!ObjectUtils.isEmpty(ti.alias())) { - fieldName = fieldName.as(ti.alias()); - } else { - fieldName = fieldName.as(QueryUtils.searchIndexFieldAliasFor(field, prefix)); - } String phonetic = ObjectUtils.isEmpty(ti.phonetic()) ? null : ti.phonetic(); return new TextField(fieldName, ti.weight(), ti.sortable(), ti.nostem(), ti.noindex(), phonetic); @@ -488,13 +474,10 @@ private Field indexAsTextFieldFor(java.lang.reflect.Field field, boolean isDocum private Field indexAsGeoFieldFor(java.lang.reflect.Field field, boolean isDocument, String prefix, GeoIndexed gi) { String fieldPrefix = getFieldPrefix(prefix, isDocument); - FieldName fieldName = FieldName.of(fieldPrefix + field.getName()); - - if (!ObjectUtils.isEmpty(gi.alias())) { - fieldName = fieldName.as(gi.alias()); - } else { - fieldName = fieldName.as(QueryUtils.searchIndexFieldAliasFor(field, prefix)); - } + String name = fieldPrefix + field.getName(); + String alias = ObjectUtils.isEmpty(gi.alias()) ? QueryUtils.searchIndexFieldAliasFor(field, prefix) : gi.alias(); + FieldName fieldName = FieldName.of(name); + fieldName = fieldName.as(alias); return new Field(fieldName, FieldType.GEO); } @@ -502,13 +485,10 @@ private Field indexAsGeoFieldFor(java.lang.reflect.Field field, boolean isDocume private Field indexAsNumericFieldFor(java.lang.reflect.Field field, boolean isDocument, String prefix, NumericIndexed ni) { String fieldPrefix = getFieldPrefix(prefix, isDocument); - FieldName fieldName = FieldName.of(fieldPrefix + field.getName()); - - if (!ObjectUtils.isEmpty(ni.alias())) { - fieldName = fieldName.as(ni.alias()); - } else { - fieldName = fieldName.as(QueryUtils.searchIndexFieldAliasFor(field, prefix)); - } + String name = fieldPrefix + field.getName(); + String alias = ObjectUtils.isEmpty(ni.alias()) ? QueryUtils.searchIndexFieldAliasFor(field, prefix) : ni.alias(); + FieldName fieldName = FieldName.of(name); + fieldName = fieldName.as(alias); return new Field(fieldName, FieldType.NUMERIC); } @@ -516,18 +496,20 @@ private Field indexAsNumericFieldFor(java.lang.reflect.Field field, boolean isDo private Field indexAsNumericFieldFor(java.lang.reflect.Field field, boolean isDocument, String prefix, boolean sortable, boolean noIndex) { String fieldPrefix = getFieldPrefix(prefix, isDocument); - FieldName fieldName = FieldName.of(fieldPrefix + field.getName()); - - fieldName = fieldName.as(QueryUtils.searchIndexFieldAliasFor(field, prefix)); + String name = fieldPrefix + field.getName(); + String alias = QueryUtils.searchIndexFieldAliasFor(field, prefix); + FieldName fieldName = FieldName.of(name); + fieldName = fieldName.as(alias); return new Field(fieldName, FieldType.NUMERIC, sortable, noIndex); } private Field indexAsGeoFieldFor(java.lang.reflect.Field field, boolean isDocument, String prefix) { String fieldPrefix = getFieldPrefix(prefix, isDocument); - FieldName fieldName = FieldName.of(fieldPrefix + field.getName()); - - fieldName = fieldName.as(QueryUtils.searchIndexFieldAliasFor(field, prefix)); + String name = fieldPrefix + field.getName(); + String alias = QueryUtils.searchIndexFieldAliasFor(field, prefix); + FieldName fieldName = FieldName.of(name); + fieldName = fieldName.as(alias); return new Field(fieldName, FieldType.GEO); } @@ -716,7 +698,7 @@ private Optional createIndexedFieldForIdField(Class cl, List fi private Optional createIndexedFieldForReferenceIdField( // java.lang.reflect.Field referenceIdField, // - java.lang.reflect.Field idFieldToIndex, boolean isDocument) { + boolean isDocument) { Optional result; String fieldPrefix = getFieldPrefix("", isDocument); diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/RedisModulesConfiguration.java b/redis-om-spring/src/main/java/com/redis/om/spring/RedisModulesConfiguration.java index 0c020b4cf..3217dcc8c 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/RedisModulesConfiguration.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/RedisModulesConfiguration.java @@ -320,9 +320,10 @@ public CustomRedisKeyValueTemplate getKeyValueTemplate( // @Bean(name = "streamingQueryBuilder") EntityStream streamingQueryBuilder( RedisModulesOperations redisModulesOperations, - @Qualifier("omGsonBuilder") GsonBuilder gsonBuilder + @Qualifier("omGsonBuilder") GsonBuilder gsonBuilder, + RediSearchIndexer indexer ) { - return new EntityStreamImpl(redisModulesOperations, gsonBuilder); + return new EntityStreamImpl(redisModulesOperations, gsonBuilder, indexer); } @EventListener(ContextRefreshedEvent.class) @@ -348,7 +349,6 @@ public void processBloom(ContextRefreshedEvent cre) { try { Class cl = Class.forName(beanDef.getBeanClassName()); for (java.lang.reflect.Field field : getDeclaredFieldsTransitively(cl)) { - // Text if (field.isAnnotationPresent(Bloom.class)) { Bloom bloom = field.getAnnotation(Bloom.class); BloomOperations ops = rmo.opsForBloom(); diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/metamodel/MetamodelGenerator.java b/redis-om-spring/src/main/java/com/redis/om/spring/metamodel/MetamodelGenerator.java index 0fabea0b0..aea681261 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/metamodel/MetamodelGenerator.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/metamodel/MetamodelGenerator.java @@ -2,10 +2,7 @@ import com.github.f4b6a3.ulid.Ulid; import com.google.auto.service.AutoService; -import com.redis.om.spring.annotations.Document; -import com.redis.om.spring.annotations.Indexed; -import com.redis.om.spring.annotations.SchemaFieldType; -import com.redis.om.spring.annotations.Searchable; +import com.redis.om.spring.annotations.*; import com.redis.om.spring.metamodel.indexed.*; import com.redis.om.spring.metamodel.nonindexed.*; import com.redis.om.spring.tuple.Pair; @@ -30,8 +27,6 @@ import java.io.IOException; import java.io.Writer; import java.lang.reflect.Field; -import java.math.BigDecimal; -import java.math.BigInteger; import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; @@ -40,8 +35,6 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import static java.util.Objects.requireNonNull; - @SupportedAnnotationTypes(value = {"com.redis.om.spring.annotations.Document","org.springframework.data.redis.core.RedisHash"}) @SupportedSourceVersion(SourceVersion.RELEASE_17) @AutoService(Processor.class) @@ -152,7 +145,7 @@ void generateMetaModelClass(final Element annotatedElement) throws IOException { if (i != 0) { sb.append(".getType()"); } - String formattedString = String.format("com.redis.om.spring.util.ObjectUtils.getDeclaredFieldTransitively(%s, \"%s\")", sb.toString(), element.getSimpleName()); + String formattedString = String.format("com.redis.om.spring.util.ObjectUtils.getDeclaredFieldTransitively(%s, \"%s\")", sb, element.getSimpleName()); sb.setLength(0); // clear the builder sb.append(formattedString); } @@ -203,6 +196,11 @@ private List> processFieldMet boolean fieldIsIndexed = (field.getAnnotation(Searchable.class) != null) || (field.getAnnotation(Indexed.class) != null) + || (field.getAnnotation(TextIndexed.class) != null) + || (field.getAnnotation(TagIndexed.class) != null) + || (field.getAnnotation(NumericIndexed.class) != null) + || (field.getAnnotation(GeoIndexed.class) != null) + || (field.getAnnotation(VectorIndexed.class) != null) || (field.getAnnotation(Id.class) != null); String chainedFieldName = chain.stream().map(Element::getSimpleName).collect(Collectors.joining("_")); @@ -229,10 +227,16 @@ private List> processFieldMet var searchable = field.getAnnotation(Searchable.class); var reference = field.getAnnotation(Reference.class); + var textIndexed = field.getAnnotation(TextIndexed.class); + var tagIndexed = field.getAnnotation(TagIndexed.class); + var numericIndexed = field.getAnnotation(NumericIndexed.class); + var geoIndexed = field.getAnnotation(GeoIndexed.class); + var vectorIndexed = field.getAnnotation(VectorIndexed.class); + if (indexed != null && reference != null) { targetInterceptor = ReferenceField.class; } - else if (searchable != null) { + else if (searchable != null || textIndexed != null) { targetInterceptor = TextField.class; } else if (indexed != null || field.getAnnotation(Id.class) != null) { try { @@ -243,7 +247,15 @@ else if (searchable != null) { fieldMetamodelSpec.addAll(processNestedIndexableFields(entity, chain)); } - if (indexed != null && indexed.schemaFieldType() != SchemaFieldType.AUTODETECT) { + if (tagIndexed != null) { + targetInterceptor = TextTagField.class; + } else if (numericIndexed != null) { + targetInterceptor = NumericField.class; + } else if (geoIndexed != null) { + targetInterceptor = GeoField.class; + } else if (vectorIndexed != null) { + targetInterceptor = VectorField.class; + } else if (indexed != null && indexed.schemaFieldType() != SchemaFieldType.AUTODETECT) { // here we do the non autodetect annotated fields switch (indexed.schemaFieldType()) { case TAG -> targetInterceptor = TextTagField.class; @@ -411,12 +423,12 @@ private Triple generateCollectionFie if (i != 0) { sb.append(".getType()"); } - String formattedString = String.format("com.redis.om.spring.util.ObjectUtils.getDeclaredFieldTransitively(%s, \"%s\")", sb.toString(), element.getSimpleName()); + String formattedString = String.format("com.redis.om.spring.util.ObjectUtils.getDeclaredFieldTransitively(%s, \"%s\")", sb, element.getSimpleName()); sb.setLength(0); // clear the buffer sb.append(formattedString); } FieldSpec fieldSpec = ogfs.fieldSpec(); - blockBuilder.addStatement("$L = " + sb.toString(), fieldSpec.name, entity); + blockBuilder.addStatement("$L = " + sb, fieldSpec.name, entity); } for (CodeBlock initCodeBlock : initCodeBlocks) { @@ -456,7 +468,7 @@ private Triple generateCollectionFie TypeName generatedTypeName = ClassName.bestGuess(qualifiedGenEntityName); - return generateFieldMetamodel(chain, chainedFieldName, generatedTypeName, true); + return generateFieldMetamodel(chain, chainedFieldName, generatedTypeName); } private List> processNestedIndexableFields(TypeName entity, @@ -590,7 +602,7 @@ private String findGetter(final Element field, final Map getter final String fieldName = field.getSimpleName().toString(); final String getterPrefix = isGetters.contains(fieldName) ? IS_PREFIX : GET_PREFIX; - final String standardJavaName = javaNameFromExternal(fieldName); + final String standardJavaName = ObjectUtils.javaNameFromExternal(fieldName); final String standardGetterName = getterPrefix + standardJavaName; @@ -616,50 +628,6 @@ private String findGetter(final Element field, final Map getter + ".class, \"" + fieldName + "\");}"; } - /** - * Returns a static field name representation of the specified camel-cased - * string. - * - * @param externalName the string - * @return the static field name representation - */ - public static String staticField(final String externalName) { - requireNonNull(externalName); - return ObjectUtils.toUnderscoreSeparated(javaNameFromExternal(externalName)).toUpperCase(); - } - - public static String javaNameFromExternal(final String externalName) { - requireNonNull(externalName); - return MetamodelGenerator - .replaceIfIllegalJavaIdentifierCharacter(replaceIfJavaUsedWord(nameFromExternal(externalName))); - } - - public static String nameFromExternal(final String externalName) { - requireNonNull(externalName); - String result = ObjectUtils.unQuote(externalName.trim()); // Trim if there are initial spaces or trailing spaces... - /* CamelCase - * http://stackoverflow.com/questions/4050381/regular-expression-for-checking-if - * -capital-letters-are-found-consecutively-in-a [A-Z] -> \p{Lu} [^A-Za-z0-9] -> - * [^\pL0-90-9] */ - result = Stream.of(result.replaceAll("(\\p{Lu}+)", "_$1").split("[^\\pL\\d]")).map(String::toLowerCase) - .map(ObjectUtils::ucfirst).collect(Collectors.joining()); - return result; - } - - public static String replaceIfJavaUsedWord(final String word) { - requireNonNull(word); - // We need to replace regardless of case because we do not know how the returned - // string is to be used - if (JAVA_USED_WORDS_LOWER_CASE.contains(word.toLowerCase())) { - // If it is a java reserved/literal/class, add a "_" at the end to avoid naming - // conflicts - return word + "_"; - } - return word; - } - - public static final Character REPLACEMENT_CHARACTER = '_'; - private Triple generateFieldMetamodel( // TypeName entity, // List chain, // @@ -669,7 +637,7 @@ private Triple generateFieldMetamode boolean fieldIsIndexed, // String collectionPrefix // ) { - String fieldAccessor = staticField(chainFieldName); + String fieldAccessor = ObjectUtils.staticField(chainFieldName); FieldSpec objectField = FieldSpec // .builder(Field.class, chainFieldName).addModifiers(Modifier.PUBLIC, Modifier.STATIC) // @@ -705,10 +673,9 @@ private Triple generateFieldMetamode private Triple generateFieldMetamodel( List chain, // String chainFieldName, // - TypeName interceptor, // - boolean fieldIsIndexed // + TypeName interceptor // ) { - String fieldAccessor = staticField(chainFieldName); + String fieldAccessor = ObjectUtils.staticField(chainFieldName); FieldSpec objectField = FieldSpec.builder(Field.class, chainFieldName).addModifiers(Modifier.PUBLIC, Modifier.STATIC) .build(); @@ -720,7 +687,8 @@ private Triple generateFieldMetamode String searchSchemaAlias = chain.stream().map(e -> e.getSimpleName().toString()).collect(Collectors.joining("_")); CodeBlock aFieldInit = CodeBlock.builder() - .addStatement("$L = new $T(new $T(\"$L\", $L),$L)", fieldAccessor, interceptor, SearchFieldAccessor.class, searchSchemaAlias, chainFieldName, fieldIsIndexed).build(); + .addStatement("$L = new $T(new $T(\"$L\", $L),$L)", fieldAccessor, interceptor, SearchFieldAccessor.class, searchSchemaAlias, chainFieldName, + true).build(); return Tuples.of(ogf, aField, aFieldInit); } @@ -738,65 +706,6 @@ private Pair generateUnboundMetamodelField(TypeName entity return Tuples.of(aField, aFieldInit); } - public static String replaceIfIllegalJavaIdentifierCharacter(final String word) { - requireNonNull(word); - if (word.isEmpty()) { - return REPLACEMENT_CHARACTER.toString(); // No name is translated to REPLACEMENT_CHARACTER only - } - final StringBuilder sb = new StringBuilder(); - for (int i = 0; i < word.length(); i++) { - char c = word.charAt(i); - if (i == 0) { - if (Character.isJavaIdentifierStart(c)) { - // Fine! Just add the first character - sb.append(c); - } else if (Character.isJavaIdentifierPart(c)) { - // Not ok as the first, but ok otherwise. Add the replacement before it - sb.append(REPLACEMENT_CHARACTER).append(c); - } else { - // Cannot be used as a java identifier. Replace it - sb.append(REPLACEMENT_CHARACTER); - } - } else if (Character.isJavaIdentifierPart(c)) { - // Fine! Just add it - sb.append(c); - } else { - // Cannot be used as a java identifier. Replace it - sb.append(REPLACEMENT_CHARACTER); - } - - } - return sb.toString(); - } - - static final Set JAVA_LITERAL_WORDS = Set.of("true", "false", "null"); - - // Java reserved keywords - static final Set JAVA_RESERVED_WORDS = Collections.unmodifiableSet(Stream.of( - // Unused - "const", "goto", - // The real ones... - "abstract", "continue", "for", "new", "switch", "assert", "default", "goto", "package", "synchronized", "boolean", - "do", "if", "private", "this", "break", "double", "implements", "protected", "throw", "byte", "else", "import", - "public", "throws", "case", "enum", "instanceof", "return", "transient", "catch", "extends", "int", "short", - "try", "char", "final", "interface", "static", "void", "class", "finally", "long", "strictfp", "volatile", - "const", "float", "native", "super", "while").collect(Collectors.toSet())); - - static final Set> JAVA_BUILT_IN_CLASSES = Set.of(Boolean.class, Byte.class, Character.class, Double.class, - Float.class, Integer.class, Long.class, Object.class, Short.class, String.class, BigDecimal.class, - BigInteger.class, boolean.class, byte.class, char.class, double.class, float.class, int.class, long.class, - short.class); - - private static final Set JAVA_BUILT_IN_CLASS_WORDS = Collections - .unmodifiableSet(JAVA_BUILT_IN_CLASSES.stream().map(Class::getSimpleName).collect(Collectors.toSet())); - - private static final Set JAVA_USED_WORDS = Collections - .unmodifiableSet(Stream.of(JAVA_LITERAL_WORDS, JAVA_RESERVED_WORDS, JAVA_BUILT_IN_CLASS_WORDS) - .flatMap(Collection::stream).collect(Collectors.toSet())); - - private static final Set JAVA_USED_WORDS_LOWER_CASE = Collections - .unmodifiableSet(JAVA_USED_WORDS.stream().map(String::toLowerCase).collect(Collectors.toSet())); - private boolean isEnum(ProcessingEnvironment processingEnv, TypeMirror typeMirror) { Types typeUtils = processingEnv.getTypeUtils(); diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/metamodel/MetamodelUtils.java b/redis-om-spring/src/main/java/com/redis/om/spring/metamodel/MetamodelUtils.java new file mode 100644 index 000000000..f250bcfdd --- /dev/null +++ b/redis-om-spring/src/main/java/com/redis/om/spring/metamodel/MetamodelUtils.java @@ -0,0 +1,43 @@ +package com.redis.om.spring.metamodel; + +import com.redis.om.spring.util.ObjectUtils; + +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Optional; + +public class MetamodelUtils { + public static MetamodelField getMetamodelForIdField(Class entityClass) { + Optional idField = ObjectUtils.getIdFieldForEntityClass(entityClass); + if (idField.isPresent()) { + try { + Class metamodel = Class.forName(entityClass.getName() + "$"); + String metamodelField = ObjectUtils.staticField(idField.get().getName()); + Field field = metamodel.getField(metamodelField); + return (MetamodelField) field.get(null); + } catch (ClassNotFoundException | NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + return null; + } + + public static List> getMetamodelFieldsForProperties(Class entityClass, Collection properties) { + List> result = new ArrayList<>(); + try { + Class metamodel = Class.forName(entityClass.getName() + "$"); + for (var property: properties) { + try { + result.add((MetamodelField) metamodel.getField(ObjectUtils.staticField(property)).get(null)); + } catch (IllegalAccessException | NoSuchFieldException e) { + // NOOP + } + } + } catch (ClassNotFoundException e) { + // NOOP + } + return result; + } +} diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/ops/RedisModulesOperations.java b/redis-om-spring/src/main/java/com/redis/om/spring/ops/RedisModulesOperations.java index a2b744042..e4e5b7953 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/ops/RedisModulesOperations.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/ops/RedisModulesOperations.java @@ -52,4 +52,8 @@ public StringRedisTemplate getTemplate() { public RedisModulesClient getClient() { return client; } + + public GsonBuilder getGsonBuilder() { + return gsonBuilder; + } } diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/repository/RedisDocumentRepository.java b/redis-om-spring/src/main/java/com/redis/om/spring/repository/RedisDocumentRepository.java index 922065f37..0d532fd03 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/repository/RedisDocumentRepository.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/repository/RedisDocumentRepository.java @@ -5,12 +5,13 @@ import org.springframework.data.domain.Pageable; import org.springframework.data.keyvalue.repository.KeyValueRepository; import org.springframework.data.repository.NoRepositoryBean; +import org.springframework.data.repository.query.QueryByExampleExecutor; import redis.clients.jedis.json.Path; import java.io.IOException; @NoRepositoryBean -public interface RedisDocumentRepository extends KeyValueRepository { +public interface RedisDocumentRepository extends KeyValueRepository, QueryByExampleExecutor { Iterable getIds(); diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/repository/RedisEnhancedRepository.java b/redis-om-spring/src/main/java/com/redis/om/spring/repository/RedisEnhancedRepository.java index 1ac8a615f..2fce4bf0a 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/repository/RedisEnhancedRepository.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/repository/RedisEnhancedRepository.java @@ -5,9 +5,10 @@ import org.springframework.data.domain.Pageable; import org.springframework.data.keyvalue.repository.KeyValueRepository; import org.springframework.data.repository.NoRepositoryBean; +import org.springframework.data.repository.query.QueryByExampleExecutor; @NoRepositoryBean -public interface RedisEnhancedRepository extends KeyValueRepository { +public interface RedisEnhancedRepository extends KeyValueRepository, QueryByExampleExecutor { Iterable getIds(); diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisDocumentRepository.java b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisDocumentRepository.java index 91bdb2bb9..f6f95ed94 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisDocumentRepository.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisDocumentRepository.java @@ -13,18 +13,22 @@ import com.redis.om.spring.ops.json.JSONOperations; import com.redis.om.spring.ops.search.SearchOperations; import com.redis.om.spring.repository.RedisDocumentRepository; +import com.redis.om.spring.search.stream.EntityStream; +import com.redis.om.spring.search.stream.EntityStreamImpl; +import com.redis.om.spring.search.stream.FluentQueryByExample; import com.redis.om.spring.serialization.gson.GsonListOfType; import com.redis.om.spring.util.ObjectUtils; -import org.springframework.beans.*; +import org.springframework.beans.BeanWrapper; +import org.springframework.beans.BeanWrapperImpl; +import org.springframework.beans.PropertyAccessor; +import org.springframework.beans.PropertyAccessorFactory; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.data.annotation.CreatedDate; import org.springframework.data.annotation.LastModifiedDate; import org.springframework.data.annotation.Reference; import org.springframework.data.annotation.Version; -import org.springframework.data.domain.Page; -import org.springframework.data.domain.PageImpl; -import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.*; import org.springframework.data.keyvalue.core.KeyValueOperations; import org.springframework.data.keyvalue.core.mapping.KeyValuePersistentEntity; import org.springframework.data.keyvalue.repository.support.SimpleKeyValueRepository; @@ -35,6 +39,7 @@ import org.springframework.data.redis.core.convert.ReferenceResolverImpl; import org.springframework.data.redis.core.mapping.RedisMappingContext; import org.springframework.data.repository.core.EntityInformation; +import org.springframework.data.repository.query.FluentQuery.FetchableFluentQuery; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.ReflectionUtils; @@ -56,9 +61,12 @@ import java.time.LocalDateTime; import java.util.*; import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.stream.Collectors; import java.util.stream.StreamSupport; import static com.redis.om.spring.util.ObjectUtils.isPrimitiveOfType; +import static com.redis.om.spring.util.ObjectUtils.pageFromSlice; import static redis.clients.jedis.json.JsonProtocol.JsonCommand; public class SimpleRedisDocumentRepository extends SimpleKeyValueRepository @@ -73,13 +81,14 @@ public class SimpleRedisDocumentRepository extends SimpleKeyValueReposito private final ULIDIdentifierGenerator generator; private final RedisOMProperties properties; private final RedisMappingContext mappingContext; + private final EntityStream entityStream; @SuppressWarnings("unchecked") public SimpleRedisDocumentRepository( // EntityInformation metadata, // KeyValueOperations operations, // @Qualifier("redisModulesOperations") RedisModulesOperations rmo, // - RediSearchIndexer keyspaceToIndexMap, // + RediSearchIndexer indexer, // RedisMappingContext mappingContext, GsonBuilder gsonBuilder, RedisOMProperties properties) { @@ -87,37 +96,31 @@ public SimpleRedisDocumentRepository( // this.modulesOperations = (RedisModulesOperations) rmo; this.metadata = metadata; this.operations = operations; - this.indexer = keyspaceToIndexMap; + this.indexer = indexer; this.mappingConverter = new MappingRedisOMConverter(null, new ReferenceResolverImpl(modulesOperations.getTemplate())); this.generator = ULIDIdentifierGenerator.INSTANCE; this.gsonBuilder = gsonBuilder; this.mappingContext = mappingContext; this.properties = properties; + this.entityStream = new EntityStreamImpl(modulesOperations, modulesOperations.getGsonBuilder(), indexer); } @Override public Iterable getIds() { - String keyspace = indexer.getKeyspaceForEntityClass(metadata.getJavaType()); - Optional maybeSearchIndex = indexer.getIndexName(keyspace); List result = List.of(); - if (maybeSearchIndex.isPresent()) { - Gson gson = gsonBuilder.create(); - SearchOperations searchOps = modulesOperations.opsForSearch(maybeSearchIndex.get()); - Optional maybeIdField = ObjectUtils.getIdFieldForEntityClass(metadata.getJavaType()); - String idField = maybeIdField.map(Field::getName).orElse("id"); - - Query query = new Query("*"); - query.limit(0, properties.getRepository().getQuery().getLimit()); - query.returnFields(idField); - SearchResult searchResult = searchOps.search(query); - - result = searchResult.getDocuments().stream() - .map(d -> gson.fromJson(SafeEncoder.encode((byte[])d.get(idField)), metadata.getIdType())) - .toList(); - } + Gson gson = gsonBuilder.create(); + Optional maybeIdField = ObjectUtils.getIdFieldForEntityClass(metadata.getJavaType()); + String idField = maybeIdField.map(Field::getName).orElse("id"); + + Query query = new Query("*"); + query.limit(0, properties.getRepository().getQuery().getLimit()); + query.returnFields(idField); + SearchResult searchResult = getSearchOps().search(query); - return result; + return searchResult.getDocuments().stream() + .map(d -> gson.fromJson(SafeEncoder.encode((byte[])d.get(idField)), metadata.getIdType())) + .toList(); } @Override @@ -137,7 +140,7 @@ public void deleteById(ID id, Path path) { @Override public void updateField(T entity, MetamodelField field, Object value) { - modulesOperations.opsForJSON().set(getKey(metadata.getId(entity)), value, + modulesOperations.opsForJSON().set(getKey(Objects.requireNonNull(metadata.getId(entity))), value, Path.of("$." + field.getSearchAlias())); } @@ -169,12 +172,13 @@ public List saveAll(Iterable entities) { KeyValuePersistentEntity keyValueEntity = mappingConverter.getMappingContext() .getRequiredPersistentEntity(ClassUtils.getUserClass(entity)); - Object id = isNew ? generator.generateIdentifierOfType(keyValueEntity.getIdProperty().getTypeInformation()) - : keyValueEntity.getPropertyAccessor(entity).getProperty(keyValueEntity.getIdProperty()); + Object id = isNew ? generator.generateIdentifierOfType(Objects.requireNonNull(keyValueEntity.getIdProperty()).getTypeInformation()) + : keyValueEntity.getPropertyAccessor(entity).getProperty( + Objects.requireNonNull(keyValueEntity.getIdProperty())); keyValueEntity.getPropertyAccessor(entity).setProperty(keyValueEntity.getIdProperty(), id); String keyspace = keyValueEntity.getKeySpace(); - byte[] objectKey = createKey(keyspace, id.toString()); + byte[] objectKey = createKey(keyspace, Objects.requireNonNull(id).toString()); processAuditAnnotations(entity, isNew); @@ -309,23 +313,21 @@ private Optional getTTLForEntity(Object entity) { Method ttlGetter; try { Field fld = ReflectionUtils.findField(entity.getClass(), settings.getTimeToLivePropertyName()); - ttlGetter = ObjectUtils.getGetterForField(entity.getClass(), fld); - Long ttlPropertyValue = ((Number) ReflectionUtils.invokeMethod(ttlGetter, entity)).longValue(); + ttlGetter = ObjectUtils.getGetterForField(entity.getClass(), Objects.requireNonNull(fld)); + long ttlPropertyValue = ((Number) Objects.requireNonNull(ReflectionUtils.invokeMethod(ttlGetter, entity))).longValue(); ReflectionUtils.invokeMethod(ttlGetter, entity); - if (ttlPropertyValue != null) { - TimeToLive ttl = fld.getAnnotation(TimeToLive.class); - if (!ttl.unit().equals(TimeUnit.SECONDS)) { - return Optional.of(TimeUnit.SECONDS.convert(ttlPropertyValue, ttl.unit())); - } else { - return Optional.of(ttlPropertyValue); - } + TimeToLive ttl = fld.getAnnotation(TimeToLive.class); + if (!ttl.unit().equals(TimeUnit.SECONDS)) { + return Optional.of(TimeUnit.SECONDS.convert(ttlPropertyValue, ttl.unit())); + } else { + return Optional.of(ttlPropertyValue); } } catch (SecurityException | IllegalArgumentException e) { return Optional.empty(); } - } else if (settings != null && settings.getTimeToLive() != null && settings.getTimeToLive() > 0) { + } else if (settings.getTimeToLive() != null && settings.getTimeToLive() > 0) { return Optional.of(settings.getTimeToLive()); } } @@ -358,4 +360,144 @@ private Number getEntityVersion(String key, String versionProperty) { Long[] dbVersionArray = (Long[]) ops.get(key, type, Path.of("$." + versionProperty)); return dbVersionArray != null ? dbVersionArray[0] : null; } + + // ------------------------------------------------------------------------- + // Query By Example Fluent API - QueryByExampleExecutor + // ------------------------------------------------------------------------- + + @Override + public Optional findOne(Example example) { + return entityStream.of(example.getProbeType()).filter(example).findFirst(); + } + + @Override + public Iterable findAll(Example example) { + return entityStream.of(example.getProbeType()).filter(example).collect(Collectors.toList()); + } + + @Override + public Iterable findAll(Example example, Sort sort) { + return entityStream.of(example.getProbeType()).filter(example).sorted(sort).collect(Collectors.toList()); + } + + @Override + public Page findAll(Example example, Pageable pageable) { + return pageFromSlice(entityStream.of(example.getProbeType()).filter(example).getSlice(pageable)); + } + + @Override + public long count(Example example) { + return entityStream.of(example.getProbeType()).filter(example).count(); + } + + @Override + public boolean exists(Example example) { + return count(example) > 0; + } + + // ------------------------------------------------------------------------- + // Query By Example Fluent API - QueryByExampleExecutor + // ------------------------------------------------------------------------- + + @Override + public R findBy(Example example, Function, R> queryFunction) { + Assert.notNull(example, "Example must not be null"); + Assert.notNull(queryFunction, "Query function must not be null"); + + return queryFunction.apply(new FluentQueryByExample<>(example, example.getProbeType(), entityStream, getSearchOps())); + } + + private SearchOperations getSearchOps() { + String keyspace = indexer.getKeyspaceForEntityClass(metadata.getJavaType()); + Optional maybeSearchIndex = indexer.getIndexName(keyspace); + return modulesOperations.opsForSearch(maybeSearchIndex.get()); + } + +// static class FluentQueryByExample implements FluentQuery.FetchableFluentQuery { +// private final SearchStream searchStream; +// private final Class probeType; +// +// private final SearchOperations searchOps; +// +// public FluentQueryByExample( // +// Example example, // +// Class probeType, // +// EntityStream entityStream, // +// SearchOperations searchOps // +// ) { +// this.probeType = probeType; +// this.searchOps = searchOps; +// this.searchStream = entityStream.of(probeType); +// searchStream.filter(example); +// } +// +// @Override +// public FetchableFluentQuery sortBy(Sort sort) { +// searchStream.sorted(sort); +// return this; +// } +// +// @Override +// public FetchableFluentQuery as(Class resultType) { +// throw new UnsupportedOperationException("`as` is not supported on a Redis Repositories"); +// } +// +// @Override +// public FetchableFluentQuery project(Collection properties) { +// List> metamodelFields = MetamodelUtils.getMetamodelFieldsForProperties(probeType, +// properties); +// metamodelFields.forEach(mmf -> searchStream.project((MetamodelField) mmf)); +// return this; +// } +// +// @Override +// public T oneValue() { +// var result = searchStream.collect(Collectors.toList()); +// +// if (org.springframework.util.ObjectUtils.isEmpty(result)) { +// return null; +// } +// +// if (result.size() > 1) { +// throw new IncorrectResultSizeDataAccessException("Query returned non unique result", 1); +// } +// +// return result.iterator().next(); +// } +// +// @Override +// public T firstValue() { +// return searchStream.findFirst().orElse(null); +// } +// +// @Override +// public List all() { +// return searchStream.collect(Collectors.toList()); +// } +// +// @Override +// public Page page(Pageable pageable) { +// Query query = (searchStream.backingQuery().isBlank()) ? new Query() : new Query(searchStream.backingQuery()); +// query.limit(0, 0); +// SearchResult searchResult = searchOps.search(query); +// var count = searchResult.getTotalResults(); +// var pageContents = searchStream.limit(pageable.getPageSize()).skip(pageable.getOffset()).collect(Collectors.toList()); +// return new PageImpl<>(pageContents, pageable, count); +// } +// +// @Override +// public Stream stream() { +// return all().stream(); +// } +// +// @Override +// public long count() { +// return searchStream.count(); +// } +// +// @Override +// public boolean exists() { +// return searchStream.count() > 0; +// } +// } } diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisEnhancedRepository.java b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisEnhancedRepository.java index 4567c86ae..37f457347 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisEnhancedRepository.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/repository/support/SimpleRedisEnhancedRepository.java @@ -11,6 +11,9 @@ import com.redis.om.spring.ops.RedisModulesOperations; import com.redis.om.spring.ops.search.SearchOperations; import com.redis.om.spring.repository.RedisEnhancedRepository; +import com.redis.om.spring.search.stream.EntityStream; +import com.redis.om.spring.search.stream.EntityStreamImpl; +import com.redis.om.spring.search.stream.FluentQueryByExample; import com.redis.om.spring.util.ObjectUtils; import com.redis.om.spring.vectorize.FeatureExtractor; import org.springframework.beans.factory.annotation.Qualifier; @@ -25,6 +28,7 @@ import org.springframework.data.redis.core.convert.RedisData; import org.springframework.data.redis.core.convert.ReferenceResolverImpl; import org.springframework.data.repository.core.EntityInformation; +import org.springframework.data.repository.query.FluentQuery.FetchableFluentQuery; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import redis.clients.jedis.Jedis; @@ -36,10 +40,12 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.StreamSupport; import static com.redis.om.spring.RedisOMProperties.MAX_SEARCH_RESULTS; +import static com.redis.om.spring.util.ObjectUtils.pageFromSlice; public class SimpleRedisEnhancedRepository extends SimpleKeyValueRepository implements RedisEnhancedRepository { @@ -56,6 +62,8 @@ public class SimpleRedisEnhancedRepository extends SimpleKeyValueReposito private final ULIDIdentifierGenerator generator; private final RedisOMProperties properties; + private final EntityStream entityStream; + @SuppressWarnings("unchecked") public SimpleRedisEnhancedRepository( // EntityInformation metadata, // @@ -77,6 +85,7 @@ public SimpleRedisEnhancedRepository( // this.auditor = new EntityAuditor(modulesOperations.getTemplate()); this.featureExtractor = featureExtractor; this.properties = properties; + this.entityStream = new EntityStreamImpl(modulesOperations, modulesOperations.getGsonBuilder(), indexer); } @SuppressWarnings("unchecked") @@ -269,4 +278,56 @@ public byte[] createKey(String keyspace, String id) { private boolean expires(RedisData data) { return data.getTimeToLive() != null && data.getTimeToLive() > 0L; } + + // ------------------------------------------------------------------------- + // Query By Example Fluent API - QueryByExampleExecutor + // ------------------------------------------------------------------------- + + @Override + public Optional findOne(Example example) { + return entityStream.of(example.getProbeType()).filter(example).findFirst(); + } + + @Override + public Iterable findAll(Example example) { + return entityStream.of(example.getProbeType()).filter(example).collect(Collectors.toList()); + } + + @Override + public Iterable findAll(Example example, Sort sort) { + return entityStream.of(example.getProbeType()).filter(example).sorted(sort).collect(Collectors.toList()); + } + + @Override + public Page findAll(Example example, Pageable pageable) { + return pageFromSlice(entityStream.of(example.getProbeType()).filter(example).getSlice(pageable)); + } + + @Override + public long count(Example example) { + return entityStream.of(example.getProbeType()).filter(example).count(); + } + + @Override + public boolean exists(Example example) { + return count(example) > 0; + } + + // ------------------------------------------------------------------------- + // Query By Example Fluent API - QueryByExampleExecutor + // ------------------------------------------------------------------------- + + @Override + public R findBy(Example example, Function, R> queryFunction) { + Assert.notNull(example, "Example must not be null"); + Assert.notNull(queryFunction, "Query function must not be null"); + + return queryFunction.apply(new FluentQueryByExample<>(example, example.getProbeType(), entityStream, getSearchOps())); + } + + private SearchOperations getSearchOps() { + String keyspace = indexer.getKeyspaceForEntityClass(metadata.getJavaType()); + Optional maybeSearchIndex = indexer.getIndexName(keyspace); + return modulesOperations.opsForSearch(maybeSearchIndex.get()); + } } diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/EntityStreamImpl.java b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/EntityStreamImpl.java index d82af7c07..43d924823 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/EntityStreamImpl.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/EntityStreamImpl.java @@ -1,22 +1,27 @@ package com.redis.om.spring.search.stream; import com.google.gson.GsonBuilder; +import com.redis.om.spring.RediSearchIndexer; import com.redis.om.spring.ops.RedisModulesOperations; +import org.springframework.expression.spel.ast.Indexer; public class EntityStreamImpl implements EntityStream { private final RedisModulesOperations modulesOperations; private final GsonBuilder gsonBuilder; + private final RediSearchIndexer indexer; + @SuppressWarnings("unchecked") - public EntityStreamImpl(RedisModulesOperations rmo, GsonBuilder gsonBuilder) { + public EntityStreamImpl(RedisModulesOperations rmo, GsonBuilder gsonBuilder, RediSearchIndexer indexer) { this.modulesOperations = (RedisModulesOperations) rmo; this.gsonBuilder = gsonBuilder; + this.indexer = indexer; } @Override public SearchStream of(Class entityClass) { - return new SearchStreamImpl<>(entityClass, modulesOperations, gsonBuilder); + return new SearchStreamImpl<>(entityClass, modulesOperations, gsonBuilder, indexer); } } diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/ExampleToNodeConverter.java b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/ExampleToNodeConverter.java new file mode 100644 index 000000000..c2af172de --- /dev/null +++ b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/ExampleToNodeConverter.java @@ -0,0 +1,244 @@ +package com.redis.om.spring.search.stream; + +import com.redis.om.spring.RediSearchIndexer; +import com.redis.om.spring.repository.query.QueryUtils; +import com.redis.om.spring.search.stream.predicates.jedis.JedisValues; +import com.redis.om.spring.util.ObjectUtils; +import org.springframework.data.domain.Example; +import org.springframework.data.geo.Point; +import redis.clients.jedis.search.Schema; +import redis.clients.jedis.search.querybuilder.Node; +import redis.clients.jedis.search.querybuilder.QueryBuilders; +import redis.clients.jedis.search.querybuilder.QueryNode; +import redis.clients.jedis.search.querybuilder.Values; + +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.Date; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import static org.apache.commons.lang3.ObjectUtils.isNotEmpty; + +public class ExampleToNodeConverter { + + private final RediSearchIndexer indexer; + + public ExampleToNodeConverter(RediSearchIndexer indexer) { + this.indexer = indexer; + } + + private static final Pattern SCHEMA_FIELD_NAME_PATTERN = Pattern.compile("Field\\{name='(.*?)'"); + private static Optional getAliasForSchemaField(Schema.Field schemaField) { + Optional alias = Optional.empty(); + Matcher matcher = SCHEMA_FIELD_NAME_PATTERN.matcher(schemaField.toString()); + + if (matcher.find()) { + String name = matcher.group(1); + int aliasStart = name.indexOf("AS"); + if (aliasStart != -1) { + alias = Optional.of(name.substring(aliasStart + 3)); + } + } + + return alias; + } + + public Node processExample(Example example, Node rootNode) { + Class entityClass = example.getProbeType(); + final Schema schema = indexer.getSchemaFor(entityClass); + final boolean matchingAll = example.getMatcher().isAllMatching(); + Set toIgnore = example.getMatcher().getIgnoredPaths(); + + if (schema != null) { + for (Schema.Field schemaField : schema.fields) { + Optional maybeAlias = getAliasForSchemaField(schemaField); + final String fieldName = maybeAlias.orElseGet(() -> schemaField.name.replace("$.", "")); + + if (!toIgnore.contains(fieldName)) { + Object value = ObjectUtils.getValueByPath(example.getProbe(), schemaField.name); + + if (value != null) { + Class cls = value.getClass(); + switch (schemaField.type) { + // + // TAG Index Fields + // + case TAG -> { + if (Iterable.class.isAssignableFrom(value.getClass())) { + Iterable values = (Iterable) value; + values = StreamSupport.stream(values.spliterator(), false) // + .filter(Objects::nonNull).collect(Collectors.toList()); + if (values.iterator().hasNext()) { + QueryNode and = QueryBuilders.intersect(); + for (Object v : values) { + if (!v.toString().isBlank()) and.add(fieldName, "{" + v + "}"); + } + if (matchingAll) { + rootNode = QueryBuilders.intersect(rootNode, and); + } else { + rootNode = QueryBuilders.union(rootNode, and); + } + } + } else { + if (matchingAll) { + rootNode = QueryBuilders.intersect(rootNode).add(fieldName, "{" + value + "}"); + } else { + rootNode = QueryBuilders.union(rootNode).add(fieldName, "{" + value + "}"); + } + } + } + // + // TEXT Index Fields + // + case TEXT -> { + switch (example.getMatcher().getDefaultStringMatcher()) { + case DEFAULT, EXACT -> + rootNode = isNotEmpty(value) ? QueryBuilders.intersect(rootNode).add(fieldName, QueryUtils.escape(value.toString(), false)) : rootNode; + case STARTING -> + rootNode = isNotEmpty(value) ? QueryBuilders.intersect(rootNode).add(fieldName, QueryUtils.escape(value.toString(), false) + "*") : rootNode; + case ENDING -> + rootNode = isNotEmpty(value) ? QueryBuilders.intersect(rootNode).add(fieldName, "*" + QueryUtils.escape(value.toString(), false)) : rootNode; + case CONTAINING -> + rootNode = isNotEmpty(value) ? QueryBuilders.intersect(rootNode).add(fieldName, "*" + QueryUtils.escape(value.toString(), false) + "*") : rootNode; + case REGEX -> { + // NOT SUPPORTED + } + } + } + // + // GEO Index Fields + // + case GEO -> { + double x, y; + if (cls == Point.class) { + Point point = (Point) value; + x = point.getX(); + y = point.getY(); + if (matchingAll) { + rootNode = QueryBuilders.intersect(rootNode).add(fieldName, String.format("[%s %s 0.0001 mi]", x, y)); + } else { + rootNode = QueryBuilders.union(rootNode).add(fieldName, String.format("[%s %s 0.0001 mi]", x, y)); + } + } else if (CharSequence.class.isAssignableFrom(cls)) { + String[] coordinates = value.toString().split(","); + x = Double.parseDouble(coordinates[0]); + y = Double.parseDouble(coordinates[1]); + if (matchingAll) { + rootNode = QueryBuilders.intersect(rootNode).add(fieldName, String.format("[%s %s 0.0001 mi]", x, y)); + } else { + rootNode = QueryBuilders.union(rootNode).add(fieldName, String.format("[%s %s 0.0001 mi]", x, y)); + } + } + } + // + // NUMERIC + // + case NUMERIC -> { + if (Iterable.class.isAssignableFrom(value.getClass())) { + Iterable values = (Iterable) value; + values = StreamSupport.stream(values.spliterator(), false) // + .filter(Objects::nonNull).collect(Collectors.toList()); + + if (values.iterator().hasNext()) { + Class elementClass = values.iterator().next().getClass(); + QueryNode and = QueryBuilders.intersect(); + for (Object v : values) { + if (matchingAll) { + if (elementClass == LocalDate.class) { + and.add(QueryBuilders.intersect(rootNode).add(fieldName, JedisValues.eq((LocalDate) v))); + } else if (elementClass == Date.class) { + and.add(QueryBuilders.intersect(rootNode).add(fieldName, JedisValues.eq((Date) v))); + } else if (elementClass == LocalDateTime.class) { + and.add(QueryBuilders.intersect(rootNode).add(fieldName, JedisValues.eq((LocalDateTime) v))); + } else if (elementClass == Instant.class) { + and.add(QueryBuilders.intersect(rootNode).add(fieldName, JedisValues.eq((Instant) v))); + } else if (elementClass == Integer.class) { + and.add(QueryBuilders.intersect(rootNode).add(fieldName, Values.eq(Integer.parseInt(v.toString())))); + } else if (elementClass == Long.class) { + and.add(QueryBuilders.intersect(rootNode).add(fieldName, Values.eq(Long.parseLong(v.toString())))); + } else if (elementClass == Double.class) { + and.add(QueryBuilders.intersect(rootNode).add(fieldName, Values.eq(Double.parseDouble(v.toString())))); + } + } else { + if (elementClass == LocalDate.class) { + and.add(QueryBuilders.union(rootNode).add(fieldName, JedisValues.eq((LocalDate) v))); + } else if (elementClass == Date.class) { + and.add(QueryBuilders.union(rootNode).add(fieldName, JedisValues.eq((Date) v))); + } else if (elementClass == LocalDateTime.class) { + and.add(QueryBuilders.union(rootNode).add(fieldName, JedisValues.eq((LocalDateTime) v))); + } else if (elementClass == Instant.class) { + and.add(QueryBuilders.union(rootNode).add(fieldName, JedisValues.eq((Instant) v))); + } else if (elementClass == Integer.class) { + and.add(QueryBuilders.union(rootNode).add(fieldName, Values.eq(Integer.parseInt(v.toString())))); + } else if (elementClass == Long.class) { + and.add(QueryBuilders.union(rootNode).add(fieldName, Values.eq(Long.parseLong(v.toString())))); + } else if (elementClass == Double.class) { + and.add(QueryBuilders.union(rootNode).add(fieldName, Values.eq(Double.parseDouble(v.toString())))); + } + } + } + if (matchingAll) { + rootNode = QueryBuilders.intersect(rootNode, and); + } else { + rootNode = QueryBuilders.union(rootNode, and); + } + } + } else { + if (matchingAll) { + if (cls == LocalDate.class) { + rootNode = QueryBuilders.intersect(rootNode).add(fieldName, JedisValues.eq((LocalDate) value)); + } else if (cls == Date.class) { + rootNode = QueryBuilders.intersect(rootNode).add(fieldName, JedisValues.eq((Date) value)); + } else if (cls == LocalDateTime.class) { + rootNode = QueryBuilders.intersect(rootNode).add(fieldName, JedisValues.eq((LocalDateTime) value)); + } else if (cls == Instant.class) { + rootNode = QueryBuilders.intersect(rootNode).add(fieldName, JedisValues.eq((Instant) value)); + } else if (cls == Integer.class) { + rootNode = QueryBuilders.intersect(rootNode).add(fieldName, Values.eq(Integer.parseInt(value.toString()))); + } else if (cls == Long.class) { + rootNode = QueryBuilders.intersect(rootNode).add(fieldName, Values.eq(Long.parseLong(value.toString()))); + } else if (cls == Double.class) { + rootNode = QueryBuilders.intersect(rootNode).add(fieldName, Values.eq(Double.parseDouble(value.toString()))); + } + } else { + if (cls == LocalDate.class) { + rootNode = QueryBuilders.union(rootNode).add(fieldName, JedisValues.eq((LocalDate) value)); + } else if (cls == Date.class) { + rootNode = QueryBuilders.union(rootNode).add(fieldName, JedisValues.eq((Date) value)); + } else if (cls == LocalDateTime.class) { + rootNode = QueryBuilders.union(rootNode).add(fieldName, JedisValues.eq((LocalDateTime) value)); + } else if (cls == Instant.class) { + rootNode = QueryBuilders.union(rootNode).add(fieldName, JedisValues.eq((Instant) value)); + } else if (cls == Integer.class) { + rootNode = QueryBuilders.union(rootNode).add(fieldName, Values.eq(Integer.parseInt(value.toString()))); + } else if (cls == Long.class) { + rootNode = QueryBuilders.union(rootNode).add(fieldName, Values.eq(Long.parseLong(value.toString()))); + } else if (cls == Double.class) { + rootNode = QueryBuilders.union(rootNode).add(fieldName, Values.eq(Double.parseDouble(value.toString()))); + } + } + } + } + // + // VECTOR + // + case VECTOR -> { + //TODO: pending - whether to support Vector fields in QBE + } + } + } + } + } + } + + return rootNode; + } + +} diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/FluentQueryByExample.java b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/FluentQueryByExample.java new file mode 100644 index 000000000..4f2917d4a --- /dev/null +++ b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/FluentQueryByExample.java @@ -0,0 +1,103 @@ +package com.redis.om.spring.search.stream; + +import com.redis.om.spring.metamodel.MetamodelField; +import com.redis.om.spring.metamodel.MetamodelUtils; +import com.redis.om.spring.ops.search.SearchOperations; +import org.springframework.dao.IncorrectResultSizeDataAccessException; +import org.springframework.data.domain.*; +import org.springframework.data.repository.query.FluentQuery; +import redis.clients.jedis.search.Query; +import redis.clients.jedis.search.SearchResult; + +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class FluentQueryByExample implements FluentQuery.FetchableFluentQuery { + private final SearchStream searchStream; + private final Class probeType; + + private final SearchOperations searchOps; + + public FluentQueryByExample( // + Example example, // + Class probeType, // + EntityStream entityStream, // + SearchOperations searchOps // + ) { + this.probeType = probeType; + this.searchOps = searchOps; + this.searchStream = entityStream.of(probeType); + searchStream.filter(example); + } + + @Override + public FetchableFluentQuery sortBy(Sort sort) { + searchStream.sorted(sort); + return this; + } + + @Override + public FetchableFluentQuery as(Class resultType) { + throw new UnsupportedOperationException("`as` is not supported on a Redis Repositories"); + } + + @Override + public FetchableFluentQuery project(Collection properties) { + List> metamodelFields = MetamodelUtils.getMetamodelFieldsForProperties(probeType, + properties); + metamodelFields.forEach(mmf -> searchStream.project((MetamodelField) mmf)); + return this; + } + + @Override + public T oneValue() { + var result = searchStream.collect(Collectors.toList()); + + if (org.springframework.util.ObjectUtils.isEmpty(result)) { + return null; + } + + if (result.size() > 1) { + throw new IncorrectResultSizeDataAccessException("Query returned non unique result", 1); + } + + return result.iterator().next(); + } + + @Override + public T firstValue() { + return searchStream.findFirst().orElse(null); + } + + @Override + public List all() { + return searchStream.collect(Collectors.toList()); + } + + @Override + public Page page(Pageable pageable) { + Query query = (searchStream.backingQuery().isBlank()) ? new Query() : new Query(searchStream.backingQuery()); + query.limit(0, 0); + SearchResult searchResult = searchOps.search(query); + var count = searchResult.getTotalResults(); + var pageContents = searchStream.limit(pageable.getPageSize()).skip(pageable.getOffset()).collect(Collectors.toList()); + return new PageImpl<>(pageContents, pageable, count); + } + + @Override + public Stream stream() { + return all().stream(); + } + + @Override + public long count() { + return searchStream.count(); + } + + @Override + public boolean exists() { + return searchStream.count() > 0; + } +} diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/ReturnFieldsSearchStreamImpl.java b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/ReturnFieldsSearchStreamImpl.java index 31a1638c6..6f774889c 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/ReturnFieldsSearchStreamImpl.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/ReturnFieldsSearchStreamImpl.java @@ -7,13 +7,15 @@ import com.redis.om.spring.search.stream.predicates.SearchFieldPredicate; import com.redis.om.spring.tuple.Tuple; import com.redis.om.spring.tuple.Tuples; +import com.redis.om.spring.util.SearchResultRawResponseToObjectConverter; import com.redis.om.spring.util.ObjectUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.data.annotation.Id; +import org.springframework.data.domain.Example; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Slice; -import org.springframework.data.geo.Point; +import org.springframework.data.domain.Sort; import org.springframework.util.ReflectionUtils; import redis.clients.jedis.search.Document; import redis.clients.jedis.search.Query; @@ -28,8 +30,6 @@ import java.util.function.*; import java.util.stream.*; -import static com.redis.om.spring.util.ObjectUtils.isPrimitiveOfType; - public class ReturnFieldsSearchStreamImpl implements SearchStream { @SuppressWarnings("unused") @@ -112,6 +112,11 @@ public SearchStream filter(String freeText) { throw new UnsupportedOperationException("Filter on free text predicate is not supported on mapped stream"); } + @Override + public SearchStream filter(Example example) { + throw new UnsupportedOperationException("Filter on Example predicate is not supported on mapped stream"); + } + @Override public SearchStream map(Function mapper) { return new WrapperSearchStream<>(resolveStream()).map(mapper); @@ -162,6 +167,11 @@ public SearchStream sorted(Comparator comparator, SortOrder order) return sorted(comparator); } + @Override + public SearchStream sorted(Sort sort) { + throw new UnsupportedOperationException("sorted(Sort) is not supported on a ReturnFieldSearchStream"); + } + @Override public SearchStream peek(Consumer action) { return new WrapperSearchStream<>(resolveStream().peek(action)); @@ -282,7 +292,12 @@ private Stream resolveStream() { if (resultSetHasNonIndexedFields) { SearchResult searchResult = entitySearchStream.getOps().search(query); - List entities = searchResult.getDocuments().stream().map(d -> gson.fromJson(SafeEncoder.encode((byte[])d.get("$")), entitySearchStream.getEntityClass())).toList(); + List entities = searchResult + .getDocuments() // + .stream() // + .map(d -> { // + return gson.fromJson(SafeEncoder.encode((byte[])d.get("$")), entitySearchStream.getEntityClass()); // + }).toList(); results = toResultTuple(entities, returnFields); @@ -309,28 +324,7 @@ private List toResultTuple(SearchResult searchResult, String[] returnFields) String field = foi.getSearchAlias(); Class targetClass = foi.getTargetClass(); var rawValue = props.get(ObjectUtils.isCollection(targetClass) ? "$." + field : field); - Object value = rawValue != null ? SafeEncoder.encode((byte[])rawValue) : null; - - if (value != null) { - if (targetClass == Date.class) { - mappedResults.add(new Date(Long.parseLong(value.toString()))); - } else if (targetClass == Point.class) { - StringTokenizer st = new StringTokenizer(value.toString(), ","); - String lon = st.nextToken(); - String lat = st.nextToken(); - - mappedResults.add(new Point(Double.parseDouble(lon), Double.parseDouble(lat))); - } else if (targetClass == String.class) { - mappedResults.add(value.toString()); - } else if (targetClass == Boolean.class || isPrimitiveOfType(targetClass, Boolean.class)) { - mappedResults.add(value.toString().equals("1")); - } else { - mappedResults.add(gson.fromJson(value.toString(), targetClass)); - } - } else { - mappedResults.add(null); - } - + mappedResults.add(SearchResultRawResponseToObjectConverter.process(rawValue, targetClass, gson)); }); if (returning.size() > 1) { @@ -430,4 +424,20 @@ public Slice getSlice(Pageable pageable) { throw new UnsupportedOperationException("getPage is not supported on a ReturnFieldSearchStream"); } + @Override + public SearchStream project(Function field) { + throw new UnsupportedOperationException("project is not supported on a ReturnFieldSearchStream"); + } + + @SafeVarargs + @Override + public final SearchStream project(MetamodelField... field) { + throw new UnsupportedOperationException("project is not supported on a ReturnFieldSearchStream"); + } + + @Override + public String backingQuery() { + return entitySearchStream.backingQuery(); + } + } diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/SearchStream.java b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/SearchStream.java index c3e07276c..adfd3e481 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/SearchStream.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/SearchStream.java @@ -4,8 +4,10 @@ import com.redis.om.spring.metamodel.indexed.NumericField; import com.redis.om.spring.ops.search.SearchOperations; import com.redis.om.spring.search.stream.predicates.SearchFieldPredicate; +import org.springframework.data.domain.Example; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Slice; +import org.springframework.data.domain.Sort; import redis.clients.jedis.search.aggr.SortedField.SortOrder; import java.time.Duration; @@ -23,6 +25,8 @@ public interface SearchStream extends BaseStream> { SearchStream filter(String freeText); + SearchStream filter(Example example); + SearchStream map(Function field); Stream map(ToLongFunction mapper); @@ -44,6 +48,7 @@ public interface SearchStream extends BaseStream> { SearchStream sorted(Comparator comparator); SearchStream sorted(Comparator comparator, SortOrder order); + SearchStream sorted(Sort sort); SearchStream peek(Consumer action); @@ -108,4 +113,10 @@ public interface SearchStream extends BaseStream> { SearchOperations getSearchOperations(); Slice getSlice(Pageable pageable); + + SearchStream project(Function field); + @SuppressWarnings("unchecked") + SearchStream project(MetamodelField ...field); + + String backingQuery(); } diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/SearchStreamImpl.java b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/SearchStreamImpl.java index eb65baf7f..e58853a2c 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/SearchStreamImpl.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/SearchStreamImpl.java @@ -2,6 +2,7 @@ import com.google.gson.Gson; import com.google.gson.GsonBuilder; +import com.redis.om.spring.RediSearchIndexer; import com.redis.om.spring.annotations.Document; import com.redis.om.spring.convert.MappingRedisOMConverter; import com.redis.om.spring.metamodel.MetamodelField; @@ -15,12 +16,12 @@ import com.redis.om.spring.tuple.AbstractTupleMapper; import com.redis.om.spring.tuple.Pair; import com.redis.om.spring.tuple.TupleMapper; +import com.redis.om.spring.util.SearchResultRawResponseToObjectConverter; import com.redis.om.spring.util.ObjectUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.data.domain.Page; -import org.springframework.data.domain.Pageable; -import org.springframework.data.domain.Slice; +import org.springframework.beans.BeanUtils; +import org.springframework.data.domain.*; import org.springframework.data.domain.Sort.Order; import org.springframework.data.redis.core.convert.ReferenceResolverImpl; import redis.clients.jedis.search.Query; @@ -36,9 +37,13 @@ import java.lang.reflect.Method; import java.time.Duration; import java.util.*; +import java.util.Map.Entry; import java.util.function.*; import java.util.stream.*; +import static com.redis.om.spring.metamodel.MetamodelUtils.getMetamodelForIdField; +import static java.util.stream.Collectors.toCollection; + public class SearchStreamImpl implements SearchStream { @SuppressWarnings("unused") @@ -67,7 +72,12 @@ public class SearchStreamImpl implements SearchStream { private final MappingRedisOMConverter mappingConverter; private int dialect = 1; - public SearchStreamImpl(Class entityClass, RedisModulesOperations modulesOperations, GsonBuilder gsonBuilder) { + private final List> projections = new ArrayList<>(); + + private final ExampleToNodeConverter exampleToNodeConverter; + + public SearchStreamImpl(Class entityClass, RedisModulesOperations modulesOperations, GsonBuilder gsonBuilder, + RediSearchIndexer indexer) { this.modulesOperations = modulesOperations; this.entityClass = entityClass; this.searchIndex = entityClass.getName() + "Idx"; @@ -83,6 +93,7 @@ public SearchStreamImpl(Class entityClass, RedisModulesOperations mod this.isDocument = entityClass.isAnnotationPresent(Document.class); this.mappingConverter = new MappingRedisOMConverter(null, new ReferenceResolverImpl(modulesOperations.getTemplate())); + this.exampleToNodeConverter = new ExampleToNodeConverter<>(indexer); } @Override @@ -113,7 +124,7 @@ public String toString() { public String toString(Parenthesize mode) { return switch(mode) { case NEVER -> toString(); - case ALWAYS, DEFAULT -> String.format("(%s)", toString()); + case ALWAYS, DEFAULT -> String.format("(%s)", this); }; } }; @@ -121,13 +132,18 @@ public String toString(Parenthesize mode) { return this; } + @Override + public SearchStream filter(Example example) { + rootNode = exampleToNodeConverter.processExample(example, rootNode); + return this; + } + public Node processPredicate(SearchFieldPredicate predicate) { return predicate.apply(rootNode); } private Node processPredicate(Predicate predicate) { if (SearchFieldPredicate.class.isAssignableFrom(predicate.getClass())) { - @SuppressWarnings("unchecked") SearchFieldPredicate p = (SearchFieldPredicate) predicate; return processPredicate(p); } @@ -218,6 +234,16 @@ public SearchStream sorted(Comparator comparator, SortOrder order) return this; } + @Override + public SearchStream sorted(Sort sort) { + Optional maybeOrder = sort.stream().sorted().findFirst(); + if (maybeOrder.isPresent()) { + Order order = maybeOrder.get(); + sortBy = new SortedField(order.getProperty(), order.isAscending() ? SortOrder.ASC : SortOrder.DESC); + } + return this; + } + @Override public SearchStream peek(Consumer action) { return new WrapperSearchStream<>(resolveStream().peek(action)); @@ -401,6 +427,13 @@ Query prepareQuery() { if (onlyIds) { query.returnFields(idField.getName()); + } else if (!projections.isEmpty()) { + var returnFields = projections.stream() // + .map(foi -> ObjectUtils.isCollection(foi.getTargetClass()) ? "$." + foi.getSearchAlias() : foi.getSearchAlias()) + .collect(toCollection(ArrayList::new)); + returnFields.add(idField.getName()); + + query.returnFields(returnFields.toArray(String[]::new)); } return query; @@ -411,11 +444,40 @@ private SearchResult executeQuery() { } private List toEntityList(SearchResult searchResult) { - Gson g = getGson(); - if (isDocument) { - return searchResult.getDocuments().stream().map(d -> g.fromJson(SafeEncoder.encode((byte[])d.get("$")), entityClass)).toList(); + if (projections.isEmpty()) { + if (isDocument) { + Gson g = getGson(); + return searchResult.getDocuments().stream() + .map(d -> g.fromJson(SafeEncoder.encode((byte[]) d.get("$")), entityClass)).toList(); + } else { + return searchResult.getDocuments().stream() + .map(d -> (E) ObjectUtils.documentToObject(d, entityClass, mappingConverter)).toList(); + } } else { - return searchResult.getDocuments().stream().map(d -> (E)ObjectUtils.documentToObject(d, entityClass, mappingConverter)).toList(); + List projectedEntities = new ArrayList<>(); + searchResult.getDocuments().forEach(doc -> { + Map props = StreamSupport.stream(doc.getProperties().spliterator(), false) + .collect(Collectors.toMap(Entry::getKey, Entry::getValue)); + + E entity = BeanUtils.instantiateClass(this.entityClass); + projections.forEach(foi -> { + String field = foi.getSearchAlias(); + Class targetClass = foi.getTargetClass(); + + var rawValue = props.get(ObjectUtils.isCollection(targetClass) ? "$." + field : field); + Object processValue = SearchResultRawResponseToObjectConverter.process(rawValue, targetClass, getGson()); + + if (processValue != null) { + try { + foi.getSearchFieldAccessor().getField().set(entity, processValue); + } catch (IllegalAccessException e) { + logger.debug("🧨 couldn't set value on " + field, e); + } + } + }); + projectedEntities.add(entity); + }); + return projectedEntities; } } @@ -520,7 +582,7 @@ public Optional min(NumericField field) { .limit(1) // .toList(String.class, Double.class); - return minByField.isEmpty() ? Optional.empty() : Optional.of(json.get(minByField.get(0).getFirst(), entityClass)); + return minByField.isEmpty() ? Optional.empty() : Optional.ofNullable(json.get(minByField.get(0).getFirst(), entityClass)); } @Override @@ -531,7 +593,7 @@ public Optional max(NumericField field) { .limit(1) // .toList(String.class, Double.class); - return maxByField.isEmpty() ? Optional.empty() : Optional.of(json.get(maxByField.get(0).getFirst(), entityClass)); + return maxByField.isEmpty() ? Optional.empty() : Optional.ofNullable(json.get(maxByField.get(0).getFirst(), entityClass)); } @Override public SearchStream dialect(int dialect) { @@ -554,14 +616,49 @@ public Slice getSlice(Pageable pageable) { } } + @Override + public SearchStream project(Function field) { + if (MetamodelField.class.isAssignableFrom(field.getClass())) { + @SuppressWarnings("unchecked") + MetamodelField foi = (MetamodelField) field; + + projections.add(foi); + } else if (TupleMapper.class.isAssignableFrom(field.getClass())) { + @SuppressWarnings("rawtypes") + AbstractTupleMapper tm = (AbstractTupleMapper) field; + + IntStream.range(0, tm.degree()).forEach(i -> { + @SuppressWarnings("unchecked") + MetamodelField foi = (MetamodelField) tm.get(i); + projections.add(foi); + }); + } + projections.add((MetamodelField) getMetamodelForIdField(this.entityClass)); + return this; + } + + @SuppressWarnings("unchecked") + @Override + public SearchStream project(MetamodelField... fields) { + for (MetamodelField field: fields) { + projections.add((MetamodelField) field); + } + return this; + } + + @Override + public String backingQuery() { + return rootNode.toString(); + } + public boolean isDocument() { return isDocument; } + private Gson getGson() { if (gson == null) { gson = gsonBuilder.create(); } return gson; } - } diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/WrapperSearchStream.java b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/WrapperSearchStream.java index 54bd1d082..6a543785f 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/WrapperSearchStream.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/search/stream/WrapperSearchStream.java @@ -4,8 +4,10 @@ import com.redis.om.spring.metamodel.indexed.NumericField; import com.redis.om.spring.ops.search.SearchOperations; import com.redis.om.spring.search.stream.predicates.SearchFieldPredicate; +import org.springframework.data.domain.Example; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Slice; +import org.springframework.data.domain.Sort; import redis.clients.jedis.search.aggr.SortedField.SortOrder; import java.time.Duration; @@ -86,6 +88,12 @@ public SearchStream filter(String freeText) { return this; } + @Override + public SearchStream filter(Example example) { + // NO-OP + return this; + } + @Override public SearchStream map(Function mapper) { return new WrapperSearchStream<>(backingStream.map(mapper)); @@ -136,6 +144,11 @@ public SearchStream sorted(Comparator comparator, SortOrder order) return new WrapperSearchStream<>(backingStream.sorted(comparator)); } + @Override + public SearchStream sorted(Sort sort) { + throw new UnsupportedOperationException("sorted(Sort) is not supported on a WrappedSearchStream"); + } + @Override public SearchStream peek(Consumer action) { return new WrapperSearchStream<>(backingStream.peek(action)); @@ -295,4 +308,20 @@ public Slice getSlice(Pageable pageable) { throw new UnsupportedOperationException("getPage is not supported on a WrappedSearchStream"); } + @Override + public SearchStream project(Function field) { + throw new UnsupportedOperationException("project is not supported on a WrappedSearchStream"); + } + + @SafeVarargs + @Override + public final SearchStream project(MetamodelField... field) { + throw new UnsupportedOperationException("project is not supported on a WrappedSearchStream"); + } + + @Override + public String backingQuery() { + throw new UnsupportedOperationException("backingQuery is not supported on a WrappedSearchStream"); + } + } diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/util/ObjectUtils.java b/redis-om-spring/src/main/java/com/redis/om/spring/util/ObjectUtils.java index 6a8ccd533..e56afce6b 100644 --- a/redis-om-spring/src/main/java/com/redis/om/spring/util/ObjectUtils.java +++ b/redis-om-spring/src/main/java/com/redis/om/spring/util/ObjectUtils.java @@ -11,11 +11,18 @@ import org.springframework.core.ResolvableType; import org.springframework.core.type.filter.AnnotationTypeFilter; import org.springframework.data.annotation.Id; +import org.springframework.data.domain.Page; +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Slice; +import org.springframework.data.domain.Sort; import org.springframework.data.geo.Distance; import org.springframework.data.redis.connection.RedisGeoCommands.DistanceUnit; import org.springframework.data.redis.core.convert.Bucket; import org.springframework.data.redis.core.convert.RedisData; import org.springframework.data.util.Pair; +import org.springframework.expression.Expression; +import org.springframework.expression.ExpressionParser; +import org.springframework.expression.spel.standard.SpelExpressionParser; import org.springframework.util.ReflectionUtils; import redis.clients.jedis.args.GeoUnit; import redis.clients.jedis.search.Document; @@ -25,10 +32,14 @@ import java.lang.reflect.Method; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; +import java.math.BigDecimal; +import java.math.BigInteger; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.*; import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static java.util.Objects.requireNonNull; import static org.springframework.util.ClassUtils.resolvePrimitiveIfNecessary; @@ -127,7 +138,7 @@ public static Object getIdFieldForEntity(Object entity) { public static Object getIdFieldForEntity(Field idField, Object entity) { String getterName = "get" + ObjectUtils.ucfirst(idField.getName()); Method getter = ReflectionUtils.findMethod(entity.getClass(), getterName); - return ReflectionUtils.invokeMethod(getter, entity); + return ReflectionUtils.invokeMethod(requireNonNull(getter), entity); } public static Method getGetterForField(Class cls, Field field) { @@ -140,6 +151,12 @@ public static Method getSetterForField(Class cls, Field field) { return ReflectionUtils.findMethod(cls, setterName, field.getType()); } + public static Object getValueForField(Field field, Object entity) { + String getterName = "get" + ObjectUtils.ucfirst(field.getName()); + Method getter = ReflectionUtils.findMethod(entity.getClass(), getterName); + return getter != null ? ReflectionUtils.invokeMethod(getter, entity) : null; + } + /** * Returns the specified text but with the first character uppercase. * @@ -397,7 +414,7 @@ public static byte[] longArrayToByteArray(long[] input) { return floatArrayToByteArray(floats); } - public static Collection instantiateCollection(Type type) { + public static Collection instantiateCollection(Type type) { Class rawType = (Class) ((ParameterizedType) type).getRawType(); if (rawType.isInterface()) { if (List.class.isAssignableFrom(rawType)) { @@ -427,6 +444,241 @@ public static String getKey(String keyspace, Object id) { return String.format(format, keyspace, id); } + public static Page pageFromSlice(Slice slice) { + return new Page<>() { + @Override + public int getTotalPages() { + return -1; + } + + @Override + public long getTotalElements() { + return -1; + } + + @Override + public Page map(Function converter) { + return pageFromSlice(slice.map(converter)); + } + + @Override + public int getNumber() { + return slice.getNumber(); + } + + @Override + public int getSize() { + return slice.getSize(); + } + + @Override + public int getNumberOfElements() { + return slice.getNumberOfElements(); + } + + @Override + public List getContent() { + return slice.getContent(); + } + + @Override + public boolean hasContent() { + return slice.hasContent(); + } + + @Override + public Sort getSort() { + return slice.getSort(); + } + + @Override + public boolean isFirst() { + return slice.isFirst(); + } + + @Override + public boolean isLast() { + return slice.isLast(); + } + + @Override + public boolean hasNext() { + return slice.hasNext(); + } + + @Override + public boolean hasPrevious() { + return slice.hasPrevious(); + } + + @Override + public Pageable nextPageable() { + return slice.nextPageable(); + } + + @Override + public Pageable previousPageable() { + return slice.previousPageable(); + } + + @Override + public Iterator iterator() { + return slice.iterator(); + } + }; + } + + private static final ExpressionParser SPEL_EXPRESSION_PARSER = new SpelExpressionParser(); + + public static Object getValueByPath(Object target, String path) { + // Remove JSONPath prefix + String safeSpelPath = path.replace("$.", ""); + // does the expression have any arrays + boolean hasNestedObject = path.contains("[0:]"); + + Object value = null; + + if (!hasNestedObject) { + safeSpelPath = safeSpelPath // + .replace(".", "?.") // + .replace("[*]", "") // + .replace(".", "?."); + + value = SPEL_EXPRESSION_PARSER.parseExpression(safeSpelPath).getValue(target); + } else { + String[] tempParts = safeSpelPath.split("\\[0:\\]", 2); + String[] parts = tempParts[1].split("\\.", 2); + String leftPath = tempParts[0].replace(".", "?."); + String rightPath = parts[1].replace(".", "?."); + + Expression leftExp = SPEL_EXPRESSION_PARSER.parseExpression(leftPath); + Expression rightExp = SPEL_EXPRESSION_PARSER.parseExpression(rightPath); + Collection left = (Collection) leftExp.getValue(target); + if (left != null && !left.isEmpty()) { + value = flattenCollection(left.stream().map(rightExp::getValue).toList()); + } + } + + return value; + } + + public static Collection flattenCollection(Collection inputCollection) { + List flatList = new ArrayList<>(); + + for (Object element : inputCollection) { + if (element instanceof Collection) { + flatList.addAll(flattenCollection((Collection) element)); + } else { + flatList.add(element); + } + } + + return flatList; + } + + public static String replaceIfIllegalJavaIdentifierCharacter(final String word) { + requireNonNull(word); + if (word.isEmpty()) { + return REPLACEMENT_CHARACTER.toString(); // No name is translated to REPLACEMENT_CHARACTER only + } + final StringBuilder sb = new StringBuilder(); + for (int i = 0; i < word.length(); i++) { + char c = word.charAt(i); + if (i == 0) { + if (Character.isJavaIdentifierStart(c)) { + // Fine! Just add the first character + sb.append(c); + } else if (Character.isJavaIdentifierPart(c)) { + // Not ok as the first, but ok otherwise. Add the replacement before it + sb.append(REPLACEMENT_CHARACTER).append(c); + } else { + // Cannot be used as a java identifier. Replace it + sb.append(REPLACEMENT_CHARACTER); + } + } else if (Character.isJavaIdentifierPart(c)) { + // Fine! Just add it + sb.append(c); + } else { + // Cannot be used as a java identifier. Replace it + sb.append(REPLACEMENT_CHARACTER); + } + + } + return sb.toString(); + } + + static final Set JAVA_LITERAL_WORDS = Set.of("true", "false", "null"); + + // Java reserved keywords + static final Set JAVA_RESERVED_WORDS = Collections.unmodifiableSet(Stream.of( + // Unused + "const", "goto", + // The real ones... + "abstract", "continue", "for", "new", "switch", "assert", "default", "goto", "package", "synchronized", "boolean", + "do", "if", "private", "this", "break", "double", "implements", "protected", "throw", "byte", "else", "import", + "public", "throws", "case", "enum", "instanceof", "return", "transient", "catch", "extends", "int", "short", + "try", "char", "final", "interface", "static", "void", "class", "finally", "long", "strictfp", "volatile", + "const", "float", "native", "super", "while").collect(Collectors.toSet())); + + static final Set> JAVA_BUILT_IN_CLASSES = Set.of(Boolean.class, Byte.class, Character.class, Double.class, + Float.class, Integer.class, Long.class, Object.class, Short.class, String.class, BigDecimal.class, + BigInteger.class, boolean.class, byte.class, char.class, double.class, float.class, int.class, long.class, + short.class); + + private static final Set JAVA_BUILT_IN_CLASS_WORDS = Collections + .unmodifiableSet(JAVA_BUILT_IN_CLASSES.stream().map(Class::getSimpleName).collect(Collectors.toSet())); + + private static final Set JAVA_USED_WORDS = Collections + .unmodifiableSet(Stream.of(JAVA_LITERAL_WORDS, JAVA_RESERVED_WORDS, JAVA_BUILT_IN_CLASS_WORDS) + .flatMap(Collection::stream).collect(Collectors.toSet())); + + private static final Set JAVA_USED_WORDS_LOWER_CASE = Collections + .unmodifiableSet(JAVA_USED_WORDS.stream().map(String::toLowerCase).collect(Collectors.toSet())); + + /** + * Returns a static field name representation of the specified camel-cased + * string. + * + * @param externalName the string + * @return the static field name representation + */ + public static String staticField(final String externalName) { + requireNonNull(externalName); + return ObjectUtils.toUnderscoreSeparated(javaNameFromExternal(externalName)).toUpperCase(); + } + + public static String javaNameFromExternal(final String externalName) { + requireNonNull(externalName); + return ObjectUtils + .replaceIfIllegalJavaIdentifierCharacter(replaceIfJavaUsedWord(nameFromExternal(externalName))); + } + + public static String nameFromExternal(final String externalName) { + requireNonNull(externalName); + String result = ObjectUtils.unQuote(externalName.trim()); // Trim if there are initial spaces or trailing spaces... + /* CamelCase + * http://stackoverflow.com/questions/4050381/regular-expression-for-checking-if + * -capital-letters-are-found-consecutively-in-a [A-Z] -> \p{Lu} [^A-Za-z0-9] -> + * [^\pL0-90-9] */ + result = Stream.of(result.replaceAll("(\\p{Lu}+)", "_$1").split("[^\\pL\\d]")).map(String::toLowerCase) + .map(ObjectUtils::ucfirst).collect(Collectors.joining()); + return result; + } + + public static String replaceIfJavaUsedWord(final String word) { + requireNonNull(word); + // We need to replace regardless of case because we do not know how the returned + // string is to be used + if (JAVA_USED_WORDS_LOWER_CASE.contains(word.toLowerCase())) { + // If it is a java reserved/literal/class, add a "_" at the end to avoid naming + // conflicts + return word + "_"; + } + return word; + } + + public static final Character REPLACEMENT_CHARACTER = '_'; + private ObjectUtils() { } } diff --git a/redis-om-spring/src/main/java/com/redis/om/spring/util/SearchResultRawResponseToObjectConverter.java b/redis-om-spring/src/main/java/com/redis/om/spring/util/SearchResultRawResponseToObjectConverter.java new file mode 100644 index 000000000..c524a3081 --- /dev/null +++ b/redis-om-spring/src/main/java/com/redis/om/spring/util/SearchResultRawResponseToObjectConverter.java @@ -0,0 +1,36 @@ +package com.redis.om.spring.util; + +import com.google.gson.Gson; +import org.springframework.data.geo.Point; +import redis.clients.jedis.util.SafeEncoder; + +import java.util.Date; +import java.util.StringTokenizer; + +import static com.redis.om.spring.util.ObjectUtils.isPrimitiveOfType; + +public class SearchResultRawResponseToObjectConverter { + public static Object process(Object rawValue, Class targetClass, Gson gson) { + Object value = rawValue != null ? SafeEncoder.encode((byte[]) rawValue) : null; + + Object processValue = null; + if (value != null) { + if (targetClass == Date.class) { + processValue = new Date(Long.parseLong(value.toString())); + } else if (targetClass == Point.class) { + StringTokenizer st = new StringTokenizer(value.toString(), ","); + String lon = st.nextToken(); + String lat = st.nextToken(); + + processValue = new Point(Double.parseDouble(lon), Double.parseDouble(lat)); + } else if (targetClass == String.class) { + processValue = value.toString(); + } else if (targetClass == Boolean.class || isPrimitiveOfType(targetClass, Boolean.class)) { + processValue = value.toString().equals("1"); + } else { + processValue = gson.fromJson(value.toString(), targetClass); + } + } + return processValue; + } +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/RedisDocumentQueryByExampleTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/RedisDocumentQueryByExampleTest.java new file mode 100644 index 000000000..72cc59d14 --- /dev/null +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/RedisDocumentQueryByExampleTest.java @@ -0,0 +1,506 @@ +package com.redis.om.spring.annotations.document; + +import com.redis.om.spring.AbstractBaseDocumentTest; +import com.redis.om.spring.annotations.document.fixtures.*; +import com.redis.om.spring.search.stream.EntityStream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.dao.IncorrectResultSizeDataAccessException; +import org.springframework.data.domain.*; +import org.springframework.data.domain.ExampleMatcher.StringMatcher; +import org.springframework.data.geo.Point; +import org.springframework.data.repository.query.FluentQuery; +import org.springframework.data.repository.query.FluentQuery.FetchableFluentQuery; + +import java.time.LocalDate; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.junit.jupiter.api.Assertions.assertAll; + +public class RedisDocumentQueryByExampleTest extends AbstractBaseDocumentTest { + @Autowired + MyDocRepository repository; + + @Autowired + EntityStream entityStream; + + @Autowired + CompanyRepository companyRepository; + + String id1; + String id2; + + @BeforeEach + void loadTestData() { + repository.deleteAll(); + Point point1 = new Point(-122.124500, 47.640160); + MyDoc doc1 = MyDoc.of("hello world", point1, point1, 1); + doc1.setTag(Set.of("news", "article")); + + Point point2 = new Point(-122.066540, 37.377690); + MyDoc doc2 = MyDoc.of("hello mundo", point2, point2, 2); + doc2.setTag(Set.of("noticias", "articulo")); + + Point point3 = new Point(-122.066540, 37.377690); + MyDoc doc3 = MyDoc.of("ola mundo", point3, point3, 3); + doc3.setTag(Set.of("noticias", "artigo")); + + Point point4 = new Point(-122.066540, 37.377690); + MyDoc doc4 = MyDoc.of("bonjour le monde", point4, point4, 3); + doc4.setTag(Set.of("actualite", "article")); + + repository.saveAll(List.of(doc1, doc2, doc3, doc4)); + + id1 = doc1.getId(); + id2 = doc2.getId(); + + companyRepository.deleteAll(); + Company redis = Company.of("RedisInc", 2011, LocalDate.of(2021, 5, 1), new Point(-122.066540, 37.377690), + "stack@redis.com"); + redis.setTags(Set.of("RedisTag", "CommonTag")); + redis.setMetaList(Set.of(CompanyMeta.of("RD", 100, Set.of("RedisTag", "CommonTag")))); + + Company microsoft = Company.of("Microsoft", 1975, LocalDate.of(2022, 8, 15), + new Point(-122.124500, 47.640160), "research@microsoft.com"); + microsoft.setTags(Set.of("MsTag", "CommonTag")); + microsoft.setMetaList(Set.of(CompanyMeta.of("MS", 50, Set.of("MsTag", "CommonTag")))); + + companyRepository.saveAll(List.of(redis, microsoft)); + } + + @Test + void testFindOneByExampleById() { + MyDoc template = new MyDoc(); + template.setId(id1); + + Example example = Example.of(template); + + Optional maybeDoc1 = repository.findOne(example); + assertThat(maybeDoc1).isPresent(); + assertThat(maybeDoc1.get().getTitle()).isEqualTo("hello world"); + } + + @Test + void testFindOneByExampleWithTextIndexedProperty() { + MyDoc template = new MyDoc(); + template.setTitle("hello world"); + + Example example = Example.of(template); + + Optional maybeDoc1 = repository.findOne(example); + assertThat(maybeDoc1).isPresent(); + assertThat(maybeDoc1.get().getTitle()).isEqualTo("hello world"); + } + + @Test + void testFindOneByExampleWithExplicitTagIndexedAnnotation() { + MyDoc template = new MyDoc(); + template.setTag(Set.of("news")); + + Example example = Example.of(template); + + Optional maybeDoc1 = repository.findOne(example); + assertThat(maybeDoc1).isPresent(); + MyDoc doc1 = maybeDoc1.get(); + assertThat(doc1.getTitle()).isEqualTo("hello world"); + assertThat(doc1.getTag()).contains("news"); + } + + @Test + void testFindOneByExampleWithExplicitNumericIndexedAnnotation() { + MyDoc template = new MyDoc(); + template.setANumber(1); + + Example example = Example.of(template); + + Optional maybeDoc1 = repository.findOne(example); + assertThat(maybeDoc1).isPresent(); + MyDoc doc1 = maybeDoc1.get(); + assertThat(doc1.getTitle()).isEqualTo("hello world"); + assertThat(doc1.getANumber()).isEqualTo(1); + } + + @Test + void testFindOneByExampleWithFieldWithExplicitGeoIndexedAnnotation() { + MyDoc template = new MyDoc(); + template.setLocation(new Point(-122.066540, 37.377690)); + + Example example = Example.of(template); + + Optional maybeDoc1 = repository.findOne(example); + assertThat(maybeDoc1).isPresent(); + MyDoc doc1 = maybeDoc1.get(); + assertThat(doc1.getTitle()).isEqualTo("hello mundo"); + assertThat(doc1.getANumber()).isEqualTo(2); + } + + @Test + void testFindOneByExampleWithMultipleFields() { + MyDoc template = new MyDoc(); + template.setANumber(3); + template.setTag(Set.of("noticias")); + + Example example = Example.of(template); + + Optional maybeDoc1 = repository.findOne(example); + assertThat(maybeDoc1).isPresent(); + MyDoc doc1 = maybeDoc1.get(); + assertThat(doc1.getTitle()).isEqualTo("ola mundo"); + assertThat(doc1.getTag()).contains("noticias"); + } + + @Test + public void testFindAllByExampleShouldReturnEmptyListIfNotMatching() { + MyDoc template = new MyDoc(); + template.setANumber(42); + + Example example = Example.of(template); + + Iterable noMatches = repository.findAll(example); + assertThat(noMatches).isEmpty(); + } + + @Test + public void testFindAllByExampleShouldReturnAllMatches() { + MyDoc template = new MyDoc(); + template.setTag(Set.of("noticias")); + + Example example = Example.of(template); + + Iterable allMatches = repository.findAll(example); + assertThat(allMatches).hasSize(2); + assertThat(allMatches).extracting("title").contains("hello mundo", "ola mundo"); + } + + @Test + public void testFindAllByExampleShouldReturnEverythingWhenSampleIsEmpty() { + MyDoc template = new MyDoc(); + + Example example = Example.of(template); + + Iterable allMatches = repository.findAll(example); + assertThat(allMatches).hasSize(4); + } + + @Test + public void testFindAllByExampleUsingAnyMatch() { + MyDoc template = new MyDoc(); + template.setTitle("hello world"); + template.setTag(Set.of("artigo")); + + Example example = Example.of(template, ExampleMatcher.matchingAny()); + + Iterable allMatches = repository.findAll(example); + assertThat(allMatches).hasSize(2); + assertThat(allMatches).extracting("title").contains("hello world", "ola mundo"); + } + + @Test + public void testFindAllByExampleUsingAnyMatch2() { + MyDoc template = new MyDoc(); + template.setTitle("hello world"); + template.setANumber(3); + + Example example = Example.of(template, ExampleMatcher.matchingAny()); + + Iterable allMatches = repository.findAll(example); + assertThat(allMatches).hasSize(3); + assertThat(allMatches).extracting("title").contains("hello world", "ola mundo", "bonjour le monde"); + } + + @Test + void testFindAllByExampleWithTextPropertyStartingWith() { + MyDoc template = new MyDoc(); + template.setTitle("hello"); + + ExampleMatcher matcher = ExampleMatcher.matching() + .withStringMatcher(StringMatcher.STARTING); + + Example example = Example.of(template, matcher); + + Iterable allMatches = repository.findAll(example); + assertThat(allMatches).hasSize(2); + assertThat(allMatches).extracting("title").contains("hello world", "hello mundo"); + } + + @Test + void testFindAllByExampleWithTextPropertyEndingWith() { + MyDoc template = new MyDoc(); + template.setTitle("ndo"); + + ExampleMatcher matcher = ExampleMatcher.matching() + .withStringMatcher(StringMatcher.ENDING); + + Example example = Example.of(template, matcher); + + Iterable allMatches = repository.findAll(example); + assertThat(allMatches).hasSize(2); + assertThat(allMatches).extracting("title").contains("ola mundo", "hello mundo"); + } + + @Test + void testFindAllByExampleWithTextPropertyContaining() { + MyDoc template = new MyDoc(); + template.setTitle("llo"); + + ExampleMatcher matcher = ExampleMatcher.matching() + .withStringMatcher(StringMatcher.CONTAINING); + + Example example = Example.of(template, matcher); + + Iterable allMatches = repository.findAll(example); + assertThat(allMatches).hasSize(2); + assertThat(allMatches).extracting("title").contains("hello world", "hello mundo"); + } + + @Test + public void testFindAllByExampleWithIgnorePaths() { + MyDoc template = new MyDoc(); + template.setTitle("hello world"); + template.setANumber(3); + + Example example = Example.of(template, ExampleMatcher.matchingAny().withIgnorePaths("aNumber")); + + Iterable allMatches = repository.findAll(example); + assertThat(allMatches).hasSize(1); + assertThat(allMatches).extracting("title").contains("hello world"); + } + + @Test + void testFindAllByExampleWithStringValueExampleInNestedField() { + Company redisTemplate = new Company(); + CompanyMeta redisCm = new CompanyMeta(); + redisCm.setStringValue("RD"); + redisTemplate.setMetaList(Set.of(redisCm)); + + Example redisExample = Example.of(redisTemplate); + + Company msTemplate = new Company(); + CompanyMeta msCm = new CompanyMeta(); + msCm.setStringValue("MS"); + msTemplate.setMetaList(Set.of(msCm)); + + Example msExample = Example.of(msTemplate); + + Iterable shouldBeOnlyRedis = companyRepository.findAll(redisExample); + Iterable shouldBeOnlyMS = companyRepository.findAll(msExample); + + assertAll( // + () -> assertThat(shouldBeOnlyRedis).map(Company::getName).containsExactly("RedisInc"), // + () -> assertThat(shouldBeOnlyMS).map(Company::getName).containsExactly("Microsoft") // + ); + } + + @Test + void testFindAllByExampleWithNumericValueExampleInNestedField() { + Company redisTemplate = new Company(); + CompanyMeta redisCm = new CompanyMeta(); + redisCm.setNumberValue(100); + redisTemplate.setMetaList(Set.of(redisCm)); + + Example redisExample = Example.of(redisTemplate); + + Company msTemplate = new Company(); + CompanyMeta msCm = new CompanyMeta(); + msCm.setNumberValue(50); + msTemplate.setMetaList(Set.of(msCm)); + + Example msExample = Example.of(msTemplate); + + Iterable shouldBeOnlyRedis = companyRepository.findAll(redisExample); + Iterable shouldBeOnlyMS = companyRepository.findAll(msExample); + + assertAll( // + () -> assertThat(shouldBeOnlyRedis).map(Company::getName).containsExactly("RedisInc"), // + () -> assertThat(shouldBeOnlyMS).map(Company::getName).containsExactly("Microsoft") // + ); + } + + @Test + void testFindAllByExampleWithTags() { + Company redisTemplate = new Company(); + redisTemplate.setTags(Set.of("RedisTag")); + Example redisExample = Example.of(redisTemplate); + + Company msTemplate = new Company(); + msTemplate.setTags(Set.of("MsTag")); + Example msExample = Example.of(msTemplate); + + Company bothTemplate = new Company(); + bothTemplate.setTags(Set.of("CommonTag")); + Example bothExample = Example.of(bothTemplate); + + Iterable shouldBeOnlyRedis = companyRepository.findAll(redisExample); + Iterable shouldBeOnlyMS = companyRepository.findAll(msExample); + Iterable shouldBeBoth = companyRepository.findAll(bothExample); + + assertAll( // + () -> assertThat(shouldBeOnlyRedis).map(Company::getName).containsExactly("RedisInc"), // + () -> assertThat(shouldBeOnlyMS).map(Company::getName).containsExactly("Microsoft"), // + () -> assertThat(shouldBeBoth).map(Company::getName).containsExactlyInAnyOrder("RedisInc", + "Microsoft") // + ); + } + + @Test + void testFindAllByExampleWithTagsInNestedField() { + Company redisTemplate = new Company(); + CompanyMeta redisCm = new CompanyMeta(); + redisCm.setTagValues(Set.of("RedisTag")); + redisTemplate.setMetaList(Set.of(redisCm)); + Example redisExample = Example.of(redisTemplate); + + Company msTemplate = new Company(); + CompanyMeta msCm = new CompanyMeta(); + msCm.setTagValues(Set.of("MsTag")); + msTemplate.setMetaList(Set.of(msCm)); + Example msExample = Example.of(msTemplate); + + Company bothTemplate = new Company(); + CompanyMeta bothCm = new CompanyMeta(); + bothCm.setTagValues(Set.of("CommonTag")); + bothTemplate.setMetaList(Set.of(bothCm)); + Example bothExample = Example.of(bothTemplate); + + Iterable shouldBeOnlyRedis = companyRepository.findAll(redisExample); + Iterable shouldBeOnlyMS = companyRepository.findAll(msExample); + Iterable shouldBeBoth = companyRepository.findAll(bothExample); + + assertAll( // + () -> assertThat(shouldBeOnlyRedis).map(Company::getName).containsExactly("RedisInc"), // + () -> assertThat(shouldBeOnlyMS).map(Company::getName).containsExactly("Microsoft"), // + () -> assertThat(shouldBeBoth).map(Company::getName).containsExactlyInAnyOrder("RedisInc", + "Microsoft") // + ); + } + + @Test + void testFindByShouldReturnFirstResult() { + MyDoc template = new MyDoc(); + template.setTitle("llo"); + + ExampleMatcher matcher = ExampleMatcher.matching() + .withStringMatcher(StringMatcher.CONTAINING); + + Example example = Example.of(template, matcher); + + MyDoc result = repository.findBy(example, FetchableFluentQuery::firstValue); + assertThat(result).isNotNull().hasFieldOrPropertyWithValue("title", "hello world"); + } + + @Test + void testFindByShouldReturnOneResult() { + MyDoc template = new MyDoc(); + template.setTitle("hello world"); + + Example example = Example.of(template); + + MyDoc result = repository.findBy(example, FetchableFluentQuery::oneValue); + assertThat(result).isNotNull().hasFieldOrPropertyWithValue("title", "hello world"); + + MyDoc moreThanOneMatchTemplate = new MyDoc(); + moreThanOneMatchTemplate.setTitle("llo"); + ExampleMatcher matcher = ExampleMatcher.matching() + .withStringMatcher(StringMatcher.CONTAINING); + + assertThatExceptionOfType(IncorrectResultSizeDataAccessException.class).isThrownBy( + () -> repository.findBy(Example.of(moreThanOneMatchTemplate, matcher), FluentQuery.FetchableFluentQuery::one)); + } + + @Test + void testFindByShouldReturnAll() { + MyDoc template = new MyDoc(); + template.setTitle("llo"); + ExampleMatcher matcher = ExampleMatcher.matching() + .withStringMatcher(StringMatcher.CONTAINING); + + List result = repository.findBy(Example.of(template, matcher), FluentQuery.FetchableFluentQuery::all); + + assertThat(result).hasSize(2); + } + + @Test + void testFindByShouldApplySortAll() { + Company probe = new Company(); + + List result = companyRepository.findBy( // + Example.of(probe), // + it -> it.sortBy(Sort.by("name")).all() // + ); + + assertThat(result).map(Company::getName).containsExactly("Microsoft", "RedisInc"); + + result = companyRepository.findBy( // + Example.of(probe), // + it -> it.sortBy(Sort.by(Sort.Direction.DESC, "name")).all() // + ); + assertThat(result).map(Company::getName).containsExactly("RedisInc", "Microsoft"); + } + + @Test + void findByShouldApplyPagination() { + MyDoc template = new MyDoc(); + template.setLocation(new Point(-122.066540, 37.377690)); + + Page firstPage = repository.findBy(Example.of(template), + it -> it.page(PageRequest.of(0, 2, Sort.by("name")))); + assertThat(firstPage.getTotalElements()).isEqualTo(3); + assertThat(firstPage.getContent().size()).isEqualTo(2); + assertThat(firstPage.getContent().stream().toList()).map(MyDoc::getTitle).containsExactly("hello mundo", "ola mundo"); + + Page secondPage = repository.findBy(Example.of(template), + it -> it.page(PageRequest.of(1, 2, Sort.by("name")))); + + assertThat(secondPage.getTotalElements()).isEqualTo(3); + assertThat(secondPage.getContent().size()).isEqualTo(1); + assertThat(secondPage.getContent().stream().toList()).map(MyDoc::getTitle).containsExactly("bonjour le monde"); + } + + @Test + void testFindByShouldCount() { + MyDoc template = new MyDoc(); + template.setLocation(new Point(-122.066540, 37.377690)); + + long count = repository.findBy(Example.of(template), FluentQuery.FetchableFluentQuery::count); + assertThat(count).isEqualTo(3L); + + template = new MyDoc(); + template.setId(id1); + + count = repository.findBy(Example.of(template), FluentQuery.FetchableFluentQuery::count); + assertThat(count).isEqualTo(1L); + } + + @Test + void testFindByShouldReportExists() { + + MyDoc template = new MyDoc(); + template.setLocation(new Point(-122.066540, 37.377690)); + + boolean exists = repository.findBy(Example.of(template), FluentQuery.FetchableFluentQuery::exists); + assertThat(exists).isTrue(); + + template = new MyDoc(); + template.setId("8675309"); + + exists = repository.findBy(Example.of(template), FluentQuery.FetchableFluentQuery::exists); + assertThat(exists).isFalse(); + } + + @Test + void testFindByShouldApplyProjection() { + MyDoc template = new MyDoc(); + template.setTitle("hello world"); + + Example example = Example.of(template); + + MyDoc doc1 = repository.findBy(example, it -> it.project("aNumber").firstValue()); + assertThat(doc1.getANumber()).isNotNull(); + assertThat(doc1.getTitle()).isNull(); + } +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/DeepNest.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/DeepNest.java index 5d3e4ff23..c35861961 100644 --- a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/DeepNest.java +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/DeepNest.java @@ -7,6 +7,7 @@ @Data @RequiredArgsConstructor(staticName = "of") +@NoArgsConstructor @AllArgsConstructor(access = AccessLevel.PROTECTED) @Document public class DeepNest { diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/Doc3.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/Doc3.java new file mode 100644 index 000000000..ce0c86cd4 --- /dev/null +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/Doc3.java @@ -0,0 +1,26 @@ +package com.redis.om.spring.annotations.document.fixtures; + +import com.redis.om.spring.annotations.Document; +import com.redis.om.spring.annotations.Searchable; +import lombok.*; +import org.springframework.data.annotation.Id; + +@Data +@RequiredArgsConstructor(staticName = "of") +@AllArgsConstructor(access = AccessLevel.PROTECTED) +@NoArgsConstructor(force = true) +@Document +public class Doc3 { + @Id + private String id; + + @Searchable(sortable = true) + @NonNull + private String first; + + @Searchable(sortable = true) + private String second; + + @Searchable(sortable = true) + private String third; +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/Doc3Repository.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/Doc3Repository.java new file mode 100644 index 000000000..1b821a1c1 --- /dev/null +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/document/fixtures/Doc3Repository.java @@ -0,0 +1,8 @@ +package com.redis.om.spring.annotations.document.fixtures; + +import com.redis.om.spring.annotations.document.fixtures.Doc2; +import com.redis.om.spring.annotations.document.fixtures.Doc3; +import com.redis.om.spring.repository.RedisDocumentRepository; + +@SuppressWarnings("unused") public interface Doc3Repository extends RedisDocumentRepository { +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/RedisHashQueryByExampleTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/RedisHashQueryByExampleTest.java new file mode 100644 index 000000000..55fe5f5be --- /dev/null +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/RedisHashQueryByExampleTest.java @@ -0,0 +1,440 @@ +package com.redis.om.spring.annotations.hash; + +import com.redis.om.spring.AbstractBaseEnhancedRedisTest; +import com.redis.om.spring.annotations.hash.fixtures.Company; +import com.redis.om.spring.annotations.hash.fixtures.CompanyRepository; +import com.redis.om.spring.annotations.hash.fixtures.MyHash; +import com.redis.om.spring.annotations.hash.fixtures.MyHashRepository; +import com.redis.om.spring.search.stream.EntityStream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.dao.IncorrectResultSizeDataAccessException; +import org.springframework.data.domain.*; +import org.springframework.data.domain.ExampleMatcher.StringMatcher; +import org.springframework.data.geo.Point; +import org.springframework.data.repository.query.FluentQuery; +import org.springframework.data.repository.query.FluentQuery.FetchableFluentQuery; + +import java.time.LocalDate; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.junit.jupiter.api.Assertions.assertAll; + +public class RedisHashQueryByExampleTest extends AbstractBaseEnhancedRedisTest { + @Autowired + MyHashRepository repository; + + @Autowired + CompanyRepository companyRepository; + + @Autowired + EntityStream entityStream; + + String id1; + String id2; + + @BeforeEach + void loadTestData() { + repository.deleteAll(); + Point point1 = new Point(-122.124500, 47.640160); + MyHash hash1 = MyHash.of("hello world", point1, point1, 1); + Set tags = new HashSet<>(); + tags.add("news"); + tags.add("article"); + hash1.setTag(tags); + + Point point2 = new Point(-122.066540, 37.377690); + MyHash hash2 = MyHash.of("hello mundo", point2, point2, 2); + Set tags2 = new HashSet<>(); + tags2.add("noticias"); + tags2.add("articulo"); + hash2.setTag(tags2); + + Point point3 = new Point(-122.066540, 37.377690); + MyHash hash3 = MyHash.of("ola mundo", point3, point3, 3); + Set tags3 = new HashSet<>(); + tags3.add("noticias"); + tags3.add("artigo"); + hash3.setTag(tags3); + + Point point4 = new Point(-122.066540, 37.377690); + MyHash hash4 = MyHash.of("bonjour le monde", point4, point4, 3); + Set tags4 = new HashSet<>(); + tags4.add("actualite"); + tags4.add("article"); + hash4.setTag(tags4); + + repository.saveAll(List.of(hash1, hash2, hash3, hash4)); + + id1 = hash1.getId(); + id2 = hash2.getId(); + + companyRepository.deleteAll(); + Company redis = Company.of("RedisInc", 2011, LocalDate.of(2021, 5, 1), new Point(-122.066540, 37.377690), + "stack@redis.com"); + redis.setTags(Set.of("RedisTag", "CommonTag")); + + Company microsoft = Company.of("Microsoft", 1975, LocalDate.of(2022, 8, 15), + new Point(-122.124500, 47.640160), "research@microsoft.com"); + microsoft.setTags(Set.of("MsTag", "CommonTag")); + + companyRepository.saveAll(List.of(redis, microsoft)); + } + + @Test + void testFindById() { + MyHash template = new MyHash(); + template.setId(id1); + + Example example = Example.of(template); + + Optional maybeDoc1 = repository.findOne(example); + assertThat(maybeDoc1).isPresent(); + assertThat(maybeDoc1.get().getTitle()).isEqualTo("hello world"); + } + + @Test + void testFindByTextIndexedProperty() { + MyHash template = new MyHash(); + template.setTitle("hello world"); + + Example example = Example.of(template); + + Optional maybeDoc1 = repository.findOne(example); + assertThat(maybeDoc1).isPresent(); + assertThat(maybeDoc1.get().getTitle()).isEqualTo("hello world"); + } + + @Test + void testFindByFieldWithExplicitTagIndexedAnnotation() { + MyHash template = new MyHash(); + template.setTag(Set.of("news")); + + Example example = Example.of(template); + + Optional maybeDoc1 = repository.findOne(example); + assertThat(maybeDoc1).isPresent(); + MyHash doc1 = maybeDoc1.get(); + assertThat(doc1.getTitle()).isEqualTo("hello world"); + assertThat(doc1.getTag()).contains("news"); + } + + @Test + void testFindByFieldWithExplicitNumericIndexedAnnotation() { + MyHash template = new MyHash(); + template.setANumber(1); + + Example example = Example.of(template); + + Optional maybeDoc1 = repository.findOne(example); + assertThat(maybeDoc1).isPresent(); + MyHash doc1 = maybeDoc1.get(); + assertThat(doc1.getTitle()).isEqualTo("hello world"); + assertThat(doc1.getANumber()).isEqualTo(1); + } + + @Test + void testFindByFieldWithExplicitGeoIndexedAnnotation() { + MyHash template = new MyHash(); + template.setLocation(new Point(-122.066540, 37.377690)); + + Example example = Example.of(template); + + Optional maybeDoc1 = repository.findOne(example); + assertThat(maybeDoc1).isPresent(); + MyHash doc1 = maybeDoc1.get(); + assertThat(doc1.getTitle()).isEqualTo("hello mundo"); + assertThat(doc1.getANumber()).isEqualTo(2); + } + + @Test + void testFindByMultipleFields() { + MyHash template = new MyHash(); + template.setANumber(3); + template.setTag(Set.of("noticias")); + + Example example = Example.of(template); + + Optional maybeDoc1 = repository.findOne(example); + assertThat(maybeDoc1).isPresent(); + MyHash doc1 = maybeDoc1.get(); + assertThat(doc1.getTitle()).isEqualTo("ola mundo"); + assertThat(doc1.getTag()).contains("noticias"); + } + + @Test + public void findByExampleShouldReturnEmptyListIfNotMatching() { + MyHash template = new MyHash(); + template.setANumber(42); + + Example example = Example.of(template); + + Iterable noMatches = repository.findAll(example); + assertThat(noMatches).isEmpty(); + } + + @Test + public void findAllByExampleShouldReturnAllMatches() { + MyHash template = new MyHash(); + template.setTag(Set.of("noticias")); + + Example example = Example.of(template); + + Iterable allMatches = repository.findAll(example); + assertThat(allMatches).hasSize(2); + assertThat(allMatches).extracting("title").contains("hello mundo", "ola mundo"); + } + + @Test + public void findByExampleShouldReturnEverythingWhenSampleIsEmpty() { + MyHash template = new MyHash(); + + Example example = Example.of(template); + + Iterable allMatches = repository.findAll(example); + assertThat(allMatches).hasSize(4); + } + + @Test + public void findsExampleUsingAnyMatch() { + MyHash template = new MyHash(); + template.setTitle("hello world"); + template.setTag(Set.of("artigo")); + + Example example = Example.of(template, ExampleMatcher.matchingAny()); + + Iterable allMatches = repository.findAll(example); + assertThat(allMatches).hasSize(2); + assertThat(allMatches).extracting("title").contains("hello world", "ola mundo"); + } + + @Test + public void findsExampleUsingAnyMatch2() { + MyHash template = new MyHash(); + template.setTitle("hello world"); + template.setANumber(3); + + Example example = Example.of(template, ExampleMatcher.matchingAny()); + + Iterable allMatches = repository.findAll(example); + assertThat(allMatches).hasSize(3); + assertThat(allMatches).extracting("title").contains("hello world", "ola mundo", "bonjour le monde"); + } + + @Test + void testFindByTextPropertyStartingWith() { + MyHash template = new MyHash(); + template.setTitle("hello"); + + ExampleMatcher matcher = ExampleMatcher.matching() + .withStringMatcher(StringMatcher.STARTING); + + Example example = Example.of(template, matcher); + + Iterable allMatches = repository.findAll(example); + assertThat(allMatches).hasSize(2); + assertThat(allMatches).extracting("title").contains("hello world", "hello mundo"); + } + + @Test + void testFindByTextPropertyEndingWith() { + MyHash template = new MyHash(); + template.setTitle("ndo"); + + ExampleMatcher matcher = ExampleMatcher.matching() + .withStringMatcher(StringMatcher.ENDING); + + Example example = Example.of(template, matcher); + + Iterable allMatches = repository.findAll(example); + assertThat(allMatches).hasSize(2); + assertThat(allMatches).extracting("title").contains("ola mundo", "hello mundo"); + } + + @Test + void testFindByTextPropertyContaining() { + MyHash template = new MyHash(); + template.setTitle("llo"); + + ExampleMatcher matcher = ExampleMatcher.matching() + .withStringMatcher(StringMatcher.CONTAINING); + + Example example = Example.of(template, matcher); + + Iterable allMatches = repository.findAll(example); + assertThat(allMatches).hasSize(2); + assertThat(allMatches).extracting("title").contains("hello world", "hello mundo"); + } + + @Test + public void testFindWithIgnorePaths() { + MyHash template = new MyHash(); + template.setTitle("hello world"); + template.setANumber(3); + + Example example = Example.of(template, ExampleMatcher.matchingAny().withIgnorePaths("aNumber")); + + Iterable allMatches = repository.findAll(example); + assertThat(allMatches).hasSize(1); + assertThat(allMatches).extracting("title").contains("hello world"); + } + + @Test + void testFindAllByExampleWithTags() { + Company redisTemplate = new Company(); + redisTemplate.setTags(Set.of("RedisTag")); + Example redisExample = Example.of(redisTemplate); + + Company msTemplate = new Company(); + msTemplate.setTags(Set.of("MsTag")); + Example msExample = Example.of(msTemplate); + + Company bothTemplate = new Company(); + bothTemplate.setTags(Set.of("CommonTag")); + Example bothExample = Example.of(bothTemplate); + + Iterable shouldBeOnlyRedis = companyRepository.findAll(redisExample); + Iterable shouldBeOnlyMS = companyRepository.findAll(msExample); + Iterable shouldBeBoth = companyRepository.findAll(bothExample); + + assertAll( // + () -> assertThat(shouldBeOnlyRedis).map(Company::getName).containsExactly("RedisInc"), // + () -> assertThat(shouldBeOnlyMS).map(Company::getName).containsExactly("Microsoft"), // + () -> assertThat(shouldBeBoth).map(Company::getName).containsExactlyInAnyOrder("RedisInc", + "Microsoft") // + ); + } + + @Test + void testFindByShouldReturnFirstResult() { + MyHash template = new MyHash(); + template.setTitle("llo"); + + ExampleMatcher matcher = ExampleMatcher.matching() + .withStringMatcher(StringMatcher.CONTAINING); + + Example example = Example.of(template, matcher); + + MyHash result = repository.findBy(example, FetchableFluentQuery::firstValue); + assertThat(result).isNotNull().hasFieldOrPropertyWithValue("title", "hello world"); + } + + @Test + void testFindByShouldReturnOneResult() { + MyHash template = new MyHash(); + template.setTitle("hello world"); + + Example example = Example.of(template); + + MyHash result = repository.findBy(example, FetchableFluentQuery::oneValue); + assertThat(result).isNotNull().hasFieldOrPropertyWithValue("title", "hello world"); + + MyHash moreThanOneMatchTemplate = new MyHash(); + moreThanOneMatchTemplate.setTitle("llo"); + ExampleMatcher matcher = ExampleMatcher.matching() + .withStringMatcher(StringMatcher.CONTAINING); + + assertThatExceptionOfType(IncorrectResultSizeDataAccessException.class).isThrownBy( + () -> repository.findBy(Example.of(moreThanOneMatchTemplate, matcher), FluentQuery.FetchableFluentQuery::one)); + } + + @Test + void testFindByShouldReturnAll() { + MyHash template = new MyHash(); + template.setTitle("llo"); + ExampleMatcher matcher = ExampleMatcher.matching() + .withStringMatcher(StringMatcher.CONTAINING); + + List result = repository.findBy(Example.of(template, matcher), FluentQuery.FetchableFluentQuery::all); + + assertThat(result).hasSize(2); + } + + @Test + void testFindByShouldApplySortAll() { + Company probe = new Company(); + + List result = companyRepository.findBy( // + Example.of(probe), // + it -> it.sortBy(Sort.by("name")).all() // + ); + + assertThat(result).map(Company::getName).containsExactly("Microsoft", "RedisInc"); + + result = companyRepository.findBy( // + Example.of(probe), // + it -> it.sortBy(Sort.by(Sort.Direction.DESC, "name")).all() // + ); + assertThat(result).map(Company::getName).containsExactly("RedisInc", "Microsoft"); + } + + @Test + void findByShouldApplyPagination() { + MyHash template = new MyHash(); + template.setLocation(new Point(-122.066540, 37.377690)); + + Page firstPage = repository.findBy(Example.of(template), + it -> it.page(PageRequest.of(0, 2, Sort.by("name")))); + assertThat(firstPage.getTotalElements()).isEqualTo(3); + assertThat(firstPage.getContent().size()).isEqualTo(2); + assertThat(firstPage.getContent().stream().toList()).map(MyHash::getTitle).containsExactly("hello mundo", "ola mundo"); + + Page secondPage = repository.findBy(Example.of(template), + it -> it.page(PageRequest.of(1, 2, Sort.by("name")))); + + assertThat(secondPage.getTotalElements()).isEqualTo(3); + assertThat(secondPage.getContent().size()).isEqualTo(1); + assertThat(secondPage.getContent().stream().toList()).map(MyHash::getTitle).containsExactly("bonjour le monde"); + } + + @Test + void testFindByShouldCount() { + MyHash template = new MyHash(); + template.setLocation(new Point(-122.066540, 37.377690)); + + long count = repository.findBy(Example.of(template), FluentQuery.FetchableFluentQuery::count); + assertThat(count).isEqualTo(3L); + + template = new MyHash(); + template.setId(id1); + + count = repository.findBy(Example.of(template), FluentQuery.FetchableFluentQuery::count); + assertThat(count).isEqualTo(1L); + } + + @Test + void testFindByShouldReportExists() { + + MyHash template = new MyHash(); + template.setLocation(new Point(-122.066540, 37.377690)); + + boolean exists = repository.findBy(Example.of(template), FluentQuery.FetchableFluentQuery::exists); + assertThat(exists).isTrue(); + + template = new MyHash(); + template.setId("8675309"); + + exists = repository.findBy(Example.of(template), FluentQuery.FetchableFluentQuery::exists); + assertThat(exists).isFalse(); + } + + @Test + void testFindByShouldApplyProjection() { + MyHash template = new MyHash(); + template.setTitle("hello world"); + + Example example = Example.of(template); + + MyHash doc1 = repository.findBy(example, it -> it.project("aNumber").firstValue()); + assertThat(doc1.getANumber()).isNotNull(); + assertThat(doc1.getTitle()).isNull(); + } + + +} diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/fixtures/MyHashRepository.java b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/fixtures/MyHashRepository.java index 4d7af8ea6..5efa9aadb 100644 --- a/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/fixtures/MyHashRepository.java +++ b/redis-om-spring/src/test/java/com/redis/om/spring/annotations/hash/fixtures/MyHashRepository.java @@ -10,6 +10,7 @@ import redis.clients.jedis.search.SearchResult; import java.util.List; +import java.util.Optional; import java.util.Set; @SuppressWarnings({ "unused", "SpellCheckingInspection", "SpringDataMethodInconsistencyInspection" }) public interface MyHashRepository extends RedisEnhancedRepository, MyHashQueries { diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/search/stream/EntityStreamDocsTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/search/stream/EntityStreamDocsTest.java index 5da4f9d3e..eff232e3d 100644 --- a/redis-om-spring/src/test/java/com/redis/om/spring/search/stream/EntityStreamDocsTest.java +++ b/redis-om-spring/src/test/java/com/redis/om/spring/search/stream/EntityStreamDocsTest.java @@ -3,6 +3,7 @@ import com.google.common.collect.Iterators; import com.google.common.collect.Sets; import com.redis.om.spring.AbstractBaseDocumentTest; +import com.redis.om.spring.annotations.document.fixtures.Doc3Repository; import com.redis.om.spring.annotations.document.fixtures.*; import com.redis.om.spring.tuple.Fields; import com.redis.om.spring.tuple.Pair; @@ -37,6 +38,8 @@ @Autowired EntityStream entityStream; + @Autowired Doc3Repository doc3Repository; + String redisId; String microsoftId; String teslaId; @@ -111,6 +114,24 @@ nicRepository.saveAll(List.of(niRedis, niMicrosoft, niTesla)); } + + // entity with nullable properties for projection testing + if (doc3Repository.count() == 0) { + var doc31 = Doc3.of("doc3.1"); + doc31.setSecond("doc3.1 second"); + doc31.setThird("doc3.1 third"); + var doc32 = Doc3.of("doc3.2"); + doc32.setSecond("doc3.2 second"); + doc32.setThird("doc3.2 third"); + var doc33 = Doc3.of("doc3.3"); + doc33.setSecond("doc3.3 second"); + doc33.setThird("doc3.3 third"); + var doc34 = Doc3.of("doc3.4"); + doc34.setSecond("doc3.4 second"); + doc34.setThird("doc3.4 third"); + + doc3Repository.saveAll(List.of(doc31, doc32, doc33, doc34)); + } } @Test void testStreamSelectAll() { @@ -2390,4 +2411,24 @@ List names = companies.stream().map(Company::getName).collect(Collectors.toList()); assertThat(names).contains("RedisInc"); } + + @Test void testProjectProperties() { + List docs = entityStream // + .of(Doc3.class) // + .sorted(Doc3$.FIRST, SortOrder.DESC) // + .project(Fields.of(Doc3$.FIRST, Doc3$.THIRD)) // + .collect(Collectors.toList()); + + assertEquals(4, docs.size()); + + docs.forEach(d -> { + // projection fields are not null + assertThat(d.getFirst()).isNotNull(); + assertThat(d.getThird()).isNotNull(); + // non-projection nullable fields are null + assertThat(d.getSecond()).isNull(); + // id is always projected + assertThat(d.getId()).isNotNull(); + }); + } } diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/util/ObjectUtilsTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/util/ObjectUtilsTest.java index fc24663c1..4a5aba8cd 100644 --- a/redis-om-spring/src/test/java/com/redis/om/spring/util/ObjectUtilsTest.java +++ b/redis-om-spring/src/test/java/com/redis/om/spring/util/ObjectUtilsTest.java @@ -1,5 +1,6 @@ package com.redis.om.spring.util; +import com.google.common.collect.Sets; import com.redis.om.spring.AbstractBaseDocumentTest; import com.redis.om.spring.annotations.AutoComplete; import com.redis.om.spring.annotations.AutoCompletePayload; @@ -9,6 +10,10 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.domain.Page; +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Slice; +import org.springframework.data.domain.SliceImpl; import org.springframework.data.geo.Distance; import org.springframework.data.geo.Point; import org.springframework.data.redis.connection.RedisGeoCommands.DistanceUnit; @@ -27,7 +32,7 @@ @SuppressWarnings({ "ConstantConditions", "SpellCheckingInspection" }) class ObjectUtilsTest extends AbstractBaseDocumentTest { @Autowired - CompanyRepository companyRepository; + CompanyRepository companyRepository; @Autowired DocWithCustomNameIdRepository docWithCustomNameIdRepository; @@ -104,9 +109,8 @@ void testGetTargetClassName() throws SecurityException { assertThat(ObjectUtils.getTargetClassName(lofs.getClass().getTypeName())).isEqualTo(ArrayList.class.getTypeName()); assertThat(ObjectUtils.getTargetClassName(inta.getClass().getTypeName())).isEqualTo(int[].class.getTypeName()); assertThat(ObjectUtils.getTargetClassName(typeName)).isEqualTo(boolean.class.getTypeName()); - assertThat( - ObjectUtils.getTargetClassName("java.util.List")) - .isEqualTo(List.class.getTypeName()); + assertThat(ObjectUtils.getTargetClassName( + "java.util.List")).isEqualTo(List.class.getTypeName()); } @Test @@ -258,8 +262,8 @@ void testLcfirst() { void testUnQuote() { assertThat(ObjectUtils.unQuote("\"Spam\"")).isEqualTo("Spam"); assertThat(ObjectUtils.unQuote("Spam")).isEqualTo("Spam"); - assertThat(ObjectUtils.unQuote("\"The quick \\\"brown\\\" fox \\\"jumps\\\" over the lazy dog\"")) - .isEqualTo("The quick \\\"brown\\\" fox \\\"jumps\\\" over the lazy dog"); + assertThat(ObjectUtils.unQuote("\"The quick \\\"brown\\\" fox \\\"jumps\\\" over the lazy dog\"")).isEqualTo( + "The quick \\\"brown\\\" fox \\\"jumps\\\" over the lazy dog"); } @Test @@ -273,9 +277,9 @@ void testToUnderscoreSeparated() { assertThat(ObjectUtils.toUnderscoreSeparated("someValue")).isEqualTo("some_value"); assertThat(ObjectUtils.toUnderscoreSeparated("someOtherValue")).isEqualTo("some_other_value"); } - + @Test - void testIsPropertyAnnotatedWith() { + void testIsPropertyAnnotatedWith() { assertThat(ObjectUtils.isPropertyAnnotatedWith(Address.class, "city", Indexed.class)).isTrue(); assertThat(ObjectUtils.isPropertyAnnotatedWith(Address.class, "city", Searchable.class)).isFalse(); assertThat(ObjectUtils.isPropertyAnnotatedWith(Address.class, "street", Searchable.class)).isTrue(); @@ -288,4 +292,106 @@ void testIsPropertyAnnotatedWith() { assertThat(ObjectUtils.isPropertyAnnotatedWith(Airport.class, "state", Searchable.class)).isFalse(); assertThat(ObjectUtils.isPropertyAnnotatedWith(Airport.class, "nonExistentField", Searchable.class)).isFalse(); } + + @Test + void testGetValueByPath() { + Company redis = Company.of("RedisInc", 2011, LocalDate.of(2021, 5, 1), new Point(-122.066540, 37.377690), + "stack@redis.com"); + redis.setId("8675309"); + redis.setMetaList(Set.of(CompanyMeta.of("RD", 100, Set.of("RedisTag", "CommonTag")))); + redis.setTags(Set.of("fast", "scalable", "reliable", "database", "nosql")); + + Set employees = Sets.newHashSet(Employee.of("Brian Sam-Bodden"), Employee.of("Guy Royse"), + Employee.of("Justin Castilla")); + redis.setEmployees(employees); + + String id = (String) ObjectUtils.getValueByPath(redis, "$.id"); + String name = (String) ObjectUtils.getValueByPath(redis, "$.name"); + Integer yearFounded = (Integer) ObjectUtils.getValueByPath(redis, "$.yearFounded"); + LocalDate lastValuation = (LocalDate) ObjectUtils.getValueByPath(redis, "$.lastValuation"); + Point location = (Point) ObjectUtils.getValueByPath(redis, "$.location"); + Set tags = (Set) ObjectUtils.getValueByPath(redis, "$.tags[*]"); + String email = (String) ObjectUtils.getValueByPath(redis, "$.email"); + boolean publiclyListed = (boolean) ObjectUtils.getValueByPath(redis, "$.publiclyListed"); + Collection metaList_numberValue = (Collection) ObjectUtils.getValueByPath(redis, "$.metaList[0:].numberValue"); + Collection metaList_stringValue = (Collection) ObjectUtils.getValueByPath(redis, "$.metaList[0:].stringValue"); + Collection employees_name = (Collection) ObjectUtils.getValueByPath(redis, "$.employees[0:].name"); + + assertAll( // + () -> assertThat(id).isEqualTo(redis.getId()), + () -> assertThat(name).isEqualTo(redis.getName()), + () -> assertThat(yearFounded).isEqualTo(redis.getYearFounded()), + () -> assertThat(lastValuation).isEqualTo(redis.getLastValuation()), + () -> assertThat(location).isEqualTo(redis.getLocation()), + () -> assertThat(tags).isEqualTo(redis.getTags()), + () -> assertThat(email).isEqualTo(redis.getEmail()), + () -> assertThat(publiclyListed).isEqualTo(redis.isPubliclyListed()), + () -> assertThat(metaList_numberValue).containsExactlyElementsOf(redis.getMetaList().stream().map(CompanyMeta::getNumberValue).toList()), + () -> assertThat(metaList_stringValue).containsExactlyElementsOf(redis.getMetaList().stream().map(CompanyMeta::getStringValue).toList()), + () -> assertThat(employees_name).containsExactlyElementsOf(redis.getEmployees().stream().map(Employee::getName).toList()) + ); + } + + @Test + void testFlattenCollection() { + var nested = List.of(List.of(List.of("a", "b")), List.of("c", "d"), "e", List.of(List.of(List.of("f")))); + var flatten = ObjectUtils.flattenCollection(nested); + assertThat(flatten).containsExactly("a", "b", "c", "d", "e", "f"); + } + + @Test + void testPageFromSlice() { + List strings = List.of("Pantufla", "Mondongo", "Latifundio", "Alcachofa"); + Slice slice = new SliceImpl<>(strings); + + Page page = ObjectUtils.pageFromSlice(slice); + + assertThat(page.getContent()).hasSize(4); + assertThat(page.getContent().get(0)).isEqualTo("Pantufla"); + assertThat(page.getNumber()).isEqualTo(slice.getNumber()); + assertThat(page.getSize()).isEqualTo(slice.getSize()); + assertThat(page.getNumberOfElements()).isEqualTo(slice.getNumberOfElements()); + assertThat(page.getSort()).isEqualTo(slice.getSort()); + assertThat(page.hasContent()).isEqualTo(slice.hasContent()); + assertThat(page.hasNext()).isEqualTo(slice.hasNext()); + assertThat(page.hasPrevious()).isEqualTo(slice.hasPrevious()); + assertThat(page.isFirst()).isEqualTo(slice.isFirst()); + assertThat(page.isLast()).isEqualTo(slice.isLast()); + assertThat(page.nextPageable()).isEqualTo(slice.nextPageable()); + assertThat(page.previousPageable()).isEqualTo(slice.previousPageable()); + assertThat(page.getTotalPages()).isEqualTo(-1); + assertThat(page.getPageable()).isEqualTo(Pageable.ofSize(4)); + } + + @Test + public void testEmptyString() { + String result = ObjectUtils.replaceIfIllegalJavaIdentifierCharacter(""); + assertThat(result).isEqualTo(ObjectUtils.REPLACEMENT_CHARACTER.toString()); + } + + @Test + public void testValidIdentifier() { + String input = "validIdentifier"; + String result = ObjectUtils.replaceIfIllegalJavaIdentifierCharacter(input); + assertThat(result).isEqualTo(input); + } + + @Test + public void testInvalidStartCharacter() { + String result = ObjectUtils.replaceIfIllegalJavaIdentifierCharacter("1invalid"); + assertThat(result).startsWith(ObjectUtils.REPLACEMENT_CHARACTER.toString()); + } + + @Test + public void testInvalidPartCharacter() { + String result = ObjectUtils.replaceIfIllegalJavaIdentifierCharacter("invalid*identifier"); + assertThat(result).isEqualTo("invalid" + ObjectUtils.REPLACEMENT_CHARACTER.toString() + "identifier"); + } + + @Test + public void testCompletelyInvalidString() { + String result = ObjectUtils.replaceIfIllegalJavaIdentifierCharacter("!@#*%^&*()"); + String expected = "__________"; + assertThat(result).isEqualTo(expected); + } } diff --git a/redis-om-spring/src/test/java/com/redis/om/spring/util/SearchResultRawResponseToObjectConverterTest.java b/redis-om-spring/src/test/java/com/redis/om/spring/util/SearchResultRawResponseToObjectConverterTest.java new file mode 100644 index 000000000..e21aa96ca --- /dev/null +++ b/redis-om-spring/src/test/java/com/redis/om/spring/util/SearchResultRawResponseToObjectConverterTest.java @@ -0,0 +1,55 @@ +package com.redis.om.spring.util; + +import com.google.gson.Gson; +import lombok.Data; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.geo.Point; + +import java.util.Date; + +import static org.assertj.core.api.Assertions.assertThat; + +public class SearchResultRawResponseToObjectConverterTest { + @Autowired + private Gson gson; + + @Test + void shouldReturnNullWhenRawValueIsNull() { + assertThat(SearchResultRawResponseToObjectConverter.process(null, String.class, new Gson())).isNull(); + } + + @Test + void shouldProcessStringWhenTargetClassIsString() { + assertThat(SearchResultRawResponseToObjectConverter.process("hello".getBytes(), String.class, new Gson())).isEqualTo("hello"); + } + + @Test + void shouldProcessDateWhenTargetClassIsDate() { + Date date = new Date(); + assertThat(SearchResultRawResponseToObjectConverter.process(String.valueOf(date.getTime()).getBytes(), Date.class, new Gson())).isEqualTo(date); + } + + @Test + void shouldProcessPointWhenTargetClassIsPoint() { + Point point = new Point(12.34, 56.78); + assertThat(SearchResultRawResponseToObjectConverter.process("12.34,56.78".getBytes(), Point.class, new Gson())).isEqualTo(point); + } + + @Test + void shouldProcessBooleanWhenTargetClassIsBoolean() { + assertThat(SearchResultRawResponseToObjectConverter.process("1".getBytes(), Boolean.class, new Gson())).isEqualTo(true); + assertThat(SearchResultRawResponseToObjectConverter.process("0".getBytes(), Boolean.class, new Gson())).isEqualTo(false); + } + + @Data + static class MyClass { + private String name; + } + + @Test + void shouldProcessOtherObjectWhenTargetClassIsNotSpecial() { + MyClass target = new Gson().fromJson("{\"name\": \"Morgan\"}", MyClass.class); + assertThat(SearchResultRawResponseToObjectConverter.process("{\"name\": \"Morgan\"}".getBytes(), MyClass.class, new Gson())).isEqualTo(target); + } +}