Skip to content

Commit 2771600

Browse files
Calls RemoveSmbGlobalMapping when it necessary.
1 parent a408bba commit 2771600

File tree

5 files changed

+414
-13
lines changed

5 files changed

+414
-13
lines changed

pkg/mounter/refcounter_windows.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
//go:build windows
2+
// +build windows
3+
4+
/*
5+
Copyright 2020 The Kubernetes Authors.
6+
7+
Licensed under the Apache License, Version 2.0 (the "License");
8+
you may not use this file except in compliance with the License.
9+
You may obtain a copy of the License at
10+
11+
http://www.apache.org/licenses/LICENSE-2.0
12+
13+
Unless required by applicable law or agreed to in writing, software
14+
distributed under the License is distributed on an "AS IS" BASIS,
15+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
See the License for the specific language governing permissions and
17+
limitations under the License.
18+
*/
19+
20+
package mounter
21+
22+
import (
23+
"crypto/md5"
24+
"fmt"
25+
"os"
26+
"path/filepath"
27+
"strings"
28+
"sync"
29+
)
30+
31+
var basePath = "c:\\csi\\smbmounts"
32+
var mutexes sync.Map
33+
34+
func lock(key string) func() {
35+
value, _ := mutexes.LoadOrStore(key, &sync.Mutex{})
36+
mtx := value.(*sync.Mutex)
37+
mtx.Lock()
38+
39+
return func() { mtx.Unlock() }
40+
}
41+
42+
// getRootMappingPath - returns root of smb share path or empty string if the path is invalid. For example:
43+
//
44+
// \\hostname\share\subpath => \\hostname\share, error is nil
45+
// \\hostname\share => \\hostname\share, error is nil
46+
// \\hostname => '', error is 'remote path (\\hostname) is invalid'
47+
func getRootMappingPath(path string) (string, error) {
48+
items := strings.Split(path, "\\")
49+
parts := []string{}
50+
for _, s := range items {
51+
if len(s) > 0 {
52+
parts = append(parts, s)
53+
if len(parts) == 2 {
54+
break
55+
}
56+
}
57+
}
58+
if len(parts) != 2 {
59+
return "", fmt.Errorf("remote path (%s) is invalid", path)
60+
}
61+
// parts[0] is a smb host name
62+
// parts[1] is a smb share name
63+
return strings.ToLower("\\\\" + parts[0] + "\\" + parts[1]), nil
64+
}
65+
66+
// incementRemotePathReferencesCount - adds new reference between mappingPath and remotePath if it doesn't exist.
67+
// How it works:
68+
// 1. MappingPath contains two components: hostname, sharename
69+
// 2. We create directory in basePath related to each mappingPath. It will be used as container for references.
70+
// Example: c:\\csi\\smbmounts\\hostname\\sharename
71+
// 3. Each reference is a file with name based on MD5 of remotePath. For debug it also will contains remotePath in body of the file.
72+
// So, in incementRemotePathReferencesCount we create the file. In decrementRemotePathReferencesCount we remove the file.
73+
// Example: c:\\csi\\smbmounts\\hostname\\sharename\\092f1413e6c1d03af8b5da6f44619af8
74+
func incementRemotePathReferencesCount(mappingPath, remotePath string) error {
75+
remotePath = strings.TrimSuffix(remotePath, "\\")
76+
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
77+
if err := os.MkdirAll(path, os.ModeDir); err != nil {
78+
return err
79+
}
80+
filePath := filepath.Join(path, getMd5(remotePath))
81+
file, err := os.Create(filePath)
82+
if err != nil {
83+
return err
84+
}
85+
defer func() {
86+
file.Close()
87+
}()
88+
89+
_, err = file.WriteString(remotePath)
90+
return err
91+
}
92+
93+
// decrementRemotePathReferencesCount - removes reference between mappingPath and remotePath.
94+
// See incementRemotePathReferencesCount to understand how references work.
95+
func decrementRemotePathReferencesCount(mappingPath, remotePath string) error {
96+
remotePath = strings.TrimSuffix(remotePath, "\\")
97+
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
98+
if err := os.MkdirAll(path, os.ModeDir); err != nil {
99+
return err
100+
}
101+
filePath := filepath.Join(path, getMd5(remotePath))
102+
return os.Remove(filePath)
103+
}
104+
105+
// getRemotePathReferencesCount - returns count of references between mappingPath and remotePath.
106+
// See incementRemotePathReferencesCount to understand how references work.
107+
func getRemotePathReferencesCount(mappingPath string) int {
108+
path := filepath.Join(basePath, strings.TrimPrefix(mappingPath, "\\\\"))
109+
if os.MkdirAll(path, os.ModeDir) != nil {
110+
return -1
111+
}
112+
files, err := os.ReadDir(path)
113+
if err != nil {
114+
return -1
115+
}
116+
return len(files)
117+
}
118+
119+
func getMd5(path string) string {
120+
data := []byte(strings.ToLower(path))
121+
return fmt.Sprintf("%x", md5.Sum(data))
122+
}
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
/*
2+
Copyright 2020 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package mounter
18+
19+
import (
20+
"os"
21+
"testing"
22+
"time"
23+
24+
"github.com/stretchr/testify/assert"
25+
)
26+
27+
func TestLockUnlock(t *testing.T) {
28+
key := "resource name"
29+
30+
unlock := lock(key)
31+
defer unlock()
32+
33+
_, loaded := mutexes.Load(key)
34+
assert.True(t, loaded)
35+
}
36+
37+
func TestLockLockedResource(t *testing.T) {
38+
locked := true
39+
unlock := lock("a")
40+
go func() {
41+
time.Sleep(500 * time.Microsecond)
42+
locked = false
43+
unlock()
44+
}()
45+
46+
// try to lock already locked resource
47+
unlock2 := lock("a")
48+
defer unlock2()
49+
if locked {
50+
assert.Fail(t, "access to locked resource")
51+
}
52+
}
53+
54+
func TestLockDifferentKeys(t *testing.T) {
55+
unlocka := lock("a")
56+
unlockb := lock("b")
57+
unlocka()
58+
unlockb()
59+
}
60+
61+
func TestGetRootMappingPath(t *testing.T) {
62+
testCases := []struct {
63+
remote string
64+
expectResult string
65+
expectError bool
66+
}{
67+
{
68+
remote: "",
69+
expectResult: "",
70+
expectError: true,
71+
},
72+
{
73+
remote: "hostname",
74+
expectResult: "",
75+
expectError: true,
76+
},
77+
{
78+
remote: "\\\\hostname\\path",
79+
expectResult: "\\\\hostname\\path",
80+
expectError: false,
81+
},
82+
{
83+
remote: "\\\\hostname\\path\\",
84+
expectResult: "\\\\hostname\\path",
85+
expectError: false,
86+
},
87+
{
88+
remote: "\\\\hostname\\path\\subpath",
89+
expectResult: "\\\\hostname\\path",
90+
expectError: false,
91+
},
92+
}
93+
for _, tc := range testCases {
94+
result, err := getRootMappingPath(tc.remote)
95+
if tc.expectError && err == nil {
96+
t.Errorf("Expected error but getRootMappingPath returned a nil error")
97+
}
98+
if !tc.expectError {
99+
if err != nil {
100+
t.Errorf("Expected no errors but getRootMappingPath returned error: %v", err)
101+
}
102+
if tc.expectResult != result {
103+
t.Errorf("Expected (%s) but getRootMappingPath returned (%s)", tc.expectResult, result)
104+
}
105+
}
106+
}
107+
}
108+
109+
func TestRemotePathReferencesCounter(t *testing.T) {
110+
remotePath1 := "\\\\servername\\share\\subpath\\1"
111+
remotePath2 := "\\\\servername\\share\\subpath\\2"
112+
mappingPath, err := getRootMappingPath(remotePath1)
113+
assert.Nil(t, err)
114+
115+
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
116+
os.RemoveAll(basePath)
117+
defer func() {
118+
// cleanup temp folder
119+
os.RemoveAll(basePath)
120+
}()
121+
122+
// by default we have no any files in `mappingPath`. So, `count` should be zero
123+
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
124+
// add reference to `remotePath1`. So, `count` should be equal `1`
125+
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath1))
126+
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
127+
// add reference to `remotePath2`. So, `count` should be equal `2`
128+
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath2))
129+
assert.Equal(t, 2, getRemotePathReferencesCount(mappingPath))
130+
// remove reference to `remotePath1`. So, `count` should be equal `1`
131+
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath1))
132+
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
133+
// remove reference to `remotePath2`. So, `count` should be equal `0`
134+
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath2))
135+
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
136+
}
137+
138+
func TestIncementRemotePathReferencesCount(t *testing.T) {
139+
remotePath := "\\\\servername\\share\\subpath"
140+
mappingPath, err := getRootMappingPath(remotePath)
141+
assert.Nil(t, err)
142+
143+
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
144+
os.RemoveAll(basePath)
145+
defer func() {
146+
// cleanup temp folder
147+
os.RemoveAll(basePath)
148+
}()
149+
150+
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
151+
152+
mappingPathContainer := basePath + "\\servername\\share"
153+
if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() {
154+
t.Error("mapping file container does not exist")
155+
}
156+
157+
reference := mappingPathContainer + "\\" + getMd5(remotePath)
158+
if file, err := os.Stat(reference); os.IsNotExist(err) || file.IsDir() {
159+
t.Error("reference file does not exist")
160+
}
161+
}
162+
163+
func TestDecrementRemotePathReferencesCount(t *testing.T) {
164+
remotePath := "\\\\servername\\share\\subpath"
165+
mappingPath, err := getRootMappingPath(remotePath)
166+
assert.Nil(t, err)
167+
168+
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
169+
os.RemoveAll(basePath)
170+
defer func() {
171+
// cleanup temp folder
172+
os.RemoveAll(basePath)
173+
}()
174+
175+
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
176+
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))
177+
178+
mappingPathContainer := basePath + "\\servername\\share"
179+
if dir, err := os.Stat(mappingPathContainer); os.IsNotExist(err) || !dir.IsDir() {
180+
t.Error("mapping file container does not exist")
181+
}
182+
183+
reference := mappingPathContainer + "\\" + getMd5(remotePath)
184+
if _, err := os.Stat(reference); os.IsExist(err) {
185+
t.Error("reference file exists")
186+
}
187+
}
188+
189+
func TestMultiplyCallsOfIncementRemotePathReferencesCount(t *testing.T) {
190+
remotePath := "\\\\servername\\share\\subpath"
191+
mappingPath, err := getRootMappingPath(remotePath)
192+
assert.Nil(t, err)
193+
194+
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
195+
os.RemoveAll(basePath)
196+
defer func() {
197+
// cleanup temp folder
198+
os.RemoveAll(basePath)
199+
}()
200+
201+
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
202+
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
203+
// next calls of `incementMappingPathCount` with the same arguments should be ignored
204+
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
205+
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
206+
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
207+
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
208+
assert.Equal(t, 1, getRemotePathReferencesCount(mappingPath))
209+
}
210+
211+
func TestMultiplyCallsOfDecrementRemotePathReferencesCount(t *testing.T) {
212+
remotePath := "\\\\servername\\share\\subpath"
213+
mappingPath, err := getRootMappingPath(remotePath)
214+
assert.Nil(t, err)
215+
216+
basePath = os.Getenv("TEMP") + "\\TestMappingPathCounter"
217+
os.RemoveAll(basePath)
218+
defer func() {
219+
// cleanup temp folder
220+
os.RemoveAll(basePath)
221+
}()
222+
223+
assert.Zero(t, getRemotePathReferencesCount(mappingPath))
224+
assert.Nil(t, incementRemotePathReferencesCount(mappingPath, remotePath))
225+
assert.Nil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))
226+
assert.NotNil(t, decrementRemotePathReferencesCount(mappingPath, remotePath))
227+
}

0 commit comments

Comments
 (0)