Hello, I am running train_mode.py but get_easse_report_from_exp_dir fail. This is the error
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-15-5dc913ca0d14> in <module>()
----> 1 result = fairseq_train_and_evaluate_with_parametrization(**kwargs)
13 frames
/content/drive/MyDrive/muss/muss/fairseq/main.py in fairseq_train_and_evaluate_with_parametrization(dataset, **kwargs)
228 kwargs['preprocessor_kwargs'] = recommended_preprocessors_kwargs
229 # Evaluation
--> 230 scores = print_running_time(fairseq_evaluate_and_save)(exp_dir, **kwargs)
231 score = combine_metrics(scores['bleu'], scores['sari'], scores['fkgl'], kwargs.get('metrics_coefs', [0, 1, 0]))
232 # TODO: This is a redundant hack with what happens in fairseq_evaluate_and_save (predict_files and evaluate_kwargs), it should be fixed
/content/drive/MyDrive/muss/muss/utils/helpers.py in wrapped_func(*args, **kwargs)
468 function_name = getattr(func, '__name__', repr(func))
469 with log_action(function_name):
--> 470 return func(*args, **kwargs)
471
472 return wrapped_func
/content/drive/MyDrive/muss/muss/fairseq/main.py in fairseq_evaluate_and_save(exp_dir, **kwargs)
104 print(f'scores={scores}')
105 report_path = exp_dir / 'easse_report.html'
--> 106 shutil.move(get_easse_report_from_exp_dir(exp_dir, **kwargs), report_path)
107 print(f'report_path={report_path}')
108 predict_files = kwargs.get(
/content/drive/MyDrive/muss/muss/fairseq/main.py in get_easse_report_from_exp_dir(exp_dir, **kwargs)
97 def get_easse_report_from_exp_dir(exp_dir, **kwargs):
98 simplifier = fairseq_get_simplifier(exp_dir, **kwargs)
---> 99 return get_easse_report(simplifier, **kwargs.get('evaluate_kwargs', {'test_set': 'asset_valid'}))
100
101
/content/drive/MyDrive/muss/muss/evaluation/general.py in get_easse_report(simplifier, test_set, orig_sents_path, refs_sents_paths)
40 orig_sents_path=orig_sents_path,
41 refs_sents_paths=refs_sents_paths,
---> 42 report_path=report_path,
43 )
44 return report_path
/usr/local/lib/python3.7/dist-packages/easse/cli.py in report(test_set, sys_sents_path, orig_sents_path, refs_sents_paths, report_path, tokenizer, lowercase, metrics)
302 lowercase=lowercase,
303 tokenizer=tokenizer,
--> 304 metrics=metrics,
305 )
306
/usr/local/lib/python3.7/dist-packages/easse/report.py in write_html_report(filepath, *args, **kwargs)
477 def write_html_report(filepath, *args, **kwargs):
478 with open(filepath, 'w') as f:
--> 479 f.write(get_html_report(*args, **kwargs) + '\n')
480
481
/usr/local/lib/python3.7/dist-packages/easse/report.py in get_html_report(orig_sents, sys_sents, refs_sents, test_set, lowercase, tokenizer, metrics)
471 doc.stag('hr')
472 with doc.tag('div', klass='container-fluid'):
--> 473 doc.asis(get_qualitative_examples_html(orig_sents, sys_sents, refs_sents))
474 return indent(doc.getvalue())
475
/usr/local/lib/python3.7/dist-packages/easse/report.py in get_qualitative_examples_html(orig_sents, sys_sents, refs_sents)
154 sample_generator = sorted(
155 zip(orig_sents, sys_sents, zip(*refs_sents)),
--> 156 key=lambda args: sort_key(*args),
157 )
158 # Samples displayed by default
/usr/local/lib/python3.7/dist-packages/easse/report.py in <lambda>(args)
154 sample_generator = sorted(
155 zip(orig_sents, sys_sents, zip(*refs_sents)),
--> 156 key=lambda args: sort_key(*args),
157 )
158 # Samples displayed by default
/usr/local/lib/python3.7/dist-packages/easse/report.py in <lambda>(c, s, refs)
91 (
92 'Best simplifications according to SARI',
---> 93 lambda c, s, refs: -corpus_sari([c], [s], [refs]),
94 lambda value: f'SARI={-value:.2f}',
95 ),
/usr/local/lib/python3.7/dist-packages/easse/sari.py in corpus_sari(*args, **kwargs)
264
265 def corpus_sari(*args, **kwargs):
--> 266 add_score, keep_score, del_score = get_corpus_sari_operation_scores(*args, **kwargs)
267 return (add_score + keep_score + del_score) / 3
/usr/local/lib/python3.7/dist-packages/easse/sari.py in get_corpus_sari_operation_scores(orig_sents, sys_sents, refs_sents, lowercase, tokenizer, legacy, use_f1_for_deletion, use_paper_version)
254 refs_sents = [[utils_prep.normalize(sent, lowercase, tokenizer) for sent in ref_sents] for ref_sents in refs_sents]
255
--> 256 stats = compute_ngram_stats(orig_sents, sys_sents, refs_sents)
257
258 if not use_paper_version:
/usr/local/lib/python3.7/dist-packages/easse/sari.py in compute_ngram_stats(orig_sents, sys_sents, refs_sents)
110 assert all(
111 len(ref_sents) == len(orig_sents) for ref_sents in refs_sents
--> 112 ), "Reference sentences don't have the shape (n_references, n_samples)"
113 add_sys_correct = [0] * NGRAM_ORDER
114 add_sys_total = [0] * NGRAM_ORDER
AssertionError: Reference sentences don't have the shape (n_references, n_samples)
I printed out where the error occurs and it showed that
len(refs_sents)=1
len(ref_sents)=10
len(orig_sents)=1
which I suppose should be like this?
len(refs_sents)=10
len(ref_sents)=1
len(orig_sents)=1
I am not sure how to make this change happen without impacting the outcome of the code. I'll appreciate any advice. Thank you in advance!