Skip to content

Commit

Permalink
refactoring and fixing bagel cache
Browse files Browse the repository at this point in the history
  • Loading branch information
EvanDietzMorris committed Nov 7, 2024
1 parent 09aff44 commit a85c940
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 128 deletions.
5 changes: 2 additions & 3 deletions parsers/LitCoin/src/bagel/bagel.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def extract_best_match(bagel_results):
}
for result_curie, result in bagel_results.items():
syn_type = result["synonym_type"]
if syn_type == "unrelated":
if syn_type not in ranking:
continue
rank = min(result.get("name_res_rank", 1000), result.get("sapbert_rank", 1000))
ranking[syn_type].append({"id": result_curie, "name": result["name"], "rank": rank})
Expand All @@ -103,8 +103,7 @@ def extract_best_match(bagel_results):
return ranking["related"][0], "related"
else:
# throw out unrelated
return None, "unrelated"

return None, None

def augment_results(terms, nameres, taxes):
"""Given a dict where the key is a curie, and the value are data about the match, augment the value with
Expand Down
17 changes: 8 additions & 9 deletions parsers/LitCoin/src/bagel/bagel_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ def ask_classes_and_descriptions(text, term, termlist, abstract_id, requests_ses
"synonymType": ...
}}
]
where the value for synonym is the element from the synonym list, vocabulary class is the
class that I input associated with that synonym, and synonymType is either "exact" or "narrow".
where the value for synonym is the element from the synonym list, vocabulary class is the class that I input
associated with that synonym, and synonymType is either "exact", "narrow", "broad", or "related".
abstract: {text}
query_term: {term}
query term: {term}
possible_synonyms_classes_and_descriptions: {synonym_list}
"""

Expand All @@ -68,13 +68,12 @@ class that I input associated with that synonym, and synonymType is either "exac
for curie in curies:
termlist[curie]["synonym_Type"] = syntype

grouped_by_syntype = dict()
grouped_by_syntype = defaultdict(list)
for curie in termlist:
syntype = termlist[curie].get("synonym_Type", "unrelated")
termlist[curie]["curie"] = curie
if syntype not in grouped_by_syntype:
grouped_by_syntype[syntype] = []
grouped_by_syntype[syntype].append(termlist[curie])
syntype = termlist[curie].get("synonym_Type", None)
if syntype:
termlist[curie]["curie"] = curie
grouped_by_syntype[syntype].append(termlist[curie])
return grouped_by_syntype


Expand Down
181 changes: 65 additions & 116 deletions parsers/LitCoin/src/loadLitCoin.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,20 @@ def __init__(self, test_mode: bool = False, source_data_dir: str = None):
self.data_url = 'https://stars.renci.org/var/data_services/litcoin/'
# self.llm_output_file = 'rare_disease_abstracts_fixed._gpt4_20240320.json'
self.abstracts_file = 'abstracts_CompAndHeal.json'
self.llm_output_file = 'abstracts_CompAndHealpaca_v2.0_20241001.jsonl'
self.llm_output_file = 'abstracts_CompAndHealpaca_v2.0_20241001_truncated.jsonl'
# self.biolink_predicate_vectors_file = 'mapped_predicate_vectors.json'
# self.data_files = [self.llm_output_file, self.abstracts_file, self.biolink_predicate_vectors_file]
self.data_files = [self.llm_output_file, self.abstracts_file]

# dicts of name to id lookups organized by node type (node_name_to_id_lookup[node_type] = dict of names -> id)
# self.node_name_to_id_lookup = defaultdict(dict) <--- replaced with bagel
self.bagel_results_lookup = None
self.name_res_stats = []
self.bl_utils = BiolinkUtils()

self.mentions_predicate = "IAO:0000142"
self.parsing_metadata = {}

def get_latest_source_version(self) -> str:
latest_version = 'v3.0'
latest_version = 'v4.0'
return latest_version

def get_data(self) -> bool:
Expand All @@ -152,9 +151,6 @@ def parse_data(self) -> dict:
:return: ret_val: load_metadata
"""

# could use cached results for faster dev runs with something like this
# with open(os.path.join(self.data_path, "litcoin_name_res_results.json"), "w") as name_res_results_file:
# self.node_name_to_id_lookup = orjson.load(name_res_results_file)
"""
predicate_vectors_file_path = os.path.join(self.data_path,
self.biolink_predicate_vectors_file)
Expand All @@ -170,22 +166,23 @@ def parse_data(self) -> dict:
abstracts_file_path = os.path.join(self.data_path, self.abstracts_file)
abstracts = self.load_abstracts(abstracts_file_path)

