< Return to JoshWhitman.com Home Page

Source code for the Bayesian Spam Filter

This code is the Machine Learning part of the project that appears in one of my YouTube Videos.

Download the project that lets you drag-drop from Outlook here: BayesianTrainingClient.zip. Note that the app doesn't use the BayesianClassifier class. The code in the app was used to develop the class. I included BayesianClassifier.cs in the zip, and commented in where you'd probably want to create a couple instances in case you want to rewrite the app to use the production class.

    
    public class BayesianClassifier
    {
        private const string DataPath = @"D:\SpamFilter\Bayesian";
        private const int MaxKeywordsToEvaluate = 40;

        private long SpamTrainingMessages = 0, HamTrainingMessages = 0;

        private Dictionary<string, long> SubjectSpamKeywords = new Dictionary<string, long>();
        private Dictionary<string, long> SubjectHamKeywords = new Dictionary<string, long>();
        private Dictionary<string, long> BodySpamKeywords = new Dictionary<string, long>();
        private Dictionary<string, long> BodyHamKeywords = new Dictionary<string, long>();

        public BayesianClassifier()
        {
            //Read the exiting training sets into the dictionaries.
            LoadExisting();
        }

        #region Evaluating unknown messages for spam
        public double IsSubjectSpam(string[] subject)
        {
            return GetSpamProbability(subject, SubjectSpamKeywords, SubjectHamKeywords);
        }
        public double IsBodySpam(string[] body)
        {
            return GetSpamProbability(body, BodySpamKeywords, BodyHamKeywords);
        }

        private double GetSpamProbability(string[] words, Dictionary<string, long> spam, Dictionary<string, long> ham)
        {
            //Set the prevailing ratio of spam to ham.
            double pSpam = 0.5d; //set to a different ratio if messages will be prejudices as being spam.

            //Set the training set ratio on the keyword class.
            KeywordComparison.TrainingSetHamToSpamRatio = (double)HamTrainingMessages / SpamTrainingMessages; //Make sure to compare ham and spam proportional to the data in the training set.

            //Assemble a list of keyword comparisons.
            KeywordComparison[] keywords = new KeywordComparison[words.Length];
            for (int n = 0; n < words.Length; n++)
            {
                long hcount = 0, scount = 0;
                if (ham.ContainsKey(words[n])) hcount = ham[words[n]];
                if (spam.ContainsKey(words[n])) scount = spam[words[n]];
                keywords[n] = new KeywordComparison(words[n], hcount, scount); //.GetValue(words[n])
            }

            //Sort by their significance.
            Array.Sort(keywords); //Sort by usefulness.
            Array.Reverse(keywords); //Most useful to least.

            double x = Math.Log(1.0d - pSpam) - Math.Log(pSpam); //Start with the prevailing ratio.
            for (int n = 0; n < keywords.Length && n < MaxKeywordsToEvaluate; n++) //Look at no more than the 40 most revealing keywords to eliminate large amounts of text from de-weighting a spam email.
            {
                //Compile a number to prevent underflow for longer messages.
                x += Math.Log(1.0d - keywords[n].SpamProbability) - Math.Log(keywords[n].SpamProbability);
            }

            //The probability of spam over the total probability is the actual probability of spam.  To the root of the number of keywords sampled.
            return 1.0d / (1.0d + Math.Pow(Math.E, x));
        }

        private class KeywordComparison : IComparable
        {
            //We'll adjust the ratios of words to offest the disparity in messages processed in the training sets.  
            //Default is that we've seen an equal number of both good and bad messages.
            public static double TrainingSetHamToSpamRatio { get; set; } = 1.0d;
            public static double s = 1.0d; //Add one to account for not having encountered this word, but it is possible it will be encountered in either class in the future.
            public KeywordComparison(string word, long hammy, long spammy)
            {
                Word = word;
                Hammy = hammy;
                Spammy = spammy;
                SpamProbability = (spammy + s) / (spammy + s + (hammy / TrainingSetHamToSpamRatio + s)); //The ratio is actually the ratio of incidents in spam to total incidents.
                Usefulness = Math.Abs(0.5d - SpamProbability) * 2;
            }
            public string Word { get; private set; }
            public long Hammy { get; private set; }
            public long Spammy { get; private set; }
            public double SpamProbability { get; private set; }
            private double Usefulness { get; set; } = 0.0d;

            public int CompareTo(object obj)
            {
                return Usefulness.CompareTo((obj as KeywordComparison).Usefulness);
            }
        }

        #endregion
        #region Training, Saving and Loading
        private void TrainSpam(string[] subject, string[] body)
        {
            TrainList(SubjectSpamKeywords, subject);
            TrainList(BodySpamKeywords, body);
            SpamTrainingMessages++;

            //Save the spam training data.
            SaveItem(subject, body, "spam");
        }
        private void TrainHam(string[] subject, string[] body)
        {
            TrainList(SubjectHamKeywords, subject);
            TrainList(BodyHamKeywords, body);
            HamTrainingMessages++;

            //Save the spam training data.
            SaveItem(subject, body, "ham");
        }
        private void SaveItem(string[] subject, string[] body, string folder)
        {
            string info = string.Join(",", subject) + "\r\n" + string.Join(",", body);
            System.IO.File.WriteAllText(DataPath + "\\" + folder + "\\" + Guid.NewGuid().ToString() + ".txt", info);
        }
        private void LoadExisting()
        {
            //Get the list of spams.
            string[] spams = System.IO.Directory.GetFiles(DataPath + "\\spam");
            string[] hams = System.IO.Directory.GetFiles(DataPath + "\\ham");

            foreach (string file in spams)
            {
                try
                {
                    string inf = System.IO.File.ReadAllText(file);
                    string[] parts = inf.Split(new char[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries);
                    if (parts.Length == 2)
                    {
                        TrainList(SubjectSpamKeywords, parts[0].Split(','));
                        TrainList(BodySpamKeywords, parts[1].Split(','));
                    }
                    SpamTrainingMessages++;
                }
                catch { }
            }

            foreach (string file in hams)
            {
                try
                {
                    string inf = System.IO.File.ReadAllText(file);
                    string[] parts = inf.Split(new char[] { '\r', '\n' }, StringSplitOptions.RemoveEmptyEntries);
                    if (parts.Length == 2)
                    {
                        TrainList(SubjectHamKeywords, parts[0].Split(','));
                        TrainList(BodyHamKeywords, parts[1].Split(','));
                    }
                    HamTrainingMessages++;
                }
                catch { }
            }
        }
        private void TrainList(Dictionary<string, long> list, string[] words)
        {
            if (words == null) return;
            foreach (string s in words)
            {
                if (!list.ContainsKey(s)) list.Add(s, 1);
                else list[s] = list[s] + 1;
            }
        }
        #endregion
    }