Skip to content

Commit

Permalink
Run Sieve scripts in async context
Browse files Browse the repository at this point in the history
  • Loading branch information
mdecimus committed May 24, 2024
1 parent ffdb7d7 commit 4e7087d
Show file tree
Hide file tree
Showing 19 changed files with 366 additions and 461 deletions.
116 changes: 49 additions & 67 deletions crates/common/src/scripts/plugins/bayes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ use nlp::{
};
use sieve::{runtime::Variable, FunctionMap};
use store::{write::key::KeySerializer, LookupStore, U64_LEN};
use tokio::runtime::Handle;

use super::PluginContext;

Expand All @@ -50,15 +49,15 @@ pub fn register_is_balanced(plugin_id: u32, fnc_map: &mut FunctionMap) {
fnc_map.set_external_function("bayes_is_balanced", plugin_id, 3);
}

pub fn exec_train(ctx: PluginContext<'_>) -> Variable {
train(ctx, true)
pub async fn exec_train(ctx: PluginContext<'_>) -> Variable {
train(ctx, true).await
}

pub fn exec_untrain(ctx: PluginContext<'_>) -> Variable {
train(ctx, false)
pub async fn exec_untrain(ctx: PluginContext<'_>) -> Variable {
train(ctx, false).await
}

fn train(ctx: PluginContext<'_>, is_train: bool) -> Variable {
async fn train(ctx: PluginContext<'_>, is_train: bool) -> Variable {
let span: &tracing::Span = ctx.span;
let store = match &ctx.arguments[0] {
Variable::String(v) if !v.is_empty() => ctx.core.storage.lookups.get(v.as_ref()),
Expand All @@ -82,7 +81,6 @@ fn train(ctx: PluginContext<'_>, is_train: bool) -> Variable {
if text.is_empty() {
return false.into();
}
let handle = ctx.handle;

// Train the model
let mut model = BayesModel::default();
Expand All @@ -109,18 +107,17 @@ fn train(ctx: PluginContext<'_>, is_train: bool) -> Variable {
let bayes_cache = &ctx.core.sieve.bayes_cache;
if is_train {
for (hash, weights) in model.weights {
if handle
.block_on(
store.counter_incr(
KeySerializer::new(U64_LEN)
.write(hash.h1)
.write(hash.h2)
.finalize(),
weights.into(),
None,
false,
),
if store
.counter_incr(
KeySerializer::new(U64_LEN)
.write(hash.h1)
.write(hash.h2)
.finalize(),
weights.into(),
None,
false,
)
.await
.is_err()
{
return false.into();
Expand All @@ -134,18 +131,17 @@ fn train(ctx: PluginContext<'_>, is_train: bool) -> Variable {
} else {
Weights { spam: 0, ham: 1 }
};
if handle
.block_on(
store.counter_incr(
KeySerializer::new(U64_LEN)
.write(0u64)
.write(0u64)
.finalize(),
weights.into(),
None,
false,
),
if store
.counter_incr(
KeySerializer::new(U64_LEN)
.write(0u64)
.write(0u64)
.finalize(),
weights.into(),
None,
false,
)
.await
.is_err()
{
return false.into();
Expand All @@ -160,7 +156,7 @@ fn train(ctx: PluginContext<'_>, is_train: bool) -> Variable {
true.into()
}

pub fn exec_classify(ctx: PluginContext<'_>) -> Variable {
pub async fn exec_classify(ctx: PluginContext<'_>) -> Variable {
let span = ctx.span;
let store = match &ctx.arguments[0] {
Variable::String(v) if !v.is_empty() => ctx.core.storage.lookups.get(v.as_ref()),
Expand Down Expand Up @@ -200,12 +196,10 @@ pub fn exec_classify(ctx: PluginContext<'_>) -> Variable {
}
}

let handle = ctx.handle;

// Obtain training counts
let bayes_cache = &ctx.core.sieve.bayes_cache;
let (spam_learns, ham_learns) =
if let Some(weights) = bayes_cache.get_or_update(TokenHash::default(), handle, store) {
if let Some(weights) = bayes_cache.get_or_update(TokenHash::default(), store).await {
(weights.spam, weights.ham)
} else {
tracing::warn!(
Expand All @@ -231,27 +225,25 @@ pub fn exec_classify(ctx: PluginContext<'_>) -> Variable {
}

// Classify the text
let mut tokens = Vec::new();
for token in OsbTokenizer::<_, TokenHash>::new(
BayesTokenizer::new(text.as_ref(), &ctx.core.smtp.resolvers.psl),
5,
) {
if let Some(weights) = bayes_cache.get_or_update(token.inner, store).await {
tokens.push(OsbToken {
inner: weights,
idx: token.idx,
});
}
}
classifier
.classify(
OsbTokenizer::<_, TokenHash>::new(
BayesTokenizer::new(text.as_ref(), &ctx.core.smtp.resolvers.psl),
5,
)
.filter_map(|t| {
OsbToken {
inner: bayes_cache.get_or_update(t.inner, handle, store)?,
idx: t.idx,
}
.into()
}),
ham_learns,
spam_learns,
)
.classify(tokens.into_iter(), ham_learns, spam_learns)
.map(Variable::from)
.unwrap_or_default()
}

pub fn exec_is_balanced(ctx: PluginContext<'_>) -> Variable {
pub async fn exec_is_balanced(ctx: PluginContext<'_>) -> Variable {
let min_balance = match &ctx.arguments[2] {
Variable::Float(n) => *n,
Variable::Integer(n) => *n as f64,
Expand Down Expand Up @@ -282,10 +274,9 @@ pub fn exec_is_balanced(ctx: PluginContext<'_>) -> Variable {
let learn_spam = ctx.arguments[1].to_bool();

// Obtain training counts
let handle = ctx.handle;
let bayes_cache = &ctx.core.sieve.bayes_cache;
let (spam_learns, ham_learns) =
if let Some(weights) = bayes_cache.get_or_update(TokenHash::default(), handle, store) {
if let Some(weights) = bayes_cache.get_or_update(TokenHash::default(), store).await {
(weights.spam as f64, weights.ham as f64)
} else {
tracing::warn!(
Expand Down Expand Up @@ -321,31 +312,22 @@ pub fn exec_is_balanced(ctx: PluginContext<'_>) -> Variable {
}

trait LookupOrInsert {
fn get_or_update(
&self,
hash: TokenHash,
handle: &Handle,
get_token: &LookupStore,
) -> Option<Weights>;
async fn get_or_update(&self, hash: TokenHash, get_token: &LookupStore) -> Option<Weights>;
}

impl LookupOrInsert for BayesTokenCache {
fn get_or_update(
&self,
hash: TokenHash,
handle: &Handle,
get_token: &LookupStore,
) -> Option<Weights> {
async fn get_or_update(&self, hash: TokenHash, get_token: &LookupStore) -> Option<Weights> {
if let Some(weights) = self.get(&hash) {
weights.unwrap_or_default().into()
} else if let Ok(num) = handle.block_on(
get_token.counter_get(
} else if let Ok(num) = get_token
.counter_get(
KeySerializer::new(U64_LEN)
.write(hash.h1)
.write(hash.h2)
.finalize(),
),
) {
)
.await
{
if num != 0 {
let weights = Weights::from(num);
self.insert_positive(hash, weights);
Expand Down
90 changes: 52 additions & 38 deletions crates/common/src/scripts/plugins/dns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,19 @@ pub fn register_exists(plugin_id: u32, fnc_map: &mut FunctionMap) {
fnc_map.set_external_function("dns_exists", plugin_id, 2);
}

pub fn exec(ctx: PluginContext<'_>) -> Variable {
pub async fn exec(ctx: PluginContext<'_>) -> Variable {
let entry = ctx.arguments[0].to_string();
let record_type = ctx.arguments[1].to_string();

if record_type.eq_ignore_ascii_case("ip") {
match ctx.handle.block_on(ctx.core.smtp.resolvers.dns.ip_lookup(
entry.as_ref(),
IpLookupStrategy::Ipv4thenIpv6,
10,
)) {
match ctx
.core
.smtp
.resolvers
.dns
.ip_lookup(entry.as_ref(), IpLookupStrategy::Ipv4thenIpv6, 10)
.await
{
Ok(result) => result
.iter()
.map(|ip| Variable::from(ip.to_string()))
Expand All @@ -54,10 +57,7 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable {
Err(err) => err.short_error().into(),
}
} else if record_type.eq_ignore_ascii_case("mx") {
match ctx
.handle
.block_on(ctx.core.smtp.resolvers.dns.mx_lookup(entry.as_ref()))
{
match ctx.core.smtp.resolvers.dns.mx_lookup(entry.as_ref()).await {
Ok(result) => result
.iter()
.flat_map(|mx| {
Expand All @@ -78,18 +78,19 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable {
}

match ctx
.handle
.block_on(ctx.core.smtp.resolvers.dns.txt_raw_lookup(entry.as_ref()))
.core
.smtp
.resolvers
.dns
.txt_raw_lookup(entry.as_ref())
.await
{
Ok(result) => Variable::from(String::from_utf8(result).unwrap_or_default()),
Err(err) => err.short_error().into(),
}
} else if record_type.eq_ignore_ascii_case("ptr") {
if let Ok(addr) = entry.parse::<IpAddr>() {
match ctx
.handle
.block_on(ctx.core.smtp.resolvers.dns.ptr_lookup(addr))
{
match ctx.core.smtp.resolvers.dns.ptr_lookup(addr).await {
Ok(result) => result
.iter()
.map(|host| Variable::from(host.to_string()))
Expand All @@ -110,8 +111,12 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable {
}

match ctx
.handle
.block_on(ctx.core.smtp.resolvers.dns.ipv4_lookup(entry.as_ref()))
.core
.smtp
.resolvers
.dns
.ipv4_lookup(entry.as_ref())
.await
{
Ok(result) => result
.iter()
Expand All @@ -122,8 +127,12 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable {
}
} else if record_type.eq_ignore_ascii_case("ipv6") {
match ctx
.handle
.block_on(ctx.core.smtp.resolvers.dns.ipv6_lookup(entry.as_ref()))
.core
.smtp
.resolvers
.dns
.ipv6_lookup(entry.as_ref())
.await
{
Ok(result) => result
.iter()
Expand All @@ -137,35 +146,32 @@ pub fn exec(ctx: PluginContext<'_>) -> Variable {
}
}

pub fn exec_exists(ctx: PluginContext<'_>) -> Variable {
pub async fn exec_exists(ctx: PluginContext<'_>) -> Variable {
let entry = ctx.arguments[0].to_string();
let record_type = ctx.arguments[1].to_string();

if record_type.eq_ignore_ascii_case("ip") {
match ctx.handle.block_on(ctx.core.smtp.resolvers.dns.ip_lookup(
entry.as_ref(),
IpLookupStrategy::Ipv4thenIpv6,
10,
)) {
match ctx
.core
.smtp
.resolvers
.dns
.ip_lookup(entry.as_ref(), IpLookupStrategy::Ipv4thenIpv6, 10)
.await
{
Ok(result) => i64::from(!result.is_empty()),
Err(Error::DnsRecordNotFound(_)) => 0,
Err(_) => -1,
}
} else if record_type.eq_ignore_ascii_case("mx") {
match ctx
.handle
.block_on(ctx.core.smtp.resolvers.dns.mx_lookup(entry.as_ref()))
{
match ctx.core.smtp.resolvers.dns.mx_lookup(entry.as_ref()).await {
Ok(result) => i64::from(result.iter().any(|mx| !mx.exchanges.is_empty())),
Err(Error::DnsRecordNotFound(_)) => 0,
Err(_) => -1,
}
} else if record_type.eq_ignore_ascii_case("ptr") {
if let Ok(addr) = entry.parse::<IpAddr>() {
match ctx
.handle
.block_on(ctx.core.smtp.resolvers.dns.ptr_lookup(addr))
{
match ctx.core.smtp.resolvers.dns.ptr_lookup(addr).await {
Ok(result) => i64::from(!result.is_empty()),
Err(Error::DnsRecordNotFound(_)) => 0,
Err(_) => -1,
Expand All @@ -182,17 +188,25 @@ pub fn exec_exists(ctx: PluginContext<'_>) -> Variable {
}

match ctx
.handle
.block_on(ctx.core.smtp.resolvers.dns.ipv4_lookup(entry.as_ref()))
.core
.smtp
.resolvers
.dns
.ipv4_lookup(entry.as_ref())
.await
{
Ok(result) => i64::from(!result.is_empty()),
Err(Error::DnsRecordNotFound(_)) => 0,
Err(_) => -1,
}
} else if record_type.eq_ignore_ascii_case("ipv6") {
match ctx
.handle
.block_on(ctx.core.smtp.resolvers.dns.ipv6_lookup(entry.as_ref()))
.core
.smtp
.resolvers
.dns
.ipv6_lookup(entry.as_ref())
.await
{
Ok(result) => i64::from(!result.is_empty()),
Err(Error::DnsRecordNotFound(_)) => 0,
Expand Down
Loading

0 comments on commit 4e7087d

Please sign in to comment.