Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,50 @@ values := []any{"aztec", "nuke", "", 2, 10}
(given "customdata" is configured with `filter.WithNestedJSONB("customdata", "password", "playerCount")`)


## Order By Support

In addition to filtering, this package also supports converting MongoDB-style sort objects into PostgreSQL ORDER BY clauses using the `ConvertOrderBy` method:

```go
// Convert a sort object to an ORDER BY clause
sortInput := []byte(`{"playerCount": -1, "name": 1}`)
orderBy, err := converter.ConvertOrderBy(sortInput)
if err != nil {
// handle error
}
fmt.Println(orderBy) // "playerCount" DESC, "name" ASC

db.Query("SELECT * FROM games ORDER BY " + orderBy)
```

### Sort Direction Values:
- `1`: Ascending (ASC)
- `-1`: Descending (DESC)

### Return value
The `ConvertOrderBy` method returns a string that can be directly used in an SQL ORDER BY clause. When the input is an empty object or `nil`, it returns an empty string. Keep in mind that the method does not add the `ORDER BY` keyword itself; you need to include it in your SQL query.

### JSONB Field Sorting:
For JSONB fields, the package generates sophisticated ORDER BY clauses that handle both numeric and text sorting:

```go
// With WithNestedJSONB("metadata", "created_at"):
sortInput := []byte(`{"score": -1}`)
orderBy, err := converter.ConvertOrderBy(sortInput)
// Generates: (CASE WHEN jsonb_typeof(metadata->'score') = 'number' THEN (metadata->>'score')::numeric END) DESC NULLS LAST, metadata->>'score' DESC NULLS LAST
```

This ensures proper sorting whether the JSONB field contains numeric or text values.

> [!TIP]
> Always add an `, id ASC` to your ORDER BY clause to ensure a consistent order (where `id` is your primary key).
> ```go
> if orderBy != "" {
> orderBy += ", "
> }
> orderBy += "id ASC"
> ```

## Difference with MongoDB

- The MongoDB query filters don't have the option to compare fields with each other. This package adds the `$field` operator to compare fields with each other.
Expand Down
103 changes: 90 additions & 13 deletions filter/converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
// `column != ANY(...)` does not work, so we need to do `NOT column = ANY(...)` instead.
neg = "NOT "
}
inner = append(inner, fmt.Sprintf("(%s%s = ANY($%d))", neg, c.columnName(key), paramIndex))
inner = append(inner, fmt.Sprintf("(%s%s = ANY($%d))", neg, c.columnName(key, true), paramIndex))
paramIndex++
if c.arrayDriver != nil {
v[operator] = c.arrayDriver(v[operator])
Expand Down Expand Up @@ -245,7 +245,7 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
//
// EXISTS (SELECT 1 FROM unnest("foo") AS __filter_placeholder WHERE ("__filter_placeholder"::text = $1))
//
inner = append(inner, fmt.Sprintf("EXISTS (SELECT 1 FROM unnest(%s) AS %s WHERE %s)", c.columnName(key), c.placeholderName, innerConditions))
inner = append(inner, fmt.Sprintf("EXISTS (SELECT 1 FROM unnest(%s) AS %s WHERE %s)", c.columnName(key, true), c.placeholderName, innerConditions))
}
values = append(values, innerValues...)
case "$field":
Expand All @@ -254,7 +254,7 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
return "", nil, fmt.Errorf("invalid value for $field operator (must be string): %v", v[operator])
}

inner = append(inner, fmt.Sprintf("(%s = %s)", c.columnName(key), c.columnName(vv)))
inner = append(inner, fmt.Sprintf("(%s = %s)", c.columnName(key, true), c.columnName(vv, true)))
default:
value := v[operator]
isNumericOperator := false
Expand All @@ -274,8 +274,8 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
return "", nil, fmt.Errorf("invalid value for %s operator (must be object with $field key only): %v", operator, value)
}

left := c.columnName(key)
right := c.columnName(field)
left := c.columnName(key, true)
right := c.columnName(field, true)

