Skip to content

Add Binarizer #744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.Spark.ML.Feature;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;
using Microsoft.Spark.UnitTest.TestUtils;
using Xunit;

namespace Microsoft.Spark.E2ETest.IpcTests.ML.Feature
{
[Collection("Spark E2E Tests")]
public class BinarizerTests : FeatureBaseTests<Binarizer>
{
private readonly SparkSession _spark;

public BinarizerTests(SparkFixture fixture) : base(fixture)
{
_spark = fixture.Spark;
}

[Fact]
public void TestBinarizer()
{
string inputCol = "feature";
DataFrame input = _spark.CreateDataFrame(
new List<GenericRow>
{
new GenericRow(new object[] {0, 0.1}),
new GenericRow(new object[] {1, 0.8}),
new GenericRow(new object[] {2, 0.2})
},
new StructType(new List<StructField>
{
new StructField("id", new IntegerType()), new StructField(inputCol, new DoubleType())
}));
string expectedUid = "theUid";
string outputCol = "binarized_feature";
double threshold = 0.5;
Binarizer binarizer = new Binarizer(expectedUid)
.SetInputCol(inputCol)
.SetOutputCol(outputCol)
.SetThreshold(threshold);
DataFrame output = binarizer.Transform(input);
StructType outputSchema = binarizer.TransformSchema(input.Schema());

Assert.Contains(output.Schema().Fields, (f => f.Name == outputCol));
Assert.Contains(outputSchema.Fields, (f => f.Name == outputCol));
Assert.Equal(inputCol, binarizer.GetInputCol());
Assert.Equal(outputCol, binarizer.GetOutputCol());
Assert.Equal(threshold, binarizer.GetThreshold());

using (var tempDirectory = new TemporaryDirectory())
{
string savePath = Path.Join(tempDirectory.Path, "Binarizer");
binarizer.Save(savePath);

Binarizer loadedBinarizer = Binarizer.Load(savePath);
Assert.Equal(loadedBinarizer.Uid(), binarizer.Uid());
}

Assert.Equal(expectedUid, binarizer.Uid());
}

[Fact]
public void TestBinarizerWithArrayParams()
{
string[] inputCol = new[] {"col1", "col2"};
string[] outputCol = new[] {"feature1", "feature2"};
double[] threshold = new[] {0.5, 0.8};
Binarizer binarizer = new Binarizer()
.SetInputCols(inputCol)
.SetOutputCols(outputCol)
.SetThresholds(threshold);

Assert.Equal(inputCol, binarizer.GetInputCols());
Assert.Equal(outputCol, binarizer.GetOutputCols());
Assert.Equal(threshold, binarizer.GetThresholds());
}
}
}
161 changes: 161 additions & 0 deletions src/csharp/Microsoft.Spark/ML/Feature/Binarizer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.Spark.Interop;
using Microsoft.Spark.Interop.Ipc;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;

