https://github.com/akkartik/mu/blob/master/054static_dispatch.cc
  1 //: Transform to maintain multiple variants of a recipe depending on the
  2 //: number and types of the ingredients and products. Allows us to use nice
  3 //: names like 'print' or 'length' in many mutually extensible ways.
  4 
  5 :(scenario static_dispatch)
  6 def main [
  7   7:num/raw <- test 3
  8 ]
  9 def test a:num -> z:num [
 10   z <- copy 1
 11 ]
 12 def test a:num, b:num -> z:num [
 13   z <- copy 2
 14 ]
 15 +mem: storing 1 in location 7
 16 
 17 //: When loading recipes, accumulate variants if headers don't collide, and
 18 //: flag an error if headers collide.
 19 
 20 :(before "End Globals")
 21 map<string, vector<recipe_ordinal> > Recipe_variants;
 22 :(before "End One-time Setup")
 23 put(Recipe_variants, "main", vector<recipe_ordinal>());  // since we manually added main to Recipe_ordinal
 24 
 25 :(before "End Globals")
 26 map<string, vector<recipe_ordinal> > Recipe_variants_snapshot;
 27 :(before "End save_snapshots")
 28 Recipe_variants_snapshot = Recipe_variants;
 29 :(before "End restore_snapshots")
 30 Recipe_variants = Recipe_variants_snapshot;
 31 
 32 :(before "End Load Recipe Header(result)")
 33 // there can only ever be one variant for main
 34 if (result.name != "main" && contains_key(Recipe_ordinal, result.name)) {
 35   const recipe_ordinal r = get(Recipe_ordinal, result.name);
 36   if (!contains_key(Recipe, r) || get(Recipe, r).has_header) {
 37     string new_name = matching_variant_name(result);
 38     if (new_name.empty()) {
 39       // variant doesn't already exist
 40       new_name = next_unused_recipe_name(result.name);
 41       put(Recipe_ordinal, new_name, Next_recipe_ordinal++);
 42       get_or_insert(Recipe_variants, result.name).push_back(get(Recipe_ordinal, new_name));
 43     }
 44     trace("load") << "switching " << result.name << " to " << new_name << end();
 45     result.name = new_name;
 46     result.is_autogenerated = true;
 47   }
 48 }
 49 else {
 50   // save first variant
 51   put(Recipe_ordinal, result.name, Next_recipe_ordinal++);
 52   get_or_insert(Recipe_variants, result.name).push_back(get(Recipe_ordinal, result.name));
 53 }
 54 
 55 :(code)
 56 string matching_variant_name(const recipe& rr) {
 57   const vector<recipe_ordinal>& variants = get_or_insert(Recipe_variants, rr.name);
 58   for (int i = 0;  i < SIZE(variants);  ++i) {
 59     if (!contains_key(Recipe, variants.at(i))) continue;
 60     const recipe& candidate = get(Recipe, variants.at(i));
 61     if (!all_reagents_match(rr, candidate)) continue;
 62     return candidate.name;
 63   }
 64   return "";
 65 }
 66 
 67 bool all_reagents_match(const recipe& r1, const recipe& r2) {
 68   if (SIZE(r1.ingredients) != SIZE(r2.ingredients)) return false;
 69   if (SIZE(r1.products) != SIZE(r2.products)) return false;
 70   for (int i = 0;  i < SIZE(r1.ingredients);  ++i) {
 71     expand_type_abbreviations(r1.ingredients.at(i).type);
 72     expand_type_abbreviations(r2.ingredients.at(i).type);
 73     if (!deeply_equal_type_names(r1.ingredients.at(i), r2.ingredients.at(i)))
 74       return false;
 75   }
 76   for (int i = 0;  i < SIZE(r1.products);  ++i) {
 77     expand_type_abbreviations(r1.products.at(i).type);
 78     expand_type_abbreviations(r2.products.at(i).type);
 79     if (!deeply_equal_type_names(r1.products.at(i), r2.products.at(i)))
 80       return false;
 81   }
 82   return true;
 83 }
 84 
 85 :(before "End Globals")
 86 set<string> Literal_type_names;
 87 :(before "End One-time Setup")
 88 Literal_type_names.insert("number");
 89 Literal_type_names.insert("character");
 90 :(code)
 91 bool deeply_equal_type_names(const reagent& a, const reagent& b) {
 92   return deeply_equal_type_names(a.type, b.type);
 93 }
 94 bool deeply_equal_type_names(const type_tree* a, const type_tree* b) {
 95   if (!a) return !b;
 96   if (!b) return !a;
 97   if (a->atom != b->atom) return false;
 98   if (a->atom) {
 99     if (a->name == "literal" && b->name == "literal")
100       return true;
101     if (a->name == "literal")
102       return Literal_type_names.find(b->name) != Literal_type_names.end();
103     if (b->name == "literal")
104       return Literal_type_names.find(a->name) != Literal_type_names.end();
105     return a->name == b->name;
106   }
107   return deeply_equal_type_names(a->left, b->left)
108       && deeply_equal_type_names(a->right, b->right);
109 }
110 
111 string next_unused_recipe_name(const string& recipe_name) {
112   for (int i = 2;  /*forever*/;  ++i) {
113     ostringstream out;
114     out << recipe_name << '_' << i;
115     if (!contains_key(Recipe_ordinal, out.str()))
116       return out.str();
117   }
118 }
119 
120 //: Once all the recipes are loaded, transform their bodies to replace each
121 //: call with the most suitable variant.
122 
123 :(scenario static_dispatch_picks_most_similar_variant)
124 def main [
125   7:num/raw <- test 3, 4, 5
126 ]
127 def test a:num -> z:num [
128   z <- copy 1
129 ]
130 def test a:num, b:num -> z:num [
131   z <- copy 2
132 ]
133 +mem: storing 2 in location 7
134 
135 //: support recipe headers in a previous transform to fill in missing types
136 :(before "End check_or_set_invalid_types")
137 for (int i = 0;  i < SIZE(caller.ingredients);  ++i)
138   check_or_set_invalid_types(caller.ingredients.at(i).type, maybe(caller.name), "recipe header ingredient");
139 for (int i = 0;  i < SIZE(caller.products);  ++i)
140   check_or_set_invalid_types(caller.products.at(i).type, maybe(caller.name), "recipe header product");
141 
142 //: save original name of recipes before renaming them
143 :(before "End recipe Fields")
144 string original_name;
145 //: original name is only set during load
146 :(before "End Load Recipe Name")
147 result.original_name = result.name;
148 
149 //: after filling in all missing types (because we'll be introducing 'blank' types in this transform in a later layer, for shape-shifting recipes)
150 :(after "Transform.push_back(transform_names)")
151 Transform.push_back(resolve_ambiguous_calls);  // idempotent
152 
153 //: In a later layer we'll introduce recursion in resolve_ambiguous_calls, by
154 //: having it generate code for shape-shifting recipes and then transform such
155 //: code. This data structure will help error messages be more useful.
156 //:
157 //: We're punning the 'call' data structure just because it has slots for
158 //: calling recipe and calling instruction.
159 :(before "End Globals")
160 list<call> Resolve_stack;
161 
162 :(code)
163 void resolve_ambiguous_calls(const recipe_ordinal r) {
164   recipe& caller_recipe = get(Recipe, r);
165   trace(9991, "transform") << "--- resolve ambiguous calls for recipe " << caller_recipe.name << end();
166   for (int index = 0;  index < SIZE(caller_recipe.steps);  ++index) {
167     instruction& inst = caller_recipe.steps.at(index);
168     if (inst.is_label) continue;
169     resolve_ambiguous_call(r, index, inst, caller_recipe);
170   }
171 }
172 
173 void resolve_ambiguous_call(const recipe_ordinal r, int index, instruction& inst, const recipe& caller_recipe) {
174   // End resolve_ambiguous_call(r, index, inst, caller_recipe) Special-cases
175   if (non_ghost_size(get_or_insert(Recipe_variants, inst.name)) == 0) return;
176   trace(9992, "transform") << "instruction " << to_original_string(inst) << end();
177   Resolve_stack.push_front(call(r, index));
178   string new_name = best_variant(inst, caller_recipe);
179   if (!new_name.empty())
180     inst.name = new_name;
181   assert(Resolve_stack.front().running_recipe == r);
182   assert(Resolve_stack.front().running_step_index == index);
183   Resolve_stack.pop_front();
184 }
185 
186 string best_variant(const instruction& inst, const recipe& caller_recipe) {
187   const vector<recipe_ordinal>& variants = get(Recipe_variants, inst.name);
188   vector<recipe_ordinal> candidates;
189 
190   // Static Dispatch Phase 1
191 //?   cerr << inst.name << " phase 1\n";
192   candidates = strictly_matching_variants(inst, variants);
193   if (!candidates.empty()) return best_variant(inst, candidates).name;
194 
195 //?   cerr << inst.name << " phase 3\n";
196   // Static Dispatch Phase 2
197   //: (shape-shifting recipes in a later layer)
198   // End Static Dispatch Phase 2
199 
200   // Static Dispatch Phase 3
201 //?   cerr << inst.name << " phase 4\n";
202   candidates = matching_variants(inst, variants);
203   if (!candidates.empty()) return best_variant(inst, candidates).name;
204 
205   // error messages
206   if (!is_primitive(get(Recipe_ordinal, inst.name))) {  // we currently don't check types for primitive variants
207     if (SIZE(variants) == 1) {
208       raise << maybe(caller_recipe.name) << "types don't match in call for '" << to_original_string(inst) << "'\n" << end();
209       raise << "  which tries to call '" << original_header_label(get(Recipe, variants.at(0))) << "'\n" << end();
210     }
211     else {
212       raise << maybe(caller_recipe.name) << "failed to find a matching call for '" << to_original_string(inst) << "'\n" << end();
213       raise << "  available variants are:\n" << end();
214       for (int i = 0;  i < SIZE(variants);  ++i)
215         raise << "    " << original_header_label(get(Recipe, variants.at(i))) << '\n' << end();
216     }
217     for (list<call>::iterator p = /*skip*/++Resolve_stack.begin();  p != Resolve_stack.end();  ++p) {
218       const recipe& specializer_recipe = get(Recipe, p->running_recipe);
219       const instruction& specializer_inst = specializer_recipe.steps.at(p->running_step_index);
220       if (specializer_recipe.name != "interactive")
221         raise << "  (from '" << to_original_string(specializer_inst) << "' in " << specializer_recipe.name << ")\n" << end();
222       else
223         raise << "  (from '" << to_original_string(specializer_inst) << "')\n" << end();
224       // One special-case to help with the rewrite_stash transform. (cross-layer)
225       if (specializer_inst.products.at(0).name.find("stash_") == 0) {
226         instruction stash_inst;
227         if (next_stash(*p, &stash_inst)) {
228           if (specializer_recipe.name != "interactive")
229             raise << "  (part of '" << to_original_string(stash_inst) << "' in " << specializer_recipe.name << ")\n" << end();
230           else
231             raise << "  (part of '" << to_original_string(stash_inst) << "')\n" << end();
232         }
233       }
234     }
235   }
236   return "";
237 }
238 
239 // phase 1
240 vector<recipe_ordinal> strictly_matching_variants(const instruction& inst, const vector<recipe_ordinal>& variants) {
241   vector<recipe_ordinal> result;
242   for (int i = 0;  i < SIZE(variants);  ++i) {
243     if (variants.at(i) == -1) continue;
244     trace(9992, "transform") << "checking variant (strict) " << i << ": " << header_label(variants.at(i)) << end();
245     if (all_header_reagents_strictly_match(inst, get(Recipe, variants.at(i))))
246       result.push_back(variants.at(i));
247   }
248   return result;
249 }
250 
251 bool all_header_reagents_strictly_match(const instruction& inst, const recipe& variant) {
252   for (int i = 0;  i < min(SIZE(inst.ingredients), SIZE(variant.ingredients));  ++i) {
253     if (!types_strictly_match(variant.ingredients.at(i), inst.ingredients.at(i))) {
254       trace(9993, "transform") << "strict match failed: ingredient " << i << end();
255       return false;
256     }
257   }
258   for (int i = 0;  i < min(SIZE(inst.products), SIZE(variant.products));  ++i) {
259     if (is_dummy(inst.products.at(i))) continue;
260     if (!types_strictly_match(variant.products.at(i), inst.products.at(i))) {
261       trace(9993, "transform") << "strict match failed: product " << i << end();
262       return false;
263     }
264   }
265   return true;
266 }
267 
268 // phase 3
269 vector<recipe_ordinal> matching_variants(const instruction& inst, const vector<recipe_ordinal>& variants) {
270   vector<recipe_ordinal> result;
271   for (int i = 0;  i < SIZE(variants);  ++i) {
272     if (variants.at(i) == -1) continue;
273     trace(9992, "transform") << "checking variant " << i << ": " << header_label(variants.at(i)) << end();
274     if (all_header_reagents_match(inst, get(Recipe, variants.at(i))))
275       result.push_back(variants.at(i));
276   }
277   return result;
278 }
279 
280 bool all_header_reagents_match(const instruction& inst, const recipe& variant) {
281   for (int i = 0;  i < min(SIZE(inst.ingredients), SIZE(variant.ingredients));  ++i) {
282     if (!types_match(variant.ingredients.at(i), inst.ingredients.at(i))) {
283       trace(9993, "transform") << "match failed: ingredient " << i << end();
284       return false;
285     }
286   }
287   for (int i = 0;  i < min(SIZE(variant.products), SIZE(inst.products));  ++i) {
288     if (is_dummy(inst.products.at(i))) continue;
289     if (!types_match(variant.products.at(i), inst.products.at(i))) {
290       trace(9993, "transform") << "match failed: product " << i << end();
291       return false;
292     }
293   }
294   return true;
295 }
296 
297 // tie-breaker for each phase
298 const recipe& best_variant(const instruction& inst, vector<recipe_ordinal>& candidates) {
299   assert(!candidates.empty());
300   if (SIZE(candidates) == 1) return get(Recipe, candidates.at(0));
301   int min_score = 999;
302   int min_index = 0;
303   for (int i = 0;  i < SIZE(candidates);  ++i) {
304     const recipe& candidate = get(Recipe, candidates.at(i));
305     // prefer variants without extra or missing ingredients or products
306     int score = abs(SIZE(candidate.products)-SIZE(inst.products))
307                           + abs(SIZE(candidate.ingredients)-SIZE(inst.ingredients));
308     // prefer variants with non-address ingredients or products
309     for (int j = 0;  j < SIZE(candidate.ingredients);  ++j) {
310       if (is_mu_address(candidate.ingredients.at(j)))
311         ++score;
312     }
313     for (int j = 0;  j < SIZE(candidate.products);  ++j) {
314       if (is_mu_address(candidate.products.at(j)))
315         ++score;
316     }
317     assert(score < 999);
318     if (score < min_score) {
319       min_score = score;
320       min_index = i;
321     }
322   }
323   return get(Recipe, candidates.at(min_index));
324 }
325 
326 int non_ghost_size(vector<recipe_ordinal>& variants) {
327   int result = 0;
328   for (int i = 0;  i < SIZE(variants);  ++i)
329     if (variants.at(i) != -1) ++result;
330   return result;
331 }
332 
333 bool next_stash(const call& c, instruction* stash_inst) {
334   const recipe& specializer_recipe = get(Recipe, c.running_recipe);
335   int index = c.running_step_index;
336   for (++index;  index < SIZE(specializer_recipe.steps);  ++index) {
337     const instruction& inst = specializer_recipe.steps.at(index);
338     if (inst.name == "stash") {
339       *stash_inst = inst;
340       return true;
341     }
342   }
343   return false;
344 }
345 
346 :(scenario static_dispatch_disabled_in_recipe_without_variants)
347 def main [
348   1:num <- test 3
349 ]
350 def test [
351   2:num <- next-ingredient  # ensure no header
352   return 34
353 ]
354 +mem: storing 34 in location 1
355 
356 :(scenario static_dispatch_disabled_on_headerless_definition)
357 % Hide_errors = true;
358 def test a:num -> z:num [
359   z <- copy 1
360 ]
361 def test [
362   return 34
363 ]
364 +error: redefining recipe test
365 
366 :(scenario static_dispatch_disabled_on_headerless_definition_2)
367 % Hide_errors = true;
368 def test [
369   return 34
370 ]
371 def test a:num -> z:num [
372   z <- copy 1
373 ]
374 +error: redefining recipe test
375 
376 :(scenario static_dispatch_on_primitive_names)
377 def main [
378   1:num <- copy 34
379   2:num <- copy 34
380   3:bool <- equal 1:num, 2:num
381   4:bool <- copy false
382   5:bool <- copy false
383   6:bool <- equal 4:bool, 5:bool
384 ]
385 # temporarily hardcode number equality to always fail
386 def equal x:num, y:num -> z:bool [
387   local-scope
388   load-ingredients
389   z <- copy false
390 ]
391 # comparing numbers used overload
392 +mem: storing 0 in location 3
393 # comparing booleans continues to use primitive
394 +mem: storing 1 in location 6
395 
396 :(scenario static_dispatch_works_with_dummy_results_for_containers)
397 def main [
398   _ <- test 3, 4
399 ]
400 def test a:num -> z:point [
401   local-scope
402   load-ingredients
403   z <- merge a, 0
404 ]
405 def test a:num, b:num -> z:point [
406   local-scope
407   load-ingredients
408   z <- merge a, b
409 ]
410 $error: 0
411 
412 :(scenario static_dispatch_works_with_compound_type_containing_container_defined_after_first_use)
413 def main [
414   x:&:foo <- new foo:type
415   test x
416 ]
417 container foo [
418   x:num
419 ]
420 def test a:&:foo -> z:num [
421   local-scope
422   load-ingredients
423   z:num <- get *a, x:offset
424 ]
425 $error: 0
426 
427 :(scenario static_dispatch_works_with_compound_type_containing_container_defined_after_second_use)
428 def main [
429   x:&:foo <- new foo:type
430   test x
431 ]
432 def test a:&:foo -> z:num [
433   local-scope
434   load-ingredients
435   z:num <- get *a, x:offset
436 ]
437 container foo [
438   x:num
439 ]
440 $error: 0
441 
442 :(scenario static_dispatch_on_non_literal_character_ignores_variant_with_numbers)
443 % Hide_errors = true;
444 def main [
445   local-scope
446   x:char <- copy 10/newline
447   1:num/raw <- foo x
448 ]
449 def foo x:num -> y:num [
450   load-ingredients
451   return 34
452 ]
453 +error: main: ingredient 0 has the wrong type at '1:num/raw <- foo x'
454 -mem: storing 34 in location 1
455 
456 :(scenario static_dispatch_dispatches_literal_to_character)
457 def main [
458   1:num/raw <- foo 97
459 ]
460 def foo x:char -> y:num [
461   local-scope
462   load-ingredients
463   return 34
464 ]
465 # character variant is preferred
466 +mem: storing 34 in location 1
467 
468 :(scenario static_dispatch_dispatches_literal_to_number_if_at_all_possible)
469 def main [
470   1:num/raw <- foo 97
471 ]
472 def foo x:char -> y:num [
473   local-scope
474   load-ingredients
475   return 34
476 ]
477 def foo x:num -> y:num [
478   local-scope
479   load-ingredients
480   return 35
481 ]
482 # number variant is preferred
483 +mem: storing 35 in location 1
484 
485 :(replace{} "string header_label(const recipe_ordinal r)")
486 string header_label(const recipe_ordinal r) {
487   return header_label(get(Recipe, r));
488 }
489 :(code)
490 string header_label(const recipe& caller) {
491   ostringstream out;
492   out << "recipe " << caller.name;
493   for (int i = 0;  i < SIZE(caller.ingredients);  ++i)
494     out << ' ' << to_string(caller.ingredients.at(i));
495   if (!caller.products.empty()) out << " ->";
496   for (int i = 0;  i < SIZE(caller.products);  ++i)
497     out << ' ' << to_string(caller.products.at(i));
498   return out.str();
499 }
500 
501 string original_header_label(const recipe& caller) {
502   ostringstream out;
503   out << "recipe " << caller.original_name;
504   for (int i = 0;  i < SIZE(caller.ingredients);  ++i)
505     out << ' ' << caller.ingredients.at(i).original_string;
506   if (!caller.products.empty()) out << " ->";
507   for (int i = 0;  i < SIZE(caller.products);  ++i)
508     out << ' ' << caller.products.at(i).original_string;
509   return out.str();
510 }
511 
512 :(scenario reload_variant_retains_other_variants)
513 def main [
514   1:num <- copy 34
515   2:num <- foo 1:num
516 ]
517 def foo x:num -> y:num [
518   local-scope
519   load-ingredients
520   return 34
521 ]
522 def foo x:&:num -> y:num [
523   local-scope
524   load-ingredients
525   return 35
526 ]
527 def! foo x:&:num -> y:num [
528   local-scope
529   load-ingredients
530   return 36
531 ]
532 +mem: storing 34 in location 2
533 $error: 0
534 
535 :(scenario dispatch_errors_come_after_unknown_name_errors)
536 % Hide_errors = true;
537 def main [
538   y:num <- foo x
539 ]
540 def foo a:num -> b:num [
541   local-scope
542   load-ingredients
543   return 34
544 ]
545 def foo a:bool -> b:num [
546   local-scope
547   load-ingredients
548   return 35
549 ]
550 +error: main: missing type for 'x' in 'y:num <- foo x'
551 +error: main: failed to find a matching call for 'y:num <- foo x'
552 
553 :(scenario override_methods_with_type_abbreviations)
554 def main [
555   local-scope
556   s:text <- new [abc]
557   1:num/raw <- foo s
558 ]
559 def foo a:address:array:character -> result:number [
560   return 34
561 ]
562 # identical to previous variant once you take type abbreviations into account
563 def! foo a:text -> result:num [
564   return 35
565 ]
566 +mem: storing 35 in location 1
567 
568 :(scenario ignore_static_dispatch_in_type_errors_without_overloading)
569 % Hide_errors = true;
570 def main [
571   local-scope
572   x:&:num <- copy 0
573   foo x
574 ]
575 def foo x:&:char [
576   local-scope
577   load-ingredients
578 ]
579 +error: main: types don't match in call for 'foo x'
580 +error:   which tries to call 'recipe foo x:&:char'
581 
582 :(scenario show_available_variants_in_dispatch_errors)
583 % Hide_errors = true;
584 def main [
585   local-scope
586   x:&:num <- copy 0
587   foo x
588 ]
589 def foo x:&:char [
590   local-scope
591   load-ingredients
592 ]
593 def foo x:&:bool [
594   local-scope
595   load-ingredients
596 ]
597 +error: main: failed to find a matching call for 'foo x'
598 +error:   available variants are:
599 +error:     recipe foo x:&:char
600 +error:     recipe foo x:&:bool
601 
602 :(before "End Includes")
603 using std::abs;