diff --git a/translation_util/translate_util.py b/translation_util/translate_util.py index 849bfcc646f683a69d970db472c0468957e3e357..6aeca6527a42d384169f58ba50017438e898e053 100644 --- a/translation_util/translate_util.py +++ b/translation_util/translate_util.py @@ -182,9 +182,12 @@ def translate_func(inputs, translation_server): i_src,tgt = list(),list() tagged_tgt,tagged_src = list(),list() s_id,n_id = [0000],[0000] + i_s0_src,i_s0_tgt,i_save = list(),list(),list() + try: for i in inputs: + s0_src,s0_tgt,save = "NA","NA",False logger.info(log_with_request_info(i.get("s_id"),LOG_TAGS["input"],i)) if all(v in i for v in ['s_id','n_id']): s_id = [i['s_id']] @@ -195,7 +198,12 @@ def translate_func(inputs, translation_server): out['response_body'] = [] logger.info("either id or src missing in some input") return (out) - + + if any(v in i for v in ['s0_src','s0_tgt','save']): + s0_src,s0_tgt,save = handle_custome_input(i,s0_src,s0_tgt,save) + + i_s0_src.append(s0_src),i_s0_tgt.append(s0_tgt),i_save.append(save) + logger.info("input sentences:{}".format(i['src'])) i_src.append(i['src']) i['src'] = i['src'].strip() @@ -403,7 +411,7 @@ def translate_func(inputs, translation_server): out['response_body'] = [{"tgt": tgt[i], "pred_score": pred_score[i], "s_id": sentence_id[i],"input_subwords": input_subwords[i], "output_subwords":output_subwords[i],"n_id":node_id[i],"src":i_src[i], - "tagged_tgt":tagged_tgt[i],"tagged_src":tagged_src[i]} + "tagged_tgt":tagged_tgt[i],"tagged_src":tagged_src[i],"save":i_save[i],"s0_src":i_s0_src[i],"s0_tgt":i_s0_tgt[i]} for i in range(len(tgt))] except ServerModelError as e: out['status'] = statusCode["SEVER_MODEL_ERR"] @@ -418,3 +426,12 @@ def translate_func(inputs, translation_server): return (out) +def handle_custome_input(i,s0_src,s0_tgt,save): + if 'save' in i: + save = i["save"] + if "s0_src" in i: + s0_src = i["s0_src"] + if "s0_tgt" in i: + s0_tgt = i["s0_tgt"] + + return s0_src,s0_tgt,save \ No newline at end of file