From 36466509aeff3aa53e08cb90454547537440b172 Mon Sep 17 00:00:00 2001 From: Brandt Keller <43887158+brandtkeller@users.noreply.github.com> Date: Fri, 28 Jun 2024 15:43:00 -0700 Subject: [PATCH] fix(oscal): single model write operations support (#502) * fix(oscal): remove mutli-model write operations * fix(oscall): testing for GetOscalModel() * fix(oscal): remove assessment test file --------- Co-authored-by: Cole (Mike) Winberry <86802655+mike-winberry@users.noreply.github.com> Co-authored-by: Andy Mills <61879371+CloudBeard@users.noreply.github.com> --- src/pkg/common/oscal/complete-schema.go | 76 +++++++++++++++++--- src/pkg/common/oscal/complete-schema_test.go | 72 +++++++++++++++++++ 2 files changed, 138 insertions(+), 10 deletions(-) create mode 100644 src/pkg/common/oscal/complete-schema_test.go diff --git a/src/pkg/common/oscal/complete-schema.go b/src/pkg/common/oscal/complete-schema.go index e8c119fa2..f2421a4f3 100644 --- a/src/pkg/common/oscal/complete-schema.go +++ b/src/pkg/common/oscal/complete-schema.go @@ -3,6 +3,7 @@ package oscal import ( "bytes" "encoding/json" + "fmt" "os" "path/filepath" @@ -32,13 +33,18 @@ func NewOscalModel(data []byte) (*oscalTypes_1_1_2.OscalModels, error) { // supports both json and yaml func WriteOscalModel(filePath string, model *oscalTypes_1_1_2.OscalModels) error { - // if no path or directory add default filename - if filepath.Ext(filePath) == "" { - filePath = filepath.Join(filePath, "oscal.yaml") + modelType, err := GetOscalModel(model) + if err != nil { + return err } - if err := files.IsJsonOrYaml(filePath); err != nil { - return err + // if no path or directory add default filename + if filepath.Ext(filePath) == "" { + filePath = filepath.Join(filePath, fmt.Sprintf("%s.yaml", modelType)) + } else { + if err := files.IsJsonOrYaml(filePath); err != nil { + return err + } } if _, err := os.Stat(filePath); err == nil { @@ -51,9 +57,18 @@ func WriteOscalModel(filePath string, model *oscalTypes_1_1_2.OscalModels) error if err != nil { return err } + + existingModelType, err := GetOscalModel(existingModel) + if err != nil { + return nil + } + + if existingModelType != modelType { + return fmt.Errorf("cannot merge model %s with existing model %s", modelType, existingModelType) + } // Merge the existing model with the new model // re-assign to perform common operations below - model, err = MergeOscalModels(existingModel, model) + model, err = MergeOscalModels(existingModel, model, modelType) if err != nil { return err } @@ -71,7 +86,7 @@ func WriteOscalModel(filePath string, model *oscalTypes_1_1_2.OscalModels) error yamlEncoder.Encode(model) } - err := files.WriteOutput(b.Bytes(), filePath) + err = files.WriteOutput(b.Bytes(), filePath) if err != nil { return err } @@ -82,12 +97,12 @@ func WriteOscalModel(filePath string, model *oscalTypes_1_1_2.OscalModels) error } -func MergeOscalModels(existingModel *oscalTypes_1_1_2.OscalModels, newModel *oscalTypes_1_1_2.OscalModels) (*oscalTypes_1_1_2.OscalModels, error) { +func MergeOscalModels(existingModel *oscalTypes_1_1_2.OscalModels, newModel *oscalTypes_1_1_2.OscalModels, modelType string) (*oscalTypes_1_1_2.OscalModels, error) { var err error // Now to check each model type - currently only component definition and assessment-results apply // Component definition - if existingModel.ComponentDefinition != nil && newModel.ComponentDefinition != nil { + if modelType == "component" { merged, err := MergeComponentDefinitions(existingModel.ComponentDefinition, newModel.ComponentDefinition) if err != nil { return nil, err @@ -99,7 +114,7 @@ func MergeOscalModels(existingModel *oscalTypes_1_1_2.OscalModels, newModel *osc } // Assessment Results - if existingModel.AssessmentResults != nil && newModel.AssessmentResults != nil { + if modelType == "assessment-results" { merged, err := MergeAssessmentResults(existingModel.AssessmentResults, newModel.AssessmentResults) if err != nil { return existingModel, err @@ -112,3 +127,44 @@ func MergeOscalModels(existingModel *oscalTypes_1_1_2.OscalModels, newModel *osc return existingModel, err } + +func GetOscalModel(model *oscalTypes_1_1_2.OscalModels) (modelType string, err error) { + + // Check if one model present and all other nil - is there a better way to do this? + models := make([]string, 0) + + if model.Catalog != nil { + models = append(models, "catalog") + } + + if model.Profile != nil { + models = append(models, "profile") + } + + if model.ComponentDefinition != nil { + models = append(models, "component") + } + + if model.SystemSecurityPlan != nil { + models = append(models, "system-security-plan") + } + + if model.AssessmentPlan != nil { + models = append(models, "assessment-plan") + } + + if model.AssessmentResults != nil { + models = append(models, "assessment-results") + } + + if model.PlanOfActionAndMilestones != nil { + models = append(models, "poam") + } + + if len(models) > 1 { + return "", fmt.Errorf("%v models identified when only oneOf is permitted", len(models)) + } else { + return models[0], nil + } + +} diff --git a/src/pkg/common/oscal/complete-schema_test.go b/src/pkg/common/oscal/complete-schema_test.go new file mode 100644 index 000000000..9b0ef2824 --- /dev/null +++ b/src/pkg/common/oscal/complete-schema_test.go @@ -0,0 +1,72 @@ +package oscal_test + +import ( + oscalTypes_1_1_2 "github.com/defenseunicorns/go-oscal/src/types/oscal-1-1-2" + "github.com/defenseunicorns/lula/src/pkg/common/oscal" + "testing" +) + +func TestGetOscalModel(t *testing.T) { + t.Parallel() + + type TestCase struct { + Model oscalTypes_1_1_2.OscalModels + ModelType string + } + + testCases := []TestCase{ + { + Model: oscalTypes_1_1_2.OscalModels{ + Catalog: &oscalTypes_1_1_2.Catalog{}, + }, + ModelType: "catalog", + }, + { + Model: oscalTypes_1_1_2.OscalModels{ + Profile: &oscalTypes_1_1_2.Profile{}, + }, + ModelType: "profile", + }, + { + Model: oscalTypes_1_1_2.OscalModels{ + ComponentDefinition: &oscalTypes_1_1_2.ComponentDefinition{}, + }, + ModelType: "component", + }, + { + Model: oscalTypes_1_1_2.OscalModels{ + SystemSecurityPlan: &oscalTypes_1_1_2.SystemSecurityPlan{}, + }, + ModelType: "system-security-plan", + }, + { + Model: oscalTypes_1_1_2.OscalModels{ + AssessmentPlan: &oscalTypes_1_1_2.AssessmentPlan{}, + }, + ModelType: "assessment-plan", + }, + { + Model: oscalTypes_1_1_2.OscalModels{ + AssessmentResults: &oscalTypes_1_1_2.AssessmentResults{}, + }, + ModelType: "assessment-results", + }, + { + Model: oscalTypes_1_1_2.OscalModels{ + PlanOfActionAndMilestones: &oscalTypes_1_1_2.PlanOfActionAndMilestones{}, + }, + ModelType: "poam", + }, + } + for _, testCase := range testCases { + actual, err := oscal.GetOscalModel(&testCase.Model) + if err != nil { + t.Fatalf("unexpected error for model %s", testCase.ModelType) + } + expected := testCase.ModelType + if expected != actual { + t.Fatalf("error GetOscalModel: expected: %s | got: %s", expected, actual) + } + } + +}