]> Nishi Git Mirror - gwion.git/commitdiff
:art: improve template inference
authorfennecdjay <fennecdjay@gmail.com>
Sun, 14 Aug 2022 18:57:25 +0000 (20:57 +0200)
committerfennecdjay <fennecdjay@gmail.com>
Sun, 14 Aug 2022 18:57:25 +0000 (20:57 +0200)
src/parse/check.c

index e14a99e7f91a9251f7e83c361f9138eeb54cf46f..7dc6ce2780d3162db502358ef565ead934137292 100644 (file)
@@ -475,14 +475,27 @@ ANN static inline Type type_list_base(const Gwion gwion, const Type type) {
   return !(is_func(gwion, type)) ? type : type_list_base_func(type);
 }
 
-ANN static Type_Decl* mk_td(const Env env, const Arg *arg,
+ANN static Type_Decl*  mk_td(const Env env, Type_Decl *td,
                                   const Type type, const loc_t pos) {
   const Type base = type_list_base(env->gwion, type);
-  const Type t    = !arg->td->array ?
-          base : array_type(env, base, arg->td->array->depth);
+  const Type t    = !td->array ?
+          base : array_type(env, base, td->array->depth);
   return type2td(env->gwion, t, pos);
 }
 
+ANN static Type tmpl_arg_match(const Env env, const Symbol xid, const Symbol tgt, const Type t) {
+  if (xid == tgt) return t;
+  if(!tflag(t, tflag_cdef)) return NULL;
+  const uint32_t len = mp_vector_len(t->info->cdef->base.tmpl->list);
+  for(uint32_t i = 0; i < len; i++) {
+    Specialized *spec = mp_vector_at(t->info->cdef->base.tmpl->list, Specialized, i);
+    const Type base = nspc_lookup_type1(t->nspc, spec->xid);
+    const Type t = tmpl_arg_match(env, xid, spec->xid, base);
+    if(t) return t;
+  }
+  return NULL;
+}
+
 ANN2(1, 2)
 static Func find_func_match_actual(const Env env, const Func f, const Exp exp,
                                    const bool implicit, const bool specific) {
@@ -676,10 +689,10 @@ ANN static Type_List check_template_args(const Env env, Exp_Call *exp,
   if(exp->other) {
     for(uint32_t i = 0; i < len; i++) {
       Specialized *spec = mp_vector_at(sl, Specialized, i);
-      if (spec->xid == fdef->base->td->xid) { // check no next?
+      if (tmpl_arg_match(env, spec->xid, fdef->base->td->xid, fdef->base->ret_type)) {
         CHECK_OO(check_exp(env, exp->other));
          if(!is_func(env->gwion, exp->other->type)) {
-           Type_Decl *td = type2td(env->gwion, exp->other->type, fdef->base->pos);
+           Type_Decl *td = type2td(env->gwion, exp->other->type, fdef->base->td->pos);
            mp_vector_set(tl, Type_Decl*, 0, td);
          } else {
            Func func = exp->other->type->info->func;
@@ -705,9 +718,9 @@ ANN static Type_List check_template_args(const Env env, Exp_Call *exp,
     uint32_t count = 0;
     while (count < args_len && template_arg) {
       Arg *arg = mp_vector_at(args, Arg, count);
-      if (spec->xid == arg->td->xid) {
+      if (tmpl_arg_match(env, spec->xid, arg->td->xid, template_arg->type)) {
         mp_vector_set(tl, Type_Decl*, args_number,
-          mk_td(env, arg, template_arg->type, fdef->base->pos));
+          mk_td(env, arg->td, template_arg->type, fdef->base->pos));
         ++args_number;
         break;
       }