if isNumericOperator {
if c.isNestedColumn(key) {
Expand Down Expand Up @@ -304,9 +304,9 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
}

if isNumericOperator && isNumeric(value) && c.isNestedColumn(key) {
inner = append(inner, fmt.Sprintf("((%s)::numeric %s $%d)", c.columnName(key), op, paramIndex))
inner = append(inner, fmt.Sprintf("((%s)::numeric %s $%d)", c.columnName(key, true), op, paramIndex))
} else {
inner = append(inner, fmt.Sprintf("(%s %s $%d)", c.columnName(key), op, paramIndex))
inner = append(inner, fmt.Sprintf("(%s %s $%d)", c.columnName(key, true), op, paramIndex))
}
paramIndex++
values = append(values, value)
Expand All @@ -329,9 +329,9 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
}
}
if isNestedColumn {
conditions = append(conditions, fmt.Sprintf("(jsonb_path_match(%s, 'exists($.%s)') AND %s IS NULL)", c.nestedColumn, key, c.columnName(key)))
conditions = append(conditions, fmt.Sprintf("(jsonb_path_match(%s, 'exists($.%s)') AND %s IS NULL)", c.nestedColumn, key, c.columnName(key, true)))
} else {
conditions = append(conditions, fmt.Sprintf("(%s IS NULL)", c.columnName(key)))
conditions = append(conditions, fmt.Sprintf("(%s IS NULL)", c.columnName(key, true)))
}
default:
// Prevent cryptic errors like:
Expand All @@ -341,9 +341,9 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
}
if isNumeric(value) && c.isNestedColumn(key) {
// If the value is numeric and the column is a nested JSONB column, we need to cast the column to numeric.
conditions = append(conditions, fmt.Sprintf("((%s)::numeric = $%d)", c.columnName(key), paramIndex))
conditions = append(conditions, fmt.Sprintf("((%s)::numeric = $%d)", c.columnName(key, true), paramIndex))
} else {
conditions = append(conditions, fmt.Sprintf("(%s = $%d)", c.columnName(key), paramIndex))
conditions = append(conditions, fmt.Sprintf("(%s = $%d)", c.columnName(key, true), paramIndex))
}
paramIndex++
values = append(values, value)
Expand All @@ -358,7 +358,7 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
return result, values, nil
}

func (c *Converter) columnName(column string) string {
func (c *Converter) columnName(column string, jsonFieldAsText bool) string {
if column == c.placeholderName {
return fmt.Sprintf(`%q::text`, column)
}
Expand All @@ -370,7 +370,10 @@ func (c *Converter) columnName(column string) string {
return fmt.Sprintf("%q", column)
}
}
return fmt.Sprintf(`%q->>'%s'`, c.nestedColumn, column)
if jsonFieldAsText {
return fmt.Sprintf(`%q->>'%s'`, c.nestedColumn, column)
}
return fmt.Sprintf(`%q->'%s'`, c.nestedColumn, column)
}

func (c *Converter) isColumnAllowed(column string) bool {
Expand Down Expand Up @@ -404,3 +407,77 @@ func (c *Converter) isNestedColumn(column string) bool {
}
return true
}

// ConvertOrderBy converts a JSON object with field names and sort directions
// into a PostgreSQL ORDER BY clause. The JSON object should have keys with values
// of 1 (ASC) or -1 (DESC).
//
// For JSONB fields, it generates clauses that handle both numeric and text sorting.
//
// Example: {"playerCount": -1, "name": 1} -> "playerCount DESC, name ASC"
func (c *Converter) ConvertOrderBy(query []byte) (string, error) {
keyValues, err := objectInOrder(query)
if err != nil {
return "", err
}

parts := make([]string, 0, len(keyValues))

for _, kv := range keyValues {
key, value := kv.Key, kv.Value

if !isValidPostgresIdentifier(key) {
return "", fmt.Errorf("invalid column name: %s", key)
}
if !c.isColumnAllowed(key) {
return "", ColumnNotAllowedError{Column: key}
}

// Convert value to number for direction
var direction string
switch v := value.(type) {
case json.Number:
if num, err := v.Int64(); err == nil {
switch num {
case 1:
direction = "ASC"
case -1:
direction = "DESC"
default:
return "", fmt.Errorf("invalid order direction for field %s: %v (must be 1 or -1)", key, value)
}
} else {
return "", fmt.Errorf("invalid order direction for field %s: %v (must be 1 or -1)", key, value)
}
case float64:
switch v {
case 1:
direction = "ASC"
case -1:
direction = "DESC"
default:
return "", fmt.Errorf("invalid order direction for field %s: %v (must be 1 or -1)", key, value)
}
default:
return "", fmt.Errorf("invalid order direction for field %s: %v (must be 1 or -1)", key, value)
}

var fieldClause string
if c.isNestedColumn(key) {
// For JSONB fields, handle both numeric and text sorting.
// We need to use the raw JSONB reference for jsonb_typeof, but columnName() for the actual sorting
fieldClause = fmt.Sprintf("(CASE WHEN jsonb_typeof(%s) = 'number' THEN (%s)::numeric END) %s NULLS LAST, %s %s NULLS LAST", c.columnName(key, false), c.columnName(key, true), direction, c.columnName(key, true), direction)
} else {
// Regular field.
fieldClause = fmt.Sprintf(`%s %s NULLS LAST`, c.columnName(key, true), direction)
}

parts = append(parts, fieldClause)
}

if len(parts) == 0 {
return "", nil
}

return strings.Join(parts, ", "), nil
}
131 changes: 131 additions & 0 deletions filter/converter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -641,3 +641,134 @@ func TestConverter_AccessControl(t *testing.T) {
t.Run("nested but disallow password, disallow",
f(`{"password": "hacks"}`, no("password"), filter.WithNestedJSONB("meta", "created_at"), filter.WithDisallowColumns("password")))
}

