In [1]:
import allel
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go  # type: ignore

import warnings
warnings.filterwarnings('ignore')

In [2]:
metadata_path = '../../../results/config/metadata.qcpass.tsv'
bed_targets_path = "../../../config/ag-vampir.bed"
vcf_path = "../../../results/vcfs/targets/ag-vampir-002.annot.vcf"
wkdir = "../../.."
cohort_cols = 'location,taxon'

In [3]:
# Parameters
metadata_path = "results/config/metadata.qcpass.tsv"
wkdir = "/home/snagi/lstm_projects/ampseq-agvampir002"
vcf_path = "results/vcfs/targets/ag-vampir-002.annot.vcf"
bed_targets_path = "config/ag-vampir.bed"
cohort_cols = "location,taxon"


In [4]:
import sys
import os
sys.path.append(os.path.join(wkdir, 'workflow'))
import ampseekertools as amp

### Species ID

In [5]:
cohort_col = cohort_cols.split(',')[0]

metadata = pd.read_csv(metadata_path , sep="\t")

import json
with open(f"{wkdir}/config/metadata_colours.json", 'r') as f:
    color_mapping = json.load(f)
    
targets = pd.read_csv(bed_targets_path, sep="\t", header=None)
targets.columns = ['contig', 'start', 'end', 'amplicon', 'mutation', 'ref', 'alt']

geno, pos, gn_contigs, metadata, refs, alts, ann = amp.load_vcf(vcf_path=vcf_path, metadata=metadata)
samples = metadata['sample_id'].to_list()

alts = np.concatenate([refs.reshape(refs.shape[0], -1), alts], axis=1)

In [6]:
def _aims_n_alt(gt, aim_alts, data_alts):
    n_sites = gt.shape[0]
    n_samples = gt.shape[1]
    # create empty array
    aim_n_alt = np.empty((n_sites, n_samples), dtype=np.int8)

    # for every site
    for i in range(n_sites):
        # find the index of the correct tag snp allele
        tagsnp_index = np.where(aim_alts[i] == data_alts[i])[0]
        for j in range(n_samples):
            n_tag_alleles = np.sum(gt[i, j] == tagsnp_index[0])
            n_missing = np.sum(gt[i, j] == -1)
            if n_missing != 0:
                aim_n_alt[i,j] = -1
            else:
                aim_n_alt[i, j] = n_tag_alleles

    return aim_n_alt

contigs = ['2R', '2L', '3R', '3L', 'X']
df_aims = targets.query("mutation.str.contains('AIM')", engine='python')

aim_mask = np.isin(pos, df_aims.end.to_list())
aim_gn = geno.compress(aim_mask, axis=0)
aim_pos = pos[aim_mask]
aim_contigs = gn_contigs[aim_mask]
aim_alts = alts[aim_mask]

aim_loc = ["aim_" + c + ":" + str(aim_pos[i]) for i, c in enumerate(aim_contigs)]
df_aims = df_aims.assign(loc=lambda x: "aim_" + x.contig + ":" + x.end.astype(str)).set_index('loc')
df_aims = df_aims.loc[aim_loc]

aim_gn_alt = _aims_n_alt(aim_gn, aim_alts=df_aims.alt.to_list(), data_alts=aim_alts)
df_aims = pd.concat([df_aims, pd.DataFrame(aim_gn_alt, columns=samples, index=aim_loc)], axis=1)
df_aims = pd.concat([df_aims.query(f"contig == '{contig}'") for contig in contigs])

# sort by cohort_col and then within that, by aim fraction 
aimplot_sample_order = []
for coh in metadata[cohort_col].unique():
    coh_samples = metadata.query(f"{cohort_col} == '{coh}'").sample_id.to_list()
    coh_samples_aim_order = df_aims.iloc[:, 7:].loc[:, coh_samples].replace({-1: np.nan}).mean().sort_values(ascending=True).index.to_list()
    aimplot_sample_order.extend(coh_samples_aim_order)

# exclude samples with missing data
# n_missing = df_aims.replace({-1: np.nan}).iloc[:, 7:].isna().sum(axis=0).sort_values(ascending=False)
# missing_samples = n_missing[n_missing > 20].index.to_list()
# aimplot_sample_order = [s for s in aimplot_sample_order if s not in missing_samples]
from plotly.subplots import make_subplots
col_widths = [
    np.count_nonzero(aim_contigs == contig)
    for contig in contigs
]

