Skip to content

Commit

Permalink
Working CSV report formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
garfieldnate committed Feb 17, 2025
1 parent 346e8e3 commit e3a48dc
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 112 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ computer scientists (and most machine learning frameworks in general). Examples
'outcome' (class label), and 'variable' (feature). This software uses the CS terminology internally, but user-facing
reports use the AM terminology.

The running time for analogical modeling is exponential in nature and practice, and thus it is not suitable for very
large datasets; exact calculation becomes impractical after about 50 features. Therefore, this tool will automatically
use an approximation algorithm when there are 50 or more features.
The running time for analogical modeling is exponential in the number of features (variables); exact calculation becomes
impractical after about 50 features. Therefore, this tool will automatically use an approximation algorithm when there
are 50 or more features.

## Features

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ protected void doPrintClassification(Classifier classifier, Instance inst, int i
}

if (getGangs()) {
GangEffectsFormatter formatter = new GangEffectsFormatter(getNumDecimals(), format);
GangEffectsFormatter formatter = new GangEffectsFormatter(getNumDecimals(), format, AMUtils.LINE_SEPARATOR);
append("Gang effects:");
append(AMUtils.LINE_SEPARATOR);
append(formatter.formatGangs(results));
Expand Down Expand Up @@ -324,6 +324,13 @@ private Vector<Option> getOptionsOfSuper() {
* -gang
* Output gang effects
* </pre>
* <pre>
* -F &lt;format&gt;
* Format to print reports in. The options are 'human' and 'csv'. 'human' output is a human-readable, text-based
* table of some kind. 'csv', or comma-separated values, is intended to be machine-readable (for loading in Excel,
* Pandas, etc.), and contains strictly more data, such as the configuration parameters. Default is 'human'. If
* summary printing is turned on, this is always printed in the human-readable format.
* </pre>
*
* * <!-- options-end -->
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import java.util.function.BiFunction;
import java.util.stream.Stream;

import static java.lang.System.lineSeparator;

public class AnalogicalSetFormatter {

private final int numDecimals;
Expand All @@ -50,12 +48,12 @@ private static class TableEntry {


public String formatAnalogicalSet(AMResults results) {
TableSection.Builder bodyBuilder = new TableSection.Builder(); // 🏋️
streamTableEntries(results).forEach(e ->
bodyBuilder.addRow(e.getPercentage(), e.getPointers().toString(), e.getInstanceAtts(), e.getInstanceClass()));

switch (format) {
case HUMAN: {
TableSection.Builder bodyBuilder = new TableSection.Builder(); // 🏋️
streamTableEntries(results, true).forEach(e ->
bodyBuilder.addRow(e.getPercentage(), e.getPointers().toString(), e.getInstanceAtts(), e.getInstanceClass()));
return new Table.Builder().
setTableStyle(
new TableStyle.Builder().
Expand Down Expand Up @@ -91,12 +89,12 @@ public String formatAnalogicalSet(AMResults results) {
}

@NotNull
private Stream<TableEntry> streamTableEntries(AMResults results) {
private Stream<TableEntry> streamTableEntries(AMResults results, boolean addPercentPrefix) {
final Labeler labeler = results.getLabeler();
final BigDecimal totalPointers = new BigDecimal(results.getTotalPointers());

BiFunction<Instance, BigInteger, TableEntry> getTableEntry = (inst, pointers) -> {
String percentage = AMUtils.formatPointerPercentage(pointers, totalPointers, numDecimals);
String percentage = AMUtils.formatPointerPercentage(pointers, totalPointers, numDecimals, addPercentPrefix);
String instanceAtts = labeler.getInstanceAttsString(inst);
String instanceClass = inst.stringValue(inst.classIndex());
return new TableEntry(pointers, percentage, instanceAtts, instanceClass);
Expand All @@ -115,7 +113,7 @@ private AMUtils.CsvDoc getCsvDoc(AMResults results) {
List<String> headers = Arrays.asList("item", "class", "pointers", "percentage");
List<List<String>> entries = new ArrayList<>();

streamTableEntries(results).forEach(e -> entries.add(
streamTableEntries(results, false).forEach(e -> entries.add(
Arrays.asList(e.getInstanceAtts(), e.getInstanceClass(), e.getPointers().toString(), e.getPercentage())));

return new AMUtils.CsvDoc(headers, entries);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ private AMUtils.CsvDoc getCsvDoc(AMResults results) {
continue;
}
Attribute classAtt = classifiedExemplar.attribute(i);
headers.add(classAtt.name());
headers.add("f_" + classAtt.name());
values.add(classifiedExemplar.stringValue(classAtt));
}

Expand All @@ -98,6 +98,7 @@ private AMUtils.CsvDoc getCsvDoc(AMResults results) {
while (classNameIterator.hasNext()) {
headers.add("Class " + classIndex);
values.add((String) classNameIterator.next());
classIndex += 1;
}

results.getClassPointers().forEach((className, pointers) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
* Formatting choices for {@link AnalogicalModelingOutput}
*/
enum Format implements TagInfo {
HUMAN("human", "Human-readable format"),
HUMAN("human", "Human-readable format (*not* machine-readable!)"),
CSV("csv", "Machine-readable CSV designed for analysis in Excel, Pandas, etc.");

// string used on command line to indicate the use of this strategy
Expand Down
Original file line number Diff line number Diff line change
@@ -1,116 +1,195 @@
package weka.classifiers.evaluation.output.prediction;

import com.jakewharton.picnic.*;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import org.jetbrains.annotations.NotNull;
import weka.classifiers.lazy.AM.AMUtils;
import weka.classifiers.lazy.AM.data.AMResults;
import weka.classifiers.lazy.AM.data.GangEffect;
import weka.classifiers.lazy.AM.label.Labeler;
import weka.core.Instance;

import java.io.IOException;
import java.io.StringWriter;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Comparator;
import java.util.Map;
import java.util.Set;
import java.util.*;

import static weka.classifiers.lazy.AM.AMUtils.REPORT_TABLE_STYLE;
import static weka.classifiers.lazy.AM.AMUtils.*;

public class GangEffectsFormatter {
private static final CellStyle SUBHEADER_STYLE = new CellStyle.Builder().setBorderTop(true).setBorderBottom(true).build();
private static final CellStyle SUBHEADER_STYLE = new CellStyle.Builder().setBorderTop(true).setBorderBottom(true).build();

private final int numDecimals;
private final int numDecimals;
private Format format;
private final String lineSeparator;

/**
* @param numDecimals the number of digits to output after the decimal point
*/
public GangEffectsFormatter(int numDecimals, Format format) {
public GangEffectsFormatter(int numDecimals, Format format, String lineSeparator) {
this.numDecimals = numDecimals;
this.format = format;
this.lineSeparator = lineSeparator;
}

/**
* Format the provided gang effects nicely for human consumption; the returned string is <em>not</em> intended to be machine-readable.
* Format the provided gang effects using the specified format.
*/
public String formatGangs(AMResults results) {
Labeler labeler = results.getLabeler();
BigDecimal totalPointers = new BigDecimal(results.getTotalPointers());

TableSection.Builder bodyBuilder = new TableSection.Builder(); // 🏋️
for (GangEffect effect : results.getGangEffects()) {
// Subcontext header
bodyBuilder.addRow(getSubcontextHeader(labeler, totalPointers, effect));
effect.getClassToPointers().entrySet().stream().
// sort by count then alphabetically by class name
sorted(
Map.Entry.<String, BigInteger>comparingByValue(Comparator.reverseOrder()).
thenComparing(Map.Entry.comparingByKey())).
forEach(classToPointers -> {
Set<Instance> instances = effect.getClassToInstances().get(classToPointers.getKey());

// Class header
bodyBuilder.addRow(getClassHeader(classToPointers.getKey(), classToPointers.getValue(), totalPointers, instances.size()));

// sort and print instances
instances.stream().map(labeler::getInstanceAttsString).sorted().forEach(s -> bodyBuilder.addRow("", "", "", "", s));
});
}

switch(format) {
switch (format) {
case HUMAN: {
return new Table.Builder().
setTableStyle(
new TableStyle.Builder().
setBorder(true).build()).
setCellStyle(
REPORT_TABLE_STYLE
).setHeader(
new TableSection.Builder().
addRow(
"Percentage", "Pointers", "Num Items", "Class", "Context").build())
.setBody(bodyBuilder.build())
.build()
.toString();
return getHumanFormatted(results);
}
case CSV: {
return "TODO";
CsvDoc doc = getCsvDoc(results);
CSVFormat csvFormat = CSVFormat.DEFAULT.builder().setRecordSeparator(lineSeparator).setHeader(doc.headers.toArray(new String[]{})).build();
StringWriter sw = new StringWriter();
try (final CSVPrinter printer = new CSVPrinter(sw, csvFormat)) {
for (List<String> entry : doc.entries) {
printer.printRecord(entry);
}
} catch (IOException e) {
return "Error printing results to CSV: " + e;
}
return sw.toString();
}
default: {
throw new IllegalStateException("Unknown format " + format.getOptionString());
}
}
}

private Row getClassHeader(String className, BigInteger classPointers, BigDecimal totalPointers, int numInstances) {
return new Row.Builder().
addCell(AMUtils.formatPointerPercentage(classPointers, totalPointers, numDecimals)).
addCell(classPointers.toString()).
addCell(Integer.toString(numInstances)).
addCell(className).
addCell("").
build();
}

@NotNull
private Row getSubcontextHeader(Labeler labeler, BigDecimal totalPointers, GangEffect effect) {
return formatSubheaderRow(
AMUtils.formatPointerPercentage(effect.getTotalPointers(), totalPointers, numDecimals),
effect.getTotalPointers().toString(),
"" + effect.getSubcontext().getExemplars().size(),
"",
labeler.getContextString(effect.getSubcontext().getLabel()));
}

private static Row formatSubheaderRow(String... content) {
Row.Builder row = new Row.Builder();
for (String c : content) {
row.addCell(subheaderCell(c));
}
return row.build();
}

private static Cell subheaderCell(String content) {
return new Cell.Builder(content).setStyle(SUBHEADER_STYLE).build();
}
@NotNull
private String getHumanFormatted(AMResults results) {
Labeler labeler = results.getLabeler();
BigDecimal totalPointers = new BigDecimal(results.getTotalPointers());
TableSection.Builder bodyBuilder = new TableSection.Builder(); // 🏋️
for (GangEffect effect : results.getGangEffects()) {
// Subcontext header
bodyBuilder.addRow(getSubcontextHeader(labeler, totalPointers, effect));
effect.getClassToPointers().entrySet().stream().
// sort by count then alphabetically by class name
sorted(
Map.Entry.<String, BigInteger>comparingByValue(Comparator.reverseOrder()).
thenComparing(Map.Entry.comparingByKey())).
forEach(classToPointers -> {
Set<Instance> instances = effect.getClassToInstances().get(classToPointers.getKey());

// Class header
bodyBuilder.addRow(getClassHeader(classToPointers.getKey(), classToPointers.getValue(), totalPointers, instances.size()));

// sort and print instances
instances.stream().map(labeler::getInstanceAttsString).sorted().forEach(s -> bodyBuilder.addRow("", "", "", "", s));
});
}

return new Table.Builder().
setTableStyle(
new TableStyle.Builder().
setBorder(true).build()).
setCellStyle(
REPORT_TABLE_STYLE
).setHeader(
new TableSection.Builder().
addRow(
"Percentage", "Pointers", "Num Items", "Class", "Context").build())
.setBody(bodyBuilder.build())
.build()
.toString();
}

private Row getClassHeader(String className, BigInteger classPointers, BigDecimal totalPointers, int numInstances) {
return new Row.Builder().
addCell(formatPointerPercentage(classPointers, totalPointers, numDecimals, true)).
addCell(classPointers.toString()).
addCell(Integer.toString(numInstances)).
addCell(className).
addCell("").
build();
}

@NotNull
private Row getSubcontextHeader(Labeler labeler, BigDecimal totalPointers, GangEffect effect) {
return formatSubheaderRow(
formatPointerPercentage(effect.getTotalPointers(), totalPointers, numDecimals, true),
effect.getTotalPointers().toString(),
"" + effect.getSubcontext().getExemplars().size(),
"",
labeler.getContextString(effect.getSubcontext().getLabel()));
}

private static Row formatSubheaderRow(String... content) {
Row.Builder row = new Row.Builder();
for (String c : content) {
row.addCell(subheaderCell(c));
}
return row.build();
}

private static Cell subheaderCell(String content) {
return new Cell.Builder(content).setStyle(SUBHEADER_STYLE).build();
}

private CsvDoc getCsvDoc(AMResults results) {
CsvBuilder builder = new CsvBuilder();
Labeler labeler = results.getLabeler();
BigInteger totalPointers = results.getTotalPointers();
int rank = 0;
BigInteger previousPointers = null;
for (GangEffect effect : results.getGangEffects()) {
BigInteger totalEffectPointers = effect.getTotalPointers();
Map<String, String> commonRowData = new HashMap<>();
if (!totalEffectPointers.equals(previousPointers)) {
rank += 1;
previousPointers = totalEffectPointers;
}
commonRowData.put("rank", Integer.toString(rank));
commonRowData.put("total_ptrs", totalPointers.toString());
commonRowData.put("gang_ptrs", totalEffectPointers.toString());
commonRowData.put("gang_pct", formatPointerPercentage(totalEffectPointers, new BigDecimal(totalPointers), numDecimals, false));
commonRowData.put("size", Integer.toString(effect.getSubcontext().getExemplars().size()));
effect.getClassToPointers().entrySet().stream().
// sort by count then alphabetically by class name
sorted(
Map.Entry.<String, BigInteger>comparingByValue(Comparator.reverseOrder()).
thenComparing(Map.Entry.comparingByKey())).
forEach(classToPointers -> {
Set<Instance> instances = effect.getClassToInstances().get(classToPointers.getKey());
String className = classToPointers.getKey();
BigInteger classPointers = classToPointers.getValue();

// Class data
commonRowData.put("class", className);

String classPtrsColumn = className + "_ptrs";
commonRowData.put(classPtrsColumn, classPointers.toString());
builder.setDefault(classPtrsColumn, "0");

String classPctColumn = className + "_pct";
commonRowData.put(classPctColumn, formatPointerPercentage(classPointers, new BigDecimal(totalPointers), numDecimals, false));
builder.setDefault(classPctColumn, "0.0");

String classNumInstancesColumn = className + "_size";
commonRowData.put(classNumInstancesColumn, Integer.toString(instances.size()));
builder.setDefault(classNumInstancesColumn, "0");

List<String> contextLabelList = labeler.getContextList(effect.getSubcontext().getLabel(), "*");
for (int i = 0; i < contextLabelList.size(); i++) {
commonRowData.put("GF" + (i + 1), contextLabelList.get(i));
}
instances.forEach(instance -> {
Map<String, String> finalRowData = new HashMap<>(commonRowData);
List<String> atts = labeler.getInstanceAttsList(instance);
for (int i = 0; i < atts.size(); i++) {
finalRowData.put("F" + (i + 1), atts.get(i));
}
builder.addEntry(finalRowData);
});

});
}

return builder.build();
}
}
Loading

0 comments on commit e3a48dc

Please sign in to comment.