Skip to content

Commit

Permalink
Consume gas before verifying range proof (#1987)
Browse files Browse the repository at this point in the history
* Consume gas before verifying range proof

* bump timeout

* uint64

* test issue
  • Loading branch information
mj850 authored Dec 13, 2024
1 parent 0999e20 commit e19c16c
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 91 deletions.
2 changes: 1 addition & 1 deletion proto/confidentialtransfers/params.proto
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ option go_package = "github.com/sei-protocol/sei-chain/x/confidentialtransfers/t
// Params defines the parameters for the confidential tokens module.
message Params {
bool enable_ct_module = 1;
uint32 range_proof_gas_multiplier = 2;
uint64 range_proof_gas_cost = 2;
}
8 changes: 4 additions & 4 deletions x/confidentialtransfers/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type Keeper interface {
SetParams(ctx sdk.Context, params types.Params)

IsCtModuleEnabled(ctx sdk.Context) bool
GetRangeProofGasMultiplier(ctx sdk.Context) uint32
GetRangeProofGasCost(ctx sdk.Context) uint64

BankKeeper() types.BankKeeper

Expand Down Expand Up @@ -153,9 +153,9 @@ func (k BaseKeeper) IsCtModuleEnabled(ctx sdk.Context) bool {
return isCtModuleEnabled
}

// GetRangeProofGasMultiplier retrieves the value of the RangeProofGasMultiplier param from the parameter store
func (k BaseKeeper) GetRangeProofGasMultiplier(ctx sdk.Context) uint32 {
var rangeProofGas uint32
// GetRangeProofGasCost retrieves the value of the RangeProofGasCost param from the parameter store
func (k BaseKeeper) GetRangeProofGasCost(ctx sdk.Context) uint64 {
var rangeProofGas uint64
k.paramSpace.Get(ctx, types.KeyRangeProofGas, &rangeProofGas)
return rangeProofGas
}
Expand Down
31 changes: 14 additions & 17 deletions x/confidentialtransfers/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,13 @@ func (m msgServer) Withdraw(goCtx context.Context, req *types.MsgWithdraw) (*typ
// Verify that the account has sufficient funds (Remaining balance after making the transfer is greater than or equal to zero.)
// This range proof verification is performed on the RemainingBalanceCommitment sent by the user.
// An additional check is required to ensure that this matches the remaining balance calculated by the server.

// Consume additional gas as range proofs are computationally expensive.
cost := m.Keeper.GetRangeProofGasCost(ctx)
if cost > 0 {
ctx.GasMeter().ConsumeGas(cost, "range proof verification")
}

verified, _ := zkproofs.VerifyRangeProof(instruction.Proofs.RemainingBalanceRangeProof, instruction.RemainingBalanceCommitment, 128, m.CachedRangeVerifierFactory)
if !verified {
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidRequest, "range proof verification failed")
Expand Down Expand Up @@ -257,14 +264,6 @@ func (m msgServer) Withdraw(goCtx context.Context, req *types.MsgWithdraw) (*typ
return nil, sdkerrors.Wrapf(sdkerrors.ErrInsufficientFunds, "insufficient funds to withdraw %s %s", req.Amount, req.Denom)
}

gasSoFar := ctx.GasMeter().GasConsumed()
multiplier := m.Keeper.GetRangeProofGasMultiplier(ctx)

// Consume additional gas according to the multiplier as range proofs are computationally expensive.
if multiplier > 1 {
ctx.GasMeter().ConsumeGas(gasSoFar*uint64(multiplier-1), "range proof verification")
}

// Emit any required events
//TODO: Look into whether we can use EmitTypedEvents instead since EmitEvents is deprecated
ctx.EventManager().EmitEvents(sdk.Events{
Expand Down Expand Up @@ -444,6 +443,13 @@ func (m msgServer) Transfer(goCtx context.Context, req *types.MsgTransfer) (*typ
}

// Validate proofs
rangeProofGasCost := m.Keeper.GetRangeProofGasCost(ctx)

// Consume additional gas as range proofs are computationally expensive.
if rangeProofGasCost > 0 {
ctx.GasMeter().ConsumeGas(rangeProofGasCost, "range proof verification")
}

err = types.VerifyTransferProofs(instruction, &senderAccount.PublicKey, &recipientAccount.PublicKey, newSenderBalanceCiphertext, m.CachedRangeVerifierFactory)
if err != nil {
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidRequest, err.Error())
Expand Down Expand Up @@ -498,15 +504,6 @@ func (m msgServer) Transfer(goCtx context.Context, req *types.MsgTransfer) (*typ
return nil, sdkerrors.Wrap(sdkerrors.ErrInvalidRequest, "error setting recipient account")
}

gasSoFar := ctx.GasMeter().GasConsumed()
multiplier := m.Keeper.GetRangeProofGasMultiplier(ctx)

// Consume additional gas according to the multiplier as range proofs are computationally expensive.
// gasSoFar + ((multiplier-1) x gasSoFar) = multiplier x gasSoFar
if multiplier > 1 {
ctx.GasMeter().ConsumeGas(gasSoFar*uint64(multiplier-1), "range proof verification")
}

// Emit any required events
ctx.EventManager().EmitEvents(sdk.Events{
sdk.NewEvent(
Expand Down
2 changes: 1 addition & 1 deletion x/confidentialtransfers/types/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const (
var (
AccountsKeyPrefix = []byte{0x01}
KeyEnableCtModule = []byte("EnableCtModule")
KeyRangeProofGas = []byte("RangeProofGasMultiplier")
KeyRangeProofGas = []byte("RangeProofGasCost")
)

// GetAddressPrefix generates the prefix for all accounts under a specific address
Expand Down
19 changes: 8 additions & 11 deletions x/confidentialtransfers/types/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
// DefaultEnableCtModule is the default value for the EnableCtModule flag.
const DefaultEnableCtModule = true

// DefaultRangeProofGasMultiplier is the default value for RangeProofGasMultiplier param.
const DefaultRangeProofGasMultiplier = uint32(10)
// DefaultRangeProofGasCost is the default value for RangeProofGasCost param.
const DefaultRangeProofGasCost = uint64(1000000)

// ParamKeyTable ParamTable for confidential transfers module.
func ParamKeyTable() paramtypes.KeyTable {
Expand All @@ -20,8 +20,8 @@ func ParamKeyTable() paramtypes.KeyTable {
// DefaultParams default confidential transfers module parameters.
func DefaultParams() Params {
return Params{
EnableCtModule: DefaultEnableCtModule,
RangeProofGasMultiplier: DefaultRangeProofGasMultiplier,
EnableCtModule: DefaultEnableCtModule,
RangeProofGasCost: DefaultRangeProofGasCost,
}
}

Expand All @@ -31,7 +31,7 @@ func (p *Params) Validate() error {
return err
}

if err := validateRangeProofGasMultiplier(p.RangeProofGasMultiplier); err != nil {
if err := validateRangeProofGasCost(p.RangeProofGasCost); err != nil {
return err
}

Expand All @@ -42,7 +42,7 @@ func (p *Params) Validate() error {
func (p *Params) ParamSetPairs() paramtypes.ParamSetPairs {
return paramtypes.ParamSetPairs{
paramtypes.NewParamSetPair(KeyEnableCtModule, &p.EnableCtModule, validateEnableCtModule),
paramtypes.NewParamSetPair(KeyRangeProofGas, &p.RangeProofGasMultiplier, validateRangeProofGasMultiplier),
paramtypes.NewParamSetPair(KeyRangeProofGas, &p.RangeProofGasCost, validateRangeProofGasCost),
}
}

Expand All @@ -56,14 +56,11 @@ func validateEnableCtModule(i interface{}) error {
}

// Validator for the parameter.
func validateRangeProofGasMultiplier(i interface{}) error {
multiplier, ok := i.(uint32)
func validateRangeProofGasCost(i interface{}) error {
_, ok := i.(uint64)
if !ok {
return fmt.Errorf("invalid parameter type: %T", i)
}

if multiplier < 1 {
return fmt.Errorf("range proof gas multiplier must be greater than 0")
}
return nil
}
52 changes: 26 additions & 26 deletions x/confidentialtransfers/types/params.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 22 additions & 31 deletions x/confidentialtransfers/types/params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ func TestDefaultParams(t *testing.T) {
{
name: "default params",
want: Params{
EnableCtModule: DefaultEnableCtModule,
RangeProofGasMultiplier: DefaultRangeProofGasMultiplier,
EnableCtModule: DefaultEnableCtModule,
RangeProofGasCost: DefaultRangeProofGasCost,
},
},
}
Expand All @@ -33,8 +33,8 @@ func TestDefaultParams(t *testing.T) {

func TestParams_Validate(t *testing.T) {
type fields struct {
EnableCtModule bool
RangeProofGasMultiplier uint32
EnableCtModule bool
RangeProofGasCost uint64
}
tests := []struct {
name string
Expand All @@ -45,26 +45,17 @@ func TestParams_Validate(t *testing.T) {
{
name: "valid params",
fields: fields{
EnableCtModule: true,
RangeProofGasMultiplier: 10,
EnableCtModule: true,
RangeProofGasCost: 1000000,
},
wantErr: false,
},
{
name: "invalid params",
fields: fields{
EnableCtModule: true,
RangeProofGasMultiplier: 0,
},
wantErr: true,
errMsg: "range proof gas multiplier must be greater than 0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := Params{
EnableCtModule: tt.fields.EnableCtModule,
RangeProofGasMultiplier: tt.fields.RangeProofGasMultiplier,
EnableCtModule: tt.fields.EnableCtModule,
RangeProofGasCost: tt.fields.RangeProofGasCost,
}
err := p.Validate()
if (err != nil) != tt.wantErr {
Expand All @@ -91,34 +82,34 @@ func TestValidateEnableCtModule(t *testing.T) {
})
}

func TestValidateRangeProofGasMultiplier(t *testing.T) {
t.Run("valid multiplier", func(t *testing.T) {
multiplier := uint32(10)
err := validateRangeProofGasMultiplier(multiplier)
func TestValidateRangeProofGasCost(t *testing.T) {
t.Run("valid cost", func(t *testing.T) {
cost := uint64(1000000)
err := validateRangeProofGasCost(cost)
assert.Nil(t, err)
})

t.Run("valid but useless multiplier value", func(t *testing.T) {
flag := uint32(1)
err := validateRangeProofGasMultiplier(flag)
t.Run("valid but useless gas cost", func(t *testing.T) {
flag := uint64(0)
err := validateRangeProofGasCost(flag)
assert.Nil(t, err)
})

t.Run("invalid multiplier value", func(t *testing.T) {
flag := uint32(0)
err := validateRangeProofGasMultiplier(flag)
t.Run("invalid gas cost", func(t *testing.T) {
flag := -1
err := validateRangeProofGasCost(flag)
assert.Error(t, err)
})

t.Run("invalid multiplier type", func(t *testing.T) {
t.Run("invalid gas cost type", func(t *testing.T) {
flag := "True"
err := validateRangeProofGasMultiplier(flag)
err := validateRangeProofGasCost(flag)
assert.Error(t, err)
})
}

func TestParams_ParamSetPairs(t *testing.T) {
params := &Params{EnableCtModule: DefaultEnableCtModule, RangeProofGasMultiplier: DefaultRangeProofGasMultiplier}
params := &Params{EnableCtModule: DefaultEnableCtModule, RangeProofGasCost: DefaultRangeProofGasCost}
tests := []struct {
name string
want types.ParamSetPairs
Expand All @@ -127,7 +118,7 @@ func TestParams_ParamSetPairs(t *testing.T) {
name: "valid param set pairs",
want: types.ParamSetPairs{
types.NewParamSetPair(KeyEnableCtModule, &params.EnableCtModule, validateEnableCtModule),
types.NewParamSetPair(KeyRangeProofGas, &params.RangeProofGasMultiplier, validateRangeProofGasMultiplier),
types.NewParamSetPair(KeyRangeProofGas, &params.RangeProofGasCost, validateRangeProofGasCost),
},
},
}
Expand Down

0 comments on commit e19c16c

Please sign in to comment.