Skip to content

Commit

Permalink
fix: assert that the binary decomposition of a variable is less than …
Browse files Browse the repository at this point in the history
…the modulus (#835)

The issue was reported by Marcin Kostrzewa @ Reilabs (@kustosz).

* test: add internal regression test pacakge for tracking filed bugs

* test: add test for reproducing non-unique binary decomposition

* refactor: make comparison against a constant bound public

Even though the method is public this method is not listable as we export
interfaces with smaller method sets. The method can be accessed by implicitly
implementing interface with the method `MustBeLessOrEqCst(aBits
[]frontend.Variable, bound *big.Int, aForDebug frontend.Variable)`.

We use the method for checking in `std/math/bits` package that the binary
decomposition of the bound returned by hint is less than the modulus.

* feat: add option to omit uniqueness check when binary decomposing value

* fix: check that the binary decomposition is unique when nbBits=modlen

* refactor: use bits gadget directly for option control

* docs: describe in documentation alternatives to Cmp and LEQ

* test: add test for math/cmp cases

* feat: add TestEngine checking

* test: reduce binary decomposition length to accomodate tinyfield

* refactor: rename to OmitModulusCheck

* fix: limit decomposition length to fieldbitlen

* feat: allow less than nbBits in constant comparison

* test: update circuit statistics

* feat: allow decomposition length be longer than field length

* Revert "feat: add TestEngine checking"

This reverts commit 8da5d07.

* test: implement constant comparison for test engine

* test: rename file to track issue
  • Loading branch information
ivokub authored Sep 19, 2023
1 parent 7f38c21 commit 59a4087
Show file tree
Hide file tree
Showing 13 changed files with 346 additions and 45 deletions.
15 changes: 13 additions & 2 deletions frontend/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,14 @@ type API interface {
// IsZero returns 1 if a is zero, 0 otherwise
IsZero(i1 Variable) Variable

// Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1<i2
// Cmp returns:
// * 1 if i1>i2,
// * 0 if i1=i2,
// * -1 if i1<i2.
//
// If the absolute difference between the variables i1 and i2 is known, then
// it is more efficient to use the bounded methdods in package
// [github.com/consensys/gnark/std/math/bits].
Cmp(i1, i2 Variable) Variable

// ---------------------------------------------------------------------------------------------
Expand All @@ -115,7 +122,11 @@ type API interface {
// AssertIsBoolean fails if v != 0 ∥ v != 1
AssertIsBoolean(i1 Variable)

// AssertIsLessOrEqual fails if v > bound
// AssertIsLessOrEqual fails if v > bound.
//
// If the absolute difference between the variables b and bound is known, then
// it is more efficient to use the bounded methdods in package
// [github.com/consensys/gnark/std/math/bits].
AssertIsLessOrEqual(v Variable, bound Variable)

// Println behaves like fmt.Println but accepts cd.Variable as parameter
Expand Down
12 changes: 8 additions & 4 deletions frontend/cs/r1cs/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ package r1cs
import (
"errors"
"fmt"
"github.com/consensys/gnark/internal/utils"
"path/filepath"
"reflect"
"runtime"
"strings"

"github.com/consensys/gnark/internal/utils"

"github.com/consensys/gnark/debug"
"github.com/consensys/gnark/frontend/cs"

Expand Down Expand Up @@ -570,9 +571,12 @@ func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable {
// Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1<i2
func (builder *builder) Cmp(i1, i2 frontend.Variable) frontend.Variable {

vars, _ := builder.toVariables(i1, i2)
bi1 := builder.ToBinary(vars[0], builder.cs.FieldBitLen())
bi2 := builder.ToBinary(vars[1], builder.cs.FieldBitLen())
nbBits := builder.cs.FieldBitLen()
// in AssertIsLessOrEq we omitted comparison against modulus for the left
// side as if `a+r<b` implies `a<b`, then here we compute the inequality
// directly.
bi1 := bits.ToBinary(builder, i1, bits.WithNbDigits(nbBits))
bi2 := bits.ToBinary(builder, i2, bits.WithNbDigits(nbBits))

res := builder.cstZero()

Expand Down
27 changes: 17 additions & 10 deletions frontend/cs/r1cs/api_assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,12 @@ func (builder *builder) AssertIsLessOrEqual(v frontend.Variable, bound frontend.
}
}

nbBits := builder.cs.FieldBitLen()
vBits := bits.ToBinary(builder, v, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs())

// bound is constant
if bConst {
vv := builder.toVariable(v)
builder.mustBeLessOrEqCst(vv, builder.cs.ToBigInt(cb))
builder.MustBeLessOrEqCst(vBits, builder.cs.ToBigInt(cb), v)
return
}

Expand All @@ -119,8 +121,8 @@ func (builder *builder) mustBeLessOrEqVar(a, bound frontend.Variable) {

nbBits := builder.cs.FieldBitLen()

aBits := bits.ToBinary(builder, a, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs())
boundBits := builder.ToBinary(bound, nbBits)
aBits := bits.ToBinary(builder, a, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs(), bits.OmitModulusCheck())
boundBits := bits.ToBinary(builder, bound, bits.WithNbDigits(nbBits))

// constraint added
added := make([]int, 0, nbBits)
Expand Down Expand Up @@ -166,9 +168,18 @@ func (builder *builder) mustBeLessOrEqVar(a, bound frontend.Variable) {

}

func (builder *builder) mustBeLessOrEqCst(a expr.LinearExpression, bound *big.Int) {
// MustBeLessOrEqCst asserts that value represented using its bit decomposition
// aBits is less or equal than constant bound. The method boolean constraints
// the bits in aBits, so the caller can provide unconstrained bits.
func (builder *builder) MustBeLessOrEqCst(aBits []frontend.Variable, bound *big.Int, aForDebug frontend.Variable) {

nbBits := builder.cs.FieldBitLen()
if len(aBits) > nbBits {
panic("more input bits than field bit length")
}
for i := len(aBits); i < nbBits; i++ {
aBits = append(aBits, 0)
}

// ensure the bound is positive, it's bit-len doesn't matter
if bound.Sign() == -1 {
Expand All @@ -179,11 +190,7 @@ func (builder *builder) mustBeLessOrEqCst(a expr.LinearExpression, bound *big.In
}

// debug info
debug := builder.newDebugInfo("mustBeLessOrEq", a, " <= ", builder.toVariable(bound))

// note that at this stage, we didn't boolean-constraint these new variables yet
// (as opposed to ToBinary)
aBits := bits.ToBinary(builder, a, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs())
debug := builder.newDebugInfo("mustBeLessOrEq", aForDebug, " <= ", builder.toVariable(bound))

// t trailing bits in the bound
t := 0
Expand Down
8 changes: 6 additions & 2 deletions frontend/cs/scs/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,12 @@ func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable {
// Cmp returns 1 if i1>i2, 0 if i1=i2, -1 if i1<i2
func (builder *builder) Cmp(i1, i2 frontend.Variable) frontend.Variable {

bi1 := builder.ToBinary(i1, builder.cs.FieldBitLen())
bi2 := builder.ToBinary(i2, builder.cs.FieldBitLen())
nbBits := builder.cs.FieldBitLen()
// in AssertIsLessOrEq we omitted comparison against modulus for the left
// side as if `a+r<b` implies `a<b`, then here we compute the inequality
// directly.
bi1 := bits.ToBinary(builder, i1, bits.WithNbDigits(nbBits))
bi2 := bits.ToBinary(builder, i2, bits.WithNbDigits(nbBits))

var res frontend.Variable
res = 0
Expand Down
57 changes: 36 additions & 21 deletions frontend/cs/scs/api_assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"github.com/consensys/gnark/debug"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/frontend/internal/expr"
"github.com/consensys/gnark/internal/utils"
"github.com/consensys/gnark/std/math/bits"
)

Expand Down Expand Up @@ -131,11 +130,30 @@ func (builder *builder) AssertIsBoolean(i1 frontend.Variable) {

// AssertIsLessOrEqual fails if v > bound
func (builder *builder) AssertIsLessOrEqual(v frontend.Variable, bound frontend.Variable) {
switch b := bound.(type) {
case expr.Term:
cv, vConst := builder.constantValue(v)
cb, bConst := builder.constantValue(bound)

// both inputs are constants
if vConst && bConst {
bv, bb := builder.cs.ToBigInt(cv), builder.cs.ToBigInt(cb)
if bv.Cmp(bb) == 1 {
panic(fmt.Sprintf("AssertIsLessOrEqual: %s > %s", bv.String(), bb.String()))
}
}

nbBits := builder.cs.FieldBitLen()
vBits := bits.ToBinary(builder, v, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs())

// bound is constant
if bConst {
builder.MustBeLessOrEqCst(vBits, builder.cs.ToBigInt(cb), v)
return
}

if b, ok := bound.(expr.Term); ok {
builder.mustBeLessOrEqVar(v, b)
default:
builder.mustBeLessOrEqCst(v, utils.FromInterface(b))
} else {
panic(fmt.Sprintf("expected bound type expr.Term, got %T", bound))
}
}

Expand All @@ -145,8 +163,8 @@ func (builder *builder) mustBeLessOrEqVar(a frontend.Variable, bound expr.Term)

nbBits := builder.cs.FieldBitLen()

aBits := bits.ToBinary(builder, a, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs())
boundBits := builder.ToBinary(bound, nbBits)
aBits := bits.ToBinary(builder, a, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs(), bits.OmitModulusCheck())
boundBits := bits.ToBinary(builder, bound, bits.WithNbDigits(nbBits)) // enforces range check against modulus

p := make([]frontend.Variable, nbBits+1)
p[nbBits] = 1
Expand Down Expand Up @@ -191,9 +209,18 @@ func (builder *builder) mustBeLessOrEqVar(a frontend.Variable, bound expr.Term)

}

func (builder *builder) mustBeLessOrEqCst(a frontend.Variable, bound big.Int) {
// MustBeLessOrEqCst asserts that value represented using its bit decomposition
// aBits is less or equal than constant bound. The method boolean constraints
// the bits in aBits, so the caller can provide unconstrained bits.
func (builder *builder) MustBeLessOrEqCst(aBits []frontend.Variable, bound *big.Int, aForDebug frontend.Variable) {

nbBits := builder.cs.FieldBitLen()
if len(aBits) > nbBits {
panic("more input bits than field bit length")
}
for i := len(aBits); i < nbBits; i++ {
aBits = append(aBits, 0)
}

// ensure the bound is positive, it's bit-len doesn't matter
if bound.Sign() == -1 {
Expand All @@ -203,20 +230,8 @@ func (builder *builder) mustBeLessOrEqCst(a frontend.Variable, bound big.Int) {
panic("AssertIsLessOrEqual: bound is too large, constraint will never be satisfied")
}

if ca, ok := builder.constantValue(a); ok {
// a is constant, compare the big int values
ba := builder.cs.ToBigInt(ca)
if ba.Cmp(&bound) == 1 {
panic(fmt.Sprintf("AssertIsLessOrEqual: %s > %s", ba.String(), bound.String()))
}
}

// debug info
debug := builder.newDebugInfo("mustBeLessOrEq", a, " <= ", bound)

// note that at this stage, we didn't boolean-constraint these new variables yet
// (as opposed to ToBinary)
aBits := bits.ToBinary(builder, a, bits.WithNbDigits(nbBits), bits.WithUnconstrainedOutputs())
debug := builder.newDebugInfo("mustBeLessOrEq", aForDebug, " <= ", bound)

// t trailing bits in the bound
t := 0
Expand Down
2 changes: 1 addition & 1 deletion internal/backend/circuits/hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (circuit *recursiveHint) Define(api frontend.API) error {
// api.ToBinary calls another hint (bits.NBits) with linearExpression as input
// however, when the solver will resolve bits[...] it will need to detect w1 as a dependency
// in order to compute the correct linearExpression value
bits := api.ToBinary(linearExpression, 10)
bits := api.ToBinary(linearExpression, 6)

a := api.FromBinary(bits...)

Expand Down
2 changes: 2 additions & 0 deletions internal/regression_tests/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// Package regressiontests includes tests to avoid re-introducing regressions.
package regressiontests
Loading

0 comments on commit 59a4087

Please sign in to comment.