mit-plv / fiat-crypto

Cryptographic Primitive Code Generation by Fiat
http://adam.chlipala.net/papers/FiatCryptoSP19/FiatCryptoSP19.pdf
Other
717 stars 147 forks source link

Need rewrite rules for new base conversion #833

Closed jadephilipoom closed 4 years ago

jadephilipoom commented 4 years ago

As discussed in a couple of recent fiat-crypto meetings and #813 (also related to #809), I've been trying to clean up Core.v by removing the enormous nth_default proofs about carry functions. As it stands, the only one being used is nth_default_chained_carries_no_reduce, which is used in base conversion to prove that convert_bases is equivalent to evaluating the input in the source base system and then evenly partitioning it in the destination base system (convert_bases_partitions). Therefore, removing the need for nth_default in that one proof will remove some of the largest and nastiest proofs from the core arithmetic library.

To that end, in the columns-base-conversion branch, I've written a definition of convert_bases that uses the saturated arithmetic library (Columns.from_associational and Columns.flatten) to convert between bases. It seems to work in my preliminary tests and the proofs pass. However, because the Columns definitions use Z.add_get_carry_full, the output ends up having add-get-carry statements that split at invalid bit indices, e.g.:

(uint40, uint8) (Z.add_get_carry (2^40) ((uint47) x) ((uint47) y))

Instead, we want an add and then a bitwise div/mod, e.g.:

