Skip to content

Implement ML Features #381. SQLTransformer class and testcase #781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.IO;
using Microsoft.Spark.ML.Feature;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;
using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;

namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
{
[Collection("Spark E2E Tests")]
public class SQLTransformerTests : FeatureBaseTests<SQLTransformer>
{
private readonly SparkSession _spark;

public SQLTransformerTests(SparkFixture fixture) : base(fixture)
{
_spark = fixture.Spark;
}

/// <summary>
/// Create a <see cref="DataFrame"/>, create a <see cref="SQLTransformer"/> and test the
/// available methods.
/// </summary>
[Fact]
public void TestSQLTransformer()
{
DataFrame input = _spark.CreateDataFrame(
new List<GenericRow>
{
new GenericRow(new object[] { 0, 1.0, 3.0 }),
new GenericRow(new object[] { 2, 2.0, 5.0 })
},
new StructType(new List<StructField>
{
new StructField("id", new IntegerType()),
new StructField("v1", new DoubleType()),
new StructField("v2", new DoubleType())
}));

string expectedUid = "theUid";
string inputStatement = "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__";

SQLTransformer sqlTransformer = new SQLTransformer(expectedUid)
.SetStatement(inputStatement);

string outputStatement = sqlTransformer.GetStatement();

DataFrame output = sqlTransformer.Transform(input);
StructType outputSchema = sqlTransformer.TransformSchema(input.Schema());

Assert.Contains(output.Schema().Fields, (f => f.Name == "v3"));
Assert.Contains(output.Schema().Fields, (f => f.Name == "v4"));
Assert.Contains(outputSchema.Fields, (f => f.Name == "v3"));
Assert.Contains(outputSchema.Fields, (f => f.Name == "v4"));
Assert.Equal(inputStatement, outputStatement);

using (var tempDirectory = new TemporaryDirectory())
{
string savePath = Path.Join(tempDirectory.Path, "SQLTransformer");
sqlTransformer.Save(savePath);

SQLTransformer loadedsqlTransformer = SQLTransformer.Load(savePath);
Assert.Equal(sqlTransformer.Uid(), loadedsqlTransformer.Uid());
}
Assert.Equal(expectedUid, sqlTransformer.Uid());
}
}
}
97 changes: 97 additions & 0 deletions src/csharp/Microsoft.Spark/ML/Feature/SQLTransformer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.Spark.Interop;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;

namespace Microsoft.Spark.ML.Feature
{
/// <summary>
/// <see cref="SQLTransformer"/> implements the transformations which are defined by SQL statement.
/// </summary>
public class SQLTransformer : FeatureBase<SQLTransformer>, IJvmObjectReferenceProvider
{
private static readonly string s_sqlTransformerClassName =
"org.apache.spark.ml.feature.SQLTransformer";

/// <summary>
/// Create a <see cref="SQLTransformer"/> without any parameters.
/// </summary>
public SQLTransformer() : base(s_sqlTransformerClassName)
{
}

/// <summary>
/// Create a <see cref="SQLTransformer"/> with a UID that is used to give the
/// <see cref="SQLTransformer"/> a unique ID.
/// </summary>
/// <param name="uid">An immutable unique ID for the object and its derivatives.</param>
public SQLTransformer(string uid) : base(s_sqlTransformerClassName, uid)
{
}

internal SQLTransformer(JvmObjectReference jvmObject) : base(jvmObject)
{
}

JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject;

/// <summary>
/// Executes the <see cref="SQLTransformer"/> and transforms the DataFrame to include the new
/// column.
/// </summary>
/// <param name="source">The DataFrame to transform</param>
/// <returns>
/// New <see cref="DataFrame"/> object with the source <see cref="DataFrame"/> transformed.
/// </returns>
public DataFrame Transform(DataFrame source) =>
new DataFrame((JvmObjectReference)_jvmObject.Invoke("transform", source));

/// <summary>
/// Executes the <see cref="SQLTransformer"/> and transforms the schema.
/// </summary>
/// <param name="value">The Schema to be transformed</param>
/// <returns>
/// New <see cref="StructType"/> object with the schema <see cref="StructType"/> transformed.
/// </returns>
public StructType TransformSchema(StructType value) =>
new StructType(
(JvmObjectReference)_jvmObject.Invoke(
"transformSchema",
DataType.FromJson(_jvmObject.Jvm, value.Json)));

/// <summary>
/// Gets the statement.
/// </summary>
/// <returns>Statement</returns>
public string GetStatement() => (string)_jvmObject.Invoke("getStatement");

/// <summary>
/// Sets the statement to <see cref="SQLTransformer"/>.
/// </summary>
/// <param name="statement">SQL Statement</param>
/// <returns>
/// <see cref="SQLTransformer"/> with the statement set.
/// </returns>
public SQLTransformer SetStatement(string statement) =>
WrapAsSQLTransformer((JvmObjectReference)_jvmObject.Invoke("setStatement", statement));

/// <summary>
/// Loads the <see cref="SQLTransformer"/> that was previously saved using Save.
/// </summary>
/// <param name="path">The path the previous <see cref="SQLTransformer"/> was saved to</param>
/// <returns>New <see cref="SQLTransformer"/> object, loaded from path</returns>
public static SQLTransformer Load(string path) =>
WrapAsSQLTransformer(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_sqlTransformerClassName,
"load",
path));

private static SQLTransformer WrapAsSQLTransformer(object obj) =>
new SQLTransformer((JvmObjectReference)obj);
}
}