Coverage for barbet/embedding.py: 31.99%
322 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-12 04:23 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-12 04:23 +0000
1import gzip
2import os
3from pathlib import Path
4from abc import ABC, abstractmethod
5from Bio import SeqIO
6import random
7import numpy as np
8from rich.progress import track
9from hierarchicalsoftmax import SoftmaxNode
10from hierarchicalsoftmax import TreeDict
11import tarfile
12import torch
13from io import StringIO
14from torchapp.cli import CLIApp, tool, method
15import typer
16from dataclasses import dataclass
18from .data import read_memmap, RANKS
21def _open(path, mode='rt', **kwargs):
22 """
23 Open a file normally, or with gzip if it ends in .gz.
25 Args:
26 path (str or Path): The path to the file.
27 mode (str): The mode to open the file with (default 'rt' for reading text).
28 **kwargs: Additional arguments passed to open or gzip.open.
30 Returns:
31 A file object.
32 """
33 path = Path(path)
34 if path.suffix == '.gz':
35 return gzip.open(path, mode, **kwargs)
36 return open(path, mode, **kwargs)
39def set_validation_rank_to_treedict(
40 treedict:TreeDict,
41 validation_rank:str="species",
42 partitions:int=5,
43) -> TreeDict:
44 # find the taxonomic rank to use for the validation partition
45 validation_rank = validation_rank.lower()
46 assert validation_rank in RANKS
47 validation_rank_index = RANKS.index(validation_rank)
49 partitions_dict = {}
50 for key in treedict:
51 node = treedict.node(key)
52 # Assign validation partition at set rank
53 partition_node = node.ancestors[validation_rank_index]
54 if partition_node not in partitions_dict:
55 partitions_dict[partition_node] = random.randint(0,partitions-1)
57 treedict[key].partition = partitions_dict[partition_node]
59 return treedict
62def get_key(accession:str, gene:str) -> str:
63 """ Returns the standard format of a key """
64 key = f"{accession}/{gene}"
65 return key
68def get_node(lineage:str, lineage_to_node:dict[str,SoftmaxNode]) -> SoftmaxNode:
69 if lineage in lineage_to_node:
70 return lineage_to_node[lineage]
72 assert ";" in lineage, f"Semi-colon ';' not found in lineage '{lineage}'"
73 split_point = lineage.rfind(";")
74 parent_lineage = lineage[:split_point]
75 name = lineage[split_point+1:]
76 parent = get_node(parent_lineage, lineage_to_node)
77 node = SoftmaxNode(name=name, parent=parent)
78 lineage_to_node[lineage] = node
79 return node
82def generate_overlapping_intervals(total: int, interval_size: int, min_overlap: int, check:bool=True, variable_size:bool=False) -> list[tuple[int, int]]:
83 """
84 Creates a list of overlapping intervals within a specified range, adjusting the interval size to ensure
85 that the overlap is approximately the same across all intervals.
87 Args:
88 total (int): The total range within which intervals are to be created.
89 max_interval_size (int): The maximum size of each interval.
90 min_overlap (int): The minimum number of units by which consecutive intervals overlap.
91 check (bool): If True, checks are performed to ensure that the intervals meet the specified conditions.
93 Returns:
94 list[tuple[int, int]]: A list of tuples where each tuple represents the start (inclusive)
95 and end (exclusive) of an interval.
97 Example:
98 >>> generate_overlapping_intervals(20, 5, 2)
99 [(0, 5), (3, 8), (6, 11), (9, 14), (12, 17), (15, 20)]
100 """
101 intervals = []
102 start = 0
104 if total == 0:
105 return intervals
107 max_interval_size = interval_size
108 assert interval_size
109 assert min_overlap is not None
110 assert interval_size > min_overlap, f"Max interval size of {interval_size} must be greater than min overlap of {min_overlap}"
112 # Calculate the number of intervals needed to cover the range
113 num_intervals, remainder = divmod(total - min_overlap, interval_size - min_overlap)
114 if remainder > 0:
115 num_intervals += 1
117 # Calculate the exact interval size to ensure consistent overlap
118 overlap = min_overlap
119 if variable_size:
120 if num_intervals > 1:
121 interval_size, remainder = divmod(total + (num_intervals - 1) * overlap, num_intervals)
122 if remainder > 0:
123 interval_size += 1
124 else:
125 # If the size is fixed, then vary the overlap to keep it even
126 if num_intervals > 1:
127 overlap, remainder = divmod( num_intervals * interval_size - total, num_intervals - 1)
128 if overlap < min_overlap:
129 overlap = min_overlap
131 while True:
132 end = start + interval_size
133 if end > total:
134 end = total
135 start = max(end - interval_size,0)
136 intervals.append((start, end))
137 start += interval_size - overlap
138 if end >= total:
139 break
141 if check:
142 assert intervals[0][0] == 0
143 assert intervals[-1][1] == total
144 assert len(intervals) == num_intervals, f"Expected {num_intervals} intervals, got {len(intervals)}"
146 assert interval_size <= max_interval_size, f"Interval size of {interval_size} exceeds max interval size of {max_interval_size}"
147 for interval in intervals:
148 assert interval[1] - interval[0] == interval_size, f"Interval size of {interval[1] - interval[0]} is not the expected size {interval_size}"
150 for i in range(1, len(intervals)):
151 overlap = intervals[i - 1][1] - intervals[i][0]
152 assert overlap >= min_overlap, f"Min overlap condition of {min_overlap} not met for intervals {intervals[i - 1]} and {intervals[i]} (overlap {overlap})"
154 return intervals
157@dataclass
158class Embedding(CLIApp, ABC):
159 """ A class for embedding protein sequences. """
160 max_length:int|None=None
161 overlap:int=64
163 def __post_init__(self):
164 super().__init__()
166 @abstractmethod
167 def embed(self, seq:str) -> torch.Tensor:
168 """ Takes a protein sequence as a string and returns an embedding vector. """
169 raise NotImplementedError
171 def reduce(self, tensor:torch.Tensor) -> torch.Tensor:
172 if tensor.ndim == 2:
173 tensor = tensor.mean(dim=0)
174 assert tensor.ndim == 1
175 return tensor
177 def __call__(self, seq:str) -> torch.Tensor:
178 """ Takes a protein sequence as a string and returns an embedding vector. """
179 if not self.max_length or len(seq) <= self.max_length:
180 tensor = self.embed(seq)
181 return self.reduce(tensor)
183 epsilon = 0.1
184 intervals = generate_overlapping_intervals(len(seq), self.max_length, self.overlap)
185 weights = torch.zeros( (len(seq),), device="cpu" )
186 tensor = None
187 for start,end in intervals:
188 result = self.embed(seq[start:end]).cpu()
190 assert result.shape[0] == end-start
191 embedding_size = result.shape[1]
193 if tensor is None:
194 tensor = torch.zeros( (len(seq), embedding_size ), device="cpu")
196 assert tensor.shape[-1] == embedding_size
198 interval_indexes = torch.arange(end-start)
199 distance_from_ends = torch.min( interval_indexes-start, end-interval_indexes-1 )
201 weight = epsilon + torch.minimum(distance_from_ends, torch.tensor(self.overlap))
203 tensor[start:end] += result * weight.unsqueeze(1)
204 weights[start:end] += weight
206 tensor = tensor/weights.unsqueeze(1)
208 return self.reduce(tensor)
210 @method
211 def setup(self, **kwargs):
212 pass
214 def build_treedict(self, taxonomy:Path) -> tuple[TreeDict,dict[str,SoftmaxNode]]:
215 # Create root of tree
216 lineage_to_node = {}
217 root = None
219 # Fill out tree with taxonomy
220 accession_to_node = {}
221 with _open(taxonomy) as f:
222 for line in f:
223 accesssion, lineage = line.split("\t")
225 if not root:
226 root_name = lineage.split(";")[0]
227 root = SoftmaxNode(root_name)
228 lineage_to_node[root_name] = root
230 node = get_node(lineage, lineage_to_node)
231 accession_to_node[accesssion] = node
233 treedict = TreeDict(classification_tree=root)
234 return treedict, accession_to_node
236 @tool("setup")
237 def test_lengths(
238 self,
239 end:int=5_000,
240 start:int=1000,
241 retries:int=5,
242 **kwargs,
243 ):
244 def random_amino_acid_sequence(k):
245 amino_acids = "ACDEFGHIKLMNPQRSTVWY" # standard 20 amino acids
246 return ''.join(random.choice(amino_acids) for _ in range(k))
248 self.max_length = None
249 self.setup(**kwargs)
250 for ii in track(range(start,end)):
251 for _ in range(retries):
252 seq = random_amino_acid_sequence(ii)
253 try:
254 self(seq)
255 except Exception as err:
256 print(f"{ii}: {err}")
257 return
260 @tool("setup")
261 def build_gene_array(
262 self,
263 marker_genes:Path=typer.Option(default=..., help="The path to the marker genes tarball (e.g. bac120_msa_marker_genes_all_r220.tar.gz)."),
264 family_index:int=typer.Option(default=..., help="The index for the gene family to use. E.g. if there are 120 gene families then this should be a number from 0 to 119."),
265 output_dir:Path=typer.Option(default=..., help="A directory to store the output which includes the memmap array, the listing of accessions and an error log."),
266 flush_every:int=typer.Option(default=5_000, help="An interval to flush the memmap array as it is generated."),
267 max_length:int=None,
268 **kwargs,
269 ):
270 self.max_length = max_length
271 self.setup(**kwargs)
273 assert marker_genes is not None
274 assert family_index is not None
275 assert output_dir is not None
277 dtype = 'float16'
279 memmap_wip_array = None
280 output_dir.mkdir(parents=True, exist_ok=True)
281 memmap_wip_path = output_dir / f"{family_index}-wip.npy"
282 error = output_dir / f"{family_index}-errors.txt"
283 accessions_wip = output_dir / f"{family_index}-accessions-wip.txt"
285 accessions = []
287 print(f"Loading {marker_genes} file.")
288 with tarfile.open(marker_genes, "r:gz") as tar, open(error, "w") as error_file, open(accessions_wip, "w") as accessions_wip_file:
289 members = [member for member in tar.getmembers() if member.isfile() and member.name.endswith(".faa")]
290 prefix_length = len(os.path.commonprefix([Path(member.name).with_suffix("").name for member in members]))
292 member = members[family_index]
293 print(f"Processing file {family_index} in {marker_genes}")
295 f = tar.extractfile(member)
296 marker_id = Path(member.name).with_suffix("").name[prefix_length:]
298 fasta_io = StringIO(f.read().decode('ascii'))
300 total = sum(1 for _ in SeqIO.parse(fasta_io, "fasta"))
301 fasta_io.seek(0)
302 print(marker_id, total)
304 for record in track(SeqIO.parse(fasta_io, "fasta"), total=total):
305 # for record in SeqIO.parse(fasta_io, "fasta"):
306 species_accession = record.id
308 key = get_key(species_accession, marker_id)
310 seq = str(record.seq).replace("-","").replace("*","")
311 try:
312 vector = self(seq)
313 except Exception as err:
314 print(f"{key} ({len(seq)}): {err}", file=error_file)
315 print(f"{key} ({len(seq)}): {err}")
316 continue
318 if vector is None:
319 print(f"{key} ({len(seq)}): Embedding is None", file=error_file)
320 print(f"{key} ({len(seq)}): Embedding is None")
321 continue
323 if torch.isnan(vector).any():
324 print(f"{key} ({len(seq)}): Embedding contains NaN", file=error_file)
325 print(f"{key} ({len(seq)}): Embedding contains NaN")
326 continue
328 if memmap_wip_array is None:
329 size = len(vector)
330 shape = (total,size)
331 memmap_wip_array = np.memmap(memmap_wip_path, dtype=dtype, mode='w+', shape=shape)
333 index = len(accessions)
334 memmap_wip_array[index,:] = vector.cpu().half().numpy()
335 if index % flush_every == 0:
336 memmap_wip_array.flush()
338 accessions.append(key)
339 print(key, file=accessions_wip_file)
341 memmap_wip_array.flush()
343 accessions_path = output_dir / f"{family_index}.txt"
344 with open(accessions_path, "w") as f:
345 for accession in accessions:
346 print(accession, file=f)
348 # Save final memmap array now that we now the final size
349 memmap_path = output_dir / f"{family_index}.npy"
350 shape = (len(accessions),size)
351 print(f"Writing final memmap array of shape {shape}: {memmap_path}")
352 memmap_array = np.memmap(memmap_path, dtype=dtype, mode='w+', shape=shape)
353 memmap_array[:len(accessions),:] = memmap_wip_array[:len(accessions),:]
354 memmap_array.flush()
356 # Clean up
357 memmap_array._mmap.close()
358 memmap_array._mmap = None
359 memmap_array = None
360 memmap_wip_path.unlink()
361 accessions_wip.unlink()
363 @tool
364 def set_validation_rank(
365 self,
366 treedict:Path=typer.Option(default=..., help="The path to the treedict file."),
367 output:Path=typer.Option(default=..., help="The path to save the adapted treedict file."),
368 validation_rank:str=typer.Option(default="species", help="The rank to hold out for cross-validation."),
369 partitions:int=typer.Option(default=5, help="The number of cross-validation partitions."),
370 ) -> TreeDict:
371 treedict = TreeDict.load(treedict)
372 set_validation_rank_to_treedict(treedict, validation_rank=validation_rank, partitions=partitions)
373 treedict.save(output)
374 return treedict
376 @tool
377 def preprocess(
378 self,
379 taxonomy:Path=typer.Option(default=..., help="The path to the TSV taxonomy file (e.g. bac120_taxonomy_r220.tsv)."),
380 marker_genes:Path=typer.Option(default=..., help="The path to the marker genes tarball (e.g. bac120_msa_marker_genes_all_r220.tar.gz)."),
381 output_dir:Path=typer.Option(default=..., help="A directory to store the output which includes the memmap array, the listing of accessions and an error log."),
382 partitions:int=typer.Option(default=5, help="The number of cross-validation partitions."),
383 seed:int=typer.Option(default=42, help="The random seed."),
384 treedict_only:bool=typer.Option(default=False, help="Only output TreeDict file and then exit before concatenating memmap array"),
385 ):
386 treedict, accession_to_node = self.build_treedict(taxonomy)
388 dtype = 'float16'
390 random.seed(seed)
392 print(f"Loading {marker_genes} file.")
393 with tarfile.open(marker_genes, "r:gz") as tar:
394 members = [member for member in tar.getmembers() if member.isfile() and member.name.endswith(".faa")]
395 family_count = len(members)
396 print(f"{family_count} gene families found.")
398 # Read and collect accessions
399 print(f"Building treedict")
400 keys = []
401 counts = []
402 node_to_partition_dict = dict()
403 for family_index in track(range(family_count)):
404 keys_path = output_dir / f"{family_index}.txt"
406 if not keys_path.exists():
407 counts.append(0)
408 continue
410 with open(keys_path) as f:
411 family_index_keys = [line.strip() for line in f]
412 keys += family_index_keys
413 counts.append(len(family_index_keys))
415 for key in family_index_keys:
416 genome_accession = key.split("/")[0]
417 node = accession_to_node[genome_accession]
418 partition = node_to_partition_dict.setdefault(node, random.randint(0, partitions - 1))
420 # Add to treedict
421 treedict.add(key, node, partition)
423 assert len(counts) == family_count
425 # Save treedict
426 treedict_path = output_dir / f"{output_dir.name}.td"
427 print(f"Saving TreeDict to {treedict_path}")
428 treedict.save(treedict_path)
430 if treedict_only:
431 return
433 # Concatenate numpy memmap arrays
434 memmap_array = None
435 memmap_array_path = output_dir / f"{output_dir.name}.npy"
436 print(f"Saving memmap to {memmap_array_path}")
437 current_index = 0
438 for family_index, family_count in track(enumerate(counts), total=len(counts)):
439 my_memmap_path = output_dir / f"{family_index}.npy"
441 # Build memmap for gene family if it doesn't exist
442 if not my_memmap_path.exists():
443 continue
444 # print("Building", my_memmap_path)
445 # self.build_gene_array(marker_genes=marker_genes, family_index=family_index, output_dir=output_dir)
446 # assert my_memmap_path.exists()
448 my_memmap = read_memmap(my_memmap_path, family_count)
450 # Build memmap for output if it doesn't exist
451 if memmap_array is None:
452 size = my_memmap.shape[1]
453 shape = (len(keys),size)
454 memmap_array = np.memmap(memmap_array_path, dtype=dtype, mode='w+', shape=shape)
456 # Copy memmap for gene family into output memmap
457 memmap_array[current_index:current_index+family_count,:] = my_memmap[:,:]
459 current_index += family_count
461 assert len(keys) == current_index
463 memmap_array.flush()
465 # Save keys
466 keys_path = output_dir / f"{output_dir.name}.txt"
467 print(f"Saving keys to {keys_path}")
468 with open(keys_path, "w") as f:
469 for key in keys:
470 print(key, file=f)
472 @tool
473 def prune_to_representatives(treedict:Path, representatives:Path, output:Path):
474 print("Getting list of representatives from", representatives)
475 keys_to_keep = []
476 with tarfile.open(representatives, "r:gz") as tar:
477 members = [member for member in tar.getmembers() if member.isfile() and member.name.endswith(".faa")]
479 print(f"Processing {len(members)} files in {representatives}")
481 for member in track(members):
482 f = tar.extractfile(member)
483 marker_id = Path(member.name.split("_")[-1]).with_suffix("").name
485 fasta_io = StringIO(f.read().decode('ascii'))
487 for record in SeqIO.parse(fasta_io, "fasta"):
488 species_accession = record.id
489 key = get_key(species_accession, marker_id)
490 keys_to_keep.append(key)
492 # keys_to_keep = set(keys_to_keep)
493 print(f"Keeping {len(keys_to_keep)} representatives")
495 print(f"Loading treedict {treedict}")
497 treedict = TreeDict.load(treedict)
498 print("Total", len(treedict))
499 missing = []
500 for key in track(keys_to_keep):
501 if key not in treedict:
502 missing.append(key)
504 print(f"{len(missing)} representatives missing output {len(keys_to_keep)} (total: {len(treedict)})")
505 if len(missing):
506 keys_to_keep = [k for k in keys_to_keep if k not in missing]
508 new_treedict = TreeDict(treedict.classification_tree)
509 new_treedict.update({k:treedict[k] for k in keys_to_keep})
510 print("Total after pruning", len(new_treedict))
512 print("Saving treedict to", output)
513 new_treedict.save(output)