importmarimoasmoimportmatplotlib.pyplotaspltfromnovelentitymatcherimportMatchermo.md("""# Training Impact AnalyzerCompare **zero-shot** vs **trained** matching side-by-side. Adjust the numberof training samples and see how accuracy changes for known and tricky inputs.""")
entities=[{"id":"DE","name":"Germany","aliases":["Deutschland"]},{"id":"FR","name":"France","aliases":["Frankreich"]},{"id":"US","name":"United States","aliases":["USA","America"]},{"id":"JP","name":"Japan","aliases":["Nippon"]},{"id":"CN","name":"China","aliases":["Zhongguo"]},]full_training=[{"text":"Germany","label":"DE"},{"text":"Deutschland","label":"DE"},{"text":"Deutchland","label":"DE"},{"text":"GER","label":"DE"},{"text":"France","label":"FR"},{"text":"French Republic","label":"FR"},{"text":"La France","label":"FR"},{"text":"FRA","label":"FR"},{"text":"United States","label":"US"},{"text":"USA","label":"US"},{"text":"America","label":"US"},{"text":"U.S.A.","label":"US"},{"text":"Japan","label":"JP"},{"text":"Nippon","label":"JP"},{"text":"Nihon","label":"JP"},{"text":"China","label":"CN"},{"text":"Zhongguo","label":"CN"},{"text":"PRC","label":"CN"},]test_queries=[("Deutchland","DE"),("America","US"),("Frankreich","FR"),("Nihon","JP"),("PRC","CN"),("Bundesrepublik","DE"),("U.S. of A","US"),("La Republique","FR"),]zero_matcher=Matcher(entities=entities,mode="zero-shot")zero_matcher.fit()
_n=n_samples.valueif_n>0:_training_subset=full_training[:_n]trained_matcher=Matcher(entities=entities,verbose=False)trained_matcher.fit(training_data=_training_subset,num_epochs=1)_trained_mode="trained"else:trained_matcher=zero_matcher_trained_mode="same (zero-shot)"_rows=[]for_query,_expectedintest_queries:_zr=zero_matcher.match(_query)_z_entry=_zrifisinstance(_zr,dict)else_zr_z_id=_z_entry.get("id","?")ifisinstance(_z_entry,dict)else"?"_z_score=_z_entry.get("score",0)ifisinstance(_z_entry,dict)else0_tr=trained_matcher.match(_query)_t_entry=_trifisinstance(_tr,dict)else_tr_t_id=_t_entry.get("id","?")ifisinstance(_t_entry,dict)else"?"_t_score=_t_entry.get("score",0)ifisinstance(_t_entry,dict)else0_rows.append({"query":_query,"expected":_expected,"zero_shot_id":_z_id,"zero_shot_score":f"{_z_score:.2%}","zero_shot_correct":"OK"if_z_id==_expectedelse"MISS","trained_id":_t_id,"trained_score":f"{_t_score:.2%}","trained_correct":"OK"if_t_id==_expectedelse"MISS",})mo.ui.table(_rows,label=f"Comparison: zero-shot vs {_trained_mode} ({_n} samples)")
_n=n_samples.value_sample_counts=list(range(min(_n+1,len(full_training)+1)))_zero_acc=[]_trained_acc=[]for_countin_sample_counts:_z_correct=sum(1for_q,_expintest_queriesif(_entry:=zero_matcher.match(_q))and(_e:=_entryifisinstance(_entry,dict)else_entry)and(_e.get("id","?")ifisinstance(_e,dict)else"?")==_exp)_zero_acc.append(_z_correct/len(test_queries))if_count>0:_subset=full_training[:_count]try:_tm=Matcher(entities=entities,verbose=False)_tm.fit(training_data=_subset,num_epochs=1)_t_correct=sum(1for_q,_expintest_queriesif(_entry:=_tm.match(_q))and(_e:=_entryifisinstance(_entry,dict)else_entry)and(_e.get("id","?")ifisinstance(_e,dict)else"?")==_exp)_trained_acc.append(_t_correct/len(test_queries))exceptException:_trained_acc.append(_zero_acc[-1])else:_trained_acc.append(_zero_acc[-1])_fig,_ax=plt.subplots(figsize=(8,4))_ax.plot(_sample_counts,_zero_acc,"o--",label="Zero-shot",color="#3498db")_ax.plot(_sample_counts,_trained_acc,"s-",label="Trained",color="#e74c3c")_ax.set_xlabel("Number of Training Samples")_ax.set_ylabel("Accuracy")_ax.set_title("Zero-Shot vs Trained Matching Accuracy")_ax.set_ylim(0,1.05)_ax.legend()_ax.grid(True,alpha=0.3)plt.tight_layout()_fig