zkFold / zkfold-base

ZkFold's Base library
https://zkfold.io
MIT License
13 stars 6 forks source link

Proposal: Structured I/O in Circuit definition #159

Closed echatav closed 1 month ago

echatav commented 3 months ago

Here's a proposal sketch for structured I/O, it probably needs a lot of work, but at least evalC idC seems to work.

-- The input variables are numbered in decreasing order n..1
-- which is maintained as an invariant of the circuit.
data Circuit i o a = Circuit
  { systemC   :: Map Natural (Poly a Natural Natural)
    -- ^ The system of polynomial constraints
  , witnessC  :: i a -> Map Natural a
    -- ^ The witness generation function
  , outputC  :: o Natural
    -- ^ The output variables
  }

instance o ~ U1 => Semigroup (Circuit i o a) where
  c1 <> c2 = Circuit
    { systemC = systemC c1 <> systemC c2
    , witnessC = (<>) <$> witnessC c1 <*> witnessC c2
    , outputC = U1
    }

instance (o ~ U1, Traversable i) => Monoid (Circuit i o a) where
  mempty = Circuit mempty inputWitness U1

dimapC
  :: (i1 a -> i0 a)
  -> (o0 Natural -> o1 Natural)
  -> Circuit i0 o0 a -> Circuit i1 o1 a
dimapC f g c = c
  { witnessC = witnessC c . f
  , outputC = g (outputC c)
  }

joinC :: Circuit i oL a -> Circuit i oR a -> Circuit i (oL :*: oR) a
joinC c1 c2 =
  (c1 {outputC = U1} <> c2 {outputC = U1})
    {outputC = outputC c1 :*: outputC c2}

concatC
  :: (Functor v, Foldable v, Traversable i)
  => v (Circuit i o a) -> Circuit i (v :.: o) a
concatC cs =
  (foldMap (\c -> c {outputC = U1}) cs)
    {outputC = Comp1 (fmap outputC cs)}

evalC
  :: forall i o a. Functor o
  => Circuit i o a -> i a -> o a
evalC c i = fmap (\var -> witnessC c i Map.! var) (outputC c)

applyC :: i a -> Circuit (i :*: j) o a -> Circuit j o a
applyC i c = c {witnessC = witnessC c . ((:*:) i)}

inputC :: forall a i. (VectorSpace a i, Traversable i) => i Natural
inputC = iterateV @a @i Prelude.pred (dimV @a @i)

idC :: forall a i. (VectorSpace a i, Traversable i) => Circuit i i a
idC = mempty {outputC = inputC @a}

class (forall i j. Functor (m i j))
  => MonadCircuit a m | m -> a where
    return :: x -> m i i x
    (>>=) :: m i j x -> (x -> m j k y) -> m i k y
    (>>) :: m i j x -> m j k y -> m i k y
    x >> y = x >>= \_ -> y
    (<¢>) :: m i j (x -> y) -> m j k x -> m i k y
    f <¢> x = f >>= (<$> x)
    apply :: i a -> m (i :*: j) j ()
    input :: (VectorSpace a i, Traversable i) => m i i (i Natural)
    input = return (inputC @a)
    newInput ::
      ( VectorSpace a i, Traversable i
      , VectorSpace a j, Traversable j
      ) => m j (i :*: j) (i Natural)
    eval :: (Traversable i, Functor o) => m i i (o Natural) -> i a -> o a
    runCircuit :: Circuit i o a -> m i i (o Natural)

circuit
  :: Traversable i
  => (forall m. MonadCircuit a m => m i i (o Natural))
  -> Circuit i o a
circuit x =
  case runBlueprint x mempty of (o, c) -> c {outputC = o}

newtype Blueprint a i j x = Blueprint
  {runBlueprint :: Circuit i U1 a -> (x, Circuit j U1 a)}
  deriving Functor

instance Applicative (Blueprint a i i) where
  pure = return
  (<*>) = (<¢>)

instance MonadCircuit a (Blueprint a) where
  return x = Blueprint $ \c -> (x,c)
  m >>= f = Blueprint $ \c ->
    let
      (x, c') = runBlueprint m c
    in
      runBlueprint (f x) c'
  apply i = Blueprint $ \c -> ((), applyC i c)
  newInput = input >>= \i ->
    let
      iLen = Prelude.fromIntegral (length i)
    in
      Blueprint $ \c ->
        ( fmap (iLen +) (inputC @a)
        , dimapC (\(_ :*: j) -> j) id c
        )
  eval x i = case runBlueprint x mempty of
    (o, c) -> evalC (c {outputC = o}) i
  runCircuit c = Blueprint $ \c' ->
    (outputC c, c {outputC = U1} <> c')
echatav commented 3 months ago

Optimization possibility:

-   , witnessC  :: i a -> Map Natural a
+   , witnessC  :: Map Natural (i a -> a)
echatav commented 3 months ago

It also makes it easier to calculate the variable sets, since all variables is the key set of witnessC, the structured output variables are outputC and the structured input variables are inputC.

TurtlePU commented 1 month ago

Done in #214