Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.idea
179 changes: 135 additions & 44 deletions bayesian.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package bayesian

import (
"encoding/gob"
"encoding/json"
"errors"
"io"
"log"
"math"
"os"
"path/filepath"
Expand All @@ -14,6 +16,15 @@ import (
// we have not seen before appears in the class.
const defaultProb = 0.00000000001

// Serializer to persist learned model
// gob by default
type Serializer string

const (
Gob Serializer = "gob"
JSON Serializer = "json"
)

// ErrUnderflow is returned when an underflow is detected.
var ErrUnderflow = errors.New("possible underflow detected")

Expand Down Expand Up @@ -44,22 +55,22 @@ type Classifier struct {
// Classifier objects whose fields are modifiable by
// reflection and are therefore writeable by gob.
type serializableClassifier struct {
Classes []Class
Learned int
Seen int
Datas map[Class]*classData
TfIdf bool
DidConvertTfIdf bool
Classes []Class `json:"c,omitempty"`
Learned float64 `json:"l,omitempty"`
Seen float64 `json:"s,omitempty"`
Datas map[Class]*classData `json:"d,omitempty"`
TfIdf bool `json:"t,omitempty"`
DidConvertTfIdf bool `json:"dc,omitempty"`
}

// classData holds the frequency data for words in a
// particular class. In the future, we may replace this
// structure with a trie-like structure for more
// efficient storage.
type classData struct {
Freqs map[string]float64
FreqTfs map[string][]float64
Total int
Freqs map[string]float64 `json:"r,omitempty"`
FreqTfs map[string][]float64 `json:"ft,omitempty"`
Total float64 `json:"t,omitempty"`
}

// newClassData creates a new empty classData node.
Expand Down Expand Up @@ -160,23 +171,52 @@ func NewClassifier(classes ...Class) (c *Classifier) {
// NewClassifierFromFile loads an existing classifier from
// file. The classifier was previously saved with a call
// to c.WriteToFile(string).
func NewClassifierFromFile(name string) (c *Classifier, err error) {
func NewClassifierFromFile(name string, s ...Serializer) (c *Classifier, err error) {
file, err := os.Open(name)
if err != nil {
return nil, err
}
defer file.Close()
defer func() {
err := file.Close()
if err != nil {
log.Print(err)
}
}()

return NewClassifierFromReader(file)
return NewClassifierFromReader(file, s...)
}

// NewClassifierFromReader: This actually does the deserializing of a Gob encoded classifier
func NewClassifierFromReader(r io.Reader) (c *Classifier, err error) {
dec := gob.NewDecoder(r)
// NewClassifierFromReader: This actually does the deserializing of a Gob/JSON encoded classifier
func NewClassifierFromReader(r io.Reader, s ...Serializer) (c *Classifier, err error) {
var ser Serializer
if len(s) == 1 {
ser = s[0]
}

w := new(serializableClassifier)
err = dec.Decode(w)
if ser == JSON {
dec := json.NewDecoder(r)
err = dec.Decode(w)
if err != nil {
return
}
} else {
dec := gob.NewDecoder(r)
err = dec.Decode(w)
if err != nil {
return
}
}

return &Classifier{w.Classes, w.Learned, int32(w.Seen), w.Datas, w.TfIdf, w.DidConvertTfIdf}, err
c = &Classifier{
w.Classes,
int(w.Learned),
int32(w.Seen),
w.Datas,
w.TfIdf,
w.DidConvertTfIdf,
}
return
}

// getPriors returns the prior probabilities for the
Expand All @@ -190,8 +230,8 @@ func (c *Classifier) getPriors() (priors []float64) {
sum := 0
for index, class := range c.Classes {
total := c.datas[class].Total
priors[index] = float64(total)
sum += total
priors[index] = total
sum += int(total)
}
if sum != 0 {
for i := 0; i < n; i++ {
Expand All @@ -213,7 +253,6 @@ func (c *Classifier) Seen() int {
return int(atomic.LoadInt32(&c.seen))
}


// IsTfIdf returns true if we are a classifier of type TfIdf
func (c *Classifier) IsTfIdf() bool {
return c.tfIdf
Expand All @@ -225,17 +264,18 @@ func (c *Classifier) WordCount() (result []int) {
result = make([]int, len(c.Classes))
for inx, class := range c.Classes {
data := c.datas[class]
result[inx] = data.Total
result[inx] = int(data.Total)
}
return
}

// Observe should be used when word-frequencies have been already been learned
// externally (e.g., hadoop)
func (c *Classifier) Observe(word string, count int, which Class) {
cnt := float64(count)
data := c.datas[which]
data.Freqs[word] += float64(count)
data.Total += count
data.Freqs[word] += cnt
data.Total += cnt
}

// Learn will accept new training documents for
Expand Down Expand Up @@ -292,7 +332,7 @@ func (c *Classifier) ConvertTermsFreqToTfIdf() {

// we always want a possitive TF-IDF score.
tf := c.datas[className].FreqTfs[wIndex][tfSampleIndex]
c.datas[className].FreqTfs[wIndex][tfSampleIndex] = math.Log1p(tf) * math.Log1p(float64(c.learned)/float64(c.datas[className].Total))
c.datas[className].FreqTfs[wIndex][tfSampleIndex] = math.Log1p(tf) * math.Log1p(float64(c.learned)/c.datas[className].Total)
tfIdfAdder += c.datas[className].FreqTfs[wIndex][tfSampleIndex]
}
// convert the 'counts' to TF-IDF's
Expand Down Expand Up @@ -468,70 +508,121 @@ func (c *Classifier) WordFrequencies(words []string) (freqMatrix [][]float64) {
func (c *Classifier) WordsByClass(class Class) (freqMap map[string]float64) {
freqMap = make(map[string]float64)
for word, cnt := range c.datas[class].Freqs {
freqMap[word] = float64(cnt) / float64(c.datas[class].Total)
freqMap[word] = cnt / c.datas[class].Total
}

return freqMap
}


// WriteToFile serializes this classifier to a file.
func (c *Classifier) WriteToFile(name string) (err error) {
func (c *Classifier) WriteToFile(name string, s ...Serializer) (err error) {
file, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
return err
}
defer file.Close()
defer func() {
err := file.Close()
if err != nil {
log.Print(err)
}
}()

return c.WriteTo(file)
return c.WriteTo(file, s...)
}

// WriteClassesToFile writes all classes to files.
func (c *Classifier) WriteClassesToFile(rootPath string) (err error) {
func (c *Classifier) WriteClassesToFile(rootPath string, s ...Serializer) (err error) {
for name := range c.datas {
c.WriteClassToFile(name, rootPath)
err = c.WriteClassToFile(name, rootPath, s...)
if err != nil {
return
}
}
return
}

// WriteClassToFile writes a single class to file.
func (c *Classifier) WriteClassToFile(name Class, rootPath string) (err error) {
func (c *Classifier) WriteClassToFile(name Class, rootPath string, s ...Serializer) (err error) {
var ser Serializer
if len(s) == 1 {
ser = s[0]
}

data := c.datas[name]
fileName := filepath.Join(rootPath, string(name))
file, err := os.OpenFile(fileName, os.O_WRONLY|os.O_CREATE, 0644)
if err != nil {
return err
}
defer file.Close()
defer func() {
err := file.Close()
if err != nil {
log.Print(err)
}
}()

if ser == JSON {
enc := json.NewEncoder(file)
return enc.Encode(data)
}

enc := gob.NewEncoder(file)
err = enc.Encode(data)
return
return enc.Encode(data)
}

// WriteTo serializes this classifier to JSON and write to Writer.
func (c *Classifier) WriteTo(w io.Writer, s ...Serializer) (err error) {
var ser Serializer
if len(s) == 1 {
ser = s[0]
}

// WriteTo serializes this classifier to GOB and write to Writer.
func (c *Classifier) WriteTo(w io.Writer) (err error) {
enc := gob.NewEncoder(w)
err = enc.Encode(&serializableClassifier{c.Classes, c.learned, int(c.seen), c.datas, c.tfIdf, c.DidConvertTfIdf})
data := &serializableClassifier{
c.Classes,
float64(c.learned),
float64(c.seen),
c.datas,
c.tfIdf,
c.DidConvertTfIdf,
}
if ser == JSON {
enc := json.NewEncoder(w)
return enc.Encode(data)
}

return
enc := gob.NewEncoder(w)
return enc.Encode(data)
}

// ReadClassFromFile loads existing class data from a
// file.
func (c *Classifier) ReadClassFromFile(class Class, location string) (err error) {
func (c *Classifier) ReadClassFromFile(class Class, location string, s ...Serializer) (err error) {
var ser Serializer
if len(s) == 1 {
ser = s[0]
}

fileName := filepath.Join(location, string(class))
file, err := os.Open(fileName)

if err != nil {
return err
}
defer file.Close()
defer func() {
err := file.Close()
if err != nil {
log.Print(err)
}
}()

dec := gob.NewDecoder(file)
w := new(classData)
err = dec.Decode(w)
if ser == JSON {
dec := json.NewDecoder(file)
err = dec.Decode(w)
} else {
dec := gob.NewDecoder(file)
err = dec.Decode(w)
}

c.learned++
c.datas[class] = w
Expand Down
2 changes: 1 addition & 1 deletion bayesian_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const (

func Assert(t *testing.T, condition bool, args ...interface{}) {
if !condition {
t.Fatal(args)
t.Fatal(args...)
}
}

Expand Down
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module github.com/jbrukh/bayesian

go 1.14