Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
431 changes: 240 additions & 191 deletions auth/auth_test.go

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions auth/collection_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type CollectionChannelAPI interface {
SetCollectionChannelHistory(scope, collection string, history TimedSetHistory)

// Returns true if the Principal has access to the given channel.
CanSeeCollectionChannel(scope, collection, channel string) bool
CanSeeCollectionChannel(scope, collection, channel string) (bool, error)

// Retrieve invalidation sequence for a collection
getCollectionChannelInvalSeq(scope, collection string) uint64
Expand All @@ -52,7 +52,7 @@ type CollectionChannelAPI interface {

// If the Principal has access to the given collection's channel, returns the sequence number at which
// access was granted; else returns zero.
canSeeCollectionChannelSince(scope, collection, channel string) uint64
canSeeCollectionChannelSince(scope, collection, channel string) (uint64, error)

// Returns an error if the Principal does not have access to all the channels in the set, for the specified collection.
authorizeAllCollectionChannels(scope, collection string, channels base.Set) error
Expand All @@ -76,23 +76,23 @@ type UserCollectionChannelAPI interface {
SetCollectionJWTChannels(scope, collection string, channels ch.TimedSet, seq uint64)

// Retrieves revoked channels for a collection, based on the given since value
RevokedCollectionChannels(scope, collection string, since uint64, lowSeq uint64, triggeredBy uint64) RevokedChannels
RevokedCollectionChannels(scope, collection string, since uint64, lowSeq uint64, triggeredBy uint64) (RevokedChannels, error)

// Obtains the period over which the user had access to the given collection's channel. Either directly or via a role.
CollectionChannelGrantedPeriods(scope, collection, chanName string) ([]GrantHistorySequencePair, error)

// Every channel the user has access to in the collection, including those inherited from Roles.
InheritedCollectionChannels(scope, collection string) ch.TimedSet
InheritedCollectionChannels(scope, collection string) (ch.TimedSet, error)

// Returns a TimedSet containing only the channels from the input set that the user has access
// to for the collection, annotated with the sequence number at which access was granted.
// Returns a string array containing any channels filtered out due to the user not having access
// to them.
FilterToAvailableCollectionChannels(scope, collection string, channels base.Set) (filtered ch.TimedSet, removed []string)
FilterToAvailableCollectionChannels(scope, collection string, channels base.Set) (filtered ch.TimedSet, removed []string, err error)

// If the input set contains the wildcard "*" channel, returns the user's inheritedChannels for the collection;
// else returns the input channel list unaltered.
expandCollectionWildCardChannel(scope, collection string, channels base.Set) base.Set
expandCollectionWildCardChannel(scope, collection string, channels base.Set) (base.Set, error)
}

// PrincipalCollectionAccess defines a common interface for principal access control. This interface is
Expand Down
195 changes: 101 additions & 94 deletions auth/collection_access_test.go

Large diffs are not rendered by default.

36 changes: 36 additions & 0 deletions auth/error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright 2025-Present Couchbase, Inc.
//
// Use of this software is governed by the Business Source License included
// in the file licenses/BSL-Couchbase.txt. As of the Change Date specified
// in that file, in accordance with the Business Source License, use of this
// software will be governed by the Apache License, Version 2.0, included in
// the file licenses/APL2.txt.

package auth

import (
"fmt"
"net/http"

"github.com/couchbase/sync_gateway/base"
)

var (
errLoginRequired = base.HTTPErrorf(http.StatusUnauthorized, "login required")
// errUnauthorized is a generic error message
errUnauthorized = base.HTTPErrorf(http.StatusForbidden, "You are not allowed to see this")
// errUnauthorizedChannels is used when we cannot determine which channels are unauthorized
errUnauthorizedChannels = base.HTTPErrorf(http.StatusForbidden, "Unauthorized to see channels")
// errNotAllowedChannels is used when we can determine which channels are not allowed
errNotAllowedChannels = base.HTTPErrorf(http.StatusForbidden, "You are not allowed to see channels")
)

// newErrUnauthorizedChannels creates an error indicating the user is not authorized to see the specified channels. Used when we can not determine which channels are unauthorized.
func newErrUnauthorizedChannels(channels base.Set) error {
return fmt.Errorf("%w %v", errUnauthorizedChannels, channels)
}

// newErrNotAllowedChannels creates an error indicating the user is not allowed to see the specified channels. Used when we can determine which channels are not allowed.
func newErrNotAllowedChannels[T base.Set | []string](channels T) error {
return fmt.Errorf("%w %v", errNotAllowedChannels, channels)
}
18 changes: 9 additions & 9 deletions auth/principal.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ type Principal interface {
SetSequence(sequence uint64)

// Returns true if the Principal has access to the given channel.
canSeeChannel(channel string) bool
canSeeChannel(channel string) (bool, error)

// If the Principal has access to the given channel, returns the sequence number at which
// access was granted; else returns zero.
canSeeChannelSince(channel string) uint64
canSeeChannelSince(channel string) (uint64, error)

// Returns an error if the Principal does not have access to all the channels in the set.
authorizeAllChannels(channels base.Set) error
Expand All @@ -39,7 +39,7 @@ type Principal interface {

// Returns an appropriate HTTPError for unauthorized access -- a 401 if the receiver is
// the guest user, else 403.
UnauthError(message string) error
UnauthError(err error) error

DocID() string
accessViewKey() string
Expand Down Expand Up @@ -104,7 +104,7 @@ type User interface {
SetPassword(password string) error

// GetRoles returns the set of roles the user belongs to, initializing them if necessary.
GetRoles() []Role
GetRoles() ([]Role, error)

// The set of Roles the user belongs to (including ones given to it by the sync function and by OIDC/JWT)
// Returns nil if invalidated
Expand Down Expand Up @@ -135,25 +135,25 @@ type User interface {

RoleHistory() TimedSetHistory

InitializeRoles()
InitializeRoles() error

revokedChannels(since uint64, lowSeq uint64, triggeredBy uint64) RevokedChannels
revokedChannels(since uint64, lowSeq uint64, triggeredBy uint64) (RevokedChannels, error)

// Obtains the period over which the user had access to the given channel. Either directly or via a role.
channelGrantedPeriods(chanName string) ([]GrantHistorySequencePair, error)

// Every channel the user has access to, including those inherited from Roles.
inheritedChannels() ch.TimedSet
inheritedChannels() (ch.TimedSet, error)

// If the input set contains the wildcard "*" channel, returns the user's InheritedChannels;
// else returns the input channel list unaltered.
expandWildCardChannel(channels base.Set) base.Set
expandWildCardChannel(channels base.Set) (base.Set, error)

// Returns a TimedSet containing only the channels from the input set that the user has access
// to, annotated with the sequence number at which access was granted.
// Returns a string array containing any channels filtered out due to the user not having access
// to them.
filterToAvailableChannels(channels base.Set) (filtered ch.TimedSet, removed []string)
filterToAvailableChannels(channels base.Set) (filtered ch.TimedSet, removed []string, err error)

setRolesSince(ch.TimedSet)

Expand Down
31 changes: 20 additions & 11 deletions auth/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,26 +363,27 @@ func (role *roleImpl) validate() error {

//////// CHANNEL AUTHORIZATION:

func (role *roleImpl) UnauthError(message string) error {
// UnauthError returns the underlying error unless Role is the guest user.
func (role *roleImpl) UnauthError(err error) error {
if role.Name_ == "" {
return base.HTTPErrorf(http.StatusUnauthorized, "login required: %s", message)
return errLoginRequired
}
return base.NewHTTPError(http.StatusForbidden, message)
return err
}

// Returns true if the Role is allowed to access the channel.
// A nil Role means access control is disabled, so the function will return true.
func (role *roleImpl) canSeeChannel(channel string) bool {
return role == nil || role.Channels().Contains(channel) || role.Channels().Contains(ch.UserStarChannel)
func (role *roleImpl) canSeeChannel(channel string) (bool, error) {
return role == nil || role.Channels().Contains(channel) || role.Channels().Contains(ch.UserStarChannel), nil
}

// Returns the sequence number since which the Role has been able to access the channel, else zero.
func (role *roleImpl) canSeeChannelSince(channel string) uint64 {
func (role *roleImpl) canSeeChannelSince(channel string) (uint64, error) {
seq := role.Channels()[channel]
if seq.Sequence == 0 {
seq = role.Channels()[ch.UserStarChannel]
}
return seq.Sequence
return seq.Sequence, nil
}

func (role *roleImpl) authorizeAllChannels(channels base.Set) error {
Expand All @@ -398,15 +399,19 @@ func (role *roleImpl) authorizeAnyChannel(channels base.Set) error {
func authorizeAllChannels(princ Principal, channels base.Set) error {
var forbidden []string
for channel := range channels {
if !princ.canSeeChannel(channel) {
canSee, err := princ.canSeeChannel(channel)
if err != nil {
return err
}
if !canSee {
if forbidden == nil {
forbidden = make([]string, 0, len(channels))
}
forbidden = append(forbidden, channel)
}
}
if forbidden != nil {
return princ.UnauthError(fmt.Sprintf("You are not allowed to see channels %v", forbidden))
return princ.UnauthError(newErrNotAllowedChannels(forbidden))
}
return nil
}
Expand All @@ -416,12 +421,16 @@ func authorizeAllChannels(princ Principal, channels base.Set) error {
func authorizeAnyChannel(princ Principal, channels base.Set) error {
if len(channels) > 0 {
for channel := range channels {
if princ.canSeeChannel(channel) {
canSee, err := princ.canSeeChannel(channel)
if err != nil {
return err
}
if canSee {
return nil
}
}
} else if princ.Channels().Contains(ch.UserStarChannel) {
return nil
}
return princ.UnauthError("You are not allowed to see this")
return princ.UnauthError(errUnauthorized)
}
22 changes: 10 additions & 12 deletions auth/role_collection_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
package auth

import (
"fmt"

"github.com/couchbase/sync_gateway/base"
ch "github.com/couchbase/sync_gateway/channels"
)
Expand Down Expand Up @@ -149,22 +147,22 @@ func (role *roleImpl) SetCollectionChannelHistory(scope, collection string, hist

// Returns true if the Role is allowed to access the channel.
// A nil Role means access control is disabled, so the function will return true.
func (role *roleImpl) CanSeeCollectionChannel(scope, collection, channel string) bool {
func (role *roleImpl) CanSeeCollectionChannel(scope, collection, channel string) (bool, error) {
if base.IsDefaultCollection(scope, collection) {
return role.canSeeChannel(channel)
}

if role == nil {
return true
return true, nil
}
if cc, ok := role.getCollectionAccess(scope, collection); ok {
return cc.CanSeeChannel(channel)
return cc.CanSeeChannel(channel), nil
}
return false
return false, nil
}

// Returns the sequence number since which the Role has been able to access the channel, else zero.
func (role *roleImpl) canSeeCollectionChannelSince(scope, collection, channel string) uint64 {
func (role *roleImpl) canSeeCollectionChannelSince(scope, collection, channel string) (uint64, error) {
if base.IsDefaultCollection(scope, collection) {
return role.canSeeChannelSince(channel)
}
Expand All @@ -174,9 +172,9 @@ func (role *roleImpl) canSeeCollectionChannelSince(scope, collection, channel st
if seq.Sequence == 0 {
seq = cc.Channels()[ch.UserStarChannel]
}
return seq.Sequence
return seq.Sequence, nil
}
return 0
return 0, nil
}

func (role *roleImpl) authorizeAllCollectionChannels(scope, collection string, channels base.Set) error {
Expand All @@ -195,11 +193,11 @@ func (role *roleImpl) authorizeAllCollectionChannels(scope, collection string, c
}
}
if forbidden != nil {
return role.UnauthError(fmt.Sprintf("You are not allowed to see channels %v", forbidden))
return role.UnauthError(newErrNotAllowedChannels(forbidden))
}
return nil
}
return role.UnauthError(fmt.Sprintf("Unauthorized to see channels %v", channels))
return role.UnauthError(newErrUnauthorizedChannels(channels))
}

// Returns an error if the Principal does not have access to any of the channels in the set.
Expand All @@ -219,7 +217,7 @@ func (role *roleImpl) AuthorizeAnyCollectionChannel(scope, collection string, ch
return nil
}
}
return role.UnauthError("You are not allowed to see this")
return role.UnauthError(errUnauthorized)
}

// initChannels grants the specified channels to the role as an admin grant, and performs
Expand Down
Loading
Loading