namespace Microsoft.Spark.ML.Feature
{
/// <summary>
/// A <see cref="Binarizer"/>, Binarize a column of continuous features given a threshold.
/// </summary>
public class Binarizer : FeatureBase<Binarizer>, IJvmObjectReferenceProvider
{
private static readonly string s_binarizerClassName =
"org.apache.spark.ml.feature.Binarizer";

public Binarizer() : base(s_binarizerClassName)
{
}

public Binarizer(string uid) : base(s_binarizerClassName, uid)
{
}

internal Binarizer(JvmObjectReference jvmObject) : base(jvmObject)
{
}

JvmObjectReference IJvmObjectReferenceProvider.Reference => _jvmObject;

/// <summary>
/// Gets the column that the <see cref="Binarizer"/> should read from
/// </summary>
/// <returns>string, input column</returns>
public string GetInputCol() => (string)(_jvmObject.Invoke("getInputCol"));

/// <summary>
/// Sets the column that the <see cref="Binarizer"/> should read from
/// </summary>
/// <param name="value">The name of the column to as the source</param>
/// <returns>New <see cref="Binarizer"/> object</returns>
public Binarizer SetInputCol(string value) =>
WrapAsBinarizer(_jvmObject.Invoke("setInputCol", value));

/// <summary>
/// Gets the columns that the <see cref="Binarizer"/> should read from
/// </summary>
/// <returns>array of strings, input column</returns>
public string[] GetInputCols() => (string[])(_jvmObject.Invoke("getInputCols"));

/// <summary>
/// Sets the columns that the <see cref="Binarizer"/> should read from
/// </summary>
/// <param name="value">The name of the columns to as the source</param>
/// <returns>New <see cref="Binarizer"/> object</returns>
public Binarizer SetInputCols(string[] value) =>
WrapAsBinarizer(_jvmObject.Invoke("setInputCols", value));

/// <summary>
/// Param for threshold used to <see cref="Binarizer"/> continuous features.
/// </summary>
/// <param name="value">Threshold value</param>
/// <returns>New <see cref="Binarizer"/> object</returns>
public Binarizer SetThreshold(double value) =>
WrapAsBinarizer(_jvmObject.Invoke("setThreshold", value));

/// <summary>
/// Gets threshold used to <see cref="Binarizer"/> continuous features.
/// </summary>
/// <returns>double, the threshold</returns>
public double GetThreshold() => (double)(_jvmObject.Invoke("getThreshold"));

/// <summary>
/// Param for thresholds used to <see cref="Binarizer"/> continuous features.
/// </summary>
/// <param name="value">Threshold values</param>
/// <returns>New <see cref="Binarizer"/> object</returns>
public Binarizer SetThresholds(double[] value) =>
WrapAsBinarizer(_jvmObject.Invoke("setThresholds", value));

/// <summary>
/// Gets thresholds used to <see cref="Binarizer"/> continuous features.
/// </summary>
/// <returns>array of double, the thresholds</returns>
public double[] GetThresholds() => (double[])(_jvmObject.Invoke("getThresholds"));

/// <summary>
/// The <see cref="Binarizer"/> will create a new column in the DataFrame, this is the
/// name of the new column.
/// </summary>
/// <returns>string, the output column</returns>
public string GetOutputCol() => (string)(_jvmObject.Invoke("getOutputCol"));

/// <summary>
/// The <see cref="Binarizer"/> will create a new column in the DataFrame, this is the
/// name of the new column.
/// </summary>
/// <param name="value">The name of the new column</param>
/// <returns>New <see cref="Binarizer"/> object</returns>
public Binarizer SetOutputCol(string value) =>
WrapAsBinarizer(_jvmObject.Invoke("setOutputCol", value));

/// <summary>
/// The <see cref="Binarizer"/> will create a new columns in the DataFrame, this is the
/// name of the new column.
/// </summary>
/// <returns>array of strings, the output column</returns>
public string[] GetOutputCols() => (string[])(_jvmObject.Invoke("getOutputCols"));

/// <summary>
/// The <see cref="Binarizer"/> will create a new columns in the DataFrame, this is the
/// name of the new column.
/// </summary>
/// <param name="value">The name of the new columns</param>
/// <returns>New <see cref="Binarizer"/> object</returns>
public Binarizer SetOutputCols(string[] value) =>
WrapAsBinarizer(_jvmObject.Invoke("setOutputCols", value));

/// <summary>
/// Executes the <see cref="Binarizer"/> and transforms the DataFrame to include the new
/// column
/// </summary>
/// <param name="source">The DataFrame to transform</param>
/// <returns>
/// New <see cref="DataFrame"/> object with the source <see cref="DataFrame"/> transformed
/// </returns>
public DataFrame Transform(DataFrame source) =>
new DataFrame((JvmObjectReference)_jvmObject.Invoke("transform", source));

/// <summary>
/// Executes the <see cref="Binarizer"/> and transforms the schema.
/// </summary>
/// <param name="value">The Schema to be transformed</param>
/// <returns>
/// New <see cref="StructType"/> object with the schema <see cref="StructType"/> transformed.
/// </returns>
public StructType TransformSchema(StructType value) =>
new StructType(
(JvmObjectReference)_jvmObject.Invoke(
"transformSchema",
DataType.FromJson(_jvmObject.Jvm, value.Json)));

/// <summary>
/// Loads the <see cref="Binarizer"/> that was previously saved using Save
/// </summary>
/// <param name="path">The path the previous <see cref="Binarizer"/> was saved to</param>
/// <returns>New <see cref="Binarizer"/> object, loaded from path</returns>
public static Binarizer Load(string path)
{
return WrapAsBinarizer(
SparkEnvironment.JvmBridge.CallStaticJavaMethod(
s_binarizerClassName, "load", path));
}

private static Binarizer WrapAsBinarizer(object obj) =>
new Binarizer((JvmObjectReference)obj);
}
}