self.parsing_metadata = {
'bagelization_errors': 0,
'failed_bagelization': 0,
'failed_abstract_lookup': 0,
}

records = 0
skipped_records = 0
failed_bagelization = 0
bagelization_errors = 0
bagelized_success = 0
failed_predicate_mapping = 0
failed_abstract_lookup = 0
missing_abstracts = set()
terms_that_could_not_be_bagelized = set()
predicates_that_could_not_be_mapped = set()
litcoin_file_path: str = os.path.join(self.data_path, self.llm_output_file)

for litcoin_edge in quick_jsonl_file_iterator(litcoin_file_path):

records += 1
if records == 10 and self.test_mode:
if records == 2000 and self.test_mode:
break

abstract_id = litcoin_edge[LLM.ABSTRACT_ID]
Expand All @@ -210,77 +207,33 @@ def parse_data(self) -> dict:
continue

try:
subject_name = litcoin_edge[LLM.SUBJECT_NAME]
if subject_name not in self.bagel_results_lookup:
try:
bagel_results = self.get_bagel_results(text=abstract_text,
entity=subject_name,
abstract_id=abstract_id)
self.bagel_results_lookup[subject_name] = bagel_results
bagelized_success += 1
except requests.exceptions.HTTPError as e:
self.logger.error(f'Failed Bagelization: {type(e)}:{e}')
skipped_records += 1
bagelization_errors += 1
if e.response.status_code == 429:
raise e
continue

subject_node, subject_bagel_synonym_type = self.bagelize_entity(entity_name=litcoin_edge[LLM.SUBJECT_NAME],
abstract_id=abstract_id,
abstract_text=abstract_text)
if subject_node:
subject_id = subject_node["id"]
subject_name = subject_node["name"]
else:
bagel_results = self.bagel_results_lookup[subject_name]
if 'error' in bagel_results:
skipped_records += 1
failed_bagelization += 1
self.logger.info(f'Skipping due to error in bagelization.')
continue
bagel_subject_node, subject_bagel_synonym_type = extract_best_match(bagel_results)
if not bagel_subject_node:
skipped_records += 1
failed_bagelization += 1
terms_that_could_not_be_bagelized.add(subject_name)
self.logger.info(f'Skipping due to bagelization finding no match.')
continue
subject_id = bagel_subject_node['id']
subject_name = bagel_subject_node['name']

object_name = litcoin_edge[LLM.OBJECT_NAME]
if object_name not in self.bagel_results_lookup:
try:
bagel_results = self.get_bagel_results(text=abstract_text,
entity=object_name,
abstract_id=abstract_id)
self.bagel_results_lookup[object_name] = bagel_results
bagelized_success += 1
except requests.exceptions.HTTPError as e:
self.logger.error(f'Failed Bagelization: {type(e)}:{e}')
skipped_records += 1
bagelization_errors += 1
if e.response.status_code == 429:
raise e
continue

object_node, object_bagel_synonym_type = self.bagelize_entity(entity_name=litcoin_edge[LLM.OBJECT_NAME],
abstract_id=abstract_id,
abstract_text=abstract_text)
if object_node:
object_id = object_node["id"]
object_name = object_node["name"]
else:
bagel_results = self.bagel_results_lookup[object_name]
if 'error' in bagel_results:
skipped_records += 1
failed_bagelization += 1
self.logger.info(f'Skipping due to error in bagelization.')
continue
bagel_object_node, object_bagel_synonym_type = extract_best_match(bagel_results)
if not bagel_object_node:
skipped_records += 1
failed_bagelization += 1
terms_that_could_not_be_bagelized.add(object_name)
self.logger.info(f'Skipping due to bagelization finding no match.')
continue
object_id = bagel_object_node['id']
object_name = bagel_object_node['name']

predicate = 'biolink:' + snakify(litcoin_edge[LLM.RELATIONSHIP])
# predicate = predicate_mapper.get_mapped_predicate(litcoin_edge[LLM.RELATIONSHIP])
if not predicate:
skipped_records += 1
failed_predicate_mapping += 1
self.logger.info(f'Skipping due to failed predicate mapping.')
continue
# if not predicate:
# skipped_records += 1
# self.logger.info(f'Skipping due to failed predicate mapping.')
# continue

self.output_file_writer.write_node(node_id=subject_id,
node_name=subject_name)
Expand Down Expand Up @@ -325,18 +278,46 @@ def parse_data(self) -> dict:
self.save_bagel_cache()
self.save_llm_results()

