Coverage for barbet/data.py: 63.98%
161 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-12 04:23 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-12 04:23 +0000
1import random
2import os
3import numpy as np
4from collections import defaultdict
5from typing import Iterable
6import torch
7from torch.utils.data import DataLoader
8from torch.utils.data import Dataset
9import lightning as L
10from dataclasses import dataclass, field
11from hierarchicalsoftmax import TreeDict
14RANKS = ["phylum", "class", "order", "family", "genus", "species"]
16def read_memmap(path, count, dtype:str="float16") -> np.memmap:
17 file_size = os.path.getsize(path)
18 dtype_size = np.dtype(dtype).itemsize
19 num_elements = file_size // dtype_size
20 embedding_size = num_elements // count
21 shape = (count, embedding_size)
22 return np.memmap(path, dtype=dtype, mode='r', shape=shape)
25def gene_id_from_accession(accession:str):
26 return accession.split("/")[-1]
29def choose_k_from_n(lst, k) -> list[int]:
30 n = len(lst)
31 if n == 0:
32 return []
33 repetitions = k // n
34 remainder = k % n
35 result = lst * repetitions + random.sample(lst, remainder)
36 return result
39@dataclass(kw_only=True)
40class BarbetStack():
41 genome:str
42 array_indices:np.array
44 def __post_init__(self):
45 assert self.array_indices.ndim == 1, "Stack indices must be a 1D array"
48@dataclass(kw_only=True)
49class BarbetPredictionDataset(Dataset):
50 array:np.memmap|np.ndarray
51 accessions: list[str]
52 stack_size:int
53 repeats:int = 2
54 seed:int = 42
55 stacks: list[BarbetStack] = field(init=False)
56 genome_filter:set[str]|None = None
58 def __post_init__(self):
59 genome_to_array_indices = defaultdict(set)
60 for index, accession in enumerate(self.accessions):
61 slash_position = accession.rfind("/")
62 assert slash_position != -1
63 genome = accession[:slash_position]
64 if self.genome_filter and genome not in self.genome_filter:
65 continue
66 genome_to_array_indices[genome].add(index)
68 # Build stacks
69 random.seed(self.seed)
70 self.stacks = []
71 for genome, genome_array_indices in genome_to_array_indices.items():
72 stack_indices = []
73 remainder = []
74 for repeat_index in range(self.repeats + 1):
75 if len(remainder) == 0 and repeat_index >= self.repeats:
76 break
78 # Finish Remainder
79 genome_array_indices_set = set(genome_array_indices)
80 available = genome_array_indices_set - set(remainder)
81 needed = self.stack_size - len(remainder)
82 available_list = list(available)
84 if len(available_list) >= needed:
85 to_add = random.sample(available_list, needed) # without replacement
86 else:
87 to_add = random.choices(available_list, k=needed) # with replacement
88 to_add_set = set(to_add)
89 assert not set(remainder) & to_add_set, "remainder and to_add should be disjoint"
91 self.add_stack(genome, remainder + to_add)
92 remainder = list(genome_array_indices_set - to_add_set)
93 random.shuffle(remainder)
95 # If we have already added each item the required number of times, then stop
96 if repeat_index >= self.repeats:
97 break
99 while len(remainder) >= self.stack_size:
100 self.add_stack(genome, remainder[:self.stack_size])
101 remainder = remainder[self.stack_size:]
103 def add_stack(self, genome:str, indices:Iterable[int]) -> BarbetStack:
104 """
105 Add a new stack to the dataset.
106 """
107 indices = np.array(sorted(indices))
108 stack = BarbetStack(genome=genome, array_indices=indices)
109 self.stacks.append(stack)
110 return stack
112 def __len__(self):
113 return len(self.stacks)
115 def __getitem__(self, idx):
116 stack = self.stacks[idx]
117 array_indices = stack.array_indices
119 assert len(array_indices) > 0, f"Stack has no array indices"
120 with torch.no_grad():
121 data = np.asarray(self.array[array_indices, :]).copy()
122 embeddings = torch.from_numpy(data).to(torch.float16)
124 del data
126 return embeddings
129@dataclass(kw_only=True)
130class BarbetTrainingDataset(Dataset):
131 accessions: list[str]
132 treedict: TreeDict
133 array:np.memmap|np.ndarray
134 accession_to_array_index:dict[str,int]|None=None
135 stack_size:int = 0
137 def __len__(self):
138 return len(self.accessions)
140 def __getitem__(self, idx):
141 accession = self.accessions[idx]
142 array_indices = self.accession_to_array_index[accession] if self.accession_to_array_index else idx
143 if self.stack_size:
144 array_indices = choose_k_from_n(array_indices, self.stack_size)
146 assert len(array_indices) > 0, f"Accession {accession} has no array indices"
147 with torch.no_grad():
148 data = np.array(self.array[array_indices, :], copy=False)
149 embedding = torch.tensor(data, dtype=torch.float16)
150 del data
152 # gene_id = gene_id_from_accession(accession)
153 seq_detail = self.treedict[accession]
154 node_id = int(seq_detail.node_id)
155 del seq_detail
157 return embedding, node_id
160@dataclass
161class BarbetDataModule(L.LightningDataModule):
162 treedict: TreeDict
163 # seqbank: SeqBank
164 array:np.memmap|np.ndarray
165 accession_to_array_index:dict[str,int]
166 max_items: int = 0
167 batch_size: int = 16
168 num_workers: int = 0
169 validation_partition:int = 0
170 test_partition:int = -1
171 train_all:bool = False
173 def __init__(
174 self,
175 treedict: TreeDict,
176 array:np.memmap|np.ndarray,
177 accession_to_array_index:dict[str,list[int]],
178 max_items: int = 0,
179 batch_size: int = 16,
180 num_workers: int = None,
181 validation_partition:int = 0,
182 test_partition:int=-1,
183 stack_size:int=0,
184 train_all:bool=False,
185 ):
186 super().__init__()
187 self.array = array
188 self.accession_to_array_index = accession_to_array_index
189 self.treedict = treedict
190 self.max_items = max_items
191 self.batch_size = batch_size
192 self.validation_partition = validation_partition
193 self.test_partition = test_partition
194 self.num_workers = min(os.cpu_count(), 8) if num_workers is None else num_workers
195 self.stack_size = stack_size
196 self.train_all = train_all
198 def setup(self, stage=None):
199 # make assignments here (val/train/test split)
200 # called on every process in DDP
201 self.training = []
202 self.validation = []
204 for accession, details in self.treedict.items():
205 partition = details.partition
206 if partition == self.test_partition:
207 continue
209 dataset = self.validation if partition == self.validation_partition else self.training
210 dataset.append( accession )
212 if self.max_items and len(self.training) >= self.max_items and len(self.validation) > 0:
213 break
215 if self.train_all:
216 self.training += self.validation
218 self.train_dataset = self.create_dataset(self.training)
219 self.val_dataset = self.create_dataset(self.validation)
221 def create_dataset(self, accessions:list[str]) -> BarbetTrainingDataset:
222 return BarbetTrainingDataset(
223 accessions=accessions,
224 treedict=self.treedict,
225 array=self.array,
226 accession_to_array_index=self.accession_to_array_index,
227 stack_size=self.stack_size,
228 )
230 def train_dataloader(self):
231 return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
233 def val_dataloader(self):
234 return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)