diff --git a/auth/auth_test.go b/auth/auth_test.go index c062a74b5b..d9e40be592 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -35,14 +35,18 @@ func NewTestAuthenticator(t testing.TB, dataStore sgbucket.DataStore, channelCom // requireCanSeeChannels asserts that the given principal can see all of the specified channels. func requireCanSeeChannels(t *testing.T, princ Principal, channels ...string) { for _, channel := range channels { - require.True(t, princ.canSeeChannel(channel), "Expected %s to be able to see channel %q", princ.Name(), channel) + canSee, err := princ.canSeeChannel(channel) + require.NoError(t, err) + require.True(t, canSee, "Expected %s to be able to see channel %q", princ.Name(), channel) } } // assertCannotSeeChannels asserts that the given principal cannot see any of the specified channels. func requireCannotSeeChannels(t *testing.T, princ Principal, channels ...string) { for _, channel := range channels { - require.False(t, princ.canSeeChannel(channel), "Expected %s to NOT be able to see channel %q", princ.Name(), channel) + canSee, err := princ.canSeeChannel(channel) + require.NoError(t, err) + require.False(t, canSee, "Expected %s to NOT be able to see channel %q", princ.Name(), channel) } } @@ -558,8 +562,10 @@ func TestRoleInheritance(t *testing.T) { user2, err := auth.GetUser("arthur") require.NoError(t, err) log.Printf("Channels = %s", user2.Channels()) + user2Channels, err := user2.inheritedChannels() + require.NoError(t, err) assert.Equal(t, ch.AtSequence(ch.BaseSetOf(t, "!", "britain"), 1), user2.Channels()) - assert.Equal(t, ch.TimedSet{"!": ch.NewVbSimpleSequence(0x1), "britain": ch.NewVbSimpleSequence(0x1), "dull": ch.NewVbSimpleSequence(0x3), "duller": ch.NewVbSimpleSequence(0x3), "dullest": ch.NewVbSimpleSequence(0x3), "hoopy": ch.NewVbSimpleSequence(0x4), "hoopier": ch.NewVbSimpleSequence(0x4), "hoopiest": ch.NewVbSimpleSequence(0x4)}, user2.inheritedChannels()) + assert.Equal(t, ch.TimedSet{"!": ch.NewVbSimpleSequence(0x1), "britain": ch.NewVbSimpleSequence(0x1), "dull": ch.NewVbSimpleSequence(0x3), "duller": ch.NewVbSimpleSequence(0x3), "dullest": ch.NewVbSimpleSequence(0x3), "hoopy": ch.NewVbSimpleSequence(0x4), "hoopier": ch.NewVbSimpleSequence(0x4), "hoopiest": ch.NewVbSimpleSequence(0x4)}, user2Channels) requireCanSeeChannels(t, user2, "britain", "duller", "hoopy") require.NoError(t, user2.authorizeAllChannels(ch.BaseSetOf(t, "britain", "dull", "hoopiest"))) @@ -1556,7 +1562,8 @@ func TestRevocationScenario1(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined := aliceUserPrincipal.revokedChannels(5, 0, 0) + revokedChannelsCombined, err := aliceUserPrincipal.revokedChannels(5, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) // Get Principals / Rebuild Seq 40 @@ -1569,7 +1576,8 @@ func TestRevocationScenario1(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(25, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(25, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) testMockComputer.removeRole(t, auth, "alice", "foo", 45) @@ -1588,7 +1596,8 @@ func TestRevocationScenario1(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(40, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(40, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) testMockComputer.removeRoleChannel(t, auth, "foo", "ch1", 85) @@ -1606,7 +1615,8 @@ func TestRevocationScenario1(t *testing.T) { channelHistory, ok := fooPrincipal.ChannelHistory()["ch1"] require.True(t, ok) assert.Equal(t, GrantHistorySequencePair{StartSeq: 75, EndSeq: 85}, channelHistory.Entries[0]) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(80, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(80, 0, 0) + require.NoError(t, err) require.Contains(t, revokedChannelsCombined, "ch1") assert.Equal(t, uint64(85), revokedChannelsCombined["ch1"]) } @@ -1651,7 +1661,8 @@ func TestRevocationScenario2(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined := aliceUserPrincipal.revokedChannels(5, 0, 0) + revokedChannelsCombined, err := aliceUserPrincipal.revokedChannels(5, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) testMockComputer.removeRole(t, auth, "alice", "foo", 45) @@ -1667,7 +1678,8 @@ func TestRevocationScenario2(t *testing.T) { assert.Equal(t, GrantHistorySequencePair{StartSeq: 20, EndSeq: 45}, userRoleHistory.Entries[0]) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(25, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(25, 0, 0) + require.NoError(t, err) require.Contains(t, revokedChannelsCombined, "ch1") assert.Equal(t, uint64(45), revokedChannelsCombined["ch1"]) @@ -1684,7 +1696,8 @@ func TestRevocationScenario2(t *testing.T) { assert.Equal(t, GrantHistorySequencePair{StartSeq: 20, EndSeq: 45}, userRoleHistory.Entries[0]) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(50, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(50, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) testMockComputer.removeRoleChannel(t, auth, "foo", "ch1", 85) @@ -1707,7 +1720,8 @@ func TestRevocationScenario2(t *testing.T) { assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(80, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(80, 0, 0) + require.NoError(t, err) require.Contains(t, revokedChannelsCombined, "ch1") assert.Equal(t, uint64(85), revokedChannelsCombined["ch1"]) } @@ -1752,7 +1766,8 @@ func TestRevocationScenario3(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined := aliceUserPrincipal.revokedChannels(55, 0, 0) + revokedChannelsCombined, err := aliceUserPrincipal.revokedChannels(55, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) testMockComputer.removeRole(t, auth, "alice", "foo", 45) @@ -1774,7 +1789,8 @@ func TestRevocationScenario3(t *testing.T) { assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(25, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(25, 0, 0) + require.NoError(t, err) require.Contains(t, revokedChannelsCombined, "ch1") assert.Equal(t, uint64(45), revokedChannelsCombined["ch1"]) @@ -1791,7 +1807,8 @@ func TestRevocationScenario3(t *testing.T) { assert.Len(t, fooPrincipal.ChannelHistory(), 1) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(60, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(60, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) testMockComputer.removeRoleChannel(t, auth, "foo", "ch1", 85) @@ -1817,7 +1834,8 @@ func TestRevocationScenario3(t *testing.T) { assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(80, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(80, 0, 0) + require.NoError(t, err) require.Contains(t, revokedChannelsCombined, "ch1") assert.Equal(t, uint64(85), revokedChannelsCombined["ch1"]) } @@ -1862,7 +1880,8 @@ func TestRevocationScenario4(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined := aliceUserPrincipal.revokedChannels(5, 0, 0) + revokedChannelsCombined, err := aliceUserPrincipal.revokedChannels(5, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) testMockComputer.removeRole(t, auth, "alice", "foo", 45) @@ -1881,7 +1900,8 @@ func TestRevocationScenario4(t *testing.T) { assert.Equal(t, GrantHistorySequencePair{StartSeq: 5, EndSeq: 55}, channelHistory.Entries[0]) assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(25, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(25, 0, 0) + require.NoError(t, err) require.Contains(t, revokedChannelsCombined, "ch1") assert.Equal(t, uint64(55), revokedChannelsCombined["ch1"]) @@ -1896,7 +1916,8 @@ func TestRevocationScenario4(t *testing.T) { assert.Len(t, fooPrincipal.ChannelHistory(), 1) assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(70, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(70, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) testMockComputer.removeRoleChannel(t, auth, "foo", "ch1", 85) @@ -1916,7 +1937,8 @@ func TestRevocationScenario4(t *testing.T) { assert.Equal(t, GrantHistorySequencePair{StartSeq: 5, EndSeq: 55}, channelHistory.Entries[0]) assert.Equal(t, GrantHistorySequencePair{StartSeq: 75, EndSeq: 85}, channelHistory.Entries[1]) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(80, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(80, 0, 0) + require.NoError(t, err) require.Contains(t, revokedChannelsCombined, "ch1") assert.Equal(t, uint64(85), revokedChannelsCombined["ch1"]) } @@ -1959,7 +1981,8 @@ func TestRevocationScenario5(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined := aliceUserPrincipal.revokedChannels(5, 0, 0) + revokedChannelsCombined, err := aliceUserPrincipal.revokedChannels(5, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) testMockComputer.removeRole(t, auth, "alice", "foo", 45) @@ -1977,7 +2000,8 @@ func TestRevocationScenario5(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(25, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(25, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) testMockComputer.removeRoleChannel(t, auth, "foo", "ch1", 85) @@ -1997,7 +2021,8 @@ func TestRevocationScenario5(t *testing.T) { assert.Equal(t, GrantHistorySequencePair{StartSeq: 75, EndSeq: 85}, channelHistory.Entries[0]) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(80, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(80, 0, 0) + require.NoError(t, err) require.Contains(t, revokedChannelsCombined, "ch1") assert.Equal(t, uint64(85), revokedChannelsCombined["ch1"]) } @@ -2040,7 +2065,8 @@ func TestRevocationScenario6(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined := aliceUserPrincipal.revokedChannels(5, 0, 0) + revokedChannelsCombined, err := aliceUserPrincipal.revokedChannels(5, 0, 0) + require.NoError(t, err) require.Len(t, revokedChannelsCombined, 0) testMockComputer.removeRole(t, auth, "alice", "foo", 45) @@ -2062,7 +2088,8 @@ func TestRevocationScenario6(t *testing.T) { assert.Equal(t, GrantHistorySequencePair{StartSeq: 5, EndSeq: 55}, channelHistory.Entries[0]) assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(25, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(25, 0, 0) + require.NoError(t, err) require.Contains(t, revokedChannelsCombined, "ch1") assert.Equal(t, uint64(55), revokedChannelsCombined["ch1"]) @@ -2082,7 +2109,8 @@ func TestRevocationScenario6(t *testing.T) { require.True(t, ok) assert.Equal(t, GrantHistorySequencePair{StartSeq: 5, EndSeq: 55}, channelHistory.Entries[0]) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(25, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(25, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 1) assert.Equal(t, uint64(55), revokedChannelsCombined["ch1"]) } @@ -2125,7 +2153,8 @@ func TestRevocationScenario7(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined := aliceUserPrincipal.revokedChannels(5, 0, 0) + revokedChannelsCombined, err := aliceUserPrincipal.revokedChannels(5, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) testMockComputer.removeRole(t, auth, "alice", "foo", 45) @@ -2153,7 +2182,8 @@ func TestRevocationScenario7(t *testing.T) { assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(25, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(25, 0, 0) + require.NoError(t, err) require.Contains(t, revokedChannelsCombined, "ch1") assert.Equal(t, uint64(45), revokedChannelsCombined["ch1"]) @@ -2167,7 +2197,8 @@ func TestRevocationScenario7(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 1) assert.Len(t, fooPrincipal.ChannelHistory(), 1) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(100, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(100, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) } @@ -2208,7 +2239,8 @@ func TestRevocationScenario8(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined := aliceUserPrincipal.revokedChannels(50, 0, 0) + revokedChannelsCombined, err := aliceUserPrincipal.revokedChannels(50, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) testMockComputer.removeRoleChannel(t, auth, "foo", "ch1", 55) @@ -2230,7 +2262,8 @@ func TestRevocationScenario8(t *testing.T) { assert.Equal(t, GrantHistorySequencePair{StartSeq: 5, EndSeq: 55}, channelHistory.Entries[0]) assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(50, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(50, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) } @@ -2272,7 +2305,8 @@ func TestRevocationScenario9(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined := aliceUserPrincipal.revokedChannels(5, 0, 0) + revokedChannelsCombined, err := aliceUserPrincipal.revokedChannels(5, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) testMockComputer.addRole(t, auth, "alice", "foo", 65) @@ -2290,7 +2324,8 @@ func TestRevocationScenario9(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(60, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(60, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) } @@ -2334,7 +2369,8 @@ func TestRevocationScenario10(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined := aliceUserPrincipal.revokedChannels(5, 0, 0) + revokedChannelsCombined, err := aliceUserPrincipal.revokedChannels(5, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) testMockComputer.addRoleChannels(t, auth, "foo", "ch1", 75) @@ -2353,7 +2389,8 @@ func TestRevocationScenario10(t *testing.T) { assert.Equal(t, GrantHistorySequencePair{StartSeq: 65, EndSeq: 95}, userRoleHistory.Entries[0]) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(70, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(70, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) } @@ -2398,7 +2435,8 @@ func TestRevocationScenario11(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined := aliceUserPrincipal.revokedChannels(5, 0, 0) + revokedChannelsCombined, err := aliceUserPrincipal.revokedChannels(5, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) testMockComputer.removeRoleChannel(t, auth, "foo", "ch1", 85) @@ -2421,7 +2459,8 @@ func TestRevocationScenario11(t *testing.T) { assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(80, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(80, 0, 0) + require.NoError(t, err) require.Contains(t, revokedChannelsCombined, "ch1") assert.Equal(t, uint64(85), revokedChannelsCombined["ch1"]) } @@ -2469,7 +2508,8 @@ func TestRevocationScenario12(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined := aliceUserPrincipal.revokedChannels(5, 0, 0) + revokedChannelsCombined, err := aliceUserPrincipal.revokedChannels(5, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) testMockComputer.removeRole(t, auth, "alice", "foo", 95) @@ -2485,7 +2525,8 @@ func TestRevocationScenario12(t *testing.T) { assert.Equal(t, GrantHistorySequencePair{StartSeq: 65, EndSeq: 95}, userRoleHistory.Entries[0]) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(90, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(90, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) } @@ -2533,7 +2574,8 @@ func TestRevocationScenario13(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined := aliceUserPrincipal.revokedChannels(5, 0, 0) + revokedChannelsCombined, err := aliceUserPrincipal.revokedChannels(5, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) // Rebuild seq 110 @@ -2545,7 +2587,8 @@ func TestRevocationScenario13(t *testing.T) { assert.Len(t, aliceUserPrincipal.RoleHistory(), 0) assert.Len(t, aliceUserPrincipal.ChannelHistory(), 0) assert.Len(t, fooPrincipal.ChannelHistory(), 0) - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(100, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(100, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) } @@ -2588,7 +2631,8 @@ func TestRevocationScenario14(t *testing.T) { assert.Equal(t, GrantHistorySequencePair{StartSeq: 20, EndSeq: 45}, userRoleHistory.Entries[0]) // Ensure that a since 25 shows the revocation - revokedChannelsCombined := aliceUserPrincipal.revokedChannels(25, 0, 0) + revokedChannelsCombined, err := aliceUserPrincipal.revokedChannels(25, 0, 0) + require.NoError(t, err) require.Contains(t, revokedChannelsCombined, "ch1") assert.Equal(t, uint64(45), revokedChannelsCombined["ch1"]) @@ -2596,7 +2640,8 @@ func TestRevocationScenario14(t *testing.T) { requireCannotSeeChannels(t, aliceUserPrincipal, "ch1") // Ensure a pull from 45 (same seq as revocation) wouldn't send message - revokedChannelsCombined = aliceUserPrincipal.revokedChannels(45, 0, 0) + revokedChannelsCombined, err = aliceUserPrincipal.revokedChannels(45, 0, 0) + require.NoError(t, err) assert.Len(t, revokedChannelsCombined, 0) } @@ -2691,7 +2736,9 @@ func TestObtainChannelsForDeletedRole(t *testing.T) { assert.NoError(t, err) // Successfully able to get inherited channels even though role is missing - assert.Equal(t, []string{"!"}, user.inheritedChannels().AllKeys()) + inheritedChannels, err := user.inheritedChannels() + require.NoError(t, err) + assert.Equal(t, []string{"!"}, inheritedChannels.AllKeys()) }, }, { @@ -2706,7 +2753,9 @@ func TestObtainChannelsForDeletedRole(t *testing.T) { assert.NoError(t, err) // Successfully able to get inherited channels even though role is missing - assert.Equal(t, []string{"!"}, user.inheritedChannels().AllKeys()) + inheritedChannels, err := user.inheritedChannels() + require.NoError(t, err) + assert.Equal(t, []string{"!"}, inheritedChannels.AllKeys()) }, }, } @@ -2962,5 +3011,7 @@ func TestCalculateMaxHistoryEntriesPerGrant(t *testing.T) { // requireExpandWildCardChannel is a helper function to assert that a user's wildcard channel expansion produces the expected result. func requireExpandWildCardChannel(t *testing.T, user User, expectedChannels, channelsToExpand []string) { - assert.Equal(t, base.SetFromArray(expectedChannels), user.expandWildCardChannel(base.SetFromArray(channelsToExpand)), "Expected channels %v to expand to %v", expectedChannels, channelsToExpand) + expandedChannels, err := user.expandWildCardChannel(base.SetFromArray(channelsToExpand)) + require.NoError(t, err) + assert.Equal(t, base.SetFromArray(expectedChannels), expandedChannels, "Expected channels %v to expand to %v", expectedChannels, channelsToExpand) } diff --git a/auth/collection_access.go b/auth/collection_access.go index 62cd3f8200..6ebf5d5e21 100644 --- a/auth/collection_access.go +++ b/auth/collection_access.go @@ -38,7 +38,7 @@ type CollectionChannelAPI interface { SetCollectionChannelHistory(scope, collection string, history TimedSetHistory) // Returns true if the Principal has access to the given channel. - CanSeeCollectionChannel(scope, collection, channel string) bool + CanSeeCollectionChannel(scope, collection, channel string) (bool, error) // Retrieve invalidation sequence for a collection getCollectionChannelInvalSeq(scope, collection string) uint64 @@ -52,7 +52,7 @@ type CollectionChannelAPI interface { // If the Principal has access to the given collection's channel, returns the sequence number at which // access was granted; else returns zero. - canSeeCollectionChannelSince(scope, collection, channel string) uint64 + canSeeCollectionChannelSince(scope, collection, channel string) (uint64, error) // Returns an error if the Principal does not have access to all the channels in the set, for the specified collection. authorizeAllCollectionChannels(scope, collection string, channels base.Set) error @@ -76,23 +76,23 @@ type UserCollectionChannelAPI interface { SetCollectionJWTChannels(scope, collection string, channels ch.TimedSet, seq uint64) // Retrieves revoked channels for a collection, based on the given since value - RevokedCollectionChannels(scope, collection string, since uint64, lowSeq uint64, triggeredBy uint64) RevokedChannels + RevokedCollectionChannels(scope, collection string, since uint64, lowSeq uint64, triggeredBy uint64) (RevokedChannels, error) // Obtains the period over which the user had access to the given collection's channel. Either directly or via a role. CollectionChannelGrantedPeriods(scope, collection, chanName string) ([]GrantHistorySequencePair, error) // Every channel the user has access to in the collection, including those inherited from Roles. - InheritedCollectionChannels(scope, collection string) ch.TimedSet + InheritedCollectionChannels(scope, collection string) (ch.TimedSet, error) // Returns a TimedSet containing only the channels from the input set that the user has access // to for the collection, annotated with the sequence number at which access was granted. // Returns a string array containing any channels filtered out due to the user not having access // to them. - FilterToAvailableCollectionChannels(scope, collection string, channels base.Set) (filtered ch.TimedSet, removed []string) + FilterToAvailableCollectionChannels(scope, collection string, channels base.Set) (filtered ch.TimedSet, removed []string, err error) // If the input set contains the wildcard "*" channel, returns the user's inheritedChannels for the collection; // else returns the input channel list unaltered. - expandCollectionWildCardChannel(scope, collection string, channels base.Set) base.Set + expandCollectionWildCardChannel(scope, collection string, channels base.Set) (base.Set, error) } // PrincipalCollectionAccess defines a common interface for principal access control. This interface is diff --git a/auth/collection_access_test.go b/auth/collection_access_test.go index a46a216a1f..8f328c00ed 100644 --- a/auth/collection_access_test.go +++ b/auth/collection_access_test.go @@ -22,14 +22,18 @@ import ( // requireCanSeeCollectionChannels asserts that the principal can see all the specified channels in the given collection func requireCanSeeCollectionChannels(t *testing.T, scope, collection string, princ Principal, channels ...string) { for _, channel := range channels { - require.True(t, princ.CanSeeCollectionChannel(scope, collection, channel), "Expected %s to be able to see channel %q in %s.%s", princ.Name(), channel, scope, collection) + canSee, err := princ.CanSeeCollectionChannel(scope, collection, channel) + require.NoError(t, err) + require.True(t, canSee, "Expected %s to be able to see channel %q in %s.%s", princ.Name(), channel, scope, collection) } } // requireCannotSeeCollectionChannels asserts that the principal cannot see any of the specified channels in the given collection func requireCannotSeeCollectionChannels(t *testing.T, scope, collection string, princ Principal, channels ...string) { for _, channel := range channels { - require.False(t, princ.CanSeeCollectionChannel(scope, collection, channel), "Expected %s to NOT be able to see channel %q in %s.%s", princ.Name(), channel, scope, collection) + canSee, err := princ.CanSeeCollectionChannel(scope, collection, channel) + require.NoError(t, err) + require.False(t, canSee, "Expected %s to NOT be able to see channel %q in %s.%s", princ.Name(), channel, scope, collection) } } @@ -277,5 +281,7 @@ func TestPrincipalConfigSetExplicitChannels(t *testing.T) { // requireExpandCollectionWildCardChannels asserts that the channels will be expanded to the expected channels for the given collection func requireExpandCollectionWildCardChannels(t *testing.T, user User, scope, collection string, expectedChannels []string, channelsToExpand []string) { - require.Equal(t, base.SetFromArray(expectedChannels), user.expandCollectionWildCardChannel(scope, collection, base.SetFromArray(channelsToExpand)), "Expected channels %v for %s.%s from %v on user %s", expectedChannels, scope, collection, channelsToExpand, user.Name()) + expandedChannels, err := user.expandCollectionWildCardChannel(scope, collection, base.SetFromArray(channelsToExpand)) + require.NoError(t, err) + require.Equal(t, base.SetFromArray(expectedChannels), expandedChannels, "Expected channels %v for %s.%s from %v on user %s", expectedChannels, scope, collection, channelsToExpand, user.Name()) } diff --git a/auth/principal.go b/auth/principal.go index 64f69bd895..223e7f34cb 100644 --- a/auth/principal.go +++ b/auth/principal.go @@ -25,11 +25,11 @@ type Principal interface { SetSequence(sequence uint64) // Returns true if the Principal has access to the given channel. - canSeeChannel(channel string) bool + canSeeChannel(channel string) (bool, error) // If the Principal has access to the given channel, returns the sequence number at which // access was granted; else returns zero. - canSeeChannelSince(channel string) uint64 + canSeeChannelSince(channel string) (uint64, error) // Returns an error if the Principal does not have access to all the channels in the set. authorizeAllChannels(channels base.Set) error @@ -104,7 +104,7 @@ type User interface { SetPassword(password string) error // GetRoles returns the set of roles the user belongs to, initializing them if necessary. - GetRoles() []Role + GetRoles() ([]Role, error) // The set of Roles the user belongs to (including ones given to it by the sync function and by OIDC/JWT) // Returns nil if invalidated @@ -135,25 +135,25 @@ type User interface { RoleHistory() TimedSetHistory - InitializeRoles() + InitializeRoles() error - revokedChannels(since uint64, lowSeq uint64, triggeredBy uint64) RevokedChannels + revokedChannels(since uint64, lowSeq uint64, triggeredBy uint64) (RevokedChannels, error) // Obtains the period over which the user had access to the given channel. Either directly or via a role. channelGrantedPeriods(chanName string) ([]GrantHistorySequencePair, error) // Every channel the user has access to, including those inherited from Roles. - inheritedChannels() ch.TimedSet + inheritedChannels() (ch.TimedSet, error) // If the input set contains the wildcard "*" channel, returns the user's InheritedChannels; // else returns the input channel list unaltered. - expandWildCardChannel(channels base.Set) base.Set + expandWildCardChannel(channels base.Set) (base.Set, error) // Returns a TimedSet containing only the channels from the input set that the user has access // to, annotated with the sequence number at which access was granted. // Returns a string array containing any channels filtered out due to the user not having access // to them. - filterToAvailableChannels(channels base.Set) (filtered ch.TimedSet, removed []string) + filterToAvailableChannels(channels base.Set) (filtered ch.TimedSet, removed []string, err error) setRolesSince(ch.TimedSet) diff --git a/auth/role.go b/auth/role.go index 47b5735f2d..cca1ec16e2 100644 --- a/auth/role.go +++ b/auth/role.go @@ -373,17 +373,17 @@ func (role *roleImpl) UnauthError(err error) error { // Returns true if the Role is allowed to access the channel. // A nil Role means access control is disabled, so the function will return true. -func (role *roleImpl) canSeeChannel(channel string) bool { - return role == nil || role.Channels().Contains(channel) || role.Channels().Contains(ch.UserStarChannel) +func (role *roleImpl) canSeeChannel(channel string) (bool, error) { + return role == nil || role.Channels().Contains(channel) || role.Channels().Contains(ch.UserStarChannel), nil } // Returns the sequence number since which the Role has been able to access the channel, else zero. -func (role *roleImpl) canSeeChannelSince(channel string) uint64 { +func (role *roleImpl) canSeeChannelSince(channel string) (uint64, error) { seq := role.Channels()[channel] if seq.Sequence == 0 { seq = role.Channels()[ch.UserStarChannel] } - return seq.Sequence + return seq.Sequence, nil } func (role *roleImpl) authorizeAllChannels(channels base.Set) error { @@ -399,7 +399,11 @@ func (role *roleImpl) authorizeAnyChannel(channels base.Set) error { func authorizeAllChannels(princ Principal, channels base.Set) error { var forbidden []string for channel := range channels { - if !princ.canSeeChannel(channel) { + canSee, err := princ.canSeeChannel(channel) + if err != nil { + return err + } + if !canSee { if forbidden == nil { forbidden = make([]string, 0, len(channels)) } @@ -417,7 +421,11 @@ func authorizeAllChannels(princ Principal, channels base.Set) error { func authorizeAnyChannel(princ Principal, channels base.Set) error { if len(channels) > 0 { for channel := range channels { - if princ.canSeeChannel(channel) { + canSee, err := princ.canSeeChannel(channel) + if err != nil { + return err + } + if canSee { return nil } } diff --git a/auth/role_collection_access.go b/auth/role_collection_access.go index e2be51529a..0c7704c1b5 100644 --- a/auth/role_collection_access.go +++ b/auth/role_collection_access.go @@ -147,22 +147,22 @@ func (role *roleImpl) SetCollectionChannelHistory(scope, collection string, hist // Returns true if the Role is allowed to access the channel. // A nil Role means access control is disabled, so the function will return true. -func (role *roleImpl) CanSeeCollectionChannel(scope, collection, channel string) bool { +func (role *roleImpl) CanSeeCollectionChannel(scope, collection, channel string) (bool, error) { if base.IsDefaultCollection(scope, collection) { return role.canSeeChannel(channel) } if role == nil { - return true + return true, nil } if cc, ok := role.getCollectionAccess(scope, collection); ok { - return cc.CanSeeChannel(channel) + return cc.CanSeeChannel(channel), nil } - return false + return false, nil } // Returns the sequence number since which the Role has been able to access the channel, else zero. -func (role *roleImpl) canSeeCollectionChannelSince(scope, collection, channel string) uint64 { +func (role *roleImpl) canSeeCollectionChannelSince(scope, collection, channel string) (uint64, error) { if base.IsDefaultCollection(scope, collection) { return role.canSeeChannelSince(channel) } @@ -172,9 +172,9 @@ func (role *roleImpl) canSeeCollectionChannelSince(scope, collection, channel st if seq.Sequence == 0 { seq = cc.Channels()[ch.UserStarChannel] } - return seq.Sequence + return seq.Sequence, nil } - return 0 + return 0, nil } func (role *roleImpl) authorizeAllCollectionChannels(scope, collection string, channels base.Set) error { diff --git a/auth/user.go b/auth/user.go index a27ccede0b..199e0d6f22 100644 --- a/auth/user.go +++ b/auth/user.go @@ -274,7 +274,7 @@ func (revokedChannels RevokedChannels) add(chanName string, triggeredBy uint64) } } -func (user *userImpl) revokedChannels(since uint64, lowSeq uint64, triggeredBy uint64) RevokedChannels { +func (user *userImpl) revokedChannels(since uint64, lowSeq uint64, triggeredBy uint64) (RevokedChannels, error) { return user.RevokedCollectionChannels(base.DefaultScope, base.DefaultCollection, since, lowSeq, triggeredBy) } @@ -289,7 +289,7 @@ func (user *userImpl) revokedChannels(since uint64, lowSeq uint64, triggeredBy u // // Get user: // - Revoke users revoked channels -func (user *userImpl) RevokedCollectionChannels(scope string, collection string, since uint64, lowSeq uint64, triggeredBy uint64) RevokedChannels { +func (user *userImpl) RevokedCollectionChannels(scope string, collection string, since uint64, lowSeq uint64, triggeredBy uint64) (RevokedChannels, error) { // checkSeq represents the value that we use to 'diff' against ie. What channels did the user have at checkSeq but // no longer has. // In the event we have a lowSeq that will be used. @@ -310,7 +310,10 @@ func (user *userImpl) RevokedCollectionChannels(scope string, collection string, // If there has been a revocation somewhere after the since value or we're in an interrupted revocation backfill // at the point a revocation occurred we should return this as a channel to revoke. - accessibleChannels := user.InheritedCollectionChannels(scope, collection) + accessibleChannels, err := user.InheritedCollectionChannels(scope, collection) + if err != nil { + return nil, err + } // Get revoked roles rolesToRevoke := map[string]uint64{} roleHistory := user.RoleHistory() @@ -385,14 +388,19 @@ func (user *userImpl) RevokedCollectionChannels(scope string, collection string, // Iterate over current roles and revoke any revoked channels inside role provided that channel isn't accessible // from another grant - for _, role := range user.GetRolesIncDeleted() { + + roles, err := user.GetRolesIncDeleted() + if err != nil { + return nil, err + } + for _, role := range roles { revokeChannelHistoryProcessing(role) } // Lastly get the revoked channels based off of channel history on the user itself revokeChannelHistoryProcessing(user) - return combinedRevokedChannels + return combinedRevokedChannels, nil } func (user *userImpl) channelGrantedPeriods(chanName string) ([]GrantHistorySequencePair, error) { @@ -432,8 +440,12 @@ func (user *userImpl) CollectionChannelGrantedPeriods(scope, collection, chanNam } } + roles, err := user.GetRoles() + if err != nil { + return nil, err + } // Iterate over current roles - for _, currentRole := range user.GetRoles() { + for _, currentRole := range roles { // Grab pairs from channel history on current roles roleChannelHistory, ok := currentRole.CollectionChannelHistory(scope, collection)[chanName] @@ -567,7 +579,7 @@ func (user *userImpl) SetPassword(password string) error { // ////// CHANNEL ACCESS: -func (user *userImpl) GetRoles() []Role { +func (user *userImpl) GetRoles() ([]Role, error) { if user.roles == nil { roles := make([]Role, 0, len(user.RoleNames())) deletedRoles := make([]Role, 0) @@ -575,7 +587,7 @@ func (user *userImpl) GetRoles() []Role { role, err := user.auth.GetRoleIncDeleted(name) // base.InfofCtx(user.auth.LogCtx, base.KeyAccess, "User %s role %q = %v", base.UD(user.Name_), base.UD(name), base.UD(role)) if err != nil { - panic(fmt.Sprintf("Error getting user role %q: %v", name, err)) + return nil, fmt.Errorf("Error getting user role %q: %w", name, err) } else if role != nil { if role.IsDeleted() { deletedRoles = append(deletedRoles, role) @@ -587,40 +599,67 @@ func (user *userImpl) GetRoles() []Role { user.roles = roles user.deletedRoles = deletedRoles } - return user.roles + return user.roles, nil } -func (user *userImpl) GetRolesIncDeleted() []Role { +func (user *userImpl) GetRolesIncDeleted() ([]Role, error) { // Use GetRoles to retrieve active roles (will fetch roles if needed) - allRoles := user.GetRoles() + allRoles, err := user.GetRoles() + if err != nil { + return nil, err + } allRoles = append(allRoles, user.deletedRoles...) - return allRoles + return allRoles, nil } -func (user *userImpl) InitializeRoles() { - _ = user.GetRoles() +func (user *userImpl) InitializeRoles() error { + _, err := user.GetRoles() + return err } -func (user *userImpl) canSeeChannel(channel string) bool { - if user.roleImpl.canSeeChannel(channel) { - return true +func (user *userImpl) canSeeChannel(channel string) (bool, error) { + canSee, err := user.roleImpl.canSeeChannel(channel) + if err != nil { + return canSee, err + } + if canSee { + return true, nil } - for _, role := range user.GetRoles() { - if role.canSeeChannel(channel) { - return true + roles, err := user.GetRoles() + if err != nil { + return false, err + } + for _, role := range roles { + canSee, err := role.canSeeChannel(channel) + if err != nil { + return false, err + } + if canSee { + return true, nil } } - return false + return false, nil } -func (user *userImpl) canSeeChannelSince(channel string) uint64 { - minSeq := user.roleImpl.canSeeChannelSince(channel) - for _, role := range user.GetRoles() { - if seq := role.canSeeChannelSince(channel); seq > 0 && (seq < minSeq || minSeq == 0) { +func (user *userImpl) canSeeChannelSince(channel string) (uint64, error) { + minSeq, err := user.roleImpl.canSeeChannelSince(channel) + if err != nil { + return 0, err + } + roles, err := user.GetRoles() + if err != nil { + return 0, err + } + for _, role := range roles { + seq, err := role.canSeeChannelSince(channel) + if err != nil { + return 0, err + } + if seq > 0 && (seq < minSeq || minSeq == 0) { minSeq = seq } } - return minSeq + return minSeq, nil } func (user *userImpl) authorizeAllChannels(channels base.Set) error { @@ -631,9 +670,13 @@ func (user *userImpl) authorizeAnyChannel(channels base.Set) error { return authorizeAnyChannel(user, channels) } -func (user *userImpl) inheritedChannels() ch.TimedSet { +func (user *userImpl) inheritedChannels() (ch.TimedSet, error) { channels := user.Channels().Copy() - for _, role := range user.GetRoles() { + roles, err := user.GetRoles() + if err != nil { + return nil, err + } + for _, role := range roles { roleSince := user.RoleNames()[role.Name()] channels.AddAtSequence(role.Channels(), roleSince.Sequence) } @@ -648,48 +691,64 @@ func (user *userImpl) inheritedChannels() ch.TimedSet { } }) - return channels + return channels, nil } // If a channel list contains the all-channel wildcard, replace it with all the user's accessible channels. -func (user *userImpl) expandWildCardChannel(channels base.Set) base.Set { +func (user *userImpl) expandWildCardChannel(channels base.Set) (base.Set, error) { return user.expandCollectionWildCardChannel(base.DefaultScope, base.DefaultCollection, channels) } -func (user *userImpl) expandCollectionWildCardChannel(scope, collection string, channels base.Set) base.Set { +func (user *userImpl) expandCollectionWildCardChannel(scope, collection string, channels base.Set) (base.Set, error) { if channels.Contains(ch.AllChannelWildcard) { - channels = user.InheritedCollectionChannels(scope, collection).AsSet() + allChannels, err := user.InheritedCollectionChannels(scope, collection) + if err != nil { + return nil, err + } + return allChannels.AsSet(), nil } - return channels + return channels, nil } -func (user *userImpl) filterToAvailableChannels(channelNames base.Set) (filtered ch.TimedSet, removed []string) { +func (user *userImpl) filterToAvailableChannels(channelNames base.Set) (filtered ch.TimedSet, removed []string, err error) { return user.FilterToAvailableCollectionChannels(base.DefaultScope, base.DefaultCollection, channelNames) } -func (user *userImpl) FilterToAvailableCollectionChannels(scope, collection string, channelNames base.Set) (filtered ch.TimedSet, removed []string) { +func (user *userImpl) FilterToAvailableCollectionChannels(scope, collection string, channelNames base.Set) (filtered ch.TimedSet, removed []string, err error) { filtered = ch.TimedSet{} for channelName, _ := range channelNames { if channelName == ch.AllChannelWildcard { - return user.InheritedCollectionChannels(scope, collection).Copy(), nil + channels, err := user.InheritedCollectionChannels(scope, collection) + if err != nil { + return nil, nil, err + } + return channels, nil, nil } - added := filtered.AddChannel(channelName, user.canSeeCollectionChannelSince(scope, collection, channelName)) + seq, err := user.canSeeCollectionChannelSince(scope, collection, channelName) + if err != nil { + return nil, nil, err + } + added := filtered.AddChannel(channelName, seq) if !added { removed = append(removed, channelName) } } - return filtered, removed + return filtered, removed, nil } -func (user *userImpl) GetAddedChannels(channels ch.TimedSet) base.Set { +func (user *userImpl) GetAddedChannels(channels ch.TimedSet) (base.Set, error) { + allChannels, err := user.inheritedChannels() + if err != nil { + return nil, err + } output := base.Set{} - for userChannel := range user.inheritedChannels() { + for userChannel := range allChannels { _, found := channels[userChannel] if !found { output[userChannel] = struct{}{} } } - return output + return output, nil } // ////// MARSHALING: diff --git a/auth/user_collection_access.go b/auth/user_collection_access.go index 494111e14d..16ae3c247b 100644 --- a/auth/user_collection_access.go +++ b/auth/user_collection_access.go @@ -37,22 +37,37 @@ func (user *userImpl) SetCollectionJWTChannels(scope, collection string, channel cc.ChannelInvalSeq = invalSeq } -func (user *userImpl) CanSeeCollectionChannel(scope, collection, channel string) bool { - if user.roleImpl.CanSeeCollectionChannel(scope, collection, channel) { - return true +func (user *userImpl) CanSeeCollectionChannel(scope, collection, channel string) (bool, error) { + canSee, err := user.roleImpl.CanSeeCollectionChannel(scope, collection, channel) + if err != nil { + return false, err } - for _, role := range user.GetRoles() { - if role.CanSeeCollectionChannel(scope, collection, channel) { - return true + if canSee { + return true, nil + } + roles, err := user.GetRoles() + if err != nil { + return false, err + } + for _, role := range roles { + canSee, err := role.CanSeeCollectionChannel(scope, collection, channel) + if err != nil { + return false, err + } else if canSee { + return true, nil } } - return false + return false, nil } -func (user *userImpl) InheritedCollectionChannels(scope, collection string) ch.TimedSet { +func (user *userImpl) InheritedCollectionChannels(scope, collection string) (ch.TimedSet, error) { channels := user.CollectionChannels(scope, collection).Copy() - for _, role := range user.GetRoles() { + roles, err := user.GetRoles() + if err != nil { + return nil, err + } + for _, role := range roles { roleSince := user.RoleNames()[role.Name()] channels.AddAtSequence(role.CollectionChannels(scope, collection), roleSince.Sequence) } @@ -68,7 +83,7 @@ func (user *userImpl) InheritedCollectionChannels(scope, collection string) ch.T } }) - return channels + return channels, nil } // Checks for user access to any channel in the set, including access inherited via roles @@ -92,7 +107,11 @@ func (user *userImpl) AuthorizeAnyCollectionChannel(scope, collection string, ch } // Inherited role access - for _, role := range user.GetRoles() { + roles, err := user.GetRoles() + if err != nil { + return err + } + for _, role := range roles { if role.AuthorizeAnyCollectionChannel(scope, collection, channels) == nil { return nil } @@ -101,12 +120,22 @@ func (user *userImpl) AuthorizeAnyCollectionChannel(scope, collection string, ch return user.UnauthError(errUnauthorized) } -func (user *userImpl) canSeeCollectionChannelSince(scope, collection, channel string) uint64 { - minSeq := user.roleImpl.canSeeCollectionChannelSince(scope, collection, channel) - for _, role := range user.GetRoles() { - if seq := role.canSeeCollectionChannelSince(scope, collection, channel); seq > 0 && (seq < minSeq || minSeq == 0) { +func (user *userImpl) canSeeCollectionChannelSince(scope, collection, channel string) (uint64, error) { + minSeq, err := user.roleImpl.canSeeCollectionChannelSince(scope, collection, channel) + if err != nil { + return 0, err + } + roles, err := user.GetRoles() + if err != nil { + return 0, err + } + for _, role := range roles { + seq, err := role.canSeeCollectionChannelSince(scope, collection, channel) + if err != nil { + return 0, err + } else if seq > 0 && (seq < minSeq || minSeq == 0) { minSeq = seq } } - return minSeq + return minSeq, nil } diff --git a/auth/user_test.go b/auth/user_test.go index 18e30d97c0..aef4abb8a1 100644 --- a/auth/user_test.go +++ b/auth/user_test.go @@ -277,9 +277,13 @@ func TestCanSeeChannelSince(t *testing.T) { "video": channels.NewVbSimpleSequence(1)}) for channel := range freeChannels { - assert.Equal(t, uint64(1), user.canSeeChannelSince(channel)) + seq, err := user.canSeeChannelSince(channel) + require.NoError(t, err) + assert.Equal(t, uint64(1), seq, "expected to see channels %q since %q", channel, seq) } - assert.Equal(t, uint64(0), user.canSeeChannelSince("unknown")) + seq, err := user.canSeeChannelSince("unknown") + require.NoError(t, err) + assert.Equal(t, uint64(0), seq) } func TestGetAddedChannels(t *testing.T) { @@ -306,10 +310,10 @@ func TestGetAddedChannels(t *testing.T) { "music": channels.NewVbSimpleSequence(0x5), "video": channels.NewVbSimpleSequence(0x6)}) - addedChannels := user.(*userImpl).GetAddedChannels(channels.TimedSet{ + addedChannels, err := user.(*userImpl).GetAddedChannels(channels.TimedSet{ "ESPN": channels.NewVbSimpleSequence(0x5), "HBO": channels.NewVbSimpleSequence(0x6)}) - + require.NoError(t, err) expectedChannels := channels.BaseSetOf(t, "!", "AMC", "FX", "Hulu", "Netflix", "Spotify", "Youtube") log.Printf("Added Channels: %v", addedChannels) assert.Equal(t, expectedChannels, addedChannels) diff --git a/db/blip_handler.go b/db/blip_handler.go index b6f185dc83..5ffc6595e4 100644 --- a/db/blip_handler.go +++ b/db/blip_handler.go @@ -141,7 +141,10 @@ func (bh *blipHandler) refreshUser() error { return base.NewHTTPError(CBLReconnectErrorCode, err.Error()) } newUser := bc.blipContextDb.User() - newUser.InitializeRoles() + err = newUser.InitializeRoles() + if err != nil { + return base.NewHTTPError(CBLReconnectErrorCode, err.Error()) + } bc.userChangeWaiter.RefreshUserKeys(newUser, bc.blipContextDb.MetadataKeys) // refresh the handler's database with the new BlipSyncContext database bh.db = bh._copyContextDatabase() diff --git a/db/blip_sync_context.go b/db/blip_sync_context.go index ae314f07a9..23cc549fb8 100644 --- a/db/blip_sync_context.go +++ b/db/blip_sync_context.go @@ -67,7 +67,10 @@ func NewBlipSyncContext(ctx context.Context, bc *blip.Context, db *Database, rep if u := db.User(); u != nil { bsc.userName = u.Name() - u.InitializeRoles() + err := u.InitializeRoles() + if err != nil { + return nil, err + } if u.Name() == "" && db.IsGuestReadOnly() { bsc.readOnly = true } diff --git a/db/changes.go b/db/changes.go index 44428963d2..c093b1fefe 100644 --- a/db/changes.go +++ b/db/changes.go @@ -674,14 +674,23 @@ func (col *DatabaseCollectionWithUser) checkForUserUpdates(ctx context.Context, base.DebugfCtx(ctx, base.KeyChanges, "MultiChangesFeed reloading user %+v", base.UD(col.user)) if col.user != nil { - previousChannels = col.user.InheritedCollectionChannels(col.ScopeName, col.Name) + previousChannels, err = col.user.InheritedCollectionChannels(col.ScopeName, col.Name) + if err != nil { + base.WarnfCtx(ctx, "Error getting previous channels for user %q: %v", base.UD(col.user.Name()), err) + return false, 0, nil, err + } previousRoles := col.user.RoleNames() if err := col.ReloadUser(ctx); err != nil { base.WarnfCtx(ctx, "Error reloading user %q: %v", base.UD(col.user.Name()), err) return false, 0, nil, err } // check whether channel set has changed - changedChannels = col.user.InheritedCollectionChannels(col.ScopeName, col.Name).CompareKeys(previousChannels) + channels, err := col.user.InheritedCollectionChannels(col.ScopeName, col.Name) + if err != nil { + base.WarnfCtx(ctx, "Error getting channels for user %q: %v", base.UD(col.user.Name()), err) + return false, 0, nil, err + } + changedChannels = channels.CompareKeys(previousChannels) if len(changedChannels) > 0 { base.DebugfCtx(ctx, base.KeyChanges, "Modified channel set after user reload: %v", base.UD(changedChannels)) } @@ -763,7 +772,17 @@ func (col *DatabaseCollectionWithUser) SimpleMultiChangesFeed(ctx context.Contex var channelsSince channels.TimedSet if col.user != nil { var channelsRemoved []string - channelsSince, channelsRemoved = col.user.FilterToAvailableCollectionChannels(col.ScopeName, col.Name, chans) + var err error + channelsSince, channelsRemoved, err = col.user.FilterToAvailableCollectionChannels(col.ScopeName, col.Name, chans) + if err != nil { + base.WarnfCtx(ctx, "Error filtering to available channels for user %q: %v", base.UD(col.user.Name()), err) + change := makeErrorEntry("Error filtering channels to user - terminating changes feed") + select { + case output <- &change: + case <-options.ChangesCtx.Done(): + } + return + } if len(channelsRemoved) > 0 { base.InfofCtx(ctx, base.KeyChanges, "Channels %s request without access by user %s", base.UD(channelsRemoved), base.UD(col.user.Name())) } @@ -920,7 +939,16 @@ func (col *DatabaseCollectionWithUser) SimpleMultiChangesFeed(ctx context.Contex } if options.Revocations && col.user != nil && !options.ActiveOnly { - channelsToRevoke := col.user.RevokedCollectionChannels(col.ScopeName, col.Name, options.Since.Seq, options.Since.LowSeq, options.Since.TriggeredBy) + channelsToRevoke, err := col.user.RevokedCollectionChannels(col.ScopeName, col.Name, options.Since.Seq, options.Since.LowSeq, options.Since.TriggeredBy) + if err != nil { + base.WarnfCtx(ctx, "Error retrieving revoked channels for user %q: %v", base.UD(col.user.Name()), err) + change := makeErrorEntry("Error retrieving revoked channels - terminating changes feed") + select { + case output <- &change: + case <-options.ChangesCtx.Done(): + } + return + } for channel, revokedSeq := range channelsToRevoke { revocationSinceSeq := options.Since.SafeSequence() revokeFrom := uint64(0) @@ -1141,7 +1169,16 @@ func (col *DatabaseCollectionWithUser) SimpleMultiChangesFeed(ctx context.Contex return } if userChanged && col.user != nil { - newChannelsSince, _ := col.user.FilterToAvailableCollectionChannels(col.ScopeName, col.Name, chans) + newChannelsSince, _, err := col.user.FilterToAvailableCollectionChannels(col.ScopeName, col.Name, chans) + if err != nil { + base.WarnfCtx(ctx, "Error filtering to available channels for user %q: %v", base.UD(col.user.Name()), err) + change := makeErrorEntry("Error filtering channels to user - terminating changes feed") + select { + case output <- &change: + case <-options.ChangesCtx.Done(): + } + return + } changedChannels = newChannelsSince.CompareKeys(channelsSince) if len(changedChannels) > 0 { @@ -1315,12 +1352,16 @@ func (db *DatabaseCollectionWithUser) DocIDChangesFeed(ctx context.Context, user // Subroutine that creates a response row for a document: output := make(chan *ChangeEntry, len(explicitDocIds)) + defer close(output) rowMap := make(map[uint64]*ChangeEntry) // Sort results by sequence var sequences base.SortedUint64Slice for _, docID := range explicitDocIds { - row := createChangesEntry(ctx, docID, db, options) + row, err := createChangesEntry(ctx, docID, db, options) + if err != nil { + return nil, err + } if row != nil { rowMap[row.Seq.Seq] = row sequences = append(sequences, row.Seq.Seq) @@ -1339,23 +1380,21 @@ func (db *DatabaseCollectionWithUser) DocIDChangesFeed(ctx context.Context, user } } - close(output) - return output, nil } // createChangesEntry is used when creating a doc ID filtered changes feed -func createChangesEntry(ctx context.Context, docid string, db *DatabaseCollectionWithUser, options ChangesOptions) *ChangeEntry { +func createChangesEntry(ctx context.Context, docid string, db *DatabaseCollectionWithUser, options ChangesOptions) (*ChangeEntry, error) { row := &ChangeEntry{ID: docid} populatedDoc, err := db.GetDocument(ctx, docid, DocUnmarshalSync) if err != nil { base.InfofCtx(ctx, base.KeyChanges, "Unable to get changes for docID %v, caused by %v", base.UD(docid), err) - return nil + return nil, nil } if populatedDoc.Sequence <= options.Since.Seq { - return nil + return nil, nil } versionRequested := options.VersionType @@ -1390,7 +1429,11 @@ func createChangesEntry(ctx context.Context, docid string, db *DatabaseCollectio // - the active revision is in a channel the user can see (removal==nil) // - the doc has been removed from a user's channel later the requested since value (removal.Seq > options.Since.Seq). In this case, we need to send removal:true changes entry for channel, removal := range populatedDoc.Channels { - if db.user.CanSeeCollectionChannel(db.ScopeName, db.Name, channel) && (removal == nil || removal.Seq > options.Since.Seq) { + canSee, err := db.user.CanSeeCollectionChannel(db.ScopeName, db.Name, channel) + if err != nil { + return nil, err + } + if canSee && (removal == nil || removal.Seq > options.Since.Seq) { userCanSeeDocChannel = true // If removal, update removed channels and deleted flag. if removal != nil { @@ -1404,7 +1447,7 @@ func createChangesEntry(ctx context.Context, docid string, db *DatabaseCollectio } if !userCanSeeDocChannel { - return nil + return nil, nil } row.Removed = base.SetFromArray(removedChannels) @@ -1414,7 +1457,7 @@ func createChangesEntry(ctx context.Context, docid string, db *DatabaseCollectio } } - return row + return row, nil } func (options ChangesOptions) String() string { diff --git a/db/crud.go b/db/crud.go index 4c70ff6efc..aab2d90ecd 100644 --- a/db/crud.go +++ b/db/crud.go @@ -1751,8 +1751,11 @@ func (db *DatabaseCollectionWithUser) SyncFnDryrun(ctx context.Context, body Bod return nil, err, nil } - output, err := db.ChannelMapper.MapToChannelsAndAccess(ctx, mutableBody, string(oldDoc._rawBody), metaMap, - MakeUserCtx(db.user, db.ScopeName, db.Name)) + syncOptions, err := MakeUserCtx(db.user, db.ScopeName, db.Name) + if err != nil { + return nil, err, nil + } + output, err := db.ChannelMapper.MapToChannelsAndAccess(ctx, mutableBody, string(oldDoc._rawBody), metaMap, syncOptions) return output, nil, err } @@ -3264,8 +3267,12 @@ func (col *DatabaseCollectionWithUser) getChannelsAndAccess(ctx context.Context, var output *channels.ChannelMapperOutput startTime := time.Now() - output, err = col.ChannelMapper.MapToChannelsAndAccess(ctx, body, oldJson, metaMap, - MakeUserCtx(col.user, col.ScopeName, col.Name)) + var syncOptions map[string]any + syncOptions, err = MakeUserCtx(col.user, col.ScopeName, col.Name) + if err != nil { + return result, access, roles, expiry, oldJson, err + } + output, err = col.ChannelMapper.MapToChannelsAndAccess(ctx, body, oldJson, metaMap, syncOptions) syncFunctionTimeNano := time.Since(startTime).Nanoseconds() col.dbStats().Database().SyncFunctionTime.Add(syncFunctionTimeNano) @@ -3320,15 +3327,19 @@ func (col *DatabaseCollectionWithUser) getChannelsAndAccess(ctx context.Context, } // Creates a userCtx object to be passed to the sync function -func MakeUserCtx(user auth.User, scopeName string, collectionName string) map[string]any { +func MakeUserCtx(user auth.User, scopeName string, collectionName string) (map[string]any, error) { if user == nil { - return nil + return nil, nil + } + channels, err := user.InheritedCollectionChannels(scopeName, collectionName) + if err != nil { + return nil, err } return map[string]any{ "name": user.Name(), "roles": user.RoleNames(), - "channels": user.InheritedCollectionChannels(scopeName, collectionName).AllKeys(), - } + "channels": channels.AllKeys(), + }, nil } // Are the principal and role names in an AccessMap all valid? diff --git a/db/database_test.go b/db/database_test.go index 7e04e53a73..cc10be8878 100644 --- a/db/database_test.go +++ b/db/database_test.go @@ -2147,7 +2147,9 @@ func TestAccessFunctionDb(t *testing.T) { assert.Equal(t, expected, user.CollectionChannels(collection.ScopeName, collection.Name)) expected.AddChannel("CrunchyRoll", 2) - assert.Equal(t, expected, user.InheritedCollectionChannels(collection.ScopeName, collection.Name)) + channels, err := user.InheritedCollectionChannels(collection.ScopeName, collection.Name) + require.NoError(t, err) + assert.Equal(t, expected, channels) } func TestDocIDs(t *testing.T) { diff --git a/db/design_doc.go b/db/design_doc.go index e6721430da..07e2ed1b11 100644 --- a/db/design_doc.go +++ b/db/design_doc.go @@ -268,20 +268,23 @@ func (db *Database) QueryDesignDoc(ctx context.Context, ddocName string, viewNam stripSyncProperty(row) } } - } else { - applyChannelFiltering := options["reduce"] != true && db.GetUserViewsEnabled() - result = filterViewResult(result, db.user, applyChannelFiltering) + return &result, nil + } + applyChannelFiltering := options["reduce"] != true && db.GetUserViewsEnabled() + result, err = filterViewResult(result, db.user, applyChannelFiltering) + if err != nil { + return nil, err } return &result, nil } // Cleans up the Value property, and removes rows that aren't visible to the current user -func filterViewResult(input sgbucket.ViewResult, user auth.User, applyChannelFiltering bool) (result sgbucket.ViewResult) { +func filterViewResult(input sgbucket.ViewResult, user auth.User, applyChannelFiltering bool) (result sgbucket.ViewResult, err error) { hasStarChannel := false var visibleChannels ch.TimedSet if user != nil { // Views only support default collection, so filter based on default collection channels - visibleChannels = user.InheritedCollectionChannels(base.DefaultScope, base.DefaultCollection) + visibleChannels, err = user.InheritedCollectionChannels(base.DefaultScope, base.DefaultCollection) hasStarChannel = !visibleChannels.Contains("*") if !applyChannelFiltering { return // this is an error diff --git a/db/functions/function.go b/db/functions/function.go index 4767f5640f..3a4cb69576 100644 --- a/db/functions/function.go +++ b/db/functions/function.go @@ -281,9 +281,16 @@ func (fn *functionImpl) authorize(user auth.User, args map[string]any) error { for _, channelPattern := range allow.Channels { if channelPattern == channels.AllChannelWildcard { return nil - } else if channel, err := expandPattern(channelPattern, args, user); err != nil { + } + channel, err := expandPattern(channelPattern, args, user) + if err != nil { + return err + } + canSee, err := user.CanSeeCollectionChannel(base.DefaultScope, base.DefaultCollection, channel) + if err != nil { return err - } else if user.CanSeeCollectionChannel(base.DefaultScope, base.DefaultCollection, channel) { + } + if canSee { return nil // User has access to one of the allowed channels } } diff --git a/db/functions/js_function.go b/db/functions/js_function.go index a1c235999d..becdb88112 100644 --- a/db/functions/js_function.go +++ b/db/functions/js_function.go @@ -34,7 +34,11 @@ func (fn *jsInvocation) Iterate() (sgbucket.QueryResultIterator, error) { } func (fn *jsInvocation) Run(ctx context.Context) (any, error) { - return fn.call(ctx, db.MakeUserCtx(fn.db.User(), base.DefaultScope, base.DefaultCollection), fn.args) + syncCtx, err := db.MakeUserCtx(fn.db.User(), base.DefaultScope, base.DefaultCollection) + if err != nil { + return nil, err + } + return fn.call(ctx, syncCtx, fn.args) } func (fn *jsInvocation) call(ctx context.Context, jsArgs ...any) (any, error) { diff --git a/db/util_testing.go b/db/util_testing.go index 4f9085737d..c5bb20f39e 100644 --- a/db/util_testing.go +++ b/db/util_testing.go @@ -493,7 +493,9 @@ func (dbc *DatabaseContext) GetPrincipalForTest(tb testing.TB, name string, isUs info.Name = &name info.ExplicitChannels = princ.CollectionExplicitChannels(base.DefaultScope, base.DefaultCollection).AsSet() if user, ok := princ.(auth.User); ok { - info.Channels = user.InheritedCollectionChannels(base.DefaultScope, base.DefaultCollection).AsSet() + channels, err := user.InheritedCollectionChannels(base.DefaultScope, base.DefaultCollection) + require.NoError(tb, err) + info.Channels = channels.AsSet() email := user.Email() info.Email = &email info.Disabled = base.Ptr(user.Disabled()) diff --git a/rest/admin_api.go b/rest/admin_api.go index 7b28759398..0ce79701f3 100644 --- a/rest/admin_api.go +++ b/rest/admin_api.go @@ -1724,7 +1724,7 @@ func externalUserName(name string) string { } // marshalPrincipal outputs a PrincipalConfig in a format for REST API endpoints. -func marshalPrincipal(database *db.Database, princ auth.Principal, includeDynamicGrantInfo bool) auth.PrincipalConfig { +func marshalPrincipal(database *db.Database, princ auth.Principal, includeDynamicGrantInfo bool) (*auth.PrincipalConfig, error) { name := externalUserName(princ.Name()) info := auth.PrincipalConfig{ Name: &name, @@ -1747,7 +1747,11 @@ func marshalPrincipal(database *db.Database, princ auth.Principal, includeDynami } if includeDynamicGrantInfo { if user, ok := princ.(auth.User); ok { - collectionAccessConfig.Channels_ = user.InheritedCollectionChannels(scopeName, collectionName).AsSet() + channels, err := user.InheritedCollectionChannels(scopeName, collectionName) + if err != nil { + return nil, err + } + collectionAccessConfig.Channels_ = channels.AsSet() collectionAccessConfig.JWTChannels_ = user.CollectionJWTChannels(scopeName, collectionName).AsSet() lastUpdated := collection.JWTLastUpdated if lastUpdated != nil && !lastUpdated.IsZero() { @@ -1769,7 +1773,11 @@ func marshalPrincipal(database *db.Database, princ auth.Principal, includeDynami info.Disabled = base.Ptr(user.Disabled()) info.ExplicitRoleNames = user.ExplicitRoles().AsSet() if includeDynamicGrantInfo { - info.Channels = user.InheritedCollectionChannels(base.DefaultScope, base.DefaultCollection).AsSet() + channels, err := user.InheritedCollectionChannels(base.DefaultScope, base.DefaultCollection) + if err != nil { + return nil, err + } + info.Channels = channels.AsSet() info.RoleNames = user.RoleNames().AllKeys() info.JWTIssuer = base.Ptr(user.JWTIssuer()) info.JWTRoles = user.JWTRoles().AsSet() @@ -1784,7 +1792,7 @@ func marshalPrincipal(database *db.Database, princ auth.Principal, includeDynami info.Channels = princ.Channels().AsSet() } } - return info + return &info, nil } // Handles PUT and POST for a user or a role. @@ -1961,7 +1969,10 @@ func (h *handler) getUserInfo() error { } // If not specified will default to false includeDynamicGrantInfo := h.permissionsResults[PermReadPrincipalAppData.PermissionName] - info := marshalPrincipal(h.db, user, includeDynamicGrantInfo) + info, err := marshalPrincipal(h.db, user, includeDynamicGrantInfo) + if err != nil { + return err + } // If the user's OIDC issuer is no longer valid, remove the OIDC information to avoid confusing users // (it'll get removed permanently the next time the user signs in) if info.JWTIssuer != nil { @@ -2002,7 +2013,10 @@ func (h *handler) getRoleInfo() error { } // If not specified will default to false includeDynamicGrantInfo := h.permissionsResults[PermReadPrincipalAppData.PermissionName] - info := marshalPrincipal(h.db, role, includeDynamicGrantInfo) + info, err := marshalPrincipal(h.db, role, includeDynamicGrantInfo) + if err != nil { + return err + } b, err := base.JSONMarshal(info) if err == nil { base.Audit(h.ctx(), base.AuditIDRoleRead, base.AuditFields{ diff --git a/rest/bulk_api.go b/rest/bulk_api.go index 77007ce322..19157e1c76 100644 --- a/rest/bulk_api.go +++ b/rest/bulk_api.go @@ -97,10 +97,14 @@ func (h *handler) handleAllDocs() error { // Get the set of channels the user has access to; nil if user is admin or has access to user "*" var availableChannels ch.TimedSet if h.user != nil { - availableChannels = h.user.InheritedCollectionChannels(h.collection.ScopeName, h.collection.Name) + var err error + availableChannels, err = h.user.InheritedCollectionChannels(h.collection.ScopeName, h.collection.Name) + if err != nil { + return err + } if availableChannels == nil { - // TODO: CBG-1948 - panic("no channels for user?") + base.AssertfCtx(h.ctx(), "User %q has no channels in handleAllDocs", base.UD(h.user.Name())) + return base.HTTPErrorf(http.StatusInternalServerError, "user has no channels") } if availableChannels.Contains(ch.UserStarChannel) { availableChannels = nil diff --git a/rest/diagnostic_api.go b/rest/diagnostic_api.go index e668504621..d6a8cf12f4 100644 --- a/rest/diagnostic_api.go +++ b/rest/diagnostic_api.go @@ -43,7 +43,10 @@ func (h *handler) getAllUserChannelsResponse(user auth.User) (map[string]map[str for _, dsName := range h.db.DataStoreNames() { keyspace := dsName.ScopeName() + "." + dsName.CollectionName() - currentChannels := user.InheritedCollectionChannels(dsName.ScopeName(), dsName.CollectionName()) + currentChannels, err := user.InheritedCollectionChannels(dsName.ScopeName(), dsName.CollectionName()) + if err != nil { + return nil, fmt.Errorf("error getting user channels: %w", err) + } chanHistory := user.CollectionChannelHistory(dsName.ScopeName(), dsName.CollectionName()) // If no channels aside from public and no channels in history, don't make a key for this keyspace diff --git a/rest/handler.go b/rest/handler.go index 53f1dcd379..0110ca11cc 100644 --- a/rest/handler.go +++ b/rest/handler.go @@ -883,18 +883,21 @@ func needRolesForAudit(db *db.DatabaseContext, domain base.UserIDDomain) bool { } // getSGUserRolesForAudit returns a list of role names for the given user, if audit role filtering is enabled. -func getSGUserRolesForAudit(db *db.DatabaseContext, user auth.User) []string { +func getSGUserRolesForAudit(db *db.DatabaseContext, user auth.User) ([]string, error) { if user == nil { - return nil + return nil, nil } if !needRolesForAudit(db, base.UserDomainSyncGateway) { - return nil + return nil, nil } - roles := user.GetRoles() + roles, err := user.GetRoles() + if err != nil { + return nil, err + } if len(roles) == 0 { - return nil + return nil, nil } roleNames := make([]string, 0, len(roles)) @@ -902,11 +905,11 @@ func getSGUserRolesForAudit(db *db.DatabaseContext, user auth.User) []string { roleNames = append(roleNames, role.Name()) } - return roleNames + return roleNames, nil } -// checkPublicAuth verifies that the current request is authenticated for the given database. -// +// checkPublicAuth verifies that the current request is authenticated for the given database. Returns an HTTPError if +// authentication fails. // NOTE: checkPublicAuth is not used for the admin interface. func (h *handler) checkPublicAuth(dbCtx *db.DatabaseContext) (err error) { @@ -915,40 +918,53 @@ func (h *handler) checkPublicAuth(dbCtx *db.DatabaseContext) (err error) { return nil } - var auditFields base.AuditFields - - // Record Auth stats - defer func(t time.Time) { - delta := time.Since(t).Nanoseconds() - dbCtx.DbStats.Security().TotalAuthTime.Add(delta) - if err != nil { - dbCtx.DbStats.Security().AuthFailedCount.Add(1) - if errors.Is(err, ErrInvalidLogin) { - base.Audit(h.ctx(), base.AuditIDPublicUserAuthenticationFailed, auditFields) - } - } else { - dbCtx.DbStats.Security().AuthSuccessCount.Add(1) - - username := "" - if h.isGuest() { - username = base.GuestUsername - } else if h.user != nil { - username = h.user.Name() - } - roleNames := getSGUserRolesForAudit(dbCtx, h.user) - h.rqCtx = base.UserLogCtx(h.ctx(), username, base.UserDomainSyncGateway, roleNames) - base.Audit(h.ctx(), base.AuditIDPublicUserAuthenticated, auditFields) + start := time.Now() + auditFields, err := h.setUserForPublicAuth(dbCtx) + dbCtx.DbStats.Security().TotalAuthTime.Add(time.Since(start).Nanoseconds()) + if err != nil { + dbCtx.DbStats.Security().AuthFailedCount.Add(1) + if errors.Is(err, ErrInvalidLogin) { + base.Audit(h.ctx(), base.AuditIDPublicUserAuthenticationFailed, auditFields) } - }(time.Now()) + return + } + dbCtx.DbStats.Security().AuthSuccessCount.Add(1) + + username := "" + if h.isGuest() { + username = base.GuestUsername + } else if h.user != nil { + username = h.user.Name() + } + roleNames, err := getSGUserRolesForAudit(dbCtx, h.user) + if err != nil { + base.InfofCtx(h.ctx(), base.KeyHTTP, "Error getting user roles for audit logging: %v", err) + } + h.rqCtx = base.UserLogCtx(h.ctx(), username, base.UserDomainSyncGateway, roleNames) + base.Audit(h.ctx(), base.AuditIDPublicUserAuthenticated, auditFields) + return err +} +// setUserForPublicAuth sets h.user based on the authentication information in the request. Returns an error if the user +// can not authenticate successfully, and returns AuditFields even in the case that there is an error in the request. +// +// Uses: +// +// 1. Bearer token (OIDC JWT) if present and OIDC is enabled +// 2. Basic auth if present and password authentication is not disabled +// 3. Cookie auth if present +// 4. Guest access if enabled +func (h *handler) setUserForPublicAuth(dbCtx *db.DatabaseContext) (base.AuditFields, error) { + var auditFields base.AuditFields // If oidc enabled, check for bearer ID token if dbCtx.Options.OIDCOptions != nil || len(dbCtx.LocalJWTProviders) > 0 { if token := h.getBearerToken(); token != "" { auditFields = base.AuditFields{base.AuditFieldAuthMethod: "bearer"} var updates auth.PrincipalConfig + var err error h.user, updates, err = dbCtx.Authenticator(h.ctx()).AuthenticateUntrustedJWT(token, dbCtx.OIDCProviders, dbCtx.LocalJWTProviders, h.getOIDCCallbackURL) if h.user == nil || err != nil { - return ErrInvalidLogin + return auditFields, ErrInvalidLogin } if issuer := h.user.JWTIssuer(); issuer != "" { auditFields["oidc_issuer"] = issuer @@ -956,15 +972,15 @@ func (h *handler) checkPublicAuth(dbCtx *db.DatabaseContext) (err error) { if changes := checkJWTIssuerStillValid(h.ctx(), dbCtx, h.user); changes != nil { updates = updates.Merge(*changes) } - _, _, err := dbCtx.UpdatePrincipal(h.ctx(), &updates, true, true) + _, _, err = dbCtx.UpdatePrincipal(h.ctx(), &updates, true, true) if err != nil { - return fmt.Errorf("failed to update OIDC user after sign-in: %w", err) + return auditFields, fmt.Errorf("failed to update OIDC user after sign-in: %w", err) } // TODO: could avoid this extra fetch if UpdatePrincipal returned the newly updated principal if updates.Name != nil { h.user, err = dbCtx.Authenticator(h.ctx()).GetUser(*updates.Name) } - return err + return auditFields, err } /* @@ -979,8 +995,7 @@ func (h *handler) checkPublicAuth(dbCtx *db.DatabaseContext) (err error) { provider := dbCtx.Options.OIDCOptions.Providers.GetProviderForIssuer(h.ctx(), issuerUrlForDB(h, dbCtx.Name), testProviderAudiences) if provider != nil && provider.ValidationKey != nil { if base.ValDefault(provider.ClientID, "") == username && *provider.ValidationKey == password { - auditFields = base.AuditFields{base.AuditFieldAuthMethod: "basic"} - return nil + return base.AuditFields{base.AuditFieldAuthMethod: "basic"}, nil } } } @@ -990,47 +1005,49 @@ func (h *handler) checkPublicAuth(dbCtx *db.DatabaseContext) (err error) { // Check basic auth first if !dbCtx.Options.DisablePasswordAuthentication { if userName, password := h.getBasicAuth(); userName != "" { - auditFields = base.AuditFields{base.AuditFieldAuthMethod: "basic"} + auditFields := base.AuditFields{base.AuditFieldAuthMethod: "basic"} + var err error h.user, err = dbCtx.Authenticator(h.ctx()).AuthenticateUser(userName, password) if err != nil { - return err + return auditFields, err } if h.user == nil { auditFields["username"] = userName if dbCtx.Options.SendWWWAuthenticateHeader == nil || *dbCtx.Options.SendWWWAuthenticateHeader { h.response.Header().Set("WWW-Authenticate", wwwAuthenticateHeader) } - return ErrInvalidLogin + return auditFields, ErrInvalidLogin } - return nil + return auditFields, nil } } // Check cookie auditFields = base.AuditFields{base.AuditFieldAuthMethod: "cookie"} + var err error h.user, err = dbCtx.Authenticator(h.ctx()).AuthenticateCookie(h.rq, h.response) if err != nil && h.privs != publicPrivs { - return err + return auditFields, err } else if h.user != nil { - return nil + return auditFields, nil } // No auth given -- check guest access auditFields = base.AuditFields{base.AuditFieldAuthMethod: "guest"} if h.user, err = dbCtx.Authenticator(h.ctx()).GetUser(""); err != nil { - return err + return auditFields, err } if h.privs == regularPrivs && h.user.Disabled() { if dbCtx.Options.SendWWWAuthenticateHeader == nil || *dbCtx.Options.SendWWWAuthenticateHeader { h.response.Header().Set("WWW-Authenticate", wwwAuthenticateHeader) } if h.providedAuthCredentials() { - return ErrInvalidLogin + return auditFields, ErrInvalidLogin } - return ErrLoginRequired + return auditFields, ErrLoginRequired } - return nil + return auditFields, nil } func checkJWTIssuerStillValid(ctx context.Context, dbCtx *db.DatabaseContext, user auth.User) *auth.PrincipalConfig { diff --git a/rest/role_api_test.go b/rest/role_api_test.go index 7cdd1ad948..7e44919ac8 100644 --- a/rest/role_api_test.go +++ b/rest/role_api_test.go @@ -339,13 +339,16 @@ func TestRoleAccessChanges(t *testing.T) { RequireStatus(t, rt.SendRequest("PUT", "/{{.keyspace}}/d1", `{"channel":"delta"}`), 201) // Check user access: - alice, _ := a.GetUser("alice") + alice, err := a.GetUser("alice") + require.NoError(t, err) + chs, err := alice.InheritedCollectionChannels(s, c) + require.NoError(t, err) assert.Equal(t, channels.TimedSet{ "!": channels.NewVbSimpleSequence(1), "alpha": channels.NewVbSimpleSequence(alice.Sequence()), "gamma": channels.NewVbSimpleSequence(roleGrantSequence), - }, alice.InheritedCollectionChannels(s, c)) + }, chs) assert.Equal(t, channels.TimedSet{ @@ -353,13 +356,15 @@ func TestRoleAccessChanges(t *testing.T) { "hipster": channels.NewVbSimpleSequence(roleGrantSequence), }, alice.RoleNames()) - zegpold, _ := a.GetUser("zegpold") + zegpold, err := a.GetUser("zegpold") + require.NoError(t, err) + chs, err = zegpold.InheritedCollectionChannels(s, c) + require.NoError(t, err) assert.Equal(t, - channels.TimedSet{ "!": channels.NewVbSimpleSequence(1), "beta": channels.NewVbSimpleSequence(zegpold.Sequence()), - }, zegpold.InheritedCollectionChannels(s, c)) + }, chs) assert.Equal(t, channels.TimedSet{}, zegpold.RoleNames()) @@ -387,20 +392,26 @@ func TestRoleAccessChanges(t *testing.T) { updatedRoleGrantSequence := rt.GetDocumentSequence("fashion") // Check user access again: - alice, _ = a.GetUser("alice") + alice, err = a.GetUser("alice") + require.NoError(t, err) + chs, err = alice.InheritedCollectionChannels(s, c) + require.NoError(t, err) assert.Equal(t, channels.TimedSet{ "!": channels.NewVbSimpleSequence(0x1), "alpha": channels.NewVbSimpleSequence(alice.Sequence()), - }, alice.InheritedCollectionChannels(s, c)) + }, chs) - zegpold, _ = a.GetUser("zegpold") + zegpold, err = a.GetUser("zegpold") + require.NoError(t, err) + chs, err = zegpold.InheritedCollectionChannels(s, c) + require.NoError(t, err) assert.Equal(t, channels.TimedSet{ "!": channels.NewVbSimpleSequence(0x1), "beta": channels.NewVbSimpleSequence(zegpold.Sequence()), "gamma": channels.NewVbSimpleSequence(updatedRoleGrantSequence), - }, zegpold.InheritedCollectionChannels(s, c)) + }, chs) // The complete _changes feed for zegpold contains docs g1 and b1: cacheWaiter.Wait() diff --git a/rest/user_api_test.go b/rest/user_api_test.go index 33e73a2dd0..53a8d06ff0 100644 --- a/rest/user_api_test.go +++ b/rest/user_api_test.go @@ -631,7 +631,10 @@ func TestObtainUserChannelsForDeletedRoleCasFail(t *testing.T) { triggerCallback = true } - assert.Equal(t, []string{"!"}, user.InheritedCollectionChannels(s, c).AllKeys()) + chs, err := user.InheritedCollectionChannels(s, c) + assert.NoError(t, err) + + assert.Equal(t, []string{"!"}, chs.AllKeys()) // Ensure callback ran assert.False(t, triggerCallback)