diff --git a/internal/api/v1/handlers.go b/internal/api/v1/handlers.go index b2f0402..cc75058 100644 --- a/internal/api/v1/handlers.go +++ b/internal/api/v1/handlers.go @@ -18,6 +18,7 @@ import ( "github.com/wild-cloud/wild-central/daemon/internal/instance" "github.com/wild-cloud/wild-central/daemon/internal/operations" "github.com/wild-cloud/wild-central/daemon/internal/secrets" + "github.com/wild-cloud/wild-central/daemon/internal/storage" "github.com/wild-cloud/wild-central/daemon/internal/tools" ) @@ -302,7 +303,7 @@ func (api *API) GetConfig(w http.ResponseWriter, r *http.Request) { } // updateYAMLFile updates a YAML file with the provided key-value pairs -func (api *API) updateYAMLFile(w http.ResponseWriter, r *http.Request, instanceName, fileType string, updateFunc func(string, string, string) error) { +func (api *API) updateYAMLFile(w http.ResponseWriter, r *http.Request, instanceName, fileType string) { if err := api.instance.ValidateInstance(instanceName); err != nil { respondError(w, http.StatusNotFound, fmt.Sprintf("Instance not found: %v", err)) return @@ -327,13 +328,44 @@ func (api *API) updateYAMLFile(w http.ResponseWriter, r *http.Request, instanceN filePath = api.instance.GetInstanceSecretsPath(instanceName) } - // Update each key-value pair - for key, value := range updates { - valueStr := fmt.Sprintf("%v", value) - if err := updateFunc(filePath, key, valueStr); err != nil { - respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update %s key %s: %v", fileType, key, err)) + // Read existing config/secrets file + existingContent, err := storage.ReadFile(filePath) + if err != nil && !os.IsNotExist(err) { + respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to read existing %s: %v", fileType, err)) + return + } + + // Parse existing content or initialize empty map + var existingConfig map[string]interface{} + if len(existingContent) > 0 { + if err := yaml.Unmarshal(existingContent, &existingConfig); err != nil { + respondError(w, http.StatusBadRequest, fmt.Sprintf("Failed to parse existing %s: %v", fileType, err)) return } + } else { + existingConfig = make(map[string]interface{}) + } + + // Merge updates into existing config (shallow merge for top-level keys) + // This preserves unmodified keys while updating specified ones + for key, value := range updates { + existingConfig[key] = value + } + + // Marshal the merged config back to YAML with proper formatting + yamlContent, err := yaml.Marshal(existingConfig) + if err != nil { + respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to marshal YAML: %v", err)) + return + } + + // Write the complete merged YAML content to the file with proper locking + lockPath := filePath + ".lock" + if err := storage.WithLock(lockPath, func() error { + return storage.WriteFile(filePath, yamlContent, 0644) + }); err != nil { + respondError(w, http.StatusInternalServerError, fmt.Sprintf("Failed to update %s: %v", fileType, err)) + return } // Capitalize first letter of fileType for message @@ -351,7 +383,7 @@ func (api *API) updateYAMLFile(w http.ResponseWriter, r *http.Request, instanceN func (api *API) UpdateConfig(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) name := vars["name"] - api.updateYAMLFile(w, r, name, "config", api.config.SetConfigValue) + api.updateYAMLFile(w, r, name, "config") } // GetSecrets retrieves instance secrets (redacted by default) @@ -399,7 +431,7 @@ func (api *API) GetSecrets(w http.ResponseWriter, r *http.Request) { func (api *API) UpdateSecrets(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) name := vars["name"] - api.updateYAMLFile(w, r, name, "secrets", api.secrets.SetSecret) + api.updateYAMLFile(w, r, name, "secrets") } // GetContext retrieves current context diff --git a/internal/api/v1/handlers_config_test.go b/internal/api/v1/handlers_config_test.go new file mode 100644 index 0000000..547f382 --- /dev/null +++ b/internal/api/v1/handlers_config_test.go @@ -0,0 +1,656 @@ +package v1 + +import ( + "bytes" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/gorilla/mux" + "gopkg.in/yaml.v3" + + "github.com/wild-cloud/wild-central/daemon/internal/storage" +) + +func setupTestAPI(t *testing.T) (*API, string) { + tmpDir := t.TempDir() + appsDir := filepath.Join(tmpDir, "apps") + + api, err := NewAPI(tmpDir, appsDir) + if err != nil { + t.Fatalf("Failed to create test API: %v", err) + } + + return api, tmpDir +} + +func createTestInstance(t *testing.T, api *API, name string) { + if err := api.instance.CreateInstance(name); err != nil { + t.Fatalf("Failed to create test instance: %v", err) + } +} + +func TestUpdateYAMLFile_DeltaUpdate(t *testing.T) { + api, _ := setupTestAPI(t) + instanceName := "test-instance" + createTestInstance(t, api, instanceName) + + configPath := api.instance.GetInstanceConfigPath(instanceName) + + // Create initial config + initialConfig := map[string]interface{}{ + "domain": "old.com", + "email": "admin@old.com", + "cluster": map[string]interface{}{ + "name": "test-cluster", + }, + } + initialYAML, _ := yaml.Marshal(initialConfig) + if err := storage.WriteFile(configPath, initialYAML, 0644); err != nil { + t.Fatalf("Failed to write initial config: %v", err) + } + + // Update only domain + updateData := map[string]interface{}{ + "domain": "new.com", + } + updateYAML, _ := yaml.Marshal(updateData) + + req := httptest.NewRequest("PUT", "/api/v1/instances/"+instanceName+"/config", bytes.NewBuffer(updateYAML)) + w := httptest.NewRecorder() + + vars := map[string]string{"name": instanceName} + req = mux.SetURLVars(req, vars) + + api.UpdateConfig(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify merged config + resultData, err := storage.ReadFile(configPath) + if err != nil { + t.Fatalf("Failed to read result: %v", err) + } + + var result map[string]interface{} + if err := yaml.Unmarshal(resultData, &result); err != nil { + t.Fatalf("Failed to parse result: %v", err) + } + + // Domain should be updated + if result["domain"] != "new.com" { + t.Errorf("Expected domain='new.com', got %v", result["domain"]) + } + + // Email should be preserved + if result["email"] != "admin@old.com" { + t.Errorf("Expected email='admin@old.com', got %v", result["email"]) + } + + // Cluster should be preserved + if cluster, ok := result["cluster"].(map[string]interface{}); !ok { + t.Errorf("Cluster not preserved as map") + } else if cluster["name"] != "test-cluster" { + t.Errorf("Cluster name not preserved") + } +} + +func TestUpdateYAMLFile_FullReplacement(t *testing.T) { + api, _ := setupTestAPI(t) + instanceName := "test-instance" + createTestInstance(t, api, instanceName) + + configPath := api.instance.GetInstanceConfigPath(instanceName) + + // Create initial config + initialConfig := map[string]interface{}{ + "domain": "old.com", + "email": "admin@old.com", + "oldKey": "oldValue", + } + initialYAML, _ := yaml.Marshal(initialConfig) + if err := storage.WriteFile(configPath, initialYAML, 0644); err != nil { + t.Fatalf("Failed to write initial config: %v", err) + } + + // Full replacement + newConfig := map[string]interface{}{ + "domain": "new.com", + "email": "new@new.com", + "newKey": "newValue", + } + newYAML, _ := yaml.Marshal(newConfig) + + req := httptest.NewRequest("PUT", "/api/v1/instances/"+instanceName+"/config", bytes.NewBuffer(newYAML)) + w := httptest.NewRecorder() + + vars := map[string]string{"name": instanceName} + req = mux.SetURLVars(req, vars) + + api.UpdateConfig(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify result + resultData, err := storage.ReadFile(configPath) + if err != nil { + t.Fatalf("Failed to read result: %v", err) + } + + var result map[string]interface{} + if err := yaml.Unmarshal(resultData, &result); err != nil { + t.Fatalf("Failed to parse result: %v", err) + } + + // All new values should be present + if result["domain"] != "new.com" { + t.Errorf("Expected domain='new.com', got %v", result["domain"]) + } + if result["email"] != "new@new.com" { + t.Errorf("Expected email='new@new.com', got %v", result["email"]) + } + if result["newKey"] != "newValue" { + t.Errorf("Expected newKey='newValue', got %v", result["newKey"]) + } + + // Old key should still be present (shallow merge) + if result["oldKey"] != "oldValue" { + t.Errorf("Expected oldKey='oldValue', got %v", result["oldKey"]) + } +} + +func TestUpdateYAMLFile_NestedStructure(t *testing.T) { + api, _ := setupTestAPI(t) + instanceName := "test-instance" + createTestInstance(t, api, instanceName) + + configPath := api.instance.GetInstanceConfigPath(instanceName) + + // Update with nested structure + updateData := map[string]interface{}{ + "cloud": map[string]interface{}{ + "domain": "test.com", + "dns": map[string]interface{}{ + "ip": "1.2.3.4", + "port": 53, + }, + }, + } + updateYAML, _ := yaml.Marshal(updateData) + + req := httptest.NewRequest("PUT", "/api/v1/instances/"+instanceName+"/config", bytes.NewBuffer(updateYAML)) + w := httptest.NewRecorder() + + vars := map[string]string{"name": instanceName} + req = mux.SetURLVars(req, vars) + + api.UpdateConfig(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify nested structure preserved + resultData, err := storage.ReadFile(configPath) + if err != nil { + t.Fatalf("Failed to read result: %v", err) + } + + var result map[string]interface{} + if err := yaml.Unmarshal(resultData, &result); err != nil { + t.Fatalf("Failed to parse result: %v", err) + } + + // Verify nested structure is proper YAML, not Go map notation + resultStr := string(resultData) + if bytes.Contains(resultData, []byte("map[")) { + t.Errorf("Result contains Go map notation: %s", resultStr) + } + + // Verify structure is accessible + cloud, ok := result["cloud"].(map[string]interface{}) + if !ok { + t.Fatalf("cloud is not a map: %T", result["cloud"]) + } + + if cloud["domain"] != "test.com" { + t.Errorf("Expected cloud.domain='test.com', got %v", cloud["domain"]) + } + + dns, ok := cloud["dns"].(map[string]interface{}) + if !ok { + t.Fatalf("cloud.dns is not a map: %T", cloud["dns"]) + } + + if dns["ip"] != "1.2.3.4" { + t.Errorf("Expected dns.ip='1.2.3.4', got %v", dns["ip"]) + } + if dns["port"] != 53 { + t.Errorf("Expected dns.port=53, got %v", dns["port"]) + } +} + +func TestUpdateYAMLFile_EmptyFileCreation(t *testing.T) { + api, _ := setupTestAPI(t) + instanceName := "test-instance" + createTestInstance(t, api, instanceName) + + configPath := api.instance.GetInstanceConfigPath(instanceName) + + // Truncate the config file to make it empty (but still exists) + if err := storage.WriteFile(configPath, []byte(""), 0644); err != nil { + t.Fatalf("Failed to empty config file: %v", err) + } + + // Update should populate empty file + updateData := map[string]interface{}{ + "domain": "new.com", + "email": "admin@new.com", + } + updateYAML, _ := yaml.Marshal(updateData) + + req := httptest.NewRequest("PUT", "/api/v1/instances/"+instanceName+"/config", bytes.NewBuffer(updateYAML)) + w := httptest.NewRecorder() + + vars := map[string]string{"name": instanceName} + req = mux.SetURLVars(req, vars) + + api.UpdateConfig(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify content + resultData, err := storage.ReadFile(configPath) + if err != nil { + t.Fatalf("Failed to read result: %v", err) + } + + var result map[string]interface{} + if err := yaml.Unmarshal(resultData, &result); err != nil { + t.Fatalf("Failed to parse result: %v", err) + } + + if result["domain"] != "new.com" { + t.Errorf("Expected domain='new.com', got %v", result["domain"]) + } + if result["email"] != "admin@new.com" { + t.Errorf("Expected email='admin@new.com', got %v", result["email"]) + } +} + +func TestUpdateYAMLFile_EmptyUpdate(t *testing.T) { + api, _ := setupTestAPI(t) + instanceName := "test-instance" + createTestInstance(t, api, instanceName) + + configPath := api.instance.GetInstanceConfigPath(instanceName) + + // Create initial config + initialConfig := map[string]interface{}{ + "domain": "test.com", + } + initialYAML, _ := yaml.Marshal(initialConfig) + if err := storage.WriteFile(configPath, initialYAML, 0644); err != nil { + t.Fatalf("Failed to write initial config: %v", err) + } + + // Empty update + updateData := map[string]interface{}{} + updateYAML, _ := yaml.Marshal(updateData) + + req := httptest.NewRequest("PUT", "/api/v1/instances/"+instanceName+"/config", bytes.NewBuffer(updateYAML)) + w := httptest.NewRecorder() + + vars := map[string]string{"name": instanceName} + req = mux.SetURLVars(req, vars) + + api.UpdateConfig(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify file unchanged + resultData, err := storage.ReadFile(configPath) + if err != nil { + t.Fatalf("Failed to read result: %v", err) + } + + var result map[string]interface{} + if err := yaml.Unmarshal(resultData, &result); err != nil { + t.Fatalf("Failed to parse result: %v", err) + } + + if result["domain"] != "test.com" { + t.Errorf("Expected domain='test.com', got %v", result["domain"]) + } +} + +func TestUpdateYAMLFile_YAMLFormatting(t *testing.T) { + api, _ := setupTestAPI(t) + instanceName := "test-instance" + createTestInstance(t, api, instanceName) + + configPath := api.instance.GetInstanceConfigPath(instanceName) + + // Update with complex nested structure + updateData := map[string]interface{}{ + "cloud": map[string]interface{}{ + "domain": "test.com", + "dns": map[string]interface{}{ + "ip": "1.2.3.4", + }, + }, + "cluster": map[string]interface{}{ + "nodes": []interface{}{ + map[string]interface{}{ + "name": "node1", + "ip": "10.0.0.1", + }, + map[string]interface{}{ + "name": "node2", + "ip": "10.0.0.2", + }, + }, + }, + } + updateYAML, _ := yaml.Marshal(updateData) + + req := httptest.NewRequest("PUT", "/api/v1/instances/"+instanceName+"/config", bytes.NewBuffer(updateYAML)) + w := httptest.NewRecorder() + + vars := map[string]string{"name": instanceName} + req = mux.SetURLVars(req, vars) + + api.UpdateConfig(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify YAML formatting + resultData, err := storage.ReadFile(configPath) + if err != nil { + t.Fatalf("Failed to read result: %v", err) + } + + resultStr := string(resultData) + + // Should not contain Go map notation + if bytes.Contains(resultData, []byte("map[")) { + t.Errorf("Result contains Go map notation: %s", resultStr) + } + + // Should be valid YAML + var result map[string]interface{} + if err := yaml.Unmarshal(resultData, &result); err != nil { + t.Fatalf("Result is not valid YAML: %v", err) + } + + // Should have proper indentation (check for nested structure indicators) + if !bytes.Contains(resultData, []byte(" ")) { + t.Error("Result appears to lack proper indentation") + } +} + +func TestUpdateYAMLFile_InvalidYAML(t *testing.T) { + api, _ := setupTestAPI(t) + instanceName := "test-instance" + createTestInstance(t, api, instanceName) + + // Send invalid YAML + invalidYAML := []byte("invalid: yaml: content: [") + + req := httptest.NewRequest("PUT", "/api/v1/instances/"+instanceName+"/config", bytes.NewBuffer(invalidYAML)) + w := httptest.NewRecorder() + + vars := map[string]string{"name": instanceName} + req = mux.SetURLVars(req, vars) + + api.UpdateConfig(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Expected status 400, got %d", w.Code) + } +} + +func TestUpdateYAMLFile_InvalidInstance(t *testing.T) { + api, _ := setupTestAPI(t) + + updateData := map[string]interface{}{ + "domain": "test.com", + } + updateYAML, _ := yaml.Marshal(updateData) + + req := httptest.NewRequest("PUT", "/api/v1/instances/nonexistent/config", bytes.NewBuffer(updateYAML)) + w := httptest.NewRecorder() + + vars := map[string]string{"name": "nonexistent"} + req = mux.SetURLVars(req, vars) + + api.UpdateConfig(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("Expected status 404, got %d", w.Code) + } +} + +func TestUpdateYAMLFile_FilePermissions(t *testing.T) { + api, _ := setupTestAPI(t) + instanceName := "test-instance" + createTestInstance(t, api, instanceName) + + configPath := api.instance.GetInstanceConfigPath(instanceName) + + updateData := map[string]interface{}{ + "domain": "test.com", + } + updateYAML, _ := yaml.Marshal(updateData) + + req := httptest.NewRequest("PUT", "/api/v1/instances/"+instanceName+"/config", bytes.NewBuffer(updateYAML)) + w := httptest.NewRecorder() + + vars := map[string]string{"name": instanceName} + req = mux.SetURLVars(req, vars) + + api.UpdateConfig(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Check file permissions + info, err := os.Stat(configPath) + if err != nil { + t.Fatalf("Failed to stat config file: %v", err) + } + + expectedPerm := os.FileMode(0644) + if info.Mode().Perm() != expectedPerm { + t.Errorf("Expected permissions %v, got %v", expectedPerm, info.Mode().Perm()) + } +} + +func TestUpdateYAMLFile_UpdateSecrets(t *testing.T) { + api, _ := setupTestAPI(t) + instanceName := "test-instance" + createTestInstance(t, api, instanceName) + + secretsPath := api.instance.GetInstanceSecretsPath(instanceName) + + // Update secrets + updateData := map[string]interface{}{ + "dbPassword": "secret123", + "apiKey": "key456", + } + updateYAML, _ := yaml.Marshal(updateData) + + req := httptest.NewRequest("PUT", "/api/v1/instances/"+instanceName+"/secrets", bytes.NewBuffer(updateYAML)) + w := httptest.NewRecorder() + + vars := map[string]string{"name": instanceName} + req = mux.SetURLVars(req, vars) + + api.UpdateSecrets(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify secrets file created and contains data + resultData, err := storage.ReadFile(secretsPath) + if err != nil { + t.Fatalf("Failed to read secrets: %v", err) + } + + var result map[string]interface{} + if err := yaml.Unmarshal(resultData, &result); err != nil { + t.Fatalf("Failed to parse secrets: %v", err) + } + + if result["dbPassword"] != "secret123" { + t.Errorf("Expected dbPassword='secret123', got %v", result["dbPassword"]) + } + if result["apiKey"] != "key456" { + t.Errorf("Expected apiKey='key456', got %v", result["apiKey"]) + } +} + +func TestUpdateYAMLFile_ConcurrentUpdates(t *testing.T) { + api, _ := setupTestAPI(t) + instanceName := "test-instance" + createTestInstance(t, api, instanceName) + + // This test verifies that file locking prevents race conditions + // We'll simulate concurrent updates and verify data integrity + + numUpdates := 10 + done := make(chan bool, numUpdates) + + for i := 0; i < numUpdates; i++ { + go func(index int) { + updateData := map[string]interface{}{ + "counter": index, + } + updateYAML, _ := yaml.Marshal(updateData) + + req := httptest.NewRequest("PUT", "/api/v1/instances/"+instanceName+"/config", bytes.NewBuffer(updateYAML)) + w := httptest.NewRecorder() + + vars := map[string]string{"name": instanceName} + req = mux.SetURLVars(req, vars) + + api.UpdateConfig(w, req) + + done <- w.Code == http.StatusOK + }(i) + } + + // Wait for all updates to complete + successCount := 0 + for i := 0; i < numUpdates; i++ { + if <-done { + successCount++ + } + } + + if successCount != numUpdates { + t.Errorf("Expected %d successful updates, got %d", numUpdates, successCount) + } + + // Verify file is still valid YAML + configPath := api.instance.GetInstanceConfigPath(instanceName) + resultData, err := storage.ReadFile(configPath) + if err != nil { + t.Fatalf("Failed to read final config: %v", err) + } + + var result map[string]interface{} + if err := yaml.Unmarshal(resultData, &result); err != nil { + t.Fatalf("Final config is not valid YAML: %v", err) + } +} + +func TestUpdateYAMLFile_PreservesComplexTypes(t *testing.T) { + api, _ := setupTestAPI(t) + instanceName := "test-instance" + createTestInstance(t, api, instanceName) + + configPath := api.instance.GetInstanceConfigPath(instanceName) + + // Create config with various types + updateData := map[string]interface{}{ + "stringValue": "text", + "intValue": 42, + "floatValue": 3.14, + "boolValue": true, + "arrayValue": []interface{}{"a", "b", "c"}, + "mapValue": map[string]interface{}{ + "nested": "value", + }, + "nullValue": nil, + } + updateYAML, _ := yaml.Marshal(updateData) + + req := httptest.NewRequest("PUT", "/api/v1/instances/"+instanceName+"/config", bytes.NewBuffer(updateYAML)) + w := httptest.NewRecorder() + + vars := map[string]string{"name": instanceName} + req = mux.SetURLVars(req, vars) + + api.UpdateConfig(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + // Verify types preserved + resultData, err := storage.ReadFile(configPath) + if err != nil { + t.Fatalf("Failed to read result: %v", err) + } + + var result map[string]interface{} + if err := yaml.Unmarshal(resultData, &result); err != nil { + t.Fatalf("Failed to parse result: %v", err) + } + + if result["stringValue"] != "text" { + t.Errorf("String value not preserved: %v", result["stringValue"]) + } + if result["intValue"] != 42 { + t.Errorf("Int value not preserved: %v", result["intValue"]) + } + if result["floatValue"] != 3.14 { + t.Errorf("Float value not preserved: %v", result["floatValue"]) + } + if result["boolValue"] != true { + t.Errorf("Bool value not preserved: %v", result["boolValue"]) + } + + arrayValue, ok := result["arrayValue"].([]interface{}) + if !ok { + t.Errorf("Array not preserved as slice: %T", result["arrayValue"]) + } else if len(arrayValue) != 3 { + t.Errorf("Array length not preserved: %d", len(arrayValue)) + } + + mapValue, ok := result["mapValue"].(map[string]interface{}) + if !ok { + t.Errorf("Map not preserved: %T", result["mapValue"]) + } else if mapValue["nested"] != "value" { + t.Errorf("Nested map value not preserved: %v", mapValue["nested"]) + } + + if result["nullValue"] != nil { + t.Errorf("Null value not preserved: %v", result["nullValue"]) + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..7e23045 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,714 @@ +package config + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +// Test: LoadGlobalConfig loads valid configuration +func TestLoadGlobalConfig(t *testing.T) { + tests := []struct { + name string + configYAML string + verify func(t *testing.T, config *GlobalConfig) + wantErr bool + }{ + { + name: "loads complete configuration", + configYAML: `wildcloud: + repository: "https://github.com/example/repo" + currentPhase: "setup" + completedPhases: + - "phase1" + - "phase2" +server: + port: 8080 + host: "localhost" +operator: + email: "admin@example.com" +cloud: + dns: + ip: "192.168.1.1" + externalResolver: "8.8.8.8" + router: + ip: "192.168.1.254" + dynamicDns: "example.dyndns.org" + dnsmasq: + interface: "eth0" +cluster: + endpointIp: "192.168.1.100" + nodes: + talos: + version: "v1.8.0" +`, + verify: func(t *testing.T, config *GlobalConfig) { + if config.Wildcloud.Repository != "https://github.com/example/repo" { + t.Error("repository not loaded correctly") + } + if config.Server.Port != 8080 { + t.Error("port not loaded correctly") + } + if config.Cloud.DNS.IP != "192.168.1.1" { + t.Error("DNS IP not loaded correctly") + } + if config.Cluster.EndpointIP != "192.168.1.100" { + t.Error("endpoint IP not loaded correctly") + } + }, + wantErr: false, + }, + { + name: "applies default values", + configYAML: `cloud: + dns: + ip: "192.168.1.1" +cluster: + nodes: + talos: + version: "v1.8.0" +`, + verify: func(t *testing.T, config *GlobalConfig) { + if config.Server.Port != 5055 { + t.Errorf("default port not applied, got %d, want 5055", config.Server.Port) + } + if config.Server.Host != "0.0.0.0" { + t.Errorf("default host not applied, got %q, want %q", config.Server.Host, "0.0.0.0") + } + }, + wantErr: false, + }, + { + name: "preserves custom port and host", + configYAML: `server: + port: 9000 + host: "127.0.0.1" +cloud: + dns: + ip: "192.168.1.1" +cluster: + nodes: + talos: + version: "v1.8.0" +`, + verify: func(t *testing.T, config *GlobalConfig) { + if config.Server.Port != 9000 { + t.Errorf("custom port not preserved, got %d, want 9000", config.Server.Port) + } + if config.Server.Host != "127.0.0.1" { + t.Errorf("custom host not preserved, got %q, want %q", config.Server.Host, "127.0.0.1") + } + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + + if err := os.WriteFile(configPath, []byte(tt.configYAML), 0644); err != nil { + t.Fatalf("setup failed: %v", err) + } + + config, err := LoadGlobalConfig(configPath) + if tt.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if config == nil { + t.Fatal("config is nil") + } + + if tt.verify != nil { + tt.verify(t, config) + } + }) + } +} + +// Test: LoadGlobalConfig error cases +func TestLoadGlobalConfig_Errors(t *testing.T) { + tests := []struct { + name string + setupFunc func(t *testing.T) string + errContains string + }{ + { + name: "non-existent file", + setupFunc: func(t *testing.T) string { + return filepath.Join(t.TempDir(), "nonexistent.yaml") + }, + errContains: "reading config file", + }, + { + name: "invalid yaml", + setupFunc: func(t *testing.T) string { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + content := `invalid: yaml: [[[` + if err := os.WriteFile(configPath, []byte(content), 0644); err != nil { + t.Fatalf("setup failed: %v", err) + } + return configPath + }, + errContains: "parsing config file", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + configPath := tt.setupFunc(t) + _, err := LoadGlobalConfig(configPath) + + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + }) + } +} + +// Test: SaveGlobalConfig saves configuration correctly +func TestSaveGlobalConfig(t *testing.T) { + tests := []struct { + name string + config *GlobalConfig + verify func(t *testing.T, configPath string) + }{ + { + name: "saves complete configuration", + config: &GlobalConfig{ + Wildcloud: struct { + Repository string `yaml:"repository" json:"repository"` + CurrentPhase string `yaml:"currentPhase" json:"currentPhase"` + CompletedPhases []string `yaml:"completedPhases" json:"completedPhases"` + }{ + Repository: "https://github.com/example/repo", + CurrentPhase: "setup", + CompletedPhases: []string{"phase1", "phase2"}, + }, + Server: struct { + Port int `yaml:"port" json:"port"` + Host string `yaml:"host" json:"host"` + }{ + Port: 8080, + Host: "localhost", + }, + }, + verify: func(t *testing.T, configPath string) { + content, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("failed to read saved config: %v", err) + } + contentStr := string(content) + if !strings.Contains(contentStr, "repository") { + t.Error("saved config missing repository field") + } + if !strings.Contains(contentStr, "8080") { + t.Error("saved config missing port value") + } + }, + }, + { + name: "saves empty configuration", + config: &GlobalConfig{}, + verify: func(t *testing.T, configPath string) { + if _, err := os.Stat(configPath); os.IsNotExist(err) { + t.Error("config file not created") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "subdir", "config.yaml") + + err := SaveGlobalConfig(tt.config, configPath) + if err != nil { + t.Errorf("SaveGlobalConfig failed: %v", err) + return + } + + // Verify file exists + if _, err := os.Stat(configPath); err != nil { + t.Errorf("config file not created: %v", err) + return + } + + // Verify file permissions + info, err := os.Stat(configPath) + if err != nil { + t.Fatalf("failed to stat config file: %v", err) + } + if info.Mode().Perm() != 0644 { + t.Errorf("expected permissions 0644, got %v", info.Mode().Perm()) + } + + // Verify content can be loaded back + loadedConfig, err := LoadGlobalConfig(configPath) + if err != nil { + t.Errorf("failed to reload saved config: %v", err) + } else if loadedConfig == nil { + t.Error("loaded config is nil") + } + + if tt.verify != nil { + tt.verify(t, configPath) + } + }) + } +} + +// Test: SaveGlobalConfig creates directory +func TestSaveGlobalConfig_CreatesDirectory(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "nested", "dirs", "config.yaml") + + config := &GlobalConfig{} + err := SaveGlobalConfig(config, configPath) + if err != nil { + t.Fatalf("SaveGlobalConfig failed: %v", err) + } + + // Verify nested directories were created + if _, err := os.Stat(filepath.Dir(configPath)); err != nil { + t.Errorf("directory not created: %v", err) + } + + // Verify file exists + if _, err := os.Stat(configPath); err != nil { + t.Errorf("config file not created: %v", err) + } +} + +// Test: GlobalConfig.IsEmpty checks if config is empty +func TestGlobalConfig_IsEmpty(t *testing.T) { + tests := []struct { + name string + config *GlobalConfig + want bool + }{ + { + name: "nil config is empty", + config: nil, + want: true, + }, + { + name: "default config is empty", + config: &GlobalConfig{}, + want: true, + }, + { + name: "config with only DNS IP is empty", + config: &GlobalConfig{ + Cloud: struct { + DNS struct { + IP string `yaml:"ip" json:"ip"` + ExternalResolver string `yaml:"externalResolver" json:"externalResolver"` + } `yaml:"dns" json:"dns"` + Router struct { + IP string `yaml:"ip" json:"ip"` + DynamicDns string `yaml:"dynamicDns" json:"dynamicDns"` + } `yaml:"router" json:"router"` + Dnsmasq struct { + Interface string `yaml:"interface" json:"interface"` + } `yaml:"dnsmasq" json:"dnsmasq"` + }{ + DNS: struct { + IP string `yaml:"ip" json:"ip"` + ExternalResolver string `yaml:"externalResolver" json:"externalResolver"` + }{ + IP: "192.168.1.1", + }, + }, + }, + want: true, + }, + { + name: "config with only Talos version is empty", + config: &GlobalConfig{ + Cluster: struct { + EndpointIP string `yaml:"endpointIp" json:"endpointIp"` + Nodes struct { + Talos struct { + Version string `yaml:"version" json:"version"` + } `yaml:"talos" json:"talos"` + } `yaml:"nodes" json:"nodes"` + }{ + Nodes: struct { + Talos struct { + Version string `yaml:"version" json:"version"` + } `yaml:"talos" json:"talos"` + }{ + Talos: struct { + Version string `yaml:"version" json:"version"` + }{ + Version: "v1.8.0", + }, + }, + }, + }, + want: true, + }, + { + name: "config with both DNS IP and Talos version is not empty", + config: &GlobalConfig{ + Cloud: struct { + DNS struct { + IP string `yaml:"ip" json:"ip"` + ExternalResolver string `yaml:"externalResolver" json:"externalResolver"` + } `yaml:"dns" json:"dns"` + Router struct { + IP string `yaml:"ip" json:"ip"` + DynamicDns string `yaml:"dynamicDns" json:"dynamicDns"` + } `yaml:"router" json:"router"` + Dnsmasq struct { + Interface string `yaml:"interface" json:"interface"` + } `yaml:"dnsmasq" json:"dnsmasq"` + }{ + DNS: struct { + IP string `yaml:"ip" json:"ip"` + ExternalResolver string `yaml:"externalResolver" json:"externalResolver"` + }{ + IP: "192.168.1.1", + }, + }, + Cluster: struct { + EndpointIP string `yaml:"endpointIp" json:"endpointIp"` + Nodes struct { + Talos struct { + Version string `yaml:"version" json:"version"` + } `yaml:"talos" json:"talos"` + } `yaml:"nodes" json:"nodes"` + }{ + Nodes: struct { + Talos struct { + Version string `yaml:"version" json:"version"` + } `yaml:"talos" json:"talos"` + }{ + Talos: struct { + Version string `yaml:"version" json:"version"` + }{ + Version: "v1.8.0", + }, + }, + }, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.config.IsEmpty() + if got != tt.want { + t.Errorf("IsEmpty() = %v, want %v", got, tt.want) + } + }) + } +} + +// Test: LoadCloudConfig loads instance configuration +func TestLoadCloudConfig(t *testing.T) { + tests := []struct { + name string + configYAML string + verify func(t *testing.T, config *InstanceConfig) + wantErr bool + }{ + { + name: "loads complete instance configuration", + configYAML: `cloud: + router: + ip: "192.168.1.254" + dns: + ip: "192.168.1.1" + externalResolver: "8.8.8.8" + dhcpRange: "192.168.1.100,192.168.1.200" + baseDomain: "example.com" + domain: "home" + internalDomain: "internal.example.com" +cluster: + name: "my-cluster" + loadBalancerIp: "192.168.1.10" + nodes: + talos: + version: "v1.8.0" + activeNodes: + - node1: + role: "control" + interface: "eth0" + disk: "/dev/sda" +`, + verify: func(t *testing.T, config *InstanceConfig) { + if config.Cloud.BaseDomain != "example.com" { + t.Error("base domain not loaded correctly") + } + if config.Cluster.Name != "my-cluster" { + t.Error("cluster name not loaded correctly") + } + if config.Cluster.Nodes.Talos.Version != "v1.8.0" { + t.Error("talos version not loaded correctly") + } + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + + if err := os.WriteFile(configPath, []byte(tt.configYAML), 0644); err != nil { + t.Fatalf("setup failed: %v", err) + } + + config, err := LoadCloudConfig(configPath) + if tt.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if config == nil { + t.Fatal("config is nil") + } + + if tt.verify != nil { + tt.verify(t, config) + } + }) + } +} + +// Test: LoadCloudConfig error cases +func TestLoadCloudConfig_Errors(t *testing.T) { + tests := []struct { + name string + setupFunc func(t *testing.T) string + errContains string + }{ + { + name: "non-existent file", + setupFunc: func(t *testing.T) string { + return filepath.Join(t.TempDir(), "nonexistent.yaml") + }, + errContains: "reading config file", + }, + { + name: "invalid yaml", + setupFunc: func(t *testing.T) string { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + content := `invalid: yaml: [[[` + if err := os.WriteFile(configPath, []byte(content), 0644); err != nil { + t.Fatalf("setup failed: %v", err) + } + return configPath + }, + errContains: "parsing config file", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + configPath := tt.setupFunc(t) + _, err := LoadCloudConfig(configPath) + + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + }) + } +} + +// Test: SaveCloudConfig saves instance configuration +func TestSaveCloudConfig(t *testing.T) { + tests := []struct { + name string + config *InstanceConfig + verify func(t *testing.T, configPath string) + }{ + { + name: "saves instance configuration", + config: &InstanceConfig{ + Cloud: struct { + Router struct { + IP string `yaml:"ip" json:"ip"` + } `yaml:"router" json:"router"` + DNS struct { + IP string `yaml:"ip" json:"ip"` + ExternalResolver string `yaml:"externalResolver" json:"externalResolver"` + } `yaml:"dns" json:"dns"` + DHCPRange string `yaml:"dhcpRange" json:"dhcpRange"` + Dnsmasq struct { + Interface string `yaml:"interface" json:"interface"` + } `yaml:"dnsmasq" json:"dnsmasq"` + BaseDomain string `yaml:"baseDomain" json:"baseDomain"` + Domain string `yaml:"domain" json:"domain"` + InternalDomain string `yaml:"internalDomain" json:"internalDomain"` + NFS struct { + MediaPath string `yaml:"mediaPath" json:"mediaPath"` + Host string `yaml:"host" json:"host"` + StorageCapacity string `yaml:"storageCapacity" json:"storageCapacity"` + } `yaml:"nfs" json:"nfs"` + DockerRegistryHost string `yaml:"dockerRegistryHost" json:"dockerRegistryHost"` + Backup struct { + Root string `yaml:"root" json:"root"` + } `yaml:"backup" json:"backup"` + }{ + BaseDomain: "example.com", + Domain: "home", + }, + }, + verify: func(t *testing.T, configPath string) { + content, err := os.ReadFile(configPath) + if err != nil { + t.Fatalf("failed to read saved config: %v", err) + } + contentStr := string(content) + if !strings.Contains(contentStr, "example.com") { + t.Error("saved config missing base domain") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "subdir", "config.yaml") + + err := SaveCloudConfig(tt.config, configPath) + if err != nil { + t.Errorf("SaveCloudConfig failed: %v", err) + return + } + + // Verify file exists + if _, err := os.Stat(configPath); err != nil { + t.Errorf("config file not created: %v", err) + return + } + + // Verify content can be loaded back + loadedConfig, err := LoadCloudConfig(configPath) + if err != nil { + t.Errorf("failed to reload saved config: %v", err) + } else if loadedConfig == nil { + t.Error("loaded config is nil") + } + + if tt.verify != nil { + tt.verify(t, configPath) + } + }) + } +} + +// Test: Round-trip save and load preserves data +func TestGlobalConfig_RoundTrip(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + + // Create config with all fields + original := &GlobalConfig{ + Wildcloud: struct { + Repository string `yaml:"repository" json:"repository"` + CurrentPhase string `yaml:"currentPhase" json:"currentPhase"` + CompletedPhases []string `yaml:"completedPhases" json:"completedPhases"` + }{ + Repository: "https://github.com/example/repo", + CurrentPhase: "setup", + CompletedPhases: []string{"phase1", "phase2"}, + }, + Server: struct { + Port int `yaml:"port" json:"port"` + Host string `yaml:"host" json:"host"` + }{ + Port: 8080, + Host: "localhost", + }, + Operator: struct { + Email string `yaml:"email" json:"email"` + }{ + Email: "admin@example.com", + }, + } + + // Save config + if err := SaveGlobalConfig(original, configPath); err != nil { + t.Fatalf("SaveGlobalConfig failed: %v", err) + } + + // Load config + loaded, err := LoadGlobalConfig(configPath) + if err != nil { + t.Fatalf("LoadGlobalConfig failed: %v", err) + } + + // Verify all fields match + if loaded.Wildcloud.Repository != original.Wildcloud.Repository { + t.Errorf("repository mismatch: got %q, want %q", loaded.Wildcloud.Repository, original.Wildcloud.Repository) + } + if loaded.Server.Port != original.Server.Port { + t.Errorf("port mismatch: got %d, want %d", loaded.Server.Port, original.Server.Port) + } + if loaded.Operator.Email != original.Operator.Email { + t.Errorf("email mismatch: got %q, want %q", loaded.Operator.Email, original.Operator.Email) + } +} + +// Test: Round-trip save and load for instance config +func TestInstanceConfig_RoundTrip(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + + // Create instance config + original := &InstanceConfig{} + original.Cloud.BaseDomain = "example.com" + original.Cloud.Domain = "home" + original.Cluster.Name = "my-cluster" + + // Save config + if err := SaveCloudConfig(original, configPath); err != nil { + t.Fatalf("SaveCloudConfig failed: %v", err) + } + + // Load config + loaded, err := LoadCloudConfig(configPath) + if err != nil { + t.Fatalf("LoadCloudConfig failed: %v", err) + } + + // Verify fields match + if loaded.Cloud.BaseDomain != original.Cloud.BaseDomain { + t.Errorf("base domain mismatch: got %q, want %q", loaded.Cloud.BaseDomain, original.Cloud.BaseDomain) + } + if loaded.Cluster.Name != original.Cluster.Name { + t.Errorf("cluster name mismatch: got %q, want %q", loaded.Cluster.Name, original.Cluster.Name) + } +} diff --git a/internal/config/manager_test.go b/internal/config/manager_test.go new file mode 100644 index 0000000..b898a1e --- /dev/null +++ b/internal/config/manager_test.go @@ -0,0 +1,905 @@ +package config + +import ( + "os" + "path/filepath" + "strings" + "sync" + "testing" + + "github.com/wild-cloud/wild-central/daemon/internal/storage" +) + +// Test: NewManager creates manager successfully +func TestNewManager(t *testing.T) { + m := NewManager() + if m == nil { + t.Fatal("NewManager returned nil") + } + if m.yq == nil { + t.Error("Manager.yq is nil") + } +} + +// Test: EnsureInstanceConfig creates config file with proper structure +func TestEnsureInstanceConfig(t *testing.T) { + tests := []struct { + name string + setupFunc func(t *testing.T, instancePath string) + wantErr bool + errContains string + }{ + { + name: "creates config when not exists", + setupFunc: nil, + wantErr: false, + }, + { + name: "returns nil when config exists", + setupFunc: func(t *testing.T, instancePath string) { + configPath := filepath.Join(instancePath, "config.yaml") + content := `baseDomain: "test.local" +domain: "test" +internalDomain: "internal.test" +dhcpRange: "" +backup: + root: "" +nfs: + host: "" + mediaPath: "" +cluster: + name: "" + loadBalancerIp: "" + ipAddressPool: "" + hostnamePrefix: "" + certManager: + cloudflare: + domain: "" + zoneID: "" + externalDns: + ownerId: "" + nodes: + talos: + version: "" + schematicId: "" + control: + vip: "" + activeNodes: [] +` + if err := storage.WriteFile(configPath, []byte(content), 0644); err != nil { + t.Fatalf("setup failed: %v", err) + } + }, + wantErr: false, + }, + { + name: "returns error when config is invalid yaml", + setupFunc: func(t *testing.T, instancePath string) { + configPath := filepath.Join(instancePath, "config.yaml") + content := `invalid: yaml: content: [[[` + if err := storage.WriteFile(configPath, []byte(content), 0644); err != nil { + t.Fatalf("setup failed: %v", err) + } + }, + wantErr: true, + errContains: "invalid config file", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + instancePath := t.TempDir() + m := NewManager() + + if tt.setupFunc != nil { + tt.setupFunc(t, instancePath) + } + + err := m.EnsureInstanceConfig(instancePath) + if tt.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Verify config file exists + configPath := filepath.Join(instancePath, "config.yaml") + if !storage.FileExists(configPath) { + t.Error("config file not created") + } + + // Verify config is valid YAML + if err := m.ValidateConfig(configPath); err != nil { + t.Errorf("config validation failed: %v", err) + } + + // Verify config has expected structure + content, err := storage.ReadFile(configPath) + if err != nil { + t.Fatalf("failed to read config: %v", err) + } + contentStr := string(content) + requiredFields := []string{"baseDomain:", "domain:", "cluster:", "backup:", "nfs:"} + for _, field := range requiredFields { + if !strings.Contains(contentStr, field) { + t.Errorf("config missing required field: %s", field) + } + } + }) + } +} + +// Test: GetConfigValue retrieves values correctly +func TestGetConfigValue(t *testing.T) { + tests := []struct { + name string + configYAML string + key string + want string + wantErr bool + errContains string + }{ + { + name: "get simple string value", + configYAML: `baseDomain: "example.com" +domain: "test" +`, + key: "baseDomain", + want: "example.com", + wantErr: false, + }, + { + name: "get nested value with dot notation", + configYAML: `cluster: + name: "my-cluster" + nodes: + talos: + version: "v1.8.0" +`, + key: "cluster.nodes.talos.version", + want: "v1.8.0", + wantErr: false, + }, + { + name: "get empty string value", + configYAML: `baseDomain: "" +`, + key: "baseDomain", + want: "", + wantErr: false, + }, + { + name: "get non-existent key returns null", + configYAML: `baseDomain: "example.com" +`, + key: "nonexistent", + want: "null", + wantErr: false, + }, + { + name: "get from array", + configYAML: `cluster: + nodes: + activeNodes: + - "node1" + - "node2" +`, + key: "cluster.nodes.activeNodes.[0]", + want: "node1", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + + if err := storage.WriteFile(configPath, []byte(tt.configYAML), 0644); err != nil { + t.Fatalf("setup failed: %v", err) + } + + m := NewManager() + got, err := m.GetConfigValue(configPath, tt.key) + + if tt.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} + +// Test: GetConfigValue error cases +func TestGetConfigValue_Errors(t *testing.T) { + tests := []struct { + name string + setupFunc func(t *testing.T) string + key string + errContains string + }{ + { + name: "non-existent file", + setupFunc: func(t *testing.T) string { + return filepath.Join(t.TempDir(), "nonexistent.yaml") + }, + key: "baseDomain", + errContains: "config file not found", + }, + { + name: "malformed yaml", + setupFunc: func(t *testing.T) string { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + content := `invalid: yaml: [[[` + if err := storage.WriteFile(configPath, []byte(content), 0644); err != nil { + t.Fatalf("setup failed: %v", err) + } + return configPath + }, + key: "baseDomain", + errContains: "getting config value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + configPath := tt.setupFunc(t) + m := NewManager() + + _, err := m.GetConfigValue(configPath, tt.key) + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + }) + } +} + +// Test: SetConfigValue sets values correctly +func TestSetConfigValue(t *testing.T) { + tests := []struct { + name string + initialYAML string + key string + value string + verifyFunc func(t *testing.T, configPath string) + }{ + { + name: "set simple value", + initialYAML: `baseDomain: "" +domain: "" +`, + key: "baseDomain", + value: "example.com", + verifyFunc: func(t *testing.T, configPath string) { + m := NewManager() + got, err := m.GetConfigValue(configPath, "baseDomain") + if err != nil { + t.Fatalf("verify failed: %v", err) + } + if got != "example.com" { + t.Errorf("got %q, want %q", got, "example.com") + } + }, + }, + { + name: "set nested value", + initialYAML: `cluster: + name: "" + nodes: + talos: + version: "" +`, + key: "cluster.nodes.talos.version", + value: "v1.8.0", + verifyFunc: func(t *testing.T, configPath string) { + m := NewManager() + got, err := m.GetConfigValue(configPath, "cluster.nodes.talos.version") + if err != nil { + t.Fatalf("verify failed: %v", err) + } + if got != "v1.8.0" { + t.Errorf("got %q, want %q", got, "v1.8.0") + } + }, + }, + { + name: "update existing value", + initialYAML: `baseDomain: "old.com" +`, + key: "baseDomain", + value: "new.com", + verifyFunc: func(t *testing.T, configPath string) { + m := NewManager() + got, err := m.GetConfigValue(configPath, "baseDomain") + if err != nil { + t.Fatalf("verify failed: %v", err) + } + if got != "new.com" { + t.Errorf("got %q, want %q", got, "new.com") + } + }, + }, + { + name: "create new nested path", + initialYAML: `cluster: {} +`, + key: "cluster.newField", + value: "newValue", + verifyFunc: func(t *testing.T, configPath string) { + m := NewManager() + got, err := m.GetConfigValue(configPath, "cluster.newField") + if err != nil { + t.Fatalf("verify failed: %v", err) + } + if got != "newValue" { + t.Errorf("got %q, want %q", got, "newValue") + } + }, + }, + { + name: "set value with special characters", + initialYAML: `baseDomain: "" +`, + key: "baseDomain", + value: `special"quotes'and\backslashes`, + verifyFunc: func(t *testing.T, configPath string) { + m := NewManager() + got, err := m.GetConfigValue(configPath, "baseDomain") + if err != nil { + t.Fatalf("verify failed: %v", err) + } + if got != `special"quotes'and\backslashes` { + t.Errorf("got %q, want %q", got, `special"quotes'and\backslashes`) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + + if err := storage.WriteFile(configPath, []byte(tt.initialYAML), 0644); err != nil { + t.Fatalf("setup failed: %v", err) + } + + m := NewManager() + if err := m.SetConfigValue(configPath, tt.key, tt.value); err != nil { + t.Errorf("SetConfigValue failed: %v", err) + return + } + + // Verify the value was set correctly + tt.verifyFunc(t, configPath) + + // Verify config is still valid YAML + if err := m.ValidateConfig(configPath); err != nil { + t.Errorf("config validation failed after set: %v", err) + } + }) + } +} + +// Test: SetConfigValue error cases +func TestSetConfigValue_Errors(t *testing.T) { + tests := []struct { + name string + setupFunc func(t *testing.T) string + key string + value string + errContains string + }{ + { + name: "non-existent file", + setupFunc: func(t *testing.T) string { + return filepath.Join(t.TempDir(), "nonexistent.yaml") + }, + key: "baseDomain", + value: "example.com", + errContains: "config file not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + configPath := tt.setupFunc(t) + m := NewManager() + + err := m.SetConfigValue(configPath, tt.key, tt.value) + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + }) + } +} + +// Test: SetConfigValue with concurrent access +func TestSetConfigValue_ConcurrentAccess(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + + initialYAML := `counter: "0" +` + if err := storage.WriteFile(configPath, []byte(initialYAML), 0644); err != nil { + t.Fatalf("setup failed: %v", err) + } + + m := NewManager() + const numGoroutines = 10 + + var wg sync.WaitGroup + errors := make(chan error, numGoroutines) + + // Launch multiple goroutines trying to write different values + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(val int) { + defer wg.Done() + key := "counter" + value := string(rune('0' + val)) + if err := m.SetConfigValue(configPath, key, value); err != nil { + errors <- err + } + }(i) + } + + wg.Wait() + close(errors) + + // Check if any errors occurred + for err := range errors { + t.Errorf("concurrent write error: %v", err) + } + + // Verify config is still valid after concurrent access + if err := m.ValidateConfig(configPath); err != nil { + t.Errorf("config validation failed after concurrent writes: %v", err) + } + + // Verify we can read the value (should be one of the written values) + value, err := m.GetConfigValue(configPath, "counter") + if err != nil { + t.Errorf("failed to read value after concurrent writes: %v", err) + } + if value == "" || value == "null" { + t.Error("counter value is empty after concurrent writes") + } +} + +// Test: EnsureConfigValue sets value only when not set +func TestEnsureConfigValue(t *testing.T) { + tests := []struct { + name string + initialYAML string + key string + value string + expectSet bool + }{ + { + name: "sets value when empty string", + initialYAML: `baseDomain: "" +`, + key: "baseDomain", + value: "example.com", + expectSet: true, + }, + { + name: "sets value when null", + initialYAML: `baseDomain: null +`, + key: "baseDomain", + value: "example.com", + expectSet: true, + }, + { + name: "does not set value when already set", + initialYAML: `baseDomain: "existing.com" +`, + key: "baseDomain", + value: "new.com", + expectSet: false, + }, + { + name: "sets value when key does not exist", + initialYAML: `domain: "test" +`, + key: "baseDomain", + value: "example.com", + expectSet: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + + if err := storage.WriteFile(configPath, []byte(tt.initialYAML), 0644); err != nil { + t.Fatalf("setup failed: %v", err) + } + + m := NewManager() + + // Get initial value + initialVal, _ := m.GetConfigValue(configPath, tt.key) + + // Call EnsureConfigValue + if err := m.EnsureConfigValue(configPath, tt.key, tt.value); err != nil { + t.Errorf("EnsureConfigValue failed: %v", err) + return + } + + // Get final value + finalVal, err := m.GetConfigValue(configPath, tt.key) + if err != nil { + t.Fatalf("GetConfigValue failed: %v", err) + } + + if tt.expectSet { + if finalVal != tt.value { + t.Errorf("expected value to be set to %q, got %q", tt.value, finalVal) + } + } else { + if finalVal != initialVal { + t.Errorf("expected value to remain %q, got %q", initialVal, finalVal) + } + } + + // Call EnsureConfigValue again - should be idempotent + if err := m.EnsureConfigValue(configPath, tt.key, "different.com"); err != nil { + t.Errorf("second EnsureConfigValue failed: %v", err) + return + } + + // Value should not change on second call + secondVal, err := m.GetConfigValue(configPath, tt.key) + if err != nil { + t.Fatalf("GetConfigValue failed: %v", err) + } + if secondVal != finalVal { + t.Errorf("value changed on second ensure: %q -> %q", finalVal, secondVal) + } + }) + } +} + +// Test: ValidateConfig validates YAML correctly +func TestValidateConfig(t *testing.T) { + tests := []struct { + name string + configYAML string + wantErr bool + errContains string + }{ + { + name: "valid yaml", + configYAML: `baseDomain: "example.com" +domain: "test" +cluster: + name: "my-cluster" +`, + wantErr: false, + }, + { + name: "invalid yaml - bad indentation", + configYAML: `baseDomain: "example.com"\n domain: "test"`, + wantErr: true, + errContains: "yaml validation failed", + }, + { + name: "invalid yaml - unclosed bracket", + configYAML: `cluster: { name: "test"`, + wantErr: true, + errContains: "yaml validation failed", + }, + { + name: "empty file", + configYAML: "", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "config.yaml") + + if err := storage.WriteFile(configPath, []byte(tt.configYAML), 0644); err != nil { + t.Fatalf("setup failed: %v", err) + } + + m := NewManager() + err := m.ValidateConfig(configPath) + + if tt.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} + +// Test: ValidateConfig error cases +func TestValidateConfig_Errors(t *testing.T) { + t.Run("non-existent file", func(t *testing.T) { + tempDir := t.TempDir() + configPath := filepath.Join(tempDir, "nonexistent.yaml") + + m := NewManager() + err := m.ValidateConfig(configPath) + + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), "config file not found") { + t.Errorf("error %q does not contain 'config file not found'", err.Error()) + } + }) +} + +// Test: CopyConfig copies configuration correctly +func TestCopyConfig(t *testing.T) { + tests := []struct { + name string + srcYAML string + setupDst func(t *testing.T, dstPath string) + wantErr bool + errContains string + }{ + { + name: "copies config successfully", + srcYAML: `baseDomain: "example.com" +domain: "test" +cluster: + name: "my-cluster" +`, + setupDst: nil, + wantErr: false, + }, + { + name: "creates destination directory", + srcYAML: `baseDomain: "example.com"`, + setupDst: nil, + wantErr: false, + }, + { + name: "overwrites existing destination", + srcYAML: `baseDomain: "new.com" +`, + setupDst: func(t *testing.T, dstPath string) { + oldContent := `baseDomain: "old.com"` + if err := storage.WriteFile(dstPath, []byte(oldContent), 0644); err != nil { + t.Fatalf("setup failed: %v", err) + } + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + srcPath := filepath.Join(tempDir, "source.yaml") + dstPath := filepath.Join(tempDir, "subdir", "dest.yaml") + + // Create source file + if err := storage.WriteFile(srcPath, []byte(tt.srcYAML), 0644); err != nil { + t.Fatalf("setup failed: %v", err) + } + + // Setup destination if needed + if tt.setupDst != nil { + if err := storage.EnsureDir(filepath.Dir(dstPath), 0755); err != nil { + t.Fatalf("setup failed: %v", err) + } + tt.setupDst(t, dstPath) + } + + m := NewManager() + err := m.CopyConfig(srcPath, dstPath) + + if tt.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Verify destination file exists + if !storage.FileExists(dstPath) { + t.Error("destination file not created") + } + + // Verify content matches source + srcContent, err := storage.ReadFile(srcPath) + if err != nil { + t.Fatalf("failed to read source: %v", err) + } + dstContent, err := storage.ReadFile(dstPath) + if err != nil { + t.Fatalf("failed to read destination: %v", err) + } + + if string(srcContent) != string(dstContent) { + t.Error("destination content does not match source") + } + + // Verify destination is valid YAML + if err := m.ValidateConfig(dstPath); err != nil { + t.Errorf("destination config validation failed: %v", err) + } + }) + } +} + +// Test: CopyConfig error cases +func TestCopyConfig_Errors(t *testing.T) { + tests := []struct { + name string + setupFunc func(t *testing.T, tempDir string) (srcPath, dstPath string) + errContains string + }{ + { + name: "source file does not exist", + setupFunc: func(t *testing.T, tempDir string) (string, string) { + return filepath.Join(tempDir, "nonexistent.yaml"), + filepath.Join(tempDir, "dest.yaml") + }, + errContains: "source config file not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + srcPath, dstPath := tt.setupFunc(t, tempDir) + + m := NewManager() + err := m.CopyConfig(srcPath, dstPath) + + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + }) + } +} + +// Test: File permissions are preserved +func TestEnsureInstanceConfig_FilePermissions(t *testing.T) { + tempDir := t.TempDir() + m := NewManager() + + if err := m.EnsureInstanceConfig(tempDir); err != nil { + t.Fatalf("EnsureInstanceConfig failed: %v", err) + } + + configPath := filepath.Join(tempDir, "config.yaml") + info, err := os.Stat(configPath) + if err != nil { + t.Fatalf("failed to stat config file: %v", err) + } + + // Verify file has 0644 permissions + if info.Mode().Perm() != 0644 { + t.Errorf("expected permissions 0644, got %v", info.Mode().Perm()) + } +} + +// Test: Idempotent config creation +func TestEnsureInstanceConfig_Idempotent(t *testing.T) { + tempDir := t.TempDir() + m := NewManager() + + // First call creates config + if err := m.EnsureInstanceConfig(tempDir); err != nil { + t.Fatalf("first EnsureInstanceConfig failed: %v", err) + } + + configPath := filepath.Join(tempDir, "config.yaml") + firstContent, err := storage.ReadFile(configPath) + if err != nil { + t.Fatalf("failed to read config: %v", err) + } + + // Second call should not modify config + if err := m.EnsureInstanceConfig(tempDir); err != nil { + t.Fatalf("second EnsureInstanceConfig failed: %v", err) + } + + secondContent, err := storage.ReadFile(configPath) + if err != nil { + t.Fatalf("failed to read config: %v", err) + } + + if string(firstContent) != string(secondContent) { + t.Error("config content changed on second call") + } +} + +// Test: Config structure contains all required fields +func TestEnsureInstanceConfig_RequiredFields(t *testing.T) { + tempDir := t.TempDir() + m := NewManager() + + if err := m.EnsureInstanceConfig(tempDir); err != nil { + t.Fatalf("EnsureInstanceConfig failed: %v", err) + } + + configPath := filepath.Join(tempDir, "config.yaml") + content, err := storage.ReadFile(configPath) + if err != nil { + t.Fatalf("failed to read config: %v", err) + } + + contentStr := string(content) + requiredFields := []string{ + "baseDomain:", + "domain:", + "internalDomain:", + "dhcpRange:", + "backup:", + "nfs:", + "cluster:", + "loadBalancerIp:", + "ipAddressPool:", + "hostnamePrefix:", + "certManager:", + "externalDns:", + "nodes:", + "talos:", + "version:", + "schematicId:", + "control:", + "vip:", + "activeNodes:", + } + + for _, field := range requiredFields { + if !strings.Contains(contentStr, field) { + t.Errorf("config missing required field: %s", field) + } + } +} diff --git a/internal/secrets/secrets_test.go b/internal/secrets/secrets_test.go index 572b75b..4cdf2d1 100644 --- a/internal/secrets/secrets_test.go +++ b/internal/secrets/secrets_test.go @@ -3,119 +3,964 @@ package secrets import ( "os" "path/filepath" + "strings" + "sync" "testing" + + "github.com/wild-cloud/wild-central/daemon/internal/storage" ) +// Test: GenerateSecret generates valid secrets func TestGenerateSecret(t *testing.T) { - // Test various lengths - lengths := []int{32, 64, 128} - for _, length := range lengths { - secret, err := GenerateSecret(length) - if err != nil { - t.Fatalf("GenerateSecret(%d) failed: %v", length, err) - } - - if len(secret) != length { - t.Errorf("Expected length %d, got %d", length, len(secret)) - } - - // Verify only alphanumeric characters - for _, c := range secret { - if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')) { - t.Errorf("Non-alphanumeric character found: %c", c) - } - } + tests := []struct { + name string + length int + want int + }{ + { + name: "default length", + length: DefaultSecretLength, + want: DefaultSecretLength, + }, + { + name: "custom length 64", + length: 64, + want: 64, + }, + { + name: "custom length 128", + length: 128, + want: 128, + }, + { + name: "zero length defaults to DefaultSecretLength", + length: 0, + want: DefaultSecretLength, + }, + { + name: "negative length defaults to DefaultSecretLength", + length: -1, + want: DefaultSecretLength, + }, } - // Test that secrets are different (not deterministic) - secret1, _ := GenerateSecret(32) - secret2, _ := GenerateSecret(32) - if secret1 == secret2 { - t.Errorf("Generated secrets should be different") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + secret, err := GenerateSecret(tt.length) + if err != nil { + t.Fatalf("GenerateSecret failed: %v", err) + } + + if len(secret) != tt.want { + t.Errorf("got length %d, want %d", len(secret), tt.want) + } + + // Verify only alphanumeric characters + for _, c := range secret { + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')) { + t.Errorf("non-alphanumeric character found: %c", c) + } + } + }) } } -func TestManager_EnsureSecretsFile(t *testing.T) { - tmpDir := t.TempDir() +// Test: GenerateSecret produces unique values +func TestGenerateSecret_Uniqueness(t *testing.T) { + const numSecrets = 100 + secrets := make(map[string]bool, numSecrets) + + for i := 0; i < numSecrets; i++ { + secret, err := GenerateSecret(32) + if err != nil { + t.Fatalf("GenerateSecret failed: %v", err) + } + + if secrets[secret] { + t.Errorf("duplicate secret generated: %s", secret) + } + secrets[secret] = true + } + + if len(secrets) != numSecrets { + t.Errorf("expected %d unique secrets, got %d", numSecrets, len(secrets)) + } +} + +// Test: NewManager creates manager successfully +func TestNewManager(t *testing.T) { + m := NewManager() + if m == nil { + t.Fatal("NewManager returned nil") + } + if m.yq == nil { + t.Error("Manager.yq is nil") + } +} + +// Test: EnsureSecretsFile creates secrets file with proper structure and permissions +func TestEnsureSecretsFile(t *testing.T) { + tests := []struct { + name string + setupFunc func(t *testing.T, instancePath string) + wantErr bool + errContains string + }{ + { + name: "creates secrets file when not exists", + setupFunc: nil, + wantErr: false, + }, + { + name: "returns nil when secrets file exists", + setupFunc: func(t *testing.T, instancePath string) { + secretsPath := filepath.Join(instancePath, "secrets.yaml") + content := `# Wild Cloud Instance Secrets +cluster: + talosSecrets: "" + kubeconfig: "" +certManager: + cloudflare: + apiToken: "" +` + if err := storage.WriteFile(secretsPath, []byte(content), 0600); err != nil { + t.Fatalf("setup failed: %v", err) + } + }, + wantErr: false, + }, + { + name: "corrects permissions on existing file", + setupFunc: func(t *testing.T, instancePath string) { + secretsPath := filepath.Join(instancePath, "secrets.yaml") + content := `# Wild Cloud Instance Secrets +cluster: + talosSecrets: "existing-secret" + kubeconfig: "" +certManager: + cloudflare: + apiToken: "" +` + // Create with wrong permissions + if err := storage.WriteFile(secretsPath, []byte(content), 0644); err != nil { + t.Fatalf("setup failed: %v", err) + } + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + instancePath := t.TempDir() + m := NewManager() + + if tt.setupFunc != nil { + tt.setupFunc(t, instancePath) + } + + err := m.EnsureSecretsFile(instancePath) + if tt.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Verify secrets file exists + secretsPath := filepath.Join(instancePath, "secrets.yaml") + if !storage.FileExists(secretsPath) { + t.Error("secrets file not created") + } + + // Verify permissions are 0600 (secure) + info, err := os.Stat(secretsPath) + if err != nil { + t.Fatalf("failed to stat secrets file: %v", err) + } + if info.Mode().Perm() != 0600 { + t.Errorf("expected permissions 0600, got %o", info.Mode().Perm()) + } + + // Verify file has expected structure + content, err := storage.ReadFile(secretsPath) + if err != nil { + t.Fatalf("failed to read secrets: %v", err) + } + contentStr := string(content) + requiredFields := []string{"cluster:", "certManager:"} + for _, field := range requiredFields { + if !strings.Contains(contentStr, field) { + t.Errorf("secrets missing required field: %s", field) + } + } + }) + } +} + +// Test: GetSecret retrieves secrets correctly +func TestGetSecret(t *testing.T) { + tests := []struct { + name string + secretsYAML string + key string + want string + wantErr bool + errContains string + }{ + { + name: "get simple string value", + secretsYAML: `cluster: + talosSecrets: "my-secret-value" +`, + key: "cluster.talosSecrets", + want: "my-secret-value", + wantErr: false, + }, + { + name: "get nested value with dot notation", + secretsYAML: `certManager: + cloudflare: + apiToken: "cf-token-12345" +`, + key: "certManager.cloudflare.apiToken", + want: "cf-token-12345", + wantErr: false, + }, + { + name: "get non-existent key returns error", + secretsYAML: `cluster: + talosSecrets: "value" +`, + key: "nonexistent", + wantErr: true, + errContains: "secret not found", + }, + { + name: "get empty string value returns error", + secretsYAML: `cluster: + talosSecrets: "" +`, + key: "cluster.talosSecrets", + wantErr: true, + errContains: "secret not found", + }, + { + name: "get null value returns error", + secretsYAML: `cluster: + talosSecrets: null +`, + key: "cluster.talosSecrets", + wantErr: true, + errContains: "secret not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + secretsPath := filepath.Join(tempDir, "secrets.yaml") + + if err := storage.WriteFile(secretsPath, []byte(tt.secretsYAML), 0600); err != nil { + t.Fatalf("setup failed: %v", err) + } + + m := NewManager() + got, err := m.GetSecret(secretsPath, tt.key) + + if tt.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} + +// Test: GetSecret error cases +func TestGetSecret_Errors(t *testing.T) { + tests := []struct { + name string + setupFunc func(t *testing.T) string + key string + errContains string + }{ + { + name: "non-existent file", + setupFunc: func(t *testing.T) string { + return filepath.Join(t.TempDir(), "nonexistent.yaml") + }, + key: "cluster.talosSecrets", + errContains: "secrets file not found", + }, + { + name: "malformed yaml", + setupFunc: func(t *testing.T) string { + tempDir := t.TempDir() + secretsPath := filepath.Join(tempDir, "secrets.yaml") + content := `invalid: yaml: [[[` + if err := storage.WriteFile(secretsPath, []byte(content), 0600); err != nil { + t.Fatalf("setup failed: %v", err) + } + return secretsPath + }, + key: "cluster.talosSecrets", + errContains: "getting secret", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + secretsPath := tt.setupFunc(t) + m := NewManager() + + _, err := m.GetSecret(secretsPath, tt.key) + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + }) + } +} + +// Test: GetSecret does not leak secrets in error messages +func TestGetSecret_NoSecretLeakage(t *testing.T) { + tempDir := t.TempDir() + secretsPath := filepath.Join(tempDir, "secrets.yaml") + + secretValue := "super-secret-password-12345" + secretsYAML := `cluster: + talosSecrets: "` + secretValue + `" +` + if err := storage.WriteFile(secretsPath, []byte(secretsYAML), 0600); err != nil { + t.Fatalf("setup failed: %v", err) + } + m := NewManager() - instancePath := filepath.Join(tmpDir, "test-cloud") - err := os.MkdirAll(instancePath, 0755) - if err != nil { - t.Fatalf("Failed to create instance dir: %v", err) + // Try to get a non-existent key - error should not contain actual secret values + _, err := m.GetSecret(secretsPath, "nonexistent.key") + if err == nil { + t.Fatal("expected error, got nil") } - // Ensure secrets - err = m.EnsureSecretsFile(instancePath) - if err != nil { - t.Fatalf("EnsureSecretsFile failed: %v", err) + // Error message should not contain the secret value + if strings.Contains(err.Error(), secretValue) { + t.Errorf("error message leaked secret value: %v", err) + } +} + +// Test: SetSecret sets secrets correctly +func TestSetSecret(t *testing.T) { + tests := []struct { + name string + initialYAML string + key string + value string + verifyFunc func(t *testing.T, secretsPath string) + }{ + { + name: "set simple value", + initialYAML: `cluster: + talosSecrets: "" +`, + key: "cluster.talosSecrets", + value: "new-secret-value", + verifyFunc: func(t *testing.T, secretsPath string) { + m := NewManager() + got, err := m.GetSecret(secretsPath, "cluster.talosSecrets") + if err != nil { + t.Fatalf("verify failed: %v", err) + } + if got != "new-secret-value" { + t.Errorf("got %q, want %q", got, "new-secret-value") + } + }, + }, + { + name: "set nested value", + initialYAML: `certManager: + cloudflare: + apiToken: "" +`, + key: "certManager.cloudflare.apiToken", + value: "cf-token-xyz", + verifyFunc: func(t *testing.T, secretsPath string) { + m := NewManager() + got, err := m.GetSecret(secretsPath, "certManager.cloudflare.apiToken") + if err != nil { + t.Fatalf("verify failed: %v", err) + } + if got != "cf-token-xyz" { + t.Errorf("got %q, want %q", got, "cf-token-xyz") + } + }, + }, + { + name: "update existing value", + initialYAML: `cluster: + talosSecrets: "old-secret" +`, + key: "cluster.talosSecrets", + value: "new-secret", + verifyFunc: func(t *testing.T, secretsPath string) { + m := NewManager() + got, err := m.GetSecret(secretsPath, "cluster.talosSecrets") + if err != nil { + t.Fatalf("verify failed: %v", err) + } + if got != "new-secret" { + t.Errorf("got %q, want %q", got, "new-secret") + } + }, + }, + { + name: "create new nested path", + initialYAML: `cluster: {} +`, + key: "cluster.newSecret", + value: "newValue", + verifyFunc: func(t *testing.T, secretsPath string) { + m := NewManager() + got, err := m.GetSecret(secretsPath, "cluster.newSecret") + if err != nil { + t.Fatalf("verify failed: %v", err) + } + if got != "newValue" { + t.Errorf("got %q, want %q", got, "newValue") + } + }, + }, + { + name: "set value with special characters", + initialYAML: `cluster: + talosSecrets: "" +`, + key: "cluster.talosSecrets", + value: `special"quotes'and\backslashes`, + verifyFunc: func(t *testing.T, secretsPath string) { + m := NewManager() + got, err := m.GetSecret(secretsPath, "cluster.talosSecrets") + if err != nil { + t.Fatalf("verify failed: %v", err) + } + if got != `special"quotes'and\backslashes` { + t.Errorf("got %q, want %q", got, `special"quotes'and\backslashes`) + } + }, + }, } - secretsPath := filepath.Join(instancePath, "secrets.yaml") + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + secretsPath := filepath.Join(tempDir, "secrets.yaml") - // Verify file exists + if err := storage.WriteFile(secretsPath, []byte(tt.initialYAML), 0600); err != nil { + t.Fatalf("setup failed: %v", err) + } + + m := NewManager() + if err := m.SetSecret(secretsPath, tt.key, tt.value); err != nil { + t.Errorf("SetSecret failed: %v", err) + return + } + + // Verify the value was set correctly + tt.verifyFunc(t, secretsPath) + + // Verify permissions remain secure (0600) + info, err := os.Stat(secretsPath) + if err != nil { + t.Fatalf("failed to stat secrets file: %v", err) + } + if info.Mode().Perm() != 0600 { + t.Errorf("permissions changed after SetSecret: got %o, want 0600", info.Mode().Perm()) + } + }) + } +} + +// Test: SetSecret error cases +func TestSetSecret_Errors(t *testing.T) { + tests := []struct { + name string + setupFunc func(t *testing.T) string + key string + value string + errContains string + }{ + { + name: "non-existent file", + setupFunc: func(t *testing.T) string { + return filepath.Join(t.TempDir(), "nonexistent.yaml") + }, + key: "cluster.talosSecrets", + value: "secret", + errContains: "secrets file not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + secretsPath := tt.setupFunc(t) + m := NewManager() + + err := m.SetSecret(secretsPath, tt.key, tt.value) + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), tt.errContains) { + t.Errorf("error %q does not contain %q", err.Error(), tt.errContains) + } + }) + } +} + +// Test: SetSecret with concurrent access +func TestSetSecret_ConcurrentAccess(t *testing.T) { + tempDir := t.TempDir() + secretsPath := filepath.Join(tempDir, "secrets.yaml") + + initialYAML := `counter: "0" +` + if err := storage.WriteFile(secretsPath, []byte(initialYAML), 0600); err != nil { + t.Fatalf("setup failed: %v", err) + } + + m := NewManager() + const numGoroutines = 10 + + var wg sync.WaitGroup + errors := make(chan error, numGoroutines) + + // Launch multiple goroutines trying to write different values + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(val int) { + defer wg.Done() + key := "counter" + value := string(rune('0' + val)) + if err := m.SetSecret(secretsPath, key, value); err != nil { + errors <- err + } + }(i) + } + + wg.Wait() + close(errors) + + // Check if any errors occurred + for err := range errors { + t.Errorf("concurrent write error: %v", err) + } + + // Verify permissions remain secure after concurrent access info, err := os.Stat(secretsPath) if err != nil { - t.Fatalf("Secrets file not created: %v", err) + t.Fatalf("failed to stat secrets file: %v", err) + } + if info.Mode().Perm() != 0600 { + t.Errorf("permissions changed after concurrent writes: got %o, want 0600", info.Mode().Perm()) } - // Verify permissions are 0600 - mode := info.Mode().Perm() - if mode != 0600 { - t.Errorf("Wrong permissions: got %o, want 0600", mode) - } - - // Test idempotency - calling again should not error - err = m.EnsureSecretsFile(instancePath) + // Verify we can read a value (should be one of the written values) + value, err := m.GetSecret(secretsPath, "counter") if err != nil { - t.Fatalf("EnsureSecretsFile not idempotent: %v", err) + t.Errorf("failed to read value after concurrent writes: %v", err) + } + if value == "" || value == "null" { + t.Error("counter value is empty after concurrent writes") } } -func TestManager_SetAndGetSecret(t *testing.T) { - tmpDir := t.TempDir() +// Test: EnsureSecret generates and sets secret only when not set +func TestEnsureSecret(t *testing.T) { + tests := []struct { + name string + initialYAML string + key string + length int + expectNew bool + }{ + { + name: "generates secret when empty string", + initialYAML: `cluster: + talosSecrets: "" +`, + key: "cluster.talosSecrets", + length: 32, + expectNew: true, + }, + { + name: "generates secret when null", + initialYAML: `cluster: + talosSecrets: null +`, + key: "cluster.talosSecrets", + length: 32, + expectNew: true, + }, + { + name: "does not generate when secret exists", + initialYAML: `cluster: + talosSecrets: "existing-secret" +`, + key: "cluster.talosSecrets", + length: 32, + expectNew: false, + }, + { + name: "generates secret for non-existent key", + initialYAML: `cluster: {} +`, + key: "cluster.newSecret", + length: 64, + expectNew: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + secretsPath := filepath.Join(tempDir, "secrets.yaml") + + if err := storage.WriteFile(secretsPath, []byte(tt.initialYAML), 0600); err != nil { + t.Fatalf("setup failed: %v", err) + } + + m := NewManager() + + // Get initial value if exists + initialVal, _ := m.GetSecret(secretsPath, tt.key) + + // Call EnsureSecret + secret, err := m.EnsureSecret(secretsPath, tt.key, tt.length) + if err != nil { + t.Errorf("EnsureSecret failed: %v", err) + return + } + + // Verify secret is returned + if secret == "" { + t.Error("EnsureSecret returned empty secret") + } + + // Verify length + if tt.expectNew && len(secret) != tt.length { + t.Errorf("expected secret length %d, got %d", tt.length, len(secret)) + } + + // Get final value + finalVal, err := m.GetSecret(secretsPath, tt.key) + if err != nil { + t.Fatalf("GetSecret failed: %v", err) + } + + if tt.expectNew { + // Should have generated new secret + if initialVal != "" && finalVal == initialVal { + t.Errorf("expected new secret, got same value: %q", finalVal) + } + } else { + // Should have kept existing secret + if finalVal != initialVal { + t.Errorf("expected to keep existing secret %q, got %q", initialVal, finalVal) + } + } + + // Call EnsureSecret again - should be idempotent + secret2, err := m.EnsureSecret(secretsPath, tt.key, tt.length) + if err != nil { + t.Errorf("second EnsureSecret failed: %v", err) + return + } + + // Secret should not change on second call + if secret2 != secret { + t.Errorf("secret changed on second ensure: %q -> %q", secret, secret2) + } + }) + } +} + +// Test: GenerateAndStoreSecret convenience function +func TestGenerateAndStoreSecret(t *testing.T) { + tempDir := t.TempDir() + secretsPath := filepath.Join(tempDir, "secrets.yaml") + + initialYAML := `cluster: + talosSecrets: "" +` + if err := storage.WriteFile(secretsPath, []byte(initialYAML), 0600); err != nil { + t.Fatalf("setup failed: %v", err) + } + m := NewManager() - instancePath := filepath.Join(tmpDir, "test-cloud") - err := os.MkdirAll(instancePath, 0755) + // Generate and store secret + secret, err := m.GenerateAndStoreSecret(secretsPath, "cluster.talosSecrets") if err != nil { - t.Fatalf("Failed to create instance dir: %v", err) + t.Fatalf("GenerateAndStoreSecret failed: %v", err) } - secretsPath := filepath.Join(instancePath, "secrets.yaml") - - // Initialize secrets - err = m.EnsureSecretsFile(instancePath) - if err != nil { - t.Fatalf("EnsureSecretsFile failed: %v", err) + // Verify secret length matches default + if len(secret) != DefaultSecretLength { + t.Errorf("expected length %d, got %d", DefaultSecretLength, len(secret)) } - // Set a custom secret (requires yq) - err = m.SetSecret(secretsPath, "customSecret", "myvalue123") - if err != nil { - t.Skipf("SetSecret requires yq: %v", err) - return - } - - // Get the secret back - value, err := m.GetSecret(secretsPath, "customSecret") + // Verify secret is stored + stored, err := m.GetSecret(secretsPath, "cluster.talosSecrets") if err != nil { t.Fatalf("GetSecret failed: %v", err) } - if value != "myvalue123" { - t.Errorf("Secret not retrieved correctly: got %q, want %q", value, "myvalue123") - } - - // Verify permissions still 0600 - info, _ := os.Stat(secretsPath) - if info.Mode().Perm() != 0600 { - t.Errorf("Permissions changed after SetSecret") - } - - // Get non-existent secret should error - _, err = m.GetSecret(secretsPath, "nonExistent") - if err == nil { - t.Fatalf("GetSecret should fail for non-existent secret") + if stored != secret { + t.Errorf("stored secret %q does not match returned secret %q", stored, secret) + } +} + +// Test: DeleteSecret removes secrets correctly +func TestDeleteSecret(t *testing.T) { + tests := []struct { + name string + initialYAML string + key string + wantErr bool + }{ + { + name: "delete existing secret", + initialYAML: `cluster: + talosSecrets: "secret-to-delete" + kubeconfig: "other-secret" +`, + key: "cluster.talosSecrets", + wantErr: false, + }, + { + name: "delete nested secret", + initialYAML: `certManager: + cloudflare: + apiToken: "token-to-delete" +`, + key: "certManager.cloudflare.apiToken", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tempDir := t.TempDir() + secretsPath := filepath.Join(tempDir, "secrets.yaml") + + if err := storage.WriteFile(secretsPath, []byte(tt.initialYAML), 0600); err != nil { + t.Fatalf("setup failed: %v", err) + } + + m := NewManager() + + // Verify secret exists before deletion + _, err := m.GetSecret(secretsPath, tt.key) + if err != nil { + t.Fatalf("secret should exist before deletion: %v", err) + } + + // Delete secret + err = m.DeleteSecret(secretsPath, tt.key) + if tt.wantErr { + if err == nil { + t.Error("expected error, got nil") + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + + // Verify secret no longer exists + _, err = m.GetSecret(secretsPath, tt.key) + if err == nil { + t.Error("secret should not exist after deletion") + } + + // Verify permissions remain secure (0600) + info, err := os.Stat(secretsPath) + if err != nil { + t.Fatalf("failed to stat secrets file: %v", err) + } + if info.Mode().Perm() != 0600 { + t.Errorf("permissions changed after DeleteSecret: got %o, want 0600", info.Mode().Perm()) + } + }) + } +} + +// Test: DeleteSecret error cases +func TestDeleteSecret_Errors(t *testing.T) { + t.Run("non-existent file", func(t *testing.T) { + tempDir := t.TempDir() + secretsPath := filepath.Join(tempDir, "nonexistent.yaml") + + m := NewManager() + err := m.DeleteSecret(secretsPath, "cluster.talosSecrets") + + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), "secrets file not found") { + t.Errorf("error %q does not contain 'secrets file not found'", err.Error()) + } + }) +} + +// Test: File permissions are always 0600 +func TestEnsureSecretsFile_FilePermissions(t *testing.T) { + tempDir := t.TempDir() + m := NewManager() + + if err := m.EnsureSecretsFile(tempDir); err != nil { + t.Fatalf("EnsureSecretsFile failed: %v", err) + } + + secretsPath := filepath.Join(tempDir, "secrets.yaml") + info, err := os.Stat(secretsPath) + if err != nil { + t.Fatalf("failed to stat secrets file: %v", err) + } + + // Verify file has 0600 permissions (read/write for owner only) + if info.Mode().Perm() != 0600 { + t.Errorf("expected permissions 0600, got %o", info.Mode().Perm()) + } +} + +// Test: Idempotent secrets creation +func TestEnsureSecretsFile_Idempotent(t *testing.T) { + tempDir := t.TempDir() + m := NewManager() + + // First call creates secrets + if err := m.EnsureSecretsFile(tempDir); err != nil { + t.Fatalf("first EnsureSecretsFile failed: %v", err) + } + + secretsPath := filepath.Join(tempDir, "secrets.yaml") + firstContent, err := storage.ReadFile(secretsPath) + if err != nil { + t.Fatalf("failed to read secrets: %v", err) + } + + // Second call should not modify secrets + if err := m.EnsureSecretsFile(tempDir); err != nil { + t.Fatalf("second EnsureSecretsFile failed: %v", err) + } + + secondContent, err := storage.ReadFile(secretsPath) + if err != nil { + t.Fatalf("failed to read secrets: %v", err) + } + + if string(firstContent) != string(secondContent) { + t.Error("secrets content changed on second call") + } +} + +// Test: Secrets structure contains required fields +func TestEnsureSecretsFile_RequiredFields(t *testing.T) { + tempDir := t.TempDir() + m := NewManager() + + if err := m.EnsureSecretsFile(tempDir); err != nil { + t.Fatalf("EnsureSecretsFile failed: %v", err) + } + + secretsPath := filepath.Join(tempDir, "secrets.yaml") + content, err := storage.ReadFile(secretsPath) + if err != nil { + t.Fatalf("failed to read secrets: %v", err) + } + + contentStr := string(content) + requiredFields := []string{ + "cluster:", + "talosSecrets:", + "kubeconfig:", + "certManager:", + "cloudflare:", + "apiToken:", + } + + for _, field := range requiredFields { + if !strings.Contains(contentStr, field) { + t.Errorf("secrets missing required field: %s", field) + } + } + + // Verify warning comment exists + if !strings.Contains(contentStr, "WARNING") || !strings.Contains(contentStr, "sensitive") { + t.Error("secrets file missing security warning comment") + } +} + +// Test: Secrets are more restrictive than config +func TestSecretsPermissions_MoreRestrictiveThanConfig(t *testing.T) { + tempDir := t.TempDir() + secretsPath := filepath.Join(tempDir, "secrets.yaml") + configPath := filepath.Join(tempDir, "config.yaml") + + // Create secrets file + m := NewManager() + if err := m.EnsureSecretsFile(tempDir); err != nil { + t.Fatalf("EnsureSecretsFile failed: %v", err) + } + + // Create config file (typically 0644) + configContent := `baseDomain: "example.com"` + if err := storage.WriteFile(configPath, []byte(configContent), 0644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + + // Check permissions + secretsInfo, err := os.Stat(secretsPath) + if err != nil { + t.Fatalf("failed to stat secrets: %v", err) + } + + configInfo, err := os.Stat(configPath) + if err != nil { + t.Fatalf("failed to stat config: %v", err) + } + + secretsPerm := secretsInfo.Mode().Perm() + configPerm := configInfo.Mode().Perm() + + // Secrets (0600) should be more restrictive than config (0644) + if secretsPerm >= configPerm { + t.Errorf("secrets permissions %o should be more restrictive than config %o", secretsPerm, configPerm) + } + + // Secrets should not be group or world readable + if secretsPerm&0077 != 0 { + t.Errorf("secrets file should not have group/world permissions, got %o", secretsPerm) } } diff --git a/internal/storage/storage_test.go b/internal/storage/storage_test.go index a266ee1..fcce005 100644 --- a/internal/storage/storage_test.go +++ b/internal/storage/storage_test.go @@ -1,107 +1,481 @@ package storage import ( + "errors" + "io/fs" "os" "path/filepath" + "sync" + "sync/atomic" "testing" + "time" ) +func TestFileExists(t *testing.T) { + tests := []struct { + name string + setup func(tmpDir string) string + expected bool + }{ + { + name: "existing file returns true", + setup: func(tmpDir string) string { + path := filepath.Join(tmpDir, "test.txt") + if err := os.WriteFile(path, []byte("test"), 0644); err != nil { + t.Fatal(err) + } + return path + }, + expected: true, + }, + { + name: "non-existent file returns false", + setup: func(tmpDir string) string { + return filepath.Join(tmpDir, "nonexistent.txt") + }, + expected: false, + }, + { + name: "directory path returns true", + setup: func(tmpDir string) string { + path := filepath.Join(tmpDir, "testdir") + if err := os.Mkdir(path, 0755); err != nil { + t.Fatal(err) + } + return path + }, + expected: true, + }, + { + name: "empty path returns false", + setup: func(tmpDir string) string { + return "" + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + path := tt.setup(tmpDir) + got := FileExists(path) + if got != tt.expected { + t.Errorf("FileExists(%q) = %v, want %v", path, got, tt.expected) + } + }) + } +} + func TestEnsureDir(t *testing.T) { - tmpDir := t.TempDir() - testDir := filepath.Join(tmpDir, "test", "nested", "dir") - - err := EnsureDir(testDir, 0755) - if err != nil { - t.Fatalf("EnsureDir failed: %v", err) + tests := []struct { + name string + setup func(tmpDir string) (string, os.FileMode) + wantErr bool + }{ + { + name: "creates new directory", + setup: func(tmpDir string) (string, os.FileMode) { + return filepath.Join(tmpDir, "newdir"), 0755 + }, + wantErr: false, + }, + { + name: "idempotent - doesn't error if exists", + setup: func(tmpDir string) (string, os.FileMode) { + path := filepath.Join(tmpDir, "existingdir") + if err := os.Mkdir(path, 0755); err != nil { + t.Fatal(err) + } + return path, 0755 + }, + wantErr: false, + }, + { + name: "creates nested directories", + setup: func(tmpDir string) (string, os.FileMode) { + return filepath.Join(tmpDir, "a", "b", "c", "d"), 0755 + }, + wantErr: false, + }, } - // Verify directory exists - info, err := os.Stat(testDir) - if err != nil { - t.Fatalf("Directory not created: %v", err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + path, perm := tt.setup(tmpDir) + + err := EnsureDir(path, perm) + if (err != nil) != tt.wantErr { + t.Errorf("EnsureDir() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr { + info, err := os.Stat(path) + if err != nil { + t.Errorf("Directory not created: %v", err) + return + } + if !info.IsDir() { + t.Error("Path is not a directory") + } + } + }) } - if !info.IsDir() { - t.Fatalf("Path is not a directory") +} + +func TestReadFile(t *testing.T) { + tests := []struct { + name string + setup func(tmpDir string) string + wantData []byte + wantErr bool + errCheck func(error) bool + }{ + { + name: "read existing file", + setup: func(tmpDir string) string { + path := filepath.Join(tmpDir, "test.txt") + if err := os.WriteFile(path, []byte("test content"), 0644); err != nil { + t.Fatal(err) + } + return path + }, + wantData: []byte("test content"), + wantErr: false, + }, + { + name: "non-existent file", + setup: func(tmpDir string) string { + return filepath.Join(tmpDir, "nonexistent.txt") + }, + wantErr: true, + errCheck: func(err error) bool { + return errors.Is(err, fs.ErrNotExist) + }, + }, + { + name: "empty file", + setup: func(tmpDir string) string { + path := filepath.Join(tmpDir, "empty.txt") + if err := os.WriteFile(path, []byte{}, 0644); err != nil { + t.Fatal(err) + } + return path + }, + wantData: []byte{}, + wantErr: false, + }, } - // Calling again should be idempotent - err = EnsureDir(testDir, 0755) - if err != nil { - t.Fatalf("EnsureDir not idempotent: %v", err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + path := tt.setup(tmpDir) + + got, err := ReadFile(path) + if (err != nil) != tt.wantErr { + t.Errorf("ReadFile() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr && tt.errCheck != nil && !tt.errCheck(err) { + t.Errorf("ReadFile() error type mismatch: %v", err) + } + + if !tt.wantErr && string(got) != string(tt.wantData) { + t.Errorf("ReadFile() = %q, want %q", got, tt.wantData) + } + }) } } func TestWriteFile(t *testing.T) { - tmpDir := t.TempDir() - testFile := filepath.Join(tmpDir, "test.txt") - testData := []byte("test content") - - // Write file - err := WriteFile(testFile, testData, 0644) - if err != nil { - t.Fatalf("WriteFile failed: %v", err) + tests := []struct { + name string + setup func(tmpDir string) (string, []byte, os.FileMode) + validate func(t *testing.T, path string, data []byte, perm os.FileMode) + wantErr bool + }{ + { + name: "write new file", + setup: func(tmpDir string) (string, []byte, os.FileMode) { + return filepath.Join(tmpDir, "new.txt"), []byte("new content"), 0644 + }, + validate: func(t *testing.T, path string, data []byte, perm os.FileMode) { + got, err := os.ReadFile(path) + if err != nil { + t.Errorf("Failed to read written file: %v", err) + } + if string(got) != string(data) { + t.Errorf("Content = %q, want %q", got, data) + } + }, + }, + { + name: "overwrite existing file", + setup: func(tmpDir string) (string, []byte, os.FileMode) { + path := filepath.Join(tmpDir, "existing.txt") + if err := os.WriteFile(path, []byte("old content"), 0644); err != nil { + t.Fatal(err) + } + return path, []byte("new content"), 0644 + }, + validate: func(t *testing.T, path string, data []byte, perm os.FileMode) { + got, err := os.ReadFile(path) + if err != nil { + t.Errorf("Failed to read overwritten file: %v", err) + } + if string(got) != string(data) { + t.Errorf("Content = %q, want %q", got, data) + } + }, + }, + { + name: "correct permissions applied", + setup: func(tmpDir string) (string, []byte, os.FileMode) { + return filepath.Join(tmpDir, "perms.txt"), []byte("test"), 0600 + }, + validate: func(t *testing.T, path string, data []byte, perm os.FileMode) { + info, err := os.Stat(path) + if err != nil { + t.Errorf("Failed to stat file: %v", err) + return + } + if info.Mode().Perm() != perm { + t.Errorf("Permissions = %o, want %o", info.Mode().Perm(), perm) + } + }, + }, } - // Read file back - data, err := os.ReadFile(testFile) - if err != nil { - t.Fatalf("ReadFile failed: %v", err) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + path, data, perm := tt.setup(tmpDir) - if string(data) != string(testData) { - t.Fatalf("Data mismatch: got %q, want %q", string(data), string(testData)) - } -} + err := WriteFile(path, data, perm) + if (err != nil) != tt.wantErr { + t.Errorf("WriteFile() error = %v, wantErr %v", err, tt.wantErr) + return + } -func TestFileExists(t *testing.T) { - tmpDir := t.TempDir() - testFile := filepath.Join(tmpDir, "test.txt") - - // File should not exist initially - if FileExists(testFile) { - t.Fatalf("File should not exist") - } - - // Create file - err := WriteFile(testFile, []byte("test"), 0644) - if err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - - // File should exist now - if !FileExists(testFile) { - t.Fatalf("File should exist") + if !tt.wantErr && tt.validate != nil { + tt.validate(t, path, data, perm) + } + }) } } func TestWithLock(t *testing.T) { - tmpDir := t.TempDir() - lockFile := filepath.Join(tmpDir, "test.lock") - counter := 0 + t.Run("acquires and releases lock", func(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "test.lock") + executed := false - // Execute with lock - err := WithLock(lockFile, func() error { - counter++ - return nil + err := WithLock(lockPath, func() error { + executed = true + return nil + }) + + if err != nil { + t.Errorf("WithLock() error = %v", err) + } + if !executed { + t.Error("Function was not executed") + } }) - if err != nil { - t.Fatalf("WithLock failed: %v", err) - } - if counter != 1 { - t.Fatalf("Function not executed: counter=%d", counter) - } + t.Run("releases lock after executing", func(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "test.lock") - // Should be idempotent - can acquire lock multiple times sequentially - err = WithLock(lockFile, func() error { - counter++ - return nil + err := WithLock(lockPath, func() error { + return nil + }) + if err != nil { + t.Fatalf("First lock failed: %v", err) + } + + err = WithLock(lockPath, func() error { + return nil + }) + if err != nil { + t.Errorf("Second lock failed (lock not released): %v", err) + } }) - if err != nil { - t.Fatalf("WithLock failed on second call: %v", err) + + t.Run("concurrent access blocked", func(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "concurrent.lock") + + var counter atomic.Int32 + var wg sync.WaitGroup + goroutines := 10 + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + err := WithLock(lockPath, func() error { + current := counter.Load() + time.Sleep(10 * time.Millisecond) + counter.Store(current + 1) + return nil + }) + if err != nil { + t.Errorf("WithLock() error = %v", err) + } + }() + } + + wg.Wait() + + if counter.Load() != int32(goroutines) { + t.Errorf("Counter = %d, want %d (concurrent access not properly blocked)", counter.Load(), goroutines) + } + }) + + t.Run("lock released on error", func(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "error.lock") + testErr := errors.New("test error") + + err := WithLock(lockPath, func() error { + return testErr + }) + if err != testErr { + t.Errorf("Expected error %v, got %v", testErr, err) + } + + err = WithLock(lockPath, func() error { + return nil + }) + if err != nil { + t.Errorf("Lock not released after error: %v", err) + } + }) + + t.Run("lock released on panic", func(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "panic.lock") + + func() { + defer func() { + if r := recover(); r == nil { + t.Error("Expected panic") + } + }() + _ = WithLock(lockPath, func() error { + panic("test panic") + }) + }() + + err := WithLock(lockPath, func() error { + return nil + }) + if err != nil { + t.Errorf("Lock not released after panic: %v", err) + } + }) +} + +func TestLockManual(t *testing.T) { + t.Run("manual acquire and release", func(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "manual.lock") + + lock, err := AcquireLock(lockPath) + if err != nil { + t.Fatalf("AcquireLock() error = %v", err) + } + + err = lock.Release() + if err != nil { + t.Errorf("Release() error = %v", err) + } + }) + + t.Run("double release is safe", func(t *testing.T) { + tmpDir := t.TempDir() + lockPath := filepath.Join(tmpDir, "double.lock") + + lock, err := AcquireLock(lockPath) + if err != nil { + t.Fatalf("AcquireLock() error = %v", err) + } + + err = lock.Release() + if err != nil { + t.Errorf("First Release() error = %v", err) + } + + err = lock.Release() + if err != nil { + t.Errorf("Second Release() error = %v", err) + } + }) +} + +func TestEnsureFilePermissions(t *testing.T) { + tests := []struct { + name string + setup func(tmpDir string) string + perm os.FileMode + wantErr bool + errCheck func(error) bool + }{ + { + name: "sets permissions on existing file", + setup: func(tmpDir string) string { + path := filepath.Join(tmpDir, "test.txt") + if err := os.WriteFile(path, []byte("test"), 0644); err != nil { + t.Fatal(err) + } + return path + }, + perm: 0600, + wantErr: false, + }, + { + name: "non-existent file returns error", + setup: func(tmpDir string) string { + return filepath.Join(tmpDir, "nonexistent.txt") + }, + perm: 0644, + wantErr: true, + errCheck: func(err error) bool { + return errors.Is(err, fs.ErrNotExist) + }, + }, } - if counter != 2 { - t.Fatalf("Function not executed on second call: counter=%d", counter) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := t.TempDir() + path := tt.setup(tmpDir) + + err := EnsureFilePermissions(path, tt.perm) + if (err != nil) != tt.wantErr { + t.Errorf("EnsureFilePermissions() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr && tt.errCheck != nil && !tt.errCheck(err) { + t.Errorf("EnsureFilePermissions() error type mismatch: %v", err) + } + + if !tt.wantErr { + info, err := os.Stat(path) + if err != nil { + t.Errorf("Failed to stat file: %v", err) + return + } + if info.Mode().Perm() != tt.perm { + t.Errorf("Permissions = %o, want %o", info.Mode().Perm(), tt.perm) + } + } + }) } } diff --git a/internal/tools/kubectl_test.go b/internal/tools/kubectl_test.go new file mode 100644 index 0000000..34a559a --- /dev/null +++ b/internal/tools/kubectl_test.go @@ -0,0 +1,750 @@ +package tools + +import ( + "testing" + "time" +) + +func TestNewKubectl(t *testing.T) { + tests := []struct { + name string + kubeconfigPath string + }{ + { + name: "creates Kubectl with kubeconfig path", + kubeconfigPath: "/path/to/kubeconfig", + }, + { + name: "creates Kubectl with empty path", + kubeconfigPath: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := NewKubectl(tt.kubeconfigPath) + if k == nil { + t.Fatal("NewKubectl() returned nil") + } + if k.kubeconfigPath != tt.kubeconfigPath { + t.Errorf("kubeconfigPath = %q, want %q", k.kubeconfigPath, tt.kubeconfigPath) + } + }) + } +} + +func TestKubectlDeploymentExists(t *testing.T) { + tests := []struct { + name string + depName string + namespace string + skipTest bool + }{ + { + name: "check deployment exists", + depName: "test-deployment", + namespace: "default", + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires kubectl and running cluster") + } + + k := NewKubectl("") + exists := k.DeploymentExists(tt.depName, tt.namespace) + _ = exists // Result depends on actual cluster state + }) + } +} + +func TestKubectlGetPods(t *testing.T) { + tests := []struct { + name string + namespace string + detailed bool + skipTest bool + }{ + { + name: "get pods basic", + namespace: "default", + detailed: false, + skipTest: true, + }, + { + name: "get pods detailed", + namespace: "kube-system", + detailed: true, + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires kubectl and running cluster") + } + + k := NewKubectl("") + pods, err := k.GetPods(tt.namespace, tt.detailed) + + if err == nil { + if pods == nil { + t.Error("GetPods() returned nil slice without error") + } + // Verify pod structure + for i, pod := range pods { + if pod.Name == "" { + t.Errorf("pod[%d].Name is empty", i) + } + if pod.Status == "" { + t.Errorf("pod[%d].Status is empty", i) + } + if pod.Ready == "" { + t.Errorf("pod[%d].Ready is empty", i) + } + if pod.Age == "" { + t.Errorf("pod[%d].Age is empty", i) + } + if tt.detailed && pod.Containers == nil { + t.Errorf("pod[%d].Containers is nil in detailed mode", i) + } + } + } + }) + } +} + +func TestKubectlGetFirstPodName(t *testing.T) { + tests := []struct { + name string + namespace string + skipTest bool + }{ + { + name: "get first pod name", + namespace: "kube-system", + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires kubectl and running cluster") + } + + k := NewKubectl("") + podName, err := k.GetFirstPodName(tt.namespace) + + if err == nil { + if podName == "" { + t.Error("GetFirstPodName() returned empty string without error") + } + } + }) + } +} + +func TestKubectlGetPodContainers(t *testing.T) { + tests := []struct { + name string + namespace string + podName string + skipTest bool + }{ + { + name: "get pod containers", + namespace: "kube-system", + podName: "coredns-123", + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires kubectl and running cluster") + } + + k := NewKubectl("") + containers, err := k.GetPodContainers(tt.namespace, tt.podName) + + if err == nil { + if containers == nil { + t.Error("GetPodContainers() returned nil slice without error") + } + } + }) + } +} + +func TestKubectlGetDeployment(t *testing.T) { + tests := []struct { + name string + depName string + namespace string + skipTest bool + }{ + { + name: "get deployment info", + depName: "test-deployment", + namespace: "default", + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires kubectl and running cluster") + } + + k := NewKubectl("") + depInfo, err := k.GetDeployment(tt.depName, tt.namespace) + + if err == nil { + if depInfo == nil { + t.Error("GetDeployment() returned nil without error") + } + // Desired should be non-negative + if depInfo.Desired < 0 { + t.Errorf("Desired = %d, should be non-negative", depInfo.Desired) + } + } + }) + } +} + +func TestKubectlGetReplicas(t *testing.T) { + tests := []struct { + name string + namespace string + skipTest bool + }{ + { + name: "get replicas for namespace", + namespace: "default", + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires kubectl and running cluster") + } + + k := NewKubectl("") + replicaInfo, err := k.GetReplicas(tt.namespace) + + if err == nil { + if replicaInfo == nil { + t.Error("GetReplicas() returned nil without error") + } + // All values should be non-negative + if replicaInfo.Desired < 0 { + t.Error("Desired < 0") + } + if replicaInfo.Current < 0 { + t.Error("Current < 0") + } + if replicaInfo.Ready < 0 { + t.Error("Ready < 0") + } + if replicaInfo.Available < 0 { + t.Error("Available < 0") + } + } + }) + } +} + +func TestKubectlGetResources(t *testing.T) { + tests := []struct { + name string + namespace string + skipTest bool + }{ + { + name: "get resources for namespace", + namespace: "default", + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires kubectl and running cluster") + } + + k := NewKubectl("") + usage, err := k.GetResources(tt.namespace) + + if err == nil { + if usage == nil { + t.Error("GetResources() returned nil without error") + } + } + }) + } +} + +func TestKubectlGetRecentEvents(t *testing.T) { + tests := []struct { + name string + namespace string + limit int + skipTest bool + }{ + { + name: "get recent events", + namespace: "default", + limit: 10, + skipTest: true, + }, + { + name: "get all events with zero limit", + namespace: "default", + limit: 0, + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires kubectl and running cluster") + } + + k := NewKubectl("") + events, err := k.GetRecentEvents(tt.namespace, tt.limit) + + if err == nil { + if events == nil { + t.Error("GetRecentEvents() returned nil slice without error") + } + if tt.limit > 0 && len(events) > tt.limit { + t.Errorf("len(events) = %d, want <= %d", len(events), tt.limit) + } + } + }) + } +} + +func TestKubectlGetLogs(t *testing.T) { + tests := []struct { + name string + namespace string + podName string + opts LogOptions + skipTest bool + }{ + { + name: "get logs with tail", + namespace: "kube-system", + podName: "coredns-123", + opts: LogOptions{Tail: 100}, + skipTest: true, + }, + { + name: "get logs with container", + namespace: "kube-system", + podName: "coredns-123", + opts: LogOptions{Container: "coredns", Tail: 50}, + skipTest: true, + }, + { + name: "get previous logs", + namespace: "default", + podName: "test-pod", + opts: LogOptions{Previous: true, Tail: 100}, + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires kubectl and running cluster") + } + + k := NewKubectl("") + logs, err := k.GetLogs(tt.namespace, tt.podName, tt.opts) + + if err == nil { + if logs == nil { + t.Error("GetLogs() returned nil slice without error") + } + } + }) + } +} + +func TestKubectlStreamLogs(t *testing.T) { + tests := []struct { + name string + namespace string + podName string + opts LogOptions + skipTest bool + }{ + { + name: "stream logs", + namespace: "default", + podName: "test-pod", + opts: LogOptions{Tail: 10}, + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires kubectl and running cluster") + } + + k := NewKubectl("") + cmd, err := k.StreamLogs(tt.namespace, tt.podName, tt.opts) + + if err == nil { + if cmd == nil { + t.Error("StreamLogs() returned nil command without error") + } + } + }) + } +} + +func TestFormatAge(t *testing.T) { + tests := []struct { + name string + duration time.Duration + want string + }{ + { + name: "seconds", + duration: 45 * time.Second, + want: "45s", + }, + { + name: "minutes", + duration: 5 * time.Minute, + want: "5m", + }, + { + name: "hours", + duration: 3 * time.Hour, + want: "3h", + }, + { + name: "days", + duration: 48 * time.Hour, + want: "2d", + }, + { + name: "less than minute", + duration: 30 * time.Second, + want: "30s", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatAge(tt.duration) + if got != tt.want { + t.Errorf("formatAge(%v) = %q, want %q", tt.duration, got, tt.want) + } + }) + } +} + +func TestParseResourceQuantity(t *testing.T) { + tests := []struct { + name string + quantity string + want int64 + }{ + { + name: "millicores", + quantity: "500m", + want: 500, + }, + { + name: "cores as plain number", + quantity: "2", + want: 2, + }, + { + name: "Ki suffix", + quantity: "100Ki", + want: 100 * 1024, + }, + { + name: "Mi suffix", + quantity: "512Mi", + want: 512 * 1024 * 1024, + }, + { + name: "Gi suffix", + quantity: "2Gi", + want: 2 * 1024 * 1024 * 1024, + }, + { + name: "K suffix", + quantity: "100K", + want: 100 * 1000, + }, + { + name: "M suffix", + quantity: "500M", + want: 500 * 1000 * 1000, + }, + { + name: "G suffix", + quantity: "1G", + want: 1 * 1000 * 1000 * 1000, + }, + { + name: "empty string", + quantity: "", + want: 0, + }, + { + name: "whitespace", + quantity: " ", + want: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseResourceQuantity(tt.quantity) + if got != tt.want { + t.Errorf("parseResourceQuantity(%q) = %d, want %d", tt.quantity, got, tt.want) + } + }) + } +} + +func TestFormatCPU(t *testing.T) { + tests := []struct { + name string + millicores int64 + want string + }{ + { + name: "zero", + millicores: 0, + want: "0", + }, + { + name: "millicores", + millicores: 500, + want: "500m", + }, + { + name: "one core", + millicores: 1000, + want: "1.0", + }, + { + name: "two and half cores", + millicores: 2500, + want: "2.5", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatCPU(tt.millicores) + if got != tt.want { + t.Errorf("formatCPU(%d) = %q, want %q", tt.millicores, got, tt.want) + } + }) + } +} + +func TestFormatMemory(t *testing.T) { + tests := []struct { + name string + bytes int64 + want string + }{ + { + name: "zero", + bytes: 0, + want: "0", + }, + { + name: "bytes", + bytes: 512, + want: "512B", + }, + { + name: "kibibytes", + bytes: 1024, + want: "1.0Ki", + }, + { + name: "mebibytes", + bytes: 1024 * 1024, + want: "1.0Mi", + }, + { + name: "gibibytes", + bytes: 2 * 1024 * 1024 * 1024, + want: "2.0Gi", + }, + { + name: "tebibytes", + bytes: 1024 * 1024 * 1024 * 1024, + want: "1.0Ti", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatMemory(tt.bytes) + if got != tt.want { + t.Errorf("formatMemory(%d) = %q, want %q", tt.bytes, got, tt.want) + } + }) + } +} + +func TestPodInfoStruct(t *testing.T) { + t.Run("PodInfo has required fields", func(t *testing.T) { + pod := PodInfo{ + Name: "test-pod", + Status: "Running", + Ready: "1/1", + Restarts: 0, + Age: "5m", + Node: "node-1", + IP: "10.0.0.1", + } + + if pod.Name != "test-pod" { + t.Errorf("Name = %q, want %q", pod.Name, "test-pod") + } + if pod.Status != "Running" { + t.Errorf("Status = %q, want %q", pod.Status, "Running") + } + if pod.Ready != "1/1" { + t.Errorf("Ready = %q, want %q", pod.Ready, "1/1") + } + if pod.Restarts != 0 { + t.Errorf("Restarts = %d, want %d", pod.Restarts, 0) + } + }) +} + +func TestContainerInfoStruct(t *testing.T) { + t.Run("ContainerInfo has required fields", func(t *testing.T) { + container := ContainerInfo{ + Name: "test-container", + Image: "nginx:latest", + Ready: true, + RestartCount: 0, + State: ContainerState{ + Status: "running", + Since: time.Now(), + }, + } + + if container.Name != "test-container" { + t.Errorf("Name = %q, want %q", container.Name, "test-container") + } + if !container.Ready { + t.Error("Ready should be true") + } + if container.State.Status != "running" { + t.Errorf("State.Status = %q, want %q", container.State.Status, "running") + } + }) +} + +func TestDeploymentInfoStruct(t *testing.T) { + t.Run("DeploymentInfo has required fields", func(t *testing.T) { + dep := DeploymentInfo{ + Desired: 3, + Current: 3, + Ready: 3, + Available: 3, + } + + if dep.Desired != 3 { + t.Errorf("Desired = %d, want %d", dep.Desired, 3) + } + if dep.Current != 3 { + t.Errorf("Current = %d, want %d", dep.Current, 3) + } + }) +} + +func TestResourceMetricStruct(t *testing.T) { + t.Run("ResourceMetric has required fields", func(t *testing.T) { + metric := ResourceMetric{ + Used: "1.5", + Requested: "2.0", + Limit: "4.0", + Percentage: 37.5, + } + + if metric.Used != "1.5" { + t.Errorf("Used = %q, want %q", metric.Used, "1.5") + } + if metric.Percentage != 37.5 { + t.Errorf("Percentage = %f, want %f", metric.Percentage, 37.5) + } + }) +} + +func TestLogOptionsStruct(t *testing.T) { + t.Run("LogOptions has all option fields", func(t *testing.T) { + opts := LogOptions{ + Container: "nginx", + Tail: 100, + Previous: true, + Since: "5m", + SinceSeconds: 300, + } + + if opts.Container != "nginx" { + t.Errorf("Container = %q, want %q", opts.Container, "nginx") + } + if opts.Tail != 100 { + t.Errorf("Tail = %d, want %d", opts.Tail, 100) + } + if !opts.Previous { + t.Error("Previous should be true") + } + }) +} + +func TestKubernetesEventStruct(t *testing.T) { + t.Run("KubernetesEvent has required fields", func(t *testing.T) { + now := time.Now() + event := KubernetesEvent{ + Type: "Warning", + Reason: "BackOff", + Message: "Back-off restarting failed container", + Count: 5, + FirstSeen: now.Add(-5 * time.Minute), + LastSeen: now, + Object: "Pod/test-pod", + } + + if event.Type != "Warning" { + t.Errorf("Type = %q, want %q", event.Type, "Warning") + } + if event.Count != 5 { + t.Errorf("Count = %d, want %d", event.Count, 5) + } + }) +} diff --git a/internal/tools/talosctl_test.go b/internal/tools/talosctl_test.go new file mode 100644 index 0000000..cfc31e3 --- /dev/null +++ b/internal/tools/talosctl_test.go @@ -0,0 +1,558 @@ +package tools + +import ( + "os" + "path/filepath" + "testing" +) + +func TestNewTalosctl(t *testing.T) { + t.Run("creates Talosctl instance without config", func(t *testing.T) { + tc := NewTalosctl() + if tc == nil { + t.Fatal("NewTalosctl() returned nil") + } + if tc.talosconfigPath != "" { + t.Error("talosconfigPath should be empty for NewTalosctl()") + } + }) + + t.Run("creates Talosctl instance with config", func(t *testing.T) { + configPath := "/path/to/talosconfig" + tc := NewTalosconfigWithConfig(configPath) + if tc == nil { + t.Fatal("NewTalosconfigWithConfig() returned nil") + } + if tc.talosconfigPath != configPath { + t.Errorf("talosconfigPath = %q, want %q", tc.talosconfigPath, configPath) + } + }) +} + +func TestTalosconfigBuildArgs(t *testing.T) { + tests := []struct { + name string + talosconfigPath string + baseArgs []string + wantPrefix []string + }{ + { + name: "no talosconfig adds no prefix", + talosconfigPath: "", + baseArgs: []string{"version", "--short"}, + wantPrefix: nil, + }, + { + name: "with talosconfig adds prefix", + talosconfigPath: "/path/to/talosconfig", + baseArgs: []string{"version", "--short"}, + wantPrefix: []string{"--talosconfig", "/path/to/talosconfig"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc := &Talosctl{talosconfigPath: tt.talosconfigPath} + got := tc.buildArgs(tt.baseArgs) + + if tt.wantPrefix == nil { + // Should return baseArgs unchanged + if len(got) != len(tt.baseArgs) { + t.Errorf("buildArgs() length = %d, want %d", len(got), len(tt.baseArgs)) + } + for i, arg := range tt.baseArgs { + if i >= len(got) || got[i] != arg { + t.Errorf("buildArgs()[%d] = %q, want %q", i, got[i], arg) + } + } + } else { + // Should have prefix + baseArgs + expectedLen := len(tt.wantPrefix) + len(tt.baseArgs) + if len(got) != expectedLen { + t.Errorf("buildArgs() length = %d, want %d", len(got), expectedLen) + } + // Check prefix + for i, arg := range tt.wantPrefix { + if i >= len(got) || got[i] != arg { + t.Errorf("buildArgs() prefix[%d] = %q, want %q", i, got[i], arg) + } + } + // Check baseArgs follow prefix + for i, arg := range tt.baseArgs { + idx := len(tt.wantPrefix) + i + if idx >= len(got) || got[idx] != arg { + t.Errorf("buildArgs()[%d] = %q, want %q", idx, got[idx], arg) + } + } + } + }) + } +} + +func TestTalosconfigGenConfig(t *testing.T) { + tests := []struct { + name string + clusterName string + endpoint string + outputDir string + skipTest bool + }{ + { + name: "gen config with valid params", + clusterName: "test-cluster", + endpoint: "https://192.168.1.100:6443", + outputDir: "testdata", + skipTest: true, // Skip actual execution + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires talosctl binary") + } + + tmpDir := t.TempDir() + tc := NewTalosctl() + err := tc.GenConfig(tt.clusterName, tt.endpoint, tmpDir) + + // This will fail without talosctl, but tests the method signature + if err == nil { + // If it somehow succeeds, verify files were created + expectedFiles := []string{ + "controlplane.yaml", + "worker.yaml", + "talosconfig", + } + for _, file := range expectedFiles { + path := filepath.Join(tmpDir, file) + if _, err := os.Stat(path); os.IsNotExist(err) { + t.Errorf("Expected file not created: %s", file) + } + } + } + }) + } +} + +func TestTalosconfigApplyConfig(t *testing.T) { + tests := []struct { + name string + nodeIP string + configFile string + insecure bool + talosconfigPath string + skipTest bool + }{ + { + name: "apply config with all params", + nodeIP: "192.168.1.100", + configFile: "/path/to/config.yaml", + insecure: true, + skipTest: true, + }, + { + name: "apply config with talosconfig", + nodeIP: "192.168.1.100", + configFile: "/path/to/config.yaml", + insecure: false, + talosconfigPath: "/path/to/talosconfig", + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires talosctl binary") + } + + tc := NewTalosctl() + err := tc.ApplyConfig(tt.nodeIP, tt.configFile, tt.insecure, tt.talosconfigPath) + + // Will fail without talosctl, but tests method signature + _ = err + }) + } +} + +func TestTalosconfigGetDisks(t *testing.T) { + tests := []struct { + name string + nodeIP string + insecure bool + skipTest bool + }{ + { + name: "get disks in insecure mode", + nodeIP: "192.168.1.100", + insecure: true, + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires talosctl binary and running node") + } + + tc := NewTalosctl() + disks, err := tc.GetDisks(tt.nodeIP, tt.insecure) + + if err == nil { + // If successful, verify return type + if disks == nil { + t.Error("GetDisks() returned nil slice without error") + } + // Each disk should have path and size + for i, disk := range disks { + if disk.Path == "" { + t.Errorf("disk[%d].Path is empty", i) + } + if disk.Size <= 0 { + t.Errorf("disk[%d].Size = %d, want > 0", i, disk.Size) + } + // Size should be > 10GB per filtering + if disk.Size <= 10000000000 { + t.Errorf("disk[%d].Size = %d, should be filtered (> 10GB)", i, disk.Size) + } + } + } + }) + } +} + +func TestTalosconfigGetLinks(t *testing.T) { + tests := []struct { + name string + nodeIP string + insecure bool + skipTest bool + }{ + { + name: "get links in insecure mode", + nodeIP: "192.168.1.100", + insecure: true, + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires talosctl binary and running node") + } + + tc := NewTalosctl() + links, err := tc.GetLinks(tt.nodeIP, tt.insecure) + + if err == nil { + if links == nil { + t.Error("GetLinks() returned nil slice without error") + } + } + }) + } +} + +func TestTalosconfigGetRoutes(t *testing.T) { + tests := []struct { + name string + nodeIP string + insecure bool + skipTest bool + }{ + { + name: "get routes in insecure mode", + nodeIP: "192.168.1.100", + insecure: true, + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires talosctl binary and running node") + } + + tc := NewTalosctl() + routes, err := tc.GetRoutes(tt.nodeIP, tt.insecure) + + if err == nil { + if routes == nil { + t.Error("GetRoutes() returned nil slice without error") + } + } + }) + } +} + +func TestTalosconfigGetDefaultInterface(t *testing.T) { + tests := []struct { + name string + nodeIP string + insecure bool + skipTest bool + }{ + { + name: "get default interface", + nodeIP: "192.168.1.100", + insecure: true, + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires talosctl binary and running node") + } + + tc := NewTalosctl() + iface, err := tc.GetDefaultInterface(tt.nodeIP, tt.insecure) + + if err == nil { + if iface == "" { + t.Error("GetDefaultInterface() returned empty string without error") + } + } + }) + } +} + +func TestTalosconfigGetPhysicalInterface(t *testing.T) { + tests := []struct { + name string + nodeIP string + insecure bool + skipTest bool + }{ + { + name: "get physical interface", + nodeIP: "192.168.1.100", + insecure: true, + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires talosctl binary and running node") + } + + tc := NewTalosctl() + iface, err := tc.GetPhysicalInterface(tt.nodeIP, tt.insecure) + + if err == nil { + if iface == "" { + t.Error("GetPhysicalInterface() returned empty string without error") + } + // Should not be loopback + if iface == "lo" { + t.Error("GetPhysicalInterface() returned loopback interface") + } + } + }) + } +} + +func TestTalosconfigGetVersion(t *testing.T) { + tests := []struct { + name string + nodeIP string + insecure bool + want string // Expected for maintenance mode or version string + skipTest bool + }{ + { + name: "get version in insecure mode", + nodeIP: "192.168.1.100", + insecure: true, + skipTest: true, + }, + { + name: "get version in secure mode", + nodeIP: "192.168.1.100", + insecure: false, + skipTest: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.skipTest { + t.Skip("Skipping test that requires talosctl binary and running node") + } + + tc := NewTalosctl() + version, err := tc.GetVersion(tt.nodeIP, tt.insecure) + + if err == nil { + if version == "" { + t.Error("GetVersion() returned empty string without error") + } + // Version should be either "maintenance" or start with "v" + if version != "maintenance" && version[0] != 'v' { + t.Errorf("GetVersion() = %q, expected 'maintenance' or version starting with 'v'", version) + } + } + }) + } +} + +func TestTalosconfigValidate(t *testing.T) { + t.Run("validate checks for talosctl", func(t *testing.T) { + tc := NewTalosctl() + err := tc.Validate() + + // This will pass if talosctl is installed, fail otherwise + // We can't guarantee talosctl is installed in all test environments + _ = err + }) +} + +func TestDiskInfoStruct(t *testing.T) { + t.Run("DiskInfo has required fields", func(t *testing.T) { + disk := DiskInfo{ + Path: "/dev/sda", + Size: 1000000000000, // 1TB + } + + if disk.Path != "/dev/sda" { + t.Errorf("Path = %q, want %q", disk.Path, "/dev/sda") + } + if disk.Size != 1000000000000 { + t.Errorf("Size = %d, want %d", disk.Size, 1000000000000) + } + }) +} + +func TestTalosconfigResourceJSONParsing(t *testing.T) { + // This test verifies the logic of getResourceJSON without actually calling talosctl + t.Run("getResourceJSON uses correct command structure", func(t *testing.T) { + tc := &Talosctl{talosconfigPath: "/path/to/talosconfig"} + + // We can't easily test the actual command execution without mocking, + // but we can verify buildArgs works correctly + baseArgs := []string{"get", "disks", "--nodes", "192.168.1.100", "-o", "json"} + finalArgs := tc.buildArgs(baseArgs) + + // Should have talosconfig prepended + if len(finalArgs) < 2 || finalArgs[0] != "--talosconfig" { + t.Error("buildArgs() should prepend --talosconfig") + } + }) +} + +func TestTalosconfigInterfaceFiltering(t *testing.T) { + // Test the logic for filtering physical interfaces + tests := []struct { + name string + interfaceName string + linkType string + operState string + shouldAccept bool + }{ + { + name: "eth0 up and ethernet", + interfaceName: "eth0", + linkType: "ether", + operState: "up", + shouldAccept: true, + }, + { + name: "eno1 up and ethernet", + interfaceName: "eno1", + linkType: "ether", + operState: "up", + shouldAccept: true, + }, + { + name: "loopback should be filtered", + interfaceName: "lo", + linkType: "loopback", + operState: "up", + shouldAccept: false, + }, + { + name: "cni interface should be filtered", + interfaceName: "cni0", + linkType: "ether", + operState: "up", + shouldAccept: false, + }, + { + name: "flannel interface should be filtered", + interfaceName: "flannel.1", + linkType: "ether", + operState: "up", + shouldAccept: false, + }, + { + name: "docker interface should be filtered", + interfaceName: "docker0", + linkType: "ether", + operState: "up", + shouldAccept: false, + }, + { + name: "bridge interface should be filtered", + interfaceName: "br-1234", + linkType: "ether", + operState: "up", + shouldAccept: false, + }, + { + name: "veth interface should be filtered", + interfaceName: "veth123", + linkType: "ether", + operState: "up", + shouldAccept: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // This simulates the filtering logic in GetPhysicalInterface + id := tt.interfaceName + linkType := tt.linkType + operState := tt.operState + + shouldAccept := (linkType == "ether" && operState == "up" && + id != "lo" && + (id[:3] == "eth" || id[:2] == "en") && + !containsAny(id, []string{"cni", "flannel", "docker", "br-", "veth"})) + + if shouldAccept != tt.shouldAccept { + t.Errorf("Interface %q filtering = %v, want %v", id, shouldAccept, tt.shouldAccept) + } + }) + } +} + +// Helper function for interface filtering test +func containsAny(s string, substrs []string) bool { + for _, substr := range substrs { + if len(substr) > 0 { + if substr[len(substr)-1] == '-' { + // Prefix match for things like "br-" + if len(s) >= len(substr) && s[:len(substr)] == substr { + return true + } + } else { + // Contains match + if len(s) >= len(substr) { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + } + } + } + } + return false +} diff --git a/internal/tools/yq_test.go b/internal/tools/yq_test.go new file mode 100644 index 0000000..f16e29c --- /dev/null +++ b/internal/tools/yq_test.go @@ -0,0 +1,469 @@ +package tools + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestNewYQ(t *testing.T) { + t.Run("creates YQ instance with default path", func(t *testing.T) { + yq := NewYQ() + if yq == nil { + t.Fatal("NewYQ() returned nil") + } + if yq.yqPath == "" { + t.Error("yqPath should not be empty") + } + }) +} + +func TestYQGet(t *testing.T) { + tests := []struct { + name string + setup func(tmpDir string) (string, string) + expression string + want string + wantErr bool + }{ + { + name: "get simple value", + setup: func(tmpDir string) (string, string) { + yamlContent := `name: test +version: "1.0" +` + filePath := filepath.Join(tmpDir, "test.yaml") + if err := os.WriteFile(filePath, []byte(yamlContent), 0644); err != nil { + t.Fatal(err) + } + return filePath, ".name" + }, + want: "test", + wantErr: false, + }, + { + name: "get nested value", + setup: func(tmpDir string) (string, string) { + yamlContent := `person: + name: John + age: 30 +` + filePath := filepath.Join(tmpDir, "nested.yaml") + if err := os.WriteFile(filePath, []byte(yamlContent), 0644); err != nil { + t.Fatal(err) + } + return filePath, ".person.name" + }, + want: "John", + wantErr: false, + }, + { + name: "non-existent file returns error", + setup: func(tmpDir string) (string, string) { + return filepath.Join(tmpDir, "nonexistent.yaml"), ".name" + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Skip if yq is not available + if _, err := os.Stat("/usr/bin/yq"); os.IsNotExist(err) { + t.Skip("yq not installed, skipping test") + } + + tmpDir := t.TempDir() + filePath, expression := tt.setup(tmpDir) + + yq := NewYQ() + got, err := yq.Get(filePath, expression) + + if (err != nil) != tt.wantErr { + t.Errorf("Get() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && got != tt.want { + t.Errorf("Get() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestYQSet(t *testing.T) { + tests := []struct { + name string + setup func(tmpDir string) string + expression string + value string + verify func(t *testing.T, filePath string) + wantErr bool + }{ + { + name: "set simple value", + setup: func(tmpDir string) string { + yamlContent := `name: old` + filePath := filepath.Join(tmpDir, "test.yaml") + if err := os.WriteFile(filePath, []byte(yamlContent), 0644); err != nil { + t.Fatal(err) + } + return filePath + }, + expression: ".name", + value: "new", + verify: func(t *testing.T, filePath string) { + content, err := os.ReadFile(filePath) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(content), "new") { + t.Errorf("File does not contain expected value 'new': %s", content) + } + }, + wantErr: false, + }, + { + name: "set value with special characters", + setup: func(tmpDir string) string { + yamlContent := `message: hello` + filePath := filepath.Join(tmpDir, "special.yaml") + if err := os.WriteFile(filePath, []byte(yamlContent), 0644); err != nil { + t.Fatal(err) + } + return filePath + }, + expression: ".message", + value: `hello "world"`, + verify: func(t *testing.T, filePath string) { + content, err := os.ReadFile(filePath) + if err != nil { + t.Fatal(err) + } + // Should contain escaped quotes + if !strings.Contains(string(content), "hello") { + t.Errorf("File does not contain expected value: %s", content) + } + }, + wantErr: false, + }, + { + name: "expression without leading dot gets dot prepended", + setup: func(tmpDir string) string { + yamlContent := `key: value` + filePath := filepath.Join(tmpDir, "nodot.yaml") + if err := os.WriteFile(filePath, []byte(yamlContent), 0644); err != nil { + t.Fatal(err) + } + return filePath + }, + expression: "key", + value: "newvalue", + verify: func(t *testing.T, filePath string) { + content, err := os.ReadFile(filePath) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(content), "newvalue") { + t.Errorf("File does not contain expected value: %s", content) + } + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Skip if yq is not available + if _, err := os.Stat("/usr/bin/yq"); os.IsNotExist(err) { + t.Skip("yq not installed, skipping test") + } + + tmpDir := t.TempDir() + filePath := tt.setup(tmpDir) + + yq := NewYQ() + err := yq.Set(filePath, tt.expression, tt.value) + + if (err != nil) != tt.wantErr { + t.Errorf("Set() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && tt.verify != nil { + tt.verify(t, filePath) + } + }) + } +} + +func TestYQDelete(t *testing.T) { + tests := []struct { + name string + setup func(tmpDir string) string + expression string + verify func(t *testing.T, filePath string) + wantErr bool + }{ + { + name: "delete simple key", + setup: func(tmpDir string) string { + yamlContent := `name: test +version: "1.0" +` + filePath := filepath.Join(tmpDir, "delete.yaml") + if err := os.WriteFile(filePath, []byte(yamlContent), 0644); err != nil { + t.Fatal(err) + } + return filePath + }, + expression: ".name", + verify: func(t *testing.T, filePath string) { + content, err := os.ReadFile(filePath) + if err != nil { + t.Fatal(err) + } + if strings.Contains(string(content), "name:") { + t.Errorf("Key 'name' was not deleted: %s", content) + } + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Skip if yq is not available + if _, err := os.Stat("/usr/bin/yq"); os.IsNotExist(err) { + t.Skip("yq not installed, skipping test") + } + + tmpDir := t.TempDir() + filePath := tt.setup(tmpDir) + + yq := NewYQ() + err := yq.Delete(filePath, tt.expression) + + if (err != nil) != tt.wantErr { + t.Errorf("Delete() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && tt.verify != nil { + tt.verify(t, filePath) + } + }) + } +} + +func TestYQValidate(t *testing.T) { + tests := []struct { + name string + setup func(tmpDir string) string + wantErr bool + }{ + { + name: "valid YAML", + setup: func(tmpDir string) string { + yamlContent := `name: test +version: "1.0" +nested: + key: value +` + filePath := filepath.Join(tmpDir, "valid.yaml") + if err := os.WriteFile(filePath, []byte(yamlContent), 0644); err != nil { + t.Fatal(err) + } + return filePath + }, + wantErr: false, + }, + { + name: "invalid YAML", + setup: func(tmpDir string) string { + invalidYaml := `name: test + invalid indentation +version: "1.0" +` + filePath := filepath.Join(tmpDir, "invalid.yaml") + if err := os.WriteFile(filePath, []byte(invalidYaml), 0644); err != nil { + t.Fatal(err) + } + return filePath + }, + wantErr: true, + }, + { + name: "non-existent file", + setup: func(tmpDir string) string { + return filepath.Join(tmpDir, "nonexistent.yaml") + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Skip if yq is not available + if _, err := os.Stat("/usr/bin/yq"); os.IsNotExist(err) { + t.Skip("yq not installed, skipping test") + } + + tmpDir := t.TempDir() + filePath := tt.setup(tmpDir) + + yq := NewYQ() + err := yq.Validate(filePath) + + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestYQExec(t *testing.T) { + tests := []struct { + name string + setup func(tmpDir string) (string, []string) + wantErr bool + }{ + { + name: "exec with valid args", + setup: func(tmpDir string) (string, []string) { + yamlContent := `name: test` + filePath := filepath.Join(tmpDir, "exec.yaml") + if err := os.WriteFile(filePath, []byte(yamlContent), 0644); err != nil { + t.Fatal(err) + } + return filePath, []string{"eval", ".name", filePath} + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Skip if yq is not available + if _, err := os.Stat("/usr/bin/yq"); os.IsNotExist(err) { + t.Skip("yq not installed, skipping test") + } + + tmpDir := t.TempDir() + _, args := tt.setup(tmpDir) + + yq := NewYQ() + output, err := yq.Exec(args...) + + if (err != nil) != tt.wantErr { + t.Errorf("Exec() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && len(output) == 0 { + t.Error("Exec() returned empty output") + } + }) + } +} + +func TestCleanYQOutput(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "removes trailing newline", + input: "value\n", + want: "value", + }, + { + name: "converts null to empty string", + input: "null", + want: "", + }, + { + name: "removes whitespace", + input: " value \n", + want: "value", + }, + { + name: "handles empty string", + input: "", + want: "", + }, + { + name: "handles multiple newlines", + input: "value\n\n", + want: "value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CleanYQOutput(tt.input) + if got != tt.want { + t.Errorf("CleanYQOutput(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestYQMerge(t *testing.T) { + tests := []struct { + name string + setup func(tmpDir string) (string, string, string) + verify func(t *testing.T, outputPath string) + wantErr bool + }{ + { + name: "merge two files", + setup: func(tmpDir string) (string, string, string) { + file1 := filepath.Join(tmpDir, "file1.yaml") + file2 := filepath.Join(tmpDir, "file2.yaml") + output := filepath.Join(tmpDir, "output.yaml") + + if err := os.WriteFile(file1, []byte("key1: value1\n"), 0644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(file2, []byte("key2: value2\n"), 0644); err != nil { + t.Fatal(err) + } + + return file1, file2, output + }, + verify: func(t *testing.T, outputPath string) { + if _, err := os.Stat(outputPath); os.IsNotExist(err) { + t.Error("Output file was not created") + } + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Skip if yq is not available + if _, err := os.Stat("/usr/bin/yq"); os.IsNotExist(err) { + t.Skip("yq not installed, skipping test") + } + + tmpDir := t.TempDir() + file1, file2, output := tt.setup(tmpDir) + + yq := NewYQ() + err := yq.Merge(file1, file2, output) + + if (err != nil) != tt.wantErr { + t.Errorf("Merge() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if !tt.wantErr && tt.verify != nil { + tt.verify(t, output) + } + }) + } +}