diff --git a/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/SQLTransformerTests.cs b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/SQLTransformerTests.cs new file mode 100644 index 000000000..3ddfc9624 --- /dev/null +++ b/src/csharp/Microsoft.Spark.E2ETest/IpcTests/ML/Feature/SQLTransformerTests.cs @@ -0,0 +1,73 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.IO; +using Microsoft.Spark.ML.Feature; +using Microsoft.Spark.Sql; +using Microsoft.Spark.Sql.Types; +using Microsoft.Spark.UnitTest.TestUtils; +using Xunit; + +namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature +{ + [Collection("Spark E2E Tests")] + public class SQLTransformerTests : FeatureBaseTests + { + private readonly SparkSession _spark; + + public SQLTransformerTests(SparkFixture fixture) : base(fixture) + { + _spark = fixture.Spark; + } + + /// + /// Create a , create a and test the + /// available methods. + /// + [Fact] + public void TestSQLTransformer() + { + DataFrame input = _spark.CreateDataFrame( + new List + { + new GenericRow(new object[] { 0, 1.0, 3.0 }), + new GenericRow(new object[] { 2, 2.0, 5.0 }) + }, + new StructType(new List + { + new StructField("id", new IntegerType()), + new StructField("v1", new DoubleType()), + new StructField("v2", new DoubleType()) + })); + + string expectedUid = "theUid"; + string inputStatement = "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__"; + + SQLTransformer sqlTransformer = new SQLTransformer(expectedUid) + .SetStatement(inputStatement); + + string outputStatement = sqlTransformer.GetStatement(); + + DataFrame output = sqlTransformer.Transform(input); + StructType outputSchema = sqlTransformer.TransformSchema(input.Schema()); + + Assert.Contains(output.Schema().Fields, (f => f.Name == "v3")); + Assert.Contains(output.Schema().Fields, (f => f.Name == "v4")); + Assert.Contains(outputSchema.Fields, (f => f.Name == "v3")); + Assert.Contains(outputSchema.Fields, (f => f.Name == "v4")); + Assert.Equal(inputStatement, outputStatement); + + using (var tempDirectory = new TemporaryDirectory()) + { + string savePath = Path.Join(tempDirectory.Path, "SQLTransformer"); + sqlTransformer.Save(savePath); + + SQLTransformer loadedsqlTransformer = SQLTransformer.Load(savePath); + Assert.Equal(sqlTransformer.Uid(), loadedsqlTransformer.Uid()); + } + Assert.Equal(expectedUid, sqlTransformer.Uid()); + } + } +} diff --git a/src/csharp/Microsoft.Spark/ML/Feature/SQLTransformer.cs b/src/csharp/Microsoft.Spark/ML/Feature/SQLTransformer.cs new file mode 100644 index 000000000..a4d84570a --- /dev/null +++ b/src/csharp/Microsoft.Spark/ML/Feature/SQLTransformer.cs @@ -0,0 +1,97 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.Spark.Interop; +using Microsoft.Spark.Interop.Ipc; +using Microsoft.Spark.Sql; +using Microsoft.Spark.Sql.Types; + +namespace Microsoft.Spark.ML.Feature +{ + /// + /// implements the transformations which are defined by SQL statement. + /// + public class SQLTransformer : FeatureBase, IJvmObjectReferenceProvider + { + private static readonly string s_sqlTransformerClassName = + "org.apache.spark.ml.feature.SQLTransformer"; + + /// + /// Create a without any parameters. + /// + public SQLTransformer() : base(s_sqlTransformerClassName) + { + } + + /// + /// Create a with a UID that is used to give the + /// a unique ID. + /// + /// An immutable unique ID for the object and its derivatives. + public SQLTransformer(string uid) : base(s_sqlTransformerClassName, uid) + { + } + + internal SQLTransformer(JvmObjectReference jvmObject) : base(jvmObject) + { + } + + JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject; + + /// + /// Executes the and transforms the DataFrame to include the new + /// column. + /// + /// The DataFrame to transform + /// + /// New object with the source transformed. + /// + public DataFrame Transform(DataFrame source) => + new DataFrame((JvmObjectReference)_jvmObject.Invoke("transform", source)); + + /// + /// Executes the and transforms the schema. + /// + /// The Schema to be transformed + /// + /// New object with the schema transformed. + /// + public StructType TransformSchema(StructType value) => + new StructType( + (JvmObjectReference)_jvmObject.Invoke( + "transformSchema", + DataType.FromJson(_jvmObject.Jvm, value.Json))); + + /// + /// Gets the statement. + /// + /// Statement + public string GetStatement() => (string)_jvmObject.Invoke("getStatement"); + + /// + /// Sets the statement to . + /// + /// SQL Statement + /// + /// with the statement set. + /// + public SQLTransformer SetStatement(string statement) => + WrapAsSQLTransformer((JvmObjectReference)_jvmObject.Invoke("setStatement", statement)); + + /// + /// Loads the that was previously saved using Save. + /// + /// The path the previous was saved to + /// New object, loaded from path + public static SQLTransformer Load(string path) => + WrapAsSQLTransformer( + SparkEnvironment.JvmBridge.CallStaticJavaMethod( + s_sqlTransformerClassName, + "load", + path)); + + private static SQLTransformer WrapAsSQLTransformer(object obj) => + new SQLTransformer((JvmObjectReference)obj); + } +}