1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
use std::io;

use dusk_plonk::bls12_381::Scalar as BlsScalar;
use kelvin::{
    annotation,
    annotations::{Cardinality, Count},
    Branch, BranchMut, ByteHash, Compound, Content, Sink, Source,
};
use nstack::NStack;

use crate::merkle_lvl_hash::hash::*;
use crate::merkle_proof::poseidon_branch::extend_scalar;
use crate::ARITY;
use crate::{PoseidonBranch, PoseidonLevel, StorageScalar};

annotation! {
    /// The annotation for the PoseidonTree
    pub struct PoseidonAnnotation {
        scalar: StorageScalar,
        count: Cardinality<u64>,
    }
}

/// A zk-friendly datastructure to store elements
pub struct PoseidonTree<T, H>
where
    T: Content<H>,
    for<'a> &'a T: Into<StorageScalar>,
    H: ByteHash,
{
    branch_depth: u16,
    inner: NStack<T, PoseidonAnnotation, H>,
}

impl<T, H> Clone for PoseidonTree<T, H>
where
    T: Content<H>,
    for<'a> &'a T: Into<StorageScalar>,
    H: ByteHash,
{
    fn clone(&self) -> Self {
        PoseidonTree {
            branch_depth: self.branch_depth,
            inner: self.inner.clone(),
        }
    }
}

impl<T, H> Content<H> for PoseidonTree<T, H>
where
    T: Content<H>,
    for<'a> &'a T: Into<StorageScalar>,
    H: ByteHash,
{
    fn persist(&mut self, sink: &mut Sink<H>) -> io::Result<()> {
        self.branch_depth.persist(sink)?;
        self.inner.persist(sink)
    }

    fn restore(source: &mut Source<H>) -> io::Result<Self> {
        Ok(PoseidonTree {
            branch_depth: u16::restore(source)?,
            inner: NStack::restore(source)?,
        })
    }
}

impl<T, H> PoseidonTree<T, H>
where
    T: Content<H>,
    for<'a> &'a T: Into<StorageScalar>,
    H: ByteHash,
{
    /// Constructs a new empty PoseidonTree
    pub fn new(depth: usize) -> Self {
        PoseidonTree {
            branch_depth: depth as u16,
            inner: Default::default(),
        }
    }

    /// Returns the scalar root-hash of the poseidon tree
    ///
    /// This includes padding the value to the correct branch length equivalent
    pub fn root(&self) -> io::Result<BlsScalar> {
        if let Some(ann) = self.inner.annotation() {
            let borrow: &StorageScalar = ann.borrow();
            let scalar: BlsScalar = borrow.clone().into();

            // FIXME, depth could be inferred from the cardinality
            if let Some(branch) = self.get(0)? {
                let depth = branch.levels().len();
                Ok(extend_scalar(scalar, self.branch_depth as usize - depth))
            } else {
                unreachable!("Annotation in empty tree")
            }
        } else {
            // empty case, use an empty level for hashing
            let leaves = [BlsScalar::zero(); ARITY + 1];
            let level = PoseidonLevel { leaves, offset: 0 };
            let root = merkle_level_hash_without_bitflags(&level);
            Ok(extend_scalar(root, self.branch_depth as usize))
        }
    }

    /// Returns a poseidon branch pointing at the specific index
    ///
    /// This includes padding the value to the correct branch length equivalent
    pub fn poseidon_branch(
        &self,
        idx: u64,
    ) -> io::Result<Option<PoseidonBranch>> {
        Ok(self.inner.get(idx)?.map(|ref branch| {
            let mut pbranch: PoseidonBranch = branch.into();
            pbranch.extend(self.branch_depth as usize);
            pbranch
        }))
    }

    /// Push a new item onto the tree
    pub fn push(&mut self, t: T) -> io::Result<u64> {
        let idx = self.inner.count();
        self.inner.push(t)?;
        Ok(idx)
    }

    /// Get a branch reference to the element at index `idx`, if any
    pub fn get(
        &self,
        idx: u64,
    ) -> io::Result<Option<Branch<NStack<T, PoseidonAnnotation, H>, H>>> {
        self.inner.get(idx)
    }

    /// Get a mutable branch reference to the element at index `idx`, if any
    pub fn get_mut(
        &mut self,
        idx: u64,
    ) -> io::Result<Option<BranchMut<NStack<T, PoseidonAnnotation, H>, H>>>
    {
        self.inner.get_mut(idx)
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use kelvin::Blake2b;

    #[test]
    fn insert() {
        let mut tree = PoseidonTree::<_, Blake2b>::new(17);

        for i in 0..128u64 {
            let idx = tree.push(StorageScalar::from(i)).unwrap();
            assert_eq!(idx, i);
        }

        assert!(true)
    }
}