about summary refs log tree commit diff stats
path: root/057static_dispatch.cc
diff options
context:
space:
mode:
authorKartik K. Agaram <vc@akkartik.com>2016-02-11 16:13:10 -0800
committerKartik K. Agaram <vc@akkartik.com>2016-02-11 17:29:50 -0800
commit5d67fac7966b6f05611d014420eb4971b8016c31 (patch)
tree4ad84631596f09a994a2154b77e80118e7189b37 /057static_dispatch.cc
parente4b03c6f574ade8c3fe4df18b958da981ac58acb (diff)
downloadmu-5d67fac7966b6f05611d014420eb4971b8016c31.tar.gz
2646 - redo static dispatch algorithm
The old approach of ad hoc boosts and penalties based on various
features was repeatedly running into exceptions and bugs. New
organization: multiple tiered scores interleaved with tie-breaks. The
moment one tier yields one or more candidates, we stop scanning further
tiers. Just break ties and return.
Diffstat (limited to '057static_dispatch.cc')
-rw-r--r--057static_dispatch.cc224
1 files changed, 145 insertions, 79 deletions
diff --git a/057static_dispatch.cc b/057static_dispatch.cc
index fcae37ee..9c71ac38 100644
--- a/057static_dispatch.cc
+++ b/057static_dispatch.cc
@@ -160,35 +160,40 @@ void resolve_ambiguous_calls(recipe_ordinal r) {
   for (long long int index = 0; index < SIZE(caller_recipe.steps); ++index) {
     instruction& inst = caller_recipe.steps.at(index);
     if (inst.is_label) continue;
-    if (get_or_insert(Recipe_variants, inst.name).empty()) continue;
+    if (non_ghost_size(get_or_insert(Recipe_variants, inst.name)) == 0) continue;
+    trace(9992, "transform") << "instruction " << inst.original_string << end();
     resolve_stack.push_front(call(r));
     resolve_stack.front().running_step_index = index;
-    replace_best_variant(inst, caller_recipe);
+    string new_name = best_variant(inst, caller_recipe);
+    if (!new_name.empty())
+      inst.name = new_name;
     assert(resolve_stack.front().running_recipe == r);
     assert(resolve_stack.front().running_step_index == index);
     resolve_stack.pop_front();
   }
 }
 
