275 lines
8.1 KiB
Rust
275 lines
8.1 KiB
Rust
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
// Copyright by contributors to this project.
|
|
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
|
|
|
|
use std::{
|
|
fmt::{self, Debug},
|
|
sync::{Arc, Mutex},
|
|
};
|
|
|
|
use rusqlite::{params, Connection, OptionalExtension};
|
|
|
|
use crate::SqLiteDataStorageError;
|
|
|
|
const INSERT_SQL: &str =
|
|
"INSERT INTO kvs (key, value) VALUES (?,?) ON CONFLICT(key) DO UPDATE SET value=excluded.value";
|
|
|
|
#[derive(Debug, Clone)]
|
|
/// SQLite key-value storage for application specific data.
|
|
pub struct SqLiteApplicationStorage {
|
|
connection: Arc<Mutex<Connection>>,
|
|
}
|
|
|
|
impl SqLiteApplicationStorage {
|
|
pub(crate) fn new(connection: Connection) -> SqLiteApplicationStorage {
|
|
SqLiteApplicationStorage {
|
|
connection: Arc::new(Mutex::new(connection)),
|
|
}
|
|
}
|
|
|
|
/// Insert `value` into storage indexed by `key`.
|
|
///
|
|
/// If a value already exists for `key` it will be overwritten.
|
|
pub fn insert(&self, key: &str, value: &[u8]) -> Result<(), SqLiteDataStorageError> {
|
|
let connection = self.connection.lock().unwrap();
|
|
|
|
// Upsert into the database
|
|
connection
|
|
.execute(INSERT_SQL, params![key, value])
|
|
.map(|_| ())
|
|
.map_err(sql_engine_error)
|
|
}
|
|
|
|
/// Execute multiple [`SqLiteApplicationStorage::insert`] operations in a transaction.
|
|
pub fn transact_insert(&self, items: &[Item]) -> Result<(), SqLiteDataStorageError> {
|
|
let mut connection = self.connection.lock().unwrap();
|
|
|
|
// Upsert into the database
|
|
let tx = connection.transaction().map_err(sql_engine_error)?;
|
|
|
|
items.iter().try_for_each(|item| {
|
|
tx.execute(INSERT_SQL, params![item.key, item.value])
|
|
.map_err(sql_engine_error)
|
|
.map(|_| ())
|
|
})?;
|
|
|
|
tx.commit().map_err(sql_engine_error)?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Get a value from storage based on its `key`.
|
|
pub fn get(&self, key: &str) -> Result<Option<Vec<u8>>, SqLiteDataStorageError> {
|
|
let connection = self.connection.lock().unwrap();
|
|
|
|
connection
|
|
.query_row("SELECT value FROM kvs WHERE key = ?", params![key], |row| {
|
|
row.get(0)
|
|
})
|
|
.optional()
|
|
.map_err(sql_engine_error)
|
|
}
|
|
|
|
/// Delete a value from storage based on its `key`.
|
|
pub fn delete(&self, key: &str) -> Result<(), SqLiteDataStorageError> {
|
|
let connection = self.connection.lock().unwrap();
|
|
|
|
connection
|
|
.execute("DELETE FROM kvs WHERE key = ?", params![key])
|
|
.map(|_| ())
|
|
.map_err(sql_engine_error)
|
|
}
|
|
|
|
/// Get all keys and values from storage for which key starts with `key_prefix`.
|
|
pub fn get_by_prefix(&self, key_prefix: &str) -> Result<Vec<Item>, SqLiteDataStorageError> {
|
|
let connection = self.connection.lock().unwrap();
|
|
let mut key_prefix = sanitize(key_prefix);
|
|
key_prefix.push('%');
|
|
|
|
let mut stmt = connection
|
|
.prepare("SELECT key, value FROM kvs WHERE key LIKE ? ESCAPE '$'")
|
|
.map_err(sql_engine_error)?;
|
|
|
|
let rows = stmt
|
|
.query(params![key_prefix])
|
|
.map_err(sql_engine_error)?
|
|
.mapped(|row| Ok(Item::new(row.get(0)?, row.get(1)?)));
|
|
|
|
rows.collect::<Result<_, _>>().map_err(sql_engine_error)
|
|
}
|
|
|
|
/// Delete all values from storage for which key starts with `key_prefix`.
|
|
pub fn delete_by_prefix(&self, key_prefix: &str) -> Result<(), SqLiteDataStorageError> {
|
|
let connection = self.connection.lock().unwrap();
|
|
let mut key_prefix = sanitize(key_prefix);
|
|
key_prefix.push('%');
|
|
|
|
connection
|
|
.execute(
|
|
"DELETE FROM kvs WHERE key LIKE ? ESCAPE '$'",
|
|
params![key_prefix],
|
|
)
|
|
.map(|_| ())
|
|
.map_err(sql_engine_error)
|
|
}
|
|
}
|
|
|
|
fn sanitize(string: &str) -> String {
|
|
string.replace('_', "$_").replace('%', "$%")
|
|
}
|
|
|
|
fn sql_engine_error(e: rusqlite::Error) -> SqLiteDataStorageError {
|
|
SqLiteDataStorageError::SqlEngineError(e.into())
|
|
}
|
|
|
|
#[derive(Clone, Default, Hash, PartialEq, Eq, PartialOrd, Ord)]
|
|
pub struct Item {
|
|
pub key: String,
|
|
pub value: Vec<u8>,
|
|
}
|
|
|
|
impl Debug for Item {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
f.debug_struct("Item")
|
|
.field("key", &self.key)
|
|
.field("value", &mls_rs_core::debug::pretty_bytes(&self.value))
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
impl Item {
|
|
pub fn new(key: String, value: Vec<u8>) -> Self {
|
|
Self { key, value }
|
|
}
|
|
|
|
pub fn key(&self) -> &str {
|
|
&self.key
|
|
}
|
|
|
|
pub fn value(&self) -> &[u8] {
|
|
&self.value
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use crate::{
|
|
application::Item, connection_strategy::MemoryStrategy, test_utils::gen_rand_bytes,
|
|
SqLiteDataStorageEngine,
|
|
};
|
|
|
|
use super::SqLiteApplicationStorage;
|
|
|
|
fn test_kv() -> (String, Vec<u8>) {
|
|
let key = hex::encode(gen_rand_bytes(32));
|
|
let value = gen_rand_bytes(64);
|
|
|
|
(key, value)
|
|
}
|
|
|
|
fn test_storage() -> SqLiteApplicationStorage {
|
|
SqLiteDataStorageEngine::new(MemoryStrategy)
|
|
.unwrap()
|
|
.application_data_storage()
|
|
.unwrap()
|
|
}
|
|
|
|
#[test]
|
|
fn test_insert() {
|
|
let (key, value) = test_kv();
|
|
let storage = test_storage();
|
|
|
|
storage.insert(&key, &value).unwrap();
|
|
|
|
let from_storage = storage.get(&key).unwrap().unwrap();
|
|
assert_eq!(from_storage, value);
|
|
}
|
|
|
|
#[test]
|
|
fn test_insert_existing_overwrite() {
|
|
let (key, value) = test_kv();
|
|
let (_, new_value) = test_kv();
|
|
|
|
let storage = test_storage();
|
|
|
|
storage.insert(&key, &value).unwrap();
|
|
storage.insert(&key, &new_value).unwrap();
|
|
|
|
let from_storage = storage.get(&key).unwrap().unwrap();
|
|
assert_eq!(from_storage, new_value);
|
|
}
|
|
|
|
#[test]
|
|
fn test_delete() {
|
|
let (key, value) = test_kv();
|
|
let storage = test_storage();
|
|
|
|
storage.insert(&key, &value).unwrap();
|
|
storage.delete(&key).unwrap();
|
|
|
|
assert!(storage.get(&key).unwrap().is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_by_prefix() {
|
|
let keys = ["prefix one", "prefix two", "prefiy ", "prefiw "].map(ToString::to_string);
|
|
let value = gen_rand_bytes(5);
|
|
|
|
let storage = test_storage();
|
|
|
|
keys.iter().for_each(|k| storage.insert(k, &value).unwrap());
|
|
|
|
let mut expected = vec![
|
|
Item::new(keys[0].clone(), value.clone()),
|
|
Item::new(keys[1].clone(), value.clone()),
|
|
];
|
|
|
|
expected.sort();
|
|
|
|
let mut result = storage.get_by_prefix("prefix").unwrap();
|
|
result.sort();
|
|
|
|
assert_eq!(result, expected);
|
|
|
|
let result = storage.get_by_prefix("a").unwrap();
|
|
assert!(result.is_empty());
|
|
|
|
let result = storage.get_by_prefix("").unwrap();
|
|
assert_eq!(result.len(), keys.len());
|
|
|
|
storage.delete_by_prefix("prefix").unwrap();
|
|
let result = storage.get_by_prefix("").unwrap();
|
|
assert_eq!(result.len(), 2);
|
|
assert!(result.contains(&Item::new("prefiy ".to_string(), value.clone())));
|
|
assert!(result.contains(&Item::new("prefiw ".to_string(), value)));
|
|
}
|
|
|
|
#[test]
|
|
fn test_special_characters() {
|
|
let storage = test_storage();
|
|
|
|
storage.insert("%$_ƕ❤_$%", &gen_rand_bytes(5)).unwrap();
|
|
storage.insert("%$_ƕ❤a$%", &gen_rand_bytes(5)).unwrap();
|
|
storage.insert("%$_ƕ❤Ḉ$%", &gen_rand_bytes(5)).unwrap();
|
|
|
|
let items = storage.get_by_prefix("%$_ƕ❤_").unwrap();
|
|
let keys = items.into_iter().map(|i| i.key).collect::<Vec<_>>();
|
|
assert_eq!(vec!["%$_ƕ❤_$%".to_string()], keys);
|
|
}
|
|
|
|
#[test]
|
|
fn batch_insert() {
|
|
let storage = test_storage();
|
|
let items = vec![test_item(), test_item(), test_item()];
|
|
|
|
storage.transact_insert(&items).unwrap();
|
|
|
|
for item in items {
|
|
assert_eq!(storage.get(&item.key).unwrap(), Some(item.value));
|
|
}
|
|
}
|
|
|
|
fn test_item() -> Item {
|
|
Item::new(hex::encode(gen_rand_bytes(5)), gen_rand_bytes(5))
|
|
}
|
|
}
|