From 7c147e4e6ef78fd56233ed7f3bec8a615e6aa76a Mon Sep 17 00:00:00 2001 From: Aurumzhoom Lee Date: Sat, 25 May 2024 00:44:45 +0800 Subject: [PATCH] fix: search file by supported extension use manually specific config type first --- file.go | 17 +++++- file_test.go | 142 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+), 2 deletions(-) create mode 100644 file_test.go diff --git a/file.go b/file.go index a54fe5a7a..ea6762987 100644 --- a/file.go +++ b/file.go @@ -29,11 +29,24 @@ func (v *Viper) searchInPath(in string) (filename string) { for _, ext := range SupportedExts { v.logger.Debug("checking if file exists", "file", filepath.Join(in, v.configName+"."+ext)) if b, _ := exists(v.fs, filepath.Join(in, v.configName+"."+ext)); b { - v.logger.Debug("found file", "file", filepath.Join(in, v.configName+"."+ext)) - return filepath.Join(in, v.configName+"."+ext) + // record the first found + if len(filename) < 1 { + filename = filepath.Join(in, v.configName+"."+ext) + } + // if specific configType are same with the current extension type, then return current + if v.configType == ext { + filename = filepath.Join(in, v.configName+"."+ext) + break + } } } + // return only if file exists + if len(filename) > 0 { + v.logger.Debug("found file", "file", filename) + return filename + } + if v.configType != "" { if b, _ := exists(v.fs, filepath.Join(in, v.configName)); b { return filepath.Join(in, v.configName) diff --git a/file_test.go b/file_test.go new file mode 100644 index 000000000..cd8162006 --- /dev/null +++ b/file_test.go @@ -0,0 +1,142 @@ +package viper + +import ( + "fmt" + "testing" + + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/spf13/viper/internal/testutil" +) + +func Test_searchInPath(t *testing.T) { + t.Run("find file without extension", func(t *testing.T) { + fs := afero.NewMemMapFs() + + err := fs.Mkdir(testutil.AbsFilePath(t, "/etc/viper"), 0o777) + require.NoError(t, err) + + file, err := fs.Create(testutil.AbsFilePath(t, "/etc/viper/config")) + require.NoError(t, err) + + _, err = file.WriteString(`key: value`) + require.NoError(t, err) + + file.Close() + + v := New() + + v.SetFs(fs) + v.AddConfigPath("/etc/viper") + v.SetConfigType("yaml") + + err = v.ReadInConfig() + require.NoError(t, err) + + assert.Equal(t, "value", v.Get("key")) + }) + + t.Run("cannot find file with extension", func(t *testing.T) { + fs := afero.NewMemMapFs() + + err := fs.Mkdir(testutil.AbsFilePath(t, "/etc/viper"), 0o777) + require.NoError(t, err) + + v := New() + + v.SetFs(fs) + v.AddConfigPath("/etc/viper") + v.SetConfigType("yaml") + + err = v.ReadInConfig() + require.EqualError(t, err, ConfigFileNotFoundError{v.configName, fmt.Sprintf("%s", v.configPaths)}.Error()) + }) + + t.Run("find file with supported extension", func(t *testing.T) { + fs := afero.NewMemMapFs() + + err := fs.Mkdir(testutil.AbsFilePath(t, "/etc/viper"), 0o777) + require.NoError(t, err) + + file, err := fs.Create(testutil.AbsFilePath(t, "/etc/viper/config.yaml")) + require.NoError(t, err) + + _, err = file.WriteString(`key: value`) + require.NoError(t, err) + + file.Close() + + v := New() + + v.SetFs(fs) + v.AddConfigPath("/etc/viper") + v.SetConfigType("yaml") + + err = v.ReadInConfig() + require.NoError(t, err) + + assert.Equal(t, "value", v.Get("key")) + }) + + t.Run("find file with specific extension from multiple supported files", func(t *testing.T) { + fs := afero.NewMemMapFs() + + err := fs.Mkdir(testutil.AbsFilePath(t, "/etc/viper"), 0o777) + require.NoError(t, err) + + jsonFile, err := fs.Create(testutil.AbsFilePath(t, "/etc/viper/config.json")) + require.NoError(t, err) + + jsonFile.Close() + + file, err := fs.Create(testutil.AbsFilePath(t, "/etc/viper/config.yaml")) + require.NoError(t, err) + + _, err = file.WriteString(`key: value`) + require.NoError(t, err) + + file.Close() + + v := New() + + v.SetFs(fs) + v.AddConfigPath("/etc/viper") + v.SetConfigType("yaml") + + err = v.ReadInConfig() + require.NoError(t, err) + + assert.Equal(t, "value", v.Get("key")) + }) + + t.Run("find file with unsupported extension", func(t *testing.T) { + fs := afero.NewMemMapFs() + + err := fs.Mkdir(testutil.AbsFilePath(t, "/etc/viper"), 0o777) + require.NoError(t, err) + + jsonFile, err := fs.Create(testutil.AbsFilePath(t, "/etc/viper/config.json")) + require.NoError(t, err) + + jsonFile.Close() + + file, err := fs.Create(testutil.AbsFilePath(t, "/etc/viper/config.yaml")) + require.NoError(t, err) + + _, err = file.WriteString(`key: value`) + require.NoError(t, err) + + file.Close() + + v := New() + + v.SetFs(fs) + v.AddConfigPath("/etc/viper") + v.SetConfigType("xyz") + + err = v.ReadInConfig() + require.EqualError(t, err, UnsupportedConfigError("xyz").Error()) + }) +}