fs_sinr / models.py
angelazhu96
code for viz
9ff98d7
import torch
import torch.utils.data
import torch.nn as nn
import math
import csv
import numpy as np
import json
import os
def get_model(params, inference_only=False):
if params['model'] == 'ResidualFCNet':
return ResidualFCNet(params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] and 'contrastive' not in params['input_enc'] else 0) + (1 if params['noise_time'] else 0), params['num_classes'] + (20 if 'env' in params['loss'] else 0), params['num_filts'], params['depth'])
elif params['model'] == 'LinNet':
return LinNet(params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] else 0) + (1 if params['noise_time'] else 0), params['num_classes'])
elif params['model'] == 'HyperNet':
return HyperNet(params, params['input_dim'] + (20 if 'env' in params['input_enc'] else 0), params['num_classes'], params['num_filts'], params['depth'],
params['species_dim'], params['species_enc_depth'], params['species_filts'], params['species_enc'], inference_only=inference_only)
# chris models
elif params['model'] == 'MultiInputModel':
return MultiInputModel(num_inputs=params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] and 'contrastive' not in params['input_enc'] else 0) + (1 if params['noise_time'] else 0),
num_filts=params['num_filts'], num_classes=params['num_classes'] + (20 if 'env' in params['loss'] else 0),
depth=params['depth'], ema_factor=params['ema_factor'], nhead=params['num_heads'], num_encoder_layers=params['species_enc_depth'],
dim_feedforward=params['species_filts'], dropout=params['transformer_dropout'],
batch_first=True, token_dim=(params['species_dim'] + (20 if 'env' in params['transformer_input_enc'] else 0)),
sinr_inputs=True if 'sinr' in params['transformer_input_enc'] else False,
register=params['use_register'], use_pretrained_sinr=params['use_pretrained_sinr'],
freeze_sinr=params['freeze_sinr'], pretrained_loc=params['pretrained_loc'],
text_inputs=params['use_text_inputs'], class_token_transformation=params['class_token_transformation'])
elif params['model'] == 'VariableInputModel':
return VariableInputModel(num_inputs=params['input_dim'] + params['input_time_dim'] + (20 if 'env' in params['input_enc'] and 'contrastive' not in params['input_enc'] else 0) + (1 if params['noise_time'] else 0),
num_filts=params['num_filts'], num_classes=params['num_classes'] + (20 if 'env' in params['loss'] else 0),
depth=params['depth'], ema_factor=params['ema_factor'], nhead=params['num_heads'], num_encoder_layers=params['species_enc_depth'],
dim_feedforward=params['species_filts'], dropout=params['transformer_dropout'],
batch_first=True, token_dim=(params['species_dim'] + (20 if 'env' in params['transformer_input_enc'] else 0)),
sinr_inputs=True if 'sinr' in params['transformer_input_enc'] else False,
register=params['use_register'], use_pretrained_sinr=params['use_pretrained_sinr'],
freeze_sinr=params['freeze_sinr'], pretrained_loc=params['pretrained_loc'],
text_inputs=params['use_text_inputs'], image_inputs=params['use_image_inputs'],
env_inputs=params['use_env_inputs'],
class_token_transformation=params['class_token_transformation'])
# class VariableInputModel(nn.Module):
# def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1,
# nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, batch_first=True, token_dim=256,
# sinr_inputs=False, register=False, use_pretrained_sinr=False, freeze_sinr=False, pretrained_loc='',
# text_inputs=False, image_inputs=False, env_inputs=False, class_token_transformation='identity'):
class ResLayer(nn.Module):
def __init__(self, linear_size, activation=nn.ReLU, p=0.5):
super(ResLayer, self).__init__()
self.l_size = linear_size
self.nonlin1 = activation()
self.nonlin2 = activation()
self.dropout1 = nn.Dropout(p=p)
self.w1 = nn.Linear(self.l_size, self.l_size)
self.w2 = nn.Linear(self.l_size, self.l_size)
def forward(self, x):
y = self.w1(x)
y = self.nonlin1(y)
y = self.dropout1(y)
y = self.w2(y)
y = self.nonlin2(y)
out = x + y
return out
class ResidualFCNet(nn.Module):
def __init__(self, num_inputs, num_classes, num_filts, depth=4, nonlin='relu', lowrank=0, dropout_p=0.5):
super(ResidualFCNet, self).__init__()
self.inc_bias = False
if lowrank < num_filts and lowrank != 0:
l1 = nn.Linear(num_filts if depth != -1 else num_inputs, lowrank, bias=self.inc_bias)
l2 = nn.Linear(lowrank, num_classes, bias=self.inc_bias)
self.class_emb = nn.Sequential(l1, l2)
else:
self.class_emb = nn.Linear(num_filts if depth != -1 else num_inputs, num_classes, bias=self.inc_bias)
if nonlin == 'relu':
activation = nn.ReLU
elif nonlin == 'silu':
activation = nn.SiLU
else:
raise NotImplementedError('Invalid nonlinearity specified.')
layers = []
if depth != -1:
layers.append(nn.Linear(num_inputs, num_filts))
layers.append(activation())
for i in range(depth):
layers.append(ResLayer(num_filts, activation=activation))
else:
layers.append(nn.Identity())
self.feats = torch.nn.Sequential(*layers)
def forward(self, x, class_of_interest=None, return_feats=False):
loc_emb = self.feats(x)
if return_feats:
return loc_emb
if class_of_interest is None:
class_pred = self.class_emb(loc_emb)
else:
class_pred = self.eval_single_class(loc_emb, class_of_interest), self.eval_single_class(loc_emb, -1)
return torch.sigmoid(class_pred[0]), torch.sigmoid(class_pred[1])
return torch.sigmoid(class_pred)
def eval_single_class(self, x, class_of_interest):
if self.inc_bias:
return x @ self.class_emb.weight[class_of_interest, :] + self.class_emb.bias[class_of_interest]
else:
return x @ self.class_emb.weight[class_of_interest, :]
class SimpleFCNet(ResidualFCNet):
def forward(self, x, return_feats=True):
assert return_feats
loc_emb = self.feats(x)
class_pred = self.class_emb(loc_emb)
return class_pred
class MockTransformer(nn.Module):
def __init__(self, num_classes, num_dims):
super(MockTransformer, self).__init__()
self.species_emb = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_dims)
def forward(self, class_ids):
return self.species_emb(class_ids)
class CombinedModel(nn.Module):
def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1):
super(CombinedModel, self).__init__()
self.headless_model = HeadlessSINR(num_inputs, num_filts, depth, nonlin, lowrank)
if lowrank < num_filts and lowrank != 0:
self.transformer_model = MockTransformer(num_classes, lowrank)
else:
self.transformer_model = MockTransformer(num_classes, num_filts)
self.ema_factor = ema_factor
self.ema_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=lowrank if (lowrank < num_filts and lowrank != 0) else num_filts)
self.ema_embeddings.weight.data.copy_(self.transformer_model.species_emb.weight.data) # Initialize EMA with the same values as transformer
# this will have to change when I start using the actual transformer
def forward(self, x, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None):
# Process input through the headless model to get feature embeddings
feature_embeddings = self.headless_model(x)
if return_feats:
return feature_embeddings
else:
if class_of_interest == None:
# Get class-specific embeddings based on class_ids
class_embeddings = self.transformer_model(class_ids)
if return_class_embeddings:
return class_embeddings
else:
# Update EMA embeddings for these class IDs
if self.training:
self.update_ema_embeddings(class_ids, class_embeddings)
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embeddings.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
return probabilities
else:
device = self.ema_embeddings.weight.device
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
print(f'using EMA estimate for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
def update_ema_embeddings(self, class_ids, current_embeddings):
if self.training:
# Get current EMA embeddings for the class IDs
ema_current = self.ema_embeddings(class_ids)
# Calculate new EMA values
ema_new = self.ema_factor * current_embeddings + (1 - self.ema_factor) * ema_current
# Update the EMA embeddings
self.ema_embeddings.weight.data[class_ids] = ema_new.detach() # Detach to prevent gradients from flowing here
def get_ema_embeddings(self, class_ids):
# Method to access EMA embeddings
return self.ema_embeddings(class_ids)
class HeadlessSINR(nn.Module):
def __init__(self, num_inputs, num_filts, depth=4, nonlin='relu', lowrank=0, dropout_p=0.5):
super(HeadlessSINR, self).__init__()
self.inc_bias = False
self.low_rank_feats = None
if lowrank < num_filts and lowrank != 0:
l1 = nn.Linear(num_filts if depth != -1 else num_inputs, lowrank, bias=self.inc_bias)
self.low_rank_feats = l1
# else:
# self.class_emb = nn.Linear(num_filts if depth != -1 else num_inputs, num_classes, bias=self.inc_bias)
if nonlin == 'relu':
activation = nn.ReLU
elif nonlin == 'silu':
activation = nn.SiLU
else:
raise NotImplementedError('Invalid nonlinearity specified.')
# Create the layers list for feature extraction
layers = []
if depth != -1:
layers.append(nn.Linear(num_inputs, num_filts))
layers.append(activation())
for i in range(depth):
layers.append(ResLayer(num_filts, activation=activation, p=dropout_p))
else:
layers.append(nn.Identity())
# Include low-rank features in the sequential model if it is defined
if self.low_rank_feats:
# Apply initial layers then low-rank features
layers.append(self.low_rank_feats)
# Set up the features as a sequential model
self.feats = nn.Sequential(*layers)
def forward(self, x):
loc_emb = self.feats(x)
return loc_emb
class TransformerEncoderModel(nn.Module):
def __init__(self, d_model=256, nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, activation='relu',
batch_first=True, output_dim=256): # BATCH FIRST MIGHT HAVE TO CHANGE
super(TransformerEncoderModel, self).__init__()
self.input_layer_norm = nn.LayerNorm(normalized_shape=d_model)
# Create an encoder layer
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
batch_first=batch_first
)
# Stack the encoder layers into an encoder module
self.transformer_encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=num_encoder_layers
)
# Example output layer (modify according to your needs)
self.output_layer = nn.Linear(d_model, output_dim)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
"""
Args:
src: the sequence to the encoder (shape: [seq_length, batch_size, d_model])
src_mask: the mask for the src sequence (shape: [seq_length, seq_length])
src_key_padding_mask: the mask for the padding tokens (shape: [batch_size, seq_length])
Returns:
output of the transformer encoder
"""
# Pass the input through the transformer encoder
encoder_input = self.input_layer_norm(src)
encoder_output = self.transformer_encoder(encoder_input, src_key_padding_mask=src_key_padding_mask, mask=src_mask)
# # Pass the encoder output through the output layer
# output = self.output_layer(encoder_output)
# Assuming the class token is the first in the sequence
# batch_first so we have (batch, sequence, dim)
if encoder_output.ndim == 2:
# in situations where we don't have a batch
encoder_output = encoder_output.unsqueeze(0)
class_token_embedding = encoder_output[:, 0, :]
output = self.output_layer(class_token_embedding) # Process only the class token embedding
return output
class MultiInputModel(nn.Module):
def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1,
nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, batch_first=True, token_dim=256,
sinr_inputs=False, register=False, use_pretrained_sinr=False, freeze_sinr=False, pretrained_loc='',
text_inputs=False, class_token_transformation='identity'):
super(MultiInputModel, self).__init__()
self.headless_model = HeadlessSINR(num_inputs, num_filts, depth, nonlin, lowrank, dropout_p=dropout)
self.ema_factor = ema_factor
self.class_token_transformation = class_token_transformation
# Load pretrained state_dict if use_pretrained_sinr is set to True
if use_pretrained_sinr:
#pretrained_state_dict = torch.load(pretrained_loc, weights_only=False)['state_dict']
pretrained_state_dict = torch.load(pretrained_loc, map_location=torch.device('cpu'))['state_dict']
filtered_state_dict = {k: v for k, v in pretrained_state_dict.items() if not k.startswith('class_emb')}
self.headless_model.load_state_dict(filtered_state_dict, strict=False)
#print(f'Using pretrained sinr from {pretrained_loc}')
# Freeze the SINR model if freeze_sinr is set to True
if freeze_sinr:
for param in self.headless_model.parameters():
param.requires_grad = False
print("Freezing SINR model parameters")
# self.transformer_model = MockTransformer(num_classes, num_filts)
self.transformer_model = TransformerEncoderModel(d_model=token_dim,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=batch_first,
output_dim=num_filts)
self.ema_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts)
# this is just a workaround for now to load eval embeddings - probably not needed long term
self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts)
self.ema_embeddings.weight.requires_grad = False
self.eval_embeddings.weight.requires_grad = False
self.num_filts=num_filts
self.token_dim = token_dim
# nn.init.xavier_uniform_(self.ema_embeddings.weight) # not needed I think
self.sinr_inputs = sinr_inputs
if self.sinr_inputs:
if self.num_filts != self.token_dim and self.class_token_transformation == 'identity':
raise ValueError("If using sinr inputs to transformer with identity class token transformation"
"then token_dim of transformer must be equal to num_filts of sinr model")
# Add a class token
self.class_token = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.class_token)
if register:
# Add a register token initialized with Xavier uniform initialization
self.register = nn.Parameter(torch.empty(1, self.token_dim))
# self.register = (self.register / 2)
nn.init.xavier_uniform_(self.register)
else:
self.register = None
self.text_inputs = text_inputs
if self.text_inputs:
#print("JUST USING A HEADLESS SINR FOR THE TEXT MODEL RIGHT NOW")
self.text_model=HeadlessSINR(num_inputs=4096, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout)
else:
self.text_model=None
# Type-specific embeddings for class, register, location, and text tokens
self.class_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.class_type_embedding)
if register:
self.register_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.register_type_embedding)
self.location_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.location_type_embedding)
if text_inputs:
self.text_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.text_type_embedding)
# Instantiate the class token transformation module
if class_token_transformation == 'identity':
self.class_token_transform = Identity(token_dim, num_filts)
elif class_token_transformation == 'linear':
self.class_token_transform = LinearTransformation(token_dim, num_filts)
elif class_token_transformation == 'single_layer_nn':
self.class_token_transform = SingleLayerNN(token_dim, num_filts, dropout_p=dropout)
elif class_token_transformation == 'two_layer_nn':
self.class_token_transform = TwoLayerNN(token_dim, num_filts, dropout_p=dropout)
elif class_token_transformation == 'sinr':
self.class_token_transform = HeadlessSINR(token_dim, num_filts, depth, nonlin, lowrank, dropout_p=dropout)
else:
raise ValueError(f"Unknown class_token_transformation: {class_token_transformation}")
def forward(self, x, context_sequence, context_mask, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None, use_eval_embeddings=False, text_emb=None):
# Process input through the headless model to get feature embeddings
feature_embeddings = self.headless_model(x)
if return_feats:
return feature_embeddings
if context_sequence.dim() == 2:
context_sequence = context_sequence.unsqueeze(0) # Add batch dimension if missing
context_sequence = context_sequence[:, 1:, :]
if self.sinr_inputs:
# Pass through the headless model
context_sequence = self.headless_model(context_sequence)
# Add type-specific embedding to each location token
# print("SEE IF THIS WORKS")
context_sequence += self.location_type_embedding
batch_size = context_sequence.size(0)
# Expand the class token to match the batch size and add its type-specific embedding
class_token_expanded = self.class_token.expand(batch_size, -1, -1) + self.class_type_embedding
if self.text_inputs and (text_emb is not None):
text_mask = (text_emb.sum(dim=1) == 0)
text_emb = self.text_model(text_emb)
text_emb += self.text_type_embedding
text_emb[text_mask] = 0
# Reshape text_emb to have the shape (batch_size, 1, embedding_dim)
text_emb = text_emb.unsqueeze(1)
if self.register is None:
# context sequence = learnable class_token + rest of sequence
if self.text_inputs:
# Add the class token and text embeddings to the context sequence
context_sequence = torch.cat((class_token_expanded, text_emb, context_sequence), dim=1)
# Pad the context mask to account for the added text embeddings
context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False)
# Update the new part of the mask with the text_mask
context_mask[:, 1] = text_mask # Apply mask directly
else:
context_sequence = torch.cat((class_token_expanded, context_sequence), dim=1)
else:
# Expand the register token to match the batch size and add its type-specific embedding
register_expanded = self.register.expand(batch_size, -1, -1) + self.register_type_embedding
if self.text_inputs:
# Add all components: class token, register, text embeddings, and context
context_sequence = torch.cat((class_token_expanded, register_expanded, text_emb, context_sequence),
dim=1)
# Double pad the context mask: first for register, then for text embeddings
context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False)
context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False)
# Update the new part of the mask for text embeddings
context_mask[:, register_expanded.size(1) + 1] = text_mask # Apply mask directly
else:
context_sequence = torch.cat((class_token_expanded, register_expanded, context_sequence), dim=1)
# Update the context mask to account for the register token
context_mask = nn.functional.pad(context_mask, pad=(1, 0), value=False)
if use_eval_embeddings == False:
if class_of_interest == None:
# Get class-specific embeddings based on class_ids
class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask)
# pass these through the class token transformation
class_embeddings = self.class_token_transform(class_token_output) # Shape: (batch_size, num_filts)
if return_class_embeddings:
return class_embeddings
else:
# Update EMA embeddings for these class IDs
with torch.no_grad():
if self.training:
self.update_ema_embeddings(class_ids, class_embeddings)
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embeddings.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
return probabilities
else:
device = self.ema_embeddings.weight.device
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
print(f'using EMA estimate for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
else:
self.eval()
if not hasattr(self, 'eval_embeddings'):
self.eval_embeddings = self.ema_embeddings
if class_of_interest == None:
# Get class-specific embeddings based on class_ids
class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask)
class_embeddings = self.class_token_transform(class_token_output)
# Update EMA embeddings for these class IDs
self.generate_eval_embeddings(class_ids, class_embeddings)
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embeddings.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
return probabilities
else:
device = self.ema_embeddings.weight.device
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_eval_embeddings(class_of_interest_tensor)
print(f'using eval embedding for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
def init_eval_embeddings(self, num_classes):
self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=self.num_filts)
nn.init.xavier_uniform_(self.eval_embeddings.weight)
def get_ema_embeddings(self, class_ids):
# Method to access EMA embeddings
return self.ema_embeddings(class_ids)
def get_eval_embeddings(self, class_ids):
# Method to access eval embeddings
return self.eval_embeddings(class_ids)
def update_ema_embeddings(self, class_ids, current_embeddings):
if self.training:
# Get unique class IDs and their counts
unique_class_ids, inverse_indices, counts = class_ids.unique(return_counts=True, return_inverse=True)
# Get current EMA embeddings for unique class IDs
ema_current = self.ema_embeddings(unique_class_ids)
# Initialize a placeholder for new EMA values
ema_new = torch.zeros_like(ema_current)
# Compute the average of current embeddings for each unique class ID
current_sum = torch.zeros_like(ema_current)
current_sum.index_add_(0, inverse_indices, current_embeddings)
current_avg = current_sum / counts.unsqueeze(1)
# Apply EMA update formula
ema_new = self.ema_factor * current_avg + (1 - self.ema_factor) * ema_current
# Update the EMA embeddings for unique class IDs
self.ema_embeddings.weight.data[unique_class_ids] = ema_new.detach() # Detach to prevent gradients
def generate_eval_embeddings(self, class_id, current_embedding):
self.eval_embeddings.weight.data[class_id, :] = current_embedding.detach() # Detach to prevent gradients
# self.eval_embeddings.weight.data[class_id] = self.ema_embeddings.weight.data[class_id] # Detach to prevent gradients
def embedding_forward(self, x, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None, eval=False):
# forward method that uses ema or eval embeddings rather than context sequence
# Process input through the headless model to get feature embeddings
feature_embeddings = self.headless_model(x)
if return_feats:
return feature_embeddings
else:
if class_of_interest == None:
# Get class-specific embeddings based on class_ids
if eval == False:
class_embeddings = self.get_ema_embeddings(class_ids=class_ids)
else:
class_embeddings = self.get_eval_embeddings(class_ids=class_ids)
if return_class_embeddings:
return class_embeddings
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embeddings.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
return probabilities
else:
if eval == False:
device = self.ema_embeddings.weight.device
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
print(f'using EMA estimate for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
else:
device = self.eval_embeddings.weight.device
class_of_interest_tensor = torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_eval_embeddings(class_of_interest_tensor)
#print(f'using eval estimate for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
class VariableInputModel(nn.Module):
def __init__(self, num_inputs, num_filts, num_classes, depth=4, nonlin='relu', lowrank=0, ema_factor=0.1,
nhead=8, num_encoder_layers=4, dim_feedforward=2048, dropout=0.1, batch_first=True, token_dim=256,
sinr_inputs=False, register=False, use_pretrained_sinr=False, freeze_sinr=False, pretrained_loc='',
text_inputs=False, image_inputs=False, env_inputs=False, class_token_transformation='identity'):
super(VariableInputModel, self).__init__()
self.headless_model = HeadlessSINR(num_inputs, num_filts, depth, nonlin, lowrank, dropout_p=dropout)
self.ema_factor = ema_factor
self.class_token_transformation = class_token_transformation
# Load pretrained state_dict if use_pretrained_sinr is set to True
if use_pretrained_sinr:
pretrained_state_dict = torch.load(pretrained_loc, weights_only=False)['state_dict']
filtered_state_dict = {k: v for k, v in pretrained_state_dict.items() if not k.startswith('class_emb')}
self.headless_model.load_state_dict(filtered_state_dict, strict=False)
#print(f'Using pretrained sinr from {pretrained_loc}')
# Freeze the SINR model if freeze_sinr is set to True
if freeze_sinr:
for param in self.headless_model.parameters():
param.requires_grad = False
print("Freezing SINR model parameters")
# self.transformer_model = MockTransformer(num_classes, num_filts)
self.transformer_model = TransformerEncoderModel(d_model=token_dim,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=batch_first,
output_dim=num_filts)
self.ema_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts)
# this is just a workaround for now to load eval embeddings - probably not needed long term
self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=num_filts)
self.ema_embeddings.weight.requires_grad = False
self.eval_embeddings.weight.requires_grad = False
self.num_filts=num_filts
self.token_dim = token_dim
# nn.init.xavier_uniform_(self.ema_embeddings.weight) # not needed I think
self.sinr_inputs = sinr_inputs
if self.sinr_inputs:
if self.num_filts != self.token_dim and self.class_token_transformation == 'identity':
raise ValueError("If using sinr inputs to transformer with identity class token transformation"
"then token_dim of transformer must be equal to num_filts of sinr model")
# Add a class token
self.class_token = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.class_token)
if register:
# Add a register token initialized with Xavier uniform initialization
self.register = nn.Parameter(torch.empty(1, self.token_dim))
# self.register = (self.register / 2)
nn.init.xavier_uniform_(self.register)
else:
self.register = None
self.text_inputs = text_inputs
if self.text_inputs:
print("JUST USING A HEADLESS SINR FOR THE TEXT MODEL RIGHT NOW")
self.text_model=HeadlessSINR(num_inputs=4096, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout)
else:
self.text_model=None
self.image_inputs = image_inputs
if self.image_inputs:
print("JUST USING A HEADLESS SINR FOR THE IMAGE MODEL RIGHT NOW")
self.image_model=HeadlessSINR(num_inputs=1024, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout)
else:
self.image_model=None
self.env_inputs = env_inputs
if self.env_inputs:
print("JUST USING A HEADLESS SINR FOR THE ENV MODEL RIGHT NOW")
self.env_model=HeadlessSINR(num_inputs=20, num_filts=512, depth=2, nonlin=nonlin, lowrank=token_dim, dropout_p=dropout)
else:
self.env_model=None
# Type-specific embeddings for class, register, location, text, image and env tokens
self.class_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.class_type_embedding)
if register:
self.register_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.register_type_embedding)
self.location_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.location_type_embedding)
if text_inputs:
self.text_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.text_type_embedding)
if image_inputs:
self.image_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.image_type_embedding)
if env_inputs:
self.env_type_embedding = nn.Parameter(torch.empty(1, self.token_dim))
nn.init.xavier_uniform_(self.env_type_embedding)
# Instantiate the class token transformation module
if class_token_transformation == 'identity':
self.class_token_transform = Identity(token_dim, num_filts)
elif class_token_transformation == 'linear':
self.class_token_transform = LinearTransformation(token_dim, num_filts)
elif class_token_transformation == 'single_layer_nn':
self.class_token_transform = SingleLayerNN(token_dim, num_filts, dropout_p=dropout)
elif class_token_transformation == 'two_layer_nn':
self.class_token_transform = TwoLayerNN(token_dim, num_filts, dropout_p=dropout)
elif class_token_transformation == 'sinr':
self.class_token_transform = HeadlessSINR(token_dim, num_filts, 2, nonlin, lowrank, dropout_p=dropout)
else:
raise ValueError(f"Unknown class_token_transformation: {class_token_transformation}")
def forward(self, x, context_sequence, context_mask, class_ids=None, return_feats=False,
return_class_embeddings=False, class_of_interest=None, use_eval_embeddings=False, text_emb=None,
image_emb=None, env_emb=None):
# Process input through the headless model to get feature embeddings
feature_embeddings = self.headless_model(x)
if return_feats:
return feature_embeddings
if context_sequence.dim() == 2:
context_sequence = context_sequence.unsqueeze(0) # Add batch dimension if missing
context_sequence = context_sequence[:, 1:, :]
context_mask = context_mask[:, 1:]
if self.sinr_inputs:
context_sequence = self.headless_model(context_sequence)
# Add type-specific embedding to each location token
context_sequence += self.location_type_embedding
batch_size = context_sequence.size(0)
# Initialize lists for tokens and masks
tokens = []
masks = []
# Process class token
class_token_expanded = self.class_token.expand(batch_size, -1, -1) + self.class_type_embedding
tokens.append(class_token_expanded)
# The class token is always present, so mask is False (i.e., not masked out)
class_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=context_sequence.device)
masks.append(class_mask)
# Process register token if present
if self.register is not None:
register_expanded = self.register.expand(batch_size, -1, -1) + self.register_type_embedding
tokens.append(register_expanded)
register_mask = torch.zeros(batch_size, 1, dtype=torch.bool, device=context_sequence.device)
masks.append(register_mask)
# Process text embeddings
if self.text_inputs and (text_emb is not None):
text_mask = (text_emb.sum(dim=1) == 0)
text_emb = self.text_model(text_emb)
text_emb += self.text_type_embedding
# Set embeddings to zero where mask is True
text_emb[text_mask] = 0
text_emb = text_emb.unsqueeze(1)
tokens.append(text_emb)
# Expand text_mask to match sequence dimensions
text_mask = text_mask.unsqueeze(1)
masks.append(text_mask)
# Process image embeddings
if self.image_inputs and (image_emb is not None):
image_mask = (image_emb.sum(dim=1) == 0)
image_emb = self.image_model(image_emb)
image_emb += self.image_type_embedding
image_emb[image_mask] = 0
image_emb = image_emb.unsqueeze(1)
tokens.append(image_emb)
image_mask = image_mask.unsqueeze(1)
masks.append(image_mask)
# Process env embeddings if needed (can be added similarly)
if self.env_inputs and (env_emb is not None):
env_mask = context_mask
env_emb = self.env_model(env_emb)
env_emb += self.env_type_embedding
env_emb[env_mask] = 0
env_emb = env_emb.unsqueeze(1)
tokens.append(env_emb)
env_mask = env_mask.unsqueeze(1)
masks.append(env_mask)
# Process location tokens
tokens.append(context_sequence)
masks.append(context_mask)
# Concatenate all tokens and masks
context_sequence = torch.cat(tokens, dim=1)
context_mask = torch.cat(masks, dim=1)
if use_eval_embeddings == False:
if class_of_interest == None:
# Get class-specific embeddings based on class_ids
class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask)
# pass these through the class token transformation
class_embeddings = self.class_token_transform(class_token_output) # Shape: (batch_size, num_filts)
if return_class_embeddings:
return class_embeddings
else:
# Update EMA embeddings for these class IDs
with torch.no_grad():
if self.training:
self.update_ema_embeddings(class_ids, class_embeddings)
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embeddings.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
return probabilities
else:
device = self.ema_embeddings.weight.device
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
print(f'using EMA estimate for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
else:
self.eval()
if not hasattr(self, 'eval_embeddings'):
print('No Eval Embeddings for this species?!')
self.eval_embeddings = self.ema_embeddings
if class_of_interest == None:
# Get class-specific embeddings based on class_ids
class_token_output = self.transformer_model(src=context_sequence, src_key_padding_mask=context_mask)
class_embeddings = self.class_token_transform(class_token_output)
# Update EMA embeddings for these class IDs
self.generate_eval_embeddings(class_ids, class_embeddings)
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embeddings.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
return probabilities
else:
device = self.ema_embeddings.weight.device
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_eval_embeddings(class_of_interest_tensor)
print(f'using eval embedding for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
def get_loc_emb(self, x):
feature_embeddings = self.headless_model(x)
return feature_embeddings
def init_eval_embeddings(self, num_classes):
self.eval_embeddings = nn.Embedding(num_embeddings=num_classes, embedding_dim=self.num_filts)
nn.init.xavier_uniform_(self.eval_embeddings.weight)
def get_ema_embeddings(self, class_ids):
# Method to access EMA embeddings
return self.ema_embeddings(class_ids)
def get_eval_embeddings(self, class_ids):
# Method to access eval embeddings
return self.eval_embeddings(class_ids)
def update_ema_embeddings(self, class_ids, current_embeddings):
if self.training:
# Get unique class IDs and their counts
unique_class_ids, inverse_indices, counts = class_ids.unique(return_counts=True, return_inverse=True)
# Get current EMA embeddings for unique class IDs
ema_current = self.ema_embeddings(unique_class_ids)
# Initialize a placeholder for new EMA values
ema_new = torch.zeros_like(ema_current)
# Compute the average of current embeddings for each unique class ID
current_sum = torch.zeros_like(ema_current)
current_sum.index_add_(0, inverse_indices, current_embeddings)
current_avg = current_sum / counts.unsqueeze(1)
# Apply EMA update formula
ema_new = self.ema_factor * current_avg + (1 - self.ema_factor) * ema_current
# Update the EMA embeddings for unique class IDs
self.ema_embeddings.weight.data[unique_class_ids] = ema_new.detach() # Detach to prevent gradients
def generate_eval_embeddings(self, class_id, current_embedding):
self.eval_embeddings.weight.data[class_id, :] = current_embedding.detach() # Detach to prevent gradients
# self.eval_embeddings.weight.data[class_id] = self.ema_embeddings.weight.data[class_id] # Detach to prevent gradients
def embedding_forward(self, x, class_ids=None, return_feats=False, return_class_embeddings=False, class_of_interest=None, eval=False):
# forward method that uses ema or eval embeddings rather than context sequence
# Process input through the headless model to get feature embeddings
feature_embeddings = self.headless_model(x)
if return_feats:
return feature_embeddings
else:
if class_of_interest == None:
# Get class-specific embeddings based on class_ids
if eval == False:
class_embeddings = self.get_ema_embeddings(class_ids=class_ids)
else:
class_embeddings = self.get_eval_embeddings(class_ids=class_ids)
if return_class_embeddings:
return class_embeddings
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embeddings.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
return probabilities
else:
if eval == False:
device = self.ema_embeddings.weight.device
class_of_interest_tensor =torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_ema_embeddings(class_of_interest_tensor)
print(f'using EMA estimate for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
else:
device = self.eval_embeddings.weight.device
class_of_interest_tensor = torch.tensor([class_of_interest]).to(device)
class_embedding = self.get_eval_embeddings(class_of_interest_tensor)
#print(f'using eval estimate for class {class_of_interest}')
if return_class_embeddings:
return class_embedding
else:
# Matrix multiplication to produce logits
logits = feature_embeddings @ class_embedding.T
# Apply sigmoid to convert logits to probabilities
probabilities = torch.sigmoid(logits)
probabilities = probabilities.squeeze()
return probabilities
class LinNet(nn.Module):
def __init__(self, num_inputs, num_classes):
super(LinNet, self).__init__()
self.num_layers = 0
self.inc_bias = False
self.class_emb = nn.Linear(num_inputs, num_classes, bias=self.inc_bias)
self.feats = nn.Identity() # does not do anything
def forward(self, x, class_of_interest=None, return_feats=False):
loc_emb = self.feats(x)
if return_feats:
return loc_emb
if class_of_interest is None:
class_pred = self.class_emb(loc_emb)
else:
class_pred = self.eval_single_class(loc_emb, class_of_interest)
return torch.sigmoid(class_pred)
def eval_single_class(self, x, class_of_interest):
if self.inc_bias:
return x @ self.class_emb.weight[class_of_interest, :] + self.class_emb.bias[class_of_interest]
else:
return x @ self.class_emb.weight[class_of_interest, :]
class ParallelMulti(torch.nn.Module):
def __init__(self, x: list[torch.nn.Module]):
super(ParallelMulti, self).__init__()
self.layers = nn.ModuleList(x)
def forward(self, xs, **kwargs):
out = torch.cat([self.layers[i](x, **kwargs) for i,x in enumerate(xs)], dim=1)
return out
class SequentialMulti(torch.nn.Sequential):
def forward(self, *inputs, **kwargs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs, **kwargs)
else:
inputs = module(inputs)
return inputs
# Chris's transformation classes
class Identity(nn.Module):
def __init__(self, in_dim, out_dim):
super(Identity, self).__init__()
# No parameters needed for identity transformation
def forward(self, x):
return x
class LinearTransformation(nn.Module):
def __init__(self, in_dim, out_dim, bias=True):
super(LinearTransformation, self).__init__()
self.linear = nn.Linear(in_dim, out_dim, bias=bias)
def forward(self, x):
return self.linear(x)
class SingleLayerNN(nn.Module):
def __init__(self, in_dim, out_dim, dropout_p=0.1, bias=True):
super(SingleLayerNN, self).__init__()
hidden_dim = (in_dim + out_dim) // 2 # Choose an appropriate hidden dimension
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim, bias=bias),
nn.ReLU(),
nn.Dropout(p=dropout_p),
nn.Linear(hidden_dim, out_dim, bias=bias)
)
def forward(self, x):
return self.net(x)
class TwoLayerNN(nn.Module):
def __init__(self, in_dim, out_dim, dropout_p=0.1, bias=True):
super(TwoLayerNN, self).__init__()
hidden_dim = (in_dim + out_dim) // 2 # Choose an appropriate hidden dimension
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim, bias=bias),
nn.ReLU(),
nn.Dropout(p=dropout_p),
nn.Linear(hidden_dim, hidden_dim, bias=bias),
nn.ReLU(),
nn.Dropout(p=dropout_p),
nn.Linear(hidden_dim, out_dim, bias=bias)
)
def forward(self, x):
return self.net(x)
class HyperNet(nn.Module):
'''
:param asdf
'''
def __init__(self, params, num_inputs, num_classes, num_filts, pos_enc_depth, species_dim, species_enc_depth, species_filts, species_enc='embed', inference_only=False):
super(HyperNet, self).__init__()
if species_enc == 'embed':
self.species_emb = nn.Embedding(num_classes, species_dim)
self.species_emb.weight.data *= 0.01
elif species_enc == 'taxa':
self.species_emb = TaxaEncoder(params, './data/inat_taxa_info.csv', species_dim)
elif species_enc == 'text':
self.species_emb = TextEncoder(params, params['text_emb_path'], species_dim, './data/inat_taxa_info.csv')
elif species_enc == 'wiki':
self.species_emb = WikiEncoder(params, params['text_emb_path'], species_dim, inference_only=inference_only)
if species_enc_depth == -1:
self.species_enc = nn.Identity()
elif species_enc_depth == 0:
self.species_enc = nn.Linear(species_dim, num_filts+1)
else:
self.species_enc = SimpleFCNet(species_dim, num_filts+1, species_filts, depth=species_enc_depth)
if 'geoprior' in params['loss']:
self.species_params = nn.Parameter(torch.randn(num_classes, species_dim))
self.species_params.data *= 0.0386
self.pos_enc = SimpleFCNet(num_inputs, num_filts, num_filts, depth=pos_enc_depth)
def forward(self, x, y):
ys, indmap = torch.unique(y, return_inverse=True)
species = self.species_enc(self.species_emb(ys))
species_w, species_b = species[...,:-1], species[...,-1:]
pos = self.pos_enc(x)
out = torch.bmm(species_w[indmap],pos[...,None])
out = (out + 0*species_b[indmap]).squeeze(-1) #TODO
if hasattr(self, 'species_params'):
out2 = torch.bmm(self.species_params[ys][indmap],pos[...,None])
out2 = out2.squeeze(-1)
out3 = (species_w, self.species_params[ys], ys)
return out, out2, out3
else:
return out
def zero_shot(self, x, species_emb):
species = self.species_enc(self.species_emb.zero_shot(species_emb))
species_w, _ = species[...,:-1], species[...,-1:]
pos = self.pos_enc(x)
out = pos @ species_w.T
return out
class TaxaEncoder(nn.Module):
def __init__(self, params, fpath, embedding_dim):
super(TaxaEncoder, self).__init__()
import datasets
with open('paths.json', 'r') as f:
paths = json.load(f)
data_dir = paths['train']
obs_file = os.path.join(data_dir, params['obs_file'])
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
taxa_of_interest = datasets.get_taxa_of_interest(params['species_set'], params['num_aux_species'],
params['aux_species_seed'], params['taxa_file'], taxa_file_snt)
locs, labels, _, dates, _, _ = datasets.load_inat_data(obs_file, taxa_of_interest)
unique_taxa, class_ids = np.unique(labels, return_inverse=True)
class_to_taxa = unique_taxa.tolist()
self.fpath = fpath
ids = []
rows = []
with open(fpath, newline='') as csvfile:
spamreader = csv.reader(csvfile, delimiter=',')
for row in spamreader:
if row[0] == 'taxon_id':
continue
ids.append(int(row[0]))
rows.append(row[3:])
print()
rows = np.array(rows)
rows = [np.unique(rows[:,i], return_inverse=True)[1] for i in range(rows.shape[1])]
rows = torch.from_numpy(np.vstack(rows).T)
rows = rows
self.taxa2row = {taxaid:i for i, taxaid in enumerate(ids)}
embs = [nn.Embedding(rows[:,i].max()+2, embedding_dim, 0) for i in range(rows.shape[1])]
embs[-1] = nn.Embedding(len(class_to_taxa), embedding_dim)
rows2 = torch.zeros((len(class_to_taxa), 7), dtype=rows.dtype)
startind = rows[:,-1].max()
for i in range(len(class_to_taxa)):
if class_to_taxa[i] in ids:
rows2[i] = rows[ids.index(class_to_taxa[i])]+1
rows2[i,-1] -= 1
else:
rows2[i,-1] = startind
startind += 1
self.register_buffer('rows', rows2)
for e in embs:
e.weight.data *= 0.01
self.embs = nn.ModuleList(embs)
def forward(self, x):
inds = self.rows[x]
out = sum([self.embs[i](inds[...,i]) for i in range(inds.shape[-1])])
return out
class TextEncoder(nn.Module):
def __init__(self, params, path, embedding_dim, fpath='inat_taxa_info.csv'):
super(TextEncoder, self).__init__()
import datasets
with open('paths.json', 'r') as f:
paths = json.load(f)
data_dir = paths['train']
obs_file = os.path.join(data_dir, params['obs_file'])
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
taxa_of_interest = datasets.get_taxa_of_interest(params['species_set'], params['num_aux_species'],
params['aux_species_seed'], params['taxa_file'], taxa_file_snt)
locs, labels, _, dates, _, _ = datasets.load_inat_data(obs_file, taxa_of_interest)
unique_taxa, class_ids = np.unique(labels, return_inverse=True)
class_to_taxa = unique_taxa.tolist()
self.fpath = fpath
ids = []
with open(fpath, newline='') as csvfile:
spamreader = csv.reader(csvfile, delimiter=',')
for row in spamreader:
if row[0] == 'taxon_id':
continue
ids.append(int(row[0]))
embs = torch.load(path)
if len(embs) != len(ids):
print("Warning: Number of embeddings doesn't match number of species")
ids = ids[:embs.shape[0]]
if isinstance(embs, list):
embs = torch.stack(embs)
self.taxa2row = {taxaid:i for i, taxaid in enumerate(ids)}
indmap = -1+torch.zeros(len(class_to_taxa), dtype=torch.int)
embmap = -1+torch.zeros(len(class_to_taxa), dtype=torch.int)
self.missing_emb = nn.Embedding(len(class_to_taxa)-embs.shape[0], embedding_dim)
startind = 0
for i in range(len(class_to_taxa)):
if class_to_taxa[i] in ids:
indmap[i] = ids.index(class_to_taxa[i])
else:
embmap[i] = startind
startind += 1
self.scales = nn.Parameter(torch.zeros(len(class_to_taxa), 1))
self.register_buffer('indmap', indmap, persistent=False)
self.register_buffer('embmap', embmap, persistent=False)
self.register_buffer('embs', embs, persistent=False)
if params['text_hidden_dim'] == 0:
self.linear1 = nn.Linear(embs.shape[1], embedding_dim)
else:
self.linear1 = nn.Linear(embs.shape[1], params['text_hidden_dim'])
self.linear2 = nn.Linear(params['text_hidden_dim'], embedding_dim)
self.act = nn.SiLU()
if params['text_learn_dim'] > 0:
self.learned_emb = nn.Embedding(len(class_to_taxa), params['text_learn_dim'])
self.learned_emb.weight.data *= 0.01
self.linear_learned = nn.Linear(params['text_learn_dim'], embedding_dim)
def forward(self, x):
inds = self.indmap[x]
out = self.embs[self.indmap[x].cpu()]
out = self.linear1(out)
if hasattr(self, 'linear2'):
out = self.linear2(self.act(out))
out = self.scales[x] * (out / (out.std(dim=1)[:, None]))
out[inds == -1] = self.missing_emb(self.embmap[x[inds == -1]])
if hasattr(self, 'learned_emb'):
out2 = self.learned_emb(x)
out2 = self.linear_learned(out2)
out = out+out2
return out
class WikiEncoder(nn.Module):
def __init__(self, params, path, embedding_dim, inference_only=False):
super(WikiEncoder, self).__init__()
self.path = path
if not inference_only:
import datasets
with open('paths.json', 'r') as f:
paths = json.load(f)
data_dir = paths['train']
obs_file = os.path.join(data_dir, params['obs_file'])
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
taxa_of_interest = datasets.get_taxa_of_interest(params['species_set'], params['num_aux_species'],
params['aux_species_seed'], params['taxa_file'], taxa_file_snt)
locs, labels, _, dates, _, _ = datasets.load_inat_data(obs_file, taxa_of_interest)
if params['zero_shot']:
with open('paths.json', 'r') as f:
paths = json.load(f)
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
data = json.load(f)
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
D = D.item()
taxa_snt = D['taxa'].tolist()
taxa = [int(tt) for tt in data['taxa_presence'].keys()]
taxa = list(set(taxa + taxa_snt))
mask = labels != taxa[0]
for i in range(1, len(taxa)):
mask &= (labels != taxa[i])
locs = locs[mask]
dates = dates[mask]
labels = labels[mask]
unique_taxa, class_ids = np.unique(labels, return_inverse=True)
class_to_taxa = unique_taxa.tolist()
embs = torch.load(path)
ids = embs['taxon_id'].tolist()
if 'keys' in embs:
taxa_counts = torch.zeros(len(ids), dtype=torch.int32)
for i,k in embs['keys']:
taxa_counts[i] += 1
else:
taxa_counts = torch.ones(len(ids), dtype=torch.int32)
count_sum = torch.cumsum(taxa_counts, dim=0) - taxa_counts
embs = embs['data']
self.taxa2row = {taxaid:i for i, taxaid in enumerate(ids)}
indmap = -1+torch.zeros(len(class_to_taxa), dtype=torch.int)
countmap = torch.zeros(len(class_to_taxa), dtype=torch.int)
self.species_emb = nn.Embedding(len(class_to_taxa), embedding_dim)
self.species_emb.weight.data *= 0.01
for i in range(len(class_to_taxa)):
if class_to_taxa[i] in ids:
i2 = ids.index(class_to_taxa[i])
indmap[i] = count_sum[i2]
countmap[i] = taxa_counts[i2]
self.register_buffer('indmap', indmap, persistent=False)
self.register_buffer('countmap', countmap, persistent=False)
self.register_buffer('embs', embs, persistent=False)
assert embs.shape[1] == 4096
self.scale = nn.Parameter(torch.zeros(1))
if params['species_dropout'] > 0:
self.dropout = nn.Dropout(p=params['species_dropout'])
if params['text_hidden_dim'] == 0:
self.linear1 = nn.Linear(4096, embedding_dim)
else:
self.linear1 = nn.Linear(4096, params['text_hidden_dim'])
if params['text_batchnorm']:
self.bn1 = nn.BatchNorm1d(params['text_hidden_dim'])
for l in range(params['text_num_layers']-1):
setattr(self, f'linear{l+2}', nn.Linear(params['text_hidden_dim'], params['text_hidden_dim']))
if params['text_batchnorm']:
setattr(self, f'bn{l+2}', nn.BatchNorm1d(params['text_hidden_dim']))
setattr(self, f'linear{params["text_num_layers"]+1}', nn.Linear(params['text_hidden_dim'], embedding_dim))
self.act = nn.SiLU()
if params['text_learn_dim'] > 0:
self.learned_emb = nn.Embedding(len(class_to_taxa), params['text_learn_dim'])
self.learned_emb.weight.data *= 0.01
self.linear_learned = nn.Linear(params['text_learn_dim'], embedding_dim)
def forward(self, x):
inds = self.indmap[x] + (torch.rand(x.shape,device=x.device)*self.countmap[x]).floor().int()
out = self.embs[inds]
if hasattr(self, 'dropout'):
out = self.dropout(out)
out = self.linear1(out)
if hasattr(self, 'linear2'):
out = self.act(out)
if hasattr(self, 'bn1'):
out = self.bn1(out)
i = 2
while hasattr(self, f'linear{i}'):
if hasattr(self, f'linear{i}'):
out = self.act(getattr(self, f'linear{i}')(out))
if hasattr(self, f'bn{i}'):
out = getattr(self, f'bn{i}')(out)
i += 1
#out = self.scale * (out / (out.std(dim=1)[:, None]))
out2 = self.species_emb(x)
chosen = torch.rand((out.shape[0],), device=x.device)
chosen = 1+0*chosen #TODO fix this
chosen[inds == -1] = 0
out = chosen[:,None] * out + (1-chosen[:,None])*out2
if hasattr(self, 'learned_emb'):
out2 = self.learned_emb(x)
out2 = self.linear_learned(out2)
out = out+out2
return out
def zero_shot(self, species_emb):
out = species_emb
out = self.linear1(out)
if hasattr(self, 'linear2'):
out = self.act(out)
if hasattr(self, 'bn1'):
out = self.bn1(out)
i = 2
while hasattr(self, f'linear{i}'):
if hasattr(self, f'linear{i}'):
out = self.act(getattr(self, f'linear{i}')(out))
if hasattr(self, f'bn{i}'):
out = getattr(self, f'bn{i}')(out)
i += 1
return out
def zero_shot_old(self, species_emb):
out = species_emb
out = self.linear1(out)
if hasattr(self, 'linear2'):
out = self.linear2(self.act(out))
out = self.scale * (out / (out.std(dim=-1, keepdim=True)))
return out
# MINE - would only be used for my models - not currently being used at all
# CURRENTLY JUST USING A HEADLESS_SINR FOR THE TEXT ENCODER
class MultiInputTextEncoder(nn.Module):
def __init__(self, token_dim, dropout, input_dim=4096, depth=2, hidden_dim=512, nonlin='relu', batch_norm=True, layer_norm=False):
super(MultiInputTextEncoder, self).__init__()
print("THINK ABOUT IF SOME OF THESE HYPERPARAMETERS SHOULD BE DISTINCT FROM THE TRANSFORMER VERSION")
print("DEPTH / NUM_ENCODER_LAYERS, DROPOUT, DIM_FEEDFORWARD, ETC")
print("AT PRESENT WE JUST HAVE A SORT OF BASIC VERSION IMPLEMENTED THAT ATTEMPTS TO BE LIKE MAX'S VERSION")
print("ALSO, OPTION TO HAVE IT PRETRAINED? ADD RESIDUAL LAYERS?")
self.token_dim=token_dim
self.dropout=dropout
self.input_dim=input_dim
self.depth=depth
self.hidden_dim=hidden_dim
self.batch_norm = batch_norm
self.layer_norm = layer_norm
if nonlin == 'relu':
activation = nn.ReLU
elif nonlin == 'silu':
activation = nn.SiLU
else:
raise NotImplementedError('Invalid nonlinearity specified.')
self.dropout_layer = nn.Dropout(p=self.dropout)
if self.depth <= 1:
self.linear1 = nn.Linear(self.input_dim, self.token_dim)
else:
self.linear1 = nn.Linear(self.input_dim, self.hidden_dim)
if self.batch_norm:
self.bn1 = nn.BatchNorm1d(self.hidden_dim)
# if self.layer_norm:
# self.ln1 = nn.LayerNorm(self.hidden_dim)