-void replace_best_variant(instruction& inst, const recipe& caller_recipe) {
-  trace(9992, "transform") << "instruction " << inst.original_string << end();
+string best_variant(instruction& inst, const recipe& caller_recipe) {
   vector<recipe_ordinal>& variants = get(Recipe_variants, inst.name);
-//?   trace(9992, "transform") << "checking base: " << get(Recipe_ordinal, inst.name) << end();
-  long long int best_score = variant_score(inst, get(Recipe_ordinal, inst.name));
-  trace(9992, "transform") << "score for base: " << best_score << end();
-  for (long long int i = 0; i < SIZE(variants); ++i) {
-    if (variants.at(i) == -1) continue;
-    trace(9992, "transform") << "checking variant " << i << ": " << header_label(variants.at(i)) << end();
-    long long int current_score = variant_score(inst, variants.at(i));
-    trace(9992, "transform") << "score for variant " << i << ": " << current_score << end();
-    if (current_score > best_score) {
-      trace(9993, "transform") << "switching " << inst.name << " to " << get(Recipe, variants.at(i)).name << end();
-      inst.name = get(Recipe, variants.at(i)).name;
-      best_score = current_score;
-    }
-  }
-  // End Instruction Dispatch(inst, best_score)
-  if (best_score == -1 && get(Recipe_ordinal, inst.name) >= MAX_PRIMITIVE_RECIPES) {
+  vector<recipe_ordinal> candidates;
+
+  // Static Dispatch Phase 1
+  candidates = strictly_matching_variants(inst, variants);
+  if (!candidates.empty()) return best_variant(inst, candidates).name;
+
+  // Static Dispatch Phase 2 (shape-shifting recipes in a later layer)
+  // End Static Dispatch Phase 2
+
+  // Static Dispatch Phase 3
+  candidates = strictly_matching_variants_except_literal_against_boolean(inst, variants);
+  if (!candidates.empty()) return best_variant(inst, candidates).name;
+
+  // Static Dispatch Phase 4
+  candidates = matching_variants(inst, variants);
+  if (!candidates.empty()) return best_variant(inst, candidates).name;
+
+  // error messages
+  if (get(Recipe_ordinal, inst.name) >= MAX_PRIMITIVE_RECIPES) {  // we currently don't check types for primitive variants
     raise_error << maybe(caller_recipe.name) << "failed to find a matching call for '" << inst.to_string() << "'\n" << end();
     for (list<call>::iterator p = /*skip*/++resolve_stack.begin(); p != resolve_stack.end(); ++p) {
       const recipe& specializer_recipe = get(Recipe, p->running_recipe);
@@ -209,83 +214,144 @@ void replace_best_variant(instruction& inst, const recipe& caller_recipe) {
       }
     }
   }
+  return "";
 }
 
-bool next_stash(const call& c, instruction* stash_inst) {
-  const recipe& specializer_recipe = get(Recipe, c.running_recipe);
-  long long int index = c.running_step_index;
-  for (++index; index < SIZE(specializer_recipe.steps); ++index) {
-    const instruction& inst = specializer_recipe.steps.at(index);
-    if (inst.name == "stash") {
-      *stash_inst = inst;
-      return true;
-    }
+// phase 1
+vector<recipe_ordinal> strictly_matching_variants(const instruction& inst, vector<recipe_ordinal>& variants) {
+  vector<recipe_ordinal> result;
+  for (long long int i = 0; i < SIZE(variants); ++i) {
+    if (variants.at(i) == -1) continue;
+    trace(9992, "transform") << "checking variant (strict) " << i << ": " << header_label(variants.at(i)) << end();
+    if (all_header_reagents_strictly_match(inst, get(Recipe, variants.at(i))))
+      result.push_back(variants.at(i));
   }
-  return false;
+  return result;
 }
 
-long long int variant_score(const instruction& inst, recipe_ordinal variant) {
-  long long int result = 100;
-  if (variant == -1) return -1;  // ghost from a previous test
-//?   cerr << "variant score: " << inst.to_string() << '\n';
-  if (!contains_key(Recipe, variant)) {
-    assert(variant < MAX_PRIMITIVE_RECIPES);
-    return -1;
+bool all_header_reagents_strictly_match(const instruction& inst, const recipe& variant) {
+  for (long long int i = 0; i < min(SIZE(inst.ingredients), SIZE(variant.ingredients)); ++i) {
+    if (!types_strictly_match(variant.ingredients.at(i), inst.ingredients.at(i))) {
+      trace(9993, "transform") << "strict match failed: ingredient " << i << end();
+      return false;
+    }
   }
-  const vector<reagent>& header_ingredients = get(Recipe, variant).ingredients;
-//?   cerr << "=== checking ingredients\n";
-  for (long long int i = 0; i < min(SIZE(inst.ingredients), SIZE(header_ingredients)); ++i) {
-    if (!types_match(header_ingredients.at(i), inst.ingredients.at(i))) {
-      trace(9993, "transform") << "mismatch: ingredient " << i << end();
-//?       cerr << "mismatch: ingredient " << i << '\n';
-      return -1;
+  for (long long int i = 0; i < min(SIZE(inst.products), SIZE(variant.products)); ++i) {
+    if (is_dummy(inst.products.at(i))) continue;
+    if (!types_strictly_match(variant.products.at(i), inst.products.at(i))) {
+      trace(9993, "transform") << "strict match failed: product " << i << end();
+      return false;
     }
-    if (types_strictly_match(header_ingredients.at(i), inst.ingredients.at(i))) {
-      trace(9993, "transform") << "strict match: ingredient " << i << end();
-//?       cerr << "strict match: ingredient " << i << '\n';
+  }
+  return true;
+}
+
+// phase 3
+vector<recipe_ordinal> strictly_matching_variants_except_literal_against_boolean(const instruction& inst, vector<recipe_ordinal>& variants) {
+  vector<recipe_ordinal> result;
+  for (long long int i = 0; i < SIZE(variants); ++i) {
+    if (variants.at(i) == -1) continue;
+    trace(9992, "transform") << "checking variant (strict except literals-against-booleans) " << i << ": " << header_label(variants.at(i)) << end();
+    if (all_header_reagents_strictly_match_except_literal_against_boolean(inst, get(Recipe, variants.at(i))))
+      result.push_back(variants.at(i));
+  }
+  return result;
+}
+
+bool all_header_reagents_strictly_match_except_literal_against_boolean(const instruction& inst, const recipe& variant) {
+  for (long long int i = 0; i < min(SIZE(inst.ingredients), SIZE(variant.ingredients)); ++i) {
+    if (!types_strictly_match_except_literal_against_boolean(variant.ingredients.at(i), inst.ingredients.at(i))) {
+      trace(9993, "transform") << "strict match failed: ingredient " << i << end();
+      return false;
     }
-    else if (boolean_matches_literal(header_ingredients.at(i), inst.ingredients.at(i))) {
-      // slight penalty for coercing literal to boolean (prefer direct conversion to number if possible)
-      trace(9993, "transform") << "boolean matches literal: ingredient " << i << end();
-      result--;
+  }
+  for (long long int i = 0; i < min(SIZE(variant.products), SIZE(inst.products)); ++i) {
+    if (is_dummy(inst.products.at(i))) continue;
+    if (!types_strictly_match_except_literal_against_boolean(variant.products.at(i), inst.products.at(i))) {
+      trace(9993, "transform") << "strict match failed: product " << i << end();
+      return false;
     }
-    else {
-      // slightly larger penalty for modifying type in other ways
-      trace(9993, "transform") << "non-strict match: ingredient " << i << end();
-//?       cerr << "non-strict match: ingredient " << i << '\n';
-      result-=10;
+  }
+  return true;
+}
+
+// phase 4
+vector<recipe_ordinal> matching_variants(const instruction& inst, vector<recipe_ordinal>& variants) {
+  vector<recipe_ordinal> result;
+  for (long long int i = 0; i < SIZE(variants); ++i) {
+    if (variants.at(i) == -1) continue;
+    trace(9992, "transform") << "checking variant " << i << ": " << header_label(variants.at(i)) << end();
+    if (all_header_reagents_match(inst, get(Recipe, variants.at(i))))
+      result.push_back(variants.at(i));
+  }
+  return result;
+}
+
+bool all_header_reagents_match(const instruction& inst, const recipe& variant) {
+  for (long long int i = 0; i < min(SIZE(inst.ingredients), SIZE(variant.ingredients)); ++i) {
+    if (!types_match(variant.ingredients.at(i), inst.ingredients.at(i))) {
+      trace(9993, "transform") << "strict match failed: ingredient " << i << end();
+      return false;
     }
   }
-//?   cerr << "=== done checking ingredients\n";
-  const vector<reagent>& header_products = get(Recipe, variant).products;
-  for (long long int i = 0; i < min(SIZE(header_products), SIZE(inst.products)); ++i) {
+  for (long long int i = 0; i < min(SIZE(variant.products), SIZE(inst.products)); ++i) {
     if (is_dummy(inst.products.at(i))) continue;
-    if (!types_match(header_products.at(i), inst.products.at(i))) {
-      trace(9993, "transform") << "mismatch: product " << i << end();
-//?       cerr << "mismatch: product " << i << '\n';
-      return -1;
-    }
-    if (types_strictly_match(header_products.at(i), inst.products.at(i))) {
-      trace(9993, "transform") << "strict match: product " << i << end();
-//?       cerr << "strict match: product " << i << '\n';
+    if (!types_match(variant.products.at(i), inst.products.at(i))) {
+      trace(9993, "transform") << "strict match failed: product " << i << end();
+      return false;
     }
-    else if (boolean_matches_literal(header_products.at(i), inst.products.at(i))) {
-      // slight penalty for coercing literal to boolean (prefer direct conversion to number if possible)
-      trace(9993, "transform") << "boolean matches literal: product " << i << end();
-      result--;
+  }
+  return true;
+}
+
+// tie-breaker for each phase
+const recipe& best_variant(const instruction& inst, vector<recipe_ordinal>& candidates) {
+  assert(!candidates.empty());
+  long long int min_score = 999;
+  long long int min_index = 0;
+  for (long long int i = 0; i < SIZE(candidates); ++i) {
+    const recipe& candidate = get(Recipe, candidates.at(i));
+    long long int score = abs(SIZE(candidate.products)-SIZE(inst.products))
+                          + abs(SIZE(candidate.ingredients)-SIZE(inst.ingredients));
+    assert(score < 999);
+    if (score < min_score) {
+      min_score = score;
+      min_index = i;
     }
-    else {
-      // slightly larger penalty for modifying type in other ways
-      trace(9993, "transform") << "non-strict match: product " << i << end();
-//?       cerr << "non-strict match: product " << i << '\n';
-      result-=10;
+  }
+  return get(Recipe, candidates.at(min_index));
+}
+
+long long int non_ghost_size(vector<recipe_ordinal>& variants) {
+  long long int result = 0;
+  for (long long int i = 0; i < SIZE(variants); ++i)
+    if (variants.at(i) != -1) ++result;
+  return result;
+}
+
+bool next_stash(const call& c, instruction* stash_inst) {
+  const recipe& specializer_recipe = get(Recipe, c.running_recipe);
+  long long int index = c.running_step_index;
+  for (++index; index < SIZE(specializer_recipe.steps); ++index) {
+    const instruction& inst = specializer_recipe.steps.at(index);
+    if (inst.name == "stash") {
+      *stash_inst = inst;
+      return true;
     }
   }
-  // the greater the number of unused ingredients/products, the lower the score
-  return result - abs(SIZE(get(Recipe, variant).products)-SIZE(inst.products))
-                - abs(SIZE(inst.ingredients)-SIZE(get(Recipe, variant).ingredients));
+  return false;
 }
 
+:(scenario static_dispatch_disabled_in_recipe_without_variants)
+recipe main [
+  1:number <- test 3
+]
+recipe test [
+  2:number <- next-ingredient  # ensure no header
+  reply 34
+]
++mem: storing 34 in location 1
+
 :(scenario static_dispatch_disabled_on_headerless_definition)
 % Hide_warnings = true;
 recipe test a:number -> z:number [