func TestConverter_ConvertOrderBy(t *testing.T) {
tests := []struct {
name string
options []filter.Option
input string
expected string
err error
}{
{
"single field ascending",
[]filter.Option{filter.WithAllowAllColumns()},
`{"playerCount": 1}`,
`"playerCount" ASC NULLS LAST`,
nil,
},
{
"single field descending",
[]filter.Option{filter.WithAllowAllColumns()},
`{"playerCount": -1}`,
`"playerCount" DESC NULLS LAST`,
nil,
},
{
"multiple fields",
[]filter.Option{filter.WithAllowAllColumns()},
`{"playerCount": -1, "name": 1}`,
`"playerCount" DESC NULLS LAST, "name" ASC NULLS LAST`,
nil,
},
{
"nested JSONB single field ascending",
[]filter.Option{filter.WithNestedJSONB("customdata", "created_at")},
`{"map": 1}`,
`(CASE WHEN jsonb_typeof("customdata"->'map') = 'number' THEN ("customdata"->>'map')::numeric END) ASC NULLS LAST, "customdata"->>'map' ASC NULLS LAST`,
nil,
},
{
"nested JSONB single field descending",
[]filter.Option{filter.WithNestedJSONB("customdata", "created_at")},
`{"map": -1}`,
`(CASE WHEN jsonb_typeof("customdata"->'map') = 'number' THEN ("customdata"->>'map')::numeric END) DESC NULLS LAST, "customdata"->>'map' DESC NULLS LAST`,
nil,
},
{
"nested JSONB multiple fields",
[]filter.Option{filter.WithNestedJSONB("customdata", "created_at")},
`{"map": 1, "bar": -1}`,
`(CASE WHEN jsonb_typeof("customdata"->'map') = 'number' THEN ("customdata"->>'map')::numeric END) ASC NULLS LAST, "customdata"->>'map' ASC NULLS LAST, (CASE WHEN jsonb_typeof("customdata"->'bar') = 'number' THEN ("customdata"->>'bar')::numeric END) DESC NULLS LAST, "customdata"->>'bar' DESC NULLS LAST`,
nil,
},
{
"mixed nested and regular fields",
[]filter.Option{filter.WithNestedJSONB("customdata", "created_at")},
`{"created_at": 1, "map": -1}`,
`"created_at" ASC NULLS LAST, (CASE WHEN jsonb_typeof("customdata"->'map') = 'number' THEN ("customdata"->>'map')::numeric END) DESC NULLS LAST, "customdata"->>'map' DESC NULLS LAST`,
nil,
},
{
"field name with spaces",
[]filter.Option{filter.WithAllowAllColumns()},
`{"my_field": 1}`,
`"my_field" ASC NULLS LAST`,
nil,
},
{
"empty object",
[]filter.Option{filter.WithAllowAllColumns()},
`{}`,
``,
nil,
},
{
"invalid field name for SQL injection",
[]filter.Option{filter.WithAllowAllColumns()},
`{"my field": 1}`,
``,
fmt.Errorf("invalid column name: my field"),
},
{
"invalid direction value",
[]filter.Option{filter.WithAllowAllColumns()},
`{"playerCount": 2}`,
``,
fmt.Errorf("invalid order direction for field playerCount: 2 (must be 1 or -1)"),
},
{
"invalid direction string",
[]filter.Option{filter.WithAllowAllColumns()},
`{"playerCount": "asc"}`,
``,
fmt.Errorf("invalid order direction for field playerCount: asc (must be 1 or -1)"),
},
{
"disallowed column",
[]filter.Option{filter.WithAllowColumns("name")},
`{"playerCount": 1}`,
``,
filter.ColumnNotAllowedError{Column: "playerCount"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
converter, err := filter.NewConverter(tt.options...)
if err != nil {
t.Fatalf("Failed to create converter: %v", err)
}

result, err := converter.ConvertOrderBy([]byte(tt.input))

if tt.err != nil {
if err == nil {
t.Fatalf("Expected error %v, got nil", tt.err)
}
if err.Error() != tt.err.Error() {
t.Errorf("Expected error %v, got %v", tt.err, err)
}
return
}

if err != nil {
t.Fatalf("Unexpected error: %v", err)
}

if result != tt.expected {
t.Errorf("Expected %q, got %q", tt.expected, result)
}
})
}
}
9 changes: 9 additions & 0 deletions filter/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,12 @@ type ColumnNotAllowedError struct {
func (e ColumnNotAllowedError) Error() string {
return fmt.Sprintf("column not allowed: %s", e.Column)
}

type InvalidOrderDirectionError struct {
Field string
Value any
}

func (e InvalidOrderDirectionError) Error() string {
return fmt.Sprintf("invalid order direction for field %s: %v (must be 1 or -1)", e.Field, e.Value)
}
Loading
Loading