diff --git a/README.md b/README.md index c76b43d..7fea1fd 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ This library aims to require as little configuration as possible, favouring over | Version | 12.1.0 | | RuntimePath | $USER_HOME/.embedded-postgres-go/extracted | | DataPath | $USER_HOME/.embedded-postgres-go/extracted/data | +| BinariesPath | $USER_HOME/.embedded-postgres-go/extracted | | Port | 5432 | | StartTimeout | 15 Seconds | @@ -54,6 +55,12 @@ If a persistent data location is required, set *DataPath* to a directory outside If the *RuntimePath* directory is empty or already initialized but with an incompatible postgres version, it will be removed and Postgres reinitialized. +Postgres binaries will be downloaded and placed in *BinaryPath* if `BinaryPath/bin` doesn't exist. +If the directory does exist, whatever binary version is placed there will be used (no version check +is done). +If your test need to run multiple different versions of Postgres for different tests, make sure +*BinaryPath* is a subdirectory of *RuntimePath*. + A single Postgres instance can be created, started and stopped as follows ```go diff --git a/config.go b/config.go index 54bdb67..a2b2542 100644 --- a/config.go +++ b/config.go @@ -15,6 +15,7 @@ type Config struct { password string runtimePath string dataPath string + binariesPath string locale string startTimeout time.Duration logger io.Writer @@ -84,6 +85,13 @@ func (c Config) DataPath(path string) Config { return c } +// BinariesPath sets the path of the pre-downloaded postgres binaries. +// If this option is left unset, the binaries will be downloaded. +func (c Config) BinariesPath(path string) Config { + c.binariesPath = path + return c +} + // Locale sets the default locale for initdb func (c Config) Locale(locale string) Config { c.locale = locale diff --git a/embedded_postgres.go b/embedded_postgres.go index 5b4e142..f5a33d3 100644 --- a/embedded_postgres.go +++ b/embedded_postgres.go @@ -68,37 +68,54 @@ func (ep *EmbeddedPostgres) Start() error { return err } - cacheLocation, exists := ep.cacheLocator() - if !exists { - if err := ep.remoteFetchStrategy(); err != nil { - return err - } + cacheLocation, cacheExists := ep.cacheLocator() + + if ep.config.runtimePath == "" { + ep.config.runtimePath = filepath.Join(filepath.Dir(cacheLocation), "extracted") + } + + if ep.config.dataPath == "" { + ep.config.dataPath = filepath.Join(ep.config.runtimePath, "data") } - binaryExtractLocation := userRuntimePathOrDefault(ep.config.runtimePath, cacheLocation) - if err := os.RemoveAll(binaryExtractLocation); err != nil { - return fmt.Errorf("unable to clean up runtime directory %s with error: %s", binaryExtractLocation, err) + if err := os.RemoveAll(ep.config.runtimePath); err != nil { + return fmt.Errorf("unable to clean up runtime directory %s with error: %s", ep.config.runtimePath, err) } - if err := archiver.NewTarXz().Unarchive(cacheLocation, binaryExtractLocation); err != nil { - return fmt.Errorf("unable to extract postgres archive %s to %s", cacheLocation, binaryExtractLocation) + if ep.config.binariesPath == "" { + ep.config.binariesPath = ep.config.runtimePath + } + + _, binDirErr := os.Stat(filepath.Join(ep.config.binariesPath, "bin")) + if os.IsNotExist(binDirErr) { + if !cacheExists { + if err := ep.remoteFetchStrategy(); err != nil { + return err + } + } + + if err := archiver.NewTarXz().Unarchive(cacheLocation, ep.config.binariesPath); err != nil { + return fmt.Errorf("unable to extract postgres archive %s to %s", cacheLocation, ep.config.binariesPath) + } } - dataLocation := userDataPathOrDefault(ep.config.dataPath, binaryExtractLocation) + if err := os.MkdirAll(ep.config.runtimePath, 0755); err != nil { + return fmt.Errorf("unable to create runtime directory %s with error: %s", ep.config.runtimePath, err) + } - reuseData := ep.config.dataPath != "" && dataDirIsValid(dataLocation, ep.config.version) + reuseData := dataDirIsValid(ep.config.dataPath, ep.config.version) if !reuseData { - if err := os.RemoveAll(dataLocation); err != nil { - return fmt.Errorf("unable to clean up data directory %s with error: %s", dataLocation, err) + if err := os.RemoveAll(ep.config.dataPath); err != nil { + return fmt.Errorf("unable to clean up data directory %s with error: %s", ep.config.dataPath, err) } - if err := ep.initDatabase(binaryExtractLocation, dataLocation, ep.config.username, ep.config.password, ep.config.locale, ep.config.logger); err != nil { + if err := ep.initDatabase(ep.config.binariesPath, ep.config.runtimePath, ep.config.dataPath, ep.config.username, ep.config.password, ep.config.locale, ep.config.logger); err != nil { return err } } - if err := startPostgres(binaryExtractLocation, ep.config); err != nil { + if err := startPostgres(ep.config); err != nil { return err } @@ -106,7 +123,7 @@ func (ep *EmbeddedPostgres) Start() error { if !reuseData { if err := ep.createDatabase(ep.config.port, ep.config.username, ep.config.password, ep.config.database); err != nil { - if stopErr := stopPostgres(binaryExtractLocation, ep.config); stopErr != nil { + if stopErr := stopPostgres(ep.config); stopErr != nil { return fmt.Errorf("unable to stop database casused by error %s", err) } @@ -115,7 +132,7 @@ func (ep *EmbeddedPostgres) Start() error { } if err := healthCheckDatabaseOrTimeout(ep.config); err != nil { - if stopErr := stopPostgres(binaryExtractLocation, ep.config); stopErr != nil { + if stopErr := stopPostgres(ep.config); stopErr != nil { return fmt.Errorf("unable to stop database casused by error %s", err) } @@ -127,13 +144,11 @@ func (ep *EmbeddedPostgres) Start() error { // Stop will try to stop the Postgres process gracefully returning an error when there were any problems. func (ep *EmbeddedPostgres) Stop() error { - cacheLocation, exists := ep.cacheLocator() - if !exists || !ep.started { + if !ep.started { return errors.New("server has not been started") } - binaryExtractLocation := userRuntimePathOrDefault(ep.config.runtimePath, cacheLocation) - if err := stopPostgres(binaryExtractLocation, ep.config); err != nil { + if err := stopPostgres(ep.config); err != nil { return err } @@ -142,10 +157,10 @@ func (ep *EmbeddedPostgres) Stop() error { return nil } -func startPostgres(binaryExtractLocation string, config Config) error { - postgresBinary := filepath.Join(binaryExtractLocation, "bin/pg_ctl") +func startPostgres(config Config) error { + postgresBinary := filepath.Join(config.binariesPath, "bin/pg_ctl") postgresProcess := exec.Command(postgresBinary, "start", "-w", - "-D", userDataPathOrDefault(config.dataPath, binaryExtractLocation), + "-D", config.dataPath, "-o", fmt.Sprintf(`"-p %d"`, config.port)) postgresProcess.Stderr = config.logger postgresProcess.Stdout = config.logger @@ -157,10 +172,10 @@ func startPostgres(binaryExtractLocation string, config Config) error { return nil } -func stopPostgres(binaryExtractLocation string, config Config) error { - postgresBinary := filepath.Join(binaryExtractLocation, "bin/pg_ctl") +func stopPostgres(config Config) error { + postgresBinary := filepath.Join(config.binariesPath, "bin/pg_ctl") postgresProcess := exec.Command(postgresBinary, "stop", "-w", - "-D", userDataPathOrDefault(config.dataPath, binaryExtractLocation)) + "-D", config.dataPath) postgresProcess.Stderr = config.logger postgresProcess.Stdout = config.logger @@ -180,22 +195,6 @@ func ensurePortAvailable(port uint32) error { return nil } -func userRuntimePathOrDefault(userLocation, cacheLocation string) string { - if userLocation != "" { - return userLocation - } - - return filepath.Join(filepath.Dir(cacheLocation), "extracted") -} - -func userDataPathOrDefault(userLocation, runtimeLocation string) string { - if userLocation != "" { - return userLocation - } - - return filepath.Join(runtimeLocation, "data") -} - func dataDirIsValid(dataDir string, version PostgresVersion) bool { pgVersion := filepath.Join(dataDir, "PG_VERSION") diff --git a/embedded_postgres_test.go b/embedded_postgres_test.go index 528b08b..13b0cb1 100644 --- a/embedded_postgres_test.go +++ b/embedded_postgres_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/mholt/archiver/v3" "github.com/stretchr/testify/assert" ) @@ -118,7 +119,7 @@ func Test_ErrorWhenUnableToInitDatabase(t *testing.T) { return jarFile, true } - database.initDatabase = func(binaryExtractLocation, dataLocation, username, password, locale string, logger io.Writer) error { + database.initDatabase = func(binaryExtractLocation, runtimePath, dataLocation, username, password, locale string, logger io.Writer) error { return errors.New("ah it did not work") } @@ -221,7 +222,7 @@ func Test_ErrorWhenCannotStartPostgresProcess(t *testing.T) { return jarFile, true } - database.initDatabase = func(binaryExtractLocation, dataLocation, username, password, locale string, logger io.Writer) error { + database.initDatabase = func(binaryExtractLocation, runtimePath, dataLocation, username, password, locale string, logger io.Writer) error { return nil } @@ -424,3 +425,93 @@ func Test_ReuseData(t *testing.T) { shutdownDBAndFail(t, err, database) } } + +func Test_CustomBinariesLocation(t *testing.T) { + tempDir, err := ioutil.TempDir("", "prepare_database_test") + if err != nil { + panic(err) + } + + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + panic(err) + } + }() + + database := NewDatabase(DefaultConfig(). + BinariesPath(tempDir)) + + if err := database.Start(); err != nil { + shutdownDBAndFail(t, err, database) + } + + if err := database.Stop(); err != nil { + shutdownDBAndFail(t, err, database) + } + + // Delete cache to make sure unarchive doesn't happen again. + cacheLocation, _ := database.cacheLocator() + if err := os.RemoveAll(cacheLocation); err != nil { + panic(err) + } + + if err := database.Start(); err != nil { + shutdownDBAndFail(t, err, database) + } + + if err := database.Stop(); err != nil { + shutdownDBAndFail(t, err, database) + } +} + +func Test_PrefetchedBinaries(t *testing.T) { + binTempDir, err := ioutil.TempDir("", "prepare_database_test_bin") + if err != nil { + panic(err) + } + + runtimeTempDir, err := ioutil.TempDir("", "prepare_database_test_runtime") + if err != nil { + panic(err) + } + + defer func() { + if err := os.RemoveAll(binTempDir); err != nil { + panic(err) + } + + if err := os.RemoveAll(runtimeTempDir); err != nil { + panic(err) + } + }() + + database := NewDatabase(DefaultConfig(). + BinariesPath(binTempDir). + RuntimePath(runtimeTempDir)) + + // Download and unarchive postgres into the bindir. + if err := database.remoteFetchStrategy(); err != nil { + panic(err) + } + + cacheLocation, _ := database.cacheLocator() + if err := archiver.NewTarXz().Unarchive(cacheLocation, binTempDir); err != nil { + panic(err) + } + + // Expect everything to work without cacheLocator and/or remoteFetch abilities. + database.cacheLocator = func() (string, bool) { + return "", false + } + database.remoteFetchStrategy = func() error { + return errors.New("did not work") + } + + if err := database.Start(); err != nil { + shutdownDBAndFail(t, err, database) + } + + if err := database.Stop(); err != nil { + shutdownDBAndFail(t, err, database) + } +} diff --git a/prepare_database.go b/prepare_database.go index 76a5051..f644337 100644 --- a/prepare_database.go +++ b/prepare_database.go @@ -14,11 +14,11 @@ import ( "github.com/lib/pq" ) -type initDatabase func(binaryExtractLocation, pgDataDir, username, password, locale string, logger io.Writer) error +type initDatabase func(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, logger io.Writer) error type createDatabase func(port uint32, username, password, database string) error -func defaultInitDatabase(binaryExtractLocation, pgDataDir, username, password, locale string, logger io.Writer) error { - passwordFile, err := createPasswordFile(binaryExtractLocation, password) +func defaultInitDatabase(binaryExtractLocation, runtimePath, pgDataDir, username, password, locale string, logger io.Writer) error { + passwordFile, err := createPasswordFile(runtimePath, password) if err != nil { return err } @@ -50,8 +50,8 @@ func defaultInitDatabase(binaryExtractLocation, pgDataDir, username, password, l return nil } -func createPasswordFile(binaryExtractLocation, password string) (string, error) { - passwordFileLocation := filepath.Join(binaryExtractLocation, "pwfile") +func createPasswordFile(runtimePath, password string) (string, error) { + passwordFileLocation := filepath.Join(runtimePath, "pwfile") if err := ioutil.WriteFile(passwordFileLocation, []byte(password), 0600); err != nil { return "", fmt.Errorf("unable to write password file to %s", passwordFileLocation) } diff --git a/prepare_database_test.go b/prepare_database_test.go index 170898f..8f02d0a 100644 --- a/prepare_database_test.go +++ b/prepare_database_test.go @@ -11,30 +11,39 @@ import ( ) func Test_defaultInitDatabase_ErrorWhenCannotCreatePasswordFile(t *testing.T) { - err := defaultInitDatabase("path_not_exists", "path_not_exists", "Tom", "Beer", "", os.Stderr) + err := defaultInitDatabase("path_not_exists", "path_not_exists", "path_not_exists", "Tom", "Beer", "", os.Stderr) assert.EqualError(t, err, "unable to write password file to path_not_exists/pwfile") } func Test_defaultInitDatabase_ErrorWhenCannotStartInitDBProcess(t *testing.T) { - tempDir, err := ioutil.TempDir("", "prepare_database_test") + binTempDir, err := ioutil.TempDir("", "prepare_database_test_bin") + if err != nil { + panic(err) + } + + runtimeTempDir, err := ioutil.TempDir("", "prepare_database_test_runtime") if err != nil { panic(err) } defer func() { - if err := os.RemoveAll(tempDir); err != nil { + if err := os.RemoveAll(binTempDir); err != nil { + panic(err) + } + + if err := os.RemoveAll(runtimeTempDir); err != nil { panic(err) } }() - err = defaultInitDatabase(tempDir, filepath.Join(tempDir, "data"), "Tom", "Beer", "", os.Stderr) + err = defaultInitDatabase(binTempDir, runtimeTempDir, filepath.Join(runtimeTempDir, "data"), "Tom", "Beer", "", os.Stderr) assert.EqualError(t, err, fmt.Sprintf("unable to init database using: %s/bin/initdb -A password -U Tom -D %s/data --pwfile=%s/pwfile", - tempDir, - tempDir, - tempDir)) - assert.FileExists(t, filepath.Join(tempDir, "pwfile")) + binTempDir, + runtimeTempDir, + runtimeTempDir)) + assert.FileExists(t, filepath.Join(runtimeTempDir, "pwfile")) } func Test_defaultInitDatabase_ErrorInvalidLocaleSetting(t *testing.T) { @@ -49,7 +58,7 @@ func Test_defaultInitDatabase_ErrorInvalidLocaleSetting(t *testing.T) { } }() - err = defaultInitDatabase(tempDir, filepath.Join(tempDir, "data"), "postgres", "postgres", "en_XY", os.Stderr) + err = defaultInitDatabase(tempDir, tempDir, filepath.Join(tempDir, "data"), "postgres", "postgres", "en_XY", os.Stderr) assert.EqualError(t, err, fmt.Sprintf("unable to init database using: %s/bin/initdb -A password -U postgres -D %s/data --pwfile=%s/pwfile --locale=en_XY", tempDir,