diff --git a/docker-compose.dpu.yml b/docker-compose.dpu.yml index 5e281dd..7e9bd16 100644 --- a/docker-compose.dpu.yml +++ b/docker-compose.dpu.yml @@ -53,10 +53,7 @@ services: '--bootstrap-trust-anchor-cert', '/certs/opi.pem', '--device-end-entity-cert', '/certs/third_my_cert.pem', '--device-private-key', '/certs/third_private_key.pem', - '--serial-number', 'third-serial-number', - '--status-file-path', '/var/lib/sztp/status.json', - '--result-file-path', '/var/lib/sztp/result.json', - '--sym-link-dir', '/run/sztp'] + '--serial-number', 'third-serial-number'] agent2: <<: *agent @@ -65,10 +62,7 @@ services: '--bootstrap-trust-anchor-cert', '/certs/opi.pem', '--device-end-entity-cert', '/certs/second_my_cert.pem', '--device-private-key', '/certs/second_private_key.pem', - '--serial-number', 'second-serial-number', - '--status-file-path', '/var/lib/sztp/status.json', - '--result-file-path', '/var/lib/sztp/result.json', - '--sym-link-dir', '/run/sztp'] + '--serial-number', 'second-serial-number'] agent1: <<: *agent @@ -77,10 +71,7 @@ services: '--bootstrap-trust-anchor-cert', '/certs/opi.pem', '--device-end-entity-cert', '/certs/first_my_cert.pem', '--device-private-key', '/certs/first_private_key.pem', - '--serial-number', 'first-serial-number', - '--status-file-path', '/var/lib/sztp/status.json', - '--result-file-path', '/var/lib/sztp/result.json', - '--sym-link-dir', '/run/sztp'] + '--serial-number', 'first-serial-number'] agent4: <<: *agent @@ -89,10 +80,7 @@ services: '--bootstrap-trust-anchor-cert', '/certs/opi.pem', '--device-end-entity-cert', '/certs/first_my_cert.pem', '--device-private-key', '/certs/first_private_key.pem', - '--serial-number', 'first-serial-number', - '--status-file-path', '/var/lib/sztp/status.json', - '--result-file-path', '/var/lib/sztp/result.json', - '--sym-link-dir', '/run/sztp'] + '--serial-number', 'first-serial-number'] agent5: <<: *agent @@ -101,10 +89,7 @@ services: '--bootstrap-trust-anchor-cert', '/certs/opi.pem', '--device-end-entity-cert', '/certs/first_my_cert.pem', '--device-private-key', '/certs/first_private_key.pem', - '--serial-number', 'first-serial-number', - '--status-file-path', '/var/lib/sztp/status.json', - '--result-file-path', '/var/lib/sztp/result.json', - '--sym-link-dir', '/run/sztp'] + '--serial-number', 'first-serial-number'] volumes: client-certs: diff --git a/sztp-agent/cmd/daemon.go b/sztp-agent/cmd/daemon.go index 0dc0a0b..7cc4c22 100644 --- a/sztp-agent/cmd/daemon.go +++ b/sztp-agent/cmd/daemon.go @@ -33,15 +33,15 @@ func Daemon() *cobra.Command { deviceEndEntityCert string bootstrapTrustAnchorCert string statusFilePath string - resultFilePath string - symLinkDir string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ Use: "daemon", Short: "Run the daemon command", RunE: func(_ *cobra.Command, _ []string) error { - arrayChecker := []string{devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath} + arrayChecker := []string{devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath} if bootstrapURL != "" && dhcpLeaseFile != "" { return fmt.Errorf("'--bootstrap-url' and '--dhcp-lease-file' are mutualy exclusive") } @@ -55,15 +55,6 @@ func Daemon() *cobra.Command { _, err := url.ParseRequestURI(bootstrapURL) cobra.CheckErr(err) } - if statusFilePath == "" { - return fmt.Errorf("'--status-file-path' is required") - } - if resultFilePath == "" { - return fmt.Errorf("'--result-file-path' is required") - } - if symLinkDir == "" { - return fmt.Errorf("'--symlink-dir' is required") - } for _, filePath := range arrayChecker { info, err := os.Stat(filePath) cobra.CheckErr(err) @@ -87,9 +78,9 @@ func Daemon() *cobra.Command { flags.StringVar(&devicePrivateKey, "device-private-key", "/certs/private_key.pem", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "/certs/my_cert.pem", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "/certs/opi.pem", "Bootstrap server trust anchor Cert") - flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Path to the status file") - flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Path to the result file") - flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Path to the symlink directory") + flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Sym Link Directory") return cmd } diff --git a/sztp-agent/cmd/disable.go b/sztp-agent/cmd/disable.go index b086aeb..732fb1c 100644 --- a/sztp-agent/cmd/disable.go +++ b/sztp-agent/cmd/disable.go @@ -29,8 +29,8 @@ func Disable() *cobra.Command { deviceEndEntityCert string bootstrapTrustAnchorCert string statusFilePath string - resultFilePath string - symLinkDir string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ @@ -53,9 +53,9 @@ func Disable() *cobra.Command { flags.StringVar(&devicePrivateKey, "device-private-key", "", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "", "Bootstrap server trust anchor Cert") - flags.StringVar(&statusFilePath, "status-file-path", "", "Status file path") - flags.StringVar(&resultFilePath, "result-file-path", "", "Result file path") - flags.StringVar(&symLinkDir, "sym-link-dir", "", "Sym Link Directory") + flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Sym Link Directory") return cmd } diff --git a/sztp-agent/cmd/enable.go b/sztp-agent/cmd/enable.go index 2c6513d..2a75031 100644 --- a/sztp-agent/cmd/enable.go +++ b/sztp-agent/cmd/enable.go @@ -29,8 +29,8 @@ func Enable() *cobra.Command { deviceEndEntityCert string bootstrapTrustAnchorCert string statusFilePath string - resultFilePath string - symLinkDir string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ @@ -53,9 +53,9 @@ func Enable() *cobra.Command { flags.StringVar(&devicePrivateKey, "device-private-key", "", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "", "Bootstrap server trust anchor Cert") - flags.StringVar(&statusFilePath, "status-file-path", "", "Status file path") - flags.StringVar(&resultFilePath, "result-file-path", "", "Result file path") - flags.StringVar(&symLinkDir, "sym-link-dir", "", "Sym Link Directory") + flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Sym Link Directory") return cmd } diff --git a/sztp-agent/cmd/run.go b/sztp-agent/cmd/run.go index a76cb79..0178e62 100644 --- a/sztp-agent/cmd/run.go +++ b/sztp-agent/cmd/run.go @@ -33,8 +33,8 @@ func Run() *cobra.Command { deviceEndEntityCert string bootstrapTrustAnchorCert string statusFilePath string - resultFilePath string - symLinkDir string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ @@ -55,15 +55,6 @@ func Run() *cobra.Command { _, err := url.ParseRequestURI(bootstrapURL) cobra.CheckErr(err) } - if statusFilePath == "" { - return fmt.Errorf("'--status-file-path' is required") - } - if resultFilePath == "" { - return fmt.Errorf("'--result-file-path' is required") - } - if symLinkDir == "" { - return fmt.Errorf("'--symlink-dir' is required") - } for _, filePath := range arrayChecker { info, err := os.Stat(filePath) cobra.CheckErr(err) @@ -89,7 +80,7 @@ func Run() *cobra.Command { flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "/certs/opi.pem", "Bootstrap server trust anchor Cert") flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Status file path") flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Result file path") - flags.StringVar(&symLinkDir, "sym-link-dir", "", "Sym Link Directory") + flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Sym Link Directory") return cmd } diff --git a/sztp-agent/cmd/status.go b/sztp-agent/cmd/status.go index debbaa4..b8b1e1f 100644 --- a/sztp-agent/cmd/status.go +++ b/sztp-agent/cmd/status.go @@ -29,8 +29,8 @@ func Status() *cobra.Command { deviceEndEntityCert string bootstrapTrustAnchorCert string statusFilePath string - resultFilePath string - symLinkDir string + resultFilePath string + symLinkDir string ) cmd := &cobra.Command{ @@ -49,13 +49,13 @@ func Status() *cobra.Command { flags.StringVar(&bootstrapURL, "bootstrap-url", "", "Bootstrap server URL") flags.StringVar(&serialNumber, "serial-number", "", "Device's serial number") flags.StringVar(&dhcpLeaseFile, "dhcp-lease-file", "/var/lib/dhclient/dhclient.leases", "Device's dhclient leases file") - flags.StringVar(&devicePassword, "device-password", "", "Dehomevice's password") + flags.StringVar(&devicePassword, "device-password", "", "Device's password") flags.StringVar(&devicePrivateKey, "device-private-key", "", "Device's private key") flags.StringVar(&deviceEndEntityCert, "device-end-entity-cert", "", "Device's End Entity cert") flags.StringVar(&bootstrapTrustAnchorCert, "bootstrap-trust-anchor-cert", "", "Bootstrap server trust anchor Cert") - flags.StringVar(&statusFilePath, "status-file-path", "", "Status file path") - flags.StringVar(&resultFilePath, "result-file-path", "", "Result file path") - flags.StringVar(&symLinkDir, "sym-link-dir", "", "Sym Link Directory") + flags.StringVar(&statusFilePath, "status-file-path", "/var/lib/sztp/status.json", "Status file path") + flags.StringVar(&resultFilePath, "result-file-path", "/var/lib/sztp/result.json", "Result file path") + flags.StringVar(&symLinkDir, "sym-link-dir", "/run/sztp", "Sym Link Directory") return cmd } diff --git a/sztp-agent/pkg/secureagent/agent.go b/sztp-agent/pkg/secureagent/agent.go index eaf3383..5c43e7d 100644 --- a/sztp-agent/pkg/secureagent/agent.go +++ b/sztp-agent/pkg/secureagent/agent.go @@ -93,9 +93,9 @@ type Agent struct { BootstrapServerOnboardingInfo BootstrapServerOnboardingInfo // BootstrapServerOnboardingInfo structure BootstrapServerRedirectInfo BootstrapServerRedirectInfo // BootstrapServerRedirectInfo structure HttpClient HttpClient - StatusFilePath string // Path to the status file - ResultFilePath string // Path to the result file - SymLinkDir string // Path to the symlink directory for the status file + StatusFilePath string // Path to the status file + ResultFilePath string // Path to the result file + SymLinkDir string // Path to the symlink directory for the status file } func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, devicePrivateKey, deviceEndEntityCert, bootstrapTrustAnchorCert, statusFilePath, resultFilePath, symLinkDir string, httpClient HttpClient) *Agent { @@ -116,7 +116,7 @@ func NewAgent(bootstrapURL, serialNumber, dhcpLeaseFile, devicePassword, deviceP HttpClient: httpClient, StatusFilePath: statusFilePath, ResultFilePath: resultFilePath, - SymLinkDir: symLinkDir, + SymLinkDir: symLinkDir, } } diff --git a/sztp-agent/pkg/secureagent/agent_test.go b/sztp-agent/pkg/secureagent/agent_test.go index 25d75f9..fec69b9 100644 --- a/sztp-agent/pkg/secureagent/agent_test.go +++ b/sztp-agent/pkg/secureagent/agent_test.go @@ -829,9 +829,9 @@ func TestNewAgent(t *testing.T) { devicePrivateKey string deviceEndEntityCert string bootstrapTrustAnchorCert string - statusFilePath string - resultFilePath string - symLinkDir string + statusFilePath string + resultFilePath string + symLinkDir string } client := http.Client{} tests := []struct { @@ -851,7 +851,7 @@ func TestNewAgent(t *testing.T) { bootstrapTrustAnchorCert: "TestBootstrapTrustCert", statusFilePath: "TestStatusFilePath", resultFilePath: "TestResultFilePath", - symLinkDir: "TestSymLinkDir", + symLinkDir: "TestSymLinkDir", }, want: &Agent{ InputBootstrapURL: "TestBootstrap", @@ -866,7 +866,7 @@ func TestNewAgent(t *testing.T) { DhcpLeaseFile: "TestDhcpLeaseFile", StatusFilePath: "TestStatusFilePath", ResultFilePath: "TestResultFilePath", - SymLinkDir: "TestSymLinkDir", + SymLinkDir: "TestSymLinkDir", HttpClient: &client, }, }, diff --git a/sztp-agent/pkg/secureagent/configuration.go b/sztp-agent/pkg/secureagent/configuration.go index 95b35b6..f6594de 100644 --- a/sztp-agent/pkg/secureagent/configuration.go +++ b/sztp-agent/pkg/secureagent/configuration.go @@ -10,7 +10,7 @@ import ( func (a *Agent) copyConfigurationFile() error { log.Println("[INFO] Starting the Copy Configuration.") _ = a.doReportProgress(ProgressTypeConfigInitiated, "Configuration Initiated") - _ = a.updateAndSaveStatus("config", true, "") + _ = a.updateAndSaveStatus(StageTypeConfig, true, "") // Copy the configuration file to the device file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + "-config") if err != nil { @@ -37,7 +37,7 @@ func (a *Agent) copyConfigurationFile() error { } log.Println("[INFO] Configuration file copied successfully") _ = a.doReportProgress(ProgressTypeConfigComplete, "Configuration Complete") - _ = a.updateAndSaveStatus("config", false, "") + _ = a.updateAndSaveStatus(StageTypeConfig, false, "") return nil } @@ -45,20 +45,24 @@ func (a *Agent) launchScriptsConfiguration(typeOf string) error { var script, scriptName string var reportStart, reportEnd ProgressType switch typeOf { - case "post": + case POST: script = a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.PostConfigurationScript - scriptName = "post" + scriptName = POST reportStart = ProgressTypePostScriptInitiated reportEnd = ProgressTypePostScriptComplete default: // pre or default script = a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.PreConfigurationScript - scriptName = "pre" + scriptName = PRE reportStart = ProgressTypePreScriptInitiated reportEnd = ProgressTypePreScriptComplete } log.Println("[INFO] Starting the " + scriptName + "-configuration.") _ = a.doReportProgress(reportStart, "Report starting") - _ = a.updateAndSaveStatus(scriptName+"-script", true, "") + if scriptName == PRE { + _ = a.updateAndSaveStatus(StageTypePreScript, true, "") + } else if scriptName == POST { + _ = a.updateAndSaveStatus(StageTypePostScript, true, "") + } // nolint:gosec file, err := os.Create(ARTIFACTS_PATH + a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference + scriptName + "configuration.sh") if err != nil { @@ -92,7 +96,11 @@ func (a *Agent) launchScriptsConfiguration(typeOf string) error { } log.Println(string(out)) // remove it _ = a.doReportProgress(reportEnd, "Report end") - _ = a.updateAndSaveStatus(scriptName+"-script", false, "") + if scriptName == PRE { + _ = a.updateAndSaveStatus(StageTypePreScript, false, "") + } else if scriptName == POST { + _ = a.updateAndSaveStatus(StageTypePostScript, false, "") + } log.Println("[INFO] " + scriptName + "-Configuration script executed successfully") return nil } diff --git a/sztp-agent/pkg/secureagent/daemon.go b/sztp-agent/pkg/secureagent/daemon.go index 066e571..a8abd7b 100644 --- a/sztp-agent/pkg/secureagent/daemon.go +++ b/sztp-agent/pkg/secureagent/daemon.go @@ -37,51 +37,60 @@ func (a *Agent) RunCommandDaemon() error { log.Println("failed to prepare status: ", err) return err } + _ = a.updateAndSaveStatus(StageTypeIsCompleted, true, "") for { err := a.performBootstrapSequence() if err != nil { log.Println("[ERROR] Failed to perform the bootstrap sequence: ", err.Error()) log.Println("[INFO] Retrying in 5 seconds") time.Sleep(5 * time.Second) + _ = a.updateAndSaveStatus(StageTypeIsCompleted, false, err.Error()) continue } + _ = a.updateAndSaveStatus(StageTypeIsCompleted, false, "") return nil } } func (a *Agent) performBootstrapSequence() error { - _ = a.updateAndSaveStatus("bootstrap", true, "") var err error err = a.discoverBootstrapURLs() if err != nil { + _ = a.updateAndSaveStatus(StageTypeParsing, false, err.Error()) return err } err = a.doRequestBootstrapServerOnboardingInfo() if err != nil { + _ = a.updateAndSaveStatus(StageTypeOnboarding, false, err.Error()) return err } err = a.doHandleBootstrapRedirect() if err != nil { + _ = a.updateAndSaveStatus(StageTypeBootImage, false, err.Error()) return err } err = a.downloadAndValidateImage() if err != nil { + _ = a.updateAndSaveStatus(StageTypeBootImage, false, err.Error()) return err } err = a.copyConfigurationFile() if err != nil { + _ = a.updateAndSaveStatus(StageTypeConfig, false, err.Error()) return err } err = a.launchScriptsConfiguration(PRE) if err != nil { + _ = a.updateAndSaveStatus(StageTypePreScript, false, err.Error()) return err } err = a.launchScriptsConfiguration(POST) if err != nil { + _ = a.updateAndSaveStatus(StageTypePostScript, false, err.Error()) return err } _ = a.doReportProgress(ProgressTypeBootstrapComplete, "Bootstrap Complete") - _ = a.updateAndSaveStatus("bootstrap", false, "") + _ = a.updateAndSaveStatus(StageTypeBootstrap, false, "") return nil } @@ -148,7 +157,7 @@ func (a *Agent) doRequestBootstrapServerOnboardingInfo() error { } log.Println("[INFO] Response retrieved successfully") _ = a.doReportProgress(ProgressTypeBootstrapInitiated, "Bootstrap Initiated") - _ = a.updateAndSaveStatus("bootstrap", true, "") + _ = a.updateAndSaveStatus(StageTypeBootstrap, true, "") crypto := res.IetfSztpBootstrapServerOutput.ConveyedInformation newVal, err := base64.StdEncoding.DecodeString(crypto) if err != nil { diff --git a/sztp-agent/pkg/secureagent/image.go b/sztp-agent/pkg/secureagent/image.go index 4466698..7362d77 100644 --- a/sztp-agent/pkg/secureagent/image.go +++ b/sztp-agent/pkg/secureagent/image.go @@ -23,7 +23,7 @@ import ( func (a *Agent) downloadAndValidateImage() error { log.Printf("[INFO] Starting the Download Image: %v", a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI) _ = a.doReportProgress(ProgressTypeBootImageInitiated, "BootImage Initiated") - _ = a.updateAndSaveStatus("boot-image", true, "") + _ = a.updateAndSaveStatus(StageTypeBootImage, true, "") // Download the image from DownloadURI and save it to a file a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.InfoTimestampReference = fmt.Sprintf("%8d", time.Now().Unix()) for i, item := range a.BootstrapServerOnboardingInfo.IetfSztpConveyedInfoOnboardingInformation.BootImage.DownloadURI { @@ -79,7 +79,7 @@ func (a *Agent) downloadAndValidateImage() error { } log.Println("[INFO] Checksum verified successfully") _ = a.doReportProgress(ProgressTypeBootImageComplete, "BootImage Complete") - _ = a.updateAndSaveStatus("boot-image", false, "") + _ = a.updateAndSaveStatus(StageTypeBootImage, false, "") return nil default: return errors.New("unsupported hash algorithm") diff --git a/sztp-agent/pkg/secureagent/progress_test.go b/sztp-agent/pkg/secureagent/progress_test.go index 78def06..a821dcf 100644 --- a/sztp-agent/pkg/secureagent/progress_test.go +++ b/sztp-agent/pkg/secureagent/progress_test.go @@ -142,7 +142,7 @@ func TestAgent_doReportProgress(t *testing.T) { DhcpLeaseFile: "DHCPLEASEFILE", ProgressJSON: ProgressJSON{}, }, - wantErr: true, + wantErr: true, }, } for _, tt := range tests { diff --git a/sztp-agent/pkg/secureagent/status.go b/sztp-agent/pkg/secureagent/status.go index 656471c..6738c27 100644 --- a/sztp-agent/pkg/secureagent/status.go +++ b/sztp-agent/pkg/secureagent/status.go @@ -4,108 +4,166 @@ Copyright (C) 2022-2023 Intel Corporation Copyright (c) 2022 Dell Inc, or its subsidiaries. Copyright (C) 2022 Red Hat. */ - +// nolint // Package secureagent implements the secure agent package secureagent import ( - "encoding/json" "fmt" "log" - "os" "path/filepath" "time" ) +type StageType int64 + +const ( + StageTypeInit StageType = iota + StageTypeDownloadingFile + StageTypePendingReboot + StageTypeParsing + StageTypeOnboarding + StageTypeRedirect + StageTypeBootImage + StageTypePreScript + StageTypeConfig + StageTypePostScript + StageTypeBootstrap + StageTypeIsCompleted +) + +func (s StageType) String() string { + switch s { + case StageTypeInit: + return "init" + case StageTypeDownloadingFile: + return "downloading-file" + case StageTypePendingReboot: + return "pending-reboot" + case StageTypeParsing: + return "parsing" + case StageTypeOnboarding: + return "onboarding" + case StageTypeRedirect: + return "redirect" + case StageTypeBootImage: + return "boot-image" + case StageTypePreScript: + return "pre-script" + case StageTypeConfig: + return "config" + case StageTypePostScript: + return "post-script" + case StageTypeBootstrap: + return "bootstrap" + case StageTypeIsCompleted: + return "is-completed" + default: + return "unknown" + } +} + // Status represents the status of the provisioning process. type Status struct { - Init StageStatus `json:"init"` - DownloadingFile StageStatus `json:"downloading-file"` // not sure if this is needed - PendingReboot StageStatus `json:"pending-reboot"` - Parsing StageStatus `json:"parsing"` - BootImage StageStatus `json:"boot-image"` - PreScript StageStatus `json:"pre-script"` - Config StageStatus `json:"config"` - PostScript StageStatus `json:"post-script"` - Bootstrap StageStatus `json:"bootstrap"` - IsCompleted StageStatus `json:"is-completed"` - Informational string `json:"informational"` - DataSource string `json:"datasource"` - Stage string `json:"stage"` + Init StageStatus `json:"init"` + DownloadingFile StageStatus `json:"downloading-file"` + PendingReboot StageStatus `json:"pending-reboot"` + Parsing StageStatus `json:"parsing"` + Onboarding StageStatus `json:"onboarding"` + Redirect StageStatus `json:"redirect"` + BootImage StageStatus `json:"boot-image"` + PreScript StageStatus `json:"pre-script"` + Config StageStatus `json:"config"` + PostScript StageStatus `json:"post-script"` + Bootstrap StageStatus `json:"bootstrap"` + IsCompleted StageStatus `json:"is-completed"` + Informational string `json:"informational"` + Stage string `json:"stage"` } +// Result represents the result of the provisioning process. type Result struct { - DataSource string `json:"dat asource"` - Errors []string `json:"errors"` + Errors []string `json:"errors"` } +// StageStatus represents the status of a specific stage. type StageStatus struct { - Errors []string `json:"errors"` - Start float64 `json:"start"` - End float64 `json:"end"` -} - -// LoadStatusFile loads the current status.json from the filesystem. -func (a *Agent) loadStatusFile() (*Status, error) { - file, err := os.ReadFile(a.GetStatusFilePath()) - if err != nil { - return nil, err - } - var status Status - err = json.Unmarshal(file, &status) - if err != nil { - return nil, err - } - return &status, nil -} - -func (a *Agent) updateAndSaveStatus(stage string, isStart bool, errMsg string) error { - status, err := a.loadStatusFile() + Errors []string `json:"errors"` + Start float64 `json:"start"` + End float64 `json:"end"` +} + +func (a *Agent) getCurrStatus() (*Status, error) { + var status Status + err := loadFile(a.GetStatusFilePath(), &status) if err != nil { - fmt.Println("Creating a new status file.") - status = a.createNewStatus() + return nil, err } + return &status, nil +} - if err := a.updateStageStatus(status, stage, isStart, errMsg); err != nil { - return err +func (a *Agent) getCurrResult() (*Result, error) { + var result Result + err := loadFile(a.GetResultFilePath(), &result) + if err != nil { + return nil, err } - - return a.saveStatus(status) + return &result, nil } -// createNewStatus initializes a new Status object when status.json doesn't exist. func (a *Agent) createNewStatus() *Status { return &Status{ - DataSource: "ds", - Stage: "", + Stage: "", + IsCompleted: StageStatus{}, } } +// updateAndSaveStatus updates the status object for a specific stage and saves it to the status.json file. +func (a *Agent) updateAndSaveStatus(s StageType, isStart bool, errMsg string) error { + status, err := a.getCurrStatus() + if err != nil { + fmt.Println("Creating a new status file.") + status = a.createNewStatus() + } + + err = a.updateStageStatus(status, s, isStart, errMsg) + if err != nil { + return err + } + + return a.saveStatus(status) +} + // updateStageStatus updates the status object for a specific stage. -func (a *Agent) updateStageStatus(status *Status, stage string, isStart bool, errMsg string) error { +func (a *Agent) updateStageStatus(status *Status, stageType StageType, isStart bool, errMsg string) error { now := float64(time.Now().Unix()) + stage := stageType.String() - switch stage { - case "init": - updateStage(&status.Init, isStart, now, errMsg) - case "downloading-file": - updateStage(&status.DownloadingFile, isStart, now, errMsg) - case "pending-reboot": - updateStage(&status.PendingReboot, isStart, now, errMsg) - case "is-completed": - updateStage(&status.IsCompleted, isStart, now, errMsg) - case "parsing": - updateStage(&status.Parsing, isStart, now, errMsg) - case "boot-image": - updateStage(&status.BootImage, isStart, now, errMsg) - case "pre-script": - updateStage(&status.PreScript, isStart, now, errMsg) - case "config": - updateStage(&status.Config, isStart, now, errMsg) - case "post-script": - updateStage(&status.PostScript, isStart, now, errMsg) - case "bootstrap": - updateStage(&status.Bootstrap, isStart, now, errMsg) + switch stageType { + case StageTypeInit: + a.updateStage(&status.Init, isStart, now, errMsg) + case StageTypeDownloadingFile: + a.updateStage(&status.DownloadingFile, isStart, now, errMsg) + case StageTypePendingReboot: + a.updateStage(&status.PendingReboot, isStart, now, errMsg) + case StageTypeIsCompleted: + a.updateStage(&status.IsCompleted, isStart, now, errMsg) + case StageTypeParsing: + a.updateStage(&status.Parsing, isStart, now, errMsg) + case StageTypeOnboarding: + a.updateStage(&status.Onboarding, isStart, now, errMsg) + case StageTypeRedirect: + a.updateStage(&status.Redirect, isStart, now, errMsg) + case StageTypeBootImage: + a.updateStage(&status.BootImage, isStart, now, errMsg) + case StageTypePreScript: + a.updateStage(&status.PreScript, isStart, now, errMsg) + case StageTypeConfig: + a.updateStage(&status.Config, isStart, now, errMsg) + case StageTypePostScript: + a.updateStage(&status.PostScript, isStart, now, errMsg) + case StageTypeBootstrap: + a.updateStage(&status.Bootstrap, isStart, now, errMsg) default: return fmt.Errorf("unknown stage: %s", stage) @@ -113,15 +171,15 @@ func (a *Agent) updateStageStatus(status *Status, stage string, isStart bool, er // Update the current stage if isStart { - status.Stage = stage + status.Stage = stage + "-in-progress" } else { - status.Stage = "" + status.Stage = stage + "-completed" } return nil } -func updateStage(stageStatus *StageStatus, isStart bool, now float64, errMsg string) { +func (a *Agent) updateStage(stageStatus *StageStatus, isStart bool, now float64, errMsg string) { if isStart { stageStatus.Start = now stageStatus.End = 0 @@ -129,30 +187,47 @@ func updateStage(stageStatus *StageStatus, isStart bool, now float64, errMsg str stageStatus.End = now if errMsg != "" { stageStatus.Errors = append(stageStatus.Errors, errMsg) + err := a.updateAndSaveResult(errMsg) + if err != nil { + fmt.Printf("Failed to update and save result: %v\n", err) + } } } } -// SaveStatusToFile saves the Status object to the status.json file. func (a *Agent) saveStatus(status *Status) error { return saveToFile(status, a.GetStatusFilePath()) } -// SaveResultFile saves the Result object to the result.json file. func (a *Agent) saveResult(result *Result) error { return saveToFile(result, a.GetResultFilePath()) } +func (a *Agent) updateAndSaveResult(errMsg string) error { + result, err := a.getCurrResult() + if err != nil { + fmt.Println("Creating a new result file.") + result = &Result{ + Errors: []string{}, + } + } + + if errMsg != "" { + result.Errors = append(result.Errors, errMsg) + } + + return a.saveResult(result) +} + // RunCommandStatus runs the command in the background func (a *Agent) RunCommandStatus() error { log.Println("RunCommandStatus") - // read the status file and print the status in command line - status, err := a.loadStatusFile() - if err != nil { - log.Println("failed to load status file: ", err) - return err - } - fmt.Printf("Current status: %+v\n", status) + status, err := a.getCurrStatus() + if err != nil { + log.Println("failed to load status file: ", err) + return err + } + fmt.Printf("Current status: %+v\n", status) return nil } @@ -165,31 +240,34 @@ func (a *Agent) prepareStatus() error { return err } - if err := ensureFileExists(a.GetStatusFilePath()); err != nil { - return err - } - if err := ensureFileExists(a.GetResultFilePath()); err != nil { - return err - } + fmt.Println("Status File Path", a.GetStatusFilePath()) + fmt.Println("Result File Path", a.GetResultFilePath()) + + if err := ensureFileExists(a.GetStatusFilePath()); err != nil { + return err + } + if err := ensureFileExists(a.GetResultFilePath()); err != nil { + return err + } - statusSymlinkPath := filepath.Join(a.GetSymLinkDir(), "status.json") - resultSymlinkPath := filepath.Join(a.GetSymLinkDir(), "result.json") + statusSymlinkPath := filepath.Join(a.GetSymLinkDir(), "status.json") + resultSymlinkPath := filepath.Join(a.GetSymLinkDir(), "result.json") - // Create symlinks for status.json and result.json - if err := createSymlink(a.GetStatusFilePath(), statusSymlinkPath); err != nil { - fmt.Printf("Failed to create symlink for status.json: %v\n", err) - return err - } - if err := createSymlink(a.GetResultFilePath(), resultSymlinkPath); err != nil { - fmt.Printf("Failed to create symlink for result.json: %v\n", err) - return err - } + // Create symlinks for status.json and result.json + if err := createSymlink(a.GetStatusFilePath(), statusSymlinkPath); err != nil { + fmt.Printf("Failed to create symlink for status.json: %v\n", err) + return err + } + if err := createSymlink(a.GetResultFilePath(), resultSymlinkPath); err != nil { + fmt.Printf("Failed to create symlink for result.json: %v\n", err) + return err + } - fmt.Println("Symlinks created successfully.") + fmt.Println("Symlinks created successfully.") - if err := a.updateAndSaveStatus("init", true, ""); err != nil { - return err - } + if err := a.updateAndSaveStatus(StageTypeInit, true, ""); err != nil { + return err + } return nil } diff --git a/sztp-agent/pkg/secureagent/status_test.go b/sztp-agent/pkg/secureagent/status_test.go index d16d14c..afdf405 100644 --- a/sztp-agent/pkg/secureagent/status_test.go +++ b/sztp-agent/pkg/secureagent/status_test.go @@ -20,9 +20,9 @@ func TestAgent_RunCommandStatus(t *testing.T) { ProgressJSON ProgressJSON BootstrapServerOnboardingInfo BootstrapServerOnboardingInfo BootstrapServerRedirectInfo BootstrapServerRedirectInfo - StatusFilePath string - ResultFilePath string - SymLinkDir string + StatusFilePath string + ResultFilePath string + SymLinkDir string } tests := []struct { name string @@ -44,9 +44,9 @@ func TestAgent_RunCommandStatus(t *testing.T) { ProgressJSON: ProgressJSON{}, BootstrapServerRedirectInfo: BootstrapServerRedirectInfo{}, BootstrapServerOnboardingInfo: BootstrapServerOnboardingInfo{}, - StatusFilePath: "/var/lib/sztp/status.json", - ResultFilePath: "/var/lib/sztp/result.json", - SymLinkDir: "/run/sztp", + StatusFilePath: "/var/lib/sztp/status.json", + ResultFilePath: "/var/lib/sztp/result.json", + SymLinkDir: "/run/sztp", }, }, } @@ -65,11 +65,13 @@ func TestAgent_RunCommandStatus(t *testing.T) { ProgressJSON: tt.fields.ProgressJSON, BootstrapServerOnboardingInfo: tt.fields.BootstrapServerOnboardingInfo, BootstrapServerRedirectInfo: tt.fields.BootstrapServerRedirectInfo, - StatusFilePath: tt.fields.StatusFilePath, - ResultFilePath: tt.fields.ResultFilePath, - SymLinkDir: tt.fields.SymLinkDir, + StatusFilePath: tt.fields.StatusFilePath, + ResultFilePath: tt.fields.ResultFilePath, + SymLinkDir: tt.fields.SymLinkDir, + } + if err := a.prepareStatus(); err != nil { + t.Errorf("prepareStatus() error = %v", err) } - a.prepareStatus() if err := a.RunCommandStatus(); (err != nil) != tt.wantErr { t.Errorf("RunCommandStatus() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/sztp-agent/pkg/secureagent/utils.go b/sztp-agent/pkg/secureagent/utils.go index 0eddb2f..23ad4f6 100644 --- a/sztp-agent/pkg/secureagent/utils.go +++ b/sztp-agent/pkg/secureagent/utils.go @@ -89,14 +89,18 @@ func calculateSHA256File(filePath string) (string, error) { return checkSum, nil } -// saveToFile writes the given data to a specified file path. func saveToFile(data interface{}, filePath string) error { tempPath := filePath + ".tmp" + tempPath = filepath.Clean(tempPath) file, err := os.Create(tempPath) if err != nil { return err } - defer file.Close() + defer func() { + if err := file.Close(); err != nil { + log.Println("[ERROR] Error when closing:", err) + } + }() encoder := json.NewEncoder(file) if err := encoder.Encode(data); err != nil { @@ -107,53 +111,72 @@ func saveToFile(data interface{}, filePath string) error { return os.Rename(tempPath, filePath) } -// EnsureDirExists checks if a directory exists, and creates it if it doesn't. func ensureDirExists(dir string) error { - if _, err := os.Stat(dir); os.IsNotExist(err) { - err := os.MkdirAll(dir, 0755) // Create the directory with appropriate permissions - if err != nil { - return fmt.Errorf("failed to create directory %s: %v", dir, err) - } - } - return nil + if _, err := os.Stat(dir); os.IsNotExist(err) { + err := os.MkdirAll(dir, 0750) // Create the directory with appropriate permissions + if err != nil { + return fmt.Errorf("failed to create directory %s: %v", dir, err) + } + } + return nil } -// EnsureFile ensures that a file exists; creates it if it does not. func ensureFileExists(filePath string) error { - // Ensure the directory exists - dir := filepath.Dir(filePath) - if err := ensureDirExists(dir); err != nil { - return err - } - - // Check if the file already exists - if _, err := os.Stat(filePath); os.IsNotExist(err) { - // File does not exist, create it - file, err := os.Create(filePath) - if err != nil { - return fmt.Errorf("failed to create file %s: %v", filePath, err) - } - defer file.Close() - fmt.Printf("File %s created successfully.\n", filePath) - } else { - fmt.Printf("File %s already exists.\n", filePath) - } - return nil + dir := filepath.Dir(filePath) + if err := ensureDirExists(dir); err != nil { + return err + } + + if _, err := os.Stat(filePath); os.IsNotExist(err) { + filePath = filepath.Clean(filePath) + file, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("failed to create file %s: %v", filePath, err) + } + defer func() { + if err := file.Close(); err != nil { + log.Println("[ERROR] Error when closing:", err) + } + }() + fmt.Printf("File %s created successfully.\n", filePath) + } else { + fmt.Printf("File %s already exists.\n", filePath) + } + return nil } -// CreateSymlink creates a symlink for a file from target to link location. func createSymlink(targetFile, linkFile string) error { - // Ensure the directory for the symlink exists - linkDir := filepath.Dir(linkFile) - if err := ensureDirExists(linkDir); err != nil { - return err - } - - // Remove any existing symlink - if _, err := os.Lstat(linkFile); err == nil { - os.Remove(linkFile) - } - - // Create a new symlink - return os.Symlink(targetFile, linkFile) + targetFile = filepath.Clean(targetFile) + linkFile = filepath.Clean(linkFile) + + linkDir := filepath.Dir(linkFile) + if err := ensureDirExists(linkDir); err != nil { + return err + } + + // Check if linkFile exists and is a symlink to targetFile + if existingTarget, err := os.Readlink(linkFile); err == nil { + if existingTarget == targetFile { + return nil // Symlink already points to the target; skip creation + } + // Remove the existing file (even if it's a wrong symlink or regular file) + if err := os.Remove(linkFile); err != nil { + return err + } + } + + return os.Symlink(targetFile, linkFile) +} + +func loadFile(filePath string, v interface{}) error { + filePath = filepath.Clean(filePath) + file, err := os.ReadFile(filePath) + if err != nil { + return err + } + err = json.Unmarshal(file, v) + if err != nil { + return err + } + return nil } diff --git a/sztp-agent/pkg/secureagent/utils_test.go b/sztp-agent/pkg/secureagent/utils_test.go index 42e134f..a3217bf 100644 --- a/sztp-agent/pkg/secureagent/utils_test.go +++ b/sztp-agent/pkg/secureagent/utils_test.go @@ -84,7 +84,11 @@ func Test_saveToFile(t *testing.T) { if err != nil { t.Fatalf("failed to create temp directory: %v", err) } - defer os.RemoveAll(tempDir) + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatalf("failed to remove temp directory: %v", err) + } + }() filePath := filepath.Join(tempDir, "test.json") data := map[string]string{"key": "value"} @@ -99,11 +103,16 @@ func Test_saveToFile(t *testing.T) { t.Fatalf("file %s was not created", filePath) } + filePath = filepath.Clean(filePath) file, err := os.Open(filePath) if err != nil { t.Fatalf("failed to open the file: %v", err) } - defer file.Close() + defer func() { + if err := file.Close(); err != nil { + t.Fatalf("failed to close the file: %v", err) + } + }() var readData map[string]string decoder := json.NewDecoder(file) @@ -117,12 +126,16 @@ func Test_saveToFile(t *testing.T) { } } -func TestEnsureDirExists(t *testing.T) { +func Test_ensureDirExists(t *testing.T) { tempDir, err := os.MkdirTemp("", "test_ensure_dir_exists") if err != nil { t.Fatalf("failed to create temp directory: %v", err) } - defer os.RemoveAll(tempDir) + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatalf("failed to remove temp directory: %v", err) + } + }() newDir := filepath.Join(tempDir, "newdir") @@ -145,12 +158,16 @@ func TestEnsureDirExists(t *testing.T) { } } -func TestEnsureFileExists(t *testing.T) { +func Test_ensureFileExists(t *testing.T) { tempDir, err := os.MkdirTemp("", "test_ensure_file_exists") if err != nil { t.Fatalf("failed to create temp directory: %v", err) } - defer os.RemoveAll(tempDir) + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatalf("failed to remove temp directory: %v", err) + } + }() newFilePath := filepath.Join(tempDir, "newdir", "testfile.txt") @@ -169,17 +186,21 @@ func TestEnsureFileExists(t *testing.T) { } } -func TestCreateSymlink(t *testing.T) { +func Test_createSymlink(t *testing.T) { tempDir, err := os.MkdirTemp("", "test_create_symlink") if err != nil { t.Fatalf("failed to create temp directory: %v", err) } - defer os.RemoveAll(tempDir) + defer func() { + if err := os.RemoveAll(tempDir); err != nil { + t.Fatalf("failed to remove temp directory: %v", err) + } + }() targetFile := filepath.Join(tempDir, "target.txt") linkFile := filepath.Join(tempDir, "link.txt") - err = os.WriteFile(targetFile, []byte("test data"), 0644) + err = os.WriteFile(targetFile, []byte("test data"), 0600) if err != nil { t.Fatalf("failed to create target file: %v", err) } @@ -190,7 +211,6 @@ func TestCreateSymlink(t *testing.T) { } linkInfo, err := os.Lstat(linkFile) - t.Logf("linkInfo: %v", linkInfo) /// if err != nil { t.Fatalf("failed to stat symlink: %v", err) } @@ -207,7 +227,7 @@ func TestCreateSymlink(t *testing.T) { } newTargetFile := filepath.Join(tempDir, "new_target.txt") - err = os.WriteFile(newTargetFile, []byte("new data"), 0644) + err = os.WriteFile(newTargetFile, []byte("new data"), 0600) if err != nil { t.Fatalf("failed to create new target file: %v", err) }