diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..485dee6 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea diff --git a/bayesian.go b/bayesian.go index a5ebac3..b74e79c 100644 --- a/bayesian.go +++ b/bayesian.go @@ -2,8 +2,10 @@ package bayesian import ( "encoding/gob" + "encoding/json" "errors" "io" + "log" "math" "os" "path/filepath" @@ -14,6 +16,15 @@ import ( // we have not seen before appears in the class. const defaultProb = 0.00000000001 +// Serializer to persist learned model +// gob by default +type Serializer string + +const ( + Gob Serializer = "gob" + JSON Serializer = "json" +) + // ErrUnderflow is returned when an underflow is detected. var ErrUnderflow = errors.New("possible underflow detected") @@ -44,12 +55,12 @@ type Classifier struct { // Classifier objects whose fields are modifiable by // reflection and are therefore writeable by gob. type serializableClassifier struct { - Classes []Class - Learned int - Seen int - Datas map[Class]*classData - TfIdf bool - DidConvertTfIdf bool + Classes []Class `json:"c,omitempty"` + Learned float64 `json:"l,omitempty"` + Seen float64 `json:"s,omitempty"` + Datas map[Class]*classData `json:"d,omitempty"` + TfIdf bool `json:"t,omitempty"` + DidConvertTfIdf bool `json:"dc,omitempty"` } // classData holds the frequency data for words in a @@ -57,9 +68,9 @@ type serializableClassifier struct { // structure with a trie-like structure for more // efficient storage. type classData struct { - Freqs map[string]float64 - FreqTfs map[string][]float64 - Total int + Freqs map[string]float64 `json:"r,omitempty"` + FreqTfs map[string][]float64 `json:"ft,omitempty"` + Total float64 `json:"t,omitempty"` } // newClassData creates a new empty classData node. @@ -160,23 +171,52 @@ func NewClassifier(classes ...Class) (c *Classifier) { // NewClassifierFromFile loads an existing classifier from // file. The classifier was previously saved with a call // to c.WriteToFile(string). -func NewClassifierFromFile(name string) (c *Classifier, err error) { +func NewClassifierFromFile(name string, s ...Serializer) (c *Classifier, err error) { file, err := os.Open(name) if err != nil { return nil, err } - defer file.Close() + defer func() { + err := file.Close() + if err != nil { + log.Print(err) + } + }() - return NewClassifierFromReader(file) + return NewClassifierFromReader(file, s...) } -// NewClassifierFromReader: This actually does the deserializing of a Gob encoded classifier -func NewClassifierFromReader(r io.Reader) (c *Classifier, err error) { - dec := gob.NewDecoder(r) +// NewClassifierFromReader: This actually does the deserializing of a Gob/JSON encoded classifier +func NewClassifierFromReader(r io.Reader, s ...Serializer) (c *Classifier, err error) { + var ser Serializer + if len(s) == 1 { + ser = s[0] + } + w := new(serializableClassifier) - err = dec.Decode(w) + if ser == JSON { + dec := json.NewDecoder(r) + err = dec.Decode(w) + if err != nil { + return + } + } else { + dec := gob.NewDecoder(r) + err = dec.Decode(w) + if err != nil { + return + } + } - return &Classifier{w.Classes, w.Learned, int32(w.Seen), w.Datas, w.TfIdf, w.DidConvertTfIdf}, err + c = &Classifier{ + w.Classes, + int(w.Learned), + int32(w.Seen), + w.Datas, + w.TfIdf, + w.DidConvertTfIdf, + } + return } // getPriors returns the prior probabilities for the @@ -190,8 +230,8 @@ func (c *Classifier) getPriors() (priors []float64) { sum := 0 for index, class := range c.Classes { total := c.datas[class].Total - priors[index] = float64(total) - sum += total + priors[index] = total + sum += int(total) } if sum != 0 { for i := 0; i < n; i++ { @@ -213,7 +253,6 @@ func (c *Classifier) Seen() int { return int(atomic.LoadInt32(&c.seen)) } - // IsTfIdf returns true if we are a classifier of type TfIdf func (c *Classifier) IsTfIdf() bool { return c.tfIdf @@ -225,7 +264,7 @@ func (c *Classifier) WordCount() (result []int) { result = make([]int, len(c.Classes)) for inx, class := range c.Classes { data := c.datas[class] - result[inx] = data.Total + result[inx] = int(data.Total) } return } @@ -233,9 +272,10 @@ func (c *Classifier) WordCount() (result []int) { // Observe should be used when word-frequencies have been already been learned // externally (e.g., hadoop) func (c *Classifier) Observe(word string, count int, which Class) { + cnt := float64(count) data := c.datas[which] - data.Freqs[word] += float64(count) - data.Total += count + data.Freqs[word] += cnt + data.Total += cnt } // Learn will accept new training documents for @@ -292,7 +332,7 @@ func (c *Classifier) ConvertTermsFreqToTfIdf() { // we always want a possitive TF-IDF score. tf := c.datas[className].FreqTfs[wIndex][tfSampleIndex] - c.datas[className].FreqTfs[wIndex][tfSampleIndex] = math.Log1p(tf) * math.Log1p(float64(c.learned)/float64(c.datas[className].Total)) + c.datas[className].FreqTfs[wIndex][tfSampleIndex] = math.Log1p(tf) * math.Log1p(float64(c.learned)/c.datas[className].Total) tfIdfAdder += c.datas[className].FreqTfs[wIndex][tfSampleIndex] } // convert the 'counts' to TF-IDF's @@ -468,70 +508,121 @@ func (c *Classifier) WordFrequencies(words []string) (freqMatrix [][]float64) { func (c *Classifier) WordsByClass(class Class) (freqMap map[string]float64) { freqMap = make(map[string]float64) for word, cnt := range c.datas[class].Freqs { - freqMap[word] = float64(cnt) / float64(c.datas[class].Total) + freqMap[word] = cnt / c.datas[class].Total } return freqMap } - // WriteToFile serializes this classifier to a file. -func (c *Classifier) WriteToFile(name string) (err error) { +func (c *Classifier) WriteToFile(name string, s ...Serializer) (err error) { file, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE, 0644) if err != nil { return err } - defer file.Close() + defer func() { + err := file.Close() + if err != nil { + log.Print(err) + } + }() - return c.WriteTo(file) + return c.WriteTo(file, s...) } // WriteClassesToFile writes all classes to files. -func (c *Classifier) WriteClassesToFile(rootPath string) (err error) { +func (c *Classifier) WriteClassesToFile(rootPath string, s ...Serializer) (err error) { for name := range c.datas { - c.WriteClassToFile(name, rootPath) + err = c.WriteClassToFile(name, rootPath, s...) + if err != nil { + return + } } return } // WriteClassToFile writes a single class to file. -func (c *Classifier) WriteClassToFile(name Class, rootPath string) (err error) { +func (c *Classifier) WriteClassToFile(name Class, rootPath string, s ...Serializer) (err error) { + var ser Serializer + if len(s) == 1 { + ser = s[0] + } + data := c.datas[name] fileName := filepath.Join(rootPath, string(name)) file, err := os.OpenFile(fileName, os.O_WRONLY|os.O_CREATE, 0644) if err != nil { return err } - defer file.Close() + defer func() { + err := file.Close() + if err != nil { + log.Print(err) + } + }() + + if ser == JSON { + enc := json.NewEncoder(file) + return enc.Encode(data) + } enc := gob.NewEncoder(file) - err = enc.Encode(data) - return + return enc.Encode(data) } +// WriteTo serializes this classifier to JSON and write to Writer. +func (c *Classifier) WriteTo(w io.Writer, s ...Serializer) (err error) { + var ser Serializer + if len(s) == 1 { + ser = s[0] + } -// WriteTo serializes this classifier to GOB and write to Writer. -func (c *Classifier) WriteTo(w io.Writer) (err error) { - enc := gob.NewEncoder(w) - err = enc.Encode(&serializableClassifier{c.Classes, c.learned, int(c.seen), c.datas, c.tfIdf, c.DidConvertTfIdf}) + data := &serializableClassifier{ + c.Classes, + float64(c.learned), + float64(c.seen), + c.datas, + c.tfIdf, + c.DidConvertTfIdf, + } + if ser == JSON { + enc := json.NewEncoder(w) + return enc.Encode(data) + } - return + enc := gob.NewEncoder(w) + return enc.Encode(data) } // ReadClassFromFile loads existing class data from a // file. -func (c *Classifier) ReadClassFromFile(class Class, location string) (err error) { +func (c *Classifier) ReadClassFromFile(class Class, location string, s ...Serializer) (err error) { + var ser Serializer + if len(s) == 1 { + ser = s[0] + } + fileName := filepath.Join(location, string(class)) file, err := os.Open(fileName) if err != nil { return err } - defer file.Close() + defer func() { + err := file.Close() + if err != nil { + log.Print(err) + } + }() - dec := gob.NewDecoder(file) w := new(classData) - err = dec.Decode(w) + if ser == JSON { + dec := json.NewDecoder(file) + err = dec.Decode(w) + } else { + dec := gob.NewDecoder(file) + err = dec.Decode(w) + } c.learned++ c.datas[class] = w diff --git a/bayesian_test.go b/bayesian_test.go index a5d6be1..8981f2f 100644 --- a/bayesian_test.go +++ b/bayesian_test.go @@ -11,7 +11,7 @@ const ( func Assert(t *testing.T, condition bool, args ...interface{}) { if !condition { - t.Fatal(args) + t.Fatal(args...) } } diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..618ab2e --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/jbrukh/bayesian + +go 1.14