Source code for alleleTools.format.hla_group

import os
from typing import Union

import pandas as pd

from ..argtypes import add_out_altable_args
from ..utils.assets import download_file, get_asset_path
from ..allele import AlleleParser
from .alleleTable import AlleleTable
import swifter


[docs] def setup_parser(subparsers): parser = subparsers.add_parser( name="hla_group", help="Group alleles by g- or p-group nomenclature", description="Group alleles according to the g- or p-group nomenclature", epilog="Author: Nicolás Mendoza Mejía (2025)", ) # Input/output arguments parser.add_argument( "input", type=str, help="Allele table with the alleles that you want to convert", ) parser.add_argument( "--group_type", type=str, choices=["g-group", "p-group"], help="Type of group nomenclature to use (g-group or p-group)", default="g-group", ) parser = add_out_altable_args(parser) parser.set_defaults(func=call_function) return parser
[docs] def call_function(args): group = GrouperHLA(args.group_type) alt_before = AlleleTable.open(args.input) # Melt and apply group conversion df = alt_before.alleles.melt(ignore_index=False, var_name='col_index', value_name='allele') df.index.name = 'sample' df = df.reset_index() # Apply conversions df[['gene', 'allele_stripped']] = df['allele'].str.split('*', n=1, expand=True) df = quick_lookup(group, df) # Slow lookup on remaining NaNs for alleles with a resolution 2 or more fields miss_group = df['group'].isna() & (df['allele_stripped'].str.count(':') >= 1) missing = df[miss_group] df.loc[miss_group, 'group'] = missing.swifter.apply(lambda row: group.lookup(row['gene'], row['allele_stripped']), axis=1) df.loc[df['group'].isna(), 'group'] = df.loc[df['group'].isna(), 'allele'] alt = AlleleTable() alt.alleles = df.pivot_table(index="sample", columns='col_index', values='group', aggfunc='first') alt.phenotype = alt_before.phenotype alt.phenotype.name = "phenotype" alt.to_csv(args.output)
[docs] def quick_lookup(group: "GrouperHLA", df: pd.DataFrame) -> pd.DataFrame: """ This optimization was not comprehensibly tested to see if it improves performance, but in theory it should be faster than doing individual lookups. It only does exact matches though. Args: group (GrouperHLA): The GrouperHLA instance with the g-group/p-group index. df (pd.DataFrame): DataFrame with 'gene' and 'allele_stripped' columns. Returns: pd.DataFrame: DataFrame with an added 'group' column. """ ref_rows = [] for g, mapping in group.index.items(): # mapping: {allele_stripped: group_value} for allele_s, grp in mapping.items(): ref_rows.append((g, allele_s, grp)) ref_df = pd.DataFrame(ref_rows, columns=['gene', 'allele_stripped', 'group']) # Bulk join exact matches (fast) return df.merge(ref_df, how='left', on=['gene', 'allele_stripped'])
[docs] class GrouperHLA: def __init__(self, reference_file: str = "g-group"): """ Load the reference file that was downloaded from: https://hla.alleles.org/pages/wmda/g_groups/ """ file_path = self._get_group_norm_file(reference_file) ref = pd.read_csv(file_path, sep=';', comment="#", header=None) ref.columns = ["gene", "alleles", "groups"] ref["gene"] = ref["gene"].str.replace("*", "") ref["alleles"] = ref["alleles"].apply(lambda x: x.split('/')) # generate and index for each gene index = dict() grouped = ref.groupby("gene") for gene, group in grouped: index[gene] = dict() dic = group.apply( lambda row: {a: row["groups"] for a in row["alleles"]}, axis=1).to_list() dic = {k: v for d in dic for k, v in d.items()} index[gene].update(dic) self.index = index def _get_group_norm_file(self, group: str) -> str: group_files = { "g-group":( "hla_nom_g.txt", "https://raw.githubusercontent.com/ANHIG/IMGTHLA/Latest/wmda/hla_nom_g.txt" ), "p-group":( "hla_nom_p.txt", "https://raw.githubusercontent.com/ANHIG/IMGTHLA/Latest/wmda/hla_nom_p.txt" ), } if group not in group_files: raise ValueError(f"Unknown group type: {group}. Supported types are {list(group_files.keys())}.") file, url = group_files[group] path = get_asset_path(file) if not os.path.exists(path): download_file( url=url, dest=path, ) return path
[docs] def lookup(self, gene: str, allele_stripped: str) -> Union[str, None]: """ Lookup g-group for the given allele. If exact match is not found, attempt to find a partial match. """ gene_ref = self.index.get(gene, None) if gene_ref is None or allele_stripped is None: return None if allele_stripped not in gene_ref: return self.lookup_partial(gene, allele_stripped) return self.lookup_exact(gene, allele_stripped)
[docs] def lookup_exact(self, gene: str, allele_stripped: str) -> Union[str, None]: """ Finds a g-group for the exact allele provided. """ gene_ref = self.index.get(gene, None) if gene_ref is None or allele_stripped is None: return None ret = gene_ref[allele_stripped] return gene + '*' + ret if isinstance(ret, str) else ret
[docs] def lookup_partial(self, gene: str, allele_stripped: str) -> Union[str, None]: """ Finds a g-group for the partial allele provided. The index does not contain all posible combinations of alleles. It only has the highest resolution alleles, so looking for a partial match is necessary. A g-group is returned only if one possible match is found. """ # TODO: Optimize this function to avoid looping over all entries # could be done by building a trie structure for the alleles. gene_ref = self.index.get(gene, None) if gene_ref is None or allele_stripped is None: return None # Look for partial matches partial = [k for k in gene_ref.keys() if k.startswith(allele_stripped)] if len(partial) == 0: return None # If more than one partial match, check if all fit in the same group groups = [gene_ref[p] for p in partial] uniq_groups = set(groups) if len(uniq_groups) != 1: # Ambiguous match return self._resolve_ambiguity(gene, allele_stripped, uniq_groups) # Return the single partial match ret = uniq_groups.pop() return gene + '*' + ret if isinstance(ret, str) else ret
def _resolve_ambiguity(self, gene: str, allele_stripped: str, uniq_groups: set) -> Union[str, None]: parser = AlleleParser("hla") allele_p = parser.parse(gene + '*' + allele_stripped) # Alleles of one field or less shouldn't be resolved if len(allele_p) <= 1: return None groups_p = [parser.parse(gene + '*' + group) for group in uniq_groups if isinstance(group, str)] truncated_groups = set([p.truncate(len(allele_p)) for p in groups_p]) if len(truncated_groups) != 1: return None return str(truncated_groups.pop())