Skip to content

Commit

Permalink
implement CSV formatting for distribution report
Browse files Browse the repository at this point in the history
2 new concepts were required for the reporting that Sabine requires:

* Number of exemplars in the training set was never tracked before. We
now track this in the subcontext list, which we now store in the
AMResults class.
* Judgement: correct, tie, incorrect, etc. I used this term because I
found "prediction" to be easily confusable with the predicted class.

There is also some work on a CSV document header, but the distribution
is only a one-row document, so the header would be overkill. I can
always add it in the follow-up Python tool for collecting all of these
CSV documents in one place.
  • Loading branch information
garfieldnate committed Feb 11, 2024
1 parent c76c7c4 commit e8d4dab
Show file tree
Hide file tree
Showing 10 changed files with 251 additions and 32 deletions.
7 changes: 4 additions & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ repositories {
}

dependencies {
implementation group: 'nz.ac.waikato.cms.weka', name: 'weka-dev', version: '3.9.5'
implementation group: 'com.jakewharton.picnic', name: 'picnic', version: '0.5.0'
implementation group: 'com.google.guava', name: 'guava', version: '19.0'
compileOnly 'org.projectlombok:lombok:1.18.20'
implementation group: 'com.jakewharton.picnic', name: 'picnic', version: '0.5.0'
implementation group: 'org.apache.commons', name: 'commons-csv', version: '1.10.0'
implementation group: 'nz.ac.waikato.cms.weka', name: 'weka-dev', version: '3.9.5'
compileOnly 'org.projectlombok:lombok:1.18.20'
annotationProcessor 'org.projectlombok:lombok:1.18.20'

testImplementation group: 'junit', name: 'junit', version: '4.13.2'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,12 @@ protected void doPrintClassification(Classifier classifier, Instance inst, int i
}

if (getOutputDistribution()) {
DistributionFormatter formatter = new DistributionFormatter(getNumDecimals(), format);
append(formatter.formatDistribution(results, distribution, m_Header));
DistributionFormatter formatter = new DistributionFormatter(getNumDecimals(), AMUtils.LINE_SEPARATOR);
append(formatter.formatDistribution(results, distribution, m_Header.relationName(), format));
append(AMUtils.LINE_SEPARATOR);
}


if (getAnalogicalSet()) {
AnalogicalSetFormatter formatter = new AnalogicalSetFormatter(getNumDecimals(), format);
append("Analogical set:");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,43 +1,140 @@
package weka.classifiers.evaluation.output.prediction;

import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import weka.classifiers.lazy.AM.AMUtils;
import weka.classifiers.lazy.AM.data.AMResults;
import weka.classifiers.lazy.AM.label.Labeler;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;

import java.io.IOException;
import java.io.StringWriter;
import java.math.BigDecimal;
import java.math.MathContext;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

import static weka.classifiers.evaluation.output.prediction.Format.getCsvCommentHeader;

public class DistributionFormatter {
private final int numDecimals;
private Format format;
private final String lineSeparator;

/**
* @param numDecimals the number of digits to output after the decimal point
*/
public DistributionFormatter(int numDecimals, Format format) {
public DistributionFormatter(int numDecimals, String lineSeparator) {
this.numDecimals = numDecimals;
this.format = format;
this.lineSeparator = lineSeparator;
}

public String formatDistribution(AMResults results, double[] distribution, Instances m_Header) {
public String formatDistribution(AMResults results, double[] distribution, String relationName, Format format) {
String doubleFormat = String.format("%%.%df", numDecimals);
StringBuilder sb = new StringBuilder();
sb.append("Class probability distribution:").append(AMUtils.LINE_SEPARATOR);
switch(format) {
switch (format) {
case HUMAN: {
StringBuilder sb = new StringBuilder();
sb.append("Class probability distribution:").append(lineSeparator);
Attribute classAttribute = results.getClassifiedEx().classAttribute();
for (int i = 0; i < distribution.length; i++) {
sb.append(m_Header.classAttribute().value(i));
sb.append(classAttribute.value(i));
sb.append(": ");
sb.append(String.format(doubleFormat, distribution[i]));
sb.append(AMUtils.LINE_SEPARATOR);
sb.append(lineSeparator);
}
break;
return sb.toString();
}
case CSV: {
sb.append("TODO");
break;
CsvDoc doc = getCSVDoc(results);

CSVFormat csvFormat = CSVFormat.DEFAULT.builder().setRecordSeparator(lineSeparator).setHeader(doc.headers.toArray(new String[]{})).build();
StringWriter sw = new StringWriter();
// for now this is too much to write for just a single row of output
// sw.write(getCsvCommentHeader(relationName, "Class Probability Distribution"));
// sw.write(lineSeparator);
try (final CSVPrinter printer = new CSVPrinter(sw, csvFormat)) {
for (List<String> entry : doc.entries) {
printer.printRecord(entry);
}
} catch (IOException e) {
return "Error printing results to CSV: " + e;
}
return sw.toString();
}
default: {
throw new IllegalStateException("Unknown formatter: " + format.getOptionString());
}
}
return sb.toString();
}

private static class CsvDoc {
final List<String> headers;
final List<List<String>> entries;

private CsvDoc(List<String> headers, List<List<String>> entries) {
this.headers = headers;
this.entries = entries;
}
}

private CsvDoc getCSVDoc(AMResults results) {
Labeler labeler = results.getLabeler();
List<String> headers = new ArrayList<>();
List<String> values = new ArrayList<>();

headers.add("Judgement");
values.add(results.getJudgement().toString().toLowerCase());

headers.add("Expected");
values.add(results.getExpectedClassName());

Instance classifiedExemplar = results.getClassifiedEx();

// value of each feature
for (int i = 0; i < classifiedExemplar.numAttributes(); i++) {
// skip ignored attributes and the class attribute
if (labeler.isIgnored(i)) {
continue;
}
if (i == classifiedExemplar.classIndex()) {
continue;
}
Attribute classAtt = classifiedExemplar.attribute(i);
headers.add(classAtt.name());
values.add(classifiedExemplar.stringValue(classAtt));
}

// each potential class value
Iterator<Object> classNameIterator = classifiedExemplar.classAttribute().enumerateValues().asIterator();
int classIndex = 1;
while (classNameIterator.hasNext()) {
headers.add("Class " + classIndex);
values.add((String) classNameIterator.next());
}

results.getClassPointers().forEach((className, pointers) -> {
headers.add(className + "_ptrs");
values.add(pointers.toString());
});
results.getClassLikelihood().forEach((className, likelihood) -> {
headers.add(className + "_pct");
BigDecimal percentage = likelihood.multiply(BigDecimal.valueOf(100)).round(MathContext.DECIMAL32);
values.add(percentage.toString());
});

headers.add("train_size");
values.add("" + results.getSubList().getConsideredExemplarCount());

headers.add("num_feats");
// subract one for the class
values.add("" + (classifiedExemplar.numAttributes() - 1));

// just one row in this CSV (one exemplar classified)
List<List<String>> entries = new ArrayList<>();
entries.add(values);

return new CsvDoc(headers, entries);
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
package weka.classifiers.evaluation.output.prediction;

import weka.classifiers.lazy.AM.AMUtils;
import weka.classifiers.lazy.AM.Enum2TagUtils.TagInfo;

import java.time.Clock;

/**
* Formatting choices for {@link AnalogicalModelingOutput}
*/
enum Format implements TagInfo {
HUMAN("human", "Human-readable format"),
CSV("csv", "Machine-readable CSV designed for analysis in Excel, Pandas, etc.");

// TODO: name of relation and name of report (distribution, analogical set or gangs)
public static final String getCsvCommentHeader(String relationName, String reportName) {
return "# relation " + relationName + " (" + reportName + ")" + AMUtils.LINE_SEPARATOR +
"# Generated via Weka Analogical Modeling plugin on " + Clock.systemDefaultZone().instant() + AMUtils.LINE_SEPARATOR +
"# This data is in CSV format." + AMUtils.LINE_SEPARATOR +
"# To load in Pandas: TODO" + AMUtils.LINE_SEPARATOR +
"# To load in Excel: TODO" + AMUtils.LINE_SEPARATOR;
}

// string used on command line to indicate the use of this strategy
private final String optionString;
// string which describes comparison strategy for a given entry
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/weka/classifiers/lazy/AM/Enum2TagUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

/**
* In Weka, configuration with a specific set of possible values is implemented using {@link Tag}. These
* utilities make it possible to use an enum as the set of tags for a given congif parameter.
* utilities make it possible to use an enum as the set of tags for a given config parameter.
*/
public class Enum2TagUtils {
/**
Expand Down
64 changes: 53 additions & 11 deletions src/main/java/weka/classifiers/lazy/AM/data/AMResults.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,22 @@ public class AMResults {

private static final String newline = System.getProperty("line.separator");
private final Labeler labeler;
private final SubcontextList subList;

/**
/**
* @param lattice filled lattice, which contains the data for calculating the analogical set
* @param testItem Exemplar being classified
* @param linear True if counting of pointers should be done linearly; false if quadratically.
* @param labeler The labeler that was used to assign contextual labels; this is made available
* for printing purposes.
*/
public AMResults(Lattice lattice, Instance testItem, boolean linear, Labeler labeler) {
Set<Supracontext> set = lattice.getSupracontexts();
public AMResults(Lattice lattice, SubcontextList subList, Instance testItem, boolean linear, Labeler labeler) {
Set<Supracontext> set = lattice.getSupracontexts();

this.classifiedExemplar = testItem;
this.supraList = set;
this.labeler = labeler;
this.subList = subList;

// find numbers of pointers to individual exemplars
this.exPointerMap = getPointers(set, linear);
Expand Down Expand Up @@ -230,13 +232,6 @@ public Map<Instance, BigInteger> getExemplarPointers() {
return exPointerMap;
}

/**
* @return A mapping between a possible class index and its likelihood (decimal probability)
*/
public Map<String, BigDecimal> getClassLikelihoodMap() {
return classLikelihoodMap;
}

/**
* @return The total number of pointers in this analogical set
*/
Expand All @@ -252,7 +247,7 @@ public Map<String, BigInteger> getClassPointers() {
}

/**
* @return A mapping between the class value index and its selection probability
* @return A mapping between the class name and its selection probability
*/
public Map<String, BigDecimal> getClassLikelihood() {
return classLikelihoodMap;
Expand Down Expand Up @@ -315,4 +310,51 @@ public List<GangEffect> getGangEffects() {
public Labeler getLabeler() {
return labeler;
}

public String getExpectedClassName() {
Instance classifiedEx = getClassifiedEx();
double expectedIndex = classifiedEx.classValue();
return classifiedEx.classAttribute().value((int) expectedIndex);
}


public SubcontextList getSubList() {
return subList;
}

public enum Judgement {
/**
* Only the correct class was predicted
*/
CORRECT,
/**
* The correct class and others were tied in the prediction
*/
TIE,
/**
* The correct class was not predicted
*/
INCORRECT,
/**
* The correct class was not specified in the dataset
*/
UNKNOWN;
}

/**
* @return a judgement of the prediction
*/
public Judgement getJudgement() {
String expected = getExpectedClassName();
if (expected == null) {
return Judgement.UNKNOWN;
}
if (getPredictedClasses().contains(expected)) {
if (getPredictedClasses().size() == 1) {
return Judgement.CORRECT;
}
return Judgement.TIE;
}
return Judgement.INCORRECT;
}
}
10 changes: 10 additions & 0 deletions src/main/java/weka/classifiers/lazy/AM/data/SubcontextList.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ public class SubcontextList implements Iterable<Subcontext> {
private final Labeler labeler;
private final boolean ignoreFullMatches;

private int consideredExemplarCount;

/**
* @return the number of attributes used to predict an outcome
*/
Expand Down Expand Up @@ -76,6 +78,7 @@ void add(Instance exemplar) {
labelToSubcontext.put(label, new Subcontext(label, labeler.getContextString(label)));
}
labelToSubcontext.get(label).add(exemplar);
consideredExemplarCount++;
}

/**
Expand Down Expand Up @@ -154,6 +157,13 @@ public int size() {
return labelToSubcontext.size();
}

/**
* @return The number of exemplars considered accepted into the list, e.g. added and not ignored
*/
public int getConsideredExemplarCount() {
return consideredExemplarCount;
}

/**
* @return The labeler object used to assign incoming data to subcontexts.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public Label label(Instance data) {
if (isIgnored(i)) continue;
if (i == getTestInstance().classIndex()) continue;
att = getTestInstance().attribute(i);
// use mdc if were are comparing a missing attribute
// use mdc if we are comparing a missing attribute
if (getTestInstance().isMissing(att) || data.isMissing(att)) {
if (!getMissingDataCompare().matches(getTestInstance(), data, att))
// use length-1-index instead of index so that in binary the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ private AMResults classify(Instance testItem) throws InterruptedException, Execu
// 3. record the analogical set and other statistics from the pointers in the
// resulting homogeneous supracontexts
// we save the results for use with AnalogicalModelingOutput
results = new AMResults(lattice, testItem, m_linearCount, labeler);
results = new AMResults(lattice, subList, testItem, m_linearCount, labeler);
return results;
}

Expand Down
Loading

0 comments on commit e8d4dab

Please sign in to comment.