|
| 1 | +package gin.algorithm.nsgaii; |
| 2 | + |
| 3 | +import com.opencsv.CSVWriter; |
| 4 | +import com.sampullara.cli.Args; |
| 5 | +import com.sampullara.cli.Argument; |
| 6 | +import gin.Patch; |
| 7 | +import gin.SourceFile; |
| 8 | +import gin.edit.Edit; |
| 9 | +import gin.test.UnitTest; |
| 10 | +import gin.test.UnitTestResultSet; |
| 11 | +import gin.util.Sampler; |
| 12 | +import org.apache.commons.rng.simple.JDKRandomBridge; |
| 13 | +import org.apache.commons.rng.simple.RandomSource; |
| 14 | +import org.pmw.tinylog.Logger; |
| 15 | + |
| 16 | +import java.io.File; |
| 17 | +import java.io.FileWriter; |
| 18 | +import java.io.IOException; |
| 19 | +import java.util.*; |
| 20 | + |
| 21 | +public class NSGAII extends Sampler { |
| 22 | + |
| 23 | + private static final long serialVersionUID = 8547883760400442899L; |
| 24 | + |
| 25 | + @Argument(alias = "et", description = "Edit type: this can be a member of the EditType enum (LINE,STATEMENT,MATCHED_STATEMENT,MODIFY_STATEMENT); the fully qualified name of a class that extends gin.edit.Edit, or a comma separated list of both") |
| 26 | + protected String editType = Edit.EditType.STATEMENT.toString(); |
| 27 | + |
| 28 | + @Argument(alias = "gn", description = "Number of generations") |
| 29 | + protected Integer genNumber = 1; |
| 30 | + |
| 31 | + @Argument(alias = "in", description = "Number of individuals") |
| 32 | + protected Integer indNumber = 10; |
| 33 | + |
| 34 | + @Argument(alias = "ms", description = "Random seed for mutation operator selection") |
| 35 | + protected Integer mutationSeed = 123; |
| 36 | + |
| 37 | + @Argument(alias = "is", description = "Random seed for individual selection") |
| 38 | + protected Integer individualSeed = 123; |
| 39 | + |
| 40 | + // Allowed edit types for sampling: parsed from editType |
| 41 | + protected List<Class<? extends Edit>> editTypes; |
| 42 | + |
| 43 | + protected Random mutationRng; |
| 44 | + protected Random individualRng; |
| 45 | + |
| 46 | + private String className; |
| 47 | + protected String methodName; |
| 48 | + private float initTime; |
| 49 | + private long initMem; |
| 50 | + private List<UnitTest> tests; |
| 51 | + |
| 52 | + public NSGAII(String[] args) { |
| 53 | + super(args); |
| 54 | + Args.parseOrExit(this, args); |
| 55 | + setup(); |
| 56 | + printAdditionalArguments(); |
| 57 | + } |
| 58 | + |
| 59 | + // Constructor used for testing |
| 60 | + public NSGAII(File projectDir, File methodFile) { |
| 61 | + super(projectDir, methodFile); |
| 62 | + setup(); |
| 63 | + } |
| 64 | + |
| 65 | + |
| 66 | + public static void main(String[] args) { |
| 67 | + NSGAII sampler = new NSGAII(args); |
| 68 | + sampler.sampleMethods(); |
| 69 | + } |
| 70 | + |
| 71 | + |
| 72 | + private void printAdditionalArguments() { |
| 73 | + Logger.info("Edit types: "+ editTypes); |
| 74 | + Logger.info("Number of generations: "+ genNumber); |
| 75 | + Logger.info("Number of individuals: "+ indNumber); |
| 76 | + Logger.info("Random seed for mutation operator selection: "+ mutationSeed); |
| 77 | + Logger.info("Random seed for individual selection: "+ individualSeed); |
| 78 | + } |
| 79 | + |
| 80 | + private void setup() { |
| 81 | + mutationRng = new JDKRandomBridge(RandomSource.MT, Long.valueOf(mutationSeed)); |
| 82 | + individualRng = new JDKRandomBridge(RandomSource.MT, Long.valueOf(individualSeed)); |
| 83 | + editTypes = Edit.parseEditClassesFromString(editType); |
| 84 | + } |
| 85 | + |
| 86 | + // Implementation of gin.util.Sampler's abstract method |
| 87 | + protected void sampleMethodsHook() { |
| 88 | + |
| 89 | + if ((indNumber < 1) || (genNumber < 1)) { |
| 90 | + Logger.info("Please enter a positive number of generations and individuals."); |
| 91 | + } else { |
| 92 | + |
| 93 | + writeNewHeader(); |
| 94 | + |
| 95 | + for (TargetMethod method : methodData) { |
| 96 | + |
| 97 | + Logger.info("Running NSGAII on method " + method); |
| 98 | + |
| 99 | + // Setup SourceFile for patching |
| 100 | + SourceFile sourceFile = SourceFile.makeSourceFileForEditTypes(editTypes, method.getFileSource().getPath(), Collections.singletonList(method.getMethodName())); |
| 101 | + |
| 102 | + search(method, new Patch(sourceFile)); |
| 103 | + |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + } |
| 108 | + |
| 109 | + private void search(TargetMethod method, Patch origPatch) { |
| 110 | + |
| 111 | + |
| 112 | + className = method.getClassName(); |
| 113 | + methodName = method.toString(); |
| 114 | + tests = method.getGinTests(); |
| 115 | + |
| 116 | + // Run original code |
| 117 | + UnitTestResultSet initRes = initFitness(className, tests, origPatch); |
| 118 | + initMem = initRes.totalMemoryUsage(); |
| 119 | + initTime = initRes.totalExecutionTime() / 1000000.0f; |
| 120 | + writePatch(initRes, methodName); |
| 121 | + |
| 122 | + ArrayList<Integer> dirs = new ArrayList<>(); |
| 123 | + dirs.add(-1); |
| 124 | + dirs.add(-1); |
| 125 | + NSGAIIPop P = new NSGAIIPop(2, dirs); |
| 126 | + UnitTestResultSet resultSet = initRes; |
| 127 | + Logger.info("Generating initial generation"); |
| 128 | + for (int i = 0; i < indNumber; i++) { |
| 129 | + Patch patch = mutate(origPatch); |
| 130 | + resultSet = testPatch(className, tests, patch); |
| 131 | + writePatch(resultSet, methodName); |
| 132 | + ArrayList<Long> fitnesses = new ArrayList<>(); |
| 133 | + if (resultSet.allTestsSuccessful()) { |
| 134 | + fitnesses.add(resultSet.totalExecutionTime()); |
| 135 | + fitnesses.add(resultSet.totalMemoryUsage()); |
| 136 | + } else { |
| 137 | + fitnesses.add(Long.MAX_VALUE); |
| 138 | + fitnesses.add(Long.MAX_VALUE); |
| 139 | + } |
| 140 | + P.addInd(patch, fitnesses); |
| 141 | + |
| 142 | + } |
| 143 | + for (int g = 0; g < genNumber; g++) { |
| 144 | + Logger.info("Generating generation " + g); |
| 145 | + NSGAIIPop Q = NSGAIIOffspring(P, origPatch); |
| 146 | + NSGAIIPop R = new NSGAIIPop(P, Q); |
| 147 | + Logger.info("getting next generation"); |
| 148 | + ArrayList<Patch> patches = R.getNextGen(indNumber); |
| 149 | + P = new NSGAIIPop(2, dirs); |
| 150 | + for (Patch patch : patches) { |
| 151 | + Logger.info("Testing patch: " + patch); |
| 152 | + resultSet = testPatch(className, tests, patch); |
| 153 | + |
| 154 | + writePatch(resultSet, methodName); |
| 155 | + ArrayList<Long> fitnesses = new ArrayList<>(); |
| 156 | + if (resultSet.allTestsSuccessful()) { |
| 157 | + fitnesses.add(resultSet.totalExecutionTime()); |
| 158 | + fitnesses.add(resultSet.totalMemoryUsage()); |
| 159 | + } else { |
| 160 | + fitnesses.add(Long.MAX_VALUE); |
| 161 | + fitnesses.add(Long.MAX_VALUE); |
| 162 | + } |
| 163 | + P.addInd(patch, fitnesses); |
| 164 | + } |
| 165 | + } |
| 166 | + } |
| 167 | + |
| 168 | + public NSGAIIPop NSGAIIOffspring(NSGAIIPop pop, Patch origpatch){ |
| 169 | + Logger.info("Generating offspring"); |
| 170 | + ArrayList<NSGAInd> population = pop.getPopulation(); |
| 171 | + List<Patch> oldPatches = new ArrayList<>(); |
| 172 | + for (NSGAInd ind : population){ |
| 173 | + oldPatches.add(ind.getPatch()); |
| 174 | + } |
| 175 | + List<Patch> patches = new ArrayList<>(); |
| 176 | + //selection |
| 177 | + for (int i = 0; i < population.size(); i++){ |
| 178 | + NSGAInd ind1 = population.get(individualRng.nextInt(population.size())); |
| 179 | + NSGAInd ind2 = population.get(individualRng.nextInt(population.size())); |
| 180 | + if (ind1.getRank() < ind2.getRank()){ |
| 181 | + patches.add(ind1.getPatch().clone()); |
| 182 | + } |
| 183 | + if (ind1.getRank() > ind2.getRank()){ |
| 184 | + patches.add(ind2.getPatch().clone()); |
| 185 | + } |
| 186 | + else{ |
| 187 | + float coinFlip = mutationRng.nextFloat(); |
| 188 | + if(coinFlip < 0.5) { |
| 189 | + patches.add(ind1.getPatch().clone()); |
| 190 | + } |
| 191 | + else { |
| 192 | + patches.add(ind2.getPatch().clone()); |
| 193 | + } |
| 194 | + } |
| 195 | + } |
| 196 | + //crossover |
| 197 | + patches = crossover(patches, origpatch); |
| 198 | + //mutation |
| 199 | + List<Patch> mutatedPatches = new ArrayList<>(); |
| 200 | + for (Patch patch : patches){ |
| 201 | + if (mutationRng.nextFloat() < 0.5){ |
| 202 | + mutatedPatches.add(mutate(patch)); |
| 203 | + } |
| 204 | + } |
| 205 | + patches = mutatedPatches; |
| 206 | + ArrayList<Integer> dirs = new ArrayList<>(); |
| 207 | + dirs.add(-1); |
| 208 | + dirs.add(-1); |
| 209 | + NSGAIIPop Q = new NSGAIIPop(2, dirs); |
| 210 | + //fitness |
| 211 | + for (Patch patch: patches){ |
| 212 | + UnitTestResultSet resultSet; |
| 213 | + resultSet = testPatch(className, tests, patch); |
| 214 | + |
| 215 | + writePatch(resultSet, methodName); |
| 216 | + ArrayList<Long> fitnesses = new ArrayList<>(); |
| 217 | + if (resultSet.allTestsSuccessful()) { |
| 218 | + fitnesses.add(resultSet.totalExecutionTime()); |
| 219 | + fitnesses.add(resultSet.totalMemoryUsage()); |
| 220 | + } else { |
| 221 | + fitnesses.add(Long.MAX_VALUE); |
| 222 | + fitnesses.add(Long.MAX_VALUE); |
| 223 | + } |
| 224 | + Q.addInd(patch, fitnesses); |
| 225 | + } |
| 226 | + |
| 227 | + return Q; |
| 228 | + } |
| 229 | + protected UnitTestResultSet initFitness(String className, List<UnitTest> tests, Patch origPatch) { |
| 230 | + |
| 231 | + UnitTestResultSet results = testPatch(className, tests, origPatch); |
| 232 | + return results; |
| 233 | + } |
| 234 | + |
| 235 | + protected Patch mutate(Patch oldPatch) { |
| 236 | + Patch patch = oldPatch.clone(); |
| 237 | + patch.addRandomEditOfClasses(mutationRng, editTypes); |
| 238 | + return patch; |
| 239 | + } |
| 240 | + |
| 241 | + |
| 242 | + protected List<Patch> crossover(List<Patch> patches, Patch origPatch) { |
| 243 | + |
| 244 | + List<Patch> crossedPatches = new ArrayList<>(); |
| 245 | + |
| 246 | + Collections.shuffle(patches, mutationRng); |
| 247 | + int half = patches.size() / 2; |
| 248 | + for (int i = 0; i < half; i++) { |
| 249 | + |
| 250 | + Patch parent1 = patches.get(i); |
| 251 | + Patch parent2 = patches.get(i + half); |
| 252 | + List<Edit> list1 = parent1.getEdits(); |
| 253 | + List<Edit> list2 = parent2.getEdits(); |
| 254 | + |
| 255 | + Patch child1 = origPatch.clone(); |
| 256 | + Patch child2 = origPatch.clone(); |
| 257 | + |
| 258 | + for (int j = 0; j < list1.size(); j++) { |
| 259 | + if (mutationRng.nextFloat() > 0.5) { |
| 260 | + child1.add(list1.get(j)); |
| 261 | + } |
| 262 | + } |
| 263 | + for (int j = 0; j < list2.size(); j++) { |
| 264 | + if (mutationRng.nextFloat() > 0.5) { |
| 265 | + child1.add(list2.get(j)); |
| 266 | + } |
| 267 | + if (mutationRng.nextFloat() > 0.5) { |
| 268 | + child2.add(list2.get(j)); |
| 269 | + } |
| 270 | + } |
| 271 | + for (int j = 0; j < list1.size(); j++) { |
| 272 | + if (mutationRng.nextFloat() > 0.5) { |
| 273 | + child2.add(list1.get(j)); |
| 274 | + } |
| 275 | + } |
| 276 | + |
| 277 | + crossedPatches.add(parent1); |
| 278 | + crossedPatches.add(parent2); |
| 279 | + crossedPatches.add(child1); |
| 280 | + crossedPatches.add(child2); |
| 281 | + } |
| 282 | + |
| 283 | + return crossedPatches; |
| 284 | + } |
| 285 | + |
| 286 | + |
| 287 | + |
| 288 | + /*============== Helper methods ==============*/ |
| 289 | + |
| 290 | + protected void writeNewHeader() { |
| 291 | + String[] entry = {"MethodName" |
| 292 | + , "Patch" |
| 293 | + , "Compiled" |
| 294 | + , "AllTestsPassed" |
| 295 | + , "TotalExecutionTime(ms)" |
| 296 | + , "ExecutionTimeFitness" |
| 297 | + , "ExecutionTimeFitnessImprovement" |
| 298 | + , "MemoryFitness" |
| 299 | + , "MemoryFitnessImprovement" |
| 300 | + }; |
| 301 | + try { |
| 302 | + outputFileWriter = new CSVWriter(new FileWriter(outputFile)); |
| 303 | + outputFileWriter.writeNext(entry); |
| 304 | + } catch (IOException e) { |
| 305 | + Logger.error(e, "Exception writing results to the output file: " + outputFile.getAbsolutePath()); |
| 306 | + Logger.trace(e); |
| 307 | + System.exit(-1); |
| 308 | + } |
| 309 | + } |
| 310 | + |
| 311 | + protected void writePatch(UnitTestResultSet results, String methodName, double fitnessTime, double improvementTime, double fitnessMem, double improvementMem) { |
| 312 | + String[] entry = {methodName |
| 313 | + , results.getPatch().toString() |
| 314 | + , Boolean.toString(results.getCleanCompile()) |
| 315 | + , Boolean.toString(results.allTestsSuccessful()) |
| 316 | + , Float.toString(results.totalExecutionTime() / 1000000.0f) |
| 317 | + , Double.toString(fitnessTime) |
| 318 | + , Double.toString(improvementTime) |
| 319 | + , Double.toString(fitnessMem) |
| 320 | + , Double.toString(improvementMem) |
| 321 | + }; |
| 322 | + outputFileWriter.writeNext(entry); |
| 323 | + } |
| 324 | + |
| 325 | + protected void writePatch(UnitTestResultSet resultSet, String methodName) { |
| 326 | + float execTime = resultSet.totalExecutionTime() / 1000000.0f; |
| 327 | + if (execTime == 0 || ! resultSet.allTestsSuccessful()) execTime = Float.MAX_VALUE; |
| 328 | + long memoryUsage = resultSet.totalMemoryUsage(); |
| 329 | + if(memoryUsage==0 || ! resultSet.allTestsSuccessful())memoryUsage=Long.MAX_VALUE; |
| 330 | + writePatch(resultSet, methodName, execTime, initTime - execTime, memoryUsage, initMem - memoryUsage); |
| 331 | + } |
| 332 | + |
| 333 | + |
| 334 | + |
| 335 | + |
| 336 | +} |
0 commit comments