1
+ use crate :: error:: NTTError ;
1
2
use anchor_lang:: prelude:: * ;
2
3
use bitmaps:: Bitmap as BM ;
3
-
4
+ use std :: result :: Result as StdResult ;
4
5
#[ derive( PartialEq , Eq , Clone , Copy , Debug , AnchorDeserialize , AnchorSerialize , InitSpace ) ]
5
6
pub struct Bitmap {
6
7
map : u128 ,
@@ -13,6 +14,8 @@ impl Default for Bitmap {
13
14
}
14
15
15
16
impl Bitmap {
17
+ pub const BITS : u8 = 128 ;
18
+
16
19
pub fn new ( ) -> Self {
17
20
Bitmap { map : 0 }
18
21
}
@@ -21,19 +24,30 @@ impl Bitmap {
21
24
Bitmap { map : value }
22
25
}
23
26
24
- pub fn set ( & mut self , index : u8 , value : bool ) {
27
+ pub fn set ( & mut self , index : u8 , value : bool ) -> StdResult < ( ) , NTTError > {
28
+ if index >= Self :: BITS {
29
+ return Err ( NTTError :: BitmapIndexOutOfBounds ) ;
30
+ }
25
31
let mut bm = BM :: < 128 > :: from_value ( self . map ) ;
26
- bm. set ( index as usize , value) ;
32
+ bm. set ( usize:: from ( index ) , value) ;
27
33
self . map = * bm. as_value ( ) ;
34
+ Ok ( ( ) )
28
35
}
29
36
30
- pub fn get ( & self , index : u8 ) -> bool {
31
- BM :: < 128 > :: from_value ( self . map ) . get ( index as usize )
37
+ pub fn get ( & self , index : u8 ) -> StdResult < bool , NTTError > {
38
+ if index >= Self :: BITS {
39
+ return Err ( NTTError :: BitmapIndexOutOfBounds ) ;
40
+ }
41
+ Ok ( BM :: < 128 > :: from_value ( self . map ) . get ( usize:: from ( index) ) )
32
42
}
33
43
34
44
pub fn count_enabled_votes ( & self , enabled : Bitmap ) -> u8 {
35
45
let bm = BM :: < 128 > :: from_value ( self . map ) & BM :: < 128 > :: from_value ( enabled. map ) ;
36
- bm. len ( ) as u8
46
+ // Conversion from usize to u8 is safe here. The Bitmap uses u128, so its maximum length
47
+ // (number of true bits) is 128.
48
+ bm. len ( )
49
+ . try_into ( )
50
+ . expect ( "Bitmap length must not exceed the bounds of u8" )
37
51
}
38
52
}
39
53
@@ -46,22 +60,40 @@ mod tests {
46
60
let mut enabled = Bitmap :: from_value ( u128:: MAX ) ;
47
61
let mut bm = Bitmap :: new ( ) ;
48
62
assert_eq ! ( bm. count_enabled_votes( enabled) , 0 ) ;
49
- bm. set ( 0 , true ) ;
63
+ bm. set ( 0 , true ) . unwrap ( ) ;
50
64
assert_eq ! ( bm. count_enabled_votes( enabled) , 1 ) ;
51
- assert ! ( bm. get( 0 ) ) ;
52
- assert ! ( !bm. get( 1 ) ) ;
53
- bm. set ( 1 , true ) ;
65
+ assert ! ( bm. get( 0 ) . unwrap ( ) ) ;
66
+ assert ! ( !bm. get( 1 ) . unwrap ( ) ) ;
67
+ bm. set ( 1 , true ) . unwrap ( ) ;
54
68
assert_eq ! ( bm. count_enabled_votes( enabled) , 2 ) ;
55
- assert ! ( bm. get( 0 ) ) ;
56
- assert ! ( bm. get( 1 ) ) ;
57
- bm. set ( 0 , false ) ;
69
+ assert ! ( bm. get( 0 ) . unwrap ( ) ) ;
70
+ assert ! ( bm. get( 1 ) . unwrap ( ) ) ;
71
+ bm. set ( 0 , false ) . unwrap ( ) ;
58
72
assert_eq ! ( bm. count_enabled_votes( enabled) , 1 ) ;
59
- assert ! ( !bm. get( 0 ) ) ;
60
- assert ! ( bm. get( 1 ) ) ;
61
- bm. set ( 18 , true ) ;
73
+ assert ! ( !bm. get( 0 ) . unwrap ( ) ) ;
74
+ assert ! ( bm. get( 1 ) . unwrap ( ) ) ;
75
+ bm. set ( 18 , true ) . unwrap ( ) ;
62
76
assert_eq ! ( bm. count_enabled_votes( enabled) , 2 ) ;
63
77
64
- enabled. set ( 18 , false ) ;
78
+ enabled. set ( 18 , false ) . unwrap ( ) ;
65
79
assert_eq ! ( bm. count_enabled_votes( enabled) , 1 ) ;
66
80
}
81
+
82
+ #[ test]
83
+ fn test_bitmap_len ( ) {
84
+ let max_bitmap = Bitmap :: from_value ( u128:: MAX ) ;
85
+ assert_eq ! ( 128 , max_bitmap. count_enabled_votes( max_bitmap) ) ;
86
+ }
87
+
88
+ #[ test]
89
+ fn test_bitmap_get_out_of_bounds ( ) {
90
+ let bm = Bitmap :: new ( ) ;
91
+ assert_eq ! ( bm. get( 129 ) , Err ( NTTError :: BitmapIndexOutOfBounds ) ) ;
92
+ }
93
+
94
+ #[ test]
95
+ fn test_bitmap_set_out_of_bounds ( ) {
96
+ let mut bm = Bitmap :: new ( ) ;
97
+ assert_eq ! ( bm. set( 129 , false ) , Err ( NTTError :: BitmapIndexOutOfBounds ) ) ;
98
+ }
67
99
}
0 commit comments