fig = make_subplots(
    rows=1,
    cols=len(contigs),
    shared_yaxes=True,
    column_titles=contigs,
    row_titles=None,
    column_widths=col_widths,
    x_title=None,
    y_title=None,
    horizontal_spacing=0.01,
    vertical_spacing=0.01,
)

species = "gamb_vs_colu".split("_vs_")
# Define a colorbar.
colorbar = dict(
    title="AIM genotype",
    tickmode="array",
    tickvals=[-1, 0, 1, 2],
    ticktext=[
        "missing",
        f"{species[0]}/{species[0]}",
        f"{species[0]}/{species[1]}",
        f"{species[1]}/{species[1]}",
    ],
    len=100,
    lenmode="pixels",
    y=1,
    yanchor="top",
    outlinewidth=1,
    outlinecolor="black",
)

# Set up default AIMs color palettes.
colors = px.colors.qualitative.T10
color_gambcolu = colors[6]
color_gamb = colors[0]
color_gamb_colu_het = colors[5]
color_colu = colors[2]
color_missing = "white"
palette = (
        color_missing,
        color_gamb,
        color_gamb_colu_het,
        color_colu,
    )

colorscale = [
    [0 / 4, palette[0]],
    [1 / 4, palette[0]],
    [1 / 4, palette[1]],
    [2 / 4, palette[1]],
    [2 / 4, palette[2]],
    [3 / 4, palette[2]],
    [3 / 4, palette[3]],
    [4 / 4, palette[3]],
]

# Create the subplots, one for each contig.
for j, contig in enumerate(contigs):

    df_aims_contig = df_aims.filter(like=contig, axis=0)
    df_aims_contig = df_aims_contig.iloc[:, 7:]  
    df_aims_contig = df_aims_contig.loc[:, aimplot_sample_order]
    df_aims_contig = df_aims_contig.T

    fig.add_trace(
        go.Heatmap(
            y=df_aims_contig.index,
            z=df_aims_contig,
            x=df_aims_contig.columns,
            colorscale=colorscale,
            zmin=-1.5,
            zmax=2.5,
            xgap=0,
            ygap=0.5,  # this creates faint lines between rows
            colorbar=colorbar,
        ),
        row=1,
        col=j + 1,
    )

fig.update_layout(
    title=f"AIMs - gambiae vs coluzzii",
    height=max(600, 1.2 * len(samples) + 300),
)
fig.write_image(f"{wkdir}/results/aims_gamb_vs_colu.png", scale=2)

fig.show()

#### Species assignments by cohorts

In [7]:
# solely based on X chromosome, all the others look unreliable
df = df_aims.query("contig == 'X'").iloc[:, 7:]
df.columns = samples
mean_aims = df.replace(-1, float('nan')).apply(np.nanmean, axis=0)
max_missing_x_aims = 7
mean_aims[df.replace(-1, float('nan')).isna().sum(axis=0) > max_missing_x_aims] = np.nan
aims = mean_aims.loc[metadata.set_index('sample_id').index]
metadata = metadata.assign(mean_aim_genotype=aims.values)

taxon = []
for i, row in metadata.iterrows():
    if row.mean_aim_genotype == np.nan:
        taxon.append('uncertain')
    elif row.mean_aim_genotype < 0.5:
        taxon.append('gambiae')
    elif row.mean_aim_genotype >= 0.5 and row.mean_aim_genotype < 1.5:
        taxon.append('uncertain')
    elif row.mean_aim_genotype >= 1.5:
        taxon.append('coluzzii')
    else: 
        taxon.append(np.nan)

new_metadata = metadata.assign(taxon=taxon)

fig = px.histogram(
    new_metadata,
    nbins=100, 
    x='mean_aim_genotype', 
    color=cohort_col, 
    color_discrete_map=color_mapping[cohort_col],
    width=750, 
    height=400, 
    template='plotly_white', 
    title='AIM genotype distribution'
    )
fig.show()

In [8]:
from IPython.display import display, Markdown
new_metadata.to_csv(f"{wkdir}/results/config/metadata.qcpass.tsv", sep="\t", index=False)
new_metadata[['sample_id', 'taxon', 'mean_aim_genotype']].to_csv(f"{wkdir}/results/ag-vampir/aims/taxon_aims.tsv", sep="\t", index=False)
display(Markdown(f'<a href={wkdir}/results/ag-vampir/aims/taxon_aims.tsv>Sample aims and taxon assignment (.tsv)</a>'))

<a href=/home/snagi/lstm_projects/ampseq-agvampir002/results/ag-vampir/aims/taxon_aims.tsv>Sample aims and taxon assignment (.tsv)</a>