fs_sinr / datasets.py
angelazhu96
rewrite history
0b54529
import os
import numpy as np
import json
import pandas as pd
from calendar import monthrange
import torch
import utils
import random
#from h3.unstable import vect
import h3.api.numpy_int as h3
from torch.nn.utils.rnn import pad_sequence
from functools import partial
import re
class LocationDataset(torch.utils.data.Dataset):
# MINE MINE MINE MINE - I added the dummy "num_context" - probably not needed anymore
# def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False, num_context=0):
# MAX MAX MAX MAX MAX
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False):
# handle input encoding:
self.input_enc = input_enc
if self.input_enc in ['env', 'sin_cos_env']:
raster = load_env()
else:
raster = None
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
# define some properties:
self.locs = locs
self.loc_feats = self.enc.encode(self.locs)
self.labels = labels
self.classes = classes
self.class_to_taxa = class_to_taxa
if dates is not None:
self.dates = dates
self.enc_time = utils.TimeEncoder()
# useful numbers:
self.num_classes = len(np.unique(labels))
self.input_dim = input_dim
self.time_dim = time_dim
self.noise_time = noise_time
if self.enc.raster is not None:
self.enc.raster = self.enc.raster.to(device)
def __len__(self):
return self.loc_feats.shape[0]
def __getitem__(self, index):
loc_feat = self.loc_feats[index, :]
loc = self.locs[index, :]
class_id = self.labels[index]
if self.time_dim > 0:
date = self.dates[index]
# add noise
if self.noise_time:
noise_level = random.random()
noise = (2*random.random() - 1) * (0.5*noise_level)
loc_feat = torch.cat([loc_feat, self.enc_time.encode_fast([date.item()+noise,noise_level])])
else:
raise NotImplementedError()
loc_feat = torch.cat([loc_feat, torch.tensor(self.enc_time.encode([2*date.item()-1], normalize=False))])
return loc_feat, torch.cat([loc, date[None]]), class_id
else:
return loc_feat, loc, class_id
# MINE MINE MINE MINE - should only be used for my models
# I need this due to the "context" points that need to be returned with each main training point
class TransformerLocationDataset(torch.utils.data.Dataset):
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0,
noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False,
variable_context_length=False):
# Handle input encoding:
self.input_enc = input_enc
self.jitter = jitter
self.variable_context_length = variable_context_length
if self.input_enc in ['env', 'sin_cos_env']:
raster = load_env()
else:
raster = None
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
# Handle transformer input encoding:
self.transformer_input_enc = transformer_input_enc
if self.transformer_input_enc in ['env', 'sin_cos_env']:
transformer_raster = load_env()
else:
transformer_raster = None
if self.transformer_input_enc == 'sinr':
self.transformer_enc = self.enc
else:
self.transformer_enc = utils.CoordEncoder(
transformer_input_enc, transformer_raster, input_dim=token_dim)
# Define properties:
self.locs = locs # Keep on CPU
self.loc_feats = self.enc.encode(self.locs, normalize=True)
transformer_loc_feats = self.transformer_enc.encode(
self.locs, normalize=False)
self.labels = labels # Keep on CPU
self.classes = classes
self.class_to_taxa = class_to_taxa
if dates is not None:
self.dates = dates
self.enc_time = utils.TimeEncoder()
# Useful numbers:
self.num_classes = len(np.unique(labels))
self.input_dim = input_dim
self.time_dim = time_dim
self.noise_time = noise_time
self.num_context = num_context
self.token_dim = token_dim
# Remove device assignment from raster encoders
# if self.enc.raster is not None:
# self.enc.raster = self.enc.raster.to(device)
# if self.transformer_enc.raster is not None:
# self.transformer_enc.raster = self.transformer_enc.raster.to(device)
# Organize the data into dictionaries
per_class_location_dict = organize_data_by_labels(
np.array(labels), np.array(locs))
per_class_loc_feats_dict = organize_data_by_labels(
np.array(labels), np.array(transformer_loc_feats))
for key, value in per_class_location_dict.items():
per_class_location_dict[key] = torch.tensor(
np.array(value)) # Keep on CPU
for key, value in per_class_loc_feats_dict.items():
per_class_loc_feats_dict[key] = torch.tensor(
np.array(value)) # Keep on CPU
self.per_class_locs = per_class_location_dict
self.per_class_loc_feats = per_class_loc_feats_dict
def __len__(self):
return self.loc_feats.shape[0]
def __getitem__(self, index):
# Retrieve the feature and class of the original point
loc_feat = self.loc_feats[index, :] # On CPU
loc = self.locs[index, :] # On CPU
class_id = self.labels[index] # On CPU
class_id_int = class_id.item()
# Fetch all locations for the given class
all_class_locs = self.per_class_locs[class_id_int] # On CPU
all_class_loc_feats = self.per_class_loc_feats[class_id_int] # On CPU
# Define a unique class token index
class_token_feature = torch.zeros(
(1, len(all_class_loc_feats[0]))) # CPU tensor
# Broadcast and compare to find all matching locations
matches = (all_class_locs == loc).all(dim=1)
# Find the index of the original location
local_index = torch.where(matches)[0]
if len(local_index) > 1:
local_index = local_index[0]
# Exclude the original location's index
filtered_local_indices = torch.arange(
len(all_class_locs)) != local_index
if self.variable_context_length:
num_context = random.randint(1, self.num_context)
else:
num_context = self.num_context
# Select indices for context
if filtered_local_indices.sum() > num_context:
selected_indices = filtered_local_indices.nonzero().squeeze()
perm = torch.randperm(selected_indices.size(0))
selected_indices = selected_indices[perm][:num_context]
else:
selected_indices = filtered_local_indices.nonzero().squeeze()
# Get context locations and features
context_loc_feats = all_class_loc_feats[selected_indices]
context_locs = all_class_locs[selected_indices]
# Adjust dimensions if necessary
if context_loc_feats.dim() == 1:
context_loc_feats = context_loc_feats.unsqueeze(0)
if context_locs.dim() == 1:
context_locs = context_locs.unsqueeze(0)
if self.jitter:
noise_std = 0.001
noise = torch.full_like(context_loc_feats, noise_std)
context_loc_feats += noise
context_sequence = torch.cat(
[class_token_feature, context_loc_feats], dim=0)
return loc_feat, loc, class_id, context_sequence, context_locs
def collate_fn(self, batch):
# Unpack the batch
loc_feats, locs, class_ids, context_sequences, context_locss = zip(*batch)
# Pad sequences
padded_sequences = pad_sequence(
context_sequences, batch_first=True, padding_value=-10)
padded_context_locs = pad_sequence(
context_locss, batch_first=True, padding_value=-10)
# Convert lists to tensors
loc_feats = torch.stack(loc_feats)
locs = torch.stack(locs)
class_ids = torch.tensor(class_ids)
# Create a mask for sequences based on padding
sequence_mask = (padded_sequences == -10).all(dim=-1)
return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask
def get_item_from_class(self, class_id):
# Fetch locations and features for the class
all_class_locs = self.per_class_locs[class_id]
all_class_loc_feats = self.per_class_loc_feats[class_id]
# Randomly select an index
index = np.random.choice(len(all_class_locs))
# Retrieve selected location and features
loc = all_class_locs[index]
if loc.ndim == 1:
loc = loc.unsqueeze(0)
loc_feat = self.enc.encode(loc, normalize=False)
# Define a unique class token index
class_token_feature = torch.zeros((1, self.token_dim)) # CPU tensor
# Exclude selected index from context
filtered_local_indices = torch.arange(len(all_class_locs)) != index
# Select indices for context
if filtered_local_indices.sum() > self.num_context:
selected_indices = filtered_local_indices.nonzero().squeeze()
perm = torch.randperm(selected_indices.size(0))
selected_indices = selected_indices[perm][:self.num_context]
else:
selected_indices = filtered_local_indices.nonzero().squeeze()
# Get context locations and features
context_loc_feats = all_class_loc_feats[selected_indices]
context_locs = all_class_locs[selected_indices]
# Adjust dimensions if necessary
if context_loc_feats.dim() == 1:
context_loc_feats = context_loc_feats.unsqueeze(0)
if context_locs.dim() == 1:
context_locs = context_locs.unsqueeze(0)
context_sequence = torch.cat(
[class_token_feature, context_loc_feats], dim=0)
return loc_feat, loc, class_id, context_sequence, context_locs
#
# class TransformerLocationDataset(torch.utils.data.Dataset):
# def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0,
# noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False,
# variable_context_length=False):
# # handle input encoding:
# self.input_enc = input_enc
# self.jitter=jitter
# self.variable_context_length=variable_context_length
# if self.input_enc in ['env', 'sin_cos_env']:
# raster = load_env()
# else:
# raster = None
# self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
#
# # handle transformer input encoding:
# self.transformer_input_enc = transformer_input_enc
# if self.transformer_input_enc in ['env', 'sin_cos_env']:
# transformer_raster = load_env()
# else:
# transformer_raster = None
# if self.transformer_input_enc == 'sinr':
# self.transformer_enc = self.enc
# else:
# self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim)
#
# # define some properties:
# self.locs = locs
# # Below line also normalises locs as well as making loc feats
# self.loc_feats = self.enc.encode(self.locs, normalize=True)
# transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False)
# self.labels = labels
# self.classes = classes
# self.class_to_taxa = class_to_taxa
# if dates is not None:
# self.dates = dates
# self.enc_time = utils.TimeEncoder()
#
# # useful numbers:
# self.num_classes = len(np.unique(labels))
# self.input_dim = input_dim
# self.time_dim = time_dim
# self.noise_time = noise_time
# self.num_context = num_context
# self.token_dim = token_dim
#
# if self.enc.raster is not None:
# self.enc.raster = self.enc.raster.to(device)
#
# if self.transformer_enc.raster is not None:
# self.transformer_enc.raster = self.transformer_enc.raster.to(device)
#
# # Organize the data into the dictionary
# per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(locs))
# per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats))
# for key, value in per_class_location_dict.items():
# per_class_location_dict[key] = torch.tensor(np.array(value))
# for key, value in per_class_loc_feats_dict.items():
# per_class_loc_feats_dict[key] = torch.tensor(np.array(value))
# self.per_class_locs = per_class_location_dict
# self.per_class_loc_feats = per_class_loc_feats_dict
#
# def __len__(self):
# return self.loc_feats.shape[0]
#
# def __getitem__(self, index):
# # Retrieve the feature and class of the original point
# loc_feat = self.loc_feats[index, :]
# loc = self.locs[index, :]
# class_id = self.labels[index]
# class_id_int = class_id.item()
#
# # Fetch all locations for the given class
# all_class_locs = self.per_class_locs[class_id_int]
# all_class_loc_feats = self.per_class_loc_feats[class_id_int]
#
# # Define a unique class token index
# class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # Create a zero vector for the class token
#
# # Broadcast and compare to find all matching locations
# matches = (all_class_locs == loc).all(dim=1)
#
# # Find the index of the original location
# local_index = torch.where(matches)[0]
# if len(local_index) > 1:
# local_index = local_index[0]
#
# # Exclude the original location's index
# filtered_local_indices = torch.arange(len(all_class_locs)) != local_index
#
# if self.variable_context_length:
# num_context = random.randint(1, self.num_context)
# else:
# num_context = self.num_context
#
# # Select random or all indices depending on the availability relative to `num_context`
# if filtered_local_indices.sum() > num_context:
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
# np.random.shuffle(selected_indices)
# selected_indices = selected_indices[:num_context]
# else:
# selected_indices = filtered_local_indices.nonzero().squeeze()
#
# # Get context locations and their features
# context_loc_feats = all_class_loc_feats[selected_indices]
# context_locs = all_class_locs[selected_indices]
#
# # Check if context_loc_feats has 1 dimension and add another if it does
# if context_loc_feats.dim() == 1:
# context_loc_feats = context_loc_feats.unsqueeze(0)
#
# if self.jitter:
# noise_std = 0.001
# noise = torch.full_like(context_loc_feats, noise_std)
# context_loc_feats = context_loc_feats + noise
#
# # Check if context_locs has 1 dimension and add another if it does
# if context_locs.dim() == 1:
# context_locs = context_locs.unsqueeze(0)
#
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
#
# return loc_feat, loc, class_id, context_sequence, context_locs
#
# def collate_fn(self, batch):
# # Unzip the batch
# loc_feats, locs, class_ids, context_sequences, context_locss = zip(*batch)
#
# # Convert list of sequences to a tensor with padding
# padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10)
#
# # Convert list of class IDs to a tensor
# class_ids = torch.tensor(class_ids)
# # Convert loc_feats and locs to tensors
# loc_feats = torch.stack(loc_feats)
# locs = torch.stack(locs)
#
# padded_context_locs = pad_sequence(context_locss, batch_first=True, padding_value=-10)
#
# # Create a mask for sequences based on padding
# # sequence_mask = (padded_sequences == 0) # Create a mask where there's padding (-10s)
# sequence_mask = (padded_sequences == -10).all(dim=-1)
#
# # return padded_sequences, padded_locs, class_ids, sequence_mask
# return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask
#
# def get_item_from_class(self, class_id):
# # Fetch all locations and features for the given class
# all_class_locs = self.per_class_locs[class_id]
# all_class_loc_feats = self.per_class_loc_feats[class_id]
#
# # Randomly select an index for the class
# index = np.random.choice(len(all_class_locs))
#
# # Retrieve the selected location and its features
# loc = all_class_locs[index]
# if loc.ndim == 1:
# loc = loc.unsqueeze(0)
# # loc = loc.unsqueeze(0)
# loc_feat = self.enc.encode(loc, normalize=False)
# # loc_feat = all_class_loc_feats[index]
#
# # Define a unique class token index
# class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token
#
# # Exclude the selected index from the context
# filtered_local_indices = torch.arange(len(all_class_locs)) != index
#
# # Select random or all indices depending on the availability relative to `num_context`
# if filtered_local_indices.sum() > self.num_context:
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
# np.random.shuffle(selected_indices)
# selected_indices = selected_indices[:self.num_context]
# else:
# selected_indices = filtered_local_indices.nonzero().squeeze()
#
# # Get context locations and their features
# context_loc_feats = all_class_loc_feats[selected_indices]
# context_locs = all_class_locs[selected_indices]
#
# # Check if context_loc_feats has 1 dimension and add another if it does
# if context_loc_feats.dim() == 1:
# context_loc_feats = context_loc_feats.unsqueeze(0)
# # Check if context_locs has 1 dimension and add another if it does
# if context_locs.dim() == 1:
# context_locs = context_locs.unsqueeze(0)
#
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
#
# return loc_feat, loc, class_id, context_sequence, context_locs
# MINE MINE MINE MINE - should only be used for my models
# I need this due to the "context" points and text embeddings that need to be returned with each main training point
class TransformerLocationTextDataset(torch.utils.data.Dataset):
def __init__(self, locs, labels, classes, class_to_taxa, embs, embs_ids, embs_keys, input_enc, device, dates=None, input_dim=4, time_dim=0,
noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False,
variable_context_length=False):
# Handle input encoding:
self.input_enc = input_enc
self.jitter = jitter
self.variable_context_length = variable_context_length
if self.input_enc in ['env', 'sin_cos_env']:
raster = load_env()
else:
raster = None
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
# Handle transformer input encoding:
self.transformer_input_enc = transformer_input_enc
if self.transformer_input_enc in ['env', 'sin_cos_env']:
transformer_raster = load_env()
else:
transformer_raster = None
if self.transformer_input_enc == 'sinr':
self.transformer_enc = self.enc
else:
self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim)
# Define some properties:
self.locs = locs # Keep on CPU
# Below line also normalizes locs as well as making loc feats
self.loc_feats = self.enc.encode(self.locs, normalize=True)
transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False)
self.labels = labels # Keep on CPU
self.classes = classes
self.class_to_taxa = class_to_taxa
if dates is not None:
self.dates = dates
self.enc_time = utils.TimeEncoder()
# Useful numbers:
self.num_classes = len(np.unique(labels))
self.input_dim = input_dim
self.time_dim = time_dim
self.noise_time = noise_time
self.num_context = num_context
self.token_dim = token_dim
# Remove device assignments from rasters
# if self.enc.raster is not None:
# self.enc.raster = self.enc.raster.to(device)
# if self.transformer_enc.raster is not None:
# self.transformer_enc.raster = self.transformer_enc.raster.to(device)
# Text embeddings
self.embs = embs # Keep on CPU
self.embs_ids = embs_ids.tolist()
self.embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.embs_ids]
self.embs_keys = embs_keys
# Initialize an empty dictionary to store the result
class_emb_dict = {}
# Populate the dictionary
for i, (index, description) in enumerate(embs_keys):
# Find the class using the index from the class_list
class_id = self.embs_class_ids[index]
# Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
if class_id == -1:
continue
# Check if the class_id is already a key in the dictionary
if class_id not in class_emb_dict:
# Initialize with empty lists if class_id is not already in the dictionary
class_emb_dict[class_id] = ([], [])
# Append the description and the index of embs_keys to the corresponding lists
class_emb_dict[class_id][0].append(i)
class_emb_dict[class_id][1].append(description)
self.class_emb_dict = class_emb_dict
# Organize the data into dictionaries
per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(locs))
per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats))
for key, value in per_class_location_dict.items():
per_class_location_dict[key] = torch.tensor(np.array(value)) # Keep on CPU
for key, value in per_class_loc_feats_dict.items():
per_class_loc_feats_dict[key] = torch.tensor(np.array(value)) # Keep on CPU
self.per_class_locs = per_class_location_dict
self.per_class_loc_feats = per_class_loc_feats_dict
def __len__(self):
return self.loc_feats.shape[0]
def __getitem__(self, index):
# Retrieve the feature and class of the original point
loc_feat = self.loc_feats[index, :]
loc = self.locs[index, :]
class_id = self.labels[index]
class_id_int = class_id.item()
# Fetch all locations for the given class
all_class_locs = self.per_class_locs[class_id_int]
all_class_loc_feats = self.per_class_loc_feats[class_id_int]
# Define a unique class token index
class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # CPU tensor
# Find the index of the original location
matches = (all_class_locs == loc).all(dim=1)
local_index = torch.where(matches)[0]
if len(local_index) > 1:
local_index = local_index[0]
# Exclude the original location's index
filtered_local_indices = torch.arange(len(all_class_locs)) != local_index
if self.variable_context_length:
num_context = random.randint(1, self.num_context)
else:
num_context = self.num_context
# Select random or all indices depending on availability
if filtered_local_indices.sum() > num_context:
selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
np.random.shuffle(selected_indices)
selected_indices = selected_indices[:num_context]
else:
selected_indices = filtered_local_indices.nonzero().squeeze()
# Get context locations and their features
context_loc_feats = all_class_loc_feats[selected_indices]
context_locs = all_class_locs[selected_indices]
# Adjust dimensions if necessary
if context_loc_feats.dim() == 1:
context_loc_feats = context_loc_feats.unsqueeze(0)
if context_locs.dim() == 1:
context_locs = context_locs.unsqueeze(0)
if self.jitter:
noise_std = 0.001
noise = torch.full_like(context_loc_feats, noise_std)
context_loc_feats = context_loc_feats + noise
context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
# Text embeddings
if class_id_int in self.class_emb_dict:
embs_indexes, descriptions = self.class_emb_dict[class_id_int]
selected_index = random.choice(embs_indexes)
emb = self.embs[selected_index]
else:
emb = torch.zeros(4096) # Adjust size if necessary
return loc_feat, loc, class_id, context_sequence, context_locs, emb
def collate_fn(self, batch):
# Unzip the batch
loc_feats, locs, class_ids, context_sequences, context_locs_list, embs = zip(*batch)
# Convert list of sequences to tensors with padding
padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10)
padded_context_locs = pad_sequence(context_locs_list, batch_first=True, padding_value=-10)
# Convert lists to tensors
loc_feats = torch.stack(loc_feats)
locs = torch.stack(locs)
class_ids = torch.tensor(class_ids)
embs = torch.stack(embs)
# Create a mask for sequences based on padding
sequence_mask = (padded_sequences == -10).all(dim=-1)
return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, embs
def get_item_from_class(self, class_id):
# Fetch all locations and features for the given class
all_class_locs = self.per_class_locs[class_id]
all_class_loc_feats = self.per_class_loc_feats[class_id]
# Randomly select an index for the class
index = np.random.choice(len(all_class_locs))
# Retrieve the selected location and its features
loc = all_class_locs[index]
if loc.ndim == 1:
loc = loc.unsqueeze(0)
loc_feat = self.enc.encode(loc, normalize=False)
# Define a unique class token index
class_token_feature = torch.zeros((1, self.token_dim)) # CPU tensor
# Exclude the selected index from the context
filtered_local_indices = torch.arange(len(all_class_locs)) != index
# Select random or all indices depending on availability
if filtered_local_indices.sum() > self.num_context:
selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
np.random.shuffle(selected_indices)
selected_indices = selected_indices[:self.num_context]
else:
selected_indices = filtered_local_indices.nonzero().squeeze()
# Get context locations and features
context_loc_feats = all_class_loc_feats[selected_indices]
context_locs = all_class_locs[selected_indices]
# Adjust dimensions if necessary
if context_loc_feats.dim() == 1:
context_loc_feats = context_loc_feats.unsqueeze(0)
if context_locs.dim() == 1:
context_locs = context_locs.unsqueeze(0)
context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
# Text embeddings
if class_id in self.class_emb_dict:
embs_indexes, descriptions = self.class_emb_dict[class_id]
selected_index = random.choice(embs_indexes)
emb = self.embs[selected_index]
else:
emb = torch.zeros(4096) # Adjust size if necessary
return loc_feat, loc, class_id, context_sequence, context_locs, emb
# def select_text_section(self, text_section):
# # Initialize an empty dictionary to store the result
# text_class_emb_dict = {}
# # Populate the dictionary
# for i, (index, description) in enumerate(self.text_embs_keys):
# # Find the class using the index from the class_list
# class_id = self.text_embs_class_ids[index]
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
# if class_id == -1:
# continue
# if description != text_section:
# continue
# # Check if the class_id is already a key in the dictionary
# if class_id not in text_class_emb_dict:
# # Initialize with empty lists if class_id is not already in the dictionary
# text_class_emb_dict[class_id] = ([], [])
#
# # Append the description and the index of embs_keys to the corresponding lists
# text_class_emb_dict[class_id][0].append(i)
# text_class_emb_dict[class_id][1].append(description)
# self.text_class_emb_dict = text_class_emb_dict
#
# class TransformerLocationTextDataset(torch.utils.data.Dataset):
# def __init__(self, locs, labels, classes, class_to_taxa, embs, embs_ids, embs_keys, input_enc, device, dates=None, input_dim=4, time_dim=0,
# noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False,
# variable_context_length=False):
# # handle input encoding:
# self.input_enc = input_enc
# self.jitter=jitter
# self.variable_context_length=variable_context_length
# if self.input_enc in ['env', 'sin_cos_env']:
# raster = load_env()
# else:
# raster = None
# self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
#
# # handle transformer input encoding:
# self.transformer_input_enc = transformer_input_enc
# if self.transformer_input_enc in ['env', 'sin_cos_env']:
# transformer_raster = load_env()
# else:
# transformer_raster = None
# if self.transformer_input_enc == 'sinr':
# self.transformer_enc = self.enc
# else:
# self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim)
#
# # define some properties:
# self.locs = locs
# # Below line also normalises locs as well as making loc feats
# self.loc_feats = self.enc.encode(self.locs, normalize=True)
# transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False)
# self.labels = labels
# self.classes = classes
# self.class_to_taxa = class_to_taxa
# if dates is not None:
# self.dates = dates
# self.enc_time = utils.TimeEncoder()
#
# # useful numbers:
# self.num_classes = len(np.unique(labels))
# self.input_dim = input_dim
# self.time_dim = time_dim
# self.noise_time = noise_time
# self.num_context = num_context
# self.token_dim = token_dim
#
# if self.enc.raster is not None:
# self.enc.raster = self.enc.raster.to(device)
#
# if self.transformer_enc.raster is not None:
# self.transformer_enc.raster = self.transformer_enc.raster.to(device)
#
# # text stuff
# # print("CHECK WHEN YOU HAVE ACCESS TO THE SERVER WHAT THE FORM OF EMBS ARE AND ALL THAT")
# self.embs = embs
# self.embs_ids = embs_ids.tolist()
# self.embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.embs_ids]
# self.embs_keys = embs_keys
#
# # Initialize an empty dictionary to store the result
# class_emb_dict = {}
# # Populate the dictionary
# for i, (index, description) in enumerate(embs_keys):
# # Find the class using the index from the class_list
# class_id = self.embs_class_ids[index]
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
# if class_id == -1:
# continue
# # Check if the class_id is already a key in the dictionary
# if class_id not in class_emb_dict:
# # Initialize with empty lists if class_id is not already in the dictionary
# class_emb_dict[class_id] = ([], [])
#
# # Append the description and the index of embs_keys to the corresponding lists
# class_emb_dict[class_id][0].append(i)
# class_emb_dict[class_id][1].append(description)
# self.class_emb_dict = class_emb_dict
#
#
# # Organize the data into the dictionary
# per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(locs))
# per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats))
# for key, value in per_class_location_dict.items():
# per_class_location_dict[key] = torch.tensor(np.array(value))
# for key, value in per_class_loc_feats_dict.items():
# per_class_loc_feats_dict[key] = torch.tensor(np.array(value))
# self.per_class_locs = per_class_location_dict
# self.per_class_loc_feats = per_class_loc_feats_dict
#
#
# def __len__(self):
# return self.loc_feats.shape[0]
#
# def __getitem__(self, index):
# # Retrieve the feature and class of the original point
# loc_feat = self.loc_feats[index, :]
# loc = self.locs[index, :]
# class_id = self.labels[index]
# class_id_int = class_id.item()
#
# # Fetch all locations for the given class
# all_class_locs = self.per_class_locs[class_id_int]
# all_class_loc_feats = self.per_class_loc_feats[class_id_int]
#
# # Define a unique class token index
# class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # Create a zero vector for the class token
#
# # Broadcast and compare to find all matching locations
# matches = (all_class_locs == loc).all(dim=1)
#
# # Find the index of the original location
# local_index = torch.where(matches)[0]
# if len(local_index) > 1:
# local_index = local_index[0]
#
# # Exclude the original location's index
# filtered_local_indices = torch.arange(len(all_class_locs)) != local_index
#
# if self.variable_context_length:
# num_context = random.randint(1, self.num_context)
# else:
# num_context = self.num_context
#
# # Select random or all indices depending on the availability relative to `num_context`
# if filtered_local_indices.sum() > num_context:
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
# np.random.shuffle(selected_indices)
# selected_indices = selected_indices[:num_context]
# else:
# selected_indices = filtered_local_indices.nonzero().squeeze()
#
# # Get context locations and their features
# context_loc_feats = all_class_loc_feats[selected_indices]
# context_locs = all_class_locs[selected_indices]
#
# # Check if context_loc_feats has 1 dimension and add another if it does
# if context_loc_feats.dim() == 1:
# context_loc_feats = context_loc_feats.unsqueeze(0)
#
# if self.jitter:
# noise_std = 0.001
# noise = torch.full_like(context_loc_feats, noise_std)
# context_loc_feats = context_loc_feats + noise
#
# # Check if context_locs has 1 dimension and add another if it does
# if context_locs.dim() == 1:
# context_locs = context_locs.unsqueeze(0)
#
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
#
# # text stuff
# # Get the embedding for the right class
# if class_id_int in self.class_emb_dict:
# embs_indexes, descriptions = self.class_emb_dict[class_id_int]
# # Randomly select an index from the list of indices
# selected_index = random.choice(embs_indexes)
# # Use the selected index to retrieve the corresponding element from embs
# emb = self.embs[selected_index]
# else:
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later?
# # emb = None
# emb = torch.zeros(4096)
#
# return loc_feat, loc, class_id, context_sequence, context_locs, emb
#
# def collate_fn(self, batch):
# # Unzip the batch
# loc_feats, locs, class_ids, context_sequences, context_locss, embs = zip(*batch)
#
# # Convert list of sequences to a tensor with padding
# padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10)
#
# # Convert list of class IDs to a tensor
# class_ids = torch.tensor(class_ids)
# # Convert loc_feats and locs to tensors
# loc_feats = torch.stack(loc_feats)
# locs = torch.stack(locs)
# embs = torch.stack(embs)
#
# padded_context_locs = pad_sequence(context_locss, batch_first=True, padding_value=-10)
#
# # Create a mask for sequences based on padding
# # sequence_mask = (padded_sequences == 0) # Create a mask where there's padding (-10s)
# sequence_mask = (padded_sequences == -10).all(dim=-1)
#
# # return padded_sequences, padded_locs, class_ids, sequence_mask
# return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, embs
#
# def get_item_from_class(self, class_id):
# # Fetch all locations and features for the given class
# all_class_locs = self.per_class_locs[class_id]
# all_class_loc_feats = self.per_class_loc_feats[class_id]
#
# # Randomly select an index for the class
# index = np.random.choice(len(all_class_locs))
#
# # Retrieve the selected location and its features
# loc = all_class_locs[index]
# if loc.ndim == 1:
# loc = loc.unsqueeze(0)
# # loc = loc.unsqueeze(0)
# loc_feat = self.enc.encode(loc, normalize=False)
# # loc_feat = all_class_loc_feats[index]
#
# # Define a unique class token index
# class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token
#
# # Exclude the selected index from the context
# filtered_local_indices = torch.arange(len(all_class_locs)) != index
#
# # Select random or all indices depending on the availability relative to `num_context`
# if filtered_local_indices.sum() > self.num_context:
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
# np.random.shuffle(selected_indices)
# selected_indices = selected_indices[:self.num_context]
# else:
# selected_indices = filtered_local_indices.nonzero().squeeze()
#
# # Get context locations and their features
# context_loc_feats = all_class_loc_feats[selected_indices]
# context_locs = all_class_locs[selected_indices]
#
# # Check if context_loc_feats has 1 dimension and add another if it does
# if context_loc_feats.dim() == 1:
# context_loc_feats = context_loc_feats.unsqueeze(0)
# # Check if context_locs has 1 dimension and add another if it does
# if context_locs.dim() == 1:
# context_locs = context_locs.unsqueeze(0)
#
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
#
# # text stuff
# # Get the embedding for the right class
# if class_id in self.class_emb_dict:
# embs_indexes, descriptions = self.class_emb_dict[class_id]
# # Randomly select an index from the list of indices
# selected_index = random.choice(embs_indexes)
# # Use the selected index to retrieve the corresponding element from embs
# emb = self.embs[selected_index]
# else:
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later?
# # emb = None
# emb = torch.zeros(4096)
#
# return loc_feat, loc, class_id, context_sequence, context_locs, emb
class TransformerLocationTextDatasetRandomizeOutputs(torch.utils.data.Dataset):
def __init__(self, locs, labels, classes, class_to_taxa, embs, embs_ids, embs_keys, input_enc, device, dates=None, input_dim=4, time_dim=0,
noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False,
variable_context_length=False, just_obs_prob=0.2, just_text_prob=0.1):
# Handle input encoding:
self.input_enc = input_enc
self.jitter = jitter
self.variable_context_length = variable_context_length
if self.input_enc in ['env', 'sin_cos_env']:
raster = load_env()
else:
raster = None
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
# Handle transformer input encoding:
self.transformer_input_enc = transformer_input_enc
if self.transformer_input_enc in ['env', 'sin_cos_env']:
transformer_raster = load_env()
else:
transformer_raster = None
if self.transformer_input_enc == 'sinr':
self.transformer_enc = self.enc
else:
self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim)
# Define some properties:
self.locs = locs # Keep on CPU
# Below line also normalizes locs as well as making loc feats
self.loc_feats = self.enc.encode(self.locs, normalize=True)
transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False)
self.labels = labels # Keep on CPU
self.classes = classes
self.class_to_taxa = class_to_taxa
if dates is not None:
self.dates = dates
self.enc_time = utils.TimeEncoder()
# Useful numbers:
self.num_classes = len(np.unique(labels))
self.input_dim = input_dim
self.time_dim = time_dim
self.noise_time = noise_time
self.num_context = num_context
self.token_dim = token_dim
# Remove device assignments from rasters
# if self.enc.raster is not None:
# self.enc.raster = self.enc.raster.to(device)
# if self.transformer_enc.raster is not None:
# self.transformer_enc.raster = self.transformer_enc.raster.to(device)
# Text embeddings
self.embs = embs # Keep on CPU
self.embs_ids = embs_ids.tolist()
self.embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.embs_ids]
self.embs_keys = embs_keys
# Initialize an empty dictionary to store the result
class_emb_dict = {}
# Populate the dictionary
for i, (index, description) in enumerate(embs_keys):
class_id = self.embs_class_ids[index]
if class_id == -1:
continue
if class_id not in class_emb_dict:
class_emb_dict[class_id] = ([], [])
class_emb_dict[class_id][0].append(i)
class_emb_dict[class_id][1].append(description)
self.class_emb_dict = class_emb_dict
# Organize the data into dictionaries
per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(locs))
per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats))
for key, value in per_class_location_dict.items():
per_class_location_dict[key] = torch.tensor(np.array(value)) # Keep on CPU
for key, value in per_class_loc_feats_dict.items():
per_class_loc_feats_dict[key] = torch.tensor(np.array(value)) # Keep on CPU
self.per_class_locs = per_class_location_dict
self.per_class_loc_feats = per_class_loc_feats_dict
self.just_obs_prob = just_obs_prob
self.just_text_prob = just_text_prob
self.both_prob = 1 - (just_obs_prob + just_text_prob)
if self.both_prob < 0:
raise ValueError('Probability of "just text" and "just observations" must sum to less than 1')
def __len__(self):
return self.loc_feats.shape[0]
def __getitem__(self, index):
# Determine the type of data to return
data_type = self.roll_output()
# Retrieve the feature and class of the original point
loc_feat = self.loc_feats[index, :]
loc = self.locs[index, :]
class_id = self.labels[index]
class_id_int = class_id.item()
if 'text' in data_type:
# Text embeddings
if class_id_int in self.class_emb_dict:
embs_indexes, descriptions = self.class_emb_dict[class_id_int]
selected_index = random.choice(embs_indexes)
emb = self.embs[selected_index]
else:
emb = torch.zeros(4096) # Adjust size if necessary
if data_type == 'text':
data_type = 'obs_text'
else:
emb = torch.zeros(4096) # Adjust size if necessary
if 'obs' in data_type:
# Fetch all locations for the given class
all_class_locs = self.per_class_locs[class_id_int]
all_class_loc_feats = self.per_class_loc_feats[class_id_int]
# Define a unique class token index
class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # CPU tensor
# Find the index of the original location
matches = (all_class_locs == loc).all(dim=1)
local_index = torch.where(matches)[0]
if len(local_index) > 1:
local_index = local_index[0]
# Exclude the original location's index
filtered_local_indices = torch.arange(len(all_class_locs)) != local_index
if self.variable_context_length:
num_context = random.randint(1, self.num_context)
else:
num_context = self.num_context
# Select random or all indices depending on availability
if filtered_local_indices.sum() > num_context:
selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
np.random.shuffle(selected_indices)
selected_indices = selected_indices[:num_context]
else:
selected_indices = filtered_local_indices.nonzero().squeeze()
# Get context locations and features
context_loc_feats = all_class_loc_feats[selected_indices]
context_locs = all_class_locs[selected_indices]
# Adjust dimensions if necessary
if context_loc_feats.dim() == 1:
context_loc_feats = context_loc_feats.unsqueeze(0)
if context_locs.dim() == 1:
context_locs = context_locs.unsqueeze(0)
if self.jitter:
noise_std = 0.001
noise = torch.full_like(context_loc_feats, noise_std)
context_loc_feats = context_loc_feats + noise
context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
else:
class_token_feature = torch.zeros((1, len(loc_feat))) # CPU tensor
context_sequence = torch.cat([class_token_feature], dim=0)
context_locs = torch.empty((0, loc.size(0))) # CPU tensor with correct dimensions
return loc_feat, loc, class_id, context_sequence, context_locs, emb
def roll_output(self):
return random.choices(
["obs", "text", "obs_text"],
[self.just_obs_prob, self.just_text_prob, self.both_prob]
)[0]
def collate_fn(self, batch):
# Unzip the batch
loc_feats, locs, class_ids, context_sequences, context_locs_list, embs = zip(*batch)
# Convert list of sequences to tensors with padding
padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10)
padded_context_locs = pad_sequence(context_locs_list, batch_first=True, padding_value=-10)
# Convert lists to tensors
loc_feats = torch.stack(loc_feats)
locs = torch.stack(locs)
class_ids = torch.tensor(class_ids)
embs = torch.stack(embs)
# Create a mask for sequences based on padding
sequence_mask = (padded_sequences == -10).all(dim=-1)
return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, embs
def get_item_from_class(self, class_id):
# Fetch all locations and features for the given class
all_class_locs = self.per_class_locs[class_id]
all_class_loc_feats = self.per_class_loc_feats[class_id]
# Randomly select an index for the class
index = np.random.choice(len(all_class_locs))
# Retrieve the selected location and its features
loc = all_class_locs[index]
if loc.ndim == 1:
loc = loc.unsqueeze(0)
loc_feat = self.enc.encode(loc, normalize=False)
# Define a unique class token index
class_token_feature = torch.zeros((1, self.token_dim)) # CPU tensor
# Exclude the selected index from the context
filtered_local_indices = torch.arange(len(all_class_locs)) != index
# Select random or all indices depending on availability
if filtered_local_indices.sum() > self.num_context:
selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
np.random.shuffle(selected_indices)
selected_indices = selected_indices[:self.num_context]
else:
selected_indices = filtered_local_indices.nonzero().squeeze()
# Get context locations and features
context_loc_feats = all_class_loc_feats[selected_indices]
context_locs = all_class_locs[selected_indices]
# Adjust dimensions if necessary
if context_loc_feats.dim() == 1:
context_loc_feats = context_loc_feats.unsqueeze(0)
if context_locs.dim() == 1:
context_locs = context_locs.unsqueeze(0)
context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
# Text embeddings
if class_id in self.class_emb_dict:
embs_indexes, descriptions = self.class_emb_dict[class_id]
selected_index = random.choice(embs_indexes)
emb = self.embs[selected_index]
else:
emb = torch.zeros(4096) # Adjust size if necessary
return loc_feat, loc, class_id, context_sequence, context_locs, emb
# def select_text_section(self, text_section):
# # Initialize an empty dictionary to store the result
# text_class_emb_dict = {}
# # Populate the dictionary
# for i, (index, description) in enumerate(self.text_embs_keys):
# # Find the class using the index from the class_list
# class_id = self.text_embs_class_ids[index]
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
# if class_id == -1:
# continue
# if description != text_section:
# continue
# # Check if the class_id is already a key in the dictionary
# if class_id not in text_class_emb_dict:
# # Initialize with empty lists if class_id is not already in the dictionary
# text_class_emb_dict[class_id] = ([], [])
#
# # Append the description and the index of embs_keys to the corresponding lists
# text_class_emb_dict[class_id][0].append(i)
# text_class_emb_dict[class_id][1].append(description)
# self.text_class_emb_dict = text_class_emb_dict
#
# class TransformerLocationTextDatasetRandomizeOutputs(torch.utils.data.Dataset):
# def __init__(self, locs, labels, classes, class_to_taxa, embs, embs_ids, embs_keys, input_enc, device, dates=None, input_dim=4, time_dim=0,
# noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False,
# variable_context_length=False, just_obs_prob=0.2, just_text_prob=0.1):
# # handle input encoding:
# self.input_enc = input_enc
# self.jitter=jitter
# self.variable_context_length=variable_context_length
# if self.input_enc in ['env', 'sin_cos_env']:
# raster = load_env()
# else:
# raster = None
# self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
#
# # handle transformer input encoding:
# self.transformer_input_enc = transformer_input_enc
# if self.transformer_input_enc in ['env', 'sin_cos_env']:
# transformer_raster = load_env()
# else:
# transformer_raster = None
# if self.transformer_input_enc == 'sinr':
# self.transformer_enc = self.enc
# else:
# self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim)
#
# # define some properties:
# self.locs = locs
# # Below line also normalises locs as well as making loc feats
# self.loc_feats = self.enc.encode(self.locs, normalize=True)
# transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False)
# self.labels = labels
# self.classes = classes
# self.class_to_taxa = class_to_taxa
# if dates is not None:
# self.dates = dates
# self.enc_time = utils.TimeEncoder()
#
# # useful numbers:
# self.num_classes = len(np.unique(labels))
# self.input_dim = input_dim
# self.time_dim = time_dim
# self.noise_time = noise_time
# self.num_context = num_context
# self.token_dim = token_dim
#
# if self.enc.raster is not None:
# self.enc.raster = self.enc.raster.to(device)
#
# if self.transformer_enc.raster is not None:
# self.transformer_enc.raster = self.transformer_enc.raster.to(device)
#
# # text stuff
# # print("CHECK WHEN YOU HAVE ACCESS TO THE SERVER WHAT THE FORM OF EMBS ARE AND ALL THAT")
# self.embs = embs
# self.embs_ids = embs_ids.tolist()
# self.embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.embs_ids]
# self.embs_keys = embs_keys
#
# # Initialize an empty dictionary to store the result
# class_emb_dict = {}
# # Populate the dictionary
# for i, (index, description) in enumerate(embs_keys):
# # Find the class using the index from the class_list
# class_id = self.embs_class_ids[index]
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
# if class_id == -1:
# continue
# # Check if the class_id is already a key in the dictionary
# if class_id not in class_emb_dict:
# # Initialize with empty lists if class_id is not already in the dictionary
# class_emb_dict[class_id] = ([], [])
#
# # Append the description and the index of embs_keys to the corresponding lists
# class_emb_dict[class_id][0].append(i)
# class_emb_dict[class_id][1].append(description)
# self.class_emb_dict = class_emb_dict
#
# # Organize the data into the dictionary
# per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(locs))
# per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats))
# for key, value in per_class_location_dict.items():
# per_class_location_dict[key] = torch.tensor(np.array(value))
# for key, value in per_class_loc_feats_dict.items():
# per_class_loc_feats_dict[key] = torch.tensor(np.array(value))
# self.per_class_locs = per_class_location_dict
# self.per_class_loc_feats = per_class_loc_feats_dict
#
# self.just_obs_prob = just_obs_prob
# self.just_text_prob = just_text_prob
# self.both_prob = 1 - (just_obs_prob + just_text_prob)
# if self.both_prob < 0:
# raise ValueError('Probability of "just text" and "just observations" must sum to less than 1')
#
# def __len__(self):
# return self.loc_feats.shape[0]
#
# def __getitem__(self, index):
# # See what needs getting
# type = self.roll_output()
#
# # Retrieve the feature and class of the original point
# loc_feat = self.loc_feats[index, :]
# loc = self.locs[index, :]
# class_id = self.labels[index]
# class_id_int = class_id.item()
#
# if 'text' in type:
# # text stuff
# # Get the embedding for the right class
# if class_id_int in self.class_emb_dict:
# embs_indexes, descriptions = self.class_emb_dict[class_id_int]
# # Randomly select an index from the list of indices
# selected_index = random.choice(embs_indexes)
# # Use the selected index to retrieve the corresponding element from embs
# emb = self.embs[selected_index]
# else:
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later?
# # emb = None
# emb = torch.zeros(4096)
# if type == 'text':
# type = 'obs_text'
# else:
# emb = torch.zeros(4096)
#
# if 'obs' in type:
# # Fetch all locations for the given class
# all_class_locs = self.per_class_locs[class_id_int]
# all_class_loc_feats = self.per_class_loc_feats[class_id_int]
#
# # Define a unique class token index
# class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # Create a zero vector for the class token
#
# # Broadcast and compare to find all matching locations
# matches = (all_class_locs == loc).all(dim=1)
#
# # Find the index of the original location
# local_index = torch.where(matches)[0]
# if len(local_index) > 1:
# local_index = local_index[0]
#
# # Exclude the original location's index
# filtered_local_indices = torch.arange(len(all_class_locs)) != local_index
#
# if self.variable_context_length:
# num_context = random.randint(1, self.num_context)
# else:
# num_context = self.num_context
#
# # Select random or all indices depending on the availability relative to `num_context`
# if filtered_local_indices.sum() > num_context:
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
# np.random.shuffle(selected_indices)
# selected_indices = selected_indices[:num_context]
# else:
# selected_indices = filtered_local_indices.nonzero().squeeze()
#
# # Get context locations and their features
# context_loc_feats = all_class_loc_feats[selected_indices]
# context_locs = all_class_locs[selected_indices]
#
# # Check if context_loc_feats has 1 dimension and add another if it does
# if context_loc_feats.dim() == 1:
# context_loc_feats = context_loc_feats.unsqueeze(0)
#
# if self.jitter:
# noise_std = 0.001
# noise = torch.full_like(context_loc_feats, noise_std)
# context_loc_feats = context_loc_feats + noise
#
# # Check if context_locs has 1 dimension and add another if it does
# if context_locs.dim() == 1:
# context_locs = context_locs.unsqueeze(0)
#
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
# else:
# class_token_feature = torch.zeros((1, len(loc_feat))) # Create a zero vector for the class token
# # print('good chance this doesnt work and need to squeeze / unsqueeze instead. Check the dims')
# # nope, it is good.
# context_sequence = torch.cat([class_token_feature], dim=0)
# context_locs = torch.empty((0, loc.size(0))) # An empty tensor with the right number of dimensions
#
# return loc_feat, loc, class_id, context_sequence, context_locs, emb
#
# def roll_output(self):
# return random.choices(
# ["obs", "text", "obs_text"],
# [self.just_obs_prob, self.just_text_prob, self.both_prob]
# )[0]
#
# def collate_fn(self, batch):
# # Unzip the batch
# loc_feats, locs, class_ids, context_sequences, context_locss, embs = zip(*batch)
#
# # Convert list of sequences to a tensor with padding
# padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10)
#
# # Convert list of class IDs to a tensor
# class_ids = torch.tensor(class_ids)
# # Convert loc_feats and locs to tensors
# loc_feats = torch.stack(loc_feats)
# locs = torch.stack(locs)
# embs = torch.stack(embs)
#
# padded_context_locs = pad_sequence(context_locss, batch_first=True, padding_value=-10)
#
# # Create a mask for sequences based on padding
# # sequence_mask = (padded_sequences == 0) # Create a mask where there's padding (-10s)
# sequence_mask = (padded_sequences == -10).all(dim=-1)
#
# # return padded_sequences, padded_locs, class_ids, sequence_mask
# return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, embs
#
# def get_item_from_class(self, class_id):
# # Fetch all locations and features for the given class
# all_class_locs = self.per_class_locs[class_id]
# all_class_loc_feats = self.per_class_loc_feats[class_id]
#
# # Randomly select an index for the class
# index = np.random.choice(len(all_class_locs))
#
# # Retrieve the selected location and its features
# loc = all_class_locs[index]
# if loc.ndim == 1:
# loc = loc.unsqueeze(0)
# # loc = loc.unsqueeze(0)
# loc_feat = self.enc.encode(loc, normalize=False)
# # loc_feat = all_class_loc_feats[index]
#
# # Define a unique class token index
# class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token
#
# # Exclude the selected index from the context
# filtered_local_indices = torch.arange(len(all_class_locs)) != index
#
# # Select random or all indices depending on the availability relative to `num_context`
# if filtered_local_indices.sum() > self.num_context:
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
# np.random.shuffle(selected_indices)
# selected_indices = selected_indices[:self.num_context]
# else:
# selected_indices = filtered_local_indices.nonzero().squeeze()
#
# # Get context locations and their features
# context_loc_feats = all_class_loc_feats[selected_indices]
# context_locs = all_class_locs[selected_indices]
#
# # Check if context_loc_feats has 1 dimension and add another if it does
# if context_loc_feats.dim() == 1:
# context_loc_feats = context_loc_feats.unsqueeze(0)
# # Check if context_locs has 1 dimension and add another if it does
# if context_locs.dim() == 1:
# context_locs = context_locs.unsqueeze(0)
#
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
#
# # text stuff
# # Get the embedding for the right class
# if class_id in self.class_emb_dict:
# embs_indexes, descriptions = self.class_emb_dict[class_id]
# # Randomly select an index from the list of indices
# selected_index = random.choice(embs_indexes)
# # Use the selected index to retrieve the corresponding element from embs
# emb = self.embs[selected_index]
# else:
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later?
# # emb = None
# emb = torch.zeros(4096)
#
# return loc_feat, loc, class_id, context_sequence, context_locs, emb
class TransformerDatasetVariableTokens(torch.utils.data.Dataset):
def __init__(self, locs, labels, classes, class_to_taxa, text_embs, text_embs_ids, text_embs_keys, image_embs,
image_embs_ids, image_embs_keys, input_enc, device, dates=None, input_dim=4, time_dim=0,
noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False,
variable_context_length=False, loc_prob=1.0, text_prob=0.0, image_prob=0.0, env_prob=0.0,
eval_mode=False):
# handle input encoding:
# ensure all stuff is on cpu until batching
self.input_enc = input_enc
self.jitter = jitter
self.variable_context_length = variable_context_length
if self.input_enc in ['env', 'sin_cos_env']:
raster = load_env()
else:
raster = None
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
# set up what token types and probs of use we have
self.loc_prob = loc_prob
self.text_prob = text_prob
self.image_prob = image_prob
self.env_prob = env_prob
if self.loc_prob == 0.0:
self.use_loc = False
else:
self.use_loc = True
if self.text_prob == 0.0:
self.use_text = False
else:
self.use_text = True
if self.image_prob == 0.0:
self.use_image = False
else:
self.use_image = True
if self.env_prob == 0.0:
self.use_env = False
else:
self.use_env = True
# handle transformer input encoding:
self.transformer_input_enc = transformer_input_enc
if self.transformer_input_enc in ['env', 'sin_cos_env']:
transformer_raster = load_env()
else:
transformer_raster = None
if self.transformer_input_enc == 'sinr':
self.transformer_enc = self.enc
else:
self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim)
# Remove device assignment from raster encoders
# if self.enc.raster is not None:
# self.enc.raster = self.enc.raster.to(device)
# if self.transformer_enc.raster is not None:
# self.transformer_enc.raster = self.transformer_enc.raster.to(device)
# define some properties:
self.locs = locs # Keep on CPU
# Below line also normalises locs as well as making loc feats
self.loc_feats = self.enc.encode(self.locs, normalize=True)
transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False)
self.labels = labels # Keep on CPU
self.classes = classes
self.class_to_taxa = class_to_taxa
if dates is not None:
self.dates = dates
self.enc_time = utils.TimeEncoder()
# useful numbers:
self.num_classes = len(np.unique(labels))
self.input_dim = input_dim
self.time_dim = time_dim
self.noise_time = noise_time
self.num_context = num_context
self.token_dim = token_dim
self.env_emb_size = 20
self.image_emb_size = 1024
self.text_emb_size = 4096
self.device = device # You can remove this if not needed
# text stuff
if self.use_text:
self.text_embs = text_embs # Keep on CPU
self.text_embs_ids = text_embs_ids.tolist()
self.text_embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.text_embs_ids]
self.text_embs_keys = text_embs_keys
# Initialize an empty dictionary to store the result
class_emb_dict = {}
# Populate the dictionary
for i, (index, description) in enumerate(text_embs_keys):
# Find the class using the index from the class_list
class_id = self.text_embs_class_ids[index]
# Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
if class_id == -1:
continue
# Check if the class_id is already a key in the dictionary
if class_id not in class_emb_dict:
# Initialize with empty lists if class_id is not already in the dictionary
class_emb_dict[class_id] = ([], [])
# Append the description and the index of embs_keys to the corresponding lists
class_emb_dict[class_id][0].append(i)
class_emb_dict[class_id][1].append(description)
self.text_class_emb_dict = class_emb_dict
else:
self.text_embs = None
self.text_embs_ids = None
self.text_embs_class_ids = None
self.text_embs_keys = None
self.text_class_emb_dict = None
# image stuff
if self.use_image:
self.image_embs = image_embs # Keep on CPU
self.image_embs_ids = image_embs_ids.tolist()
self.image_embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.image_embs_ids]
self.image_embs_keys = image_embs_keys
# Initialize an empty dictionary to store the result
class_emb_dict = {}
# Populate the dictionary
for i, (index, description) in enumerate(image_embs_keys):
# Find the class using the index from the class_list
class_id = self.image_embs_class_ids[index]
# Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
if class_id == -1:
continue
# Check if the class_id is already a key in the dictionary
if class_id not in class_emb_dict:
# Initialize with empty lists if class_id is not already in the dictionary
class_emb_dict[class_id] = ([], [])
# Append the description and the index of embs_keys to the corresponding lists
class_emb_dict[class_id][0].append(i)
class_emb_dict[class_id][1].append(description)
self.image_class_emb_dict = class_emb_dict
else:
self.image_embs = None
self.image_embs_ids = None
self.image_embs_class_ids = None
self.image_embs_keys = None
self.image_class_emb_dict = None
# Organize the data into the dictionary
per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(self.locs))
per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats))
for key, value in per_class_location_dict.items():
per_class_location_dict[key] = torch.tensor(np.array(value)) # Keep on CPU
for key, value in per_class_loc_feats_dict.items():
per_class_loc_feats_dict[key] = torch.tensor(np.array(value)) # Keep on CPU
self.per_class_locs = per_class_location_dict
self.per_class_loc_feats = per_class_loc_feats_dict
# env stuff
if self.use_env:
env_raster = load_env()
self.env_enc = utils.CoordEncoder('env', env_raster, input_dim=0)
# if self.env_enc is not None:
# self.env_enc.raster = self.env_enc.raster.to(device)
env_feats = self.env_enc.encode(self.locs, normalize=False)
per_class_env_feats_dict = organize_data_by_labels(np.array(labels), np.array(env_feats))
for key, value in per_class_env_feats_dict.items():
per_class_env_feats_dict[key] = torch.tensor(np.array(value)) # Keep on CPU
self.per_class_env_feats = per_class_env_feats_dict
else:
self.env_enc = None
self.per_class_env_feats = None
self.eval_mode = eval_mode
if eval_mode:
print('Using eval dataset. One example per class to generate eval embeddings')
# Select a single example per class
unique_labels, unique_indices = np.unique(labels, return_index=True)
self.locs = self.locs[unique_indices]
self.labels = labels[unique_indices]
self.loc_feats = self.loc_feats[unique_indices]
def __len__(self):
return self.loc_feats.shape[0]
def __getitem__(self, index):
# See what needs getting
type = self.roll_output()
# Retrieve the feature and class of the original point
loc_feat = self.loc_feats[index, :]
loc = self.locs[index, :]
class_id = self.labels[index]
class_id_int = class_id.item()
if 'text' in type:
# text stuff
# Get the embedding for the right class
if class_id_int in self.text_class_emb_dict:
text_embs_indexes, descriptions = self.text_class_emb_dict[class_id_int]
# Randomly select an index from the list of indices
selected_index = random.choice(text_embs_indexes)
# Use the selected index to retrieve the corresponding element from embs
text_emb = self.text_embs[selected_index]
else:
# Set emb to all zeros
text_emb = torch.zeros(self.text_emb_size)
else:
text_emb = torch.zeros(self.text_emb_size)
if 'image' in type:
# image stuff
if class_id_int in self.image_class_emb_dict:
image_embs_indexes, descriptions = self.image_class_emb_dict[class_id_int]
selected_index = random.choice(image_embs_indexes)
image_emb = self.image_embs[selected_index]
else:
image_emb = torch.zeros(self.image_emb_size)
else:
image_emb = torch.zeros(self.image_emb_size)
if 'loc' in type:
# Fetch all locations for the given class
all_class_locs = self.per_class_locs[class_id_int]
all_class_loc_feats = self.per_class_loc_feats[class_id_int]
# Define a unique class token index
class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # CPU tensor
# Broadcast and compare to find all matching locations
matches = (all_class_locs == loc).all(dim=1)
# Find the index of the original location
local_index = torch.where(matches)[0]
if len(local_index) > 1:
local_index = local_index[0]
# Exclude the original location's index
filtered_local_indices = torch.arange(len(all_class_locs)) != local_index
if self.variable_context_length:
num_context = random.randint(1, self.num_context)
else:
num_context = self.num_context
if filtered_local_indices.sum() > num_context:
selected_indices = filtered_local_indices.nonzero().squeeze()
perm = torch.randperm(selected_indices.size(0))
selected_indices = selected_indices[perm][:num_context]
else:
selected_indices = filtered_local_indices.nonzero().squeeze()
# Get context locations and their features
context_loc_feats = all_class_loc_feats[selected_indices]
context_locs = all_class_locs[selected_indices]
# Check if context_loc_feats has 1 dimension and add another if it does
if context_loc_feats.dim() == 1:
context_loc_feats = context_loc_feats.unsqueeze(0)
if self.jitter:
noise_std = 0.001
noise = torch.full_like(context_loc_feats, noise_std)
context_loc_feats = context_loc_feats + noise
# Check if context_locs has 1 dimension and add another if it does
if context_locs.dim() == 1:
context_locs = context_locs.unsqueeze(0)
context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
else:
class_token_feature = torch.zeros((1, len(loc_feat))) # CPU tensor
context_sequence = torch.cat([class_token_feature], dim=0)
context_locs = torch.empty((0, loc.size(0))) # Empty CPU tensor
if ('loc' in type) and ('env' not in type) and self.use_env:
env_feats = torch.zeros((len(selected_indices), self.env_emb_size))
elif ('env' in type) and ('loc' in type):
all_class_env_feats = self.per_class_env_feats[class_id_int]
env_feats = all_class_env_feats[selected_indices]
if env_feats.dim() == 1:
env_feats = env_feats.unsqueeze(0)
env_feats = torch.cat([env_feats], dim=0)
else:
env_feats = torch.zeros((1, self.env_emb_size))
return loc_feat, loc, class_id, context_sequence, context_locs, text_emb, image_emb, env_feats
def roll_output(self):
output = []
if random.random() < self.loc_prob:
output.append('loc')
if random.random() < self.image_prob:
output.append('image')
if random.random() < self.text_prob:
output.append('text')
if random.random() < self.env_prob:
output.append('env')
if (len(output) == 0) and self.loc_prob != 0.0:
output.append('loc')
return output
def collate_fn(self, batch):
# Unzip the batch
loc_feats, locs, class_ids, context_sequences, context_locss, text_embs, image_embs, env_feats = zip(*batch)
# Convert list of sequences to a tensor with padding
padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10)
# Convert list of class IDs to a tensor
class_ids = torch.tensor(class_ids)
# Convert loc_feats and locs to tensors
loc_feats = torch.stack(loc_feats)
locs = torch.stack(locs)
text_embs = torch.stack(text_embs)
image_embs = torch.stack(image_embs)
padded_context_locs = pad_sequence(context_locss, batch_first=True, padding_value=-10)
if self.use_env:
padded_env_feats = pad_sequence(env_feats, batch_first=True, padding_value=-10)
else:
padded_env_feats = None
# Create a mask for sequences based on padding
sequence_mask = (padded_sequences == -10).all(dim=-1)
return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, text_embs, image_embs, padded_env_feats
def select_text_section(self, text_section):
# Initialize an empty dictionary to store the result
text_class_emb_dict = {}
# Populate the dictionary
for i, (index, description) in enumerate(self.text_embs_keys):
# Find the class using the index from the class_list
class_id = self.text_embs_class_ids[index]
# Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
if class_id == -1:
continue
if description != text_section:
continue
# Check if the class_id is already a key in the dictionary
if class_id not in text_class_emb_dict:
# Initialize with empty lists if class_id is not already in the dictionary
text_class_emb_dict[class_id] = ([], [])
# Append the description and the index of embs_keys to the corresponding lists
text_class_emb_dict[class_id][0].append(i)
text_class_emb_dict[class_id][1].append(description)
self.text_class_emb_dict = text_class_emb_dict
# class TransformerDatasetVariableTokens(torch.utils.data.Dataset):
# def __init__(self, locs, labels, classes, class_to_taxa, text_embs, text_embs_ids, text_embs_keys, image_embs,
# image_embs_ids, image_embs_keys, input_enc, device, dates=None, input_dim=4, time_dim=0,
# noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False,
# variable_context_length=False, loc_prob=1.0, text_prob=0.0, image_prob=0.0, env_prob=0.0,
# eval_mode=False):
# # handle input encoding:
# # ensure all stuff is on cpu until batching
# self.input_enc = input_enc
# self.jitter=jitter
# self.variable_context_length=variable_context_length
# if self.input_enc in ['env', 'sin_cos_env']:
# raster = load_env()
# else:
# raster = None
# self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
#
# # set up what token types and probs of use we have
# self.loc_prob = loc_prob
# self.text_prob = text_prob
# self.image_prob = image_prob
# self.env_prob = env_prob
#
# if self.loc_prob == 0.0:
# self.use_loc = False
# else:
# self.use_loc = True
# if self.text_prob == 0.0:
# self.use_text = False
# else:
# self.use_text = True
# if self.image_prob == 0.0:
# self.use_image = False
# else:
# self.use_image = True
# if self.env_prob == 0.0:
# self.use_env = False
# else:
# self.use_env = True
#
# # handle transformer input encoding:
# self.transformer_input_enc = transformer_input_enc
# if self.transformer_input_enc in ['env', 'sin_cos_env']:
# transformer_raster = load_env()
# else:
# transformer_raster = None
# if self.transformer_input_enc == 'sinr':
# self.transformer_enc = self.enc
# else:
# self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim)
#
# if self.enc.raster is not None:
# self.enc.raster = self.enc.raster.to(device)
#
# if self.transformer_enc.raster is not None:
# self.transformer_enc.raster = self.transformer_enc.raster.to(device)
#
# # define some properties:
# self.locs = locs.to(device)
# # Below line also normalises locs as well as making loc feats
# self.loc_feats = self.enc.encode(self.locs, normalize=True)
# transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False)
# self.labels = labels.to(device)
# self.classes = classes
# self.class_to_taxa = class_to_taxa
# if dates is not None:
# self.dates = dates
# self.enc_time = utils.TimeEncoder()
#
# # useful numbers:
# self.num_classes = len(np.unique(labels))
# self.input_dim = input_dim
# self.time_dim = time_dim
# self.noise_time = noise_time
# self.num_context = num_context
# self.token_dim = token_dim
# self.env_emb_size = 20
# self.image_emb_size = 1024
# self.text_emb_size = 4096
# self.device=device
#
# # text stuff
# if self.use_text:
# self.text_embs = text_embs.to(device)
# self.text_embs_ids = text_embs_ids.tolist()
# self.text_embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.text_embs_ids]
# self.text_embs_keys = text_embs_keys
#
# # Initialize an empty dictionary to store the result
# class_emb_dict = {}
# # Populate the dictionary
# for i, (index, description) in enumerate(text_embs_keys):
# # Find the class using the index from the class_list
# class_id = self.text_embs_class_ids[index]
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
# if class_id == -1:
# continue
# # Check if the class_id is already a key in the dictionary
# if class_id not in class_emb_dict:
# # Initialize with empty lists if class_id is not already in the dictionary
# class_emb_dict[class_id] = ([], [])
#
# # Append the description and the index of embs_keys to the corresponding lists
# class_emb_dict[class_id][0].append(i)
# class_emb_dict[class_id][1].append(description)
# self.text_class_emb_dict = class_emb_dict
# else:
# self.text_embs = None
# self.text_embs_ids = None
# self.text_embs_class_ids = None
# self.text_embs_keys = None
# self.text_class_emb_dict = None
#
# # image stuff
# if self.use_image:
# self.image_embs = image_embs.to(device)
# self.image_embs_ids = image_embs_ids.tolist()
# self.image_embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.image_embs_ids]
# self.image_embs_keys = image_embs_keys
#
# # Initialize an empty dictionary to store the result
# class_emb_dict = {}
# # Populate the dictionary
# for i, (index, description) in enumerate(image_embs_keys):
# # Find the class using the index from the class_list
# class_id = self.image_embs_class_ids[index]
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
# if class_id == -1:
# continue
# # Check if the class_id is already a key in the dictionary
# if class_id not in class_emb_dict:
# # Initialize with empty lists if class_id is not already in the dictionary
# class_emb_dict[class_id] = ([], [])
#
# # Append the description and the index of embs_keys to the corresponding lists
# class_emb_dict[class_id][0].append(i)
# class_emb_dict[class_id][1].append(description)
# self.image_class_emb_dict = class_emb_dict
# else:
# self.image_embs = None
# self.image_embs_ids = None
# self.image_embs_class_ids = None
# self.image_embs_keys = None
# self.image_class_emb_dict = None
#
# # Organize the data into the dictionary
# per_class_location_dict = organize_data_by_labels(np.array(labels.cpu()), np.array(self.locs.cpu()))
# per_class_loc_feats_dict = organize_data_by_labels(np.array(labels.cpu()), np.array(transformer_loc_feats.cpu()))
# for key, value in per_class_location_dict.items():
# per_class_location_dict[key] = torch.tensor(np.array(value)).to(device)
# for key, value in per_class_loc_feats_dict.items():
# per_class_loc_feats_dict[key] = torch.tensor(np.array(value)).to(device)
# self.per_class_locs = per_class_location_dict
# self.per_class_loc_feats = per_class_loc_feats_dict
#
# # env stuff - semi useless for now - I intend to give "paired" locs and env data but that is for later on
# # additionally we also need the env data if self.use_env is true
# if self.use_env:
# env_raster = load_env()
# self.env_enc = utils.CoordEncoder('env', env_raster, input_dim=0)
# if self.env_enc is not None:
# self.env_enc.raster = self.env_enc.raster.to(device)
# env_feats = self.env_enc.encode(self.locs, normalize=False)
# per_class_env_feats_dict = organize_data_by_labels(np.array(labels.cpu()), np.array(env_feats.cpu()))
# for key, value in per_class_env_feats_dict.items():
# per_class_env_feats_dict[key] = torch.tensor(np.array(value)).to(device)
# self.per_class_env_feats = per_class_env_feats_dict
# else:
# self.env_enc = None
# self.per_class_env_feats = None
#
# self.eval_mode = eval_mode
#
# if eval_mode:
# print('Using eval dataset. One example per class to generate eval embeddings')
# # Select a single example per class
# unique_labels, unique_indices = np.unique(labels.cpu(), return_index=True)
# self.locs = self.locs[unique_indices].to(device)
# self.labels = labels[unique_indices].to(device)
# self.loc_feats = self.loc_feats[unique_indices].to(device)
#
#
# def __len__(self):
# return self.loc_feats.shape[0]
#
# def __getitem__(self, index):
# # See what needs getting
# type = self.roll_output()
#
# # Retrieve the feature and class of the original point
# loc_feat = self.loc_feats[index, :]
# loc = self.locs[index, :]
# class_id = self.labels[index]
# class_id_int = class_id.item()
#
# if 'text' in type:
# # text stuff
# # Get the embedding for the right class
# if class_id_int in self.text_class_emb_dict:
# text_embs_indexes, descriptions = self.text_class_emb_dict[class_id_int]
# # Randomly select an index from the list of indices
# selected_index = random.choice(text_embs_indexes)
# # Use the selected index to retrieve the corresponding element from embs
# text_emb = self.text_embs[selected_index]
# else:
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later?
# # emb = None
# text_emb = torch.zeros(self.text_emb_size, device=self.device)
# # if type == 'text':
# # type = 'obs_text'
# else:
# text_emb = torch.zeros(self.text_emb_size, device=self.device)
#
# if 'image' in type:
# # text stuff
# # Get the embedding for the right class
# if class_id_int in self.image_class_emb_dict:
# image_embs_indexes, descriptions = self.image_class_emb_dict[class_id_int]
# # Randomly select an index from the list of indices
# selected_index = random.choice(image_embs_indexes)
# # Use the selected index to retrieve the corresponding element from embs
# image_emb = self.image_embs[selected_index]
# else:
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later?
# # emb = None
# image_emb = torch.zeros(self.image_emb_size, device=self.device)
#
# else:
# image_emb = torch.zeros(self.image_emb_size, device=self.device)
#
# if 'loc' in type:
# # Fetch all locations for the given class
# all_class_locs = self.per_class_locs[class_id_int]
# all_class_loc_feats = self.per_class_loc_feats[class_id_int]
#
# # Define a unique class token index
# class_token_feature = torch.zeros((1, len(all_class_loc_feats[0])), device=self.device) # Create a zero vector for the class token
#
# # Broadcast and compare to find all matching locations
# matches = (all_class_locs == loc).all(dim=1)
#
# # Find the index of the original location
# local_index = torch.where(matches)[0]
# if len(local_index) > 1:
# local_index = local_index[0]
#
# # Exclude the original location's index
# filtered_local_indices = torch.arange(len(all_class_locs), device=self.device) != local_index
#
# if self.variable_context_length:
# num_context = random.randint(1, self.num_context)
# # num_context = torch.randint(1, self.num_context + 1, (1,), device=self.device).item()
# else:
# num_context = self.num_context
#
# if filtered_local_indices.sum() > num_context:
# selected_indices = filtered_local_indices.nonzero().squeeze()
# perm = torch.randperm(selected_indices.size(0), device=self.device)
# selected_indices = selected_indices[perm][:num_context]
# else:
# selected_indices = filtered_local_indices.nonzero().squeeze()
#
# # Get context locations and their features
# context_loc_feats = all_class_loc_feats[selected_indices]
# context_locs = all_class_locs[selected_indices]
#
# # Check if context_loc_feats has 1 dimension and add another if it does
# if context_loc_feats.dim() == 1:
# context_loc_feats = context_loc_feats.unsqueeze(0)
#
# if self.jitter:
# noise_std = 0.001
# noise = torch.full_like(context_loc_feats, noise_std)
# context_loc_feats = context_loc_feats + noise
#
# # Check if context_locs has 1 dimension and add another if it does
# if context_locs.dim() == 1:
# context_locs = context_locs.unsqueeze(0)
#
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
# else:
# class_token_feature = torch.zeros((1, len(loc_feat)), device=self.device) # Create a zero vector for the class token
# # print('good chance this doesnt work and need to squeeze / unsqueeze instead. Check the dims')
# # nope, it is good.
# context_sequence = torch.cat([class_token_feature], dim=0)
# context_locs = torch.empty((0, loc.size(0)), device=self.device) # An empty tensor with the right number of dimensions
#
# if ('loc' in type) and ('env' not in type) and self.use_env:
# env_feats = torch.zeros((len(selected_indices), self.env_emb_size), device=self.device)
#
# elif ('env' in type) and ('loc' in type):
# all_class_env_feats = self.per_class_env_feats[class_id_int]
# env_feats = all_class_env_feats[selected_indices]
# # Check if context_loc_feats has 1 dimension and add another if it does
# if env_feats.dim() == 1:
# env_feats = env_feats.unsqueeze(0)
# env_feats = torch.cat([env_feats], dim=0)
# else:
# env_feats = torch.zeros((1, self.env_emb_size), device=self.device)
#
# return loc_feat, loc, class_id, context_sequence, context_locs, text_emb, image_emb, env_feats
#
# def roll_output(self):
# output = []
# if random.random() < self.loc_prob:
# output.append('loc')
# if random.random() < self.image_prob:
# output.append('image')
# if random.random() < self.text_prob:
# output.append('text')
# if random.random() < self.env_prob:
# output.append('env')
# # if len(output) == 0:
# # output.append('loc')
# return output
#
# def collate_fn(self, batch):
# # Unzip the batch
# loc_feats, locs, class_ids, context_sequences, context_locss, text_embs, image_embs, env_feats = zip(*batch)
#
# # Convert list of sequences to a tensor with padding
# padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10)
#
# # Convert list of class IDs to a tensor
# class_ids = torch.tensor(class_ids)
# # Convert loc_feats and locs to tensors
# loc_feats = torch.stack(loc_feats)
# locs = torch.stack(locs)
# text_embs = torch.stack(text_embs)
# image_embs = torch.stack(image_embs)
#
# padded_context_locs = pad_sequence(context_locss, batch_first=True, padding_value=-10)
#
# if self.use_env:
# padded_env_feats = pad_sequence(env_feats, batch_first=True, padding_value=-10)
# else:
# padded_env_feats = None
#
# # Create a mask for sequences based on padding
# # sequence_mask = (padded_sequences == 0) # Create a mask where there's padding (-10s)
# sequence_mask = (padded_sequences == -10).all(dim=-1)
#
# # return padded_sequences, padded_locs, class_ids, sequence_mask
# return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, text_embs, image_embs, padded_env_feats
#
# def get_item_from_class(self, class_id):
# # Fetch all locations and features for the given class
# all_class_locs = self.per_class_locs[class_id]
# all_class_loc_feats = self.per_class_loc_feats[class_id]
#
# # Randomly select an index for the class
# index = np.random.choice(len(all_class_locs))
#
# # Retrieve the selected location and its features
# loc = all_class_locs[index]
# if loc.ndim == 1:
# loc = loc.unsqueeze(0)
# # loc = loc.unsqueeze(0)
# loc_feat = self.enc.encode(loc, normalize=False)
# # loc_feat = all_class_loc_feats[index]
#
# # Define a unique class token index
# class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token
#
# # Exclude the selected index from the context
# filtered_local_indices = torch.arange(len(all_class_locs)) != index
#
# # Select random or all indices depending on the availability relative to `num_context`
# if filtered_local_indices.sum() > self.num_context:
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
# np.random.shuffle(selected_indices)
# selected_indices = selected_indices[:self.num_context]
# else:
# selected_indices = filtered_local_indices.nonzero().squeeze()
#
# # Get context locations and their features
# context_loc_feats = all_class_loc_feats[selected_indices]
# context_locs = all_class_locs[selected_indices]
#
# # Check if context_loc_feats has 1 dimension and add another if it does
# if context_loc_feats.dim() == 1:
# context_loc_feats = context_loc_feats.unsqueeze(0)
# # Check if context_locs has 1 dimension and add another if it does
# if context_locs.dim() == 1:
# context_locs = context_locs.unsqueeze(0)
#
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
#
# # text stuff
# # Get the embedding for the right class
# if class_id in self.text_class_emb_dict:
# text_embs_indexes, descriptions = self.text_class_emb_dict[class_id]
# # Randomly select an index from the list of indices
# selected_index = random.choice(text_embs_indexes)
# # Use the selected index to retrieve the corresponding element from embs
# text_emb = self.text_embs[selected_index]
# else:
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later?
# # emb = None
# text_emb = torch.zeros(self.text_emb_size)
#
# # image stuff
# # Get the embedding for the right class
# if class_id in self.image_class_emb_dict:
# image_embs_indexes, descriptions = self.image_class_emb_dict[class_id]
# # Randomly select an index from the list of indices
# selected_index = random.choice(image_embs_indexes)
# # Use the selected index to retrieve the corresponding element from embs
# image_emb = self.image_embs[selected_index]
# else:
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later?
# # emb = None
# image_emb = torch.zeros(self.image_emb_size)
#
# return loc_feat, loc, class_id, context_sequence, context_locs, text_emb, image_emb
#
#
# class EvalTransformerDatasetVariableTokens(torch.utils.data.Dataset):
# def __init__(self, locs, labels, classes, class_to_taxa, text_embs, text_embs_ids, text_embs_keys, image_embs,
# image_embs_ids, image_embs_keys, input_enc, device, dates=None, input_dim=4, time_dim=0,
# noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False,
# variable_context_length=False, loc_prob=1.0, text_prob=0.0, image_prob=0.0, env_prob=0.0):
# # handle input encoding:
# self.input_enc = input_enc
# self.jitter=jitter
# self.variable_context_length=variable_context_length
# if self.input_enc in ['env', 'sin_cos_env']:
# raster = load_env()
# else:
# raster = None
# self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
#
# # set up what token types and probs of use we have
# self.loc_prob = loc_prob
# self.text_prob = text_prob
# self.image_prob = image_prob
# self.env_prob = env_prob
#
# if self.loc_prob == 0.0:
# self.use_loc = False
# else:
# self.use_loc = True
# if self.text_prob == 0.0:
# self.use_text = False
# else:
# self.use_text = True
# if self.image_prob == 0.0:
# self.use_image = False
# else:
# self.use_image = True
# if self.env_prob == 0.0:
# self.use_env = False
# else:
# self.use_env = True
#
# # handle transformer input encoding:
# self.transformer_input_enc = transformer_input_enc
# if self.transformer_input_enc in ['env', 'sin_cos_env']:
# transformer_raster = load_env()
# else:
# transformer_raster = None
# if self.transformer_input_enc == 'sinr':
# self.transformer_enc = self.enc
# else:
# self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim)
#
# # env stuff - semi useless for now - I intend to give "paired" locs and env data but that is for later on
# # additionally we also need the env data if self.use_env is true
# if self.use_env:
# env_raster = load_env()
# self.env_enc = utils.CoordEncoder('env', env_raster, input_dim=0)
# else:
# self.env_enc = None
#
# # define some properties:
# self.locs = locs
# # Below line also normalises locs as well as making loc feats
# self.loc_feats = self.enc.encode(self.locs, normalize=True)
# transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False)
# self.labels = labels
# self.classes = classes
# self.class_to_taxa = class_to_taxa
# if dates is not None:
# self.dates = dates
# self.enc_time = utils.TimeEncoder()
#
# # useful numbers:
# self.num_classes = len(np.unique(labels))
# self.input_dim = input_dim
# self.time_dim = time_dim
# self.noise_time = noise_time
# self.num_context = num_context
# self.token_dim = token_dim
#
# if self.enc.raster is not None:
# self.enc.raster = self.enc.raster.to(device)
#
# if self.transformer_enc.raster is not None:
# self.transformer_enc.raster = self.transformer_enc.raster.to(device)
#
# if self.env_enc is not None:
# self.env_enc.raster = self.env_enc.raster.to(device)
#
# # text stuff
# if self.use_text:
# self.text_embs = text_embs
# self.text_embs_ids = text_embs_ids.tolist()
# self.text_embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.text_embs_ids]
# self.text_embs_keys = text_embs_keys
# self.text_emb_size = 4096
#
# # Initialize an empty dictionary to store the result
# class_emb_dict = {}
# # Populate the dictionary
# for i, (index, description) in enumerate(text_embs_keys):
# # Find the class using the index from the class_list
# class_id = self.image_embs_class_ids[index]
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
# if class_id == -1:
# continue
# # Check if the class_id is already a key in the dictionary
# if class_id not in class_emb_dict:
# # Initialize with empty lists if class_id is not already in the dictionary
# class_emb_dict[class_id] = ([], [])
#
# # Append the description and the index of embs_keys to the corresponding lists
# class_emb_dict[class_id][0].append(i)
# class_emb_dict[class_id][1].append(description)
# self.text_class_emb_dict = class_emb_dict
# else:
# self.text_embs = None
# self.text_embs_ids = None
# self.text_embs_class_ids = None
# self.text_embs_keys = None
# self.text_class_emb_dict = None
# self.text_emb_size = None
#
# # image stuff
# if self.use_image:
# self.image_embs = image_embs
# self.image_embs_ids = image_embs_ids.tolist()
# self.image_embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.image_embs_ids]
# self.image_embs_keys = image_embs_keys
# self.image_emb_size = 1024
#
# # Initialize an empty dictionary to store the result
# class_emb_dict = {}
# # Populate the dictionary
# for i, (index, description) in enumerate(image_embs_keys):
# # Find the class using the index from the class_list
# class_id = self.image_embs_class_ids[index]
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
# if class_id == -1:
# continue
# # Check if the class_id is already a key in the dictionary
# if class_id not in class_emb_dict:
# # Initialize with empty lists if class_id is not already in the dictionary
# class_emb_dict[class_id] = ([], [])
#
# # Append the description and the index of embs_keys to the corresponding lists
# class_emb_dict[class_id][0].append(i)
# class_emb_dict[class_id][1].append(description)
# self.image_class_emb_dict = class_emb_dict
# else:
# self.image_embs = None
# self.image_embs_ids = None
# self.image_embs_class_ids = None
# self.image_embs_keys = None
# self.image_class_emb_dict = None
# self.image_emb_size = None
#
# # Organize the data into the dictionary
# per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(locs))
# per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats))
# for key, value in per_class_location_dict.items():
# per_class_location_dict[key] = torch.tensor(np.array(value))
# for key, value in per_class_loc_feats_dict.items():
# per_class_loc_feats_dict[key] = torch.tensor(np.array(value))
# self.per_class_locs = per_class_location_dict
# self.per_class_loc_feats = per_class_loc_feats_dict
#
# # MOD FOR EVAL MOD FOR EVAL
# # Select a single example per class
# unique_labels, unique_indices = np.unique(labels, return_index=True)
# self.locs = locs[unique_indices]
# self.labels = labels[unique_indices]
# self.loc_feats = self.loc_feats[unique_indices]
#
# # self.just_obs_prob = just_obs_prob
# # self.just_text_prob = just_text_prob
# # self.both_prob = 1 - (just_obs_prob + just_text_prob)
# # if self.both_prob < 0:
# # raise ValueError('Probability of "just text" and "just observations" must sum to less than 1')
#
# def __len__(self):
# return self.loc_feats.shape[0]
#
# def __getitem__(self, index):
# # See what needs getting
# type = self.roll_output()
#
# # Retrieve the feature and class of the original point
# loc_feat = self.loc_feats[index, :]
# loc = self.locs[index, :]
# class_id = self.labels[index]
# class_id_int = class_id.item()
#
# if 'text' in type:
# # text stuff
# # Get the embedding for the right class
# if class_id_int in self.text_class_emb_dict:
# text_embs_indexes, descriptions = self.text_class_emb_dict[class_id_int]
# # Randomly select an index from the list of indices
# selected_index = random.choice(text_embs_indexes)
# # Use the selected index to retrieve the corresponding element from embs
# text_emb = self.text_embs[selected_index]
# else:
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later?
# # emb = None
# text_emb = torch.zeros(self.text_emb_size)
# # if type == 'text':
# # type = 'obs_text'
# else:
# text_emb = torch.zeros(self.text_emb_size)
#
# if 'image' in type:
# # text stuff
# # Get the embedding for the right class
# if class_id_int in self.image_class_emb_dict:
# image_embs_indexes, descriptions = self.image_class_emb_dict[class_id_int]
# # Randomly select an index from the list of indices
# selected_index = random.choice(image_embs_indexes)
# # Use the selected index to retrieve the corresponding element from embs
# image_emb = self.image_embs[selected_index]
# else:
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later?
# # emb = None
# image_emb = torch.zeros(self.image_emb_size)
#
# else:
# image_emb = torch.zeros(self.image_emb_size)
#
# if 'loc' in type:
# # Fetch all locations for the given class
# all_class_locs = self.per_class_locs[class_id_int]
# all_class_loc_feats = self.per_class_loc_feats[class_id_int]
#
# # Define a unique class token index
# class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # Create a zero vector for the class token
#
# # Broadcast and compare to find all matching locations
# matches = (all_class_locs == loc).all(dim=1)
#
# # Find the index of the original location
# local_index = torch.where(matches)[0]
# if len(local_index) > 1:
# local_index = local_index[0]
#
# # Exclude the original location's index
# filtered_local_indices = torch.arange(len(all_class_locs)) != local_index
#
# if self.variable_context_length:
# num_context = random.randint(1, self.num_context)
# else:
# num_context = self.num_context
#
# # Select random or all indices depending on the availability relative to `num_context`
# if filtered_local_indices.sum() > num_context:
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
# np.random.shuffle(selected_indices)
# selected_indices = selected_indices[:num_context]
# else:
# selected_indices = filtered_local_indices.nonzero().squeeze()
#
# # Get context locations and their features
# context_loc_feats = all_class_loc_feats[selected_indices]
# context_locs = all_class_locs[selected_indices]
#
# # Check if context_loc_feats has 1 dimension and add another if it does
# if context_loc_feats.dim() == 1:
# context_loc_feats = context_loc_feats.unsqueeze(0)
#
# if self.jitter:
# noise_std = 0.001
# noise = torch.full_like(context_loc_feats, noise_std)
# context_loc_feats = context_loc_feats + noise
#
# # Check if context_locs has 1 dimension and add another if it does
# if context_locs.dim() == 1:
# context_locs = context_locs.unsqueeze(0)
#
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
# else:
# class_token_feature = torch.zeros((1, len(loc_feat))) # Create a zero vector for the class token
# # print('good chance this doesnt work and need to squeeze / unsqueeze instead. Check the dims')
# # nope, it is good.
# context_sequence = torch.cat([class_token_feature], dim=0)
# context_locs = torch.empty((0, loc.size(0))) # An empty tensor with the right number of dimensions
#
# return loc_feat, loc, class_id, context_sequence, context_locs, text_emb, image_emb
#
# def roll_output(self):
# import random
# output = []
#
# if random.random() < self.loc_prob:
# output.append('loc')
# if random.random() < self.image_prob:
# output.append('image')
# if random.random() < self.text_prob:
# output.append('text')
# if random.random() < self.env_prob:
# output.append('env')
# # if len(output) == 0:
# # output.append('loc')
# return output
#
# def collate_fn(self, batch):
# # Unzip the batch
# loc_feats, locs, class_ids, context_sequences, context_locss, text_embs, image_embs = zip(*batch)
#
# # Convert list of sequences to a tensor with padding
# padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10)
#
# # Convert list of class IDs to a tensor
# class_ids = torch.tensor(class_ids)
# # Convert loc_feats and locs to tensors
# loc_feats = torch.stack(loc_feats)
# locs = torch.stack(locs)
# text_embs = torch.stack(text_embs)
# image_embs = torch.stack(image_embs)
#
# padded_context_locs = pad_sequence(context_locss, batch_first=True, padding_value=-10)
#
# # Create a mask for sequences based on padding
# # sequence_mask = (padded_sequences == 0) # Create a mask where there's padding (-10s)
# sequence_mask = (padded_sequences == -10).all(dim=-1)
#
# # return padded_sequences, padded_locs, class_ids, sequence_mask
# return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, text_embs, image_embs
#
# def get_item_from_class(self, class_id):
# # Fetch all locations and features for the given class
# all_class_locs = self.per_class_locs[class_id]
# all_class_loc_feats = self.per_class_loc_feats[class_id]
#
# # Randomly select an index for the class
# index = np.random.choice(len(all_class_locs))
#
# # Retrieve the selected location and its features
# loc = all_class_locs[index]
# if loc.ndim == 1:
# loc = loc.unsqueeze(0)
# # loc = loc.unsqueeze(0)
# loc_feat = self.enc.encode(loc, normalize=False)
# # loc_feat = all_class_loc_feats[index]
#
# # Define a unique class token index
# class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token
#
# # Exclude the selected index from the context
# filtered_local_indices = torch.arange(len(all_class_locs)) != index
#
# # Select random or all indices depending on the availability relative to `num_context`
# if filtered_local_indices.sum() > self.num_context:
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
# np.random.shuffle(selected_indices)
# selected_indices = selected_indices[:self.num_context]
# else:
# selected_indices = filtered_local_indices.nonzero().squeeze()
#
# # Get context locations and their features
# context_loc_feats = all_class_loc_feats[selected_indices]
# context_locs = all_class_locs[selected_indices]
#
# # Check if context_loc_feats has 1 dimension and add another if it does
# if context_loc_feats.dim() == 1:
# context_loc_feats = context_loc_feats.unsqueeze(0)
# # Check if context_locs has 1 dimension and add another if it does
# if context_locs.dim() == 1:
# context_locs = context_locs.unsqueeze(0)
#
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
#
# # text stuff
# # Get the embedding for the right class
# if class_id in self.text_class_emb_dict:
# text_embs_indexes, descriptions = self.text_class_emb_dict[class_id]
# # Randomly select an index from the list of indices
# selected_index = random.choice(text_embs_indexes)
# # Use the selected index to retrieve the corresponding element from embs
# text_emb = self.text_embs[selected_index]
# else:
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later?
# # emb = None
# text_emb = torch.zeros(self.text_emb_size)
#
# # image stuff
# # Get the embedding for the right class
# if class_id in self.image_class_emb_dict:
# image_embs_indexes, descriptions = self.image_class_emb_dict[class_id]
# # Randomly select an index from the list of indices
# selected_index = random.choice(image_embs_indexes)
# # Use the selected index to retrieve the corresponding element from embs
# image_emb = self.image_embs[selected_index]
# else:
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later?
# # emb = None
# image_emb = torch.zeros(self.image_emb_size)
#
# return loc_feat, loc, class_id, context_sequence, context_locs, text_emb, image_emb
#
# def select_text_section(self, text_section):
# # Initialize an empty dictionary to store the result
# text_class_emb_dict = {}
# # Populate the dictionary
# for i, (index, description) in enumerate(self.text_embs_keys):
# # Find the class using the index from the class_list
# class_id = self.text_embs_class_ids[index]
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
# if class_id == -1:
# continue
# if description != text_section:
# continue
# # Check if the class_id is already a key in the dictionary
# if class_id not in text_class_emb_dict:
# # Initialize with empty lists if class_id is not already in the dictionary
# text_class_emb_dict[class_id] = ([], [])
#
# # Append the description and the index of embs_keys to the corresponding lists
# text_class_emb_dict[class_id][0].append(i)
# text_class_emb_dict[class_id][1].append(description)
# self.text_class_emb_dict = text_class_emb_dict
# MINE MINE MINE MINE - should only be used for my models
# this eval version always has only a single example from each class (with the appropriate number of context points)
# So on iterating through the dataset we can create our models "eval embeddings" for every class
# Each eval embedding is generated from a single forward pass of "num_context" context points
class EvalTransformerLocationDataset(torch.utils.data.Dataset):
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0,
noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False,
variable_context_length=False):
# Handle input encoding
self.input_enc = input_enc
if self.input_enc in ['env', 'sin_cos_env']:
raster = load_env()
else:
raster = None
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
# Handle transformer input encoding
self.transformer_input_enc = transformer_input_enc
if self.transformer_input_enc in ['env', 'sin_cos_env']:
transformer_raster = load_env()
else:
transformer_raster = None
if self.transformer_input_enc == 'sinr':
self.transformer_enc = self.enc
else:
self.transformer_enc = utils.CoordEncoder(
transformer_input_enc, transformer_raster, input_dim=token_dim)
# Define properties
self.locs = locs # Keep on CPU
self.labels = labels # Keep on CPU
self.classes = classes
self.class_to_taxa = class_to_taxa
# Normalize locs and create loc_feats
self.loc_feats = self.enc.encode(self.locs, normalize=True)
transformer_loc_feats = self.transformer_enc.encode(
self.locs, normalize=False)
if dates is not None:
self.dates = dates
self.enc_time = utils.TimeEncoder()
# Useful numbers
self.num_classes = len(np.unique(labels))
self.input_dim = input_dim
self.time_dim = time_dim
self.noise_time = noise_time
self.num_context = num_context
self.token_dim = token_dim
# Remove device assignments from rasters
# if self.enc.raster is not None:
# self.enc.raster = self.enc.raster.to(device)
# if self.transformer_enc.raster is not None:
# self.transformer_enc.raster = self.transformer_enc.raster.to(device)
# Organize data into dictionaries
per_class_location_dict = organize_data_by_labels(
np.array(labels), np.array(locs))
per_class_loc_feats_dict = organize_data_by_labels(
np.array(labels), np.array(transformer_loc_feats))
for key, value in per_class_location_dict.items():
per_class_location_dict[key] = torch.tensor(
np.array(value)) # Keep on CPU
for key, value in per_class_loc_feats_dict.items():
per_class_loc_feats_dict[key] = torch.tensor(
np.array(value)) # Keep on CPU
self.per_class_locs = per_class_location_dict
self.per_class_loc_feats = per_class_loc_feats_dict
# Select a single example per class
unique_labels, unique_indices = np.unique(labels, return_index=True)
self.locs = locs[unique_indices]
self.labels = labels[unique_indices]
self.loc_feats = self.loc_feats[unique_indices]
def __len__(self):
return self.loc_feats.shape[0]
def __getitem__(self, index):
# Retrieve feature and class of the original point
loc_feat = self.loc_feats[index, :] # On CPU
loc = self.locs[index, :] # On CPU
class_id = self.labels[index] # On CPU
class_id_int = class_id.item()
# Fetch all locations for the class
all_class_locs = self.per_class_locs[class_id_int] # On CPU
all_class_loc_feats = self.per_class_loc_feats[class_id_int] # On CPU
# Define a unique class token index
class_token_feature = torch.zeros(
(1, len(all_class_loc_feats[0]))) # CPU tensor
# Find the index of the original location
matches = (all_class_locs == loc).all(dim=1)
local_index = torch.where(matches)[0]
if len(local_index) > 1:
local_index = local_index[0]
# Exclude the original location's index
filtered_local_indices = torch.arange(
len(all_class_locs)) != local_index
# Select indices for context
if filtered_local_indices.sum() > self.num_context:
selected_indices = filtered_local_indices.nonzero().squeeze()
# No shuffling for evaluation
selected_indices = selected_indices[:self.num_context]
else:
selected_indices = filtered_local_indices.nonzero().squeeze()
# Get context locations and features
context_loc_feats = all_class_loc_feats[selected_indices]
context_locs = all_class_locs[selected_indices]
# Adjust dimensions if necessary
if context_loc_feats.dim() == 1:
context_loc_feats = context_loc_feats.unsqueeze(0)
if context_locs.dim() == 1:
context_locs = context_locs.unsqueeze(0)
context_sequence = torch.cat(
[class_token_feature, context_loc_feats], dim=0)
return loc_feat, loc, class_id, context_sequence, context_locs
def collate_fn(self, batch):
# Unpack the batch
loc_feats, locs, class_ids, context_sequences, context_locss = zip(*batch)
# Pad sequences
padded_sequences = pad_sequence(
context_sequences, batch_first=True, padding_value=-10)
padded_context_locs = pad_sequence(
context_locss, batch_first=True, padding_value=-10)
# Convert lists to tensors
loc_feats = torch.stack(loc_feats)
locs = torch.stack(locs)
class_ids = torch.tensor(class_ids)
# Create a mask for sequences based on padding
sequence_mask = (padded_sequences == -10).all(dim=-1)
return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask
def get_item_from_class(self, class_id):
# Fetch locations and features for the class
all_class_locs = self.per_class_locs[class_id]
all_class_loc_feats = self.per_class_loc_feats[class_id]
# Randomly select an index
index = np.random.choice(len(all_class_locs))
# Retrieve selected location and features
loc = all_class_locs[index]
if loc.ndim == 1:
loc = loc.unsqueeze(0)
loc_feat = self.enc.encode(loc, normalize=False)
# Define a unique class token index
class_token_feature = torch.zeros((1, self.token_dim)) # CPU tensor
# Exclude selected index from context
filtered_local_indices = torch.arange(len(all_class_locs)) != index
# Select indices for context
if filtered_local_indices.sum() > self.num_context:
selected_indices = filtered_local_indices.nonzero().squeeze()
perm = torch.randperm(selected_indices.size(0))
selected_indices = selected_indices[perm][:self.num_context]
else:
selected_indices = filtered_local_indices.nonzero().squeeze()
# Get context locations and features
context_loc_feats = all_class_loc_feats[selected_indices]
context_locs = all_class_locs[selected_indices]
# Adjust dimensions if necessary
if context_loc_feats.dim() == 1:
context_loc_feats = context_loc_feats.unsqueeze(0)
if context_locs.dim() == 1:
context_locs = context_locs.unsqueeze(0)
context_sequence = torch.cat(
[class_token_feature, context_loc_feats], dim=0)
return loc_feat, loc, class_id, context_sequence, context_locs
# class EvalTransformerLocationDataset(torch.utils.data.Dataset):
# def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0,
# noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False,
# variable_context_length=False):
# # Handle input encoding
# self.input_enc = input_enc
# if self.input_enc in ['env', 'sin_cos_env']:
# raster = load_env()
# else:
# raster = None
# self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
#
# # Handle transformer input encoding
# self.transformer_input_enc = transformer_input_enc
# if self.transformer_input_enc in ['env', 'sin_cos_env']:
# transformer_raster = load_env()
# else:
# transformer_raster = None
# if self.transformer_input_enc == 'sinr':
# self.transformer_enc = self.enc
# else:
# self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim)
#
# # Define some properties
# self.locs = locs
# self.labels = labels
# self.classes = classes
# self.class_to_taxa = class_to_taxa
#
# # Normalize locs and create loc_feats
# self.loc_feats = self.enc.encode(self.locs, normalize=True)
# transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False)
#
# if dates is not None:
# self.dates = dates
# self.enc_time = utils.TimeEncoder()
#
# # Useful numbers
# self.num_classes = len(np.unique(labels))
# self.input_dim = input_dim
# self.time_dim = time_dim
# self.noise_time = noise_time
# self.num_context = num_context
# self.token_dim = token_dim
#
# if self.enc.raster is not None:
# self.enc.raster = self.enc.raster.to(device)
#
# if self.transformer_enc.raster is not None:
# self.transformer_enc.raster = self.transformer_enc.raster.to(device)
#
# # Organize the data into the dictionary
# per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(locs))
# per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats))
# for key, value in per_class_location_dict.items():
# per_class_location_dict[key] = torch.tensor(np.array(value))
# for key, value in per_class_loc_feats_dict.items():
# per_class_loc_feats_dict[key] = torch.tensor(np.array(value))
# self.per_class_locs = per_class_location_dict
# self.per_class_loc_feats = per_class_loc_feats_dict
#
# # Select a single example per class
# unique_labels, unique_indices = np.unique(labels, return_index=True)
# self.locs = locs[unique_indices]
# self.labels = labels[unique_indices]
# self.loc_feats = self.loc_feats[unique_indices]
#
# def __len__(self):
# return self.loc_feats.shape[0]
#
# def __getitem__(self, index):
# # Retrieve the feature and class of the original point
# loc_feat = self.loc_feats[index, :]
# loc = self.locs[index, :]
# class_id = self.labels[index]
# class_id_int = class_id.item()
#
# # Fetch all locations for the given class
# all_class_locs = self.per_class_locs[class_id_int]
# all_class_loc_feats = self.per_class_loc_feats[class_id_int]
#
# # Define a unique class token index
# #class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token
# class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # Create a zero vector for the class token
#
#
# # Broadcast and compare to find all matching locations
# matches = (all_class_locs == loc).all(dim=1)
#
# # Find the index of the original location
# local_index = torch.where(matches)[0]
# if len(local_index) > 1:
# local_index = local_index[0]
#
# # Exclude the original location's index
# filtered_local_indices = torch.arange(len(all_class_locs)) != local_index
#
# # Select random or all indices depending on the availability relative to `num_context`
# if filtered_local_indices.sum() > self.num_context:
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
# # no shuffling for eval
# # np.random.shuffle(selected_indices)
# selected_indices = selected_indices[:self.num_context]
# else:
# selected_indices = filtered_local_indices.nonzero().squeeze()
#
# # Get context locations and their features
# context_loc_feats = all_class_loc_feats[selected_indices]
# context_locs = all_class_locs[selected_indices]
#
# # Check if context_loc_feats has 1 dimension and add another if it does
# if context_loc_feats.dim() == 1:
# context_loc_feats = context_loc_feats.unsqueeze(0)
# # Check if context_locs has 1 dimension and add another if it does
# if context_locs.dim() == 1:
# context_locs = context_locs.unsqueeze(0)
#
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
#
# return loc_feat, loc, class_id, context_sequence, context_locs
#
# def collate_fn(self, batch):
# # Unzip the batch
# loc_feats, locs, class_ids, context_sequences, context_locss = zip(*batch)
#
# # Convert list of sequences to a tensor with padding
# padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10)
#
# # Convert list of class IDs to a tensor
# class_ids = torch.tensor(class_ids)
# # Convert loc_feats and locs to tensors
# loc_feats = torch.stack(loc_feats)
# locs = torch.stack(locs)
#
# padded_context_locs = pad_sequence(context_locss, batch_first=True, padding_value=-10)
#
# # Create a mask for sequences based on padding
# # sequence_mask = (padded_sequences == 0) # Create a mask where there's padding (0s)
# sequence_mask = (padded_sequences == -10).all(dim=-1)
#
# # return padded_sequences, padded_locs, class_ids, sequence_mask
# return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask
#
# def get_item_from_class(self, class_id):
# # Fetch all locations and features for the given class
# all_class_locs = self.per_class_locs[class_id]
# all_class_loc_feats = self.per_class_loc_feats[class_id]
#
# # Randomly select an index for the class
# index = np.random.choice(len(all_class_locs))
#
# # Retrieve the selected location and its features
# loc = all_class_locs[index]
# if loc.ndim == 1:
# loc = loc.unsqueeze(0)
# # loc = loc.unsqueeze(0)
# loc_feat = self.enc.encode(loc, normalize=False)
# # loc_feat = all_class_loc_feats[index]
#
# # Define a unique class token index
# class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token
#
# # Exclude the selected index from the context
# filtered_local_indices = torch.arange(len(all_class_locs)) != index
#
# # Select random or all indices depending on the availability relative to `num_context`
# if filtered_local_indices.sum() > self.num_context:
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
# np.random.shuffle(selected_indices)
# selected_indices = selected_indices[:self.num_context]
# else:
# selected_indices = filtered_local_indices.nonzero().squeeze()
#
# # Get context locations and their features
# context_loc_feats = all_class_loc_feats[selected_indices]
# context_locs = all_class_locs[selected_indices]
#
# # Check if context_loc_feats has 1 dimension and add another if it does
# if context_loc_feats.dim() == 1:
# context_loc_feats = context_loc_feats.unsqueeze(0)
# # Check if context_locs has 1 dimension and add another if it does
# if context_locs.dim() == 1:
# context_locs = context_locs.unsqueeze(0)
#
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
#
# return loc_feat, loc, class_id, context_sequence, context_locs
# MINE MINE MINE MINE - should only be used for my models
# this eval version always has only a single example from each class (with the appropriate number of context points)
# So on iterating through the dataset we can create our models "eval embeddings" for every class
# Each eval embedding is generated from a single forward pass of "num_context" context points + text
class EvalTransformerTextLocationDataset(torch.utils.data.Dataset):
def __init__(self, locs, labels, classes, class_to_taxa, embs, embs_ids, embs_keys, input_enc, device, dates=None,
input_dim=4, time_dim=0, noise_time=False, num_context=50, transformer_input_enc=None,
token_dim=None, jitter=False, variable_context_length=False):
# Handle input encoding
self.input_enc = input_enc
if self.input_enc in ['env', 'sin_cos_env']:
raster = load_env()
else:
raster = None
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
# Handle transformer input encoding
self.transformer_input_enc = transformer_input_enc
if self.transformer_input_enc in ['env', 'sin_cos_env']:
transformer_raster = load_env()
else:
transformer_raster = None
if self.transformer_input_enc == 'sinr':
self.transformer_enc = self.enc
else:
self.transformer_enc = utils.CoordEncoder(
transformer_input_enc, transformer_raster, input_dim=token_dim)
# Define properties
self.locs = locs # Keep on CPU
self.labels = labels # Keep on CPU
self.classes = classes
self.class_to_taxa = class_to_taxa
# Normalize locs and create loc_feats
self.loc_feats = self.enc.encode(self.locs, normalize=True)
transformer_loc_feats = self.transformer_enc.encode(
self.locs, normalize=False)
if dates is not None:
self.dates = dates
self.enc_time = utils.TimeEncoder()
# Useful numbers
self.num_classes = len(np.unique(labels))
self.input_dim = input_dim
self.time_dim = time_dim
self.noise_time = noise_time
self.num_context = num_context
self.token_dim = token_dim
# Remove device assignments from rasters
# if self.enc.raster is not None:
# self.enc.raster = self.enc.raster.to(device)
# if self.transformer_enc.raster is not None:
# self.transformer_enc.raster = self.transformer_enc.raster.to(device)
# Text embeddings
self.embs = embs # Keep on CPU
self.embs_ids = embs_ids.tolist()
self.embs_class_ids = [class_to_taxa.index(
taxa) if taxa in class_to_taxa else -1 for taxa in self.embs_ids]
self.embs_keys = embs_keys
# Initialize class embedding dictionary
class_emb_dict = {}
for i, (index, description) in enumerate(embs_keys):
class_id = self.embs_class_ids[index]
if class_id == -1:
continue
if class_id not in class_emb_dict:
class_emb_dict[class_id] = ([], [])
class_emb_dict[class_id][0].append(i)
class_emb_dict[class_id][1].append(description)
self.class_emb_dict = class_emb_dict
# Organize data into dictionaries
per_class_location_dict = organize_data_by_labels(
np.array(labels), np.array(locs))
per_class_loc_feats_dict = organize_data_by_labels(
np.array(labels), np.array(transformer_loc_feats))
for key, value in per_class_location_dict.items():
per_class_location_dict[key] = torch.tensor(
np.array(value)) # Keep on CPU
for key, value in per_class_loc_feats_dict.items():
per_class_loc_feats_dict[key] = torch.tensor(
np.array(value)) # Keep on CPU
self.per_class_locs = per_class_location_dict
self.per_class_loc_feats = per_class_loc_feats_dict
# Select a single example per class
unique_labels, unique_indices = np.unique(labels, return_index=True)
self.locs = locs[unique_indices]
self.labels = labels[unique_indices]
self.loc_feats = self.loc_feats[unique_indices]
def __len__(self):
return self.loc_feats.shape[0]
def __getitem__(self, index):
# Retrieve feature and class of the original point
loc_feat = self.loc_feats[index, :] # On CPU
loc = self.locs[index, :] # On CPU
class_id = self.labels[index] # On CPU
class_id_int = class_id.item()
# Fetch all locations for the class
all_class_locs = self.per_class_locs[class_id_int] # On CPU
all_class_loc_feats = self.per_class_loc_feats[class_id_int] # On CPU
# Define a unique class token index
class_token_feature = torch.zeros(
(1, len(all_class_loc_feats[0]))) # CPU tensor
# Find the index of the original location
matches = (all_class_locs == loc).all(dim=1)
local_index = torch.where(matches)[0]
if len(local_index) > 1:
local_index = local_index[0]
# Exclude the original location's index
filtered_local_indices = torch.arange(
len(all_class_locs)) != local_index
# Select indices for context
if filtered_local_indices.sum() > self.num_context:
selected_indices = filtered_local_indices.nonzero().squeeze()
# No shuffling for evaluation
selected_indices = selected_indices[:self.num_context]
else:
selected_indices = filtered_local_indices.nonzero().squeeze()
# Get context locations and features
context_loc_feats = all_class_loc_feats[selected_indices]
context_locs = all_class_locs[selected_indices]
# Adjust dimensions if necessary
if context_loc_feats.dim() == 1:
context_loc_feats = context_loc_feats.unsqueeze(0)
if context_locs.dim() == 1:
context_locs = context_locs.unsqueeze(0)
context_sequence = torch.cat(
[class_token_feature, context_loc_feats], dim=0)
# Text embeddings
if class_id_int in self.class_emb_dict:
embs_indexes, descriptions = self.class_emb_dict[class_id_int]
selected_index = random.choice(embs_indexes)
emb = self.embs[selected_index]
else:
emb = torch.zeros(4096) # On CPU
return loc_feat, loc, class_id, context_sequence, context_locs, emb
def collate_fn(self, batch):
# Unpack the batch
loc_feats, locs, class_ids, context_sequences, context_locss, embs = zip(*batch)
# Pad sequences
padded_sequences = pad_sequence(
context_sequences, batch_first=True, padding_value=-10)
padded_context_locs = pad_sequence(
context_locss, batch_first=True, padding_value=-10)
# Convert lists to tensors
loc_feats = torch.stack(loc_feats)
locs = torch.stack(locs)
class_ids = torch.tensor(class_ids)
embs = torch.stack(embs)
# Create a mask for sequences based on padding
sequence_mask = (padded_sequences == -10).all(dim=-1)
return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, embs
def get_item_from_class(self, class_id):
# Fetch locations and features for the class
all_class_locs = self.per_class_locs[class_id]
all_class_loc_feats = self.per_class_loc_feats[class_id]
# Randomly select an index
index = np.random.choice(len(all_class_locs))
# Retrieve selected location and features
loc = all_class_locs[index]
if loc.ndim == 1:
loc = loc.unsqueeze(0)
loc_feat = self.enc.encode(loc, normalize=False)
# Define a unique class token index
class_token_feature = torch.zeros((1, self.token_dim)) # CPU tensor
# Exclude selected index from context
filtered_local_indices = torch.arange(len(all_class_locs)) != index
# Select indices for context
if filtered_local_indices.sum() > self.num_context:
selected_indices = filtered_local_indices.nonzero().squeeze()
perm = torch.randperm(selected_indices.size(0))
selected_indices = selected_indices[perm][:self.num_context]
else:
selected_indices = filtered_local_indices.nonzero().squeeze()
# Get context locations and features
context_loc_feats = all_class_loc_feats[selected_indices]
context_locs = all_class_locs[selected_indices]
# Adjust dimensions if necessary
if context_loc_feats.dim() == 1:
context_loc_feats = context_loc_feats.unsqueeze(0)
if context_locs.dim() == 1:
context_locs = context_locs.unsqueeze(0)
context_sequence = torch.cat(
[class_token_feature, context_loc_feats], dim=0)
# Text embeddings
if class_id in self.class_emb_dict:
embs_indexes, descriptions = self.class_emb_dict[class_id]
selected_index = random.choice(embs_indexes)
emb = self.embs[selected_index]
else:
emb = torch.zeros(4096) # On CPU
return loc_feat, loc, class_id, context_sequence, context_locs, emb
def select_text_section(self, text_section):
# Initialize an empty dictionary to store the result
class_emb_dict = {}
# Populate the dictionary
for i, (index, description) in enumerate(self.embs_keys):
# Find the class using the index from the class_list
class_id = self.embs_class_ids[index]
# Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
if class_id == -1:
continue
if description != text_section:
continue
# Check if the class_id is already a key in the dictionary
if class_id not in class_emb_dict:
# Initialize with empty lists if class_id is not already in the dictionary
class_emb_dict[class_id] = ([], [])
# Append the description and the index of embs_keys to the corresponding lists
class_emb_dict[class_id][0].append(i)
class_emb_dict[class_id][1].append(description)
self.class_emb_dict = class_emb_dict
class EvalTransformerDummyTextLocationDataset(torch.utils.data.Dataset):
def __init__(self, locs, labels, classes, class_to_taxa, embs, embs_ids, embs_keys, input_enc, device, dates=None,
input_dim=4, time_dim=0, noise_time=False, num_context=50, transformer_input_enc=None,
token_dim=None, jitter=False, variable_context_length=False):
# Handle input encoding
self.input_enc = input_enc
if self.input_enc in ['env', 'sin_cos_env']:
raster = load_env()
else:
raster = None
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
# Handle transformer input encoding
self.transformer_input_enc = transformer_input_enc
if self.transformer_input_enc in ['env', 'sin_cos_env']:
transformer_raster = load_env()
else:
transformer_raster = None
if self.transformer_input_enc == 'sinr':
self.transformer_enc = self.enc
else:
self.transformer_enc = utils.CoordEncoder(
transformer_input_enc, transformer_raster, input_dim=token_dim)
# Define properties
self.locs = locs # Keep on CPU
self.labels = labels # Keep on CPU
self.classes = classes
self.class_to_taxa = class_to_taxa
# Normalize locs and create loc_feats
self.loc_feats = self.enc.encode(self.locs, normalize=True)
transformer_loc_feats = self.transformer_enc.encode(
self.locs, normalize=False)
if dates is not None:
self.dates = dates
self.enc_time = utils.TimeEncoder()
# Useful numbers
self.num_classes = len(np.unique(labels))
self.input_dim = input_dim
self.time_dim = time_dim
self.noise_time = noise_time
self.num_context = num_context
self.token_dim = token_dim
# Remove device assignments from rasters
# if self.enc.raster is not None:
# self.enc.raster = self.enc.raster.to(device)
# if self.transformer_enc.raster is not None:
# self.transformer_enc.raster = self.transformer_enc.raster.to(device)
# Text embeddings
self.embs = embs # Keep on CPU
self.embs_ids = embs_ids.tolist()
self.embs_class_ids = [class_to_taxa.index(
taxa) if taxa in class_to_taxa else -1 for taxa in self.embs_ids]
self.embs_keys = embs_keys
# Initialize class embedding dictionary
class_emb_dict = {}
for i, (index, description) in enumerate(embs_keys):
class_id = self.embs_class_ids[index]
if class_id == -1:
continue
if class_id not in class_emb_dict:
class_emb_dict[class_id] = ([], [])
class_emb_dict[class_id][0].append(i)
class_emb_dict[class_id][1].append(description)
self.class_emb_dict = class_emb_dict
# Organize data into dictionaries
per_class_location_dict = organize_data_by_labels(
np.array(labels), np.array(locs))
per_class_loc_feats_dict = organize_data_by_labels(
np.array(labels), np.array(transformer_loc_feats))
for key, value in per_class_location_dict.items():
per_class_location_dict[key] = torch.tensor(
np.array(value)) # Keep on CPU
for key, value in per_class_loc_feats_dict.items():
per_class_loc_feats_dict[key] = torch.tensor(
np.array(value)) # Keep on CPU
self.per_class_locs = per_class_location_dict
self.per_class_loc_feats = per_class_loc_feats_dict
# Select a single example per class
unique_labels, unique_indices = np.unique(labels, return_index=True)
self.locs = locs[unique_indices]
self.labels = labels[unique_indices]
self.loc_feats = self.loc_feats[unique_indices]
def __len__(self):
return self.loc_feats.shape[0]
def __getitem__(self, index):
# Retrieve feature and class of the original point
loc_feat = self.loc_feats[index, :] # On CPU
loc = self.locs[index, :] # On CPU
class_id = self.labels[index] # On CPU
class_id_int = class_id.item()
# Fetch all locations for the class
all_class_locs = self.per_class_locs[class_id_int] # On CPU
all_class_loc_feats = self.per_class_loc_feats[class_id_int] # On CPU
# Define a unique class token index
class_token_feature = torch.zeros(
(1, len(all_class_loc_feats[0]))) # CPU tensor
# Find the index of the original location
matches = (all_class_locs == loc).all(dim=1)
local_index = torch.where(matches)[0]
if len(local_index) > 1:
local_index = local_index[0]
# Exclude the original location's index
filtered_local_indices = torch.arange(
len(all_class_locs)) != local_index
# Select indices for context
if filtered_local_indices.sum() > self.num_context:
selected_indices = filtered_local_indices.nonzero().squeeze()
# No shuffling for evaluation
selected_indices = selected_indices[:self.num_context]
else:
selected_indices = filtered_local_indices.nonzero().squeeze()
# Get context locations and features
context_loc_feats = all_class_loc_feats[selected_indices]
context_locs = all_class_locs[selected_indices]
# Adjust dimensions if necessary
if context_loc_feats.dim() == 1:
context_loc_feats = context_loc_feats.unsqueeze(0)
if context_locs.dim() == 1:
context_locs = context_locs.unsqueeze(0)
context_sequence = torch.cat(
[class_token_feature, context_loc_feats], dim=0)
# Text embeddings
emb = torch.zeros(4096) # On CPU
return loc_feat, loc, class_id, context_sequence, context_locs, emb
def collate_fn(self, batch):
# Unpack the batch
loc_feats, locs, class_ids, context_sequences, context_locss, embs = zip(*batch)
# Pad sequences
padded_sequences = pad_sequence(
context_sequences, batch_first=True, padding_value=-10)
padded_context_locs = pad_sequence(
context_locss, batch_first=True, padding_value=-10)
# Convert lists to tensors
loc_feats = torch.stack(loc_feats)
locs = torch.stack(locs)
class_ids = torch.tensor(class_ids)
embs = torch.stack(embs)
# Create a mask for sequences based on padding
sequence_mask = (padded_sequences == -10).all(dim=-1)
return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, embs
def get_item_from_class(self, class_id):
# Fetch locations and features for the class
all_class_locs = self.per_class_locs[class_id]
all_class_loc_feats = self.per_class_loc_feats[class_id]
# Randomly select an index
index = np.random.choice(len(all_class_locs))
# Retrieve selected location and features
loc = all_class_locs[index]
if loc.ndim == 1:
loc = loc.unsqueeze(0)
loc_feat = self.enc.encode(loc, normalize=False)
# Define a unique class token index
class_token_feature = torch.zeros((1, self.token_dim)) # CPU tensor
# Exclude selected index from context
filtered_local_indices = torch.arange(len(all_class_locs)) != index
# Select indices for context
if filtered_local_indices.sum() > self.num_context:
selected_indices = filtered_local_indices.nonzero().squeeze()
perm = torch.randperm(selected_indices.size(0))
selected_indices = selected_indices[perm][:self.num_context]
else:
selected_indices = filtered_local_indices.nonzero().squeeze()
# Get context locations and features
context_loc_feats = all_class_loc_feats[selected_indices]
context_locs = all_class_locs[selected_indices]
# Adjust dimensions if necessary
if context_loc_feats.dim() == 1:
context_loc_feats = context_loc_feats.unsqueeze(0)
if context_locs.dim() == 1:
context_locs = context_locs.unsqueeze(0)
context_sequence = torch.cat(
[class_token_feature, context_loc_feats], dim=0)
# Text embeddings
if class_id in self.class_emb_dict:
embs_indexes, descriptions = self.class_emb_dict[class_id]
selected_index = random.choice(embs_indexes)
emb = self.embs[selected_index]
else:
emb = torch.zeros(4096) # On CPU
return loc_feat, loc, class_id, context_sequence, context_locs, emb
def select_text_section(self, text_section):
# Initialize an empty dictionary to store the result
class_emb_dict = {}
# Populate the dictionary
for i, (index, description) in enumerate(self.embs_keys):
# Find the class using the index from the class_list
class_id = self.embs_class_ids[index]
# Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
if class_id == -1:
continue
if description != text_section:
continue
# Check if the class_id is already a key in the dictionary
if class_id not in class_emb_dict:
# Initialize with empty lists if class_id is not already in the dictionary
class_emb_dict[class_id] = ([], [])
# Append the description and the index of embs_keys to the corresponding lists
class_emb_dict[class_id][0].append(i)
class_emb_dict[class_id][1].append(description)
self.class_emb_dict = class_emb_dict
#
# class EvalTransformerTextLocationDataset(torch.utils.data.Dataset):
# def __init__(self, locs, labels, classes, class_to_taxa, embs, embs_ids, embs_keys, input_enc, device, dates=None, input_dim=4, time_dim=0,
# noise_time=False, num_context=50, transformer_input_enc=None, token_dim=None, jitter=False,
# variable_context_length=False):
# # Handle input encoding
# self.input_enc = input_enc
# if self.input_enc in ['env', 'sin_cos_env']:
# raster = load_env()
# else:
# raster = None
# self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
#
# # Handle transformer input encoding
# self.transformer_input_enc = transformer_input_enc
# if self.transformer_input_enc in ['env', 'sin_cos_env']:
# transformer_raster = load_env()
# else:
# transformer_raster = None
# if self.transformer_input_enc == 'sinr':
# self.transformer_enc = self.enc
# else:
# self.transformer_enc = utils.CoordEncoder(transformer_input_enc, transformer_raster, input_dim=token_dim)
#
# # Define some properties
# self.device = device
# self.locs = locs
# self.labels = labels
# self.classes = classes
# self.class_to_taxa = class_to_taxa
#
# # Normalize locs and create loc_feats
# self.loc_feats = self.enc.encode(self.locs, normalize=True)
# transformer_loc_feats = self.transformer_enc.encode(self.locs, normalize=False)
#
# if dates is not None:
# self.dates = dates
# self.enc_time = utils.TimeEncoder()
#
# # Useful numbers
# self.num_classes = len(np.unique(labels))
# self.input_dim = input_dim
# self.time_dim = time_dim
# self.noise_time = noise_time
# self.num_context = num_context
# self.token_dim = token_dim
#
# if self.enc.raster is not None:
# self.enc.raster = self.enc.raster.to(device)
#
# if self.transformer_enc.raster is not None:
# self.transformer_enc.raster = self.transformer_enc.raster.to(device)
#
# # text stuff
# self.embs = embs
# self.embs_ids = embs_ids.tolist()
# self.embs_class_ids = [class_to_taxa.index(taxa) if taxa in class_to_taxa else -1 for taxa in self.embs_ids]
# self.embs_keys = embs_keys
#
# # Initialize an empty dictionary to store the result
# class_emb_dict = {}
# # Populate the dictionary
# for i, (index, description) in enumerate(embs_keys):
# # Find the class using the index from the class_list
# class_id = self.embs_class_ids[index]
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
# if class_id == -1:
# continue
# # Check if the class_id is already a key in the dictionary
# if class_id not in class_emb_dict:
# # Initialize with empty lists if class_id is not already in the dictionary
# class_emb_dict[class_id] = ([], [])
#
# # Append the description and the index of embs_keys to the corresponding lists
# class_emb_dict[class_id][0].append(i)
# class_emb_dict[class_id][1].append(description)
# self.class_emb_dict = class_emb_dict
#
# # Organize the data into the dictionary
# per_class_location_dict = organize_data_by_labels(np.array(labels), np.array(locs))
# per_class_loc_feats_dict = organize_data_by_labels(np.array(labels), np.array(transformer_loc_feats))
# for key, value in per_class_location_dict.items():
# per_class_location_dict[key] = torch.tensor(np.array(value))
# for key, value in per_class_loc_feats_dict.items():
# per_class_loc_feats_dict[key] = torch.tensor(np.array(value))
# self.per_class_locs = per_class_location_dict
# self.per_class_loc_feats = per_class_loc_feats_dict
#
# # Select a single example per class
# unique_labels, unique_indices = np.unique(labels, return_index=True)
# self.locs = locs[unique_indices]
# self.labels = labels[unique_indices]
# self.loc_feats = self.loc_feats[unique_indices]
#
# def __len__(self):
# return self.loc_feats.shape[0]
#
# def __getitem__(self, index):
# # Retrieve the feature and class of the original point
# loc_feat = self.loc_feats[index, :]
# loc = self.locs[index, :]
# class_id = self.labels[index]
# class_id_int = class_id.item()
#
# # Fetch all locations for the given class
# all_class_locs = self.per_class_locs[class_id_int]
# all_class_loc_feats = self.per_class_loc_feats[class_id_int]
#
# # Define a unique class token index
# #class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token
# class_token_feature = torch.zeros((1, len(all_class_loc_feats[0]))) # Create a zero vector for the class token
#
#
# # Broadcast and compare to find all matching locations
# matches = (all_class_locs == loc).all(dim=1)
#
# # Find the index of the original location
# local_index = torch.where(matches)[0]
# if len(local_index) > 1:
# local_index = local_index[0]
#
# # Exclude the original location's index
# filtered_local_indices = torch.arange(len(all_class_locs)) != local_index
#
# # Select random or all indices depending on the availability relative to `num_context`
# if filtered_local_indices.sum() > self.num_context:
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
# # no shuffling for eval
# # np.random.shuffle(selected_indices)
# selected_indices = selected_indices[:self.num_context]
# else:
# selected_indices = filtered_local_indices.nonzero().squeeze()
#
# # Get context locations and their features
# context_loc_feats = all_class_loc_feats[selected_indices]
# context_locs = all_class_locs[selected_indices]
#
# # Check if context_loc_feats has 1 dimension and add another if it does
# if context_loc_feats.dim() == 1:
# context_loc_feats = context_loc_feats.unsqueeze(0)
# # Check if context_locs has 1 dimension and add another if it does
# if context_locs.dim() == 1:
# context_locs = context_locs.unsqueeze(0)
#
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
#
# # text stuff
# # Get the embedding for the right class
# if class_id_int in self.class_emb_dict:
# embs_indexes, descriptions = self.class_emb_dict[class_id_int]
# # Randomly select an index from the list of indices
# selected_index = random.choice(embs_indexes)
# # Use the selected index to retrieve the corresponding element from embs
# emb = self.embs[selected_index]
# else:
# # If the class_id_int is not in the dictionary, set emb to all zeros for filtering later?
# # emb = None
# emb = torch.zeros(4096, device=self.embs.device)
#
# return loc_feat, loc, class_id, context_sequence, context_locs, emb
#
# def collate_fn(self, batch):
# # Unzip the batch
# loc_feats, locs, class_ids, context_sequences, context_locss, embs = zip(*batch)
#
# # Convert list of sequences to a tensor with padding
# padded_sequences = pad_sequence(context_sequences, batch_first=True, padding_value=-10)
#
# # Convert list of class IDs to a tensor
# class_ids = torch.tensor(class_ids)
# # Convert loc_feats and locs to tensors
# loc_feats = torch.stack(loc_feats)
# locs = torch.stack(locs)
# embs = torch.stack(embs)
#
# padded_context_locs = pad_sequence(context_locss, batch_first=True, padding_value=-10)
#
# # Create a mask for sequences based on padding
# # sequence_mask = (padded_sequences == 0) # Create a mask where there's padding (0s)
# sequence_mask = (padded_sequences == -10).all(dim=-1)
#
# # return padded_sequences, padded_locs, class_ids, sequence_mask
# return loc_feats, locs, class_ids, padded_sequences, padded_context_locs, sequence_mask, embs
#
# def get_item_from_class(self, class_id):
# # Fetch all locations and features for the given class
# all_class_locs = self.per_class_locs[class_id]
# all_class_loc_feats = self.per_class_loc_feats[class_id]
#
# # Randomly select an index for the class
# index = np.random.choice(len(all_class_locs))
#
# # Retrieve the selected location and its features
# loc = all_class_locs[index]
# if loc.ndim == 1:
# loc = loc.unsqueeze(0)
# # loc = loc.unsqueeze(0)
# loc_feat = self.enc.encode(loc, normalize=False)
# # loc_feat = all_class_loc_feats[index]
#
# # Define a unique class token index
# class_token_feature = torch.zeros((1, self.token_dim)) # Create a zero vector for the class token
#
# # Exclude the selected index from the context
# filtered_local_indices = torch.arange(len(all_class_locs)) != index
#
# # Select random or all indices depending on the availability relative to `num_context`
# if filtered_local_indices.sum() > self.num_context:
# selected_indices = filtered_local_indices.nonzero().squeeze().numpy()
# np.random.shuffle(selected_indices)
# selected_indices = selected_indices[:self.num_context]
# else:
# selected_indices = filtered_local_indices.nonzero().squeeze()
#
# # Get context locations and their features
# context_loc_feats = all_class_loc_feats[selected_indices]
# context_locs = all_class_locs[selected_indices]
#
# # Check if context_loc_feats has 1 dimension and add another if it does
# if context_loc_feats.dim() == 1:
# context_loc_feats = context_loc_feats.unsqueeze(0)
# # Check if context_locs has 1 dimension and add another if it does
# if context_locs.dim() == 1:
# context_locs = context_locs.unsqueeze(0)
#
# context_sequence = torch.cat([class_token_feature, context_loc_feats], dim=0)
#
# # #text stuff
# # # get the embedding for the right class
# # print("CHECK WHETHER THE CLASS ID HERE IS A CLASS OR A TAXA")
# # emb_index = self.embs_ids.index(class_id)
# # print("CHECK THE FORM OF EMB")
# # emb = self.embs[emb_index]
# emb = None
# raise NotImplementedError('THIS METHOD NEEDS TO BE UPDATED FOR TEXT EMBEDDINGS - I THINK IT IS NOT USED ANYWHERE CURRENTLY THOUGH')
#
# return loc_feat, loc, class_id, context_sequence, context_locs, emb
#
# def select_text_section(self, text_section):
# # Initialize an empty dictionary to store the result
# class_emb_dict = {}
# # Populate the dictionary
# for i, (index, description) in enumerate(self.embs_keys):
# # Find the class using the index from the class_list
# class_id = self.embs_class_ids[index]
# # Skip this iteration if class_id is -1 - which should correspond to classes not in dataset
# if class_id == -1:
# continue
# if description != text_section:
# continue
# # Check if the class_id is already a key in the dictionary
# if class_id not in class_emb_dict:
# # Initialize with empty lists if class_id is not already in the dictionary
# class_emb_dict[class_id] = ([], [])
#
# # Append the description and the index of embs_keys to the corresponding lists
# class_emb_dict[class_id][0].append(i)
# class_emb_dict[class_id][1].append(description)
# self.class_emb_dict = class_emb_dict
def load_env():
with open('paths.json', 'r') as f:
paths = json.load(f)
raster = load_context_feats(os.path.join(paths['env'],'bioclim_elevation_scaled.npy'))
return raster
def load_context_feats(data_path):
context_feats = np.load(data_path).astype(np.float32)
context_feats = torch.from_numpy(context_feats)
return context_feats
# MAX MAX MAX MAX MAX MAX
_file_cache = {}
def load_inat_data(ip_file, taxa_of_interest=None):
if os.path.exists('.datacache.pt'):
print('\nLoading cached data')
if '.datacache.pt' not in _file_cache:
# If not in the cache, read the file and store its content in the cache
#_file_cache['.datacache.pt'] = torch.load('.datacache.pt', weights_only=False)
_file_cache['.datacache.pt'] = torch.load('.datacache.pt')
locs, taxa, users, dates, years, obs_ids = _file_cache['.datacache.pt']
else:
print('\nLoading ' + ip_file)
data = pd.read_csv(ip_file)
# remove outliers
num_obs = data.shape[0]
data = data[((data['latitude'] <= 90) & (data['latitude'] >= -90) & (data['longitude'] <= 180) & (data['longitude'] >= -180) )]
if (num_obs - data.shape[0]) > 0:
print(num_obs - data.shape[0], 'items filtered due to invalid locations')
if 'accuracy' in data.columns:
data.drop(['accuracy'], axis=1, inplace=True)
if 'positional_accuracy' in data.columns:
data.drop(['positional_accuracy'], axis=1, inplace=True)
if 'geoprivacy' in data.columns:
data.drop(['geoprivacy'], axis=1, inplace=True)
if 'observed_on' in data.columns:
data.rename(columns = {'observed_on':'date'}, inplace=True)
num_obs_orig = data.shape[0]
data = data.dropna()
size_diff = num_obs_orig - data.shape[0]
if size_diff > 0:
print(size_diff, 'observation(s) with a NaN entry out of' , num_obs_orig, 'removed')
# keep only taxa of interest:
if taxa_of_interest is not None:
num_obs_orig = data.shape[0]
data = data[data['taxon_id'].isin(taxa_of_interest)]
print(num_obs_orig - data.shape[0], 'observation(s) out of' , num_obs_orig, 'from different taxa removed')
print('Number of unique classes {}'.format(np.unique(data['taxon_id'].values).shape[0]))
locs = np.vstack((data['longitude'].values, data['latitude'].values)).T.astype(np.float32)
taxa = data['taxon_id'].values.astype(np.int64)
if 'user_id' in data.columns:
users = data['user_id'].values.astype(np.int64)
_, users = np.unique(users, return_inverse=True)
elif 'observer_id' in data.columns:
users = data['observer_id'].values.astype(np.int64)
_, users = np.unique(users, return_inverse=True)
else:
users = np.ones(taxa.shape[0], dtype=np.int64)*-1
# Note - assumes that dates are in format YYYY-MM-DD
temp = np.array(data['date'], dtype='S10')
temp = temp.view('S1').reshape((temp.size, -1))
years = temp[:,:4].view('S4').astype(int)[:,0]
months = temp[:,5:7].view('S2').astype(int)[:,0]
days = temp[:,8:10].view('S2').astype(int)[:,0]
days_per_month = np.cumsum([0] + [monthrange(2018, mm)[1] for mm in range(1, 12)])
dates = days_per_month[months-1] + days-1
dates = np.round((dates) / 364.0, 4).astype(np.float32)
if 'id' in data.columns:
obs_ids = data['id'].values
elif 'observation_uuid' in data.columns:
obs_ids = data['observation_uuid'].values
torch.save((locs, taxa, users, dates, years, obs_ids), '.datacache.pt')
return locs, taxa, users, dates, years, obs_ids
def load_eval_inat_data(ip_file):
if os.path.exists('.eval_datacache.pt'):
print('\nLoading cached eval data')
if '.eval_datacache.pt' not in _file_cache:
# If not in the cache, read the file and store its content in the cache
#_file_cache['.eval_datacache.pt'] = torch.load('.eval_datacache.pt', weights_only=False)
_file_cache['.eval_datacache.pt'] = torch.load('.eval_datacache.pt')
locs, taxa, users, dates, years, obs_ids = _file_cache['.eval_datacache.pt']
else:
print('\nLoading ' + ip_file)
# data = pd.read_csv(ip_file)
data = np.load(ip_file, allow_pickle=True)
locs = data['locs']
labels = data['labels']
class_to_taxa = data['class_to_taxa']
classes = data['classes'].item()
# outliers already removed
# create 'taxa'
# taxa = [class_to_taxa[clss] for clss in labels]
taxa = class_to_taxa[labels]
print('Warning: Setting default value (-1) for users, dates, years, obs_ids')
users = np.ones(taxa.shape[0], dtype=np.int64)*-1
dates = np.ones(taxa.shape[0], dtype=np.int64)*-1
years = np.ones(taxa.shape[0], dtype=np.int64)*-1
obs_ids = np.ones(taxa.shape[0], dtype=np.int64)*-1
torch.save((locs, taxa, users, dates, years, obs_ids), '.eval_datacache.pt')
return locs, taxa, users, dates, years, obs_ids
def choose_aux_species(current_species, num_aux_species, aux_species_seed, taxa_file):
if num_aux_species == 0:
return []
with open('paths.json', 'r') as f:
paths = json.load(f)
data_dir = paths['train']
taxa_file = os.path.join(data_dir, taxa_file)
with open(taxa_file, 'r') as f:
inat_large_metadata = json.load(f)
aux_species_candidates = [x['taxon_id'] for x in inat_large_metadata]
aux_species_candidates = np.setdiff1d(aux_species_candidates, current_species)
print(f'choosing {num_aux_species} species to add from {len(aux_species_candidates)} candidates')
rng = np.random.default_rng(aux_species_seed)
idx_rand_aux_species = rng.permutation(len(aux_species_candidates))
aux_species = list(aux_species_candidates[idx_rand_aux_species[:num_aux_species]])
return aux_species
def get_taxa_of_interest(species_set='all', num_aux_species=0, aux_species_seed=123, taxa_file=None, taxa_file_snt=None):
if species_set == 'all':
return None
if species_set == 'snt_birds':
assert taxa_file_snt is not None
with open(taxa_file_snt, 'r') as f: #
taxa_subsets = json.load(f)
taxa_of_interest = list(taxa_subsets['snt_birds'])
else:
raise NotImplementedError
# optionally add some other species back in:
aux_species = choose_aux_species(taxa_of_interest, num_aux_species, aux_species_seed, taxa_file)
taxa_of_interest.extend(aux_species)
return taxa_of_interest
def get_idx_subsample_observations(labels, hard_cap=-1, hard_cap_seed=123, subset=None, subset_cap=-1):
if hard_cap == -1:
if subset_cap != -1:
raise NotImplementedError('subset_cap set but not hard_cap')
return np.arange(len(labels))
print(f'subsampling (up to) {hard_cap} per class for the training set')
ids, counts = np.unique(labels, return_counts=True)
count_ind = np.cumsum(counts)
count_ind[1:] = count_ind[:-1]
count_ind[0] = 0
ss_rng = np.random.default_rng(hard_cap_seed)
idx_rand = ss_rng.permutation(len(labels))
ordered_inds = np.argsort(labels[idx_rand], kind='stable')
caps = hard_cap + np.zeros_like(counts)
if subset is not None and subset_cap != -1:
caps[subset] = subset_cap
idx_ss = idx_rand[np.concatenate([ordered_inds[i:i+min(limit, cap)] for i, limit, cap in zip(count_ind, counts, caps)])]
print(f'final training set size: {len(idx_ss)}')
return idx_ss
def get_idx_subsample_observations_eval(labels, hard_cap=-1):
if hard_cap == -1:
return np.arange(len(labels))
print(f'Selecting (up to) {hard_cap} per class for the training set')
labels = np.asarray(labels)
unique_labels = np.unique(labels)
idx_list = []
for label in unique_labels:
idx = np.where(labels == label)[0]
cap = min(len(idx), hard_cap)
idx_list.append(idx[:cap])
idx_ss = np.concatenate(idx_list)
print(f'Final training set size: {len(idx_ss)}')
return idx_ss
def uniform_sample_h3(cells, low, high):
'''uniformly sample points in a batch of h3 cells'''
out = np.empty((len(cells), 2))
invalid_mask = np.arange(len(cells))
cell_ids_buffer = np.empty(len(cells), dtype='uint64')
while len(invalid_mask) > 0:
#print(len(invalid_mask))
pts = np.random.random((len(invalid_mask), 2))
pts = high + pts*(low - high)
cell_ids_buffer = h3.latlng_to_cell(pts[:,0], pts[:,1], 5)
valid_mask = (cell_ids_buffer[:len(cells)] == cells)
out[invalid_mask[valid_mask]] = pts[valid_mask]
neg_mask = ~valid_mask
invalid_mask = invalid_mask[neg_mask]
low = low[neg_mask]
high = high[neg_mask]
cells = cells[neg_mask]
return out
class LocationIUCNDataset(torch.utils.data.Dataset):
# MINE MINE MINE MINE - I have included this dummy "num_context" which is probably not needed now
# def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False, num_context=0):
# MAX MAX MAX MAX
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False):
# handle input encoding:
self.input_enc = input_enc
if self.input_enc in ['env', 'sin_cos_env']:
raster = load_env()
else:
raster = None
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
if os.path.exists('iucndataset_nocap.pt'):
mask = torch.load('iucndataset_nocap.pt')
else:
from tqdm import tqdm
# load iucn data
with open('paths.json', 'r') as f:
paths = json.load(f)
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
data = json.load(f)
obs_locs = np.array(data['locs'], dtype=np.float32)
taxa = {int(tt):tk for tt, tk in data['taxa_presence'].items()}
cells = h3.latlng_to_cell(obs_locs[:, 1], obs_locs[:, 0], 5)
mask = np.zeros(len(locs), dtype=bool)
for i, data_taxa in tqdm(enumerate(class_to_taxa)):
if data_taxa in taxa:
data_cells = h3.latlng_to_cell(locs[labels==i, 1], locs[labels==i, 0], 5)
data = np.array(cells[taxa[data_taxa]])
data_inds = data.argsort()
search = np.searchsorted(data[data_inds], data_cells)
search = search.clip(0, len(data)-1)
mask[labels==i] = data[data_inds][search] == data_cells
else:
mask[labels==i] = False
torch.save(mask, 'iucndataset_nocap.pt')
print('Reduced Size: ', mask.sum())
# remove locations that are not in the iucn dataset
locs = locs[mask]
labels = labels[mask]
if dates is not None:
dates = dates[mask]
labels_uniq, labels = np.unique(labels, return_inverse=True)
classes = {class_to_taxa[i]: classes[class_to_taxa[i]] for i in labels_uniq}
class_to_taxa = [class_to_taxa[i] for i in labels_uniq]
# define some properties:
self.locs = locs
self.loc_feats = self.enc.encode(self.locs)
self.labels = torch.from_numpy(labels)
self.classes = classes
self.class_to_taxa = class_to_taxa
if dates is not None:
self.dates = dates
self.enc_time = utils.TimeEncoder()
# useful numbers:
self.num_classes = len(classes)
self.input_dim = input_dim
self.time_dim = time_dim
self.noise_time = noise_time
if self.time_dim > 0:
self.__getitem__ = self._get_item_time
if self.enc.raster is not None:
self.enc.raster = self.enc.raster.to(device)
def viz_map(self, taxa_id, high_res=False):
from matplotlib import pyplot as plt
# load params
with open('paths.json', 'r') as f:
paths = json.load(f)
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
data = json.load(f)
obs_locs = np.array(data['locs'], dtype=np.float32)
taxa = {int(tt): tk for tt, tk in data['taxa_presence'].items()}
taxa_cells = h3.latlng_to_cell(obs_locs[:, 1], obs_locs[:, 0], 5)
# load taxa of interest
if taxa_id in self.class_to_taxa:
class_of_interest = self.class_to_taxa.index(taxa_id)
else:
print(f'Error: Taxa specified that is not in the model: {taxa_id}')
return False
print(f'Loading taxa: {taxa_id}')
# load ocean mask
if high_res:
mask = np.load(os.path.join(paths['masks'], 'ocean_mask_hr.npy'))
else:
mask = np.load(os.path.join(paths['masks'], 'ocean_mask.npy'))
mask_shape = mask.shape
mask_inds = np.where(mask.reshape(-1) == 1)[0]
# generate input features
locs = utils.coord_grid(mask_shape)
locs_cells = h3.latlng_to_cell(locs[:,1], locs[:,0], 5)
# mask iucn
iucn_cells = np.sort(taxa_cells[taxa[taxa_id]])
mask = iucn_cells[np.searchsorted(iucn_cells, locs_cells).clip(max=len(iucn_cells)-1)] == locs_cells
mask_inds = np.where(mask == 1)[0]
mask = mask.reshape(mask_shape)
cell_inds, cell_counts = np.unique(h3.latlng_to_cell(90*self.locs[self.labels==class_of_interest, 1], 180*self.locs[self.labels==class_of_interest, 0], 5),return_counts=True)
search_inds = np.searchsorted(cell_inds, locs_cells).clip(max=len(cell_inds)-1)
preds = np.zeros(len(locs))
cell_mask = cell_inds[search_inds] == locs_cells
preds[cell_mask] = cell_counts[search_inds][cell_mask]
preds = preds/preds.sum()
# Convert preds to log scale
preds = np.log(preds)
center = np.median(preds[np.isfinite(preds)])
preds = 0.5*preds/(preds.max()-center)
preds = (preds + 1 - preds.max()).clip(min=0)
# mask data
op_im = np.ones((mask.shape[0] * mask.shape[1])) * np.nan # set to NaN
op_im[mask_inds] = preds[mask_inds]
# reshape and create masked array for visualization
op_im = op_im.reshape((mask.shape[0], mask.shape[1]))
op_im = np.ma.masked_invalid(op_im)
# set color for masked values
cmap = plt.cm.plasma
cmap.set_bad(color='none')
vmax = np.max(op_im)
# save image
save_loc = os.path.join('./images/', str(taxa_id) + '_map.png')
print(f'Saving image to {save_loc}')
plt.imsave(fname=save_loc, arr=op_im, vmin=0, vmax=vmax, cmap=cmap)
def __len__(self):
return self.loc_feats.shape[0]
def __getitem__(self, index):
return index
def collate_fn(self, batch):
if isinstance(batch[0], int):
loc_feat = self.loc_feats[batch, :]
loc = self.locs[batch, :]
class_id = self.labels[batch]
return loc_feat, loc, class_id
else:
return torch.utils.data.default_collate(batch)
def _get_item_time(self, index):
loc_feat = self.loc_feats[index, :]
loc = self.locs[index, :]
class_id = self.labels[index]
date = self.dates[index]
# add noise
if self.noise_time:
noise_level = random.random()
# noise = (2*random.random() - 1) * (0.5*(365 ** (noise_level - 1)))
noise = (2 * random.random() - 1) * (0.5 * noise_level)
loc_feat = torch.cat([loc_feat, self.enc_time.encode_fast([date.item() + noise, noise_level])])
else:
raise NotImplementedError()
loc_feat = torch.cat([loc_feat, torch.tensor(self.enc_time.encode([2 * date.item() - 1], normalize=False))])
return loc_feat, torch.cat([loc, date[None]]), class_id
class UniformH3Dataset(torch.utils.data.Dataset):
# MINE MINE MINE MINE - I have included this dummy "num_context" which is probably not needed now
# def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False, num_context=0, snt=False):
# MAX MAX MAX MAX
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False, snt=False):
if dates is not None or time_dim > 0 or noise_time:
raise NotImplementedError()
# handle input encoding:
self.input_enc = input_enc
if self.input_enc in ['env', 'sin_cos_env']:
raster = load_env()
else:
raster = None
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
self._enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
# load h3 data:
with open('paths.json', 'r') as f:
paths = json.load(f)
if snt:
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
D = D.item()
loc_indices_per_species = D['loc_indices_per_species']
taxa = D['taxa']
loc_indices_per_species = {t:ls for t,ls in zip(taxa, loc_indices_per_species)}
obs_locs = D['obs_locs']
else:
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
data = json.load(f)
obs_locs = np.array(data['locs'], dtype=np.float32)
loc_indices_per_species = data['taxa_presence']
self.taxa = {int(tt):tk for tt, tk in loc_indices_per_species.items()}
self.cells = h3.latlng_to_cell(obs_locs[:,1], obs_locs[:,0], 5)
self.ind_to_cell = h3.latlng_to_cell(obs_locs[:, 1], obs_locs[:, 0], 5)
self.low_b = np.stack([np.array(h3.h3_to_geo_boundary(c)).min(axis=0) for c in self.cells])
self.high_b = np.stack([np.array(h3.h3_to_geo_boundary(c)).max(axis=0) for c in self.cells])
if os.path.exists('iucndataset_nocap.pt') and not snt:
mask = torch.load('iucndataset_nocap.pt')
elif os.path.exists('sntdataset_nocap.pt') and snt:
mask = torch.load('sntdataset_nocap.pt')
else:
from tqdm import tqdm
mask = np.zeros(len(locs), dtype=bool)
for i, data_taxa in tqdm(enumerate(class_to_taxa)):
if data_taxa in self.taxa:
data_cells = h3.latlng_to_cell(locs[labels==i, 1], locs[labels==i, 0], 5)
data = np.array(self.cells[self.taxa[data_taxa]])
data_inds = data.argsort()
search = np.searchsorted(data[data_inds], data_cells)
search = search.clip(0, len(data)-1)
mask[labels==i] = data[data_inds][search] == data_cells
else:
mask[labels==i] = False
torch.save(mask, 'sntdataset_nocap.pt' if snt else 'iucndataset_nocap.pt')
print('Reduced Size: ', mask.sum())
# remove locations that are not in the iucn dataset
locs = locs[mask]
labels = labels[mask]
labels_uniq, labels = np.unique(labels, return_inverse=True)
classes = {class_to_taxa[i]: classes[class_to_taxa[i]] for i in labels_uniq}
class_to_taxa = [class_to_taxa[i] for i in labels_uniq]
# calculate species statistics
_, counts = np.unique(labels, return_counts=True)
self.num_obs = counts.sum()
self.counts = counts / self.num_obs
# define some properties:
self.classes = classes
self.class_to_taxa = class_to_taxa
# useful numbers:
self.num_classes = len(self.class_to_taxa)
self.input_dim = input_dim
if self.enc.raster is not None:
self.enc.raster = self.enc.raster.to(device)
def __len__(self):
return self.num_obs
def __getitem__(self, index, species=None):
if species is None:
class_id = np.random.choice(np.arange(self.num_classes), p=self.counts)
species = self.class_to_taxa[class_id]
else:
class_id = -1
ind = random.choice(self.taxa[species])
cell = self.cells[ind]
high, low = self.high_b[ind], self.low_b[ind]
return cell, high, low, class_id
def collate_fn(self, batch):
cell, high, low, class_id = zip(*batch)
cell = np.array(cell)
high = np.stack(high)
low = np.stack(low)
class_id = torch.tensor(class_id, dtype=torch.long)
pts = torch.from_numpy(uniform_sample_h3(cell, high, low)).flip(1)
return self._enc.encode(pts, normalize=True).float(), pts, class_id
class MultiUniformH3Dataset(torch.utils.data.Dataset):
# MINE MINE MINE MINE - I have included this dummy "num_context" which is probably not needed now
# def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False, num_context=0):
# MAX MAX MAX MAX MAX MAX MAX
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False):
if dates is not None or time_dim > 0 or noise_time:
raise NotImplementedError()
# handle input encoding:
self.input_enc = input_enc
if self.input_enc in ['env', 'sin_cos_env']:
raster = load_env()
else:
raster = None
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
self._enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
# load h3 data:
with open('paths.json', 'r') as f:
paths = json.load(f)
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
data = json.load(f)
obs_locs = np.array(data['locs'], dtype=np.float32)
self.taxa = {int(tt): tk for tt, tk in data['taxa_presence'].items()}
self.cells = h3.latlng_to_cell(obs_locs[:, 1], obs_locs[:, 0], 5)
self.ind_to_cell = h3.latlng_to_cell(obs_locs[:, 1], obs_locs[:, 0], 5)
self.low_b = np.stack([np.array(h3.h3_to_geo_boundary(c)).min(axis=0) for c in self.cells])
self.high_b = np.stack([np.array(h3.h3_to_geo_boundary(c)).max(axis=0) for c in self.cells])
if os.path.exists('iucndataset_nocap.pt'):
mask = torch.load('iucndataset_nocap.pt')
else:
from tqdm import tqdm
mask = np.zeros(len(locs), dtype=bool)
for i, data_taxa in tqdm(enumerate(class_to_taxa)):
if data_taxa in self.taxa:
data_cells = h3.latlng_to_cell(locs[labels == i, 1], locs[labels == i, 0], 5)
data = np.array(self.cells[self.taxa[data_taxa]])
data_inds = data.argsort()
search = np.searchsorted(data[data_inds], data_cells)
search = search.clip(0, len(data) - 1)
mask[labels == i] = data[data_inds][search] == data_cells
else:
mask[labels == i] = False
torch.save(mask, 'iucndataset_nocap.pt')
print('Reduced Size: ', mask.sum())
# remove locations that are not in the iucn dataset
locs = locs[mask]
labels = labels[mask]
labels_uniq, labels = np.unique(labels, return_inverse=True)
classes = {class_to_taxa[i]: classes[class_to_taxa[i]] for i in labels_uniq}
class_to_taxa = [class_to_taxa[i] for i in labels_uniq]
# calculate species statistics
_, counts = np.unique(labels, return_counts=True)
self.num_obs = counts.sum()
self.counts = counts / self.num_obs
# define some properties:
self.classes = classes
self.class_to_taxa = class_to_taxa
if os.path.exists('taxa_inverse.pt'):
self.taxa_inverse = torch.load('taxa_inverse.pt')
else:
from collections import defaultdict
ctt_i = {c: i for i, c in enumerate(class_to_taxa)}
self.taxa_inverse = defaultdict(list)
for k, vs in self.taxa.items():
for v in vs:
self.taxa_inverse[v].append(ctt_i[k])
torch.save(self.taxa_inverse, 'taxa_inverse.pt')
# useful numbers:
self.num_classes = len(self.class_to_taxa)
self.input_dim = input_dim
if self.enc.raster is not None:
self.enc.raster = self.enc.raster.to(device)
def __len__(self):
return self.num_obs
def __getitem__(self, index):
cell_ind = np.random.randint(len(self.cells))
cell = self.cells[cell_ind]
label = np.zeros(len(self.classes), dtype=np.float32)
label[self.taxa_inverse[cell_ind]] = 1
high, low = self.high_b[cell_ind], self.low_b[cell_ind]
return cell, high, low, label
def collate_fn(self, batch):
cell, high, low, label = zip(*batch)
cell = np.array(cell)
high = np.stack(high)
low = np.stack(low)
labels = torch.from_numpy(np.stack(label))
pts = torch.from_numpy(uniform_sample_h3(cell, high, low)).flip(1)
return self._enc.encode(pts, normalize=True).float(), pts.float(), labels
class MultiLocationIUCNDataset(torch.utils.data.Dataset):
# MINE MINE MINE MINE - I have included this dummy "num_context" which is probably not needed now
# def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False, num_context=0):
# MAX MAX MAX MAX MAX MAX
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device, dates=None, input_dim=4, time_dim=0, noise_time=False):
# handle input encoding:
self.input_enc = input_enc
if self.input_enc in ['env', 'sin_cos_env']:
raster = load_env()
else:
raster = None
self.enc = utils.CoordEncoder(input_enc, raster, input_dim=input_dim)
# load h3 data:
with open('paths.json', 'r') as f:
paths = json.load(f)
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
data = json.load(f)
obs_locs = np.array(data['locs'], dtype=np.float32)
self.taxa = {int(tt): tk for tt, tk in data['taxa_presence'].items()}
self.cells = h3.latlng_to_cell(obs_locs[:, 1], obs_locs[:, 0], 5)
self.cells_inverse = {c: i for i, c in enumerate(self.cells)}
if os.path.exists('iucndataset_nocap.pt'):
mask = torch.load('iucndataset_nocap.pt')
else:
from tqdm import tqdm
mask = np.zeros(len(locs), dtype=bool)
for i, data_taxa in tqdm(enumerate(class_to_taxa)):
if data_taxa in self.taxa:
data_cells = h3.latlng_to_cell(locs[labels==i, 1], locs[labels==i, 0], 5)
data = np.array(self.cells[self.taxa[data_taxa]])
data_inds = data.argsort()
search = np.searchsorted(data[data_inds], data_cells)
search = search.clip(0, len(data)-1)
mask[labels==i] = data[data_inds][search] == data_cells
else:
mask[labels==i] = False
torch.save(mask, 'iucndataset_nocap.pt')
print('Reduced Size: ', mask.sum())
# remove locations that are not in the iucn dataset
locs = locs[mask]
labels = labels[mask]
if dates is not None:
dates = dates[mask]
labels_uniq, labels = np.unique(labels, return_inverse=True)
classes = {class_to_taxa[i]: classes[class_to_taxa[i]] for i in labels_uniq}
class_to_taxa = [class_to_taxa[i] for i in labels_uniq]
# define some properties:
self.locs = locs
self.loc_cells = h3.latlng_to_cell(locs[:, 1], locs[:, 0], 5)
self.loc_feats = self.enc.encode(self.locs)
self.labels = labels
self.classes = classes
self.class_to_taxa = class_to_taxa
if dates is not None:
self.dates = dates
self.enc_time = utils.TimeEncoder()
if os.path.exists('taxa_inverse.pt'):
self.taxa_inverse = torch.load('taxa_inverse.pt')
else:
from collections import defaultdict
ctt_i = {c: i for i, c in enumerate(class_to_taxa)}
self.taxa_inverse = defaultdict(list)
for k, vs in self.taxa.items():
for v in vs:
self.taxa_inverse[v].append(ctt_i[k])
torch.save(self.taxa_inverse, 'taxa_inverse.pt')
# useful numbers:
self.num_classes = len(classes)
self.input_dim = input_dim
self.time_dim = time_dim
self.noise_time = noise_time
if self.enc.raster is not None:
self.enc.raster = self.enc.raster.to(device)
def __len__(self):
return self.loc_feats.shape[0]
def __getitem__(self, index):
loc_feat = self.loc_feats[index, :]
loc = self.locs[index, :]
cell = self.loc_cells[index]
label = np.zeros(len(self.classes), dtype=np.float32)
label[self.taxa_inverse[self.cells_inverse[cell]]] = 1
if self.time_dim > 0:
date = self.dates[index]
# add noise
if self.noise_time:
noise_level = random.random()
#noise = (2*random.random() - 1) * (0.5*(365 ** (noise_level - 1)))
noise = (2*random.random() - 1) * (0.5*noise_level)
loc_feat = torch.cat([loc_feat, self.enc_time.encode_fast([date.item()+noise,noise_level])])
else:
raise NotImplementedError()
loc_feat = torch.cat([loc_feat, torch.tensor(self.enc_time.encode([2*date.item()-1], normalize=False))])
return loc_feat, torch.cat([loc, date[None]]), label
else:
return loc_feat, loc, label
# My function for creating the per class dicts
def organize_data_by_labels(labels, locs):
label_dict = {} # Initialize an empty dictionary
for label, loc in zip(labels, locs): # Loop through labels and locations
if label in label_dict:
label_dict[label].append(loc) # Append the location
else:
label_dict[label] = [loc] # Start a new list with the tuple of location
return label_dict
# MINE MINE - I have added new datasets that I need for my new models
dataset_classes = {'inat': LocationDataset,
'iucn_inat': LocationIUCNDataset,
'iucn_uniform': UniformH3Dataset,
'multi_iucn_uniform': MultiUniformH3Dataset,
'multi_iucn_inat': MultiLocationIUCNDataset,
'transformer': TransformerLocationDataset,
'eval_transformer': EvalTransformerLocationDataset,
'text_transformer': TransformerLocationTextDataset,
'random_text_transformer': TransformerLocationTextDatasetRandomizeOutputs,
'eval_text_transformer': EvalTransformerTextLocationDataset,
'eval_dummy_text_transformer': EvalTransformerDummyTextLocationDataset,
'variable_tokens': TransformerDatasetVariableTokens,
}
def get_dataset_class(dataset_name):
# First, try to get the class directly from the dictionary
if dataset_name in dataset_classes:
return dataset_classes[dataset_name]
else:
# Check if the dataset name matches the pattern for probabilities
pattern = r'random_text_transformer_(\d+\.\d+)_(\d+\.\d+)'
match = re.match(pattern, dataset_name)
if match:
# Extract the probabilities from the dataset name
just_obs_prob = float(match.group(1))
just_text_prob = float(match.group(2))
# Use functools.partial to create a new class with preset parameters
return partial(
TransformerLocationTextDatasetRandomizeOutputs,
just_obs_prob=just_obs_prob,
just_text_prob=just_text_prob
)
else:
# If the dataset name is not recognized, raise an error
raise ValueError(f"Unknown dataset name: {dataset_name}")
def get_train_data_old(params):
with open('paths.json', 'r') as f:
paths = json.load(f)
data_dir = paths['train']
obs_file = os.path.join(data_dir, params['obs_file'])
taxa_file = os.path.join(data_dir, params['taxa_file'])
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
taxa_of_interest = get_taxa_of_interest(params['species_set'], params['num_aux_species'], params['aux_species_seed'], params['taxa_file'], taxa_file_snt)
locs, labels, _, dates, _, _ = load_inat_data(obs_file, taxa_of_interest)
if params['zero_shot']:
with open('paths.json', 'r') as f:
paths = json.load(f)
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
data = json.load(f)
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
D = D.item()
taxa_snt = D['taxa'].tolist()
taxa = [int(tt) for tt in data['taxa_presence'].keys()]
taxa = list(set(taxa + taxa_snt))
mask = labels != taxa[0]
# MINE MINE MINE MINE MINE
# for i in range(0, len(taxa)):
# MAX MAX MAX MAX
for i in range(1, len(taxa)):
mask &= (labels != taxa[i])
locs = locs[mask]
dates = dates[mask]
labels = labels[mask]
unique_taxa, class_ids = np.unique(labels, return_inverse=True)
class_to_taxa = unique_taxa.tolist()
# load class names
class_info_file = json.load(open(taxa_file, 'r'))
class_names_file = [cc['latin_name'] for cc in class_info_file]
taxa_ids_file = [cc['taxon_id'] for cc in class_info_file]
classes = dict(zip(taxa_ids_file, class_names_file))
subset = None
if params['subset_cap_name'] is not None:
if params['subset_cap_name'] == 'iucn':
with open('paths.json', 'r') as f:
paths = json.load(f)
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
data = json.load(f)
taxa = [int(tt) for tt in data['taxa_presence'].keys()]
# get classes to eval
subset = np.zeros((len(taxa),), dtype=int)
for tt_id, tt in enumerate(taxa):
class_of_interest = np.where(np.array(class_to_taxa) == tt)[0]
if len(class_of_interest) != 0:
subset[tt_id] = class_of_interest
else:
raise NotImplementedError(f'Uknown subset name: {params["subset_cap_name"]}')
idx_ss = get_idx_subsample_observations(labels, params['hard_cap_num_per_class'], params['hard_cap_seed'], subset, params['subset_cap_num_per_class'])
locs = torch.from_numpy(np.array(locs)[idx_ss]) # convert to Tensor
labels = torch.from_numpy(np.array(class_ids)[idx_ss])
dates = 364/365*torch.from_numpy(np.array(dates)[idx_ss]) if params['input_time'] else None
# MINE MINE MINE MINE - there are differences here in which dataset to load
# if ((params['dataset'] == 'transformer') or (params['dataset'] == 'eval_transformer')
# or (params['dataset'] == 'text_transformer') or (params['dataset'] == 'eval_text_transformer')):
if 'variable' in params['dataset']:
if 'eval' in params['dataset']:
thing = np.load('./data/transformer_eval_dataset_1.npz', allow_pickle=True)
locs = torch.from_numpy(thing['locs'])
labels = torch.from_numpy(thing['labels'])
classes = thing['classes'].item()
class_to_taxa = list(thing['class_to_taxa'])
# Load the text embeddings
text_embs_dict = torch.load(params['text_emb_path'])
text_embs = text_embs_dict['data']
text_embs_ids = text_embs_dict['taxon_id']
text_embs_keys = text_embs_dict['keys']
# Load the image embeddings
image_embs_dict = torch.load(params['image_emb_path'])
image_embs = image_embs_dict['data']
image_embs_ids = image_embs_dict['taxon_id']
image_embs_keys = image_embs_dict['keys']
if 'eval' in params['dataset']:
ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, text_embs=text_embs, text_embs_ids=text_embs_ids,
text_embs_keys=text_embs_keys, image_embs=image_embs, image_embs_ids=image_embs_ids,
image_embs_keys=image_embs_keys, input_enc=params['input_enc'],
device=params['device'], dates=dates, input_dim=params['input_dim'],
time_dim=params['input_time_dim'], noise_time=params['noise_time'],
num_context=params['num_context'],
transformer_input_enc=params['transformer_input_enc'],
token_dim=params['species_dim'],
jitter=params['add_location_noise'],
variable_context_length=params['variable_context_length'],
loc_prob=params['loc_prob'], text_prob=params['text_prob'],
image_prob=params['image_prob'], env_prob=0.0, eval_mode=True)
else:
ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, text_embs=text_embs, text_embs_ids=text_embs_ids,
text_embs_keys=text_embs_keys, image_embs=image_embs, image_embs_ids=image_embs_ids,
image_embs_keys=image_embs_keys, input_enc=params['input_enc'],
device=params['device'], dates=dates, input_dim=params['input_dim'],
time_dim=params['input_time_dim'], noise_time=params['noise_time'],
num_context=params['num_context'],
transformer_input_enc=params['transformer_input_enc'],
token_dim=params['species_dim'],
jitter=params['add_location_noise'],
variable_context_length=params['variable_context_length'],
loc_prob=params['loc_prob'], text_prob=params['text_prob'],
image_prob=params['image_prob'], env_prob=0.0, eval_mode=False)
elif 'transformer' in params['dataset']:
# holdover to make earlier models work
for key in ['add_location_noise', 'variable_context_length']:
if key not in params:
params[key] = False
# print("THIS PART WILL NEED CHANGING FOR YOUR NEW DATASET")
# # ensuring I load the same eval points each time
# # TODO make a few different versions and see if anything changes for my LS models
if params['dataset'] == 'eval_transformer' or params['dataset'] == 'eval_text_transformer':
thing = np.load('./data/transformer_eval_dataset_1.npz', allow_pickle=True)
locs = torch.from_numpy(thing['locs'])
labels = torch.from_numpy(thing['labels'])
classes = thing['classes'].item()
class_to_taxa = list(thing['class_to_taxa'])
# loading text data for my models
if 'text' in params['dataset']:
# Load the text embeddings
embs_dict = torch.load(params['text_emb_path'])
# Extract the embeddings (assumed to be stored under the 'data' key)
embs = embs_dict['data'] # This should be a tensor of shape [num_taxa, 4096]
# Extract the taxon IDs (assumed to be stored under the 'taxon_id' key)
embs_ids = embs_dict['taxon_id'] # This could be a list or a tensor
embs_keys = embs_dict['keys']
ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, embs=embs, embs_ids=embs_ids,
embs_keys=embs_keys, input_enc=params['input_enc'],
device=params['device'], dates=dates, input_dim=params['input_dim'],
time_dim=params['input_time_dim'], noise_time=params['noise_time'],
num_context=params['num_context'],
transformer_input_enc=params['transformer_input_enc'],
token_dim=params['species_dim'],
jitter=params['add_location_noise'],
variable_context_length=params['variable_context_length'])
else:
ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, input_enc=params['input_enc'],
device=params['device'], dates=dates, input_dim=params['input_dim'],
time_dim=params['input_time_dim'], noise_time=params['noise_time'],
num_context=params['num_context'], transformer_input_enc=params['transformer_input_enc'],
token_dim=params['species_dim'], jitter=params['add_location_noise'],
variable_context_length=params['variable_context_length'])
else:
# this version - which should be accessed if we are not using one of my datasets - is the same as Max's
ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, input_enc=params['input_enc'],
device=params['device'], dates=dates, input_dim=params['input_dim'],
time_dim=params['input_time_dim'], noise_time=params['noise_time'])
return ds
# def get_train_data(params):
# print('modified this to hopefully reduce loading during eval')
# if 'eval' in params['dataset']:
# print('Creating eval dataset')
# with open('paths.json', 'r') as f:
# paths = json.load(f)
# eval_data_path = os.path.join(paths['data'], 'positive_eval_data.npz')
# thing = np.load(eval_data_path, allow_pickle=True)
# locs = torch.from_numpy(thing['locs'])
# labels = torch.from_numpy(thing['labels'])
# classes = thing['classes'].item()
# class_to_taxa = list(thing['class_to_taxa'])
# # for now dates is just set to None. Change if I start doing time stuff.
# dates = None
#
# if 'transformer' in params['dataset']:
# # holdover to make earlier models work
# for key in ['add_location_noise', 'variable_context_length']:
# if key not in params:
# params[key] = False
#
# # loading text data for my models
# if 'text' in params['dataset']:
# # Load the text embeddings
# embs_dict = torch.load(params['text_emb_path'], weights_only=False)
#
# # Extract the embeddings (assumed to be stored under the 'data' key)
# embs = embs_dict['data'] # This should be a tensor of shape [num_taxa, 4096]
#
# # Extract the taxon IDs (assumed to be stored under the 'taxon_id' key)
# embs_ids = embs_dict['taxon_id'] # This could be a list or a tensor
#
# embs_keys = embs_dict['keys']
#
# ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, embs=embs, embs_ids=embs_ids,
# embs_keys=embs_keys, input_enc=params['input_enc'],
# device=params['device'], dates=dates, input_dim=params['input_dim'],
# time_dim=params['input_time_dim'], noise_time=params['noise_time'],
# num_context=params['num_context'],
# transformer_input_enc=params['transformer_input_enc'],
# token_dim=params['species_dim'],
# jitter=params['add_location_noise'],
# variable_context_length=params['variable_context_length'])
#
#
# else:
# ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, input_enc=params['input_enc'],
# device=params['device'], dates=dates, input_dim=params['input_dim'],
# time_dim=params['input_time_dim'], noise_time=params['noise_time'],
# num_context=params['num_context'],
# transformer_input_enc=params['transformer_input_enc'],
# token_dim=params['species_dim'],
# jitter=params['add_location_noise'],
# variable_context_length=params['variable_context_length'])
#
# if 'variable' in params['dataset']:
# # Load the text embeddings
# text_embs_dict = torch.load(params['text_emb_path'], weights_only=False)
# text_embs = text_embs_dict['data']
# text_embs_ids = text_embs_dict['taxon_id']
# text_embs_keys = text_embs_dict['keys']
#
# # Load the image embeddings
# image_embs_dict = torch.load(params['image_emb_path'], weights_only=False)
# image_embs = image_embs_dict['data']
# image_embs_ids = image_embs_dict['taxon_id']
# image_embs_keys = image_embs_dict['keys']
#
# ds = dataset_classes['variable_tokens'](locs, labels, classes, class_to_taxa, text_embs=text_embs,
# text_embs_ids=text_embs_ids,
# text_embs_keys=text_embs_keys, image_embs=image_embs,
# image_embs_ids=image_embs_ids,
# image_embs_keys=image_embs_keys, input_enc=params['input_enc'],
# device=params['device'], dates=dates,
# input_dim=params['input_dim'],
# time_dim=params['input_time_dim'],
# noise_time=params['noise_time'],
# num_context=params['num_context'],
# transformer_input_enc=params['transformer_input_enc'],
# token_dim=params['species_dim'],
# jitter=params['add_location_noise'],
# variable_context_length=params['variable_context_length'],
# loc_prob=params['loc_prob'], text_prob=params['text_prob'],
# image_prob=params['image_prob'], env_prob=0.0, eval_mode=True)
#
# else:
# print('Creating train dataset')
# with open('paths.json', 'r') as f:
# paths = json.load(f)
# data_dir = paths['train']
# obs_file = os.path.join(data_dir, params['obs_file'])
# taxa_file = os.path.join(data_dir, params['taxa_file'])
# taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
#
# taxa_of_interest = get_taxa_of_interest(params['species_set'], params['num_aux_species'], params['aux_species_seed'], params['taxa_file'], taxa_file_snt)
#
# locs, labels, _, dates, _, _ = load_inat_data(obs_file, taxa_of_interest)
# if params['zero_shot']:
# # print("You can easily prevent the loading of the large IUCN dataset here by merely saving a list of eval species")
# # with open('paths.json', 'r') as f:
# # paths = json.load(f)
# # with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
# # data = json.load(f)
# # D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
# # D = D.item()
# # taxa_snt = D['taxa'].tolist()
# # taxa = [int(tt) for tt in data['taxa_presence'].keys()]
# # taxa = list(set(taxa + taxa_snt))
# print('Hopefully we are now avoiding loading the SNT and IUCN data to find the taxa in it')
# with open('paths.json', 'r') as f:
# paths = json.load(f)
# eval_taxa_path = os.path.join(paths['data'], 'eval_taxa_list.npy')
# taxa = np.load(eval_taxa_path, allow_pickle=True)
# mask = labels != taxa[0]
# # MINE MINE MINE MINE MINE
# # for i in range(0, len(taxa)):
# # MAX MAX MAX MAX
# for i in range(1, len(taxa)):
# mask &= (labels != taxa[i])
# locs = locs[mask]
# dates = dates[mask]
# labels = labels[mask]
# unique_taxa, class_ids = np.unique(labels, return_inverse=True)
# class_to_taxa = unique_taxa.tolist()
#
# # load class names
# class_info_file = json.load(open(taxa_file, 'r'))
# class_names_file = [cc['latin_name'] for cc in class_info_file]
# taxa_ids_file = [cc['taxon_id'] for cc in class_info_file]
# classes = dict(zip(taxa_ids_file, class_names_file))
#
# subset = None
# if params['subset_cap_name'] is not None:
# if params['subset_cap_name'] == 'iucn':
# with open('paths.json', 'r') as f:
# paths = json.load(f)
# with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
# data = json.load(f)
# taxa = [int(tt) for tt in data['taxa_presence'].keys()]
# # get classes to eval
# subset = np.zeros((len(taxa),), dtype=int)
# for tt_id, tt in enumerate(taxa):
# class_of_interest = np.where(np.array(class_to_taxa) == tt)[0]
# if len(class_of_interest) != 0:
# subset[tt_id] = class_of_interest
# else:
# raise NotImplementedError(f'Uknown subset name: {params["subset_cap_name"]}')
#
# idx_ss = get_idx_subsample_observations(labels, params['hard_cap_num_per_class'], params['hard_cap_seed'], subset, params['subset_cap_num_per_class'])
#
# locs = torch.from_numpy(np.array(locs)[idx_ss]) # convert to Tensor
#
# labels = torch.from_numpy(np.array(class_ids)[idx_ss])
#
# dates = 364/365*torch.from_numpy(np.array(dates)[idx_ss]) if params['input_time'] else None
#
# # MINE MINE MINE MINE - there are differences here in which dataset to load
# if 'variable' in params['dataset']:
#
# # Load the text embeddings
# text_embs_dict = torch.load(params['text_emb_path'], weights_only=False)
# text_embs = text_embs_dict['data']
# text_embs_ids = text_embs_dict['taxon_id']
# text_embs_keys = text_embs_dict['keys']
#
# # Load the image embeddings
# image_embs_dict = torch.load(params['image_emb_path'], weights_only=False)
# image_embs = image_embs_dict['data']
# image_embs_ids = image_embs_dict['taxon_id']
# image_embs_keys = image_embs_dict['keys']
#
# ds = dataset_classes['variable_tokens'](locs, labels, classes, class_to_taxa, text_embs=text_embs, text_embs_ids=text_embs_ids,
# text_embs_keys=text_embs_keys, image_embs=image_embs, image_embs_ids=image_embs_ids,
# image_embs_keys=image_embs_keys, input_enc=params['input_enc'],
# device=params['device'], dates=dates, input_dim=params['input_dim'],
# time_dim=params['input_time_dim'], noise_time=params['noise_time'],
# num_context=params['num_context'],
# transformer_input_enc=params['transformer_input_enc'],
# token_dim=params['species_dim'],
# jitter=params['add_location_noise'],
# variable_context_length=params['variable_context_length'],
# loc_prob=params['loc_prob'], text_prob=params['text_prob'],
# image_prob=params['image_prob'], env_prob=0.0, eval_mode=False)
#
# elif 'transformer' in params['dataset']:
# # holdover to make earlier models work
# for key in ['add_location_noise', 'variable_context_length']:
# if key not in params:
# params[key] = False
#
# # loading text data for my models
# if 'text' in params['dataset']:
# # Load the text embeddings
# embs_dict = torch.load(params['text_emb_path'], weights_only=False)
#
# # Extract the embeddings (assumed to be stored under the 'data' key)
# embs = embs_dict['data'] # This should be a tensor of shape [num_taxa, 4096]
#
# # Extract the taxon IDs (assumed to be stored under the 'taxon_id' key)
# embs_ids = embs_dict['taxon_id'] # This could be a list or a tensor
#
# embs_keys = embs_dict['keys']
#
# ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, embs=embs, embs_ids=embs_ids,
# embs_keys=embs_keys, input_enc=params['input_enc'],
# device=params['device'], dates=dates, input_dim=params['input_dim'],
# time_dim=params['input_time_dim'], noise_time=params['noise_time'],
# num_context=params['num_context'],
# transformer_input_enc=params['transformer_input_enc'],
# token_dim=params['species_dim'],
# jitter=params['add_location_noise'],
# variable_context_length=params['variable_context_length'])
#
#
# else:
# ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, input_enc=params['input_enc'],
# device=params['device'], dates=dates, input_dim=params['input_dim'],
# time_dim=params['input_time_dim'], noise_time=params['noise_time'],
# num_context=params['num_context'], transformer_input_enc=params['transformer_input_enc'],
# token_dim=params['species_dim'], jitter=params['add_location_noise'],
# variable_context_length=params['variable_context_length'])
# else:
# # this version - which should be accessed if we are not using one of my datasets - is the same as Max's
# ds = dataset_classes[params['dataset']](locs, labels, classes, class_to_taxa, input_enc=params['input_enc'],
# device=params['device'], dates=dates, input_dim=params['input_dim'],
# time_dim=params['input_time_dim'], noise_time=params['noise_time'])
#
# return ds
def get_train_data(params):
print('modified this to hopefully reduce loading during eval')
if 'eval' in params['dataset']:
print('Creating eval dataset')
with open('paths.json', 'r') as f:
paths = json.load(f)
eval_data_path = os.path.join(paths['data'], 'positive_eval_data.npz')
thing = np.load(eval_data_path, allow_pickle=True)
locs = torch.from_numpy(thing['locs'])
labels = torch.from_numpy(thing['labels'])
classes = thing['classes'].item()
class_to_taxa = list(thing['class_to_taxa'])
# for now dates is just set to None. Change if I start doing time stuff.
dates = None
if 'transformer' in params['dataset']:
# holdover to make earlier models work
for key in ['add_location_noise', 'variable_context_length']:
if key not in params:
params[key] = False
# loading text data for my models
if 'text' in params['dataset']:
# Load the text embeddings
#embs_dict = torch.load(params['text_emb_path'], map_location='cpu', weights_only=False)
embs_dict = torch.load(params['text_emb_path'], map_location='cpu')
# Extract the embeddings (assumed to be stored under the 'data' key)
embs = embs_dict['data'] # This should be a tensor of shape [num_taxa, 4096]
# Extract the taxon IDs (assumed to be stored under the 'taxon_id' key)
embs_ids = embs_dict['taxon_id'] # This could be a list or a tensor
embs_keys = embs_dict['keys']
ds = get_dataset_class(params['dataset'])(
locs, labels, classes, class_to_taxa, embs=embs, embs_ids=embs_ids,
embs_keys=embs_keys, input_enc=params['input_enc'],
device=params['device'], dates=dates, input_dim=params['input_dim'],
time_dim=params['input_time_dim'], noise_time=params['noise_time'],
num_context=params['num_context'],
transformer_input_enc=params['transformer_input_enc'],
token_dim=params['species_dim'],
jitter=params['add_location_noise'],
variable_context_length=params['variable_context_length']
)
else:
ds = get_dataset_class(params['dataset'])(
locs, labels, classes, class_to_taxa, input_enc=params['input_enc'],
device=params['device'], dates=dates, input_dim=params['input_dim'],
time_dim=params['input_time_dim'], noise_time=params['noise_time'],
num_context=params['num_context'],
transformer_input_enc=params['transformer_input_enc'],
token_dim=params['species_dim'],
jitter=params['add_location_noise'],
variable_context_length=params['variable_context_length']
)
elif 'variable' in params['dataset']:
# Load the text embeddings
text_embs_dict = torch.load(params['text_emb_path'], map_location='cpu', weights_only=False)
text_embs = text_embs_dict['data']
text_embs_ids = text_embs_dict['taxon_id']
text_embs_keys = text_embs_dict['keys']
# Load the image embeddings
image_embs_dict = torch.load(params['image_emb_path'], map_location='cpu', weights_only=False)
image_embs = image_embs_dict['data']
image_embs_ids = image_embs_dict['taxon_id']
image_embs_keys = image_embs_dict['keys']
ds = dataset_classes['variable_tokens'](
locs, labels, classes, class_to_taxa, text_embs=text_embs,
text_embs_ids=text_embs_ids, text_embs_keys=text_embs_keys, image_embs=image_embs,
image_embs_ids=image_embs_ids, image_embs_keys=image_embs_keys, input_enc=params['input_enc'],
device=params['device'], dates=dates, input_dim=params['input_dim'],
time_dim=params['input_time_dim'], noise_time=params['noise_time'],
num_context=params['num_context'], transformer_input_enc=params['transformer_input_enc'],
token_dim=params['species_dim'], jitter=params['add_location_noise'],
variable_context_length=params['variable_context_length'], loc_prob=params['loc_prob'],
text_prob=params['text_prob'], image_prob=params['image_prob'], env_prob=0.0, eval_mode=True
)
else:
# print('Creating train dataset')
# with open('paths.json', 'r') as f:
# paths = json.load(f)
# with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
# data = json.load(f)
# D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
# D = D.item()
# taxa_snt = D['taxa'].tolist()
# taxa = [int(tt) for tt in data['taxa_presence'].keys()]
# taxa = list(set(taxa + taxa_snt))
# mask = labels != taxa[0]
# # MINE MINE MINE MINE MINE
# # for i in range(0, len(taxa)):
# # MAX MAX MAX MAX
# for i in range(1, len(taxa)):
# mask &= (labels != taxa[i])
# locs = locs[mask]
# dates = dates[mask]
# labels = labels[mask]
with open('paths.json', 'r') as f:
paths = json.load(f)
data_dir = paths['train']
obs_file = os.path.join(data_dir, params['obs_file'])
taxa_file = os.path.join(data_dir, params['taxa_file'])
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
taxa_of_interest = get_taxa_of_interest(
params['species_set'], params['num_aux_species'], params['aux_species_seed'],
params['taxa_file'], taxa_file_snt
)
locs, labels, _, dates, _, _ = load_inat_data(obs_file, taxa_of_interest)
if params['zero_shot']:
print('Hopefully we are now avoiding loading the SNT and IUCN data to find the taxa in it')
with open('paths.json', 'r') as f:
paths = json.load(f)
eval_taxa_path = os.path.join(paths['data'], 'eval_taxa_list.npy')
taxa = np.load(eval_taxa_path, allow_pickle=True)
mask = labels != taxa[0]
for i in range(1, len(taxa)):
mask &= (labels != taxa[i])
locs = locs[mask]
dates = dates[mask]
labels = labels[mask]
# if params['zero_shot']:
# with open('paths.json', 'r') as f:
# paths = json.load(f)
# with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
# data = json.load(f)
# D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
# D = D.item()
# taxa_snt = D['taxa'].tolist()
# taxa = [int(tt) for tt in data['taxa_presence'].keys()]
# taxa = list(set(taxa + taxa_snt))
# mask = labels != taxa[0]
# for i in range(1, len(taxa)):
# mask &= (labels != taxa[i])
# locs = locs[mask]
# dates = dates[mask]
# labels = labels[mask]
unique_taxa, class_ids = np.unique(labels, return_inverse=True)
class_to_taxa = unique_taxa.tolist()
# load class names
class_info_file = json.load(open(taxa_file, 'r'))
class_names_file = [cc['latin_name'] for cc in class_info_file]
taxa_ids_file = [cc['taxon_id'] for cc in class_info_file]
classes = dict(zip(taxa_ids_file, class_names_file))
subset = None
if params['subset_cap_name'] is not None:
if params['subset_cap_name'] == 'iucn':
with open('paths.json', 'r') as f:
paths = json.load(f)
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
data = json.load(f)
taxa = [int(tt) for tt in data['taxa_presence'].keys()]
# get classes to eval
subset = np.zeros((len(taxa),), dtype=int)
for tt_id, tt in enumerate(taxa):
class_of_interest = np.where(np.array(class_to_taxa) == tt)[0]
if len(class_of_interest) != 0:
subset[tt_id] = class_of_interest
else:
raise NotImplementedError(f'Uknown subset name: {params["subset_cap_name"]}')
idx_ss = get_idx_subsample_observations(
labels, params['hard_cap_num_per_class'], params['hard_cap_seed'], subset,
params['subset_cap_num_per_class']
)
locs = torch.from_numpy(np.array(locs)[idx_ss]) # convert to Tensor
labels = torch.from_numpy(np.array(class_ids)[idx_ss])
dates = 364/365*torch.from_numpy(np.array(dates)[idx_ss]) if params['input_time'] else None
if 'variable' in params['dataset']:
# Load the text embeddings
text_embs_dict = torch.load(params['text_emb_path'], map_location='cpu', weights_only=False)
text_embs = text_embs_dict['data']
text_embs_ids = text_embs_dict['taxon_id']
text_embs_keys = text_embs_dict['keys']
# Load the image embeddings
image_embs_dict = torch.load(params['image_emb_path'], map_location='cpu', weights_only=False)
image_embs = image_embs_dict['data']
image_embs_ids = image_embs_dict['taxon_id']
image_embs_keys = image_embs_dict['keys']
ds = dataset_classes['variable_tokens'](
locs, labels, classes, class_to_taxa, text_embs=text_embs, text_embs_ids=text_embs_ids,
text_embs_keys=text_embs_keys, image_embs=image_embs, image_embs_ids=image_embs_ids,
image_embs_keys=image_embs_keys, input_enc=params['input_enc'],
device=params['device'], dates=dates, input_dim=params['input_dim'],
time_dim=params['input_time_dim'], noise_time=params['noise_time'],
num_context=params['num_context'], transformer_input_enc=params['transformer_input_enc'],
token_dim=params['species_dim'], jitter=params['add_location_noise'],
variable_context_length=params['variable_context_length'], loc_prob=params['loc_prob'],
text_prob=params['text_prob'], image_prob=params['image_prob'], env_prob=0.0, eval_mode=False
)
elif 'transformer' in params['dataset']:
# holdover to make earlier models work
for key in ['add_location_noise', 'variable_context_length']:
if key not in params:
params[key] = False
# loading text data for my models
if 'text' in params['dataset']:
# Load the text embeddings
embs_dict = torch.load(params['text_emb_path'], map_location='cpu', weights_only=False)
# Extract the embeddings (assumed to be stored under the 'data' key)
embs = embs_dict['data'] # This should be a tensor of shape [num_taxa, 4096]
# Extract the taxon IDs (assumed to be stored under the 'taxon_id' key)
embs_ids = embs_dict['taxon_id'] # This could be a list or a tensor
embs_keys = embs_dict['keys']
ds = get_dataset_class(params['dataset'])(
locs, labels, classes, class_to_taxa, embs=embs, embs_ids=embs_ids,
embs_keys=embs_keys, input_enc=params['input_enc'],
device=params['device'], dates=dates, input_dim=params['input_dim'],
time_dim=params['input_time_dim'], noise_time=params['noise_time'],
num_context=params['num_context'],
transformer_input_enc=params['transformer_input_enc'],
token_dim=params['species_dim'], jitter=params['add_location_noise'],
variable_context_length=params['variable_context_length']
)
else:
ds = dataset_classes[params['dataset']](
locs, labels, classes, class_to_taxa, input_enc=params['input_enc'],
device=params['device'], dates=dates, input_dim=params['input_dim'],
time_dim=params['input_time_dim'], noise_time=params['noise_time'],
num_context=params['num_context'], transformer_input_enc=params['transformer_input_enc'],
token_dim=params['species_dim'], jitter=params['add_location_noise'],
variable_context_length=params['variable_context_length']
)
else:
# this version - which should be accessed if we are not using one of my datasets - is the same as Max's
ds = dataset_classes[params['dataset']](
locs, labels, classes, class_to_taxa, input_enc=params['input_enc'],
device=params['device'], dates=dates, input_dim=params['input_dim'],
time_dim=params['input_time_dim'], noise_time=params['noise_time']
)
return ds
def test_dataset():
import setup
from tqdm import tqdm
train_params = {}
train_params['species_set'] = 'all'
train_params['hard_cap_num_per_class'] = -1
train_params['num_aux_species'] = 0
train_params['input_enc'] = 'sin_cos_env'
train_params['input_dim'] = 8
train_params['input_time'] = False
train_params['input_time_dim'] = 0
train_params['num_epochs'] = 50
train_params['noise_time'] = False
train_params['loss'] = 'an_full'
train_params['dataset'] = 'iucn_inat'
params = setup.get_default_params_train(train_params)
train_dataset = get_train_data(params)
train_dataset.viz_map(10070)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=params['batch_size'],
shuffle=True,
num_workers=0,
collate_fn=getattr(train_dataset, 'collate_fn', None))
for _ in tqdm(train_loader):
pass
if __name__ == '__main__':
test_dataset()