Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import json | |
import pandas as pd | |
from calendar import monthrange | |
import torch | |
import utils | |
import random | |
#from h3.unstable import vect | |
import h3.api.numpy_int as h3 | |
from torch.nn.utils.rnn import pad_sequence | |
from functools import partial | |
import re | |
class LocationDataset(torch.utils.data.Dataset): | |
# MINE MINE MINE MINE - I added the dummy "num_context" - probably not needed anymore | |
# def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False, num_context=0): | |
# MAX MAX MAX MAX MAX | |
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False): | |
# handle input encoding: | |
self.input_enc = input_enc | |
if self.input_enc in ['env', 'sin_cos_env']: | |
raster = load_env() | |
else: | |
raster = None | |
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# define some properties: | |
self.locs = locs | |
self.loc_feats = self.enc.encode(self.locs) | |
self.labels = labels | |
self.classes = classes | |
self.class_to_taxa = class_to_taxa | |
if dates is not None: | |
self.dates = dates | |
self.enc_time = utils.TimeEncoder() | |
# useful numbers: | |
self.num_classes = len(np.unique(labels)) | |
self.input_dim = input_dim | |
self.time_dim = time_dim | |
self.noise_time = noise_time | |
if self.enc.raster is not None: | |
self.enc.raster = self.enc.raster.to(device) | |
def __len__(self): | |
return self.loc_feats.shape[0] | |
def __getitem__(self, index): | |
loc_feat = self.loc_feats[index, :] | |
loc = self.locs[index, :] | |
class_id = self.labels[index] | |
if self.time_dim > 0: | |
date = self.dates[index] | |
# add noise | |
if self.noise_time: | |
noise_level = random.random() | |
noise = (2*random.random() - 1) * (0.5*noise_level) | |
loc_feat = torch.cat([loc_feat, self.enc_time.encode_fast([date.item()+noise,noise_level])]) | |
else: | |
raise NotImplementedError() | |
loc_feat = torch.cat([loc_feat, torch.tensor(self.enc_time.encode([2*date.item()-1], normalize=False))]) | |
return loc_feat, torch.cat([loc, date[None]]), class_id | |
else: | |
return loc_feat, loc, class_id | |
# MINE MINE MINE MINE - should only be used for my models | |
# I need this due to the "context" points that need to be returned with each main training point | |
class TransformerLocationDataset(torch.utils.data.Dataset): | |
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, | |
noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False, | |
variable_context_length=False): | |
# Handle input encoding: | |
self.input_enc = input_enc | |
self.jitter = jitter | |
self.variable_context_length = variable_context_length | |
if self.input_enc in ['env', 'sin_cos_env']: | |
raster = load_env() | |
else: | |
raster = None | |
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# Handle transformer input encoding: | |
self.transformer_input_enc = transformer_input_enc | |
if self.transformer_input_enc in ['env', 'sin_cos_env']: | |
transformer_raster = load_env() | |
else: | |
transformer_raster = None | |
if self.transformer_input_enc == 'sinr': | |
self.transformer_enc = self.enc | |
else: | |
self.transformer_enc = utils.CoordEncoder( | |
transformer_input_enc, transformer_raster, input_dim=token_dim) | |
# Define properties: | |
self.locs = locs # Keep on CPU | |
self.loc_feats = self.enc.encode(self.locs, normalize=True) | |
transformer_loc_feats = self.transformer_enc.encode( | |
self.locs, normalize=False) | |
self.labels = labels # Keep on CPU | |
self.classes = classes | |
self.class_to_taxa = class_to_taxa | |
if dates is not None: | |
self.dates = dates | |
self.enc_time = utils.TimeEncoder() | |
# Useful numbers: | |
self.num_classes = len(np.unique(labels)) | |
self.input_dim = input_dim | |
self.time_dim = time_dim | |
self.noise_time = noise_time | |
self.num_context = num_context | |
self.token_dim = token_dim | |
# Remove device assignment from raster encoders | |
# if self.enc.raster is not None: | |
# self.enc.raster = self.enc.raster.to(device) | |
# if self.transformer_enc.raster is not None: | |
# self.transformer_enc.raster = self.transformer_enc.raster.to(device) | |
# Organize the data into dictionaries | |
per_class_location_dict = organize_data_by_labels( | |
np.array(labels), np.array(locs)) | |
per_class_loc_feats_dict = organize_data_by_labels( | |
np.array(labels), np.array(transformer_loc_feats)) | |
for key, value in per_class_location_dict.items(): | |
per_class_location_dict[key] = torch.tensor( | |
np.array(value)) # Keep on CPU | |
for key, value in per_class_loc_feats_dict.items(): | |
per_class_loc_feats_dict[key] = torch.tensor( | |
np.array(value)) # Keep on CPU | |
self.per_class_locs = per_class_location_dict | |
self.per_class_loc_feats = per_class_loc_feats_dict | |
def __len__(self): | |
return self.loc_feats.shape[0] | |
def __getitem__(self, index): | |
# Retrieve the feature and class of the original point | |
loc_feat = self.loc_feats[index, :] # On CPU | |
loc = self.locs[index, :] # On CPU | |
class_id = self.labels[index] # On CPU | |
class_id_int = class_id.item() | |
# Fetch all locations for the given class | |
all_class_locs = self.per_class_locs[class_id_int] # On CPU | |
all_class_loc_feats = self.per_class_loc_feats[class_id_int] # On CPU | |
# Define a unique class token index | |
class_token_feature = torch.zeros( | |
(1, len(all_class_loc_feats[0]))) # CPU tensor | |
# Broadcast and compare to find all matching locations | |
matches = (all_class_locs == loc).all(dim=1) | |
# Find the index of the original location | |
local_index = torch.where(matches)[0] | |
if len(local_index) > 1: | |
local_index = local_index[0] | |
# Exclude the original location's index | |
filtered_local_indices = torch.arange( | |
len(all_class_locs)) != local_index | |
if self.variable_context_length: | |
num_context = random.randint(1, self.num_context) | |
else: | |
num_context = self.num_context | |
# Select indices for context | |
if filtered_local_indices.sum() > num_context: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
perm = torch.randperm(selected_indices.size(0)) | |
selected_indices = selected_indices[perm][:num_context] | |
else: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
# Get context locations and features | |
context_loc_feats = all_class_loc_feats[selected_indices] | |
context_locs = all_class_locs[selected_indices] | |
# Adjust dimensions if necessary | |
if context_loc_feats.dim() == 1: | |
context_loc_feats = context_loc_feats.unsqueeze(0) | |
if context_locs.dim() == 1: | |
context_locs = context_locs.unsqueeze(0) | |
if self.jitter: | |
noise_std = 0.001 | |
noise = torch.full_like(context_loc_feats, noise_std) | |
context_loc_feats += noise | |
context_sequence = torch.cat( | |
[class_token_feature, context_loc_feats], dim=0) | |
return loc_feat, loc, class_id, context_sequence, context_locs | |
def collate_fn(self, batch): | |
# Unpack the batch | |
loc_feats, locs, class_ids, context_sequences, context_locss = zip(*batch) | |
# Pad sequences | |
padded_sequences = pad_sequence( | |
context_sequences, batch_first=True, padding_value=-10) | |
padded_context_locs = pad_sequence( | |
context_locss, batch_first=True, padding_value=-10) | |
# Convert lists to tensors | |
loc_feats = torch.stack(loc_feats) | |
locs = torch.stack(locs) | |
class_ids = torch.tensor(class_ids) | |
# Create a mask for sequences based on padding | |
sequence_mask = (padded_sequences == -10).all(dim=-1) | |
return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask | |
def get_item_from_class(self, class_id): | |
# Fetch locations and features for the class | |
all_class_locs = self.per_class_locs[class_id] | |
all_class_loc_feats = self.per_class_loc_feats[class_id] | |
# Randomly select an index | |
index = np.random.choice(len(all_class_locs)) | |
# Retrieve selected location and features | |
loc = all_class_locs[index] | |
if loc.ndim == 1: | |
loc = loc.unsqueeze(0) | |
loc_feat = self.enc.encode(loc, normalize=False) | |
# Define a unique class token index | |
class_token_feature = torch.zeros((1, self.token_dim)) # CPU tensor | |
# Exclude selected index from context | |
filtered_local_indices = torch.arange(len(all_class_locs)) != index | |
# Select indices for context | |
if filtered_local_indices.sum() > self.num_context: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
perm = torch.randperm(selected_indices.size(0)) | |
selected_indices = selected_indices[perm][:self.num_context] | |
else: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
# Get context locations and features | |
context_loc_feats = all_class_loc_feats[selected_indices] | |
context_locs = all_class_locs[selected_indices] | |
# Adjust dimensions if necessary | |
if context_loc_feats.dim() == 1: | |
context_loc_feats = context_loc_feats.unsqueeze(0) | |
if context_locs.dim() == 1: | |
context_locs = context_locs.unsqueeze(0) | |
context_sequence = torch.cat( | |
[class_token_feature, context_loc_feats], dim=0) | |
return loc_feat, loc, class_id, context_sequence, context_locs | |
# | |
# class TransformerLocationDataset(torch.utils.data.Dataset): | |
# def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, | |
# noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False, | |
# variable_context_length=False): | |
# # handle input encoding: | |
# self.input_enc = input_enc | |
# self.jitter=jitter | |
# self.variable_context_length=variable_context_length | |
# if self.input_enc in ['env', 'sin_cos_env']: | |
# raster = load_env() | |
# else: | |
# raster = None | |
# self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# | |
# # handle transformer input encoding: | |
# self.transformer_input_enc = transformer_input_enc | |
# if self.transformer_input_enc in ['env', 'sin_cos_env']: | |
# transformer_raster = load_env() | |
# else: | |
# transformer_raster = None | |
# if self.transformer_input_enc == 'sinr': | |
# self.transformer_enc = self.enc | |
# else: | |
# self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim) | |
# | |
# # define some properties: | |
# self.locs = locs | |
# # Below line also normalises locs as well as making loc feats | |
# self.loc_feats = self.enc.encode(self.locs, normalize=True) | |
# transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False) | |
# self.labels = labels | |
# self.classes = classes | |
# self.class_to_taxa = class_to_taxa | |
# if dates is not None: | |
# self.dates = dates | |
# self.enc_time = utils.TimeEncoder() | |
# | |
# # useful numbers: | |
# self.num_classes = len(np.unique(labels)) | |
# self.input_dim = input_dim | |
# self.time_dim = time_dim | |
# self.noise_time = noise_time | |
# self.num_context = num_context | |
# self.token_dim = token_dim | |
# | |
# if self.enc.raster is not None: | |
# self.enc.raster = self.enc.raster.to(device) | |
# | |
# if self.transformer_enc.raster is not None: | |
# self.transformer_enc.raster = self.transformer_enc.raster.to(device) | |
# | |
# # Organize the data into the dictionary | |
# per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(locs)) | |
# per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats)) | |
# for key, value in per_class_location_dict.items(): | |
# per_class_location_dict[key] = torch.tensor(np.array(value)) | |
# for key, value in per_class_loc_feats_dict.items(): | |
# per_class_loc_feats_dict[key] = torch.tensor(np.array(value)) | |
# self.per_class_locs = per_class_location_dict | |
# self.per_class_loc_feats = per_class_loc_feats_dict | |
# | |
# def __len__(self): | |
# return self.loc_feats.shape[0] | |
# | |
# def __getitem__(self, index): | |
# # Retrieve the feature and class of the original point | |
# loc_feat = self.loc_feats[index, :] | |
# loc = self.locs[index, :] | |
# class_id = self.labels[index] | |
# class_id_int = class_id.item() | |
# | |
# # Fetch all locations for the given class | |
# all_class_locs = self.per_class_locs[class_id_int] | |
# all_class_loc_feats = self.per_class_loc_feats[class_id_int] | |
# | |
# # Define a unique class token index | |
# class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # Create a zero vector for the class token | |
# | |
# # Broadcast and compare to find all matching locations | |
# matches = (all_class_locs == loc).all(dim=1) | |
# | |
# # Find the index of the original location | |
# local_index = torch.where(matches)[0] | |
# if len(local_index) > 1: | |
# local_index = local_index[0] | |
# | |
# # Exclude the original location's index | |
# filtered_local_indices = torch.arange(len(all_class_locs)) != local_index | |
# | |
# if self.variable_context_length: | |
# num_context = random.randint(1, self.num_context) | |
# else: | |
# num_context = self.num_context | |
# | |
# # Select random or all indices depending on the availability relative to `num_context` | |
# if filtered_local_indices.sum() > num_context: | |
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
# np.random.shuffle(selected_indices) | |
# selected_indices = selected_indices[:num_context] | |
# else: | |
# selected_indices = filtered_local_indices.nonzero().squeeze() | |
# | |
# # Get context locations and their features | |
# context_loc_feats = all_class_loc_feats[selected_indices] | |
# context_locs = all_class_locs[selected_indices] | |
# | |
# # Check if context_loc_feats has 1 dimension and add another if it does | |
# if context_loc_feats.dim() == 1: | |
# context_loc_feats = context_loc_feats.unsqueeze(0) | |
# | |
# if self.jitter: | |
# noise_std = 0.001 | |
# noise = torch.full_like(context_loc_feats, noise_std) | |
# context_loc_feats = context_loc_feats + noise | |
# | |
# # Check if context_locs has 1 dimension and add another if it does | |
# if context_locs.dim() == 1: | |
# context_locs = context_locs.unsqueeze(0) | |
# | |
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# | |
# return loc_feat, loc, class_id, context_sequence, context_locs | |
# | |
# def collate_fn(self, batch): | |
# # Unzip the batch | |
# loc_feats, locs, class_ids, context_sequences, context_locss = zip(*batch) | |
# | |
# # Convert list of sequences to a tensor with padding | |
# padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10) | |
# | |
# # Convert list of class IDs to a tensor | |
# class_ids = torch.tensor(class_ids) | |
# # Convert loc_feats and locs to tensors | |
# loc_feats = torch.stack(loc_feats) | |
# locs = torch.stack(locs) | |
# | |
# padded_context_locs = pad_sequence(context_locss, batch_first=True, padding_value=-10) | |
# | |
# # Create a mask for sequences based on padding | |
# # sequence_mask = (padded_sequences == 0) # Create a mask where there's padding (-10s) | |
# sequence_mask = (padded_sequences == -10).all(dim=-1) | |
# | |
# # return padded_sequences, padded_locs, class_ids, sequence_mask | |
# return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask | |
# | |
# def get_item_from_class(self, class_id): | |
# # Fetch all locations and features for the given class | |
# all_class_locs = self.per_class_locs[class_id] | |
# all_class_loc_feats = self.per_class_loc_feats[class_id] | |
# | |
# # Randomly select an index for the class | |
# index = np.random.choice(len(all_class_locs)) | |
# | |
# # Retrieve the selected location and its features | |
# loc = all_class_locs[index] | |
# if loc.ndim == 1: | |
# loc = loc.unsqueeze(0) | |
# # loc = loc.unsqueeze(0) | |
# loc_feat = self.enc.encode(loc, normalize=False) | |
# # loc_feat = all_class_loc_feats[index] | |
# | |
# # Define a unique class token index | |
# class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token | |
# | |
# # Exclude the selected index from the context | |
# filtered_local_indices = torch.arange(len(all_class_locs)) != index | |
# | |
# # Select random or all indices depending on the availability relative to `num_context` | |
# if filtered_local_indices.sum() > self.num_context: | |
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
# np.random.shuffle(selected_indices) | |
# selected_indices = selected_indices[:self.num_context] | |
# else: | |
# selected_indices = filtered_local_indices.nonzero().squeeze() | |
# | |
# # Get context locations and their features | |
# context_loc_feats = all_class_loc_feats[selected_indices] | |
# context_locs = all_class_locs[selected_indices] | |
# | |
# # Check if context_loc_feats has 1 dimension and add another if it does | |
# if context_loc_feats.dim() == 1: | |
# context_loc_feats = context_loc_feats.unsqueeze(0) | |
# # Check if context_locs has 1 dimension and add another if it does | |
# if context_locs.dim() == 1: | |
# context_locs = context_locs.unsqueeze(0) | |
# | |
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# | |
# return loc_feat, loc, class_id, context_sequence, context_locs | |
# MINE MINE MINE MINE - should only be used for my models | |
# I need this due to the "context" points and text embeddings that need to be returned with each main training point | |
class TransformerLocationTextDataset(torch.utils.data.Dataset): | |
def __init__(self, locs, labels, classes, class_to_taxa, embs, embs_ids, embs_keys, input_enc, device, dates=None, input_dim=4, time_dim=0, | |
noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False, | |
variable_context_length=False): | |
# Handle input encoding: | |
self.input_enc = input_enc | |
self.jitter = jitter | |
self.variable_context_length = variable_context_length | |
if self.input_enc in ['env', 'sin_cos_env']: | |
raster = load_env() | |
else: | |
raster = None | |
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# Handle transformer input encoding: | |
self.transformer_input_enc = transformer_input_enc | |
if self.transformer_input_enc in ['env', 'sin_cos_env']: | |
transformer_raster = load_env() | |
else: | |
transformer_raster = None | |
if self.transformer_input_enc == 'sinr': | |
self.transformer_enc = self.enc | |
else: | |
self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim) | |
# Define some properties: | |
self.locs = locs # Keep on CPU | |
# Below line also normalizes locs as well as making loc feats | |
self.loc_feats = self.enc.encode(self.locs, normalize=True) | |
transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False) | |
self.labels = labels # Keep on CPU | |
self.classes = classes | |
self.class_to_taxa = class_to_taxa | |
if dates is not None: | |
self.dates = dates | |
self.enc_time = utils.TimeEncoder() | |
# Useful numbers: | |
self.num_classes = len(np.unique(labels)) | |
self.input_dim = input_dim | |
self.time_dim = time_dim | |
self.noise_time = noise_time | |
self.num_context = num_context | |
self.token_dim = token_dim | |
# Remove device assignments from rasters | |
# if self.enc.raster is not None: | |
# self.enc.raster = self.enc.raster.to(device) | |
# if self.transformer_enc.raster is not None: | |
# self.transformer_enc.raster = self.transformer_enc.raster.to(device) | |
# Text embeddings | |
self.embs = embs # Keep on CPU | |
self.embs_ids = embs_ids.tolist() | |
self.embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.embs_ids] | |
self.embs_keys = embs_keys | |
# Initialize an empty dictionary to store the result | |
class_emb_dict = {} | |
# Populate the dictionary | |
for i, (index, description) in enumerate(embs_keys): | |
# Find the class using the index from the class_list | |
class_id = self.embs_class_ids[index] | |
# Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
if class_id == -1: | |
continue | |
# Check if the class_id is already a key in the dictionary | |
if class_id not in class_emb_dict: | |
# Initialize with empty lists if class_id is not already in the dictionary | |
class_emb_dict[class_id] = ([], []) | |
# Append the description and the index of embs_keys to the corresponding lists | |
class_emb_dict[class_id][0].append(i) | |
class_emb_dict[class_id][1].append(description) | |
self.class_emb_dict = class_emb_dict | |
# Organize the data into dictionaries | |
per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(locs)) | |
per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats)) | |
for key, value in per_class_location_dict.items(): | |
per_class_location_dict[key] = torch.tensor(np.array(value)) # Keep on CPU | |
for key, value in per_class_loc_feats_dict.items(): | |
per_class_loc_feats_dict[key] = torch.tensor(np.array(value)) # Keep on CPU | |
self.per_class_locs = per_class_location_dict | |
self.per_class_loc_feats = per_class_loc_feats_dict | |
def __len__(self): | |
return self.loc_feats.shape[0] | |
def __getitem__(self, index): | |
# Retrieve the feature and class of the original point | |
loc_feat = self.loc_feats[index, :] | |
loc = self.locs[index, :] | |
class_id = self.labels[index] | |
class_id_int = class_id.item() | |
# Fetch all locations for the given class | |
all_class_locs = self.per_class_locs[class_id_int] | |
all_class_loc_feats = self.per_class_loc_feats[class_id_int] | |
# Define a unique class token index | |
class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # CPU tensor | |
# Find the index of the original location | |
matches = (all_class_locs == loc).all(dim=1) | |
local_index = torch.where(matches)[0] | |
if len(local_index) > 1: | |
local_index = local_index[0] | |
# Exclude the original location's index | |
filtered_local_indices = torch.arange(len(all_class_locs)) != local_index | |
if self.variable_context_length: | |
num_context = random.randint(1, self.num_context) | |
else: | |
num_context = self.num_context | |
# Select random or all indices depending on availability | |
if filtered_local_indices.sum() > num_context: | |
selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
np.random.shuffle(selected_indices) | |
selected_indices = selected_indices[:num_context] | |
else: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
# Get context locations and their features | |
context_loc_feats = all_class_loc_feats[selected_indices] | |
context_locs = all_class_locs[selected_indices] | |
# Adjust dimensions if necessary | |
if context_loc_feats.dim() == 1: | |
context_loc_feats = context_loc_feats.unsqueeze(0) | |
if context_locs.dim() == 1: | |
context_locs = context_locs.unsqueeze(0) | |
if self.jitter: | |
noise_std = 0.001 | |
noise = torch.full_like(context_loc_feats, noise_std) | |
context_loc_feats = context_loc_feats + noise | |
context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# Text embeddings | |
if class_id_int in self.class_emb_dict: | |
embs_indexes, descriptions = self.class_emb_dict[class_id_int] | |
selected_index = random.choice(embs_indexes) | |
emb = self.embs[selected_index] | |
else: | |
emb = torch.zeros(4096) # Adjust size if necessary | |
return loc_feat, loc, class_id, context_sequence, context_locs, emb | |
def collate_fn(self, batch): | |
# Unzip the batch | |
loc_feats, locs, class_ids, context_sequences, context_locs_list, embs = zip(*batch) | |
# Convert list of sequences to tensors with padding | |
padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10) | |
padded_context_locs = pad_sequence(context_locs_list, batch_first=True, padding_value=-10) | |
# Convert lists to tensors | |
loc_feats = torch.stack(loc_feats) | |
locs = torch.stack(locs) | |
class_ids = torch.tensor(class_ids) | |
embs = torch.stack(embs) | |
# Create a mask for sequences based on padding | |
sequence_mask = (padded_sequences == -10).all(dim=-1) | |
return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, embs | |
def get_item_from_class(self, class_id): | |
# Fetch all locations and features for the given class | |
all_class_locs = self.per_class_locs[class_id] | |
all_class_loc_feats = self.per_class_loc_feats[class_id] | |
# Randomly select an index for the class | |
index = np.random.choice(len(all_class_locs)) | |
# Retrieve the selected location and its features | |
loc = all_class_locs[index] | |
if loc.ndim == 1: | |
loc = loc.unsqueeze(0) | |
loc_feat = self.enc.encode(loc, normalize=False) | |
# Define a unique class token index | |
class_token_feature = torch.zeros((1, self.token_dim)) # CPU tensor | |
# Exclude the selected index from the context | |
filtered_local_indices = torch.arange(len(all_class_locs)) != index | |
# Select random or all indices depending on availability | |
if filtered_local_indices.sum() > self.num_context: | |
selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
np.random.shuffle(selected_indices) | |
selected_indices = selected_indices[:self.num_context] | |
else: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
# Get context locations and features | |
context_loc_feats = all_class_loc_feats[selected_indices] | |
context_locs = all_class_locs[selected_indices] | |
# Adjust dimensions if necessary | |
if context_loc_feats.dim() == 1: | |
context_loc_feats = context_loc_feats.unsqueeze(0) | |
if context_locs.dim() == 1: | |
context_locs = context_locs.unsqueeze(0) | |
context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# Text embeddings | |
if class_id in self.class_emb_dict: | |
embs_indexes, descriptions = self.class_emb_dict[class_id] | |
selected_index = random.choice(embs_indexes) | |
emb = self.embs[selected_index] | |
else: | |
emb = torch.zeros(4096) # Adjust size if necessary | |
return loc_feat, loc, class_id, context_sequence, context_locs, emb | |
# def select_text_section(self, text_section): | |
# # Initialize an empty dictionary to store the result | |
# text_class_emb_dict = {} | |
# # Populate the dictionary | |
# for i, (index, description) in enumerate(self.text_embs_keys): | |
# # Find the class using the index from the class_list | |
# class_id = self.text_embs_class_ids[index] | |
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
# if class_id == -1: | |
# continue | |
# if description != text_section: | |
# continue | |
# # Check if the class_id is already a key in the dictionary | |
# if class_id not in text_class_emb_dict: | |
# # Initialize with empty lists if class_id is not already in the dictionary | |
# text_class_emb_dict[class_id] = ([], []) | |
# | |
# # Append the description and the index of embs_keys to the corresponding lists | |
# text_class_emb_dict[class_id][0].append(i) | |
# text_class_emb_dict[class_id][1].append(description) | |
# self.text_class_emb_dict = text_class_emb_dict | |
# | |
# class TransformerLocationTextDataset(torch.utils.data.Dataset): | |
# def __init__(self, locs, labels, classes, class_to_taxa, embs, embs_ids, embs_keys, input_enc, device, dates=None, input_dim=4, time_dim=0, | |
# noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False, | |
# variable_context_length=False): | |
# # handle input encoding: | |
# self.input_enc = input_enc | |
# self.jitter=jitter | |
# self.variable_context_length=variable_context_length | |
# if self.input_enc in ['env', 'sin_cos_env']: | |
# raster = load_env() | |
# else: | |
# raster = None | |
# self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# | |
# # handle transformer input encoding: | |
# self.transformer_input_enc = transformer_input_enc | |
# if self.transformer_input_enc in ['env', 'sin_cos_env']: | |
# transformer_raster = load_env() | |
# else: | |
# transformer_raster = None | |
# if self.transformer_input_enc == 'sinr': | |
# self.transformer_enc = self.enc | |
# else: | |
# self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim) | |
# | |
# # define some properties: | |
# self.locs = locs | |
# # Below line also normalises locs as well as making loc feats | |
# self.loc_feats = self.enc.encode(self.locs, normalize=True) | |
# transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False) | |
# self.labels = labels | |
# self.classes = classes | |
# self.class_to_taxa = class_to_taxa | |
# if dates is not None: | |
# self.dates = dates | |
# self.enc_time = utils.TimeEncoder() | |
# | |
# # useful numbers: | |
# self.num_classes = len(np.unique(labels)) | |
# self.input_dim = input_dim | |
# self.time_dim = time_dim | |
# self.noise_time = noise_time | |
# self.num_context = num_context | |
# self.token_dim = token_dim | |
# | |
# if self.enc.raster is not None: | |
# self.enc.raster = self.enc.raster.to(device) | |
# | |
# if self.transformer_enc.raster is not None: | |
# self.transformer_enc.raster = self.transformer_enc.raster.to(device) | |
# | |
# # text stuff | |
# # print("CHECK WHEN YOU HAVE ACCESS TO THE SERVER WHAT THE FORM OF EMBS ARE AND ALL THAT") | |
# self.embs = embs | |
# self.embs_ids = embs_ids.tolist() | |
# self.embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.embs_ids] | |
# self.embs_keys = embs_keys | |
# | |
# # Initialize an empty dictionary to store the result | |
# class_emb_dict = {} | |
# # Populate the dictionary | |
# for i, (index, description) in enumerate(embs_keys): | |
# # Find the class using the index from the class_list | |
# class_id = self.embs_class_ids[index] | |
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
# if class_id == -1: | |
# continue | |
# # Check if the class_id is already a key in the dictionary | |
# if class_id not in class_emb_dict: | |
# # Initialize with empty lists if class_id is not already in the dictionary | |
# class_emb_dict[class_id] = ([], []) | |
# | |
# # Append the description and the index of embs_keys to the corresponding lists | |
# class_emb_dict[class_id][0].append(i) | |
# class_emb_dict[class_id][1].append(description) | |
# self.class_emb_dict = class_emb_dict | |
# | |
# | |
# # Organize the data into the dictionary | |
# per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(locs)) | |
# per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats)) | |
# for key, value in per_class_location_dict.items(): | |
# per_class_location_dict[key] = torch.tensor(np.array(value)) | |
# for key, value in per_class_loc_feats_dict.items(): | |
# per_class_loc_feats_dict[key] = torch.tensor(np.array(value)) | |
# self.per_class_locs = per_class_location_dict | |
# self.per_class_loc_feats = per_class_loc_feats_dict | |
# | |
# | |
# def __len__(self): | |
# return self.loc_feats.shape[0] | |
# | |
# def __getitem__(self, index): | |
# # Retrieve the feature and class of the original point | |
# loc_feat = self.loc_feats[index, :] | |
# loc = self.locs[index, :] | |
# class_id = self.labels[index] | |
# class_id_int = class_id.item() | |
# | |
# # Fetch all locations for the given class | |
# all_class_locs = self.per_class_locs[class_id_int] | |
# all_class_loc_feats = self.per_class_loc_feats[class_id_int] | |
# | |
# # Define a unique class token index | |
# class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # Create a zero vector for the class token | |
# | |
# # Broadcast and compare to find all matching locations | |
# matches = (all_class_locs == loc).all(dim=1) | |
# | |
# # Find the index of the original location | |
# local_index = torch.where(matches)[0] | |
# if len(local_index) > 1: | |
# local_index = local_index[0] | |
# | |
# # Exclude the original location's index | |
# filtered_local_indices = torch.arange(len(all_class_locs)) != local_index | |
# | |
# if self.variable_context_length: | |
# num_context = random.randint(1, self.num_context) | |
# else: | |
# num_context = self.num_context | |
# | |
# # Select random or all indices depending on the availability relative to `num_context` | |
# if filtered_local_indices.sum() > num_context: | |
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
# np.random.shuffle(selected_indices) | |
# selected_indices = selected_indices[:num_context] | |
# else: | |
# selected_indices = filtered_local_indices.nonzero().squeeze() | |
# | |
# # Get context locations and their features | |
# context_loc_feats = all_class_loc_feats[selected_indices] | |
# context_locs = all_class_locs[selected_indices] | |
# | |
# # Check if context_loc_feats has 1 dimension and add another if it does | |
# if context_loc_feats.dim() == 1: | |
# context_loc_feats = context_loc_feats.unsqueeze(0) | |
# | |
# if self.jitter: | |
# noise_std = 0.001 | |
# noise = torch.full_like(context_loc_feats, noise_std) | |
# context_loc_feats = context_loc_feats + noise | |
# | |
# # Check if context_locs has 1 dimension and add another if it does | |
# if context_locs.dim() == 1: | |
# context_locs = context_locs.unsqueeze(0) | |
# | |
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# | |
# # text stuff | |
# # Get the embedding for the right class | |
# if class_id_int in self.class_emb_dict: | |
# embs_indexes, descriptions = self.class_emb_dict[class_id_int] | |
# # Randomly select an index from the list of indices | |
# selected_index = random.choice(embs_indexes) | |
# # Use the selected index to retrieve the corresponding element from embs | |
# emb = self.embs[selected_index] | |
# else: | |
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later? | |
# # emb = None | |
# emb = torch.zeros(4096) | |
# | |
# return loc_feat, loc, class_id, context_sequence, context_locs, emb | |
# | |
# def collate_fn(self, batch): | |
# # Unzip the batch | |
# loc_feats, locs, class_ids, context_sequences, context_locss, embs = zip(*batch) | |
# | |
# # Convert list of sequences to a tensor with padding | |
# padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10) | |
# | |
# # Convert list of class IDs to a tensor | |
# class_ids = torch.tensor(class_ids) | |
# # Convert loc_feats and locs to tensors | |
# loc_feats = torch.stack(loc_feats) | |
# locs = torch.stack(locs) | |
# embs = torch.stack(embs) | |
# | |
# padded_context_locs = pad_sequence(context_locss, batch_first=True, padding_value=-10) | |
# | |
# # Create a mask for sequences based on padding | |
# # sequence_mask = (padded_sequences == 0) # Create a mask where there's padding (-10s) | |
# sequence_mask = (padded_sequences == -10).all(dim=-1) | |
# | |
# # return padded_sequences, padded_locs, class_ids, sequence_mask | |
# return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, embs | |
# | |
# def get_item_from_class(self, class_id): | |
# # Fetch all locations and features for the given class | |
# all_class_locs = self.per_class_locs[class_id] | |
# all_class_loc_feats = self.per_class_loc_feats[class_id] | |
# | |
# # Randomly select an index for the class | |
# index = np.random.choice(len(all_class_locs)) | |
# | |
# # Retrieve the selected location and its features | |
# loc = all_class_locs[index] | |
# if loc.ndim == 1: | |
# loc = loc.unsqueeze(0) | |
# # loc = loc.unsqueeze(0) | |
# loc_feat = self.enc.encode(loc, normalize=False) | |
# # loc_feat = all_class_loc_feats[index] | |
# | |
# # Define a unique class token index | |
# class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token | |
# | |
# # Exclude the selected index from the context | |
# filtered_local_indices = torch.arange(len(all_class_locs)) != index | |
# | |
# # Select random or all indices depending on the availability relative to `num_context` | |
# if filtered_local_indices.sum() > self.num_context: | |
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
# np.random.shuffle(selected_indices) | |
# selected_indices = selected_indices[:self.num_context] | |
# else: | |
# selected_indices = filtered_local_indices.nonzero().squeeze() | |
# | |
# # Get context locations and their features | |
# context_loc_feats = all_class_loc_feats[selected_indices] | |
# context_locs = all_class_locs[selected_indices] | |
# | |
# # Check if context_loc_feats has 1 dimension and add another if it does | |
# if context_loc_feats.dim() == 1: | |
# context_loc_feats = context_loc_feats.unsqueeze(0) | |
# # Check if context_locs has 1 dimension and add another if it does | |
# if context_locs.dim() == 1: | |
# context_locs = context_locs.unsqueeze(0) | |
# | |
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# | |
# # text stuff | |
# # Get the embedding for the right class | |
# if class_id in self.class_emb_dict: | |
# embs_indexes, descriptions = self.class_emb_dict[class_id] | |
# # Randomly select an index from the list of indices | |
# selected_index = random.choice(embs_indexes) | |
# # Use the selected index to retrieve the corresponding element from embs | |
# emb = self.embs[selected_index] | |
# else: | |
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later? | |
# # emb = None | |
# emb = torch.zeros(4096) | |
# | |
# return loc_feat, loc, class_id, context_sequence, context_locs, emb | |
class TransformerLocationTextDatasetRandomizeOutputs(torch.utils.data.Dataset): | |
def __init__(self, locs, labels, classes, class_to_taxa, embs, embs_ids, embs_keys, input_enc, device, dates=None, input_dim=4, time_dim=0, | |
noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False, | |
variable_context_length=False, just_obs_prob=0.2, just_text_prob=0.1): | |
# Handle input encoding: | |
self.input_enc = input_enc | |
self.jitter = jitter | |
self.variable_context_length = variable_context_length | |
if self.input_enc in ['env', 'sin_cos_env']: | |
raster = load_env() | |
else: | |
raster = None | |
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# Handle transformer input encoding: | |
self.transformer_input_enc = transformer_input_enc | |
if self.transformer_input_enc in ['env', 'sin_cos_env']: | |
transformer_raster = load_env() | |
else: | |
transformer_raster = None | |
if self.transformer_input_enc == 'sinr': | |
self.transformer_enc = self.enc | |
else: | |
self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim) | |
# Define some properties: | |
self.locs = locs # Keep on CPU | |
# Below line also normalizes locs as well as making loc feats | |
self.loc_feats = self.enc.encode(self.locs, normalize=True) | |
transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False) | |
self.labels = labels # Keep on CPU | |
self.classes = classes | |
self.class_to_taxa = class_to_taxa | |
if dates is not None: | |
self.dates = dates | |
self.enc_time = utils.TimeEncoder() | |
# Useful numbers: | |
self.num_classes = len(np.unique(labels)) | |
self.input_dim = input_dim | |
self.time_dim = time_dim | |
self.noise_time = noise_time | |
self.num_context = num_context | |
self.token_dim = token_dim | |
# Remove device assignments from rasters | |
# if self.enc.raster is not None: | |
# self.enc.raster = self.enc.raster.to(device) | |
# if self.transformer_enc.raster is not None: | |
# self.transformer_enc.raster = self.transformer_enc.raster.to(device) | |
# Text embeddings | |
self.embs = embs # Keep on CPU | |
self.embs_ids = embs_ids.tolist() | |
self.embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.embs_ids] | |
self.embs_keys = embs_keys | |
# Initialize an empty dictionary to store the result | |
class_emb_dict = {} | |
# Populate the dictionary | |
for i, (index, description) in enumerate(embs_keys): | |
class_id = self.embs_class_ids[index] | |
if class_id == -1: | |
continue | |
if class_id not in class_emb_dict: | |
class_emb_dict[class_id] = ([], []) | |
class_emb_dict[class_id][0].append(i) | |
class_emb_dict[class_id][1].append(description) | |
self.class_emb_dict = class_emb_dict | |
# Organize the data into dictionaries | |
per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(locs)) | |
per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats)) | |
for key, value in per_class_location_dict.items(): | |
per_class_location_dict[key] = torch.tensor(np.array(value)) # Keep on CPU | |
for key, value in per_class_loc_feats_dict.items(): | |
per_class_loc_feats_dict[key] = torch.tensor(np.array(value)) # Keep on CPU | |
self.per_class_locs = per_class_location_dict | |
self.per_class_loc_feats = per_class_loc_feats_dict | |
self.just_obs_prob = just_obs_prob | |
self.just_text_prob = just_text_prob | |
self.both_prob = 1 - (just_obs_prob + just_text_prob) | |
if self.both_prob < 0: | |
raise ValueError('Probability of "just text" and "just observations" must sum to less than 1') | |
def __len__(self): | |
return self.loc_feats.shape[0] | |
def __getitem__(self, index): | |
# Determine the type of data to return | |
data_type = self.roll_output() | |
# Retrieve the feature and class of the original point | |
loc_feat = self.loc_feats[index, :] | |
loc = self.locs[index, :] | |
class_id = self.labels[index] | |
class_id_int = class_id.item() | |
if 'text' in data_type: | |
# Text embeddings | |
if class_id_int in self.class_emb_dict: | |
embs_indexes, descriptions = self.class_emb_dict[class_id_int] | |
selected_index = random.choice(embs_indexes) | |
emb = self.embs[selected_index] | |
else: | |
emb = torch.zeros(4096) # Adjust size if necessary | |
if data_type == 'text': | |
data_type = 'obs_text' | |
else: | |
emb = torch.zeros(4096) # Adjust size if necessary | |
if 'obs' in data_type: | |
# Fetch all locations for the given class | |
all_class_locs = self.per_class_locs[class_id_int] | |
all_class_loc_feats = self.per_class_loc_feats[class_id_int] | |
# Define a unique class token index | |
class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # CPU tensor | |
# Find the index of the original location | |
matches = (all_class_locs == loc).all(dim=1) | |
local_index = torch.where(matches)[0] | |
if len(local_index) > 1: | |
local_index = local_index[0] | |
# Exclude the original location's index | |
filtered_local_indices = torch.arange(len(all_class_locs)) != local_index | |
if self.variable_context_length: | |
num_context = random.randint(1, self.num_context) | |
else: | |
num_context = self.num_context | |
# Select random or all indices depending on availability | |
if filtered_local_indices.sum() > num_context: | |
selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
np.random.shuffle(selected_indices) | |
selected_indices = selected_indices[:num_context] | |
else: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
# Get context locations and features | |
context_loc_feats = all_class_loc_feats[selected_indices] | |
context_locs = all_class_locs[selected_indices] | |
# Adjust dimensions if necessary | |
if context_loc_feats.dim() == 1: | |
context_loc_feats = context_loc_feats.unsqueeze(0) | |
if context_locs.dim() == 1: | |
context_locs = context_locs.unsqueeze(0) | |
if self.jitter: | |
noise_std = 0.001 | |
noise = torch.full_like(context_loc_feats, noise_std) | |
context_loc_feats = context_loc_feats + noise | |
context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
else: | |
class_token_feature = torch.zeros((1, len(loc_feat))) # CPU tensor | |
context_sequence = torch.cat([class_token_feature], dim=0) | |
context_locs = torch.empty((0, loc.size(0))) # CPU tensor with correct dimensions | |
return loc_feat, loc, class_id, context_sequence, context_locs, emb | |
def roll_output(self): | |
return random.choices( | |
["obs", "text", "obs_text"], | |
[self.just_obs_prob, self.just_text_prob, self.both_prob] | |
)[0] | |
def collate_fn(self, batch): | |
# Unzip the batch | |
loc_feats, locs, class_ids, context_sequences, context_locs_list, embs = zip(*batch) | |
# Convert list of sequences to tensors with padding | |
padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10) | |
padded_context_locs = pad_sequence(context_locs_list, batch_first=True, padding_value=-10) | |
# Convert lists to tensors | |
loc_feats = torch.stack(loc_feats) | |
locs = torch.stack(locs) | |
class_ids = torch.tensor(class_ids) | |
embs = torch.stack(embs) | |
# Create a mask for sequences based on padding | |
sequence_mask = (padded_sequences == -10).all(dim=-1) | |
return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, embs | |
def get_item_from_class(self, class_id): | |
# Fetch all locations and features for the given class | |
all_class_locs = self.per_class_locs[class_id] | |
all_class_loc_feats = self.per_class_loc_feats[class_id] | |
# Randomly select an index for the class | |
index = np.random.choice(len(all_class_locs)) | |
# Retrieve the selected location and its features | |
loc = all_class_locs[index] | |
if loc.ndim == 1: | |
loc = loc.unsqueeze(0) | |
loc_feat = self.enc.encode(loc, normalize=False) | |
# Define a unique class token index | |
class_token_feature = torch.zeros((1, self.token_dim)) # CPU tensor | |
# Exclude the selected index from the context | |
filtered_local_indices = torch.arange(len(all_class_locs)) != index | |
# Select random or all indices depending on availability | |
if filtered_local_indices.sum() > self.num_context: | |
selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
np.random.shuffle(selected_indices) | |
selected_indices = selected_indices[:self.num_context] | |
else: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
# Get context locations and features | |
context_loc_feats = all_class_loc_feats[selected_indices] | |
context_locs = all_class_locs[selected_indices] | |
# Adjust dimensions if necessary | |
if context_loc_feats.dim() == 1: | |
context_loc_feats = context_loc_feats.unsqueeze(0) | |
if context_locs.dim() == 1: | |
context_locs = context_locs.unsqueeze(0) | |
context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# Text embeddings | |
if class_id in self.class_emb_dict: | |
embs_indexes, descriptions = self.class_emb_dict[class_id] | |
selected_index = random.choice(embs_indexes) | |
emb = self.embs[selected_index] | |
else: | |
emb = torch.zeros(4096) # Adjust size if necessary | |
return loc_feat, loc, class_id, context_sequence, context_locs, emb | |
# def select_text_section(self, text_section): | |
# # Initialize an empty dictionary to store the result | |
# text_class_emb_dict = {} | |
# # Populate the dictionary | |
# for i, (index, description) in enumerate(self.text_embs_keys): | |
# # Find the class using the index from the class_list | |
# class_id = self.text_embs_class_ids[index] | |
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
# if class_id == -1: | |
# continue | |
# if description != text_section: | |
# continue | |
# # Check if the class_id is already a key in the dictionary | |
# if class_id not in text_class_emb_dict: | |
# # Initialize with empty lists if class_id is not already in the dictionary | |
# text_class_emb_dict[class_id] = ([], []) | |
# | |
# # Append the description and the index of embs_keys to the corresponding lists | |
# text_class_emb_dict[class_id][0].append(i) | |
# text_class_emb_dict[class_id][1].append(description) | |
# self.text_class_emb_dict = text_class_emb_dict | |
# | |
# class TransformerLocationTextDatasetRandomizeOutputs(torch.utils.data.Dataset): | |
# def __init__(self, locs, labels, classes, class_to_taxa, embs, embs_ids, embs_keys, input_enc, device, dates=None, input_dim=4, time_dim=0, | |
# noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False, | |
# variable_context_length=False, just_obs_prob=0.2, just_text_prob=0.1): | |
# # handle input encoding: | |
# self.input_enc = input_enc | |
# self.jitter=jitter | |
# self.variable_context_length=variable_context_length | |
# if self.input_enc in ['env', 'sin_cos_env']: | |
# raster = load_env() | |
# else: | |
# raster = None | |
# self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# | |
# # handle transformer input encoding: | |
# self.transformer_input_enc = transformer_input_enc | |
# if self.transformer_input_enc in ['env', 'sin_cos_env']: | |
# transformer_raster = load_env() | |
# else: | |
# transformer_raster = None | |
# if self.transformer_input_enc == 'sinr': | |
# self.transformer_enc = self.enc | |
# else: | |
# self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim) | |
# | |
# # define some properties: | |
# self.locs = locs | |
# # Below line also normalises locs as well as making loc feats | |
# self.loc_feats = self.enc.encode(self.locs, normalize=True) | |
# transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False) | |
# self.labels = labels | |
# self.classes = classes | |
# self.class_to_taxa = class_to_taxa | |
# if dates is not None: | |
# self.dates = dates | |
# self.enc_time = utils.TimeEncoder() | |
# | |
# # useful numbers: | |
# self.num_classes = len(np.unique(labels)) | |
# self.input_dim = input_dim | |
# self.time_dim = time_dim | |
# self.noise_time = noise_time | |
# self.num_context = num_context | |
# self.token_dim = token_dim | |
# | |
# if self.enc.raster is not None: | |
# self.enc.raster = self.enc.raster.to(device) | |
# | |
# if self.transformer_enc.raster is not None: | |
# self.transformer_enc.raster = self.transformer_enc.raster.to(device) | |
# | |
# # text stuff | |
# # print("CHECK WHEN YOU HAVE ACCESS TO THE SERVER WHAT THE FORM OF EMBS ARE AND ALL THAT") | |
# self.embs = embs | |
# self.embs_ids = embs_ids.tolist() | |
# self.embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.embs_ids] | |
# self.embs_keys = embs_keys | |
# | |
# # Initialize an empty dictionary to store the result | |
# class_emb_dict = {} | |
# # Populate the dictionary | |
# for i, (index, description) in enumerate(embs_keys): | |
# # Find the class using the index from the class_list | |
# class_id = self.embs_class_ids[index] | |
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
# if class_id == -1: | |
# continue | |
# # Check if the class_id is already a key in the dictionary | |
# if class_id not in class_emb_dict: | |
# # Initialize with empty lists if class_id is not already in the dictionary | |
# class_emb_dict[class_id] = ([], []) | |
# | |
# # Append the description and the index of embs_keys to the corresponding lists | |
# class_emb_dict[class_id][0].append(i) | |
# class_emb_dict[class_id][1].append(description) | |
# self.class_emb_dict = class_emb_dict | |
# | |
# # Organize the data into the dictionary | |
# per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(locs)) | |
# per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats)) | |
# for key, value in per_class_location_dict.items(): | |
# per_class_location_dict[key] = torch.tensor(np.array(value)) | |
# for key, value in per_class_loc_feats_dict.items(): | |
# per_class_loc_feats_dict[key] = torch.tensor(np.array(value)) | |
# self.per_class_locs = per_class_location_dict | |
# self.per_class_loc_feats = per_class_loc_feats_dict | |
# | |
# self.just_obs_prob = just_obs_prob | |
# self.just_text_prob = just_text_prob | |
# self.both_prob = 1 - (just_obs_prob + just_text_prob) | |
# if self.both_prob < 0: | |
# raise ValueError('Probability of "just text" and "just observations" must sum to less than 1') | |
# | |
# def __len__(self): | |
# return self.loc_feats.shape[0] | |
# | |
# def __getitem__(self, index): | |
# # See what needs getting | |
# type = self.roll_output() | |
# | |
# # Retrieve the feature and class of the original point | |
# loc_feat = self.loc_feats[index, :] | |
# loc = self.locs[index, :] | |
# class_id = self.labels[index] | |
# class_id_int = class_id.item() | |
# | |
# if 'text' in type: | |
# # text stuff | |
# # Get the embedding for the right class | |
# if class_id_int in self.class_emb_dict: | |
# embs_indexes, descriptions = self.class_emb_dict[class_id_int] | |
# # Randomly select an index from the list of indices | |
# selected_index = random.choice(embs_indexes) | |
# # Use the selected index to retrieve the corresponding element from embs | |
# emb = self.embs[selected_index] | |
# else: | |
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later? | |
# # emb = None | |
# emb = torch.zeros(4096) | |
# if type == 'text': | |
# type = 'obs_text' | |
# else: | |
# emb = torch.zeros(4096) | |
# | |
# if 'obs' in type: | |
# # Fetch all locations for the given class | |
# all_class_locs = self.per_class_locs[class_id_int] | |
# all_class_loc_feats = self.per_class_loc_feats[class_id_int] | |
# | |
# # Define a unique class token index | |
# class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # Create a zero vector for the class token | |
# | |
# # Broadcast and compare to find all matching locations | |
# matches = (all_class_locs == loc).all(dim=1) | |
# | |
# # Find the index of the original location | |
# local_index = torch.where(matches)[0] | |
# if len(local_index) > 1: | |
# local_index = local_index[0] | |
# | |
# # Exclude the original location's index | |
# filtered_local_indices = torch.arange(len(all_class_locs)) != local_index | |
# | |
# if self.variable_context_length: | |
# num_context = random.randint(1, self.num_context) | |
# else: | |
# num_context = self.num_context | |
# | |
# # Select random or all indices depending on the availability relative to `num_context` | |
# if filtered_local_indices.sum() > num_context: | |
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
# np.random.shuffle(selected_indices) | |
# selected_indices = selected_indices[:num_context] | |
# else: | |
# selected_indices = filtered_local_indices.nonzero().squeeze() | |
# | |
# # Get context locations and their features | |
# context_loc_feats = all_class_loc_feats[selected_indices] | |
# context_locs = all_class_locs[selected_indices] | |
# | |
# # Check if context_loc_feats has 1 dimension and add another if it does | |
# if context_loc_feats.dim() == 1: | |
# context_loc_feats = context_loc_feats.unsqueeze(0) | |
# | |
# if self.jitter: | |
# noise_std = 0.001 | |
# noise = torch.full_like(context_loc_feats, noise_std) | |
# context_loc_feats = context_loc_feats + noise | |
# | |
# # Check if context_locs has 1 dimension and add another if it does | |
# if context_locs.dim() == 1: | |
# context_locs = context_locs.unsqueeze(0) | |
# | |
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# else: | |
# class_token_feature = torch.zeros((1, len(loc_feat))) # Create a zero vector for the class token | |
# # print('good chance this doesnt work and need to squeeze / unsqueeze instead. Check the dims') | |
# # nope, it is good. | |
# context_sequence = torch.cat([class_token_feature], dim=0) | |
# context_locs = torch.empty((0, loc.size(0))) # An empty tensor with the right number of dimensions | |
# | |
# return loc_feat, loc, class_id, context_sequence, context_locs, emb | |
# | |
# def roll_output(self): | |
# return random.choices( | |
# ["obs", "text", "obs_text"], | |
# [self.just_obs_prob, self.just_text_prob, self.both_prob] | |
# )[0] | |
# | |
# def collate_fn(self, batch): | |
# # Unzip the batch | |
# loc_feats, locs, class_ids, context_sequences, context_locss, embs = zip(*batch) | |
# | |
# # Convert list of sequences to a tensor with padding | |
# padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10) | |
# | |
# # Convert list of class IDs to a tensor | |
# class_ids = torch.tensor(class_ids) | |
# # Convert loc_feats and locs to tensors | |
# loc_feats = torch.stack(loc_feats) | |
# locs = torch.stack(locs) | |
# embs = torch.stack(embs) | |
# | |
# padded_context_locs = pad_sequence(context_locss, batch_first=True, padding_value=-10) | |
# | |
# # Create a mask for sequences based on padding | |
# # sequence_mask = (padded_sequences == 0) # Create a mask where there's padding (-10s) | |
# sequence_mask = (padded_sequences == -10).all(dim=-1) | |
# | |
# # return padded_sequences, padded_locs, class_ids, sequence_mask | |
# return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, embs | |
# | |
# def get_item_from_class(self, class_id): | |
# # Fetch all locations and features for the given class | |
# all_class_locs = self.per_class_locs[class_id] | |
# all_class_loc_feats = self.per_class_loc_feats[class_id] | |
# | |
# # Randomly select an index for the class | |
# index = np.random.choice(len(all_class_locs)) | |
# | |
# # Retrieve the selected location and its features | |
# loc = all_class_locs[index] | |
# if loc.ndim == 1: | |
# loc = loc.unsqueeze(0) | |
# # loc = loc.unsqueeze(0) | |
# loc_feat = self.enc.encode(loc, normalize=False) | |
# # loc_feat = all_class_loc_feats[index] | |
# | |
# # Define a unique class token index | |
# class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token | |
# | |
# # Exclude the selected index from the context | |
# filtered_local_indices = torch.arange(len(all_class_locs)) != index | |
# | |
# # Select random or all indices depending on the availability relative to `num_context` | |
# if filtered_local_indices.sum() > self.num_context: | |
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
# np.random.shuffle(selected_indices) | |
# selected_indices = selected_indices[:self.num_context] | |
# else: | |
# selected_indices = filtered_local_indices.nonzero().squeeze() | |
# | |
# # Get context locations and their features | |
# context_loc_feats = all_class_loc_feats[selected_indices] | |
# context_locs = all_class_locs[selected_indices] | |
# | |
# # Check if context_loc_feats has 1 dimension and add another if it does | |
# if context_loc_feats.dim() == 1: | |
# context_loc_feats = context_loc_feats.unsqueeze(0) | |
# # Check if context_locs has 1 dimension and add another if it does | |
# if context_locs.dim() == 1: | |
# context_locs = context_locs.unsqueeze(0) | |
# | |
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# | |
# # text stuff | |
# # Get the embedding for the right class | |
# if class_id in self.class_emb_dict: | |
# embs_indexes, descriptions = self.class_emb_dict[class_id] | |
# # Randomly select an index from the list of indices | |
# selected_index = random.choice(embs_indexes) | |
# # Use the selected index to retrieve the corresponding element from embs | |
# emb = self.embs[selected_index] | |
# else: | |
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later? | |
# # emb = None | |
# emb = torch.zeros(4096) | |
# | |
# return loc_feat, loc, class_id, context_sequence, context_locs, emb | |
class TransformerDatasetVariableTokens(torch.utils.data.Dataset): | |
def __init__(self, locs, labels, classes, class_to_taxa, text_embs, text_embs_ids, text_embs_keys, image_embs, | |
image_embs_ids, image_embs_keys, input_enc, device, dates=None, input_dim=4, time_dim=0, | |
noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False, | |
variable_context_length=False, loc_prob=1.0, text_prob=0.0, image_prob=0.0, env_prob=0.0, | |
eval_mode=False): | |
# handle input encoding: | |
# ensure all stuff is on cpu until batching | |
self.input_enc = input_enc | |
self.jitter = jitter | |
self.variable_context_length = variable_context_length | |
if self.input_enc in ['env', 'sin_cos_env']: | |
raster = load_env() | |
else: | |
raster = None | |
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# set up what token types and probs of use we have | |
self.loc_prob = loc_prob | |
self.text_prob = text_prob | |
self.image_prob = image_prob | |
self.env_prob = env_prob | |
if self.loc_prob == 0.0: | |
self.use_loc = False | |
else: | |
self.use_loc = True | |
if self.text_prob == 0.0: | |
self.use_text = False | |
else: | |
self.use_text = True | |
if self.image_prob == 0.0: | |
self.use_image = False | |
else: | |
self.use_image = True | |
if self.env_prob == 0.0: | |
self.use_env = False | |
else: | |
self.use_env = True | |
# handle transformer input encoding: | |
self.transformer_input_enc = transformer_input_enc | |
if self.transformer_input_enc in ['env', 'sin_cos_env']: | |
transformer_raster = load_env() | |
else: | |
transformer_raster = None | |
if self.transformer_input_enc == 'sinr': | |
self.transformer_enc = self.enc | |
else: | |
self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim) | |
# Remove device assignment from raster encoders | |
# if self.enc.raster is not None: | |
# self.enc.raster = self.enc.raster.to(device) | |
# if self.transformer_enc.raster is not None: | |
# self.transformer_enc.raster = self.transformer_enc.raster.to(device) | |
# define some properties: | |
self.locs = locs # Keep on CPU | |
# Below line also normalises locs as well as making loc feats | |
self.loc_feats = self.enc.encode(self.locs, normalize=True) | |
transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False) | |
self.labels = labels # Keep on CPU | |
self.classes = classes | |
self.class_to_taxa = class_to_taxa | |
if dates is not None: | |
self.dates = dates | |
self.enc_time = utils.TimeEncoder() | |
# useful numbers: | |
self.num_classes = len(np.unique(labels)) | |
self.input_dim = input_dim | |
self.time_dim = time_dim | |
self.noise_time = noise_time | |
self.num_context = num_context | |
self.token_dim = token_dim | |
self.env_emb_size = 20 | |
self.image_emb_size = 1024 | |
self.text_emb_size = 4096 | |
self.device = device # You can remove this if not needed | |
# text stuff | |
if self.use_text: | |
self.text_embs = text_embs # Keep on CPU | |
self.text_embs_ids = text_embs_ids.tolist() | |
self.text_embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.text_embs_ids] | |
self.text_embs_keys = text_embs_keys | |
# Initialize an empty dictionary to store the result | |
class_emb_dict = {} | |
# Populate the dictionary | |
for i, (index, description) in enumerate(text_embs_keys): | |
# Find the class using the index from the class_list | |
class_id = self.text_embs_class_ids[index] | |
# Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
if class_id == -1: | |
continue | |
# Check if the class_id is already a key in the dictionary | |
if class_id not in class_emb_dict: | |
# Initialize with empty lists if class_id is not already in the dictionary | |
class_emb_dict[class_id] = ([], []) | |
# Append the description and the index of embs_keys to the corresponding lists | |
class_emb_dict[class_id][0].append(i) | |
class_emb_dict[class_id][1].append(description) | |
self.text_class_emb_dict = class_emb_dict | |
else: | |
self.text_embs = None | |
self.text_embs_ids = None | |
self.text_embs_class_ids = None | |
self.text_embs_keys = None | |
self.text_class_emb_dict = None | |
# image stuff | |
if self.use_image: | |
self.image_embs = image_embs # Keep on CPU | |
self.image_embs_ids = image_embs_ids.tolist() | |
self.image_embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.image_embs_ids] | |
self.image_embs_keys = image_embs_keys | |
# Initialize an empty dictionary to store the result | |
class_emb_dict = {} | |
# Populate the dictionary | |
for i, (index, description) in enumerate(image_embs_keys): | |
# Find the class using the index from the class_list | |
class_id = self.image_embs_class_ids[index] | |
# Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
if class_id == -1: | |
continue | |
# Check if the class_id is already a key in the dictionary | |
if class_id not in class_emb_dict: | |
# Initialize with empty lists if class_id is not already in the dictionary | |
class_emb_dict[class_id] = ([], []) | |
# Append the description and the index of embs_keys to the corresponding lists | |
class_emb_dict[class_id][0].append(i) | |
class_emb_dict[class_id][1].append(description) | |
self.image_class_emb_dict = class_emb_dict | |
else: | |
self.image_embs = None | |
self.image_embs_ids = None | |
self.image_embs_class_ids = None | |
self.image_embs_keys = None | |
self.image_class_emb_dict = None | |
# Organize the data into the dictionary | |
per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(self.locs)) | |
per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats)) | |
for key, value in per_class_location_dict.items(): | |
per_class_location_dict[key] = torch.tensor(np.array(value)) # Keep on CPU | |
for key, value in per_class_loc_feats_dict.items(): | |
per_class_loc_feats_dict[key] = torch.tensor(np.array(value)) # Keep on CPU | |
self.per_class_locs = per_class_location_dict | |
self.per_class_loc_feats = per_class_loc_feats_dict | |
# env stuff | |
if self.use_env: | |
env_raster = load_env() | |
self.env_enc = utils.CoordEncoder('env', env_raster, input_dim=0) | |
# if self.env_enc is not None: | |
# self.env_enc.raster = self.env_enc.raster.to(device) | |
env_feats = self.env_enc.encode(self.locs, normalize=False) | |
per_class_env_feats_dict = organize_data_by_labels(np.array(labels), np.array(env_feats)) | |
for key, value in per_class_env_feats_dict.items(): | |
per_class_env_feats_dict[key] = torch.tensor(np.array(value)) # Keep on CPU | |
self.per_class_env_feats = per_class_env_feats_dict | |
else: | |
self.env_enc = None | |
self.per_class_env_feats = None | |
self.eval_mode = eval_mode | |
if eval_mode: | |
print('Using eval dataset. One example per class to generate eval embeddings') | |
# Select a single example per class | |
unique_labels, unique_indices = np.unique(labels, return_index=True) | |
self.locs = self.locs[unique_indices] | |
self.labels = labels[unique_indices] | |
self.loc_feats = self.loc_feats[unique_indices] | |
def __len__(self): | |
return self.loc_feats.shape[0] | |
def __getitem__(self, index): | |
# See what needs getting | |
type = self.roll_output() | |
# Retrieve the feature and class of the original point | |
loc_feat = self.loc_feats[index, :] | |
loc = self.locs[index, :] | |
class_id = self.labels[index] | |
class_id_int = class_id.item() | |
if 'text' in type: | |
# text stuff | |
# Get the embedding for the right class | |
if class_id_int in self.text_class_emb_dict: | |
text_embs_indexes, descriptions = self.text_class_emb_dict[class_id_int] | |
# Randomly select an index from the list of indices | |
selected_index = random.choice(text_embs_indexes) | |
# Use the selected index to retrieve the corresponding element from embs | |
text_emb = self.text_embs[selected_index] | |
else: | |
# Set emb to all zeros | |
text_emb = torch.zeros(self.text_emb_size) | |
else: | |
text_emb = torch.zeros(self.text_emb_size) | |
if 'image' in type: | |
# image stuff | |
if class_id_int in self.image_class_emb_dict: | |
image_embs_indexes, descriptions = self.image_class_emb_dict[class_id_int] | |
selected_index = random.choice(image_embs_indexes) | |
image_emb = self.image_embs[selected_index] | |
else: | |
image_emb = torch.zeros(self.image_emb_size) | |
else: | |
image_emb = torch.zeros(self.image_emb_size) | |
if 'loc' in type: | |
# Fetch all locations for the given class | |
all_class_locs = self.per_class_locs[class_id_int] | |
all_class_loc_feats = self.per_class_loc_feats[class_id_int] | |
# Define a unique class token index | |
class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # CPU tensor | |
# Broadcast and compare to find all matching locations | |
matches = (all_class_locs == loc).all(dim=1) | |
# Find the index of the original location | |
local_index = torch.where(matches)[0] | |
if len(local_index) > 1: | |
local_index = local_index[0] | |
# Exclude the original location's index | |
filtered_local_indices = torch.arange(len(all_class_locs)) != local_index | |
if self.variable_context_length: | |
num_context = random.randint(1, self.num_context) | |
else: | |
num_context = self.num_context | |
if filtered_local_indices.sum() > num_context: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
perm = torch.randperm(selected_indices.size(0)) | |
selected_indices = selected_indices[perm][:num_context] | |
else: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
# Get context locations and their features | |
context_loc_feats = all_class_loc_feats[selected_indices] | |
context_locs = all_class_locs[selected_indices] | |
# Check if context_loc_feats has 1 dimension and add another if it does | |
if context_loc_feats.dim() == 1: | |
context_loc_feats = context_loc_feats.unsqueeze(0) | |
if self.jitter: | |
noise_std = 0.001 | |
noise = torch.full_like(context_loc_feats, noise_std) | |
context_loc_feats = context_loc_feats + noise | |
# Check if context_locs has 1 dimension and add another if it does | |
if context_locs.dim() == 1: | |
context_locs = context_locs.unsqueeze(0) | |
context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
else: | |
class_token_feature = torch.zeros((1, len(loc_feat))) # CPU tensor | |
context_sequence = torch.cat([class_token_feature], dim=0) | |
context_locs = torch.empty((0, loc.size(0))) # Empty CPU tensor | |
if ('loc' in type) and ('env' not in type) and self.use_env: | |
env_feats = torch.zeros((len(selected_indices), self.env_emb_size)) | |
elif ('env' in type) and ('loc' in type): | |
all_class_env_feats = self.per_class_env_feats[class_id_int] | |
env_feats = all_class_env_feats[selected_indices] | |
if env_feats.dim() == 1: | |
env_feats = env_feats.unsqueeze(0) | |
env_feats = torch.cat([env_feats], dim=0) | |
else: | |
env_feats = torch.zeros((1, self.env_emb_size)) | |
return loc_feat, loc, class_id, context_sequence, context_locs, text_emb, image_emb, env_feats | |
def roll_output(self): | |
output = [] | |
if random.random() < self.loc_prob: | |
output.append('loc') | |
if random.random() < self.image_prob: | |
output.append('image') | |
if random.random() < self.text_prob: | |
output.append('text') | |
if random.random() < self.env_prob: | |
output.append('env') | |
if (len(output) == 0) and self.loc_prob != 0.0: | |
output.append('loc') | |
return output | |
def collate_fn(self, batch): | |
# Unzip the batch | |
loc_feats, locs, class_ids, context_sequences, context_locss, text_embs, image_embs, env_feats = zip(*batch) | |
# Convert list of sequences to a tensor with padding | |
padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10) | |
# Convert list of class IDs to a tensor | |
class_ids = torch.tensor(class_ids) | |
# Convert loc_feats and locs to tensors | |
loc_feats = torch.stack(loc_feats) | |
locs = torch.stack(locs) | |
text_embs = torch.stack(text_embs) | |
image_embs = torch.stack(image_embs) | |
padded_context_locs = pad_sequence(context_locss, batch_first=True, padding_value=-10) | |
if self.use_env: | |
padded_env_feats = pad_sequence(env_feats, batch_first=True, padding_value=-10) | |
else: | |
padded_env_feats = None | |
# Create a mask for sequences based on padding | |
sequence_mask = (padded_sequences == -10).all(dim=-1) | |
return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, text_embs, image_embs, padded_env_feats | |
def select_text_section(self, text_section): | |
# Initialize an empty dictionary to store the result | |
text_class_emb_dict = {} | |
# Populate the dictionary | |
for i, (index, description) in enumerate(self.text_embs_keys): | |
# Find the class using the index from the class_list | |
class_id = self.text_embs_class_ids[index] | |
# Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
if class_id == -1: | |
continue | |
if description != text_section: | |
continue | |
# Check if the class_id is already a key in the dictionary | |
if class_id not in text_class_emb_dict: | |
# Initialize with empty lists if class_id is not already in the dictionary | |
text_class_emb_dict[class_id] = ([], []) | |
# Append the description and the index of embs_keys to the corresponding lists | |
text_class_emb_dict[class_id][0].append(i) | |
text_class_emb_dict[class_id][1].append(description) | |
self.text_class_emb_dict = text_class_emb_dict | |
# class TransformerDatasetVariableTokens(torch.utils.data.Dataset): | |
# def __init__(self, locs, labels, classes, class_to_taxa, text_embs, text_embs_ids, text_embs_keys, image_embs, | |
# image_embs_ids, image_embs_keys, input_enc, device, dates=None, input_dim=4, time_dim=0, | |
# noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False, | |
# variable_context_length=False, loc_prob=1.0, text_prob=0.0, image_prob=0.0, env_prob=0.0, | |
# eval_mode=False): | |
# # handle input encoding: | |
# # ensure all stuff is on cpu until batching | |
# self.input_enc = input_enc | |
# self.jitter=jitter | |
# self.variable_context_length=variable_context_length | |
# if self.input_enc in ['env', 'sin_cos_env']: | |
# raster = load_env() | |
# else: | |
# raster = None | |
# self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# | |
# # set up what token types and probs of use we have | |
# self.loc_prob = loc_prob | |
# self.text_prob = text_prob | |
# self.image_prob = image_prob | |
# self.env_prob = env_prob | |
# | |
# if self.loc_prob == 0.0: | |
# self.use_loc = False | |
# else: | |
# self.use_loc = True | |
# if self.text_prob == 0.0: | |
# self.use_text = False | |
# else: | |
# self.use_text = True | |
# if self.image_prob == 0.0: | |
# self.use_image = False | |
# else: | |
# self.use_image = True | |
# if self.env_prob == 0.0: | |
# self.use_env = False | |
# else: | |
# self.use_env = True | |
# | |
# # handle transformer input encoding: | |
# self.transformer_input_enc = transformer_input_enc | |
# if self.transformer_input_enc in ['env', 'sin_cos_env']: | |
# transformer_raster = load_env() | |
# else: | |
# transformer_raster = None | |
# if self.transformer_input_enc == 'sinr': | |
# self.transformer_enc = self.enc | |
# else: | |
# self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim) | |
# | |
# if self.enc.raster is not None: | |
# self.enc.raster = self.enc.raster.to(device) | |
# | |
# if self.transformer_enc.raster is not None: | |
# self.transformer_enc.raster = self.transformer_enc.raster.to(device) | |
# | |
# # define some properties: | |
# self.locs = locs.to(device) | |
# # Below line also normalises locs as well as making loc feats | |
# self.loc_feats = self.enc.encode(self.locs, normalize=True) | |
# transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False) | |
# self.labels = labels.to(device) | |
# self.classes = classes | |
# self.class_to_taxa = class_to_taxa | |
# if dates is not None: | |
# self.dates = dates | |
# self.enc_time = utils.TimeEncoder() | |
# | |
# # useful numbers: | |
# self.num_classes = len(np.unique(labels)) | |
# self.input_dim = input_dim | |
# self.time_dim = time_dim | |
# self.noise_time = noise_time | |
# self.num_context = num_context | |
# self.token_dim = token_dim | |
# self.env_emb_size = 20 | |
# self.image_emb_size = 1024 | |
# self.text_emb_size = 4096 | |
# self.device=device | |
# | |
# # text stuff | |
# if self.use_text: | |
# self.text_embs = text_embs.to(device) | |
# self.text_embs_ids = text_embs_ids.tolist() | |
# self.text_embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.text_embs_ids] | |
# self.text_embs_keys = text_embs_keys | |
# | |
# # Initialize an empty dictionary to store the result | |
# class_emb_dict = {} | |
# # Populate the dictionary | |
# for i, (index, description) in enumerate(text_embs_keys): | |
# # Find the class using the index from the class_list | |
# class_id = self.text_embs_class_ids[index] | |
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
# if class_id == -1: | |
# continue | |
# # Check if the class_id is already a key in the dictionary | |
# if class_id not in class_emb_dict: | |
# # Initialize with empty lists if class_id is not already in the dictionary | |
# class_emb_dict[class_id] = ([], []) | |
# | |
# # Append the description and the index of embs_keys to the corresponding lists | |
# class_emb_dict[class_id][0].append(i) | |
# class_emb_dict[class_id][1].append(description) | |
# self.text_class_emb_dict = class_emb_dict | |
# else: | |
# self.text_embs = None | |
# self.text_embs_ids = None | |
# self.text_embs_class_ids = None | |
# self.text_embs_keys = None | |
# self.text_class_emb_dict = None | |
# | |
# # image stuff | |
# if self.use_image: | |
# self.image_embs = image_embs.to(device) | |
# self.image_embs_ids = image_embs_ids.tolist() | |
# self.image_embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.image_embs_ids] | |
# self.image_embs_keys = image_embs_keys | |
# | |
# # Initialize an empty dictionary to store the result | |
# class_emb_dict = {} | |
# # Populate the dictionary | |
# for i, (index, description) in enumerate(image_embs_keys): | |
# # Find the class using the index from the class_list | |
# class_id = self.image_embs_class_ids[index] | |
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
# if class_id == -1: | |
# continue | |
# # Check if the class_id is already a key in the dictionary | |
# if class_id not in class_emb_dict: | |
# # Initialize with empty lists if class_id is not already in the dictionary | |
# class_emb_dict[class_id] = ([], []) | |
# | |
# # Append the description and the index of embs_keys to the corresponding lists | |
# class_emb_dict[class_id][0].append(i) | |
# class_emb_dict[class_id][1].append(description) | |
# self.image_class_emb_dict = class_emb_dict | |
# else: | |
# self.image_embs = None | |
# self.image_embs_ids = None | |
# self.image_embs_class_ids = None | |
# self.image_embs_keys = None | |
# self.image_class_emb_dict = None | |
# | |
# # Organize the data into the dictionary | |
# per_class_location_dict = organize_data_by_labels(np.array(labels.cpu()), np.array(self.locs.cpu())) | |
# per_class_loc_feats_dict = organize_data_by_labels(np.array(labels.cpu()), np.array(transformer_loc_feats.cpu())) | |
# for key, value in per_class_location_dict.items(): | |
# per_class_location_dict[key] = torch.tensor(np.array(value)).to(device) | |
# for key, value in per_class_loc_feats_dict.items(): | |
# per_class_loc_feats_dict[key] = torch.tensor(np.array(value)).to(device) | |
# self.per_class_locs = per_class_location_dict | |
# self.per_class_loc_feats = per_class_loc_feats_dict | |
# | |
# # env stuff - semi useless for now - I intend to give "paired" locs and env data but that is for later on | |
# # additionally we also need the env data if self.use_env is true | |
# if self.use_env: | |
# env_raster = load_env() | |
# self.env_enc = utils.CoordEncoder('env', env_raster, input_dim=0) | |
# if self.env_enc is not None: | |
# self.env_enc.raster = self.env_enc.raster.to(device) | |
# env_feats = self.env_enc.encode(self.locs, normalize=False) | |
# per_class_env_feats_dict = organize_data_by_labels(np.array(labels.cpu()), np.array(env_feats.cpu())) | |
# for key, value in per_class_env_feats_dict.items(): | |
# per_class_env_feats_dict[key] = torch.tensor(np.array(value)).to(device) | |
# self.per_class_env_feats = per_class_env_feats_dict | |
# else: | |
# self.env_enc = None | |
# self.per_class_env_feats = None | |
# | |
# self.eval_mode = eval_mode | |
# | |
# if eval_mode: | |
# print('Using eval dataset. One example per class to generate eval embeddings') | |
# # Select a single example per class | |
# unique_labels, unique_indices = np.unique(labels.cpu(), return_index=True) | |
# self.locs = self.locs[unique_indices].to(device) | |
# self.labels = labels[unique_indices].to(device) | |
# self.loc_feats = self.loc_feats[unique_indices].to(device) | |
# | |
# | |
# def __len__(self): | |
# return self.loc_feats.shape[0] | |
# | |
# def __getitem__(self, index): | |
# # See what needs getting | |
# type = self.roll_output() | |
# | |
# # Retrieve the feature and class of the original point | |
# loc_feat = self.loc_feats[index, :] | |
# loc = self.locs[index, :] | |
# class_id = self.labels[index] | |
# class_id_int = class_id.item() | |
# | |
# if 'text' in type: | |
# # text stuff | |
# # Get the embedding for the right class | |
# if class_id_int in self.text_class_emb_dict: | |
# text_embs_indexes, descriptions = self.text_class_emb_dict[class_id_int] | |
# # Randomly select an index from the list of indices | |
# selected_index = random.choice(text_embs_indexes) | |
# # Use the selected index to retrieve the corresponding element from embs | |
# text_emb = self.text_embs[selected_index] | |
# else: | |
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later? | |
# # emb = None | |
# text_emb = torch.zeros(self.text_emb_size, device=self.device) | |
# # if type == 'text': | |
# # type = 'obs_text' | |
# else: | |
# text_emb = torch.zeros(self.text_emb_size, device=self.device) | |
# | |
# if 'image' in type: | |
# # text stuff | |
# # Get the embedding for the right class | |
# if class_id_int in self.image_class_emb_dict: | |
# image_embs_indexes, descriptions = self.image_class_emb_dict[class_id_int] | |
# # Randomly select an index from the list of indices | |
# selected_index = random.choice(image_embs_indexes) | |
# # Use the selected index to retrieve the corresponding element from embs | |
# image_emb = self.image_embs[selected_index] | |
# else: | |
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later? | |
# # emb = None | |
# image_emb = torch.zeros(self.image_emb_size, device=self.device) | |
# | |
# else: | |
# image_emb = torch.zeros(self.image_emb_size, device=self.device) | |
# | |
# if 'loc' in type: | |
# # Fetch all locations for the given class | |
# all_class_locs = self.per_class_locs[class_id_int] | |
# all_class_loc_feats = self.per_class_loc_feats[class_id_int] | |
# | |
# # Define a unique class token index | |
# class_token_feature = torch.zeros((1, len(all_class_loc_feats[0])), device=self.device) # Create a zero vector for the class token | |
# | |
# # Broadcast and compare to find all matching locations | |
# matches = (all_class_locs == loc).all(dim=1) | |
# | |
# # Find the index of the original location | |
# local_index = torch.where(matches)[0] | |
# if len(local_index) > 1: | |
# local_index = local_index[0] | |
# | |
# # Exclude the original location's index | |
# filtered_local_indices = torch.arange(len(all_class_locs), device=self.device) != local_index | |
# | |
# if self.variable_context_length: | |
# num_context = random.randint(1, self.num_context) | |
# # num_context = torch.randint(1, self.num_context + 1, (1,), device=self.device).item() | |
# else: | |
# num_context = self.num_context | |
# | |
# if filtered_local_indices.sum() > num_context: | |
# selected_indices = filtered_local_indices.nonzero().squeeze() | |
# perm = torch.randperm(selected_indices.size(0), device=self.device) | |
# selected_indices = selected_indices[perm][:num_context] | |
# else: | |
# selected_indices = filtered_local_indices.nonzero().squeeze() | |
# | |
# # Get context locations and their features | |
# context_loc_feats = all_class_loc_feats[selected_indices] | |
# context_locs = all_class_locs[selected_indices] | |
# | |
# # Check if context_loc_feats has 1 dimension and add another if it does | |
# if context_loc_feats.dim() == 1: | |
# context_loc_feats = context_loc_feats.unsqueeze(0) | |
# | |
# if self.jitter: | |
# noise_std = 0.001 | |
# noise = torch.full_like(context_loc_feats, noise_std) | |
# context_loc_feats = context_loc_feats + noise | |
# | |
# # Check if context_locs has 1 dimension and add another if it does | |
# if context_locs.dim() == 1: | |
# context_locs = context_locs.unsqueeze(0) | |
# | |
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# else: | |
# class_token_feature = torch.zeros((1, len(loc_feat)), device=self.device) # Create a zero vector for the class token | |
# # print('good chance this doesnt work and need to squeeze / unsqueeze instead. Check the dims') | |
# # nope, it is good. | |
# context_sequence = torch.cat([class_token_feature], dim=0) | |
# context_locs = torch.empty((0, loc.size(0)), device=self.device) # An empty tensor with the right number of dimensions | |
# | |
# if ('loc' in type) and ('env' not in type) and self.use_env: | |
# env_feats = torch.zeros((len(selected_indices), self.env_emb_size), device=self.device) | |
# | |
# elif ('env' in type) and ('loc' in type): | |
# all_class_env_feats = self.per_class_env_feats[class_id_int] | |
# env_feats = all_class_env_feats[selected_indices] | |
# # Check if context_loc_feats has 1 dimension and add another if it does | |
# if env_feats.dim() == 1: | |
# env_feats = env_feats.unsqueeze(0) | |
# env_feats = torch.cat([env_feats], dim=0) | |
# else: | |
# env_feats = torch.zeros((1, self.env_emb_size), device=self.device) | |
# | |
# return loc_feat, loc, class_id, context_sequence, context_locs, text_emb, image_emb, env_feats | |
# | |
# def roll_output(self): | |
# output = [] | |
# if random.random() < self.loc_prob: | |
# output.append('loc') | |
# if random.random() < self.image_prob: | |
# output.append('image') | |
# if random.random() < self.text_prob: | |
# output.append('text') | |
# if random.random() < self.env_prob: | |
# output.append('env') | |
# # if len(output) == 0: | |
# # output.append('loc') | |
# return output | |
# | |
# def collate_fn(self, batch): | |
# # Unzip the batch | |
# loc_feats, locs, class_ids, context_sequences, context_locss, text_embs, image_embs, env_feats = zip(*batch) | |
# | |
# # Convert list of sequences to a tensor with padding | |
# padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10) | |
# | |
# # Convert list of class IDs to a tensor | |
# class_ids = torch.tensor(class_ids) | |
# # Convert loc_feats and locs to tensors | |
# loc_feats = torch.stack(loc_feats) | |
# locs = torch.stack(locs) | |
# text_embs = torch.stack(text_embs) | |
# image_embs = torch.stack(image_embs) | |
# | |
# padded_context_locs = pad_sequence(context_locss, batch_first=True, padding_value=-10) | |
# | |
# if self.use_env: | |
# padded_env_feats = pad_sequence(env_feats, batch_first=True, padding_value=-10) | |
# else: | |
# padded_env_feats = None | |
# | |
# # Create a mask for sequences based on padding | |
# # sequence_mask = (padded_sequences == 0) # Create a mask where there's padding (-10s) | |
# sequence_mask = (padded_sequences == -10).all(dim=-1) | |
# | |
# # return padded_sequences, padded_locs, class_ids, sequence_mask | |
# return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, text_embs, image_embs, padded_env_feats | |
# | |
# def get_item_from_class(self, class_id): | |
# # Fetch all locations and features for the given class | |
# all_class_locs = self.per_class_locs[class_id] | |
# all_class_loc_feats = self.per_class_loc_feats[class_id] | |
# | |
# # Randomly select an index for the class | |
# index = np.random.choice(len(all_class_locs)) | |
# | |
# # Retrieve the selected location and its features | |
# loc = all_class_locs[index] | |
# if loc.ndim == 1: | |
# loc = loc.unsqueeze(0) | |
# # loc = loc.unsqueeze(0) | |
# loc_feat = self.enc.encode(loc, normalize=False) | |
# # loc_feat = all_class_loc_feats[index] | |
# | |
# # Define a unique class token index | |
# class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token | |
# | |
# # Exclude the selected index from the context | |
# filtered_local_indices = torch.arange(len(all_class_locs)) != index | |
# | |
# # Select random or all indices depending on the availability relative to `num_context` | |
# if filtered_local_indices.sum() > self.num_context: | |
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
# np.random.shuffle(selected_indices) | |
# selected_indices = selected_indices[:self.num_context] | |
# else: | |
# selected_indices = filtered_local_indices.nonzero().squeeze() | |
# | |
# # Get context locations and their features | |
# context_loc_feats = all_class_loc_feats[selected_indices] | |
# context_locs = all_class_locs[selected_indices] | |
# | |
# # Check if context_loc_feats has 1 dimension and add another if it does | |
# if context_loc_feats.dim() == 1: | |
# context_loc_feats = context_loc_feats.unsqueeze(0) | |
# # Check if context_locs has 1 dimension and add another if it does | |
# if context_locs.dim() == 1: | |
# context_locs = context_locs.unsqueeze(0) | |
# | |
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# | |
# # text stuff | |
# # Get the embedding for the right class | |
# if class_id in self.text_class_emb_dict: | |
# text_embs_indexes, descriptions = self.text_class_emb_dict[class_id] | |
# # Randomly select an index from the list of indices | |
# selected_index = random.choice(text_embs_indexes) | |
# # Use the selected index to retrieve the corresponding element from embs | |
# text_emb = self.text_embs[selected_index] | |
# else: | |
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later? | |
# # emb = None | |
# text_emb = torch.zeros(self.text_emb_size) | |
# | |
# # image stuff | |
# # Get the embedding for the right class | |
# if class_id in self.image_class_emb_dict: | |
# image_embs_indexes, descriptions = self.image_class_emb_dict[class_id] | |
# # Randomly select an index from the list of indices | |
# selected_index = random.choice(image_embs_indexes) | |
# # Use the selected index to retrieve the corresponding element from embs | |
# image_emb = self.image_embs[selected_index] | |
# else: | |
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later? | |
# # emb = None | |
# image_emb = torch.zeros(self.image_emb_size) | |
# | |
# return loc_feat, loc, class_id, context_sequence, context_locs, text_emb, image_emb | |
# | |
# | |
# class EvalTransformerDatasetVariableTokens(torch.utils.data.Dataset): | |
# def __init__(self, locs, labels, classes, class_to_taxa, text_embs, text_embs_ids, text_embs_keys, image_embs, | |
# image_embs_ids, image_embs_keys, input_enc, device, dates=None, input_dim=4, time_dim=0, | |
# noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False, | |
# variable_context_length=False, loc_prob=1.0, text_prob=0.0, image_prob=0.0, env_prob=0.0): | |
# # handle input encoding: | |
# self.input_enc = input_enc | |
# self.jitter=jitter | |
# self.variable_context_length=variable_context_length | |
# if self.input_enc in ['env', 'sin_cos_env']: | |
# raster = load_env() | |
# else: | |
# raster = None | |
# self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# | |
# # set up what token types and probs of use we have | |
# self.loc_prob = loc_prob | |
# self.text_prob = text_prob | |
# self.image_prob = image_prob | |
# self.env_prob = env_prob | |
# | |
# if self.loc_prob == 0.0: | |
# self.use_loc = False | |
# else: | |
# self.use_loc = True | |
# if self.text_prob == 0.0: | |
# self.use_text = False | |
# else: | |
# self.use_text = True | |
# if self.image_prob == 0.0: | |
# self.use_image = False | |
# else: | |
# self.use_image = True | |
# if self.env_prob == 0.0: | |
# self.use_env = False | |
# else: | |
# self.use_env = True | |
# | |
# # handle transformer input encoding: | |
# self.transformer_input_enc = transformer_input_enc | |
# if self.transformer_input_enc in ['env', 'sin_cos_env']: | |
# transformer_raster = load_env() | |
# else: | |
# transformer_raster = None | |
# if self.transformer_input_enc == 'sinr': | |
# self.transformer_enc = self.enc | |
# else: | |
# self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim) | |
# | |
# # env stuff - semi useless for now - I intend to give "paired" locs and env data but that is for later on | |
# # additionally we also need the env data if self.use_env is true | |
# if self.use_env: | |
# env_raster = load_env() | |
# self.env_enc = utils.CoordEncoder('env', env_raster, input_dim=0) | |
# else: | |
# self.env_enc = None | |
# | |
# # define some properties: | |
# self.locs = locs | |
# # Below line also normalises locs as well as making loc feats | |
# self.loc_feats = self.enc.encode(self.locs, normalize=True) | |
# transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False) | |
# self.labels = labels | |
# self.classes = classes | |
# self.class_to_taxa = class_to_taxa | |
# if dates is not None: | |
# self.dates = dates | |
# self.enc_time = utils.TimeEncoder() | |
# | |
# # useful numbers: | |
# self.num_classes = len(np.unique(labels)) | |
# self.input_dim = input_dim | |
# self.time_dim = time_dim | |
# self.noise_time = noise_time | |
# self.num_context = num_context | |
# self.token_dim = token_dim | |
# | |
# if self.enc.raster is not None: | |
# self.enc.raster = self.enc.raster.to(device) | |
# | |
# if self.transformer_enc.raster is not None: | |
# self.transformer_enc.raster = self.transformer_enc.raster.to(device) | |
# | |
# if self.env_enc is not None: | |
# self.env_enc.raster = self.env_enc.raster.to(device) | |
# | |
# # text stuff | |
# if self.use_text: | |
# self.text_embs = text_embs | |
# self.text_embs_ids = text_embs_ids.tolist() | |
# self.text_embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.text_embs_ids] | |
# self.text_embs_keys = text_embs_keys | |
# self.text_emb_size = 4096 | |
# | |
# # Initialize an empty dictionary to store the result | |
# class_emb_dict = {} | |
# # Populate the dictionary | |
# for i, (index, description) in enumerate(text_embs_keys): | |
# # Find the class using the index from the class_list | |
# class_id = self.image_embs_class_ids[index] | |
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
# if class_id == -1: | |
# continue | |
# # Check if the class_id is already a key in the dictionary | |
# if class_id not in class_emb_dict: | |
# # Initialize with empty lists if class_id is not already in the dictionary | |
# class_emb_dict[class_id] = ([], []) | |
# | |
# # Append the description and the index of embs_keys to the corresponding lists | |
# class_emb_dict[class_id][0].append(i) | |
# class_emb_dict[class_id][1].append(description) | |
# self.text_class_emb_dict = class_emb_dict | |
# else: | |
# self.text_embs = None | |
# self.text_embs_ids = None | |
# self.text_embs_class_ids = None | |
# self.text_embs_keys = None | |
# self.text_class_emb_dict = None | |
# self.text_emb_size = None | |
# | |
# # image stuff | |
# if self.use_image: | |
# self.image_embs = image_embs | |
# self.image_embs_ids = image_embs_ids.tolist() | |
# self.image_embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.image_embs_ids] | |
# self.image_embs_keys = image_embs_keys | |
# self.image_emb_size = 1024 | |
# | |
# # Initialize an empty dictionary to store the result | |
# class_emb_dict = {} | |
# # Populate the dictionary | |
# for i, (index, description) in enumerate(image_embs_keys): | |
# # Find the class using the index from the class_list | |
# class_id = self.image_embs_class_ids[index] | |
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
# if class_id == -1: | |
# continue | |
# # Check if the class_id is already a key in the dictionary | |
# if class_id not in class_emb_dict: | |
# # Initialize with empty lists if class_id is not already in the dictionary | |
# class_emb_dict[class_id] = ([], []) | |
# | |
# # Append the description and the index of embs_keys to the corresponding lists | |
# class_emb_dict[class_id][0].append(i) | |
# class_emb_dict[class_id][1].append(description) | |
# self.image_class_emb_dict = class_emb_dict | |
# else: | |
# self.image_embs = None | |
# self.image_embs_ids = None | |
# self.image_embs_class_ids = None | |
# self.image_embs_keys = None | |
# self.image_class_emb_dict = None | |
# self.image_emb_size = None | |
# | |
# # Organize the data into the dictionary | |
# per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(locs)) | |
# per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats)) | |
# for key, value in per_class_location_dict.items(): | |
# per_class_location_dict[key] = torch.tensor(np.array(value)) | |
# for key, value in per_class_loc_feats_dict.items(): | |
# per_class_loc_feats_dict[key] = torch.tensor(np.array(value)) | |
# self.per_class_locs = per_class_location_dict | |
# self.per_class_loc_feats = per_class_loc_feats_dict | |
# | |
# # MOD FOR EVAL MOD FOR EVAL | |
# # Select a single example per class | |
# unique_labels, unique_indices = np.unique(labels, return_index=True) | |
# self.locs = locs[unique_indices] | |
# self.labels = labels[unique_indices] | |
# self.loc_feats = self.loc_feats[unique_indices] | |
# | |
# # self.just_obs_prob = just_obs_prob | |
# # self.just_text_prob = just_text_prob | |
# # self.both_prob = 1 - (just_obs_prob + just_text_prob) | |
# # if self.both_prob < 0: | |
# # raise ValueError('Probability of "just text" and "just observations" must sum to less than 1') | |
# | |
# def __len__(self): | |
# return self.loc_feats.shape[0] | |
# | |
# def __getitem__(self, index): | |
# # See what needs getting | |
# type = self.roll_output() | |
# | |
# # Retrieve the feature and class of the original point | |
# loc_feat = self.loc_feats[index, :] | |
# loc = self.locs[index, :] | |
# class_id = self.labels[index] | |
# class_id_int = class_id.item() | |
# | |
# if 'text' in type: | |
# # text stuff | |
# # Get the embedding for the right class | |
# if class_id_int in self.text_class_emb_dict: | |
# text_embs_indexes, descriptions = self.text_class_emb_dict[class_id_int] | |
# # Randomly select an index from the list of indices | |
# selected_index = random.choice(text_embs_indexes) | |
# # Use the selected index to retrieve the corresponding element from embs | |
# text_emb = self.text_embs[selected_index] | |
# else: | |
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later? | |
# # emb = None | |
# text_emb = torch.zeros(self.text_emb_size) | |
# # if type == 'text': | |
# # type = 'obs_text' | |
# else: | |
# text_emb = torch.zeros(self.text_emb_size) | |
# | |
# if 'image' in type: | |
# # text stuff | |
# # Get the embedding for the right class | |
# if class_id_int in self.image_class_emb_dict: | |
# image_embs_indexes, descriptions = self.image_class_emb_dict[class_id_int] | |
# # Randomly select an index from the list of indices | |
# selected_index = random.choice(image_embs_indexes) | |
# # Use the selected index to retrieve the corresponding element from embs | |
# image_emb = self.image_embs[selected_index] | |
# else: | |
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later? | |
# # emb = None | |
# image_emb = torch.zeros(self.image_emb_size) | |
# | |
# else: | |
# image_emb = torch.zeros(self.image_emb_size) | |
# | |
# if 'loc' in type: | |
# # Fetch all locations for the given class | |
# all_class_locs = self.per_class_locs[class_id_int] | |
# all_class_loc_feats = self.per_class_loc_feats[class_id_int] | |
# | |
# # Define a unique class token index | |
# class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # Create a zero vector for the class token | |
# | |
# # Broadcast and compare to find all matching locations | |
# matches = (all_class_locs == loc).all(dim=1) | |
# | |
# # Find the index of the original location | |
# local_index = torch.where(matches)[0] | |
# if len(local_index) > 1: | |
# local_index = local_index[0] | |
# | |
# # Exclude the original location's index | |
# filtered_local_indices = torch.arange(len(all_class_locs)) != local_index | |
# | |
# if self.variable_context_length: | |
# num_context = random.randint(1, self.num_context) | |
# else: | |
# num_context = self.num_context | |
# | |
# # Select random or all indices depending on the availability relative to `num_context` | |
# if filtered_local_indices.sum() > num_context: | |
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
# np.random.shuffle(selected_indices) | |
# selected_indices = selected_indices[:num_context] | |
# else: | |
# selected_indices = filtered_local_indices.nonzero().squeeze() | |
# | |
# # Get context locations and their features | |
# context_loc_feats = all_class_loc_feats[selected_indices] | |
# context_locs = all_class_locs[selected_indices] | |
# | |
# # Check if context_loc_feats has 1 dimension and add another if it does | |
# if context_loc_feats.dim() == 1: | |
# context_loc_feats = context_loc_feats.unsqueeze(0) | |
# | |
# if self.jitter: | |
# noise_std = 0.001 | |
# noise = torch.full_like(context_loc_feats, noise_std) | |
# context_loc_feats = context_loc_feats + noise | |
# | |
# # Check if context_locs has 1 dimension and add another if it does | |
# if context_locs.dim() == 1: | |
# context_locs = context_locs.unsqueeze(0) | |
# | |
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# else: | |
# class_token_feature = torch.zeros((1, len(loc_feat))) # Create a zero vector for the class token | |
# # print('good chance this doesnt work and need to squeeze / unsqueeze instead. Check the dims') | |
# # nope, it is good. | |
# context_sequence = torch.cat([class_token_feature], dim=0) | |
# context_locs = torch.empty((0, loc.size(0))) # An empty tensor with the right number of dimensions | |
# | |
# return loc_feat, loc, class_id, context_sequence, context_locs, text_emb, image_emb | |
# | |
# def roll_output(self): | |
# import random | |
# output = [] | |
# | |
# if random.random() < self.loc_prob: | |
# output.append('loc') | |
# if random.random() < self.image_prob: | |
# output.append('image') | |
# if random.random() < self.text_prob: | |
# output.append('text') | |
# if random.random() < self.env_prob: | |
# output.append('env') | |
# # if len(output) == 0: | |
# # output.append('loc') | |
# return output | |
# | |
# def collate_fn(self, batch): | |
# # Unzip the batch | |
# loc_feats, locs, class_ids, context_sequences, context_locss, text_embs, image_embs = zip(*batch) | |
# | |
# # Convert list of sequences to a tensor with padding | |
# padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10) | |
# | |
# # Convert list of class IDs to a tensor | |
# class_ids = torch.tensor(class_ids) | |
# # Convert loc_feats and locs to tensors | |
# loc_feats = torch.stack(loc_feats) | |
# locs = torch.stack(locs) | |
# text_embs = torch.stack(text_embs) | |
# image_embs = torch.stack(image_embs) | |
# | |
# padded_context_locs = pad_sequence(context_locss, batch_first=True, padding_value=-10) | |
# | |
# # Create a mask for sequences based on padding | |
# # sequence_mask = (padded_sequences == 0) # Create a mask where there's padding (-10s) | |
# sequence_mask = (padded_sequences == -10).all(dim=-1) | |
# | |
# # return padded_sequences, padded_locs, class_ids, sequence_mask | |
# return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, text_embs, image_embs | |
# | |
# def get_item_from_class(self, class_id): | |
# # Fetch all locations and features for the given class | |
# all_class_locs = self.per_class_locs[class_id] | |
# all_class_loc_feats = self.per_class_loc_feats[class_id] | |
# | |
# # Randomly select an index for the class | |
# index = np.random.choice(len(all_class_locs)) | |
# | |
# # Retrieve the selected location and its features | |
# loc = all_class_locs[index] | |
# if loc.ndim == 1: | |
# loc = loc.unsqueeze(0) | |
# # loc = loc.unsqueeze(0) | |
# loc_feat = self.enc.encode(loc, normalize=False) | |
# # loc_feat = all_class_loc_feats[index] | |
# | |
# # Define a unique class token index | |
# class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token | |
# | |
# # Exclude the selected index from the context | |
# filtered_local_indices = torch.arange(len(all_class_locs)) != index | |
# | |
# # Select random or all indices depending on the availability relative to `num_context` | |
# if filtered_local_indices.sum() > self.num_context: | |
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
# np.random.shuffle(selected_indices) | |
# selected_indices = selected_indices[:self.num_context] | |
# else: | |
# selected_indices = filtered_local_indices.nonzero().squeeze() | |
# | |
# # Get context locations and their features | |
# context_loc_feats = all_class_loc_feats[selected_indices] | |
# context_locs = all_class_locs[selected_indices] | |
# | |
# # Check if context_loc_feats has 1 dimension and add another if it does | |
# if context_loc_feats.dim() == 1: | |
# context_loc_feats = context_loc_feats.unsqueeze(0) | |
# # Check if context_locs has 1 dimension and add another if it does | |
# if context_locs.dim() == 1: | |
# context_locs = context_locs.unsqueeze(0) | |
# | |
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# | |
# # text stuff | |
# # Get the embedding for the right class | |
# if class_id in self.text_class_emb_dict: | |
# text_embs_indexes, descriptions = self.text_class_emb_dict[class_id] | |
# # Randomly select an index from the list of indices | |
# selected_index = random.choice(text_embs_indexes) | |
# # Use the selected index to retrieve the corresponding element from embs | |
# text_emb = self.text_embs[selected_index] | |
# else: | |
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later? | |
# # emb = None | |
# text_emb = torch.zeros(self.text_emb_size) | |
# | |
# # image stuff | |
# # Get the embedding for the right class | |
# if class_id in self.image_class_emb_dict: | |
# image_embs_indexes, descriptions = self.image_class_emb_dict[class_id] | |
# # Randomly select an index from the list of indices | |
# selected_index = random.choice(image_embs_indexes) | |
# # Use the selected index to retrieve the corresponding element from embs | |
# image_emb = self.image_embs[selected_index] | |
# else: | |
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later? | |
# # emb = None | |
# image_emb = torch.zeros(self.image_emb_size) | |
# | |
# return loc_feat, loc, class_id, context_sequence, context_locs, text_emb, image_emb | |
# | |
# def select_text_section(self, text_section): | |
# # Initialize an empty dictionary to store the result | |
# text_class_emb_dict = {} | |
# # Populate the dictionary | |
# for i, (index, description) in enumerate(self.text_embs_keys): | |
# # Find the class using the index from the class_list | |
# class_id = self.text_embs_class_ids[index] | |
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
# if class_id == -1: | |
# continue | |
# if description != text_section: | |
# continue | |
# # Check if the class_id is already a key in the dictionary | |
# if class_id not in text_class_emb_dict: | |
# # Initialize with empty lists if class_id is not already in the dictionary | |
# text_class_emb_dict[class_id] = ([], []) | |
# | |
# # Append the description and the index of embs_keys to the corresponding lists | |
# text_class_emb_dict[class_id][0].append(i) | |
# text_class_emb_dict[class_id][1].append(description) | |
# self.text_class_emb_dict = text_class_emb_dict | |
# MINE MINE MINE MINE - should only be used for my models | |
# this eval version always has only a single example from each class (with the appropriate number of context points) | |
# So on iterating through the dataset we can create our models "eval embeddings" for every class | |
# Each eval embedding is generated from a single forward pass of "num_context" context points | |
class EvalTransformerLocationDataset(torch.utils.data.Dataset): | |
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, | |
noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False, | |
variable_context_length=False): | |
# Handle input encoding | |
self.input_enc = input_enc | |
if self.input_enc in ['env', 'sin_cos_env']: | |
raster = load_env() | |
else: | |
raster = None | |
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# Handle transformer input encoding | |
self.transformer_input_enc = transformer_input_enc | |
if self.transformer_input_enc in ['env', 'sin_cos_env']: | |
transformer_raster = load_env() | |
else: | |
transformer_raster = None | |
if self.transformer_input_enc == 'sinr': | |
self.transformer_enc = self.enc | |
else: | |
self.transformer_enc = utils.CoordEncoder( | |
transformer_input_enc, transformer_raster, input_dim=token_dim) | |
# Define properties | |
self.locs = locs # Keep on CPU | |
self.labels = labels # Keep on CPU | |
self.classes = classes | |
self.class_to_taxa = class_to_taxa | |
# Normalize locs and create loc_feats | |
self.loc_feats = self.enc.encode(self.locs, normalize=True) | |
transformer_loc_feats = self.transformer_enc.encode( | |
self.locs, normalize=False) | |
if dates is not None: | |
self.dates = dates | |
self.enc_time = utils.TimeEncoder() | |
# Useful numbers | |
self.num_classes = len(np.unique(labels)) | |
self.input_dim = input_dim | |
self.time_dim = time_dim | |
self.noise_time = noise_time | |
self.num_context = num_context | |
self.token_dim = token_dim | |
# Remove device assignments from rasters | |
# if self.enc.raster is not None: | |
# self.enc.raster = self.enc.raster.to(device) | |
# if self.transformer_enc.raster is not None: | |
# self.transformer_enc.raster = self.transformer_enc.raster.to(device) | |
# Organize data into dictionaries | |
per_class_location_dict = organize_data_by_labels( | |
np.array(labels), np.array(locs)) | |
per_class_loc_feats_dict = organize_data_by_labels( | |
np.array(labels), np.array(transformer_loc_feats)) | |
for key, value in per_class_location_dict.items(): | |
per_class_location_dict[key] = torch.tensor( | |
np.array(value)) # Keep on CPU | |
for key, value in per_class_loc_feats_dict.items(): | |
per_class_loc_feats_dict[key] = torch.tensor( | |
np.array(value)) # Keep on CPU | |
self.per_class_locs = per_class_location_dict | |
self.per_class_loc_feats = per_class_loc_feats_dict | |
# Select a single example per class | |
unique_labels, unique_indices = np.unique(labels, return_index=True) | |
self.locs = locs[unique_indices] | |
self.labels = labels[unique_indices] | |
self.loc_feats = self.loc_feats[unique_indices] | |
def __len__(self): | |
return self.loc_feats.shape[0] | |
def __getitem__(self, index): | |
# Retrieve feature and class of the original point | |
loc_feat = self.loc_feats[index, :] # On CPU | |
loc = self.locs[index, :] # On CPU | |
class_id = self.labels[index] # On CPU | |
class_id_int = class_id.item() | |
# Fetch all locations for the class | |
all_class_locs = self.per_class_locs[class_id_int] # On CPU | |
all_class_loc_feats = self.per_class_loc_feats[class_id_int] # On CPU | |
# Define a unique class token index | |
class_token_feature = torch.zeros( | |
(1, len(all_class_loc_feats[0]))) # CPU tensor | |
# Find the index of the original location | |
matches = (all_class_locs == loc).all(dim=1) | |
local_index = torch.where(matches)[0] | |
if len(local_index) > 1: | |
local_index = local_index[0] | |
# Exclude the original location's index | |
filtered_local_indices = torch.arange( | |
len(all_class_locs)) != local_index | |
# Select indices for context | |
if filtered_local_indices.sum() > self.num_context: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
# No shuffling for evaluation | |
selected_indices = selected_indices[:self.num_context] | |
else: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
# Get context locations and features | |
context_loc_feats = all_class_loc_feats[selected_indices] | |
context_locs = all_class_locs[selected_indices] | |
# Adjust dimensions if necessary | |
if context_loc_feats.dim() == 1: | |
context_loc_feats = context_loc_feats.unsqueeze(0) | |
if context_locs.dim() == 1: | |
context_locs = context_locs.unsqueeze(0) | |
context_sequence = torch.cat( | |
[class_token_feature, context_loc_feats], dim=0) | |
return loc_feat, loc, class_id, context_sequence, context_locs | |
def collate_fn(self, batch): | |
# Unpack the batch | |
loc_feats, locs, class_ids, context_sequences, context_locss = zip(*batch) | |
# Pad sequences | |
padded_sequences = pad_sequence( | |
context_sequences, batch_first=True, padding_value=-10) | |
padded_context_locs = pad_sequence( | |
context_locss, batch_first=True, padding_value=-10) | |
# Convert lists to tensors | |
loc_feats = torch.stack(loc_feats) | |
locs = torch.stack(locs) | |
class_ids = torch.tensor(class_ids) | |
# Create a mask for sequences based on padding | |
sequence_mask = (padded_sequences == -10).all(dim=-1) | |
return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask | |
def get_item_from_class(self, class_id): | |
# Fetch locations and features for the class | |
all_class_locs = self.per_class_locs[class_id] | |
all_class_loc_feats = self.per_class_loc_feats[class_id] | |
# Randomly select an index | |
index = np.random.choice(len(all_class_locs)) | |
# Retrieve selected location and features | |
loc = all_class_locs[index] | |
if loc.ndim == 1: | |
loc = loc.unsqueeze(0) | |
loc_feat = self.enc.encode(loc, normalize=False) | |
# Define a unique class token index | |
class_token_feature = torch.zeros((1, self.token_dim)) # CPU tensor | |
# Exclude selected index from context | |
filtered_local_indices = torch.arange(len(all_class_locs)) != index | |
# Select indices for context | |
if filtered_local_indices.sum() > self.num_context: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
perm = torch.randperm(selected_indices.size(0)) | |
selected_indices = selected_indices[perm][:self.num_context] | |
else: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
# Get context locations and features | |
context_loc_feats = all_class_loc_feats[selected_indices] | |
context_locs = all_class_locs[selected_indices] | |
# Adjust dimensions if necessary | |
if context_loc_feats.dim() == 1: | |
context_loc_feats = context_loc_feats.unsqueeze(0) | |
if context_locs.dim() == 1: | |
context_locs = context_locs.unsqueeze(0) | |
context_sequence = torch.cat( | |
[class_token_feature, context_loc_feats], dim=0) | |
return loc_feat, loc, class_id, context_sequence, context_locs | |
# class EvalTransformerLocationDataset(torch.utils.data.Dataset): | |
# def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, | |
# noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False, | |
# variable_context_length=False): | |
# # Handle input encoding | |
# self.input_enc = input_enc | |
# if self.input_enc in ['env', 'sin_cos_env']: | |
# raster = load_env() | |
# else: | |
# raster = None | |
# self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# | |
# # Handle transformer input encoding | |
# self.transformer_input_enc = transformer_input_enc | |
# if self.transformer_input_enc in ['env', 'sin_cos_env']: | |
# transformer_raster = load_env() | |
# else: | |
# transformer_raster = None | |
# if self.transformer_input_enc == 'sinr': | |
# self.transformer_enc = self.enc | |
# else: | |
# self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim) | |
# | |
# # Define some properties | |
# self.locs = locs | |
# self.labels = labels | |
# self.classes = classes | |
# self.class_to_taxa = class_to_taxa | |
# | |
# # Normalize locs and create loc_feats | |
# self.loc_feats = self.enc.encode(self.locs, normalize=True) | |
# transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False) | |
# | |
# if dates is not None: | |
# self.dates = dates | |
# self.enc_time = utils.TimeEncoder() | |
# | |
# # Useful numbers | |
# self.num_classes = len(np.unique(labels)) | |
# self.input_dim = input_dim | |
# self.time_dim = time_dim | |
# self.noise_time = noise_time | |
# self.num_context = num_context | |
# self.token_dim = token_dim | |
# | |
# if self.enc.raster is not None: | |
# self.enc.raster = self.enc.raster.to(device) | |
# | |
# if self.transformer_enc.raster is not None: | |
# self.transformer_enc.raster = self.transformer_enc.raster.to(device) | |
# | |
# # Organize the data into the dictionary | |
# per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(locs)) | |
# per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats)) | |
# for key, value in per_class_location_dict.items(): | |
# per_class_location_dict[key] = torch.tensor(np.array(value)) | |
# for key, value in per_class_loc_feats_dict.items(): | |
# per_class_loc_feats_dict[key] = torch.tensor(np.array(value)) | |
# self.per_class_locs = per_class_location_dict | |
# self.per_class_loc_feats = per_class_loc_feats_dict | |
# | |
# # Select a single example per class | |
# unique_labels, unique_indices = np.unique(labels, return_index=True) | |
# self.locs = locs[unique_indices] | |
# self.labels = labels[unique_indices] | |
# self.loc_feats = self.loc_feats[unique_indices] | |
# | |
# def __len__(self): | |
# return self.loc_feats.shape[0] | |
# | |
# def __getitem__(self, index): | |
# # Retrieve the feature and class of the original point | |
# loc_feat = self.loc_feats[index, :] | |
# loc = self.locs[index, :] | |
# class_id = self.labels[index] | |
# class_id_int = class_id.item() | |
# | |
# # Fetch all locations for the given class | |
# all_class_locs = self.per_class_locs[class_id_int] | |
# all_class_loc_feats = self.per_class_loc_feats[class_id_int] | |
# | |
# # Define a unique class token index | |
# #class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token | |
# class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # Create a zero vector for the class token | |
# | |
# | |
# # Broadcast and compare to find all matching locations | |
# matches = (all_class_locs == loc).all(dim=1) | |
# | |
# # Find the index of the original location | |
# local_index = torch.where(matches)[0] | |
# if len(local_index) > 1: | |
# local_index = local_index[0] | |
# | |
# # Exclude the original location's index | |
# filtered_local_indices = torch.arange(len(all_class_locs)) != local_index | |
# | |
# # Select random or all indices depending on the availability relative to `num_context` | |
# if filtered_local_indices.sum() > self.num_context: | |
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
# # no shuffling for eval | |
# # np.random.shuffle(selected_indices) | |
# selected_indices = selected_indices[:self.num_context] | |
# else: | |
# selected_indices = filtered_local_indices.nonzero().squeeze() | |
# | |
# # Get context locations and their features | |
# context_loc_feats = all_class_loc_feats[selected_indices] | |
# context_locs = all_class_locs[selected_indices] | |
# | |
# # Check if context_loc_feats has 1 dimension and add another if it does | |
# if context_loc_feats.dim() == 1: | |
# context_loc_feats = context_loc_feats.unsqueeze(0) | |
# # Check if context_locs has 1 dimension and add another if it does | |
# if context_locs.dim() == 1: | |
# context_locs = context_locs.unsqueeze(0) | |
# | |
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# | |
# return loc_feat, loc, class_id, context_sequence, context_locs | |
# | |
# def collate_fn(self, batch): | |
# # Unzip the batch | |
# loc_feats, locs, class_ids, context_sequences, context_locss = zip(*batch) | |
# | |
# # Convert list of sequences to a tensor with padding | |
# padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10) | |
# | |
# # Convert list of class IDs to a tensor | |
# class_ids = torch.tensor(class_ids) | |
# # Convert loc_feats and locs to tensors | |
# loc_feats = torch.stack(loc_feats) | |
# locs = torch.stack(locs) | |
# | |
# padded_context_locs = pad_sequence(context_locss, batch_first=True, padding_value=-10) | |
# | |
# # Create a mask for sequences based on padding | |
# # sequence_mask = (padded_sequences == 0) # Create a mask where there's padding (0s) | |
# sequence_mask = (padded_sequences == -10).all(dim=-1) | |
# | |
# # return padded_sequences, padded_locs, class_ids, sequence_mask | |
# return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask | |
# | |
# def get_item_from_class(self, class_id): | |
# # Fetch all locations and features for the given class | |
# all_class_locs = self.per_class_locs[class_id] | |
# all_class_loc_feats = self.per_class_loc_feats[class_id] | |
# | |
# # Randomly select an index for the class | |
# index = np.random.choice(len(all_class_locs)) | |
# | |
# # Retrieve the selected location and its features | |
# loc = all_class_locs[index] | |
# if loc.ndim == 1: | |
# loc = loc.unsqueeze(0) | |
# # loc = loc.unsqueeze(0) | |
# loc_feat = self.enc.encode(loc, normalize=False) | |
# # loc_feat = all_class_loc_feats[index] | |
# | |
# # Define a unique class token index | |
# class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token | |
# | |
# # Exclude the selected index from the context | |
# filtered_local_indices = torch.arange(len(all_class_locs)) != index | |
# | |
# # Select random or all indices depending on the availability relative to `num_context` | |
# if filtered_local_indices.sum() > self.num_context: | |
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
# np.random.shuffle(selected_indices) | |
# selected_indices = selected_indices[:self.num_context] | |
# else: | |
# selected_indices = filtered_local_indices.nonzero().squeeze() | |
# | |
# # Get context locations and their features | |
# context_loc_feats = all_class_loc_feats[selected_indices] | |
# context_locs = all_class_locs[selected_indices] | |
# | |
# # Check if context_loc_feats has 1 dimension and add another if it does | |
# if context_loc_feats.dim() == 1: | |
# context_loc_feats = context_loc_feats.unsqueeze(0) | |
# # Check if context_locs has 1 dimension and add another if it does | |
# if context_locs.dim() == 1: | |
# context_locs = context_locs.unsqueeze(0) | |
# | |
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# | |
# return loc_feat, loc, class_id, context_sequence, context_locs | |
# MINE MINE MINE MINE - should only be used for my models | |
# this eval version always has only a single example from each class (with the appropriate number of context points) | |
# So on iterating through the dataset we can create our models "eval embeddings" for every class | |
# Each eval embedding is generated from a single forward pass of "num_context" context points + text | |
class EvalTransformerTextLocationDataset(torch.utils.data.Dataset): | |
def __init__(self, locs, labels, classes, class_to_taxa, embs, embs_ids, embs_keys, input_enc, device, dates=None, | |
input_dim=4, time_dim=0, noise_time=False, num_context=50, transformer_input_enc=None, | |
token_dim=None, jitter=False, variable_context_length=False): | |
# Handle input encoding | |
self.input_enc = input_enc | |
if self.input_enc in ['env', 'sin_cos_env']: | |
raster = load_env() | |
else: | |
raster = None | |
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# Handle transformer input encoding | |
self.transformer_input_enc = transformer_input_enc | |
if self.transformer_input_enc in ['env', 'sin_cos_env']: | |
transformer_raster = load_env() | |
else: | |
transformer_raster = None | |
if self.transformer_input_enc == 'sinr': | |
self.transformer_enc = self.enc | |
else: | |
self.transformer_enc = utils.CoordEncoder( | |
transformer_input_enc, transformer_raster, input_dim=token_dim) | |
# Define properties | |
self.locs = locs # Keep on CPU | |
self.labels = labels # Keep on CPU | |
self.classes = classes | |
self.class_to_taxa = class_to_taxa | |
# Normalize locs and create loc_feats | |
self.loc_feats = self.enc.encode(self.locs, normalize=True) | |
transformer_loc_feats = self.transformer_enc.encode( | |
self.locs, normalize=False) | |
if dates is not None: | |
self.dates = dates | |
self.enc_time = utils.TimeEncoder() | |
# Useful numbers | |
self.num_classes = len(np.unique(labels)) | |
self.input_dim = input_dim | |
self.time_dim = time_dim | |
self.noise_time = noise_time | |
self.num_context = num_context | |
self.token_dim = token_dim | |
# Remove device assignments from rasters | |
# if self.enc.raster is not None: | |
# self.enc.raster = self.enc.raster.to(device) | |
# if self.transformer_enc.raster is not None: | |
# self.transformer_enc.raster = self.transformer_enc.raster.to(device) | |
# Text embeddings | |
self.embs = embs # Keep on CPU | |
self.embs_ids = embs_ids.tolist() | |
self.embs_class_ids = [class_to_taxa.index( | |
taxa) if taxa in class_to_taxa else -1 for taxa in self.embs_ids] | |
self.embs_keys = embs_keys | |
# Initialize class embedding dictionary | |
class_emb_dict = {} | |
for i, (index, description) in enumerate(embs_keys): | |
class_id = self.embs_class_ids[index] | |
if class_id == -1: | |
continue | |
if class_id not in class_emb_dict: | |
class_emb_dict[class_id] = ([], []) | |
class_emb_dict[class_id][0].append(i) | |
class_emb_dict[class_id][1].append(description) | |
self.class_emb_dict = class_emb_dict | |
# Organize data into dictionaries | |
per_class_location_dict = organize_data_by_labels( | |
np.array(labels), np.array(locs)) | |
per_class_loc_feats_dict = organize_data_by_labels( | |
np.array(labels), np.array(transformer_loc_feats)) | |
for key, value in per_class_location_dict.items(): | |
per_class_location_dict[key] = torch.tensor( | |
np.array(value)) # Keep on CPU | |
for key, value in per_class_loc_feats_dict.items(): | |
per_class_loc_feats_dict[key] = torch.tensor( | |
np.array(value)) # Keep on CPU | |
self.per_class_locs = per_class_location_dict | |
self.per_class_loc_feats = per_class_loc_feats_dict | |
# Select a single example per class | |
unique_labels, unique_indices = np.unique(labels, return_index=True) | |
self.locs = locs[unique_indices] | |
self.labels = labels[unique_indices] | |
self.loc_feats = self.loc_feats[unique_indices] | |
def __len__(self): | |
return self.loc_feats.shape[0] | |
def __getitem__(self, index): | |
# Retrieve feature and class of the original point | |
loc_feat = self.loc_feats[index, :] # On CPU | |
loc = self.locs[index, :] # On CPU | |
class_id = self.labels[index] # On CPU | |
class_id_int = class_id.item() | |
# Fetch all locations for the class | |
all_class_locs = self.per_class_locs[class_id_int] # On CPU | |
all_class_loc_feats = self.per_class_loc_feats[class_id_int] # On CPU | |
# Define a unique class token index | |
class_token_feature = torch.zeros( | |
(1, len(all_class_loc_feats[0]))) # CPU tensor | |
# Find the index of the original location | |
matches = (all_class_locs == loc).all(dim=1) | |
local_index = torch.where(matches)[0] | |
if len(local_index) > 1: | |
local_index = local_index[0] | |
# Exclude the original location's index | |
filtered_local_indices = torch.arange( | |
len(all_class_locs)) != local_index | |
# Select indices for context | |
if filtered_local_indices.sum() > self.num_context: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
# No shuffling for evaluation | |
selected_indices = selected_indices[:self.num_context] | |
else: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
# Get context locations and features | |
context_loc_feats = all_class_loc_feats[selected_indices] | |
context_locs = all_class_locs[selected_indices] | |
# Adjust dimensions if necessary | |
if context_loc_feats.dim() == 1: | |
context_loc_feats = context_loc_feats.unsqueeze(0) | |
if context_locs.dim() == 1: | |
context_locs = context_locs.unsqueeze(0) | |
context_sequence = torch.cat( | |
[class_token_feature, context_loc_feats], dim=0) | |
# Text embeddings | |
if class_id_int in self.class_emb_dict: | |
embs_indexes, descriptions = self.class_emb_dict[class_id_int] | |
selected_index = random.choice(embs_indexes) | |
emb = self.embs[selected_index] | |
else: | |
emb = torch.zeros(4096) # On CPU | |
return loc_feat, loc, class_id, context_sequence, context_locs, emb | |
def collate_fn(self, batch): | |
# Unpack the batch | |
loc_feats, locs, class_ids, context_sequences, context_locss, embs = zip(*batch) | |
# Pad sequences | |
padded_sequences = pad_sequence( | |
context_sequences, batch_first=True, padding_value=-10) | |
padded_context_locs = pad_sequence( | |
context_locss, batch_first=True, padding_value=-10) | |
# Convert lists to tensors | |
loc_feats = torch.stack(loc_feats) | |
locs = torch.stack(locs) | |
class_ids = torch.tensor(class_ids) | |
embs = torch.stack(embs) | |
# Create a mask for sequences based on padding | |
sequence_mask = (padded_sequences == -10).all(dim=-1) | |
return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, embs | |
def get_item_from_class(self, class_id): | |
# Fetch locations and features for the class | |
all_class_locs = self.per_class_locs[class_id] | |
all_class_loc_feats = self.per_class_loc_feats[class_id] | |
# Randomly select an index | |
index = np.random.choice(len(all_class_locs)) | |
# Retrieve selected location and features | |
loc = all_class_locs[index] | |
if loc.ndim == 1: | |
loc = loc.unsqueeze(0) | |
loc_feat = self.enc.encode(loc, normalize=False) | |
# Define a unique class token index | |
class_token_feature = torch.zeros((1, self.token_dim)) # CPU tensor | |
# Exclude selected index from context | |
filtered_local_indices = torch.arange(len(all_class_locs)) != index | |
# Select indices for context | |
if filtered_local_indices.sum() > self.num_context: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
perm = torch.randperm(selected_indices.size(0)) | |
selected_indices = selected_indices[perm][:self.num_context] | |
else: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
# Get context locations and features | |
context_loc_feats = all_class_loc_feats[selected_indices] | |
context_locs = all_class_locs[selected_indices] | |
# Adjust dimensions if necessary | |
if context_loc_feats.dim() == 1: | |
context_loc_feats = context_loc_feats.unsqueeze(0) | |
if context_locs.dim() == 1: | |
context_locs = context_locs.unsqueeze(0) | |
context_sequence = torch.cat( | |
[class_token_feature, context_loc_feats], dim=0) | |
# Text embeddings | |
if class_id in self.class_emb_dict: | |
embs_indexes, descriptions = self.class_emb_dict[class_id] | |
selected_index = random.choice(embs_indexes) | |
emb = self.embs[selected_index] | |
else: | |
emb = torch.zeros(4096) # On CPU | |
return loc_feat, loc, class_id, context_sequence, context_locs, emb | |
def select_text_section(self, text_section): | |
# Initialize an empty dictionary to store the result | |
class_emb_dict = {} | |
# Populate the dictionary | |
for i, (index, description) in enumerate(self.embs_keys): | |
# Find the class using the index from the class_list | |
class_id = self.embs_class_ids[index] | |
# Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
if class_id == -1: | |
continue | |
if description != text_section: | |
continue | |
# Check if the class_id is already a key in the dictionary | |
if class_id not in class_emb_dict: | |
# Initialize with empty lists if class_id is not already in the dictionary | |
class_emb_dict[class_id] = ([], []) | |
# Append the description and the index of embs_keys to the corresponding lists | |
class_emb_dict[class_id][0].append(i) | |
class_emb_dict[class_id][1].append(description) | |
self.class_emb_dict = class_emb_dict | |
class EvalTransformerDummyTextLocationDataset(torch.utils.data.Dataset): | |
def __init__(self, locs, labels, classes, class_to_taxa, embs, embs_ids, embs_keys, input_enc, device, dates=None, | |
input_dim=4, time_dim=0, noise_time=False, num_context=50, transformer_input_enc=None, | |
token_dim=None, jitter=False, variable_context_length=False): | |
# Handle input encoding | |
self.input_enc = input_enc | |
if self.input_enc in ['env', 'sin_cos_env']: | |
raster = load_env() | |
else: | |
raster = None | |
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# Handle transformer input encoding | |
self.transformer_input_enc = transformer_input_enc | |
if self.transformer_input_enc in ['env', 'sin_cos_env']: | |
transformer_raster = load_env() | |
else: | |
transformer_raster = None | |
if self.transformer_input_enc == 'sinr': | |
self.transformer_enc = self.enc | |
else: | |
self.transformer_enc = utils.CoordEncoder( | |
transformer_input_enc, transformer_raster, input_dim=token_dim) | |
# Define properties | |
self.locs = locs # Keep on CPU | |
self.labels = labels # Keep on CPU | |
self.classes = classes | |
self.class_to_taxa = class_to_taxa | |
# Normalize locs and create loc_feats | |
self.loc_feats = self.enc.encode(self.locs, normalize=True) | |
transformer_loc_feats = self.transformer_enc.encode( | |
self.locs, normalize=False) | |
if dates is not None: | |
self.dates = dates | |
self.enc_time = utils.TimeEncoder() | |
# Useful numbers | |
self.num_classes = len(np.unique(labels)) | |
self.input_dim = input_dim | |
self.time_dim = time_dim | |
self.noise_time = noise_time | |
self.num_context = num_context | |
self.token_dim = token_dim | |
# Remove device assignments from rasters | |
# if self.enc.raster is not None: | |
# self.enc.raster = self.enc.raster.to(device) | |
# if self.transformer_enc.raster is not None: | |
# self.transformer_enc.raster = self.transformer_enc.raster.to(device) | |
# Text embeddings | |
self.embs = embs # Keep on CPU | |
self.embs_ids = embs_ids.tolist() | |
self.embs_class_ids = [class_to_taxa.index( | |
taxa) if taxa in class_to_taxa else -1 for taxa in self.embs_ids] | |
self.embs_keys = embs_keys | |
# Initialize class embedding dictionary | |
class_emb_dict = {} | |
for i, (index, description) in enumerate(embs_keys): | |
class_id = self.embs_class_ids[index] | |
if class_id == -1: | |
continue | |
if class_id not in class_emb_dict: | |
class_emb_dict[class_id] = ([], []) | |
class_emb_dict[class_id][0].append(i) | |
class_emb_dict[class_id][1].append(description) | |
self.class_emb_dict = class_emb_dict | |
# Organize data into dictionaries | |
per_class_location_dict = organize_data_by_labels( | |
np.array(labels), np.array(locs)) | |
per_class_loc_feats_dict = organize_data_by_labels( | |
np.array(labels), np.array(transformer_loc_feats)) | |
for key, value in per_class_location_dict.items(): | |
per_class_location_dict[key] = torch.tensor( | |
np.array(value)) # Keep on CPU | |
for key, value in per_class_loc_feats_dict.items(): | |
per_class_loc_feats_dict[key] = torch.tensor( | |
np.array(value)) # Keep on CPU | |
self.per_class_locs = per_class_location_dict | |
self.per_class_loc_feats = per_class_loc_feats_dict | |
# Select a single example per class | |
unique_labels, unique_indices = np.unique(labels, return_index=True) | |
self.locs = locs[unique_indices] | |
self.labels = labels[unique_indices] | |
self.loc_feats = self.loc_feats[unique_indices] | |
def __len__(self): | |
return self.loc_feats.shape[0] | |
def __getitem__(self, index): | |
# Retrieve feature and class of the original point | |
loc_feat = self.loc_feats[index, :] # On CPU | |
loc = self.locs[index, :] # On CPU | |
class_id = self.labels[index] # On CPU | |
class_id_int = class_id.item() | |
# Fetch all locations for the class | |
all_class_locs = self.per_class_locs[class_id_int] # On CPU | |
all_class_loc_feats = self.per_class_loc_feats[class_id_int] # On CPU | |
# Define a unique class token index | |
class_token_feature = torch.zeros( | |
(1, len(all_class_loc_feats[0]))) # CPU tensor | |
# Find the index of the original location | |
matches = (all_class_locs == loc).all(dim=1) | |
local_index = torch.where(matches)[0] | |
if len(local_index) > 1: | |
local_index = local_index[0] | |
# Exclude the original location's index | |
filtered_local_indices = torch.arange( | |
len(all_class_locs)) != local_index | |
# Select indices for context | |
if filtered_local_indices.sum() > self.num_context: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
# No shuffling for evaluation | |
selected_indices = selected_indices[:self.num_context] | |
else: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
# Get context locations and features | |
context_loc_feats = all_class_loc_feats[selected_indices] | |
context_locs = all_class_locs[selected_indices] | |
# Adjust dimensions if necessary | |
if context_loc_feats.dim() == 1: | |
context_loc_feats = context_loc_feats.unsqueeze(0) | |
if context_locs.dim() == 1: | |
context_locs = context_locs.unsqueeze(0) | |
context_sequence = torch.cat( | |
[class_token_feature, context_loc_feats], dim=0) | |
# Text embeddings | |
emb = torch.zeros(4096) # On CPU | |
return loc_feat, loc, class_id, context_sequence, context_locs, emb | |
def collate_fn(self, batch): | |
# Unpack the batch | |
loc_feats, locs, class_ids, context_sequences, context_locss, embs = zip(*batch) | |
# Pad sequences | |
padded_sequences = pad_sequence( | |
context_sequences, batch_first=True, padding_value=-10) | |
padded_context_locs = pad_sequence( | |
context_locss, batch_first=True, padding_value=-10) | |
# Convert lists to tensors | |
loc_feats = torch.stack(loc_feats) | |
locs = torch.stack(locs) | |
class_ids = torch.tensor(class_ids) | |
embs = torch.stack(embs) | |
# Create a mask for sequences based on padding | |
sequence_mask = (padded_sequences == -10).all(dim=-1) | |
return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, embs | |
def get_item_from_class(self, class_id): | |
# Fetch locations and features for the class | |
all_class_locs = self.per_class_locs[class_id] | |
all_class_loc_feats = self.per_class_loc_feats[class_id] | |
# Randomly select an index | |
index = np.random.choice(len(all_class_locs)) | |
# Retrieve selected location and features | |
loc = all_class_locs[index] | |
if loc.ndim == 1: | |
loc = loc.unsqueeze(0) | |
loc_feat = self.enc.encode(loc, normalize=False) | |
# Define a unique class token index | |
class_token_feature = torch.zeros((1, self.token_dim)) # CPU tensor | |
# Exclude selected index from context | |
filtered_local_indices = torch.arange(len(all_class_locs)) != index | |
# Select indices for context | |
if filtered_local_indices.sum() > self.num_context: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
perm = torch.randperm(selected_indices.size(0)) | |
selected_indices = selected_indices[perm][:self.num_context] | |
else: | |
selected_indices = filtered_local_indices.nonzero().squeeze() | |
# Get context locations and features | |
context_loc_feats = all_class_loc_feats[selected_indices] | |
context_locs = all_class_locs[selected_indices] | |
# Adjust dimensions if necessary | |
if context_loc_feats.dim() == 1: | |
context_loc_feats = context_loc_feats.unsqueeze(0) | |
if context_locs.dim() == 1: | |
context_locs = context_locs.unsqueeze(0) | |
context_sequence = torch.cat( | |
[class_token_feature, context_loc_feats], dim=0) | |
# Text embeddings | |
if class_id in self.class_emb_dict: | |
embs_indexes, descriptions = self.class_emb_dict[class_id] | |
selected_index = random.choice(embs_indexes) | |
emb = self.embs[selected_index] | |
else: | |
emb = torch.zeros(4096) # On CPU | |
return loc_feat, loc, class_id, context_sequence, context_locs, emb | |
def select_text_section(self, text_section): | |
# Initialize an empty dictionary to store the result | |
class_emb_dict = {} | |
# Populate the dictionary | |
for i, (index, description) in enumerate(self.embs_keys): | |
# Find the class using the index from the class_list | |
class_id = self.embs_class_ids[index] | |
# Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
if class_id == -1: | |
continue | |
if description != text_section: | |
continue | |
# Check if the class_id is already a key in the dictionary | |
if class_id not in class_emb_dict: | |
# Initialize with empty lists if class_id is not already in the dictionary | |
class_emb_dict[class_id] = ([], []) | |
# Append the description and the index of embs_keys to the corresponding lists | |
class_emb_dict[class_id][0].append(i) | |
class_emb_dict[class_id][1].append(description) | |
self.class_emb_dict = class_emb_dict | |
# | |
# class EvalTransformerTextLocationDataset(torch.utils.data.Dataset): | |
# def __init__(self, locs, labels, classes, class_to_taxa, embs, embs_ids, embs_keys, input_enc, device, dates=None, input_dim=4, time_dim=0, | |
# noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False, | |
# variable_context_length=False): | |
# # Handle input encoding | |
# self.input_enc = input_enc | |
# if self.input_enc in ['env', 'sin_cos_env']: | |
# raster = load_env() | |
# else: | |
# raster = None | |
# self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# | |
# # Handle transformer input encoding | |
# self.transformer_input_enc = transformer_input_enc | |
# if self.transformer_input_enc in ['env', 'sin_cos_env']: | |
# transformer_raster = load_env() | |
# else: | |
# transformer_raster = None | |
# if self.transformer_input_enc == 'sinr': | |
# self.transformer_enc = self.enc | |
# else: | |
# self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim) | |
# | |
# # Define some properties | |
# self.device = device | |
# self.locs = locs | |
# self.labels = labels | |
# self.classes = classes | |
# self.class_to_taxa = class_to_taxa | |
# | |
# # Normalize locs and create loc_feats | |
# self.loc_feats = self.enc.encode(self.locs, normalize=True) | |
# transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False) | |
# | |
# if dates is not None: | |
# self.dates = dates | |
# self.enc_time = utils.TimeEncoder() | |
# | |
# # Useful numbers | |
# self.num_classes = len(np.unique(labels)) | |
# self.input_dim = input_dim | |
# self.time_dim = time_dim | |
# self.noise_time = noise_time | |
# self.num_context = num_context | |
# self.token_dim = token_dim | |
# | |
# if self.enc.raster is not None: | |
# self.enc.raster = self.enc.raster.to(device) | |
# | |
# if self.transformer_enc.raster is not None: | |
# self.transformer_enc.raster = self.transformer_enc.raster.to(device) | |
# | |
# # text stuff | |
# self.embs = embs | |
# self.embs_ids = embs_ids.tolist() | |
# self.embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.embs_ids] | |
# self.embs_keys = embs_keys | |
# | |
# # Initialize an empty dictionary to store the result | |
# class_emb_dict = {} | |
# # Populate the dictionary | |
# for i, (index, description) in enumerate(embs_keys): | |
# # Find the class using the index from the class_list | |
# class_id = self.embs_class_ids[index] | |
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
# if class_id == -1: | |
# continue | |
# # Check if the class_id is already a key in the dictionary | |
# if class_id not in class_emb_dict: | |
# # Initialize with empty lists if class_id is not already in the dictionary | |
# class_emb_dict[class_id] = ([], []) | |
# | |
# # Append the description and the index of embs_keys to the corresponding lists | |
# class_emb_dict[class_id][0].append(i) | |
# class_emb_dict[class_id][1].append(description) | |
# self.class_emb_dict = class_emb_dict | |
# | |
# # Organize the data into the dictionary | |
# per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(locs)) | |
# per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats)) | |
# for key, value in per_class_location_dict.items(): | |
# per_class_location_dict[key] = torch.tensor(np.array(value)) | |
# for key, value in per_class_loc_feats_dict.items(): | |
# per_class_loc_feats_dict[key] = torch.tensor(np.array(value)) | |
# self.per_class_locs = per_class_location_dict | |
# self.per_class_loc_feats = per_class_loc_feats_dict | |
# | |
# # Select a single example per class | |
# unique_labels, unique_indices = np.unique(labels, return_index=True) | |
# self.locs = locs[unique_indices] | |
# self.labels = labels[unique_indices] | |
# self.loc_feats = self.loc_feats[unique_indices] | |
# | |
# def __len__(self): | |
# return self.loc_feats.shape[0] | |
# | |
# def __getitem__(self, index): | |
# # Retrieve the feature and class of the original point | |
# loc_feat = self.loc_feats[index, :] | |
# loc = self.locs[index, :] | |
# class_id = self.labels[index] | |
# class_id_int = class_id.item() | |
# | |
# # Fetch all locations for the given class | |
# all_class_locs = self.per_class_locs[class_id_int] | |
# all_class_loc_feats = self.per_class_loc_feats[class_id_int] | |
# | |
# # Define a unique class token index | |
# #class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token | |
# class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # Create a zero vector for the class token | |
# | |
# | |
# # Broadcast and compare to find all matching locations | |
# matches = (all_class_locs == loc).all(dim=1) | |
# | |
# # Find the index of the original location | |
# local_index = torch.where(matches)[0] | |
# if len(local_index) > 1: | |
# local_index = local_index[0] | |
# | |
# # Exclude the original location's index | |
# filtered_local_indices = torch.arange(len(all_class_locs)) != local_index | |
# | |
# # Select random or all indices depending on the availability relative to `num_context` | |
# if filtered_local_indices.sum() > self.num_context: | |
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
# # no shuffling for eval | |
# # np.random.shuffle(selected_indices) | |
# selected_indices = selected_indices[:self.num_context] | |
# else: | |
# selected_indices = filtered_local_indices.nonzero().squeeze() | |
# | |
# # Get context locations and their features | |
# context_loc_feats = all_class_loc_feats[selected_indices] | |
# context_locs = all_class_locs[selected_indices] | |
# | |
# # Check if context_loc_feats has 1 dimension and add another if it does | |
# if context_loc_feats.dim() == 1: | |
# context_loc_feats = context_loc_feats.unsqueeze(0) | |
# # Check if context_locs has 1 dimension and add another if it does | |
# if context_locs.dim() == 1: | |
# context_locs = context_locs.unsqueeze(0) | |
# | |
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# | |
# # text stuff | |
# # Get the embedding for the right class | |
# if class_id_int in self.class_emb_dict: | |
# embs_indexes, descriptions = self.class_emb_dict[class_id_int] | |
# # Randomly select an index from the list of indices | |
# selected_index = random.choice(embs_indexes) | |
# # Use the selected index to retrieve the corresponding element from embs | |
# emb = self.embs[selected_index] | |
# else: | |
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later? | |
# # emb = None | |
# emb = torch.zeros(4096, device=self.embs.device) | |
# | |
# return loc_feat, loc, class_id, context_sequence, context_locs, emb | |
# | |
# def collate_fn(self, batch): | |
# # Unzip the batch | |
# loc_feats, locs, class_ids, context_sequences, context_locss, embs = zip(*batch) | |
# | |
# # Convert list of sequences to a tensor with padding | |
# padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10) | |
# | |
# # Convert list of class IDs to a tensor | |
# class_ids = torch.tensor(class_ids) | |
# # Convert loc_feats and locs to tensors | |
# loc_feats = torch.stack(loc_feats) | |
# locs = torch.stack(locs) | |
# embs = torch.stack(embs) | |
# | |
# padded_context_locs = pad_sequence(context_locss, batch_first=True, padding_value=-10) | |
# | |
# # Create a mask for sequences based on padding | |
# # sequence_mask = (padded_sequences == 0) # Create a mask where there's padding (0s) | |
# sequence_mask = (padded_sequences == -10).all(dim=-1) | |
# | |
# # return padded_sequences, padded_locs, class_ids, sequence_mask | |
# return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, embs | |
# | |
# def get_item_from_class(self, class_id): | |
# # Fetch all locations and features for the given class | |
# all_class_locs = self.per_class_locs[class_id] | |
# all_class_loc_feats = self.per_class_loc_feats[class_id] | |
# | |
# # Randomly select an index for the class | |
# index = np.random.choice(len(all_class_locs)) | |
# | |
# # Retrieve the selected location and its features | |
# loc = all_class_locs[index] | |
# if loc.ndim == 1: | |
# loc = loc.unsqueeze(0) | |
# # loc = loc.unsqueeze(0) | |
# loc_feat = self.enc.encode(loc, normalize=False) | |
# # loc_feat = all_class_loc_feats[index] | |
# | |
# # Define a unique class token index | |
# class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token | |
# | |
# # Exclude the selected index from the context | |
# filtered_local_indices = torch.arange(len(all_class_locs)) != index | |
# | |
# # Select random or all indices depending on the availability relative to `num_context` | |
# if filtered_local_indices.sum() > self.num_context: | |
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy() | |
# np.random.shuffle(selected_indices) | |
# selected_indices = selected_indices[:self.num_context] | |
# else: | |
# selected_indices = filtered_local_indices.nonzero().squeeze() | |
# | |
# # Get context locations and their features | |
# context_loc_feats = all_class_loc_feats[selected_indices] | |
# context_locs = all_class_locs[selected_indices] | |
# | |
# # Check if context_loc_feats has 1 dimension and add another if it does | |
# if context_loc_feats.dim() == 1: | |
# context_loc_feats = context_loc_feats.unsqueeze(0) | |
# # Check if context_locs has 1 dimension and add another if it does | |
# if context_locs.dim() == 1: | |
# context_locs = context_locs.unsqueeze(0) | |
# | |
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0) | |
# | |
# # #text stuff | |
# # # get the embedding for the right class | |
# # print("CHECK WHETHER THE CLASS ID HERE IS A CLASS OR A TAXA") | |
# # emb_index = self.embs_ids.index(class_id) | |
# # print("CHECK THE FORM OF EMB") | |
# # emb = self.embs[emb_index] | |
# emb = None | |
# raise NotImplementedError('THIS METHOD NEEDS TO BE UPDATED FOR TEXT EMBEDDINGS - I THINK IT IS NOT USED ANYWHERE CURRENTLY THOUGH') | |
# | |
# return loc_feat, loc, class_id, context_sequence, context_locs, emb | |
# | |
# def select_text_section(self, text_section): | |
# # Initialize an empty dictionary to store the result | |
# class_emb_dict = {} | |
# # Populate the dictionary | |
# for i, (index, description) in enumerate(self.embs_keys): | |
# # Find the class using the index from the class_list | |
# class_id = self.embs_class_ids[index] | |
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset | |
# if class_id == -1: | |
# continue | |
# if description != text_section: | |
# continue | |
# # Check if the class_id is already a key in the dictionary | |
# if class_id not in class_emb_dict: | |
# # Initialize with empty lists if class_id is not already in the dictionary | |
# class_emb_dict[class_id] = ([], []) | |
# | |
# # Append the description and the index of embs_keys to the corresponding lists | |
# class_emb_dict[class_id][0].append(i) | |
# class_emb_dict[class_id][1].append(description) | |
# self.class_emb_dict = class_emb_dict | |
def load_env(): | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
raster = load_context_feats(os.path.join(paths['env'],'bioclim_elevation_scaled.npy')) | |
return raster | |
def load_context_feats(data_path): | |
context_feats = np.load(data_path).astype(np.float32) | |
context_feats = torch.from_numpy(context_feats) | |
return context_feats | |
# MAX MAX MAX MAX MAX MAX | |
_file_cache = {} | |
def load_inat_data(ip_file, taxa_of_interest=None): | |
if os.path.exists('.datacache.pt'): | |
print('\nLoading cached data') | |
if '.datacache.pt' not in _file_cache: | |
# If not in the cache, read the file and store its content in the cache | |
#_file_cache['.datacache.pt'] = torch.load('.datacache.pt', weights_only=False) | |
_file_cache['.datacache.pt'] = torch.load('.datacache.pt') | |
locs, taxa, users, dates, years, obs_ids = _file_cache['.datacache.pt'] | |
else: | |
print('\nLoading ' + ip_file) | |
data = pd.read_csv(ip_file) | |
# remove outliers | |
num_obs = data.shape[0] | |
data = data[((data['latitude'] <= 90) & (data['latitude'] >= -90) & (data['longitude'] <= 180) & (data['longitude'] >= -180) )] | |
if (num_obs - data.shape[0]) > 0: | |
print(num_obs - data.shape[0], 'items filtered due to invalid locations') | |
if 'accuracy' in data.columns: | |
data.drop(['accuracy'], axis=1, inplace=True) | |
if 'positional_accuracy' in data.columns: | |
data.drop(['positional_accuracy'], axis=1, inplace=True) | |
if 'geoprivacy' in data.columns: | |
data.drop(['geoprivacy'], axis=1, inplace=True) | |
if 'observed_on' in data.columns: | |
data.rename(columns = {'observed_on':'date'}, inplace=True) | |
num_obs_orig = data.shape[0] | |
data = data.dropna() | |
size_diff = num_obs_orig - data.shape[0] | |
if size_diff > 0: | |
print(size_diff, 'observation(s) with a NaN entry out of' , num_obs_orig, 'removed') | |
# keep only taxa of interest: | |
if taxa_of_interest is not None: | |
num_obs_orig = data.shape[0] | |
data = data[data['taxon_id'].isin(taxa_of_interest)] | |
print(num_obs_orig - data.shape[0], 'observation(s) out of' , num_obs_orig, 'from different taxa removed') | |
print('Number of unique classes {}'.format(np.unique(data['taxon_id'].values).shape[0])) | |
locs = np.vstack((data['longitude'].values, data['latitude'].values)).T.astype(np.float32) | |
taxa = data['taxon_id'].values.astype(np.int64) | |
if 'user_id' in data.columns: | |
users = data['user_id'].values.astype(np.int64) | |
_, users = np.unique(users, return_inverse=True) | |
elif 'observer_id' in data.columns: | |
users = data['observer_id'].values.astype(np.int64) | |
_, users = np.unique(users, return_inverse=True) | |
else: | |
users = np.ones(taxa.shape[0], dtype=np.int64)*-1 | |
# Note - assumes that dates are in format YYYY-MM-DD | |
temp = np.array(data['date'], dtype='S10') | |
temp = temp.view('S1').reshape((temp.size, -1)) | |
years = temp[:,:4].view('S4').astype(int)[:,0] | |
months = temp[:,5:7].view('S2').astype(int)[:,0] | |
days = temp[:,8:10].view('S2').astype(int)[:,0] | |
days_per_month = np.cumsum([0] + [monthrange(2018, mm)[1] for mm in range(1, 12)]) | |
dates = days_per_month[months-1] + days-1 | |
dates = np.round((dates) / 364.0, 4).astype(np.float32) | |
if 'id' in data.columns: | |
obs_ids = data['id'].values | |
elif 'observation_uuid' in data.columns: | |
obs_ids = data['observation_uuid'].values | |
torch.save((locs, taxa, users, dates, years, obs_ids), '.datacache.pt') | |
return locs, taxa, users, dates, years, obs_ids | |
def load_eval_inat_data(ip_file): | |
if os.path.exists('.eval_datacache.pt'): | |
print('\nLoading cached eval data') | |
if '.eval_datacache.pt' not in _file_cache: | |
# If not in the cache, read the file and store its content in the cache | |
#_file_cache['.eval_datacache.pt'] = torch.load('.eval_datacache.pt', weights_only=False) | |
_file_cache['.eval_datacache.pt'] = torch.load('.eval_datacache.pt') | |
locs, taxa, users, dates, years, obs_ids = _file_cache['.eval_datacache.pt'] | |
else: | |
print('\nLoading ' + ip_file) | |
# data = pd.read_csv(ip_file) | |
data = np.load(ip_file, allow_pickle=True) | |
locs = data['locs'] | |
labels = data['labels'] | |
class_to_taxa = data['class_to_taxa'] | |
classes = data['classes'].item() | |
# outliers already removed | |
# create 'taxa' | |
# taxa = [class_to_taxa[clss] for clss in labels] | |
taxa = class_to_taxa[labels] | |
print('Warning: Setting default value (-1) for users, dates, years, obs_ids') | |
users = np.ones(taxa.shape[0], dtype=np.int64)*-1 | |
dates = np.ones(taxa.shape[0], dtype=np.int64)*-1 | |
years = np.ones(taxa.shape[0], dtype=np.int64)*-1 | |
obs_ids = np.ones(taxa.shape[0], dtype=np.int64)*-1 | |
torch.save((locs, taxa, users, dates, years, obs_ids), '.eval_datacache.pt') | |
return locs, taxa, users, dates, years, obs_ids | |
def choose_aux_species(current_species, num_aux_species, aux_species_seed, taxa_file): | |
if num_aux_species == 0: | |
return [] | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
data_dir = paths['train'] | |
taxa_file = os.path.join(data_dir, taxa_file) | |
with open(taxa_file, 'r') as f: | |
inat_large_metadata = json.load(f) | |
aux_species_candidates = [x['taxon_id'] for x in inat_large_metadata] | |
aux_species_candidates = np.setdiff1d(aux_species_candidates, current_species) | |
print(f'choosing {num_aux_species} species to add from {len(aux_species_candidates)} candidates') | |
rng = np.random.default_rng(aux_species_seed) | |
idx_rand_aux_species = rng.permutation(len(aux_species_candidates)) | |
aux_species = list(aux_species_candidates[idx_rand_aux_species[:num_aux_species]]) | |
return aux_species | |
def get_taxa_of_interest(species_set='all', num_aux_species=0, aux_species_seed=123, taxa_file=None, taxa_file_snt=None): | |
if species_set == 'all': | |
return None | |
if species_set == 'snt_birds': | |
assert taxa_file_snt is not None | |
with open(taxa_file_snt, 'r') as f: # | |
taxa_subsets = json.load(f) | |
taxa_of_interest = list(taxa_subsets['snt_birds']) | |
else: | |
raise NotImplementedError | |
# optionally add some other species back in: | |
aux_species = choose_aux_species(taxa_of_interest, num_aux_species, aux_species_seed, taxa_file) | |
taxa_of_interest.extend(aux_species) | |
return taxa_of_interest | |
def get_idx_subsample_observations(labels, hard_cap=-1, hard_cap_seed=123, subset=None, subset_cap=-1): | |
if hard_cap == -1: | |
if subset_cap != -1: | |
raise NotImplementedError('subset_cap set but not hard_cap') | |
return np.arange(len(labels)) | |
print(f'subsampling (up to) {hard_cap} per class for the training set') | |
ids, counts = np.unique(labels, return_counts=True) | |
count_ind = np.cumsum(counts) | |
count_ind[1:] = count_ind[:-1] | |
count_ind[0] = 0 | |
ss_rng = np.random.default_rng(hard_cap_seed) | |
idx_rand = ss_rng.permutation(len(labels)) | |
ordered_inds = np.argsort(labels[idx_rand], kind='stable') | |
caps = hard_cap + np.zeros_like(counts) | |
if subset is not None and subset_cap != -1: | |
caps[subset] = subset_cap | |
idx_ss = idx_rand[np.concatenate([ordered_inds[i:i+min(limit, cap)] for i, limit, cap in zip(count_ind, counts, caps)])] | |
print(f'final training set size: {len(idx_ss)}') | |
return idx_ss | |
def get_idx_subsample_observations_eval(labels, hard_cap=-1): | |
if hard_cap == -1: | |
return np.arange(len(labels)) | |
print(f'Selecting (up to) {hard_cap} per class for the training set') | |
labels = np.asarray(labels) | |
unique_labels = np.unique(labels) | |
idx_list = [] | |
for label in unique_labels: | |
idx = np.where(labels == label)[0] | |
cap = min(len(idx), hard_cap) | |
idx_list.append(idx[:cap]) | |
idx_ss = np.concatenate(idx_list) | |
print(f'Final training set size: {len(idx_ss)}') | |
return idx_ss | |
def uniform_sample_h3(cells, low, high): | |
'''uniformly sample points in a batch of h3 cells''' | |
out = np.empty((len(cells), 2)) | |
invalid_mask = np.arange(len(cells)) | |
cell_ids_buffer = np.empty(len(cells), dtype='uint64') | |
while len(invalid_mask) > 0: | |
#print(len(invalid_mask)) | |
pts = np.random.random((len(invalid_mask), 2)) | |
pts = high + pts*(low - high) | |
cell_ids_buffer = h3.latlng_to_cell(pts[:,0], pts[:,1], 5) | |
valid_mask = (cell_ids_buffer[:len(cells)] == cells) | |
out[invalid_mask[valid_mask]] = pts[valid_mask] | |
neg_mask = ~valid_mask | |
invalid_mask = invalid_mask[neg_mask] | |
low = low[neg_mask] | |
high = high[neg_mask] | |
cells = cells[neg_mask] | |
return out | |
class LocationIUCNDataset(torch.utils.data.Dataset): | |
# MINE MINE MINE MINE - I have included this dummy "num_context" which is probably not needed now | |
# def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False, num_context=0): | |
# MAX MAX MAX MAX | |
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False): | |
# handle input encoding: | |
self.input_enc = input_enc | |
if self.input_enc in ['env', 'sin_cos_env']: | |
raster = load_env() | |
else: | |
raster = None | |
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
if os.path.exists('iucndataset_nocap.pt'): | |
mask = torch.load('iucndataset_nocap.pt') | |
else: | |
from tqdm import tqdm | |
# load iucn data | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
data = json.load(f) | |
obs_locs = np.array(data['locs'], dtype=np.float32) | |
taxa = {int(tt):tk for tt, tk in data['taxa_presence'].items()} | |
cells = h3.latlng_to_cell(obs_locs[:, 1], obs_locs[:, 0], 5) | |
mask = np.zeros(len(locs), dtype=bool) | |
for i, data_taxa in tqdm(enumerate(class_to_taxa)): | |
if data_taxa in taxa: | |
data_cells = h3.latlng_to_cell(locs[labels==i, 1], locs[labels==i, 0], 5) | |
data = np.array(cells[taxa[data_taxa]]) | |
data_inds = data.argsort() | |
search = np.searchsorted(data[data_inds], data_cells) | |
search = search.clip(0, len(data)-1) | |
mask[labels==i] = data[data_inds][search] == data_cells | |
else: | |
mask[labels==i] = False | |
torch.save(mask, 'iucndataset_nocap.pt') | |
print('Reduced Size: ', mask.sum()) | |
# remove locations that are not in the iucn dataset | |
locs = locs[mask] | |
labels = labels[mask] | |
if dates is not None: | |
dates = dates[mask] | |
labels_uniq, labels = np.unique(labels, return_inverse=True) | |
classes = {class_to_taxa[i]: classes[class_to_taxa[i]] for i in labels_uniq} | |
class_to_taxa = [class_to_taxa[i] for i in labels_uniq] | |
# define some properties: | |
self.locs = locs | |
self.loc_feats = self.enc.encode(self.locs) | |
self.labels = torch.from_numpy(labels) | |
self.classes = classes | |
self.class_to_taxa = class_to_taxa | |
if dates is not None: | |
self.dates = dates | |
self.enc_time = utils.TimeEncoder() | |
# useful numbers: | |
self.num_classes = len(classes) | |
self.input_dim = input_dim | |
self.time_dim = time_dim | |
self.noise_time = noise_time | |
if self.time_dim > 0: | |
self.__getitem__ = self._get_item_time | |
if self.enc.raster is not None: | |
self.enc.raster = self.enc.raster.to(device) | |
def viz_map(self, taxa_id, high_res=False): | |
from matplotlib import pyplot as plt | |
# load params | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
data = json.load(f) | |
obs_locs = np.array(data['locs'], dtype=np.float32) | |
taxa = {int(tt): tk for tt, tk in data['taxa_presence'].items()} | |
taxa_cells = h3.latlng_to_cell(obs_locs[:, 1], obs_locs[:, 0], 5) | |
# load taxa of interest | |
if taxa_id in self.class_to_taxa: | |
class_of_interest = self.class_to_taxa.index(taxa_id) | |
else: | |
print(f'Error: Taxa specified that is not in the model: {taxa_id}') | |
return False | |
print(f'Loading taxa: {taxa_id}') | |
# load ocean mask | |
if high_res: | |
mask = np.load(os.path.join(paths['masks'], 'ocean_mask_hr.npy')) | |
else: | |
mask = np.load(os.path.join(paths['masks'], 'ocean_mask.npy')) | |
mask_shape = mask.shape | |
mask_inds = np.where(mask.reshape(-1) == 1)[0] | |
# generate input features | |
locs = utils.coord_grid(mask_shape) | |
locs_cells = h3.latlng_to_cell(locs[:,1], locs[:,0], 5) | |
# mask iucn | |
iucn_cells = np.sort(taxa_cells[taxa[taxa_id]]) | |
mask = iucn_cells[np.searchsorted(iucn_cells, locs_cells).clip(max=len(iucn_cells)-1)] == locs_cells | |
mask_inds = np.where(mask == 1)[0] | |
mask = mask.reshape(mask_shape) | |
cell_inds, cell_counts = np.unique(h3.latlng_to_cell(90*self.locs[self.labels==class_of_interest, 1], 180*self.locs[self.labels==class_of_interest, 0], 5),return_counts=True) | |
search_inds = np.searchsorted(cell_inds, locs_cells).clip(max=len(cell_inds)-1) | |
preds = np.zeros(len(locs)) | |
cell_mask = cell_inds[search_inds] == locs_cells | |
preds[cell_mask] = cell_counts[search_inds][cell_mask] | |
preds = preds/preds.sum() | |
# Convert preds to log scale | |
preds = np.log(preds) | |
center = np.median(preds[np.isfinite(preds)]) | |
preds = 0.5*preds/(preds.max()-center) | |
preds = (preds + 1 - preds.max()).clip(min=0) | |
# mask data | |
op_im = np.ones((mask.shape[0] * mask.shape[1])) * np.nan # set to NaN | |
op_im[mask_inds] = preds[mask_inds] | |
# reshape and create masked array for visualization | |
op_im = op_im.reshape((mask.shape[0], mask.shape[1])) | |
op_im = np.ma.masked_invalid(op_im) | |
# set color for masked values | |
cmap = plt.cm.plasma | |
cmap.set_bad(color='none') | |
vmax = np.max(op_im) | |
# save image | |
save_loc = os.path.join('./images/', str(taxa_id) + '_map.png') | |
print(f'Saving image to {save_loc}') | |
plt.imsave(fname=save_loc, arr=op_im, vmin=0, vmax=vmax, cmap=cmap) | |
def __len__(self): | |
return self.loc_feats.shape[0] | |
def __getitem__(self, index): | |
return index | |
def collate_fn(self, batch): | |
if isinstance(batch[0], int): | |
loc_feat = self.loc_feats[batch, :] | |
loc = self.locs[batch, :] | |
class_id = self.labels[batch] | |
return loc_feat, loc, class_id | |
else: | |
return torch.utils.data.default_collate(batch) | |
def _get_item_time(self, index): | |
loc_feat = self.loc_feats[index, :] | |
loc = self.locs[index, :] | |
class_id = self.labels[index] | |
date = self.dates[index] | |
# add noise | |
if self.noise_time: | |
noise_level = random.random() | |
# noise = (2*random.random() - 1) * (0.5*(365 ** (noise_level - 1))) | |
noise = (2 * random.random() - 1) * (0.5 * noise_level) | |
loc_feat = torch.cat([loc_feat, self.enc_time.encode_fast([date.item() + noise, noise_level])]) | |
else: | |
raise NotImplementedError() | |
loc_feat = torch.cat([loc_feat, torch.tensor(self.enc_time.encode([2 * date.item() - 1], normalize=False))]) | |
return loc_feat, torch.cat([loc, date[None]]), class_id | |
class UniformH3Dataset(torch.utils.data.Dataset): | |
# MINE MINE MINE MINE - I have included this dummy "num_context" which is probably not needed now | |
# def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False, num_context=0, snt=False): | |
# MAX MAX MAX MAX | |
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False, snt=False): | |
if dates is not None or time_dim > 0 or noise_time: | |
raise NotImplementedError() | |
# handle input encoding: | |
self.input_enc = input_enc | |
if self.input_enc in ['env', 'sin_cos_env']: | |
raster = load_env() | |
else: | |
raster = None | |
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
self._enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# load h3 data: | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
if snt: | |
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True) | |
D = D.item() | |
loc_indices_per_species = D['loc_indices_per_species'] | |
taxa = D['taxa'] | |
loc_indices_per_species = {t:ls for t,ls in zip(taxa, loc_indices_per_species)} | |
obs_locs = D['obs_locs'] | |
else: | |
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
data = json.load(f) | |
obs_locs = np.array(data['locs'], dtype=np.float32) | |
loc_indices_per_species = data['taxa_presence'] | |
self.taxa = {int(tt):tk for tt, tk in loc_indices_per_species.items()} | |
self.cells = h3.latlng_to_cell(obs_locs[:,1], obs_locs[:,0], 5) | |
self.ind_to_cell = h3.latlng_to_cell(obs_locs[:, 1], obs_locs[:, 0], 5) | |
self.low_b = np.stack([np.array(h3.h3_to_geo_boundary(c)).min(axis=0) for c in self.cells]) | |
self.high_b = np.stack([np.array(h3.h3_to_geo_boundary(c)).max(axis=0) for c in self.cells]) | |
if os.path.exists('iucndataset_nocap.pt') and not snt: | |
mask = torch.load('iucndataset_nocap.pt') | |
elif os.path.exists('sntdataset_nocap.pt') and snt: | |
mask = torch.load('sntdataset_nocap.pt') | |
else: | |
from tqdm import tqdm | |
mask = np.zeros(len(locs), dtype=bool) | |
for i, data_taxa in tqdm(enumerate(class_to_taxa)): | |
if data_taxa in self.taxa: | |
data_cells = h3.latlng_to_cell(locs[labels==i, 1], locs[labels==i, 0], 5) | |
data = np.array(self.cells[self.taxa[data_taxa]]) | |
data_inds = data.argsort() | |
search = np.searchsorted(data[data_inds], data_cells) | |
search = search.clip(0, len(data)-1) | |
mask[labels==i] = data[data_inds][search] == data_cells | |
else: | |
mask[labels==i] = False | |
torch.save(mask, 'sntdataset_nocap.pt' if snt else 'iucndataset_nocap.pt') | |
print('Reduced Size: ', mask.sum()) | |
# remove locations that are not in the iucn dataset | |
locs = locs[mask] | |
labels = labels[mask] | |
labels_uniq, labels = np.unique(labels, return_inverse=True) | |
classes = {class_to_taxa[i]: classes[class_to_taxa[i]] for i in labels_uniq} | |
class_to_taxa = [class_to_taxa[i] for i in labels_uniq] | |
# calculate species statistics | |
_, counts = np.unique(labels, return_counts=True) | |
self.num_obs = counts.sum() | |
self.counts = counts / self.num_obs | |
# define some properties: | |
self.classes = classes | |
self.class_to_taxa = class_to_taxa | |
# useful numbers: | |
self.num_classes = len(self.class_to_taxa) | |
self.input_dim = input_dim | |
if self.enc.raster is not None: | |
self.enc.raster = self.enc.raster.to(device) | |
def __len__(self): | |
return self.num_obs | |
def __getitem__(self, index, species=None): | |
if species is None: | |
class_id = np.random.choice(np.arange(self.num_classes), p=self.counts) | |
species = self.class_to_taxa[class_id] | |
else: | |
class_id = -1 | |
ind = random.choice(self.taxa[species]) | |
cell = self.cells[ind] | |
high, low = self.high_b[ind], self.low_b[ind] | |
return cell, high, low, class_id | |
def collate_fn(self, batch): | |
cell, high, low, class_id = zip(*batch) | |
cell = np.array(cell) | |
high = np.stack(high) | |
low = np.stack(low) | |
class_id = torch.tensor(class_id, dtype=torch.long) | |
pts = torch.from_numpy(uniform_sample_h3(cell, high, low)).flip(1) | |
return self._enc.encode(pts, normalize=True).float(), pts, class_id | |
class MultiUniformH3Dataset(torch.utils.data.Dataset): | |
# MINE MINE MINE MINE - I have included this dummy "num_context" which is probably not needed now | |
# def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False, num_context=0): | |
# MAX MAX MAX MAX MAX MAX MAX | |
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False): | |
if dates is not None or time_dim > 0 or noise_time: | |
raise NotImplementedError() | |
# handle input encoding: | |
self.input_enc = input_enc | |
if self.input_enc in ['env', 'sin_cos_env']: | |
raster = load_env() | |
else: | |
raster = None | |
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
self._enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# load h3 data: | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
data = json.load(f) | |
obs_locs = np.array(data['locs'], dtype=np.float32) | |
self.taxa = {int(tt): tk for tt, tk in data['taxa_presence'].items()} | |
self.cells = h3.latlng_to_cell(obs_locs[:, 1], obs_locs[:, 0], 5) | |
self.ind_to_cell = h3.latlng_to_cell(obs_locs[:, 1], obs_locs[:, 0], 5) | |
self.low_b = np.stack([np.array(h3.h3_to_geo_boundary(c)).min(axis=0) for c in self.cells]) | |
self.high_b = np.stack([np.array(h3.h3_to_geo_boundary(c)).max(axis=0) for c in self.cells]) | |
if os.path.exists('iucndataset_nocap.pt'): | |
mask = torch.load('iucndataset_nocap.pt') | |
else: | |
from tqdm import tqdm | |
mask = np.zeros(len(locs), dtype=bool) | |
for i, data_taxa in tqdm(enumerate(class_to_taxa)): | |
if data_taxa in self.taxa: | |
data_cells = h3.latlng_to_cell(locs[labels == i, 1], locs[labels == i, 0], 5) | |
data = np.array(self.cells[self.taxa[data_taxa]]) | |
data_inds = data.argsort() | |
search = np.searchsorted(data[data_inds], data_cells) | |
search = search.clip(0, len(data) - 1) | |
mask[labels == i] = data[data_inds][search] == data_cells | |
else: | |
mask[labels == i] = False | |
torch.save(mask, 'iucndataset_nocap.pt') | |
print('Reduced Size: ', mask.sum()) | |
# remove locations that are not in the iucn dataset | |
locs = locs[mask] | |
labels = labels[mask] | |
labels_uniq, labels = np.unique(labels, return_inverse=True) | |
classes = {class_to_taxa[i]: classes[class_to_taxa[i]] for i in labels_uniq} | |
class_to_taxa = [class_to_taxa[i] for i in labels_uniq] | |
# calculate species statistics | |
_, counts = np.unique(labels, return_counts=True) | |
self.num_obs = counts.sum() | |
self.counts = counts / self.num_obs | |
# define some properties: | |
self.classes = classes | |
self.class_to_taxa = class_to_taxa | |
if os.path.exists('taxa_inverse.pt'): | |
self.taxa_inverse = torch.load('taxa_inverse.pt') | |
else: | |
from collections import defaultdict | |
ctt_i = {c: i for i, c in enumerate(class_to_taxa)} | |
self.taxa_inverse = defaultdict(list) | |
for k, vs in self.taxa.items(): | |
for v in vs: | |
self.taxa_inverse[v].append(ctt_i[k]) | |
torch.save(self.taxa_inverse, 'taxa_inverse.pt') | |
# useful numbers: | |
self.num_classes = len(self.class_to_taxa) | |
self.input_dim = input_dim | |
if self.enc.raster is not None: | |
self.enc.raster = self.enc.raster.to(device) | |
def __len__(self): | |
return self.num_obs | |
def __getitem__(self, index): | |
cell_ind = np.random.randint(len(self.cells)) | |
cell = self.cells[cell_ind] | |
label = np.zeros(len(self.classes), dtype=np.float32) | |
label[self.taxa_inverse[cell_ind]] = 1 | |
high, low = self.high_b[cell_ind], self.low_b[cell_ind] | |
return cell, high, low, label | |
def collate_fn(self, batch): | |
cell, high, low, label = zip(*batch) | |
cell = np.array(cell) | |
high = np.stack(high) | |
low = np.stack(low) | |
labels = torch.from_numpy(np.stack(label)) | |
pts = torch.from_numpy(uniform_sample_h3(cell, high, low)).flip(1) | |
return self._enc.encode(pts, normalize=True).float(), pts.float(), labels | |
class MultiLocationIUCNDataset(torch.utils.data.Dataset): | |
# MINE MINE MINE MINE - I have included this dummy "num_context" which is probably not needed now | |
# def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False, num_context=0): | |
# MAX MAX MAX MAX MAX MAX | |
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False): | |
# handle input encoding: | |
self.input_enc = input_enc | |
if self.input_enc in ['env', 'sin_cos_env']: | |
raster = load_env() | |
else: | |
raster = None | |
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim) | |
# load h3 data: | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
data = json.load(f) | |
obs_locs = np.array(data['locs'], dtype=np.float32) | |
self.taxa = {int(tt): tk for tt, tk in data['taxa_presence'].items()} | |
self.cells = h3.latlng_to_cell(obs_locs[:, 1], obs_locs[:, 0], 5) | |
self.cells_inverse = {c: i for i, c in enumerate(self.cells)} | |
if os.path.exists('iucndataset_nocap.pt'): | |
mask = torch.load('iucndataset_nocap.pt') | |
else: | |
from tqdm import tqdm | |
mask = np.zeros(len(locs), dtype=bool) | |
for i, data_taxa in tqdm(enumerate(class_to_taxa)): | |
if data_taxa in self.taxa: | |
data_cells = h3.latlng_to_cell(locs[labels==i, 1], locs[labels==i, 0], 5) | |
data = np.array(self.cells[self.taxa[data_taxa]]) | |
data_inds = data.argsort() | |
search = np.searchsorted(data[data_inds], data_cells) | |
search = search.clip(0, len(data)-1) | |
mask[labels==i] = data[data_inds][search] == data_cells | |
else: | |
mask[labels==i] = False | |
torch.save(mask, 'iucndataset_nocap.pt') | |
print('Reduced Size: ', mask.sum()) | |
# remove locations that are not in the iucn dataset | |
locs = locs[mask] | |
labels = labels[mask] | |
if dates is not None: | |
dates = dates[mask] | |
labels_uniq, labels = np.unique(labels, return_inverse=True) | |
classes = {class_to_taxa[i]: classes[class_to_taxa[i]] for i in labels_uniq} | |
class_to_taxa = [class_to_taxa[i] for i in labels_uniq] | |
# define some properties: | |
self.locs = locs | |
self.loc_cells = h3.latlng_to_cell(locs[:, 1], locs[:, 0], 5) | |
self.loc_feats = self.enc.encode(self.locs) | |
self.labels = labels | |
self.classes = classes | |
self.class_to_taxa = class_to_taxa | |
if dates is not None: | |
self.dates = dates | |
self.enc_time = utils.TimeEncoder() | |
if os.path.exists('taxa_inverse.pt'): | |
self.taxa_inverse = torch.load('taxa_inverse.pt') | |
else: | |
from collections import defaultdict | |
ctt_i = {c: i for i, c in enumerate(class_to_taxa)} | |
self.taxa_inverse = defaultdict(list) | |
for k, vs in self.taxa.items(): | |
for v in vs: | |
self.taxa_inverse[v].append(ctt_i[k]) | |
torch.save(self.taxa_inverse, 'taxa_inverse.pt') | |
# useful numbers: | |
self.num_classes = len(classes) | |
self.input_dim = input_dim | |
self.time_dim = time_dim | |
self.noise_time = noise_time | |
if self.enc.raster is not None: | |
self.enc.raster = self.enc.raster.to(device) | |
def __len__(self): | |
return self.loc_feats.shape[0] | |
def __getitem__(self, index): | |
loc_feat = self.loc_feats[index, :] | |
loc = self.locs[index, :] | |
cell = self.loc_cells[index] | |
label = np.zeros(len(self.classes), dtype=np.float32) | |
label[self.taxa_inverse[self.cells_inverse[cell]]] = 1 | |
if self.time_dim > 0: | |
date = self.dates[index] | |
# add noise | |
if self.noise_time: | |
noise_level = random.random() | |
#noise = (2*random.random() - 1) * (0.5*(365 ** (noise_level - 1))) | |
noise = (2*random.random() - 1) * (0.5*noise_level) | |
loc_feat = torch.cat([loc_feat, self.enc_time.encode_fast([date.item()+noise,noise_level])]) | |
else: | |
raise NotImplementedError() | |
loc_feat = torch.cat([loc_feat, torch.tensor(self.enc_time.encode([2*date.item()-1], normalize=False))]) | |
return loc_feat, torch.cat([loc, date[None]]), label | |
else: | |
return loc_feat, loc, label | |
# My function for creating the per class dicts | |
def organize_data_by_labels(labels, locs): | |
label_dict = {} # Initialize an empty dictionary | |
for label, loc in zip(labels, locs): # Loop through labels and locations | |
if label in label_dict: | |
label_dict[label].append(loc) # Append the location | |
else: | |
label_dict[label] = [loc] # Start a new list with the tuple of location | |
return label_dict | |
# MINE MINE - I have added new datasets that I need for my new models | |
dataset_classes = {'inat': LocationDataset, | |
'iucn_inat': LocationIUCNDataset, | |
'iucn_uniform': UniformH3Dataset, | |
'multi_iucn_uniform': MultiUniformH3Dataset, | |
'multi_iucn_inat': MultiLocationIUCNDataset, | |
'transformer': TransformerLocationDataset, | |
'eval_transformer': EvalTransformerLocationDataset, | |
'text_transformer': TransformerLocationTextDataset, | |
'random_text_transformer': TransformerLocationTextDatasetRandomizeOutputs, | |
'eval_text_transformer': EvalTransformerTextLocationDataset, | |
'eval_dummy_text_transformer': EvalTransformerDummyTextLocationDataset, | |
'variable_tokens': TransformerDatasetVariableTokens, | |
} | |
def get_dataset_class(dataset_name): | |
# First, try to get the class directly from the dictionary | |
if dataset_name in dataset_classes: | |
return dataset_classes[dataset_name] | |
else: | |
# Check if the dataset name matches the pattern for probabilities | |
pattern = r'random_text_transformer_(\d+\.\d+)_(\d+\.\d+)' | |
match = re.match(pattern, dataset_name) | |
if match: | |
# Extract the probabilities from the dataset name | |
just_obs_prob = float(match.group(1)) | |
just_text_prob = float(match.group(2)) | |
# Use functools.partial to create a new class with preset parameters | |
return partial( | |
TransformerLocationTextDatasetRandomizeOutputs, | |
just_obs_prob=just_obs_prob, | |
just_text_prob=just_text_prob | |
) | |
else: | |
# If the dataset name is not recognized, raise an error | |
raise ValueError(f"Unknown dataset name: {dataset_name}") | |
def get_train_data_old(params): | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
data_dir = paths['train'] | |
obs_file = os.path.join(data_dir, params['obs_file']) | |
taxa_file = os.path.join(data_dir, params['taxa_file']) | |
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json') | |
taxa_of_interest = get_taxa_of_interest(params['species_set'], params['num_aux_species'], params['aux_species_seed'], params['taxa_file'], taxa_file_snt) | |
locs, labels, _, dates, _, _ = load_inat_data(obs_file, taxa_of_interest) | |
if params['zero_shot']: | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
data = json.load(f) | |
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True) | |
D = D.item() | |
taxa_snt = D['taxa'].tolist() | |
taxa = [int(tt) for tt in data['taxa_presence'].keys()] | |
taxa = list(set(taxa + taxa_snt)) | |
mask = labels != taxa[0] | |
# MINE MINE MINE MINE MINE | |
# for i in range(0, len(taxa)): | |
# MAX MAX MAX MAX | |
for i in range(1, len(taxa)): | |
mask &= (labels != taxa[i]) | |
locs = locs[mask] | |
dates = dates[mask] | |
labels = labels[mask] | |
unique_taxa, class_ids = np.unique(labels, return_inverse=True) | |
class_to_taxa = unique_taxa.tolist() | |
# load class names | |
class_info_file = json.load(open(taxa_file, 'r')) | |
class_names_file = [cc['latin_name'] for cc in class_info_file] | |
taxa_ids_file = [cc['taxon_id'] for cc in class_info_file] | |
classes = dict(zip(taxa_ids_file, class_names_file)) | |
subset = None | |
if params['subset_cap_name'] is not None: | |
if params['subset_cap_name'] == 'iucn': | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
data = json.load(f) | |
taxa = [int(tt) for tt in data['taxa_presence'].keys()] | |
# get classes to eval | |
subset = np.zeros((len(taxa),), dtype=int) | |
for tt_id, tt in enumerate(taxa): | |
class_of_interest = np.where(np.array(class_to_taxa) == tt)[0] | |
if len(class_of_interest) != 0: | |
subset[tt_id] = class_of_interest | |
else: | |
raise NotImplementedError(f'Uknown subset name: {params["subset_cap_name"]}') | |
idx_ss = get_idx_subsample_observations(labels, params['hard_cap_num_per_class'], params['hard_cap_seed'], subset, params['subset_cap_num_per_class']) | |
locs = torch.from_numpy(np.array(locs)[idx_ss]) # convert to Tensor | |
labels = torch.from_numpy(np.array(class_ids)[idx_ss]) | |
dates = 364/365*torch.from_numpy(np.array(dates)[idx_ss]) if params['input_time'] else None | |
# MINE MINE MINE MINE - there are differences here in which dataset to load | |
# if ((params['dataset'] == 'transformer') or (params['dataset'] == 'eval_transformer') | |
# or (params['dataset'] == 'text_transformer') or (params['dataset'] == 'eval_text_transformer')): | |
if 'variable' in params['dataset']: | |
if 'eval' in params['dataset']: | |
thing = np.load('./data/transformer_eval_dataset_1.npz', allow_pickle=True) | |
locs = torch.from_numpy(thing['locs']) | |
labels = torch.from_numpy(thing['labels']) | |
classes = thing['classes'].item() | |
class_to_taxa = list(thing['class_to_taxa']) | |
# Load the text embeddings | |
text_embs_dict = torch.load(params['text_emb_path']) | |
text_embs = text_embs_dict['data'] | |
text_embs_ids = text_embs_dict['taxon_id'] | |
text_embs_keys = text_embs_dict['keys'] | |
# Load the image embeddings | |
image_embs_dict = torch.load(params['image_emb_path']) | |
image_embs = image_embs_dict['data'] | |
image_embs_ids = image_embs_dict['taxon_id'] | |
image_embs_keys = image_embs_dict['keys'] | |
if 'eval' in params['dataset']: | |
ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, text_embs=text_embs, text_embs_ids=text_embs_ids, | |
text_embs_keys=text_embs_keys, image_embs=image_embs, image_embs_ids=image_embs_ids, | |
image_embs_keys=image_embs_keys, input_enc=params['input_enc'], | |
device=params['device'], dates=dates, input_dim=params['input_dim'], | |
time_dim=params['input_time_dim'], noise_time=params['noise_time'], | |
num_context=params['num_context'], | |
transformer_input_enc=params['transformer_input_enc'], | |
token_dim=params['species_dim'], | |
jitter=params['add_location_noise'], | |
variable_context_length=params['variable_context_length'], | |
loc_prob=params['loc_prob'], text_prob=params['text_prob'], | |
image_prob=params['image_prob'], env_prob=0.0, eval_mode=True) | |
else: | |
ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, text_embs=text_embs, text_embs_ids=text_embs_ids, | |
text_embs_keys=text_embs_keys, image_embs=image_embs, image_embs_ids=image_embs_ids, | |
image_embs_keys=image_embs_keys, input_enc=params['input_enc'], | |
device=params['device'], dates=dates, input_dim=params['input_dim'], | |
time_dim=params['input_time_dim'], noise_time=params['noise_time'], | |
num_context=params['num_context'], | |
transformer_input_enc=params['transformer_input_enc'], | |
token_dim=params['species_dim'], | |
jitter=params['add_location_noise'], | |
variable_context_length=params['variable_context_length'], | |
loc_prob=params['loc_prob'], text_prob=params['text_prob'], | |
image_prob=params['image_prob'], env_prob=0.0, eval_mode=False) | |
elif 'transformer' in params['dataset']: | |
# holdover to make earlier models work | |
for key in ['add_location_noise', 'variable_context_length']: | |
if key not in params: | |
params[key] = False | |
# print("THIS PART WILL NEED CHANGING FOR YOUR NEW DATASET") | |
# # ensuring I load the same eval points each time | |
# # TODO make a few different versions and see if anything changes for my LS models | |
if params['dataset'] == 'eval_transformer' or params['dataset'] == 'eval_text_transformer': | |
thing = np.load('./data/transformer_eval_dataset_1.npz', allow_pickle=True) | |
locs = torch.from_numpy(thing['locs']) | |
labels = torch.from_numpy(thing['labels']) | |
classes = thing['classes'].item() | |
class_to_taxa = list(thing['class_to_taxa']) | |
# loading text data for my models | |
if 'text' in params['dataset']: | |
# Load the text embeddings | |
embs_dict = torch.load(params['text_emb_path']) | |
# Extract the embeddings (assumed to be stored under the 'data' key) | |
embs = embs_dict['data'] # This should be a tensor of shape [num_taxa, 4096] | |
# Extract the taxon IDs (assumed to be stored under the 'taxon_id' key) | |
embs_ids = embs_dict['taxon_id'] # This could be a list or a tensor | |
embs_keys = embs_dict['keys'] | |
ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, embs=embs, embs_ids=embs_ids, | |
embs_keys=embs_keys, input_enc=params['input_enc'], | |
device=params['device'], dates=dates, input_dim=params['input_dim'], | |
time_dim=params['input_time_dim'], noise_time=params['noise_time'], | |
num_context=params['num_context'], | |
transformer_input_enc=params['transformer_input_enc'], | |
token_dim=params['species_dim'], | |
jitter=params['add_location_noise'], | |
variable_context_length=params['variable_context_length']) | |
else: | |
ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, input_enc=params['input_enc'], | |
device=params['device'], dates=dates, input_dim=params['input_dim'], | |
time_dim=params['input_time_dim'], noise_time=params['noise_time'], | |
num_context=params['num_context'], transformer_input_enc=params['transformer_input_enc'], | |
token_dim=params['species_dim'], jitter=params['add_location_noise'], | |
variable_context_length=params['variable_context_length']) | |
else: | |
# this version - which should be accessed if we are not using one of my datasets - is the same as Max's | |
ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, input_enc=params['input_enc'], | |
device=params['device'], dates=dates, input_dim=params['input_dim'], | |
time_dim=params['input_time_dim'], noise_time=params['noise_time']) | |
return ds | |
# def get_train_data(params): | |
# print('modified this to hopefully reduce loading during eval') | |
# if 'eval' in params['dataset']: | |
# print('Creating eval dataset') | |
# with open('paths.json', 'r') as f: | |
# paths = json.load(f) | |
# eval_data_path = os.path.join(paths['data'], 'positive_eval_data.npz') | |
# thing = np.load(eval_data_path, allow_pickle=True) | |
# locs = torch.from_numpy(thing['locs']) | |
# labels = torch.from_numpy(thing['labels']) | |
# classes = thing['classes'].item() | |
# class_to_taxa = list(thing['class_to_taxa']) | |
# # for now dates is just set to None. Change if I start doing time stuff. | |
# dates = None | |
# | |
# if 'transformer' in params['dataset']: | |
# # holdover to make earlier models work | |
# for key in ['add_location_noise', 'variable_context_length']: | |
# if key not in params: | |
# params[key] = False | |
# | |
# # loading text data for my models | |
# if 'text' in params['dataset']: | |
# # Load the text embeddings | |
# embs_dict = torch.load(params['text_emb_path'], weights_only=False) | |
# | |
# # Extract the embeddings (assumed to be stored under the 'data' key) | |
# embs = embs_dict['data'] # This should be a tensor of shape [num_taxa, 4096] | |
# | |
# # Extract the taxon IDs (assumed to be stored under the 'taxon_id' key) | |
# embs_ids = embs_dict['taxon_id'] # This could be a list or a tensor | |
# | |
# embs_keys = embs_dict['keys'] | |
# | |
# ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, embs=embs, embs_ids=embs_ids, | |
# embs_keys=embs_keys, input_enc=params['input_enc'], | |
# device=params['device'], dates=dates, input_dim=params['input_dim'], | |
# time_dim=params['input_time_dim'], noise_time=params['noise_time'], | |
# num_context=params['num_context'], | |
# transformer_input_enc=params['transformer_input_enc'], | |
# token_dim=params['species_dim'], | |
# jitter=params['add_location_noise'], | |
# variable_context_length=params['variable_context_length']) | |
# | |
# | |
# else: | |
# ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, input_enc=params['input_enc'], | |
# device=params['device'], dates=dates, input_dim=params['input_dim'], | |
# time_dim=params['input_time_dim'], noise_time=params['noise_time'], | |
# num_context=params['num_context'], | |
# transformer_input_enc=params['transformer_input_enc'], | |
# token_dim=params['species_dim'], | |
# jitter=params['add_location_noise'], | |
# variable_context_length=params['variable_context_length']) | |
# | |
# if 'variable' in params['dataset']: | |
# # Load the text embeddings | |
# text_embs_dict = torch.load(params['text_emb_path'], weights_only=False) | |
# text_embs = text_embs_dict['data'] | |
# text_embs_ids = text_embs_dict['taxon_id'] | |
# text_embs_keys = text_embs_dict['keys'] | |
# | |
# # Load the image embeddings | |
# image_embs_dict = torch.load(params['image_emb_path'], weights_only=False) | |
# image_embs = image_embs_dict['data'] | |
# image_embs_ids = image_embs_dict['taxon_id'] | |
# image_embs_keys = image_embs_dict['keys'] | |
# | |
# ds = dataset_classes['variable_tokens'](locs, labels, classes, class_to_taxa, text_embs=text_embs, | |
# text_embs_ids=text_embs_ids, | |
# text_embs_keys=text_embs_keys, image_embs=image_embs, | |
# image_embs_ids=image_embs_ids, | |
# image_embs_keys=image_embs_keys, input_enc=params['input_enc'], | |
# device=params['device'], dates=dates, | |
# input_dim=params['input_dim'], | |
# time_dim=params['input_time_dim'], | |
# noise_time=params['noise_time'], | |
# num_context=params['num_context'], | |
# transformer_input_enc=params['transformer_input_enc'], | |
# token_dim=params['species_dim'], | |
# jitter=params['add_location_noise'], | |
# variable_context_length=params['variable_context_length'], | |
# loc_prob=params['loc_prob'], text_prob=params['text_prob'], | |
# image_prob=params['image_prob'], env_prob=0.0, eval_mode=True) | |
# | |
# else: | |
# print('Creating train dataset') | |
# with open('paths.json', 'r') as f: | |
# paths = json.load(f) | |
# data_dir = paths['train'] | |
# obs_file = os.path.join(data_dir, params['obs_file']) | |
# taxa_file = os.path.join(data_dir, params['taxa_file']) | |
# taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json') | |
# | |
# taxa_of_interest = get_taxa_of_interest(params['species_set'], params['num_aux_species'], params['aux_species_seed'], params['taxa_file'], taxa_file_snt) | |
# | |
# locs, labels, _, dates, _, _ = load_inat_data(obs_file, taxa_of_interest) | |
# if params['zero_shot']: | |
# # print("You can easily prevent the loading of the large IUCN dataset here by merely saving a list of eval species") | |
# # with open('paths.json', 'r') as f: | |
# # paths = json.load(f) | |
# # with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
# # data = json.load(f) | |
# # D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True) | |
# # D = D.item() | |
# # taxa_snt = D['taxa'].tolist() | |
# # taxa = [int(tt) for tt in data['taxa_presence'].keys()] | |
# # taxa = list(set(taxa + taxa_snt)) | |
# print('Hopefully we are now avoiding loading the SNT and IUCN data to find the taxa in it') | |
# with open('paths.json', 'r') as f: | |
# paths = json.load(f) | |
# eval_taxa_path = os.path.join(paths['data'], 'eval_taxa_list.npy') | |
# taxa = np.load(eval_taxa_path, allow_pickle=True) | |
# mask = labels != taxa[0] | |
# # MINE MINE MINE MINE MINE | |
# # for i in range(0, len(taxa)): | |
# # MAX MAX MAX MAX | |
# for i in range(1, len(taxa)): | |
# mask &= (labels != taxa[i]) | |
# locs = locs[mask] | |
# dates = dates[mask] | |
# labels = labels[mask] | |
# unique_taxa, class_ids = np.unique(labels, return_inverse=True) | |
# class_to_taxa = unique_taxa.tolist() | |
# | |
# # load class names | |
# class_info_file = json.load(open(taxa_file, 'r')) | |
# class_names_file = [cc['latin_name'] for cc in class_info_file] | |
# taxa_ids_file = [cc['taxon_id'] for cc in class_info_file] | |
# classes = dict(zip(taxa_ids_file, class_names_file)) | |
# | |
# subset = None | |
# if params['subset_cap_name'] is not None: | |
# if params['subset_cap_name'] == 'iucn': | |
# with open('paths.json', 'r') as f: | |
# paths = json.load(f) | |
# with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
# data = json.load(f) | |
# taxa = [int(tt) for tt in data['taxa_presence'].keys()] | |
# # get classes to eval | |
# subset = np.zeros((len(taxa),), dtype=int) | |
# for tt_id, tt in enumerate(taxa): | |
# class_of_interest = np.where(np.array(class_to_taxa) == tt)[0] | |
# if len(class_of_interest) != 0: | |
# subset[tt_id] = class_of_interest | |
# else: | |
# raise NotImplementedError(f'Uknown subset name: {params["subset_cap_name"]}') | |
# | |
# idx_ss = get_idx_subsample_observations(labels, params['hard_cap_num_per_class'], params['hard_cap_seed'], subset, params['subset_cap_num_per_class']) | |
# | |
# locs = torch.from_numpy(np.array(locs)[idx_ss]) # convert to Tensor | |
# | |
# labels = torch.from_numpy(np.array(class_ids)[idx_ss]) | |
# | |
# dates = 364/365*torch.from_numpy(np.array(dates)[idx_ss]) if params['input_time'] else None | |
# | |
# # MINE MINE MINE MINE - there are differences here in which dataset to load | |
# if 'variable' in params['dataset']: | |
# | |
# # Load the text embeddings | |
# text_embs_dict = torch.load(params['text_emb_path'], weights_only=False) | |
# text_embs = text_embs_dict['data'] | |
# text_embs_ids = text_embs_dict['taxon_id'] | |
# text_embs_keys = text_embs_dict['keys'] | |
# | |
# # Load the image embeddings | |
# image_embs_dict = torch.load(params['image_emb_path'], weights_only=False) | |
# image_embs = image_embs_dict['data'] | |
# image_embs_ids = image_embs_dict['taxon_id'] | |
# image_embs_keys = image_embs_dict['keys'] | |
# | |
# ds = dataset_classes['variable_tokens'](locs, labels, classes, class_to_taxa, text_embs=text_embs, text_embs_ids=text_embs_ids, | |
# text_embs_keys=text_embs_keys, image_embs=image_embs, image_embs_ids=image_embs_ids, | |
# image_embs_keys=image_embs_keys, input_enc=params['input_enc'], | |
# device=params['device'], dates=dates, input_dim=params['input_dim'], | |
# time_dim=params['input_time_dim'], noise_time=params['noise_time'], | |
# num_context=params['num_context'], | |
# transformer_input_enc=params['transformer_input_enc'], | |
# token_dim=params['species_dim'], | |
# jitter=params['add_location_noise'], | |
# variable_context_length=params['variable_context_length'], | |
# loc_prob=params['loc_prob'], text_prob=params['text_prob'], | |
# image_prob=params['image_prob'], env_prob=0.0, eval_mode=False) | |
# | |
# elif 'transformer' in params['dataset']: | |
# # holdover to make earlier models work | |
# for key in ['add_location_noise', 'variable_context_length']: | |
# if key not in params: | |
# params[key] = False | |
# | |
# # loading text data for my models | |
# if 'text' in params['dataset']: | |
# # Load the text embeddings | |
# embs_dict = torch.load(params['text_emb_path'], weights_only=False) | |
# | |
# # Extract the embeddings (assumed to be stored under the 'data' key) | |
# embs = embs_dict['data'] # This should be a tensor of shape [num_taxa, 4096] | |
# | |
# # Extract the taxon IDs (assumed to be stored under the 'taxon_id' key) | |
# embs_ids = embs_dict['taxon_id'] # This could be a list or a tensor | |
# | |
# embs_keys = embs_dict['keys'] | |
# | |
# ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, embs=embs, embs_ids=embs_ids, | |
# embs_keys=embs_keys, input_enc=params['input_enc'], | |
# device=params['device'], dates=dates, input_dim=params['input_dim'], | |
# time_dim=params['input_time_dim'], noise_time=params['noise_time'], | |
# num_context=params['num_context'], | |
# transformer_input_enc=params['transformer_input_enc'], | |
# token_dim=params['species_dim'], | |
# jitter=params['add_location_noise'], | |
# variable_context_length=params['variable_context_length']) | |
# | |
# | |
# else: | |
# ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, input_enc=params['input_enc'], | |
# device=params['device'], dates=dates, input_dim=params['input_dim'], | |
# time_dim=params['input_time_dim'], noise_time=params['noise_time'], | |
# num_context=params['num_context'], transformer_input_enc=params['transformer_input_enc'], | |
# token_dim=params['species_dim'], jitter=params['add_location_noise'], | |
# variable_context_length=params['variable_context_length']) | |
# else: | |
# # this version - which should be accessed if we are not using one of my datasets - is the same as Max's | |
# ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, input_enc=params['input_enc'], | |
# device=params['device'], dates=dates, input_dim=params['input_dim'], | |
# time_dim=params['input_time_dim'], noise_time=params['noise_time']) | |
# | |
# return ds | |
def get_train_data(params): | |
print('modified this to hopefully reduce loading during eval') | |
if 'eval' in params['dataset']: | |
print('Creating eval dataset') | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
eval_data_path = os.path.join(paths['data'], 'positive_eval_data.npz') | |
thing = np.load(eval_data_path, allow_pickle=True) | |
locs = torch.from_numpy(thing['locs']) | |
labels = torch.from_numpy(thing['labels']) | |
classes = thing['classes'].item() | |
class_to_taxa = list(thing['class_to_taxa']) | |
# for now dates is just set to None. Change if I start doing time stuff. | |
dates = None | |
if 'transformer' in params['dataset']: | |
# holdover to make earlier models work | |
for key in ['add_location_noise', 'variable_context_length']: | |
if key not in params: | |
params[key] = False | |
# loading text data for my models | |
if 'text' in params['dataset']: | |
# Load the text embeddings | |
#embs_dict = torch.load(params['text_emb_path'], map_location='cpu', weights_only=False) | |
embs_dict = torch.load(params['text_emb_path'], map_location='cpu') | |
# Extract the embeddings (assumed to be stored under the 'data' key) | |
embs = embs_dict['data'] # This should be a tensor of shape [num_taxa, 4096] | |
# Extract the taxon IDs (assumed to be stored under the 'taxon_id' key) | |
embs_ids = embs_dict['taxon_id'] # This could be a list or a tensor | |
embs_keys = embs_dict['keys'] | |
ds = get_dataset_class(params['dataset'])( | |
locs, labels, classes, class_to_taxa, embs=embs, embs_ids=embs_ids, | |
embs_keys=embs_keys, input_enc=params['input_enc'], | |
device=params['device'], dates=dates, input_dim=params['input_dim'], | |
time_dim=params['input_time_dim'], noise_time=params['noise_time'], | |
num_context=params['num_context'], | |
transformer_input_enc=params['transformer_input_enc'], | |
token_dim=params['species_dim'], | |
jitter=params['add_location_noise'], | |
variable_context_length=params['variable_context_length'] | |
) | |
else: | |
ds = get_dataset_class(params['dataset'])( | |
locs, labels, classes, class_to_taxa, input_enc=params['input_enc'], | |
device=params['device'], dates=dates, input_dim=params['input_dim'], | |
time_dim=params['input_time_dim'], noise_time=params['noise_time'], | |
num_context=params['num_context'], | |
transformer_input_enc=params['transformer_input_enc'], | |
token_dim=params['species_dim'], | |
jitter=params['add_location_noise'], | |
variable_context_length=params['variable_context_length'] | |
) | |
elif 'variable' in params['dataset']: | |
# Load the text embeddings | |
text_embs_dict = torch.load(params['text_emb_path'], map_location='cpu', weights_only=False) | |
text_embs = text_embs_dict['data'] | |
text_embs_ids = text_embs_dict['taxon_id'] | |
text_embs_keys = text_embs_dict['keys'] | |
# Load the image embeddings | |
image_embs_dict = torch.load(params['image_emb_path'], map_location='cpu', weights_only=False) | |
image_embs = image_embs_dict['data'] | |
image_embs_ids = image_embs_dict['taxon_id'] | |
image_embs_keys = image_embs_dict['keys'] | |
ds = dataset_classes['variable_tokens']( | |
locs, labels, classes, class_to_taxa, text_embs=text_embs, | |
text_embs_ids=text_embs_ids, text_embs_keys=text_embs_keys, image_embs=image_embs, | |
image_embs_ids=image_embs_ids, image_embs_keys=image_embs_keys, input_enc=params['input_enc'], | |
device=params['device'], dates=dates, input_dim=params['input_dim'], | |
time_dim=params['input_time_dim'], noise_time=params['noise_time'], | |
num_context=params['num_context'], transformer_input_enc=params['transformer_input_enc'], | |
token_dim=params['species_dim'], jitter=params['add_location_noise'], | |
variable_context_length=params['variable_context_length'], loc_prob=params['loc_prob'], | |
text_prob=params['text_prob'], image_prob=params['image_prob'], env_prob=0.0, eval_mode=True | |
) | |
else: | |
# print('Creating train dataset') | |
# with open('paths.json', 'r') as f: | |
# paths = json.load(f) | |
# with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
# data = json.load(f) | |
# D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True) | |
# D = D.item() | |
# taxa_snt = D['taxa'].tolist() | |
# taxa = [int(tt) for tt in data['taxa_presence'].keys()] | |
# taxa = list(set(taxa + taxa_snt)) | |
# mask = labels != taxa[0] | |
# # MINE MINE MINE MINE MINE | |
# # for i in range(0, len(taxa)): | |
# # MAX MAX MAX MAX | |
# for i in range(1, len(taxa)): | |
# mask &= (labels != taxa[i]) | |
# locs = locs[mask] | |
# dates = dates[mask] | |
# labels = labels[mask] | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
data_dir = paths['train'] | |
obs_file = os.path.join(data_dir, params['obs_file']) | |
taxa_file = os.path.join(data_dir, params['taxa_file']) | |
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json') | |
taxa_of_interest = get_taxa_of_interest( | |
params['species_set'], params['num_aux_species'], params['aux_species_seed'], | |
params['taxa_file'], taxa_file_snt | |
) | |
locs, labels, _, dates, _, _ = load_inat_data(obs_file, taxa_of_interest) | |
if params['zero_shot']: | |
print('Hopefully we are now avoiding loading the SNT and IUCN data to find the taxa in it') | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
eval_taxa_path = os.path.join(paths['data'], 'eval_taxa_list.npy') | |
taxa = np.load(eval_taxa_path, allow_pickle=True) | |
mask = labels != taxa[0] | |
for i in range(1, len(taxa)): | |
mask &= (labels != taxa[i]) | |
locs = locs[mask] | |
dates = dates[mask] | |
labels = labels[mask] | |
# if params['zero_shot']: | |
# with open('paths.json', 'r') as f: | |
# paths = json.load(f) | |
# with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
# data = json.load(f) | |
# D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True) | |
# D = D.item() | |
# taxa_snt = D['taxa'].tolist() | |
# taxa = [int(tt) for tt in data['taxa_presence'].keys()] | |
# taxa = list(set(taxa + taxa_snt)) | |
# mask = labels != taxa[0] | |
# for i in range(1, len(taxa)): | |
# mask &= (labels != taxa[i]) | |
# locs = locs[mask] | |
# dates = dates[mask] | |
# labels = labels[mask] | |
unique_taxa, class_ids = np.unique(labels, return_inverse=True) | |
class_to_taxa = unique_taxa.tolist() | |
# load class names | |
class_info_file = json.load(open(taxa_file, 'r')) | |
class_names_file = [cc['latin_name'] for cc in class_info_file] | |
taxa_ids_file = [cc['taxon_id'] for cc in class_info_file] | |
classes = dict(zip(taxa_ids_file, class_names_file)) | |
subset = None | |
if params['subset_cap_name'] is not None: | |
if params['subset_cap_name'] == 'iucn': | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
data = json.load(f) | |
taxa = [int(tt) for tt in data['taxa_presence'].keys()] | |
# get classes to eval | |
subset = np.zeros((len(taxa),), dtype=int) | |
for tt_id, tt in enumerate(taxa): | |
class_of_interest = np.where(np.array(class_to_taxa) == tt)[0] | |
if len(class_of_interest) != 0: | |
subset[tt_id] = class_of_interest | |
else: | |
raise NotImplementedError(f'Uknown subset name: {params["subset_cap_name"]}') | |
idx_ss = get_idx_subsample_observations( | |
labels, params['hard_cap_num_per_class'], params['hard_cap_seed'], subset, | |
params['subset_cap_num_per_class'] | |
) | |
locs = torch.from_numpy(np.array(locs)[idx_ss]) # convert to Tensor | |
labels = torch.from_numpy(np.array(class_ids)[idx_ss]) | |
dates = 364/365*torch.from_numpy(np.array(dates)[idx_ss]) if params['input_time'] else None | |
if 'variable' in params['dataset']: | |
# Load the text embeddings | |
text_embs_dict = torch.load(params['text_emb_path'], map_location='cpu', weights_only=False) | |
text_embs = text_embs_dict['data'] | |
text_embs_ids = text_embs_dict['taxon_id'] | |
text_embs_keys = text_embs_dict['keys'] | |
# Load the image embeddings | |
image_embs_dict = torch.load(params['image_emb_path'], map_location='cpu', weights_only=False) | |
image_embs = image_embs_dict['data'] | |
image_embs_ids = image_embs_dict['taxon_id'] | |
image_embs_keys = image_embs_dict['keys'] | |
ds = dataset_classes['variable_tokens']( | |
locs, labels, classes, class_to_taxa, text_embs=text_embs, text_embs_ids=text_embs_ids, | |
text_embs_keys=text_embs_keys, image_embs=image_embs, image_embs_ids=image_embs_ids, | |
image_embs_keys=image_embs_keys, input_enc=params['input_enc'], | |
device=params['device'], dates=dates, input_dim=params['input_dim'], | |
time_dim=params['input_time_dim'], noise_time=params['noise_time'], | |
num_context=params['num_context'], transformer_input_enc=params['transformer_input_enc'], | |
token_dim=params['species_dim'], jitter=params['add_location_noise'], | |
variable_context_length=params['variable_context_length'], loc_prob=params['loc_prob'], | |
text_prob=params['text_prob'], image_prob=params['image_prob'], env_prob=0.0, eval_mode=False | |
) | |
elif 'transformer' in params['dataset']: | |
# holdover to make earlier models work | |
for key in ['add_location_noise', 'variable_context_length']: | |
if key not in params: | |
params[key] = False | |
# loading text data for my models | |
if 'text' in params['dataset']: | |
# Load the text embeddings | |
embs_dict = torch.load(params['text_emb_path'], map_location='cpu', weights_only=False) | |
# Extract the embeddings (assumed to be stored under the 'data' key) | |
embs = embs_dict['data'] # This should be a tensor of shape [num_taxa, 4096] | |
# Extract the taxon IDs (assumed to be stored under the 'taxon_id' key) | |
embs_ids = embs_dict['taxon_id'] # This could be a list or a tensor | |
embs_keys = embs_dict['keys'] | |
ds = get_dataset_class(params['dataset'])( | |
locs, labels, classes, class_to_taxa, embs=embs, embs_ids=embs_ids, | |
embs_keys=embs_keys, input_enc=params['input_enc'], | |
device=params['device'], dates=dates, input_dim=params['input_dim'], | |
time_dim=params['input_time_dim'], noise_time=params['noise_time'], | |
num_context=params['num_context'], | |
transformer_input_enc=params['transformer_input_enc'], | |
token_dim=params['species_dim'], jitter=params['add_location_noise'], | |
variable_context_length=params['variable_context_length'] | |
) | |
else: | |
ds = dataset_classes[params['dataset']]( | |
locs, labels, classes, class_to_taxa, input_enc=params['input_enc'], | |
device=params['device'], dates=dates, input_dim=params['input_dim'], | |
time_dim=params['input_time_dim'], noise_time=params['noise_time'], | |
num_context=params['num_context'], transformer_input_enc=params['transformer_input_enc'], | |
token_dim=params['species_dim'], jitter=params['add_location_noise'], | |
variable_context_length=params['variable_context_length'] | |
) | |
else: | |
# this version - which should be accessed if we are not using one of my datasets - is the same as Max's | |
ds = dataset_classes[params['dataset']]( | |
locs, labels, classes, class_to_taxa, input_enc=params['input_enc'], | |
device=params['device'], dates=dates, input_dim=params['input_dim'], | |
time_dim=params['input_time_dim'], noise_time=params['noise_time'] | |
) | |
return ds | |
def test_dataset(): | |
import setup | |
from tqdm import tqdm | |
train_params = {} | |
train_params['species_set'] = 'all' | |
train_params['hard_cap_num_per_class'] = -1 | |
train_params['num_aux_species'] = 0 | |
train_params['input_enc'] = 'sin_cos_env' | |
train_params['input_dim'] = 8 | |
train_params['input_time'] = False | |
train_params['input_time_dim'] = 0 | |
train_params['num_epochs'] = 50 | |
train_params['noise_time'] = False | |
train_params['loss'] = 'an_full' | |
train_params['dataset'] = 'iucn_inat' | |
params = setup.get_default_params_train(train_params) | |
train_dataset = get_train_data(params) | |
train_dataset.viz_map(10070) | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=params['batch_size'], | |
shuffle=True, | |
num_workers=0, | |
collate_fn=getattr(train_dataset, 'collate_fn', None)) | |
for _ in tqdm(train_loader): | |
pass | |
if __name__ == '__main__': | |
test_dataset() | |