您的位置:首页 > 博客中心 > 网络系统 >

[Machine Learning]朴素贝叶斯(NaiveBayes)

时间:2022-04-03 11:14

C++ 描述:

 1 #include <iostream>
 2 #include <string>
 3 #include <fstream>
 4 #include <sstream>
 5 #include <vector>
 6 #include <map>
 7 #include <set>
 8 
 9 using namespace std;
10 
11 class NaiveBayes {
12 public:
13     void load_data(string path);
14     void train_model();
15     int predict(const vector<int> &item);
16 private:
17     vector<vector<int>> data;
18     map<pair<int, int>, double> c_p; //conditional prob
19     map<int, double> p_p; // prior prob
20 };
21 
22 void NaiveBayes::load_data(string path) {
23     ifstream fin(path.c_str());
24     if (!fin) {
25         cerr << "open file error" << endl;
26         exit(1);
27     }
28 
29     string line;
30     while (getline(fin, line)) {
31         if (line.size() > 1) {
32             stringstream sin(line);
33             int elem;
34             vector<int> tmp;
35             while (sin >> elem) {
36                 tmp.push_back(elem);
37             }
38             data.push_back(tmp);
39         }
40     }
41     fin.close();
42 }
43 
44 void NaiveBayes::train_model() {
45     for (auto &d : data) {
46         int len = d.size();
47         p_p[d[len - 1]] += (1.0 / data.size());
48     }
49 
50     for (auto &p : p_p) {
51         int label = p.first;
52         double prior = p.second;
53         for (auto &d : data) {
54             for (int i = 0; i < d.size(); ++i) {
55                 c_p[make_pair(d[i], label)] += (1.0 / (prior * data.size()));
56             }
57         }
58     }
59 }
60 
61 int NaiveBayes::predict(const vector<int> &item) { 
62     int result;
63     double max_prob = 0.0;
64     for (auto &p : p_p) {
65         int label = p.first;
66         double prior = p.second;
67         double prob = prior;
68         for (int i = 0; i < item.size() - 1; ++i) {
69             prob *= c_p[make_pair(item[i], label)];
70         }
71 
72         if (prob > max_prob) {
73             max_prob = prob;
74             result = label;
75         }
76     }
77 
78     return result;
79 }
80 
81 int main() {
82     NaiveBayes naive_bayes;
83     naive_bayes.load_data(string("result.txt"));
84     naive_bayes.train_model();
85 
86     vector<int> item{2, 4};
87     cout << naive_bayes.predict(item);
88     return 0;
89 }

数据集:

1 4 -1
1 5 -1
1 5 1
1 4 1
1 4 -1
2 4 -1
2 5 -1
2 5 1
2 6 1
2 6 1
3 6 1
3 5 1
3 5 1
3 6 1
3 6 -1

 

本类排行

今日推荐

热门手游