parsing_metadata = {
self.parsing_metadata.update({
'records': records,
'skipped_records': skipped_records,
'bagelization_errors': bagelization_errors,
'failed_bagelization': failed_bagelization,
'failed_abstract_lookup': failed_abstract_lookup,
'missing_abstracts': list(missing_abstracts),
'terms_that_could_not_be_bagelized': list(terms_that_could_not_be_bagelized),
'failed_predicate_mapping': failed_predicate_mapping,
'predicates_that_could_not_be_mapped': list(predicates_that_could_not_be_mapped),
}
return parsing_metadata
'predicates_that_could_not_be_mapped': list(predicates_that_could_not_be_mapped)
})
return self.parsing_metadata

def bagelize_entity(self, entity_name, abstract_id, abstract_text):
abstract_id = str(abstract_id)
if abstract_id in self.bagel_results_lookup and entity_name in self.bagel_results_lookup[abstract_id]["terms"]:
bagel_results = self.bagel_results_lookup[abstract_id]["terms"][entity_name]
else:
if abstract_id not in self.bagel_results_lookup:
self.bagel_results_lookup[abstract_id] = {"abstract": abstract_text,
"terms": {}}
try:
bagel_results = self.get_bagel_results(text=abstract_text,
entity=entity_name,
abstract_id=abstract_id)
self.bagel_results_lookup[abstract_id]["terms"][entity_name] = bagel_results
except requests.exceptions.HTTPError as e:
error_message = f'Failed Bagelization: {type(e)}:{e}'
self.logger.error(error_message)
if e.response.status_code == 429:
raise e
self.parsing_metadata['bagelization_errors'] += 1
bagel_results = {'error': error_message}
self.bagel_results_lookup[abstract_id]["terms"][entity_name] = bagel_results

if 'error' in bagel_results:
return None, None

bagel_node, bagel_synonym_type = extract_best_match(bagel_results)
if not bagel_node:
self.parsing_metadata['failed_bagelization'] += 1
self.logger.info(f'Skipping due to bagelization finding no match for: {entity_name} in abstract {abstract_id}')
return None, None
return bagel_node, bagel_synonym_type

def parse_llm_edge(self, llm_json_edge, logger):
converted_edge = {}
Expand Down Expand Up @@ -411,6 +392,7 @@ def load_bagel_cache(self):

def save_bagel_cache(self):
bagel_cache_file_path = self.get_bagel_cache_path()
print(self.bagel_results_lookup)
with open(bagel_cache_file_path, "w") as bagel_cache_file:
return json.dump(self.bagel_results_lookup,
bagel_cache_file,
Expand Down Expand Up @@ -441,39 +423,6 @@ def save_llm_results(self):
with open(llm_results_path, "w") as llm_results_file:
json.dump(previous_llm_results + results_to_add, llm_results_file, indent=4)

"""
replaced by bagel
def process_llm_node(self, node_name: str, node_type: str):
# check if we did name resolution for this name and type already and return it if so
if node_name in self.node_name_to_id_lookup[node_type]:
return self.node_name_to_id_lookup[node_type][node_name]
# otherwise call the name res service and try to find a match
# the following node_type string formatting conversion is kind of unnecessary now,
# it was intended to produce valid biolink types given the node_type from the llm,
# but that doesn't really work well enough to use, now we use the NODE_TYPE_MAPPINGS mappings,
# but the keys currently use the post-conversion format so this stays for now
biolink_node_type = self.convert_node_type_to_biolink_format(node_type)
preferred_biolink_node_type = NODE_TYPE_MAPPINGS.get(biolink_node_type, None)
self.logger.info(f'calling name res for {node_name} - {preferred_biolink_node_type}')
name_resolution_results = self.name_resolution_function(node_name, preferred_biolink_node_type)
standardized_name_res_result = self.standardize_name_resolution_results(name_resolution_results)
standardized_name_res_result['queried_type'] = preferred_biolink_node_type
self.node_name_to_id_lookup[node_type][node_name] = standardized_name_res_result
return standardized_name_res_result
def convert_node_type_to_biolink_format(self, node_type):
try:
biolink_node_type = re.sub("[()/]", "", node_type) # remove parentheses and forward slash
biolink_node_type = "".join([node_type_segment[0].upper() + node_type_segment[1:].lower()
for node_type_segment in biolink_node_type.split()]) # force Pascal case
return f'{biolink_node_type}'
except TypeError as e:
self.logger.error(f'Bad node type provided by llm: {node_type}')
return ""
"""

def name_resolution_function(self, node_name, preferred_biolink_node_type, retries=0):
return call_name_resolution(node_name,
preferred_biolink_node_type,
Expand Down

0 comments on commit a85c940

Please sign in to comment.