280 lines
8.9 KiB
Rust
280 lines
8.9 KiB
Rust
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
// Copyright by contributors to this project.
|
|
// SPDX-License-Identifier: (Apache-2.0 OR MIT)
|
|
|
|
use mls_rs_core::{
|
|
key_package::{KeyPackageData, KeyPackageStorage},
|
|
mls_rs_codec::{MlsDecode, MlsEncode},
|
|
time::MlsTime,
|
|
};
|
|
use rusqlite::{params, Connection, OptionalExtension};
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
use crate::SqLiteDataStorageError;
|
|
|
|
#[derive(Debug, Clone)]
|
|
/// SQLite storage for MLS Key Packages.
|
|
pub struct SqLiteKeyPackageStorage {
|
|
connection: Arc<Mutex<Connection>>,
|
|
}
|
|
|
|
impl SqLiteKeyPackageStorage {
|
|
pub(crate) fn new(connection: Connection) -> SqLiteKeyPackageStorage {
|
|
SqLiteKeyPackageStorage {
|
|
connection: Arc::new(Mutex::new(connection)),
|
|
}
|
|
}
|
|
|
|
fn insert(
|
|
&mut self,
|
|
id: &[u8],
|
|
key_package: KeyPackageData,
|
|
) -> Result<(), SqLiteDataStorageError> {
|
|
let connection = self.connection.lock().unwrap();
|
|
|
|
connection
|
|
.execute(
|
|
"INSERT INTO key_package (id, expiration, data) VALUES (?,?,?)",
|
|
params![
|
|
id,
|
|
key_package.expiration,
|
|
key_package
|
|
.mls_encode_to_vec()
|
|
.map_err(|e| SqLiteDataStorageError::DataConversionError(e.into()))?
|
|
],
|
|
)
|
|
.map(|_| ())
|
|
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
|
|
}
|
|
|
|
fn get(&self, id: &[u8]) -> Result<Option<KeyPackageData>, SqLiteDataStorageError> {
|
|
let connection = self.connection.lock().unwrap();
|
|
|
|
connection
|
|
.query_row(
|
|
"SELECT data FROM key_package WHERE id = ?",
|
|
params![id],
|
|
|row| {
|
|
Ok(
|
|
KeyPackageData::mls_decode(&mut row.get::<_, Vec<u8>>(0)?.as_slice())
|
|
.unwrap(),
|
|
)
|
|
},
|
|
)
|
|
.optional()
|
|
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
|
|
}
|
|
|
|
/// Delete a specific key package from storage based on it's id.
|
|
pub fn delete(&self, id: &[u8]) -> Result<(), SqLiteDataStorageError> {
|
|
let connection = self.connection.lock().unwrap();
|
|
|
|
connection
|
|
.execute("DELETE FROM key_package where id = ?", params![id])
|
|
.map(|_| ())
|
|
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
|
|
}
|
|
|
|
/// Delete key packages that are expired based on the current system clock time.
|
|
pub fn delete_expired(&self) -> Result<(), SqLiteDataStorageError> {
|
|
self.delete_expired_by_time(MlsTime::now().seconds_since_epoch())
|
|
}
|
|
|
|
/// Delete key packages that are expired based on an application provided time in seconds since
|
|
/// unix epoch.
|
|
pub fn delete_expired_by_time(&self, time: u64) -> Result<(), SqLiteDataStorageError> {
|
|
let connection = self.connection.lock().unwrap();
|
|
|
|
connection
|
|
.execute(
|
|
"DELETE FROM key_package where expiration < ?",
|
|
params![time],
|
|
)
|
|
.map(|_| ())
|
|
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
|
|
}
|
|
|
|
/// Total number of key packages held in storage.
|
|
pub fn count(&self) -> Result<usize, SqLiteDataStorageError> {
|
|
let connection = self.connection.lock().unwrap();
|
|
|
|
connection
|
|
.query_row("SELECT count(*) FROM key_package", params![], |row| {
|
|
row.get(0)
|
|
})
|
|
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
|
|
}
|
|
|
|
/// Total number of key packages that will still remain in storage at a specific application provided
|
|
/// time in seconds since unix epoch. This assumes that the application would also be calling
|
|
/// [SqLiteKeyPackageStorage::delete_expired] at a reasonable cadence to be accurate.
|
|
pub fn count_at_time(&self, time: u64) -> Result<usize, SqLiteDataStorageError> {
|
|
let connection = self.connection.lock().unwrap();
|
|
|
|
connection
|
|
.query_row(
|
|
"SELECT count(*) FROM key_package where expiration >= ?",
|
|
params![time],
|
|
|row| row.get(0),
|
|
)
|
|
.map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
|
|
}
|
|
}
|
|
|
|
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
|
|
#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
|
|
impl KeyPackageStorage for SqLiteKeyPackageStorage {
|
|
type Error = SqLiteDataStorageError;
|
|
|
|
async fn insert(&mut self, id: Vec<u8>, pkg: KeyPackageData) -> Result<(), Self::Error> {
|
|
self.insert(id.as_slice(), pkg)
|
|
}
|
|
|
|
async fn get(&self, id: &[u8]) -> Result<Option<KeyPackageData>, Self::Error> {
|
|
self.get(id)
|
|
}
|
|
|
|
async fn delete(&mut self, id: &[u8]) -> Result<(), Self::Error> {
|
|
(*self).delete(id)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::SqLiteKeyPackageStorage;
|
|
use crate::{
|
|
SqLiteDataStorageEngine, SqLiteDataStorageError,
|
|
{connection_strategy::MemoryStrategy, test_utils::gen_rand_bytes},
|
|
};
|
|
use assert_matches::assert_matches;
|
|
use mls_rs_core::{crypto::HpkeSecretKey, key_package::KeyPackageData};
|
|
|
|
fn test_storage() -> SqLiteKeyPackageStorage {
|
|
SqLiteDataStorageEngine::new(MemoryStrategy)
|
|
.unwrap()
|
|
.key_package_storage()
|
|
.unwrap()
|
|
}
|
|
|
|
fn test_key_package() -> (Vec<u8>, KeyPackageData) {
|
|
let key_id = gen_rand_bytes(32);
|
|
let key_package = KeyPackageData::new(
|
|
gen_rand_bytes(256),
|
|
HpkeSecretKey::from(gen_rand_bytes(256)),
|
|
HpkeSecretKey::from(gen_rand_bytes(256)),
|
|
123,
|
|
);
|
|
|
|
(key_id, key_package)
|
|
}
|
|
|
|
#[test]
|
|
fn key_package_insert() {
|
|
let mut storage = test_storage();
|
|
let (key_package_id, key_package) = test_key_package();
|
|
|
|
storage
|
|
.insert(&key_package_id, key_package.clone())
|
|
.unwrap();
|
|
|
|
let from_storage = storage.get(&key_package_id).unwrap().unwrap();
|
|
assert_eq!(from_storage, key_package);
|
|
}
|
|
|
|
#[test]
|
|
fn duplicate_insert_should_fail() {
|
|
let mut storage = test_storage();
|
|
let (key_package_id, key_package) = test_key_package();
|
|
|
|
storage
|
|
.insert(&key_package_id, key_package.clone())
|
|
.unwrap();
|
|
|
|
let dupe_res = storage.insert(&key_package_id, key_package);
|
|
|
|
assert_matches!(dupe_res, Err(SqLiteDataStorageError::SqlEngineError(_)));
|
|
}
|
|
|
|
#[test]
|
|
fn key_package_not_found() {
|
|
let mut storage = test_storage();
|
|
let (key_package_id, key_package) = test_key_package();
|
|
|
|
storage.insert(&key_package_id, key_package).unwrap();
|
|
|
|
let (another_package_id, _) = test_key_package();
|
|
|
|
assert!(storage.get(&another_package_id).unwrap().is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn key_package_delete() {
|
|
let mut storage = test_storage();
|
|
let (key_package_id, key_package) = test_key_package();
|
|
|
|
storage.insert(&key_package_id, key_package).unwrap();
|
|
|
|
storage.delete(&key_package_id).unwrap();
|
|
assert!(storage.get(&key_package_id).unwrap().is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn expired_key_package_gelete() {
|
|
let mut storage = test_storage();
|
|
|
|
let data = [1, 15, 30, 1698652376].map(|exp| {
|
|
let mut kp = test_key_package();
|
|
kp.1.expiration = exp;
|
|
kp
|
|
});
|
|
|
|
for (id, data) in &data {
|
|
storage.insert(id, data.clone()).unwrap();
|
|
}
|
|
|
|
storage.delete_expired_by_time(30).unwrap();
|
|
|
|
assert!(storage.get(&data[0].0).unwrap().is_none());
|
|
assert!(storage.get(&data[1].0).unwrap().is_none());
|
|
storage.get(&data[2].0).unwrap().unwrap();
|
|
storage.get(&data[3].0).unwrap().unwrap();
|
|
|
|
storage.delete_expired().unwrap();
|
|
|
|
assert!(storage.get(&data[2].0).unwrap().is_none());
|
|
assert!(storage.get(&data[3].0).unwrap().is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn key_count() {
|
|
let mut storage = test_storage();
|
|
|
|
let test_packages = (0..10).map(|_| test_key_package()).collect::<Vec<_>>();
|
|
|
|
test_packages
|
|
.into_iter()
|
|
.for_each(|(key_package_id, key_package)| {
|
|
storage.insert(&key_package_id, key_package).unwrap();
|
|
});
|
|
|
|
assert_eq!(storage.count().unwrap(), 10);
|
|
}
|
|
|
|
#[test]
|
|
fn key_count_at_time() {
|
|
let mut storage = test_storage();
|
|
|
|
let mut kp_1 = test_key_package();
|
|
kp_1.1.expiration = 1;
|
|
storage.insert(&kp_1.0, kp_1.1).unwrap();
|
|
|
|
let mut kp_2 = test_key_package();
|
|
kp_2.1.expiration = 2;
|
|
storage.insert(&kp_2.0, kp_2.1).unwrap();
|
|
|
|
assert_eq!(storage.count_at_time(3).unwrap(), 0);
|
|
assert_eq!(storage.count_at_time(2).unwrap(), 1);
|
|
assert_eq!(storage.count_at_time(1).unwrap(), 2);
|
|
assert_eq!(storage.count_at_time(0).unwrap(), 2);
|
|
}
|
|
}
|