Skip to content

Commit 42193fb

Browse files
committed
Add confugurable cachePath
1 parent 8b137fc commit 42193fb

File tree

9 files changed

+45
-23
lines changed

9 files changed

+45
-23
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ This library aims to require as little configuration as possible, favouring over
4242
| Password | postgres |
4343
| Database | postgres |
4444
| Version | 12.1.0 |
45+
| CachePath | $USER_HOME/.embedded-postgres-go/ |
4546
| RuntimePath | $USER_HOME/.embedded-postgres-go/extracted |
4647
| DataPath | $USER_HOME/.embedded-postgres-go/extracted/data |
4748
| BinariesPath | $USER_HOME/.embedded-postgres-go/extracted |

cache_locator.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@ import (
88

99
// CacheLocator retrieves the location of the Postgres binary cache returning it to location.
1010
// The result of whether this cache is present will be returned to exists.
11-
type CacheLocator func() (location string, exists bool)
11+
type CacheLocator func(cachePath string) (location string, exists bool)
1212

1313
func defaultCacheLocator(versionStrategy VersionStrategy) CacheLocator {
14-
return func() (string, bool) {
15-
cacheDirectory := ".embedded-postgres-go"
16-
if userHome, err := os.UserHomeDir(); err == nil {
17-
cacheDirectory = filepath.Join(userHome, ".embedded-postgres-go")
14+
return func(cacheDirectory string) (string, bool) {
15+
if cacheDirectory == "" {
16+
cacheDirectory = ".embedded-postgres-go"
17+
if userHome, err := os.UserHomeDir(); err == nil {
18+
cacheDirectory = filepath.Join(userHome, ".embedded-postgres-go")
19+
}
1820
}
1921

2022
operatingSystem, architecture, version := versionStrategy()

cache_locator_test.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,19 @@ func Test_defaultCacheLocator_NotExists(t *testing.T) {
1111
return "a", "b", "1.2.3"
1212
})
1313

14-
cacheLocation, exists := locator()
14+
cacheLocation, exists := locator("")
1515

1616
assert.Contains(t, cacheLocation, ".embedded-postgres-go/embedded-postgres-binaries-a-b-1.2.3.txz")
1717
assert.False(t, exists)
1818
}
19+
20+
func Test_defaultCacheLocator_CustomPath(t *testing.T) {
21+
locator := defaultCacheLocator(func() (string, string, PostgresVersion) {
22+
return "a", "b", "1.2.3"
23+
})
24+
25+
cacheLocation, exists := locator("/custom/path")
26+
27+
assert.Equal(t, cacheLocation, "/custom/path/embedded-postgres-binaries-a-b-1.2.3.txz")
28+
assert.False(t, exists)
29+
}

config.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ type Config struct {
1414
database string
1515
username string
1616
password string
17+
cachePath string
1718
runtimePath string
1819
dataPath string
1920
binariesPath string
@@ -82,6 +83,13 @@ func (c Config) RuntimePath(path string) Config {
8283
return c
8384
}
8485

86+
// CachePath sets the path that will be used for storing Postgres binaries archive.
87+
// If this option is not set, ~/.go-embedded-postgres will be used.
88+
func (c Config) CachePath(path string) Config {
89+
c.cachePath = path
90+
return c
91+
}
92+
8593
// DataPath sets the path that will be used for the Postgres data directory.
8694
// If this option is set, a previously initialized data directory will be reused if possible.
8795
func (c Config) DataPath(path string) Config {

embedded_postgres.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func (ep *EmbeddedPostgres) Start() error {
7777

7878
ep.syncedLogger = logger
7979

80-
cacheLocation, cacheExists := ep.cacheLocator()
80+
cacheLocation, cacheExists := ep.cacheLocator(ep.config.cachePath)
8181

8282
if ep.config.runtimePath == "" {
8383
ep.config.runtimePath = filepath.Join(filepath.Dir(cacheLocation), "extracted")

embedded_postgres_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func Test_ErrorWhenPortAlreadyTaken(t *testing.T) {
6666

6767
func Test_ErrorWhenRemoteFetchError(t *testing.T) {
6868
database := NewDatabase()
69-
database.cacheLocator = func() (string, bool) {
69+
database.cacheLocator = func(string) (string, bool) {
7070
return "", false
7171
}
7272
database.remoteFetchStrategy = func() error {
@@ -88,7 +88,7 @@ func Test_ErrorWhenUnableToUnArchiveFile_WrongFormat(t *testing.T) {
8888
Database("beer").
8989
StartTimeout(10 * time.Second))
9090

91-
database.cacheLocator = func() (string, bool) {
91+
database.cacheLocator = func(string) (string, bool) {
9292
return jarFile, true
9393
}
9494

@@ -119,7 +119,7 @@ func Test_ErrorWhenUnableToInitDatabase(t *testing.T) {
119119
RuntimePath(extractPath).
120120
StartTimeout(10 * time.Second))
121121

122-
database.cacheLocator = func() (string, bool) {
122+
database.cacheLocator = func(string) (string, bool) {
123123
return jarFile, true
124124
}
125125

@@ -222,7 +222,7 @@ func Test_ErrorWhenCannotStartPostgresProcess(t *testing.T) {
222222
database := NewDatabase(DefaultConfig().
223223
RuntimePath(extractPath))
224224

225-
database.cacheLocator = func() (string, bool) {
225+
database.cacheLocator = func(string) (string, bool) {
226226
return jarFile, true
227227
}
228228

@@ -360,7 +360,7 @@ func Test_ConcurrentStart(t *testing.T) {
360360
var wg sync.WaitGroup
361361

362362
database := NewDatabase()
363-
cacheLocation, _ := database.cacheLocator()
363+
cacheLocation, _ := database.cacheLocator("")
364364
err := os.RemoveAll(cacheLocation)
365365
require.NoError(t, err)
366366

@@ -644,7 +644,7 @@ func Test_CustomBinariesLocation(t *testing.T) {
644644
}
645645

646646
// Delete cache to make sure unarchive doesn't happen again.
647-
cacheLocation, _ := database.cacheLocator()
647+
cacheLocation, _ := database.cacheLocator("")
648648
if err := os.RemoveAll(cacheLocation); err != nil {
649649
panic(err)
650650
}
@@ -688,13 +688,13 @@ func Test_PrefetchedBinaries(t *testing.T) {
688688
panic(err)
689689
}
690690

691-
cacheLocation, _ := database.cacheLocator()
691+
cacheLocation, _ := database.cacheLocator("")
692692
if err := decompressTarXz(defaultTarReader, cacheLocation, binTempDir); err != nil {
693693
panic(err)
694694
}
695695

696696
// Expect everything to work without cacheLocator and/or remoteFetch abilities.
697-
database.cacheLocator = func() (string, bool) {
697+
database.cacheLocator = func(string) (string, bool) {
698698
return "", false
699699
}
700700
database.remoteFetchStrategy = func() error {

remote_fetch.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator Cach
8080
return errorFetchingPostgres(err)
8181
}
8282

83-
cacheLocation, _ := cacheLocator()
83+
cacheLocation, _ := cacheLocator("")
8484

8585
if err := os.MkdirAll(filepath.Dir(cacheLocation), 0755); err != nil {
8686
return errorExtractingPostgres(err)

remote_fetch_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotExtractSubArchive(t *testing
141141

142142
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
143143
testVersionStrategy(),
144-
func() (s string, b bool) {
144+
func(string) (s string, b bool) {
145145
return filepath.FromSlash("/invalid"), false
146146
})
147147

@@ -181,7 +181,7 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateCacheDirectory(t *test
181181

182182
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
183183
testVersionStrategy(),
184-
func() (s string, b bool) {
184+
func(string) (s string, b bool) {
185185
return cacheLocation, false
186186
})
187187

@@ -218,7 +218,7 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateSubArchiveFile(t *test
218218

219219
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
220220
testVersionStrategy(),
221-
func() (s string, b bool) {
221+
func(string) (s string, b bool) {
222222
return "/\\000", false
223223
})
224224

@@ -256,7 +256,7 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenSHA256NotMatch(t *testing.T) {
256256

257257
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
258258
testVersionStrategy(),
259-
func() (s string, b bool) {
259+
func(string) (s string, b bool) {
260260
return cacheLocation, false
261261
})
262262

@@ -295,7 +295,7 @@ func Test_defaultRemoteFetchStrategy(t *testing.T) {
295295

296296
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
297297
testVersionStrategy(),
298-
func() (s string, b bool) {
298+
func(string) (s string, b bool) {
299299
return cacheLocation, false
300300
})
301301

@@ -347,7 +347,7 @@ func Test_defaultRemoteFetchStrategyWithExistingDownload(t *testing.T) {
347347

348348
remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
349349
testVersionStrategy(),
350-
func() (s string, b bool) {
350+
func(string) (s string, b bool) {
351351
return cacheLocation, false
352352
})
353353

test_util_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func testVersionStrategy() VersionStrategy {
5555
}
5656

5757
func testCacheLocator() CacheLocator {
58-
return func() (s string, b bool) {
58+
return func(string) (s string, b bool) {
5959
return "", false
6060
}
6161
}

0 commit comments

Comments
 (0)