diff --git a/.gitignore b/.gitignore index 3dafce90..301a677c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /bin/ settings.json integrationtests/integrationtests.test.exe +.vs/ \ No newline at end of file diff --git a/pkg/server/smb/server.go b/pkg/server/smb/server.go index 6340270c..76a74598 100644 --- a/pkg/server/smb/server.go +++ b/pkg/server/smb/server.go @@ -25,6 +25,26 @@ func normalizeWindowsPath(path string) string { return normalizedPath } +func getRootMappingPath(path string) (string, error) { + items := strings.Split(path, "\\") + parts := []string{} + for _, s := range items { + if len(s) > 0 { + parts = append(parts, s) + if len(parts) == 2 { + break + } + } + } + if len(parts) != 2 { + klog.Errorf("remote path (%s) is invalid", path) + return "", fmt.Errorf("remote path (%s) is invalid", path) + } + // parts[0] is a smb host name + // parts[1] is a smb share name + return strings.ToLower("\\\\" + parts[0] + "\\" + parts[1]), nil +} + func NewServer(hostAPI smb.API, fsServer *fsserver.Server) (*Server, error) { return &Server{ hostAPI: hostAPI, @@ -43,31 +63,40 @@ func (s *Server) NewSmbGlobalMapping(context context.Context, request *internal. return response, fmt.Errorf("remote path is empty") } - isMapped, err := s.hostAPI.IsSmbMapped(remotePath) + mappingPath, err := getRootMappingPath(remotePath) + if err != nil { + return response, err + } + + isMapped, err := s.hostAPI.IsSmbMapped(mappingPath) if err != nil { isMapped = false } if isMapped { - valid, err := s.fsServer.PathValid(context, remotePath) + klog.V(4).Infof("Remote %s already mapped. Validating...", mappingPath) + + valid, err := s.fsServer.PathValid(context, mappingPath) if err != nil { - klog.Warningf("PathValid(%s) failed with %v, ignore error", remotePath, err) + klog.Warningf("PathValid(%s) failed with %v, ignore error", mappingPath, err) } if !valid { - klog.V(4).Infof("RemotePath %s is not valid, removing now", remotePath) - err := s.hostAPI.RemoveSmbGlobalMapping(remotePath) + klog.V(4).Infof("RemotePath %s is not valid, removing now", mappingPath) + err := s.hostAPI.RemoveSmbGlobalMapping(mappingPath) if err != nil { - klog.Errorf("RemoveSmbGlobalMapping(%s) failed with %v", remotePath, err) + klog.Errorf("RemoveSmbGlobalMapping(%s) failed with %v", mappingPath, err) return response, err } isMapped = false + } else { + klog.V(4).Infof("RemotePath %s is valid", mappingPath) } } if !isMapped { - klog.V(4).Infof("Remote %s not mapped. Mapping now!", remotePath) - err := s.hostAPI.NewSmbGlobalMapping(remotePath, request.Username, request.Password) + klog.V(4).Infof("Remote %s not mapped. Mapping now!", mappingPath) + err := s.hostAPI.NewSmbGlobalMapping(mappingPath, request.Username, request.Password) if err != nil { klog.Errorf("failed NewSmbGlobalMapping %v", err) return response, err @@ -75,6 +104,7 @@ func (s *Server) NewSmbGlobalMapping(context context.Context, request *internal. } if len(localPath) != 0 { + klog.V(4).Infof("ValidatePluginPath: '%s'", localPath) err = s.fsServer.ValidatePluginPath(localPath) if err != nil { klog.Errorf("failed validate plugin path %v", err) @@ -101,11 +131,17 @@ func (s *Server) RemoveSmbGlobalMapping(context context.Context, request *intern return response, fmt.Errorf("remote path is empty") } - err := s.hostAPI.RemoveSmbGlobalMapping(remotePath) + mappingPath, err := getRootMappingPath(remotePath) + if err != nil { + return response, err + } + + err = s.hostAPI.RemoveSmbGlobalMapping(mappingPath) if err != nil { klog.Errorf("failed RemoveSmbGlobalMapping %v", err) return response, err } + klog.V(2).Infof("RemoveSmbGlobalMapping on remote path %q is completed", request.RemotePath) return response, nil } diff --git a/pkg/server/smb/server_test.go b/pkg/server/smb/server_test.go index e2b2e649..5706d0ca 100644 --- a/pkg/server/smb/server_test.go +++ b/pkg/server/smb/server_test.go @@ -79,7 +79,7 @@ func TestNewSmbGlobalMapping(t *testing.T) { expectError: true, }, { - remote: "\\test\\path", + remote: "\\\\hostname\\path", username: "", password: "", version: v1, @@ -111,3 +111,51 @@ func TestNewSmbGlobalMapping(t *testing.T) { } } } + +func TestGetRootMappingPath(t *testing.T) { + testCases := []struct { + remote string + expectResult string + expectError bool + }{ + { + remote: "", + expectResult: "", + expectError: true, + }, + { + remote: "hostname", + expectResult: "", + expectError: true, + }, + { + remote: "\\\\hostname\\path", + expectResult: "\\\\hostname\\path", + expectError: false, + }, + { + remote: "\\\\hostname\\path\\", + expectResult: "\\\\hostname\\path", + expectError: false, + }, + { + remote: "\\\\hostname\\path\\subpath", + expectResult: "\\\\hostname\\path", + expectError: false, + }, + } + for _, tc := range testCases { + result, err := getRootMappingPath(tc.remote) + if tc.expectError && err == nil { + t.Errorf("Expected error but getRootMappingPath returned a nil error") + } + if !tc.expectError { + if err != nil { + t.Errorf("Expected no errors but getRootMappingPath returned error: %v", err) + } + if tc.expectResult != result { + t.Errorf("Expected (%s) but getRootMappingPath returned (%s)", tc.expectResult, result) + } + } + } +}