-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDomainClassifier.java
More file actions
130 lines (109 loc) · 3.66 KB
/
DomainClassifier.java
File metadata and controls
130 lines (109 loc) · 3.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
/*
* chatbot.component is added for Assignment 3 (Language Understanding)
*
* DomainClassifier.java is added for Assignment 3 (Language
* Understanding)
*/
package chatbot.component;
public class DomainClassifier {
private static String[] domainDictionary;
public DomainClassifier() {
initializeDomainDictionary();
}
/**
* Create a dictionary of domains
*/
private void initializeDomainDictionary() {
//list all the domains
domainDictionary = new String[]{"Other", "Gym"};
//create the display string
System.out.print("Domains: (");
for(int i=0;i<domainDictionary.length;i++) {
System.out.print(domainDictionary[i]);
if(i!=domainDictionary.length-1) {
System.out.print(", ");
}
}
System.out.println(")");
}
/**
* Calculate the given meesage's score for each domain. The chatbot will
* select the domain with the *highest* score.
*
* The initial score of each domain is 0.0.
*
* @param nowInputText An English message sent from the user.
* @return An Double array that contains the score of each
* domain.
*/
private Double[] calculateDomainScores(String nowInputText) {
//DO NOT change the following 4 lines
//initiate all the scores to 0.0
Double[] scoreArray = new Double[domainDictionary.length];
for(int i=0;i<scoreArray.length;i++) {
scoreArray[i] = Double.valueOf(0.0);
}
//The following is the part you need to modify.
//============= Please Modify Here (begins) ===============
//Count key words in independent other dictionary
/*String[] otherDictionary = new String[] {"other"};
for(String otherKeyword: otherDictionary) {
if(nowInputText.toLowerCase().indexOf(otherKeyword)>=0) {
scoreArray[0] = scoreArray[0].doubleValue()+1.0;
}
}*/
//Count key words in a Gym dictionary
String[] gymDictionary = new String[] {"gym", "gyms", "reps", "sets", "time", "muscle", "exercise", "location"};
for(String gymKeyword: gymDictionary) {
if(nowInputText.toLowerCase().indexOf(gymKeyword)>=0) {
//{"Other", "Gym"}, so scoreArray[1] indicates the score for Gym domain
scoreArray[1] = scoreArray[1].doubleValue()+1.0;
}
}
//============= Please Modify Here (ends) ===============
//Do not change the following lines
//Check before returning the scoreArray
if(scoreArray.length!=domainDictionary.length) {
System.err.println("The score array size does not equal to the domain array size.");
System.exit(1);
}
for(Double nowValue: scoreArray) {
if(nowValue==null) {
System.err.println("The score array contains null values.");
System.exit(1);
}
}
return scoreArray;
}
/**
* Input:
* nowInputText: the message that the user sent to your chatbot
*
* Output:
* the label (domain) name string
*
* @param nowInputText An English message sent from the user.
* @return The name of the domain.
*
*/
public String getLabel(String nowInputText) {
//get the score array
Double[] intentScores = calculateDomainScores(nowInputText);
//print the scores of each domain
Double nowMaxScore = null;
int nowMaxIndex = -1;
System.out.print("Domain Scores: (");
for(int i=0;i<intentScores.length;i++){
System.out.print(intentScores[i].doubleValue());
if(i!=intentScores.length-1) {
System.out.print(", ");
}
if(nowMaxScore==null||nowMaxIndex==-1||intentScores[i].doubleValue()>nowMaxScore.doubleValue()) {
nowMaxIndex = i;
nowMaxScore = intentScores[i].doubleValue();
}
}
System.out.println(")");
return domainDictionary[nowMaxIndex];
}
}