diff --git a/src/error.rs b/src/error.rs index b3ab67be2a5fb62669b384157a0cc5cb00a43633..bd25d7aa8425cfd95e940d117be680e431241698 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,6 +12,8 @@ pub enum KomodoError { IncompatibleMatrixShapes(usize, usize, usize, usize), #[error("Expected at least {1}, got {0}")] TooFewShards(usize, usize), + #[error("Blocks are incompatible: {0}")] + IncompatibleBlocks(String), #[error("Another error: {0}")] Other(String), } diff --git a/src/lib.rs b/src/lib.rs index 75be1ca0367b48be3818ce2eaeb97a194f2cff83..7e416931c9fce58a567f7aebab3d821e6dd6d850 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,8 @@ mod field; mod linalg; pub mod setup; +use error::KomodoError; + #[derive(Debug, Default, Clone, PartialEq, CanonicalSerialize, CanonicalDeserialize)] pub struct Block<E: Pairing> { pub shard: fec::Shard<E>, @@ -194,17 +196,48 @@ where prove::<E, P>(commits, hash, bytes.len(), polynomials, &points) } -pub fn recode<E: Pairing>(b1: &Block<E>, b2: &Block<E>) -> Block<E> { +pub fn recode<E: Pairing>(b1: &Block<E>, b2: &Block<E>) -> Result<Block<E>, KomodoError> { let mut rng = rand::thread_rng(); let alpha = E::ScalarField::rand(&mut rng); let beta = E::ScalarField::rand(&mut rng); - Block { + if b1.shard.k != b2.shard.k { + return Err(KomodoError::IncompatibleBlocks(format!( + "k is not the same: {} vs {}", + b1.shard.k, b2.shard.k + ))); + } + if b1.shard.hash != b2.shard.hash { + return Err(KomodoError::IncompatibleBlocks(format!( + "hash is not the same: {:?} vs {:?}", + b1.shard.hash, b2.shard.hash + ))); + } + if b1.shard.size != b2.shard.size { + return Err(KomodoError::IncompatibleBlocks(format!( + "size is not the same: {} vs {}", + b1.shard.size, b2.shard.size + ))); + } + if b1.m != b2.m { + return Err(KomodoError::IncompatibleBlocks(format!( + "m is not the same: {} vs {}", + b1.m, b2.m + ))); + } + if b1.commit != b2.commit { + return Err(KomodoError::IncompatibleBlocks(format!( + "commits are not the same: {:?} vs {:?}", + b1.commit, b2.commit + ))); + } + + Ok(Block { shard: b1.shard.combine(alpha, &b2.shard, beta), commit: b1.commit.clone(), m: b1.m, - } + }) } pub fn verify<E, P>( @@ -431,8 +464,14 @@ mod tests { let powers = setup::random(bytes.len())?; let blocks = encode::<E, P>(bytes, k, n, &powers)?; - assert!(verify::<E, P>(&recode(&blocks[2], &blocks[3]), &powers)?); - assert!(verify::<E, P>(&recode(&blocks[3], &blocks[5]), &powers)?); + assert!(verify::<E, P>( + &recode(&blocks[2], &blocks[3]).unwrap(), + &powers + )?); + assert!(verify::<E, P>( + &recode(&blocks[3], &blocks[5]).unwrap(), + &powers + )?); Ok(()) } @@ -503,7 +542,7 @@ mod tests { let powers = setup::random(bytes.len())?; let blocks = encode::<E, P>(bytes, 3, 5, &powers)?; - let b_0_1 = recode(&blocks[0], &blocks[1]); + let b_0_1 = recode(&blocks[0], &blocks[1]).unwrap(); let shards = vec![ b_0_1.shard, blocks[2].shard.clone(), @@ -511,7 +550,7 @@ mod tests { ]; assert_eq!(bytes, decode::<E>(shards, true).unwrap()); - let b_0_1 = recode(&blocks[0], &blocks[1]); + let b_0_1 = recode(&blocks[0], &blocks[1]).unwrap(); let shards = vec![ blocks[0].shard.clone(), blocks[1].shard.clone(), @@ -519,14 +558,18 @@ mod tests { ]; assert!(decode::<E>(shards, true).is_err()); - let b_0_1 = recode(&blocks[0], &blocks[1]); - let b_2_3 = recode(&blocks[2], &blocks[3]); - let b_1_4 = recode(&blocks[1], &blocks[4]); + let b_0_1 = recode(&blocks[0], &blocks[1]).unwrap(); + let b_2_3 = recode(&blocks[2], &blocks[3]).unwrap(); + let b_1_4 = recode(&blocks[1], &blocks[4]).unwrap(); let shards = vec![b_0_1.shard, b_2_3.shard, b_1_4.shard]; assert_eq!(bytes, decode::<E>(shards, true).unwrap()); let fully_recoded_shards = (0..3) - .map(|_| recode(&recode(&blocks[0], &blocks[1]), &blocks[2]).shard) + .map(|_| { + recode(&recode(&blocks[0], &blocks[1]).unwrap(), &blocks[2]) + .unwrap() + .shard + }) .collect(); assert_eq!(bytes, decode::<E>(fully_recoded_shards, true).unwrap()); diff --git a/src/main.rs b/src/main.rs index b130817d0eca5e3905b6161cc69db283bf45d2a0..4d4080fca9ed4c7fc5657f86f5508cce738fb37b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -272,8 +272,14 @@ fn main() { ); } - dump_blocks(&[recode(&blocks[0].1, &blocks[1].1)], &block_dir) - .unwrap_or_else(|e| throw_error(1, &format!("could not dump block: {}", e))); + dump_blocks( + &[recode(&blocks[0].1, &blocks[1].1).unwrap_or_else(|e| { + throw_error(1, &format!("could not encode block: {}", e)); + unreachable!() + })], + &block_dir, + ) + .unwrap_or_else(|e| throw_error(1, &format!("could not dump block: {}", e))); exit(0); }