Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 73 additions & 61 deletions examples/simple/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ func main() {
Level: hclog.Debug,
})

// a very simple way to track authenticated connections
authenticatedConnections := map[int]struct{}{}

// create a new server
s, err := gldap.NewServer(gldap.WithLogger(l), gldap.WithDisablePanicRecovery())
if err != nil {
Expand All @@ -29,8 +32,8 @@ func main() {
if err != nil {
log.Fatalf("unable to create router: %s", err.Error())
}
r.Bind(bindHandler)
r.Search(searchHandler, gldap.WithLabel("All Searches"))
r.Bind(bindHandler(authenticatedConnections))
r.Search(searchHandler(authenticatedConnections), gldap.WithLabel("All Searches"))
s.Router(r)
go s.Run(":10389") // listen on port 10389

Expand All @@ -43,69 +46,78 @@ func main() {
}
}

func bindHandler(w *gldap.ResponseWriter, r *gldap.Request) {
resp := r.NewBindResponse(
gldap.WithResponseCode(gldap.ResultInvalidCredentials),
)
defer func() {
w.Write(resp)
}()

m, err := r.GetSimpleBindMessage()
if err != nil {
log.Printf("not a simple bind message: %s", err)
return
}
func bindHandler(authenticatedConnections map[int]struct{}) func(*gldap.ResponseWriter, *gldap.Request) {
return func(w *gldap.ResponseWriter, r *gldap.Request) {
resp := r.NewBindResponse(
gldap.WithResponseCode(gldap.ResultInvalidCredentials),
)
defer func() {
w.Write(resp)
}()

if m.UserName == "alice" {
resp.SetResultCode(gldap.ResultSuccess)
log.Println("bind success")
return
m, err := r.GetSimpleBindMessage()
if err != nil {
log.Printf("not a simple bind message: %s", err)
return
}
if m.UserName == "uid=alice" {
authenticatedConnections[r.ConnectionID()] = struct{}{} // mark connection as authenticated
resp.SetResultCode(gldap.ResultSuccess)
log.Println("bind success")
return
}
}
}

func searchHandler(w *gldap.ResponseWriter, r *gldap.Request) {
resp := r.NewSearchDoneResponse(gldap.WithResponseCode(gldap.ResultNoSuchObject))
defer func() {
w.Write(resp)
}()
m, err := r.GetSearchMessage()
if err != nil {
log.Printf("not a search message: %s", err)
return
}
log.Printf("search base dn: %s", m.BaseDN)
log.Printf("search scope: %d", m.Scope)
log.Printf("search filter: %s", m.Filter)
func searchHandler(authenticatedConnections map[int]struct{}) func(w *gldap.ResponseWriter, r *gldap.Request) {
return func(w *gldap.ResponseWriter, r *gldap.Request) {
resp := r.NewSearchDoneResponse(gldap.WithResponseCode(gldap.ResultNoSuchObject))
defer func() {
w.Write(resp)
}()
// check if connection is authenticated
if _, ok := authenticatedConnections[r.ConnectionID()]; !ok {
log.Printf("connection %d is not authorized", r.ConnectionID())
resp.SetResultCode(gldap.ResultAuthorizationDenied)
return
}
m, err := r.GetSearchMessage()
if err != nil {
log.Printf("not a search message: %s", err)
return
}
log.Printf("search base dn: %s", m.BaseDN)
log.Printf("search scope: %d", m.Scope)
log.Printf("search filter: %s", m.Filter)

if strings.Contains(m.Filter, "uid=alice") || m.BaseDN == "uid=alice,ou=people,cn=example,dc=org" {
entry := r.NewSearchResponseEntry(
"uid=alice,ou=people,cn=example,dc=org",
gldap.WithAttributes(map[string][]string{
"objectclass": {"top", "person", "organizationalPerson", "inetOrgPerson"},
"uid": {"alice"},
"cn": {"alice eve smith"},
"givenname": {"alice"},
"sn": {"smith"},
"ou": {"people"},
"description": {"friend of Rivest, Shamir and Adleman"},
"password": {"{SSHA}U3waGJVC7MgXYc0YQe7xv7sSePuTP8zN"},
}),
)
entry.AddAttribute("email", []string{"[email protected]"})
w.Write(entry)
resp.SetResultCode(gldap.ResultSuccess)
}
if m.BaseDN == "ou=people,cn=example,dc=org" {
entry := r.NewSearchResponseEntry(
"ou=people,cn=example,dc=org",
gldap.WithAttributes(map[string][]string{
"objectclass": {"organizationalUnit"},
"ou": {"people"},
}),
)
w.Write(entry)
resp.SetResultCode(gldap.ResultSuccess)
if strings.Contains(m.Filter, "uid=alice") || m.BaseDN == "uid=alice,ou=people,cn=example,dc=org" {
entry := r.NewSearchResponseEntry(
"uid=alice,ou=people,cn=example,dc=org",
gldap.WithAttributes(map[string][]string{
"objectclass": {"top", "person", "organizationalPerson", "inetOrgPerson"},
"uid": {"alice"},
"cn": {"alice eve smith"},
"givenname": {"alice"},
"sn": {"smith"},
"ou": {"people"},
"description": {"friend of Rivest, Shamir and Adleman"},
"password": {"{SSHA}U3waGJVC7MgXYc0YQe7xv7sSePuTP8zN"},
}),
)
entry.AddAttribute("email", []string{"[email protected]"})
w.Write(entry)
resp.SetResultCode(gldap.ResultSuccess)
}
if m.BaseDN == "ou=people,cn=example,dc=org" {
entry := r.NewSearchResponseEntry(
"ou=people,cn=example,dc=org",
gldap.WithAttributes(map[string][]string{
"objectclass": {"organizationalUnit"},
"ou": {"people"},
}),
)
w.Write(entry)
resp.SetResultCode(gldap.ResultSuccess)
}
}
return
}
8 changes: 8 additions & 0 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ func newRequest(id int, c *conn, p *packet) (*Request, error) {
return r, nil
}

// ConnectionID returns the request's connection ID which enables you to know
// "who" (i.e. which connection) made a request. Using the connection ID you
// can do things like ensure a connection performing a search operation has
// successfully authenticated (a.k.a. performed a successful bind operation).
func (r *Request) ConnectionID() int {
return r.conn.connID
}

// NewModifyResponse creates a modify response
// Supported options: WithResponseCode, WithDiagnosticMessage, WithMatchedDN
func (r *Request) NewModifyResponse(opt ...Option) *ModifyResponse {
Expand Down
16 changes: 16 additions & 0 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,19 @@ func TestRequest_GetDeleteMessage(t *testing.T) {
})
}
}

func TestRequest_GetConnectionID(t *testing.T) {
t.Parallel()
assert, require := assert.New(t), require.New(t)
const (
requestID = 1
connID = 2
)
conn := &conn{connID: connID}
packet := testSearchRequestPacket(t,
SearchMessage{baseMessage: baseMessage{id: 1}, Filter: "(uid=alice)"},
)
req, err := newRequest(requestID, conn, packet)
require.NoError(err)
assert.Equal(connID, req.ConnectionID())
}