diff --git a/strAPI/main.py b/strAPI/main.py index 4559a00..feb602e 100644 --- a/strAPI/main.py +++ b/strAPI/main.py @@ -176,6 +176,12 @@ def show_allele_freqs(repeat_id: int, db: Session = Depends(get_db)): # return [] return allfreqs +@app.get("/allseq/", response_model=List[schemas.AlleleSequence], tags=["Repeats"]) +def show_allele_seq(repeat_id: int, db: Session = Depends(get_db)): + allseq = db.query(models.AlleleSequence).filter(models.AlleleSequence.repeat_id == repeat_id).all() + return allseq + + """ Retrieve repeat info given a repeat id @@ -457,4 +463,4 @@ def get_crc_expr_repeatlen_corr(db: Session = Depends(get_db), limit = 7000): joinedload(models.CRCExprRepeatLenCorr.gene), joinedload(models.CRCExprRepeatLenCorr.repeat) ).all() - return [{**c.gene.__dict__, **c.repeat.__dict__, **c.__dict__} for c in correlations] \ No newline at end of file + return [{**c.gene.__dict__, **c.repeat.__dict__, **c.__dict__} for c in correlations] diff --git a/strAPI/repeats/models.py b/strAPI/repeats/models.py index e5a235a..86b8165 100644 --- a/strAPI/repeats/models.py +++ b/strAPI/repeats/models.py @@ -39,7 +39,7 @@ class Transcript(SQLModel, table=True): end: int = Field(nullable=False) # one to many Gene -> Transcripts - gene_id = Field(Integer, foreign_key ="genes.id") + gene_id: int = Field(Integer, foreign_key ="genes.id") gene: "Gene" = Relationship(back_populates="transcripts") @@ -183,7 +183,23 @@ class AlleleFrequency(SQLModel, table=True): # One to many, Repeat -> AlleleFrequency repeat_id: int = Field(foreign_key = "repeats.id") repeat: "Repeat" = Relationship(back_populates="allfreqs") - + + +class AlleleSequence(SQLModel, table=True): + __tablename__ = "allele_sequences" + + id: int = Field(primary_key=True) + + population: str = Field(nullable=False) + n_effective: int = Field(nullable=False) + frequency: float = Field(nullable=False) + num_called: Optional[int] = Field(nullable=True) + sequence: str = Field(nullable=False) + + # One to many, Repeat -> AlleleSequence + repeat_id: int = Field(foreign_key = "repeats.id") + repeat: "Repeat" = Relationship(back_populates="allseq") + class Repeat(SQLModel, table=True): __tablename__ = "repeats" @@ -221,7 +237,12 @@ class Repeat(SQLModel, table=True): allfreqs: Optional[List["AlleleFrequency"]] = Relationship( back_populates="repeat" ) - + + # Add relationship directive to Repeat class for one to many Repeat - AlleleSequence + allseq: Optional[List["AlleleSequence"]] = Relationship( + back_populates="repeat" + ) + # many to many Repeats <-> Transcripts transcripts: List["Transcript"] = Relationship(back_populates="repeats", link_model = RepeatTranscriptsLink) diff --git a/strAPI/repeats/schemas.py b/strAPI/repeats/schemas.py index a64c8f7..0f322b8 100644 --- a/strAPI/repeats/schemas.py +++ b/strAPI/repeats/schemas.py @@ -83,6 +83,17 @@ class AlleleFrequency(BaseModel): class Config: orm_mode = True +class AlleleSequence(BaseModel): + population: str + n_effective: int + frequency: float + num_called: int + repeat_id: int + sequence: str + + class Config: + orm_mode = True + class CRCVariation(BaseModel): tcga_barcode: str sample_type: str @@ -123,4 +134,4 @@ class Config: orm_mode = True - \ No newline at end of file + diff --git a/update_tf.py b/update_tf.py new file mode 100644 index 0000000..991d121 --- /dev/null +++ b/update_tf.py @@ -0,0 +1,145 @@ +import argparse +from sqlalchemy import create_engine +import pandas as pd +import numpy as np +from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound +from strAPI.repeats.models import Repeat, Gene, TRPanel, AlleleSequence +import math +import os +from tqdm import tqdm + +#chrom start end acount-AFR acount-AMR acount-EAS acount-EUR acount-SAS +#chr10 89246 89285 CACATACACACACACACACACACACACACACAC:2,CACATACACACACACACACACACACACACACACAC:261,CACATACACACACACACACACACACACACACA +def round_sf(number): + if number == 0: + return 0.0 + else: + #using 2 because we decided on 2 sf + precision = 2 - int(math.floor(math.log10(abs(number)))) - 1 + return round(number, precision) + + +def connection_setup(db_path): + engine = create_engine(db_path, echo=False) + Session = sessionmaker(bind=engine) + session = Session() + return engine, session + + +def process_file(filepath, session, error_log_file, skipped_lines_file): + with open(filepath, "r") as file: + next(file) + + pop_indices = {"AFR": 3, "AMR": 4, "EAS": 5, "EUR": 6, "SAS": 7} + total_lines = sum(1 for line in file) # Count total lines in the file + + # Reset file pointer to the beginning of the file + file.seek(0) + + # Create tqdm progress bar for the current file + progress_bar = tqdm(total=total_lines, desc=f"Processing {os.path.basename(filepath)}", unit="line") + + for line_number, line in enumerate(file, start=1): + progress_bar.update(1) # Increment progress bar by 1 for each line processed + columns = line.strip().split() + + try: + #print("Columns:", columns) + db_repeat = session.query(Repeat).filter( + Repeat.source == 'EnsembleTR', + Repeat.start == int(columns[1]), + Repeat.end.between(int(columns[2]) - 2, int(columns[2]) + 2), + Repeat.chr == columns[0] + ).one() + # print("Database Repeat:", db_repeat) + + except ValueError as ex: + # Handle the case where the start column cannot be converted to an integer + error_message = f"Invalid value for start column: {ex} with line contents: {columns}" + #print(error_message) + error_log_file.write(error_message + '\n') # Write error message to log file + continue + + except IndexError as ex: + # Handle the case where there are not enough columns in the line + error_message = f"Not enough columns in line: {line}" + #print(error_message) + error_log_file.write(error_message + '\n') # Write error message to log file + continue + + except MultipleResultsFound as ex: + error_message = f"Multiple results found for chr: {columns[0]}:{columns[1]}-{columns[2]}" + error_log_file.write(error_message + '\n') # Write error message to log file + continue + + except NoResultFound as ex: + error_message = f"No matching record found for chr: {columns[0]}:{columns[1]}-{columns[2]}" + error_log_file.write(error_message + '\n') # Write error message to log file + continue + +##chrom start end acount-AFR acount-AMR acount-EAS acount-EUR acount-SAS +#chr10 89246 89285 CACATACACACACACACACACACACACACACAC:2,CACATACACACACACACACACACACACACACACAC:261, + + for pop in ["AFR", "AMR", "EAS", "EUR", "SAS"]: + popdata = columns[pop_indices[pop]] + if popdata != ".": # Check if there is any data for this population + popdata_items = popdata.split(",") + #CACATACACACACACACACACACACACACACAC:2 + + # If there's only one sequence or more, calculate total counts accordingly + if len(popdata_items) == 1 and ':' in popdata_items[0]: + total = int(popdata_items[0].split(":")[1]) + else: + total = sum(int(item.split(":")[1]) for item in popdata_items if ':' in item) + + for a in popdata_items: + allele_items = a.split(":") + if len(allele_items) == 2: + aseq, acount = allele_items + acount = int(acount) + freq = round_sf(acount / total) if total > 0 else 0 + db_seq = AlleleSequence( + repeat_id=db_repeat.id, + population=pop, + n_effective=db_repeat.n_effective, + frequency=freq, + num_called=acount, + sequence=aseq + ) + db_repeat.allseq.append(db_seq) + else: + # Log or handle the case where an allele doesn't have the expected format + skipped_lines_file.write(f"Skipped or malformed allele entry at line {line_number}: {line}\n") + + + # Close the tqdm progress bar + progress_bar.close() + + + +def main(): + default_path = "postgresql://webstr:webstr@localhost:5432/strdb" + default_error_log = "error_chr1.log" + default_skipped_lines_file = "skipped_lines_chr1.txt" + default_file ="/gymreklab-tscc/creeve/chr/chr1.tab" + parser = argparse.ArgumentParser(description="Insert data into PostgreSQL database") + parser.add_argument("--db_path", type=str, default=default_path, help="PostgreSQL connection URL") + parser.add_argument("--file", type=str, default=default_file,help="File to process") + parser.add_argument("--error_log", type=str, default=default_error_log, help="Path to error log file") + parser.add_argument("--skipped_lines_file", type=str, default=default_skipped_lines_file, help="Path to skipped lines file") + + args = parser.parse_args() + engine, session = connection_setup(args.db_path) + + with open(args.error_log, 'w') as error_log_file, \ + open(args.skipped_lines_file, 'w') as skipped_lines_file: + process_file(args.file, session, error_log_file, skipped_lines_file) + + session.commit() + session.close() + engine.dispose() + print("Data inserted successfully") + +if __name__ == "__main__": + main()