dlet xy := (uint48 (Z.add (uint47 x) (uint47 y)) in
(uint40 (Z.shiftr (uint48 xy) (literal 8)), uint8 (Z.land (uint48 xy) (2^8-1))

I think existing rewrite rules will take care of translating a Z div/mod into a shiftr/land, but I don't know how to write a rule that will do the div/mod translation in the with-casts rules. Especially problematic is figuring out what the range of x + y should be, because it doesn't appear in the initial expression and I think (correct me if I'm wrong) the rewriter won't allow me to write a range in the new expression as an expression (e.g. ZRange.four_corners rx ry).

I see two main solutions:

  1. Write a rewrite rule that recognizes such add-get-carry statements and changes them into div/mod
  2. Change Columns definitions to use something other than Z.add_get_carry_full when some boolean is set, or parameterize Columns over your desired implementation of div/mod (which we can then plug in as Z.add_get_carry_full for saturated arithmetic and div/mod for base conversion)

I think 1 is preferable if it's possible. @JasonGross , can you write this rewrite rule? If 1 is not possible, I think I know how to do 2, but would like some feedback on whether introducing this extra abstraction to Columns is reasonable (@andres-erbsen ?). I think the answer kind of hinges on how we want to view Z.add_get_carry_full. Is it a definition that serves as a marker for places where we definitely want to use the carry flag? Or is it a shorthand for "div/mod, figure out if this should use carries based on bounds"?

JasonGross commented 4 years ago

because it doesn't appear in the initial expression and I think (correct me if I'm wrong) the rewriter won't allow me to write a range in the new expression as an expression (e.g. ZRange.four_corners rx ry).

As long as you make the ranges literals, then it's fine. This looks like writing ident.cast ('rx) x on the input and ident.cast ('(ZRange.four_corners rx ry)) (ident.cast ('rx) x + ident.cast ('ry) y) on the output.

I think the main question for me is "how do you recognize which Z.add_get_carry the role should trigger on?"

jadephilipoom commented 4 years ago

It might work to just have it trigger when the carry range is looser than r[0~>1]. But I somewhat doubt things will be quite so straightforward. I might end up writing something that takes advantage of the fact that we only use base conversion when converting to/from bytes (to_bytes will always be splitting at 2^8), or adding the rewrite rule in the with_bitwidth section so we can check if it's splitting at the bitwidth or not.

JasonGross commented 4 years ago

Btw, I think the rule you want is something like (untested)

(forall rv rs s rc r0 rx x ry y,
          (* some condition about when to trigger the rule -> *)
          s ∈ rs -> 0 ∈ r0 -> (n rx + n ry <= r[0~>s-1])%zrange
          -> (subst_let rxy := (rx + ry)%zrange in
                  cstZZ rv rc (Z.add_with_get_carry_full (cstZ rs ('s)) (cstZ r0 ('0)) (cstZ rx x) (cstZ ry y))
                  = (dlet xy := cstZ rxy (cstZ rx x + cstZ ry y) in
                         cstZZ rv rc (cstZ rv (Z.shiftr (cstZ rxy xy) (cstZ r[8~>8] ('s))),
                                      (cstZ rc (Z.land (cstZ rxy xy) (cstZ r[2^s-1~>2^s-1] ('(2^s-1))%Z)))))))%Z%zrange.
jadephilipoom commented 4 years ago

I think I've almost got it on the branch now -- to_bytes works, but from_bytes gives me a single None bound on the output that I don't understand. The printed expression tree shows casts in front of all the outputs.

If you have time, Jason, I'd appreciate you helping me debug this -- the build from Rules.v to something I can test takes 20 minutes on my machine and uses so much memory that I can't run a browser at the same time, so it's very painful to test things out if I don't know exactly what's wrong.

jadephilipoom commented 4 years ago

I'd like to register a more general concern with rewrite rules being opaque and really difficult to debug and iterate on. The fact that I needed to make the casts literals was something that I wouldn't have discovered without asking, which is a concern given the fact that you're about to leave. Is there a place where acceptable expressions and general guidelines for rewrite rules are codified?

My other concern is the slow iteration process. I've spent about five hours today repeatedly running this rewriter build to debug issues with the new rule (the missing bound, the rules firing mistakenly on Fancy code), where the only way for me to tell the effect of my change is to make the change, wait 20 minutes (or more -- at one point I put in a range1 <> range2 precondition and the build expressed dissatisfaction by just hanging until I finally cancelled it) and then see what output is produced. Is it possible to create a piece of code such that, given a particular rewrite rule and an expression, can tell me whether the rule fires on that expression and maybe show me the output without compiling the entire rewriter? That would certainly make this process smoother.

JasonGross commented 4 years ago

I think I've almost got it on the branch now -- to_bytes works, but from_bytes gives me a single None bound on the output that I don't understand. The printed expression tree shows casts in front of all the outputs.

My first guess is that you forgot a cast somewhere in the output of the rewrite rule? E.g., you write cst (cst x + cst y) = cst (y + x) or cst (cst x + cst y) = (cst y + cst x)? What's the rewrite rule you're using?

If you have time, Jason, I'd appreciate you helping me debug this

Sure. When? Or where's a pointer to the branch / bad output?

The fact that I needed to make the casts literals was something that I wouldn't have discovered without asking

You actually don't need to worry about this in the Rules.v file, because the cstZ notations are set up to always make casts literal.

which is a concern given the fact that you're about to leave. Is there a place where acceptable expressions and general guidelines for rewrite rules are codified?

I think a decent amount of this is scattered throughout the rewriter paper, and some partial work on this that I started in https://github.com/mit-plv/rewriter/blob/master/src/Rewriter/Demo.v. I've created https://github.com/mit-plv/rewriter/issues/14 for this, I'll plan to go flesh it out a bit more shortly.

Is it possible to create a piece of code such that, given a particular rewrite rule and an expression, can tell me whether the rule fires on that expression and maybe show me the output without compiling the entire rewriter?

I've just pushed a70c8ea9dcb73e5a6746fddb835e3e3dba3c9214, which adds a couple of test rules at https://github.com/mit-plv/fiat-crypto/blob/a70c8ea9dcb73e5a6746fddb835e3e3dba3c9214/src/Rewriter/TestRules.v#L48-L60 and a small example using these test rules at https://github.com/mit-plv/fiat-crypto/blob/a70c8ea9dcb73e5a6746fddb835e3e3dba3c9214/src/Rewriter/Passes/Test.v#L43-L54. You can modify just the test rules file, and then either test it out in Passes/Test.v or else Require Import it in BoundsPipeline to interleave it with other passes to test it on the full pipeline without needing to recompile all the other rewriter passes.

jadephilipoom commented 4 years ago

Thanks for making the test rules! That will be helpful. I'll try again tomorrow (it's late in the UK now).

I think it's definitely worthwhile to make some systematic documentation for the rewriter input language, especially if you're aiming for wider adoption -- I'll happily volunteer to review it! I am not sure if anyone other than me and you is using it right now, and if I'm struggling even with the ability to go to you for answers then a lack of documentation could really deter people from trying it out.

The branch I've been working on is at https://github.com/mit-plv/fiat-crypto/tree/columns-base-conversion I've been testing it with src/ExtractionOCaml/unsaturated_solinas --static 1271 64 '(auto)' '2^127 - 1' from_bytes to_bytes. Fancy is broken on this branch because my changes to convert_bases messed up its let binders in a way I'm also still debugging. If you know how to write rewrite rules that inline let binders then I'm all ears for that issue, but otherwise feel free to ignore it and focus on the unsaturated solinas implementations.

JasonGross commented 4 years ago

especially if you're aiming for wider adoption

I don't think the rewriter is suitable for wider adoption. It's very, very, very much a research prototype, and I don't see it becoming not a research prototype without going through a complete rewriting. Almost all of it is either a pile of hacks or proofs which are made very painful by inline CPS and dependent types. It's going to be hard to maintain, too, in it's current form. I think the ideas are good, but the implementation is just good enough to work for fiat-crypto. (Recall that the prior state-of-the-art was writing the pattern matches on PHOASTs by hand.) An incomplete list of things that I think need to be rewritten from scratch:

JasonGross commented 4 years ago

(Please respond to the above points in https://github.com/mit-plv/rewriter/issues/15 , so that we don't lose responses.)

JasonGross commented 4 years ago

The branch I've been working on is at https://github.com/mit-plv/fiat-crypto/tree/columns-base-conversion I've been testing it with src/ExtractionOCaml/unsaturated_solinas --static 1271 64 '(auto)' '2^127 - 1' from_bytes to_bytes. Fancy is broken on this branch because my changes to convert_bases messed up its let binders in a way I'm also still debugging. If you know how to write rewrite rules that inline let binders then I'm all ears for that issue, but otherwise feel free to ignore it and focus on the unsaturated solinas implementations.

I'll take a look and try it out when I get a chance, hopefully later tonight.

JasonGross commented 4 years ago

If you know how to write rewrite rules that inline let binders then I'm all ears for that issue, but otherwise feel free to ignore it and focus on the unsaturated solinas implementations.

I can give you an ident.inline so that when you run the code through RewriteAndEliminateDeadAndInline, things wrapped in ident.inline get inlined. (It won't be interwoven with rewrite rules though.) If you want it interwoven with other rewrite rules, I can also do that, but that's significantly more complicated. But note that you won't be able to match against expressions spanning multiple lines in any case, so are you sure this is what you want?

JasonGross commented 4 years ago

You have

+               (forall rm rd rs s rx x ry y,
+                   (r[0~>2] <= rm)%zrange -> s ∈ rs -> s = 2 ^ Z.log2 s
+                   -> cstZZ ('rm) ('rd) (Z.add_get_carry_full (cstZ rs ('s)) (cstZ rx x) (cstZ ry y))
+                      = (dlet xy := (cstZ ('(n rx + n ry)%zrange) (Z.add (cstZ rx x) (cstZ ry y))) in
+                             (cstZ rm (Z.land (cstZ ('(n rx + n ry)%zrange) xy) ('(s - 1)%Z)),
+                              cstZ rd (Z.shiftr (cstZ ('(n rx + n ry)%zrange) xy) ('Z.log2 s)))))

First: you don't need to wrap the ranges in literals, because cstZ does that for you. (Sorry for the red herring.)

Second: my guess is that the issue is that you're missing the cstZZ on the output tuple. I'll go try making this change now, and I'll push to the branch in maybe 30 minutes if it works.

JasonGross commented 4 years ago

Btw, the other thing you can do to make the build faster is change the second boolean argument from false to true in lines like https://github.com/mit-plv/fiat-crypto/blob/dc55596187fb67d0ff8e5c422dc1e1463c969d74/src/Rewriter/Passes/ArithWithCasts.v#L21 This tells the rewriter to skip the early reduction, which makes the passes significantly faster to create (though makes running the pipeline itself somewhat slower (perhaps 2x slower?))

JasonGross commented 4 years ago

Okay, I've found the root of the issue. The issue is three-fold, sort-of:

  1. We check the computed bounds on the intermediate expression. We do this so that the error messages in the code are post-rewrite rules, rather than being very cryptic. This is in https://github.com/mit-plv/fiat-crypto/blob/dc55596187fb67d0ff8e5c422dc1e1463c969d74/src/BoundsPipeline.v#L495-L515
  2. When we extract the bounds information from the (intermediate) expression for checking, we don't assume that cast truncates in https://github.com/mit-plv/fiat-crypto/blob/dc55596187fb67d0ff8e5c422dc1e1463c969d74/src/AbstractInterpretation/AbstractInterpretation.v#L580
  3. The bounds analysis rule for Z.land rounds the bounds to one less than the nearest power of two. This results in too-loose bounds on expressions involving Z.land. In particular, here we are taking Z.land of any unsigned 42-bit integer with 0b111111111111111000000000000000000000000000. The bounds given on the rewrite rule are "between 0 and 0b111111111111111000000000000000000000000000", while the computed bound is "between 0 and 242-1".

We could solve this with either (2) change the argument to partial.Extract from false to true so that we are assuming that casts truncate, or (3) tighten up the bounds on Z.land.

It seems like (3) is a good thing to do in general, so I'll make an issue for it. This would consist of changing the bounds on Z.land in https://github.com/mit-plv/fiat-crypto/blob/dc55596187fb67d0ff8e5c422dc1e1463c969d74/src/Util/ZRange/Operations.v#L157 to something like

    := four_corners_and_zero (fun x y => if x <? 0 && y <? 0 then Z.min x y else Z.max x y)
                             (extend_land_lor_bounds x) (extend_land_lor_bounds y).

(the logic is that if x and y are both negative, then the we're bounded below by the min; if they're both positive, then we're bounded above by the max. If one is negative and the other is positive, we were already extending the range on the negative one to include -1, so having the negative value actually imposes no truncation on the positive value, and so we can just take the max.)

I'm currently checking whether or not changing the boolean value on partial.Extract is sufficient.


Here's my debugging code (should possibly make it into `src/SlowPrimeSynthesisExamples.v`) ```coq Module debugging_21271_from_bytes. Import Crypto.PushButtonSynthesis.UnsaturatedSolinas. Import Stringification.C. Import Stringification.C.Compilers. Import Stringification.C.Compilers.ToString. Section __. Local Existing Instance C.OutputCAPI. Local Instance static : static_opt := false. Local Instance : internal_static_opt := true. Local Instance : emit_primitives_opt := false. Local Instance : use_mul_for_cmovznz_opt := false. Local Instance : widen_carry_opt := false. Local Instance : widen_bytes_opt := false. Local Instance : only_signed_opt := false. Local Instance : no_select_opt := false. Local Instance : should_split_mul_opt := false. Local Instance : should_split_multiret_opt := false. Definition n := 3%nat (*5%nat*). Definition s := 2^127 (* 255*). Definition c := [(1, 1(*9*))]. Definition machine_wordsize := 64. Import IR.Compilers.ToString. Goal True. pose (sfrom_bytes n s c machine_wordsize "1271") as v. cbv [sfrom_bytes] in v. set (k := from_bytes _ _ _ _) in (value of v). clear v. cbv [from_bytes] in k. cbv [Pipeline.BoundsPipeline] in k. set (k' := Pipeline.PreBoundsPipeline _ _ _ _ _) in (value of k). vm_compute in k'. cbv [Rewriter.Util.LetIn.Let_In] in k. set (k'' := CheckedPartialEvaluateWithBounds _ _ _ _ _ _) in (value of k). vm_compute in k''. lazymatch (eval cbv [k''] in k'') with | @inl ?A ?B ?v => pose v as V; change k'' with (@inl A B V) in (value of k) end. cbv beta iota zeta in k. clear k''. set (e := GeneralizeVar.FromFlat _) in (value of k). vm_compute in e. set (k'' := CheckedPartialEvaluateWithBounds _ _ _ _ _ _) in (value of k). cbv [CheckedPartialEvaluateWithBounds] in k''. clear -k''. cbv [Rewriter.Util.LetIn.Let_In] in k''. set (e' := (GeneralizeVar.FromFlat (GeneralizeVar.ToFlat e))) in (value of k''). vm_compute in e'; clear e; rename e' into e. set (b := (partial.Extract _ _ _)) in (value of k''). clear -b. cbv [partial.Extract partial.ident.extract partial.extract_gen type.app_curried partial.extract'] in b. subst e. cbv beta iota zeta in b. Import Rewriter.Util.LetIn. cbn [partial.abstract_interp_ident] in b. cbv [partial.abstract_interp_ident] in b. cbv [ZRange.ident.option.interp] in b. cbv [ZRange.ident.option.of_literal] in b. cbn [ZRange.ident.option.interp_Z_cast option_map] in b. cbv [partial.abstract_domain ZRange.type.base.option.interp type.interp ZRange.type.base.interp] in b. cbn [fst snd] in b. (do 54 (lazymatch (eval cbv [b] in b) with | dlet x := ?v in _ => let v' := (eval vm_compute in v) in change v with v' in (value of b) end; unfold Let_In at 1 in (value of b)); lazymatch (eval cbv [b] in b) with | dlet x := ?v in _ => let v' := (eval vm_compute in v) in change v with v' in (value of b) end). unfold Let_In at 1 in (value of b). lazymatch (eval cbv [b] in b) with | context[Crypto.Util.Option.bind ?v _] => let v' := (eval vm_compute in v) in change v with v' in (value of b) end. cbn [Crypto.Util.Option.bind] in b. set (k' := Operations.ZRange.land_bounds _ _) in (value of b). cbv [Operations.ZRange.land_bounds] in k'. clear -k'. Print Operations.ZRange.land_lor_bounds. vm_compute in k'. Abort. End __. End debugging_21271_from_bytes. ```
JasonGross commented 4 years ago

Okay, so it turns out that we can't just assume that cast truncates in just extract, we would have to make this change more uniform... Do we want to commit to assuming that casts truncate everywhere, even more deeply than we already are?

jadephilipoom commented 4 years ago

Fixing the bounds on Z.land definitely seems worth it. But I'm less sure about making the cast-is-truncation assumption more pervasive. Just to check my understanding, does this affect the assumptions we need to make about target architectures, or only our internal reasoning?

I think I'm going to do some exploration on option 2, parameterizing Columns.flatten over the implementation of add-get-carry. It makes conceptual sense to me, might actually fix the current issue with Fancy, and would avoid extensive changes to the rewriter and bounds handling.

jadephilipoom commented 4 years ago

In the columns-base-conversion-with-parameterization2 branch, I've gotten the Fancy issue fixed and unsaturated Solinas to_bytes working by parameterizing Columns.flatten over add-get-carry. I think we should go with this strategy instead of rewriting add-get-carry into div/mod.

However, from_bytes still has the issue with the single None bound, so that needs to be fixed before we can merge it. If #837 is the best way to do that, then I think it's worth it.

JasonGross commented 4 years ago

Solved by #838