File Coverage

blib/lib/Redis/NaiveBayes.pm
Criterion Covered Total %
statement 15 71 21.1
branch 0 14 0.0
condition 0 7 0.0
subroutine 5 19 26.3
pod 6 6 100.0
total 26 117 22.2


line stmt bran cond sub pod time code
1             package Redis::NaiveBayes;
2             # ABSTRACT: A generic Redis-backed NaiveBayes implementation
3             $Redis::NaiveBayes::VERSION = '0.0.4';
4              
5 8     8   361748 use strict;
  8         23  
  8         324  
6 8     8   42 use warnings;
  8         17  
  8         263  
7 8     8   50 use List::Util qw(sum reduce);
  8         16  
  8         1091  
8              
9 8     8   7614 use Redis;
  8         753694  
  8         329  
10              
11             use constant {
12 8         21160 LABELS => 'labels',
13 8     8   96 };
  8         19  
14              
15             # Lua scripts
16             my $LUA_FLUSH_FMT = q{
17             local namespace = '%s'
18             local labels_key = namespace .. '%s'
19             for _, member in ipairs(redis.call('smembers', labels_key)) do
20             redis.call('del', namespace .. member)
21             redis.call('del', namespace .. 'tally_for:' .. member)
22             end
23             redis.call('del', labels_key);
24             };
25              
26             my $LUA_TRAIN_FMT = q{
27             -- ARGV:
28             -- 1: raw label name being trained
29             -- 2: number of tokens being updated
30             -- 3-X: token being updated
31             -- X+1-N: value to increment corresponding token
32              
33             local namespace = '%s'
34             local labels_key = namespace .. '%s'
35             local label = namespace .. ARGV[1]
36             local tally_key = namespace .. 'tally_for:' .. ARGV[1]
37             local num_tokens = ARGV[2]
38             local tot_added = 0
39              
40             redis.call('sadd', labels_key, ARGV[1])
41              
42             for index, token in ipairs(ARGV) do
43             if index > num_tokens + 2 then
44             break
45             end
46             if index > 2 then
47             redis.call('hincrby', label, token, ARGV[index + num_tokens])
48             tot_added = tot_added + ARGV[index + num_tokens]
49             end
50             end
51              
52             local old_tally = redis.call('get', tally_key);
53             if (not old_tally) then
54             old_tally = 0
55             end
56              
57             redis.call('set', tally_key, old_tally + tot_added)
58             };
59              
60             my $LUA_UNTRAIN_FMT = q{
61             -- ARGV:
62             -- 1: raw label name being untrained
63             -- 2: number of tokens being updated
64             -- 3-X: token being updated
65             -- X+1-N: value to increment corresponding token
66              
67             local namespace = '%s'
68             local labels_key = namespace .. '%s'
69             local label = namespace .. ARGV[1]
70             local tally_key = namespace .. 'tally_for:' .. ARGV[1]
71             local num_tokens = ARGV[2]
72              
73             for index, token in ipairs(ARGV) do
74             if index > num_tokens + 2 then
75             break
76             end
77             if index > 2 then
78             local current = redis.call('hget', label, token);
79              
80             if (current and current - ARGV[index + num_tokens] > 0) then
81             redis.call('hincrby', label, token, -1 * ARGV[index + num_tokens])
82             else
83             redis.call('hdel', label, token)
84             end
85             end
86             end
87              
88             local tally = 0
89             for _, value in ipairs(redis.call('hvals', label)) do
90             tally = tally + value
91             end
92              
93             if tally <= 0 then
94             redis.call('del', label)
95             redis.call('srem', labels_key, ARGV[1])
96             redis.call('del', tally_key)
97             else
98             redis.call('set', tally_key, tally)
99             end
100             };
101              
102             my $_LUA_CALCULATE_SCORES = q{
103             -- ARGV
104             -- 1: correction
105             -- 2: number of tokens
106             -- 3-X: tokens
107             -- X+1-N: values for each token
108             -- FIXME: I'm ignoring the scores per token on purpose for now
109              
110             local namespace = '%s'
111             local labels_key = namespace .. '%s'
112             local correction = ARGV[1]
113             local num_tokens = ARGV[2]
114              
115             local scores = {}
116              
117             for index, raw_label in ipairs(redis.call('smembers', labels_key)) do
118             local label = namespace .. raw_label
119              
120             local tally = tonumber(redis.call('get', namespace .. 'tally_for:' .. raw_label))
121              
122             if (tally and tally > 0) then
123             scores[raw_label] = 0.0
124              
125             for idx, token in ipairs(ARGV) do
126             if idx > num_tokens + 2 then
127             break
128             end
129              
130             if idx > 2 then
131             local score = redis.call('hget', label, token)
132              
133             if (not score or score == 0) then
134             score = correction
135             end
136              
137             scores[raw_label] = scores[raw_label] + math.log(score / tally)
138             end
139             end
140             end
141             end
142             };
143              
144             my $LUA_SCORES_FMT = qq{
145             $_LUA_CALCULATE_SCORES
146              
147             local return_crap = {}
148             local index = 1
149             for key, value in pairs(scores) do
150             return_crap[index] = key
151             return_crap[index+1] = value
152             index = index + 2
153             end
154              
155             return return_crap;
156             };
157              
158             my $LUA_CLASSIFY_FMT = qq{
159             $_LUA_CALCULATE_SCORES
160              
161             local best_label = nil
162             local best_score = nil
163             for label, score in pairs(scores) do
164             if (best_score == nil or best_score < score) then
165             best_label = label
166             best_score = score
167             end
168             end
169              
170             return best_label
171             };
172              
173              
174             sub new {
175 0     0 1   my ($class, %args) = @_;
176 0           my $self = bless {}, $class;
177              
178 0   0       $self->{redis} = $args{redis} || Redis->new(%args);
179 0   0       $self->{correction} = $args{correction} || 0.001;
180 0 0         $self->{namespace} = $args{namespace} or die "Missing namespace";
181 0 0         $self->{tokenizer} = $args{tokenizer} or die "Missing tokenizer";
182              
183 0           $self->_load_scripts;
184              
185 0           return $self;
186             }
187              
188             sub _redis_script_load {
189 0     0     my ($self, $script_fmt, @args) = @_;
190              
191 0           my ($sha1) = $self->{redis}->script_load(sprintf($script_fmt, $self->{namespace}, LABELS, @args));
192              
193 0           return $sha1;
194             }
195              
196             sub _load_scripts {
197 0     0     my ($self) = @_;
198              
199 0           $self->{scripts} = {};
200              
201 0           $self->{scripts}->{flush} = $self->_redis_script_load($LUA_FLUSH_FMT);
202 0           $self->{scripts}->{train} = $self->_redis_script_load($LUA_TRAIN_FMT);
203 0           $self->{scripts}->{untrain} = $self->_redis_script_load($LUA_UNTRAIN_FMT);
204 0           $self->{scripts}->{scores} = $self->_redis_script_load($LUA_SCORES_FMT);
205 0           $self->{scripts}->{classify} = $self->_redis_script_load($LUA_CLASSIFY_FMT);
206             }
207              
208             sub _exec {
209 0     0     my ($self, $command, $key, @rest) = @_;
210              
211 0           return $self->{redis}->$command($self->{namespace} . $key, @rest);
212             }
213              
214             sub _run_script {
215 0     0     my ($self, $script, $numkeys, @rest) = @_;
216              
217 0   0       $numkeys ||= 0;
218 0 0         my $sha1 = $self->{scripts}->{$script} or die "Script wasn't loaded: '$script'";
219              
220 0           $self->{redis}->evalsha($sha1, $numkeys, @rest);
221             }
222              
223              
224             sub flush {
225 0     0 1   my ($self) = @_;
226              
227 0           $self->_run_script('flush');
228             }
229              
230             sub _mrproper {
231 0     0     my ($self) = @_;
232              
233 0           my @keys = $self->{redis}->keys($self->{namespace} . '*');
234 0 0         $self->{redis}->del(@keys) if @keys;
235             }
236              
237             sub _train {
238 0     0     my ($self, $label, $item, $script) = @_;
239              
240 0           my $occurrences = $self->{tokenizer}->($item);
241 0 0         die "tokenizer() didn't return a HASHREF" unless ref $occurrences eq 'HASH';
242              
243 0           my @argv = ($label, (scalar keys %$occurrences), keys %$occurrences, values %$occurrences);
244              
245 0           $self->_run_script($script, 0, @argv);
246              
247 0           return $occurrences;
248             }
249              
250              
251             sub train {
252 0     0 1   my ($self, $label, $item) = @_;
253              
254 0           return $self->_train($label, $item, 'train');
255             }
256              
257              
258             sub untrain {
259 0     0 1   my ($self, $label, $item) = @_;
260              
261 0           return $self->_train($label, $item, 'untrain');
262             }
263              
264              
265             sub classify {
266 0     0 1   my ($self, $item) = @_;
267              
268 0           my $occurrences = $self->{tokenizer}->($item);
269 0 0         die "tokenizer() didn't return a HASHREF" unless ref $occurrences eq 'HASH';
270              
271 0           my @argv = ($self->{correction}, scalar keys %$occurrences, keys %$occurrences, values %$occurrences);
272              
273 0           my $best_label = $self->_run_script('classify', 0, @argv);
274              
275 0           return $best_label;
276             }
277              
278              
279             sub scores {
280 0     0 1   my ($self, $item) = @_;
281              
282 0           my $occurrences = $self->{tokenizer}->($item);
283 0 0         die "tokenizer() didn't return a HASHREF" unless ref $occurrences eq 'HASH';
284              
285 0           my @argv = ($self->{correction}, scalar keys %$occurrences, keys %$occurrences, values %$occurrences);
286              
287 0           my %scores = $self->_run_script('scores', 0, @argv);
288              
289 0           return \%scores;
290             }
291              
292             sub _labels {
293 0     0     my ($self) = @_;
294              
295 0           return $self->_exec('smembers', LABELS);
296             }
297              
298             sub _priors {
299 0     0     my ($self, $label) = @_;
300              
301 0           my %data = $self->_exec('hgetall', $label);
302 0           return { %data };
303             }
304              
305              
306             1;
307              
308             __END__