From 34143bc43de572b8e81e07dfcf9883b336730a10 Mon Sep 17 00:00:00 2001 From: Jim Date: Wed, 1 Jun 2022 09:14:22 -0400 Subject: [PATCH] feature (request): add Request.ConnectionID() This allows callers to correlated requests to a connection --- examples/simple/main.go | 134 ++++++++++++++++++++++------------------ request.go | 8 +++ request_test.go | 16 +++++ 3 files changed, 97 insertions(+), 61 deletions(-) diff --git a/examples/simple/main.go b/examples/simple/main.go index b3dff86..f99aab2 100644 --- a/examples/simple/main.go +++ b/examples/simple/main.go @@ -18,6 +18,9 @@ func main() { Level: hclog.Debug, }) + // a very simple way to track authenticated connections + authenticatedConnections := map[int]struct{}{} + // create a new server s, err := gldap.NewServer(gldap.WithLogger(l), gldap.WithDisablePanicRecovery()) if err != nil { @@ -29,8 +32,8 @@ func main() { if err != nil { log.Fatalf("unable to create router: %s", err.Error()) } - r.Bind(bindHandler) - r.Search(searchHandler, gldap.WithLabel("All Searches")) + r.Bind(bindHandler(authenticatedConnections)) + r.Search(searchHandler(authenticatedConnections), gldap.WithLabel("All Searches")) s.Router(r) go s.Run(":10389") // listen on port 10389 @@ -43,69 +46,78 @@ func main() { } } -func bindHandler(w *gldap.ResponseWriter, r *gldap.Request) { - resp := r.NewBindResponse( - gldap.WithResponseCode(gldap.ResultInvalidCredentials), - ) - defer func() { - w.Write(resp) - }() - - m, err := r.GetSimpleBindMessage() - if err != nil { - log.Printf("not a simple bind message: %s", err) - return - } +func bindHandler(authenticatedConnections map[int]struct{}) func(*gldap.ResponseWriter, *gldap.Request) { + return func(w *gldap.ResponseWriter, r *gldap.Request) { + resp := r.NewBindResponse( + gldap.WithResponseCode(gldap.ResultInvalidCredentials), + ) + defer func() { + w.Write(resp) + }() - if m.UserName == "alice" { - resp.SetResultCode(gldap.ResultSuccess) - log.Println("bind success") - return + m, err := r.GetSimpleBindMessage() + if err != nil { + log.Printf("not a simple bind message: %s", err) + return + } + if m.UserName == "uid=alice" { + authenticatedConnections[r.ConnectionID()] = struct{}{} // mark connection as authenticated + resp.SetResultCode(gldap.ResultSuccess) + log.Println("bind success") + return + } } } -func searchHandler(w *gldap.ResponseWriter, r *gldap.Request) { - resp := r.NewSearchDoneResponse(gldap.WithResponseCode(gldap.ResultNoSuchObject)) - defer func() { - w.Write(resp) - }() - m, err := r.GetSearchMessage() - if err != nil { - log.Printf("not a search message: %s", err) - return - } - log.Printf("search base dn: %s", m.BaseDN) - log.Printf("search scope: %d", m.Scope) - log.Printf("search filter: %s", m.Filter) +func searchHandler(authenticatedConnections map[int]struct{}) func(w *gldap.ResponseWriter, r *gldap.Request) { + return func(w *gldap.ResponseWriter, r *gldap.Request) { + resp := r.NewSearchDoneResponse(gldap.WithResponseCode(gldap.ResultNoSuchObject)) + defer func() { + w.Write(resp) + }() + // check if connection is authenticated + if _, ok := authenticatedConnections[r.ConnectionID()]; !ok { + log.Printf("connection %d is not authorized", r.ConnectionID()) + resp.SetResultCode(gldap.ResultAuthorizationDenied) + return + } + m, err := r.GetSearchMessage() + if err != nil { + log.Printf("not a search message: %s", err) + return + } + log.Printf("search base dn: %s", m.BaseDN) + log.Printf("search scope: %d", m.Scope) + log.Printf("search filter: %s", m.Filter) - if strings.Contains(m.Filter, "uid=alice") || m.BaseDN == "uid=alice,ou=people,cn=example,dc=org" { - entry := r.NewSearchResponseEntry( - "uid=alice,ou=people,cn=example,dc=org", - gldap.WithAttributes(map[string][]string{ - "objectclass": {"top", "person", "organizationalPerson", "inetOrgPerson"}, - "uid": {"alice"}, - "cn": {"alice eve smith"}, - "givenname": {"alice"}, - "sn": {"smith"}, - "ou": {"people"}, - "description": {"friend of Rivest, Shamir and Adleman"}, - "password": {"{SSHA}U3waGJVC7MgXYc0YQe7xv7sSePuTP8zN"}, - }), - ) - entry.AddAttribute("email", []string{"alice@example.org"}) - w.Write(entry) - resp.SetResultCode(gldap.ResultSuccess) - } - if m.BaseDN == "ou=people,cn=example,dc=org" { - entry := r.NewSearchResponseEntry( - "ou=people,cn=example,dc=org", - gldap.WithAttributes(map[string][]string{ - "objectclass": {"organizationalUnit"}, - "ou": {"people"}, - }), - ) - w.Write(entry) - resp.SetResultCode(gldap.ResultSuccess) + if strings.Contains(m.Filter, "uid=alice") || m.BaseDN == "uid=alice,ou=people,cn=example,dc=org" { + entry := r.NewSearchResponseEntry( + "uid=alice,ou=people,cn=example,dc=org", + gldap.WithAttributes(map[string][]string{ + "objectclass": {"top", "person", "organizationalPerson", "inetOrgPerson"}, + "uid": {"alice"}, + "cn": {"alice eve smith"}, + "givenname": {"alice"}, + "sn": {"smith"}, + "ou": {"people"}, + "description": {"friend of Rivest, Shamir and Adleman"}, + "password": {"{SSHA}U3waGJVC7MgXYc0YQe7xv7sSePuTP8zN"}, + }), + ) + entry.AddAttribute("email", []string{"alice@example.org"}) + w.Write(entry) + resp.SetResultCode(gldap.ResultSuccess) + } + if m.BaseDN == "ou=people,cn=example,dc=org" { + entry := r.NewSearchResponseEntry( + "ou=people,cn=example,dc=org", + gldap.WithAttributes(map[string][]string{ + "objectclass": {"organizationalUnit"}, + "ou": {"people"}, + }), + ) + w.Write(entry) + resp.SetResultCode(gldap.ResultSuccess) + } } - return } diff --git a/request.go b/request.go index 838c7a9..fee2b2e 100644 --- a/request.go +++ b/request.go @@ -79,6 +79,14 @@ func newRequest(id int, c *conn, p *packet) (*Request, error) { return r, nil } +// ConnectionID returns the request's connection ID which enables you to know +// "who" (i.e. which connection) made a request. Using the connection ID you +// can do things like ensure a connection performing a search operation has +// successfully authenticated (a.k.a. performed a successful bind operation). +func (r *Request) ConnectionID() int { + return r.conn.connID +} + // NewModifyResponse creates a modify response // Supported options: WithResponseCode, WithDiagnosticMessage, WithMatchedDN func (r *Request) NewModifyResponse(opt ...Option) *ModifyResponse { diff --git a/request_test.go b/request_test.go index b3a626c..878ca65 100644 --- a/request_test.go +++ b/request_test.go @@ -299,3 +299,19 @@ func TestRequest_GetDeleteMessage(t *testing.T) { }) } } + +func TestRequest_GetConnectionID(t *testing.T) { + t.Parallel() + assert, require := assert.New(t), require.New(t) + const ( + requestID = 1 + connID = 2 + ) + conn := &conn{connID: connID} + packet := testSearchRequestPacket(t, + SearchMessage{baseMessage: baseMessage{id: 1}, Filter: "(uid=alice)"}, + ) + req, err := newRequest(requestID, conn, packet) + require.NoError(err) + assert.Equal(connID, req.ConnectionID()) +}