fs_sinr / get_gt.py
angelazhu96
code for viz
9ff98d7
# import numpy as np
# import h3
# import json
# import os
#
# snt=False
#
# def get_labels(species, data):
# species = str(species)
# lat = []
# lon = []
# gt = []
# for hx in data:
# cur_lat, cur_lon = h3.h3_to_geo(hx)
# if species in data[hx]:
# cur_label = int(len(data[hx][species]) > 0)
# gt.append(cur_label)
# lat.append(cur_lat)
# lon.append(cur_lon)
# lat = np.array(lat).astype(np.float32)
# lon = np.array(lon).astype(np.float32)
# obs_locs = np.vstack((lon, lat)).T
# gt = np.array(gt).astype(np.float32)
# return obs_locs, gt
#
# def lonlat_to_pixel(lonlat, grid_width, grid_height):
# # Convert normalized lon/lat (-1 to 1) to pixel coordinates
# x_pixel = np.floor((lonlat[:, 0] + 1) / 2 * (grid_width - 1)).astype(int)
# y_pixel = np.floor((1 - (lonlat[:, 1] + 1) / 2) * (grid_height - 1)).astype(int)
# return x_pixel, y_pixel
#
# ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True)
# # 1002, 2004 pixels
# # 0 in ocean (needs to be masked out)
#
# if snt:
# with open('paths.json', 'r') as f:
# paths = json.load(f)
# D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
# D = D.item()
# loc_indices_per_species = D['loc_indices_per_species']
# labels_per_species = D['labels_per_species']
# taxa = D['taxa']
# obs_locs = D['obs_locs']
# obs_locs_idx = D['obs_locs_idx']
# else:
# with open('paths.json', 'r') as f:
# paths = json.load(f)
# with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
# data = json.load(f)
# obs_locs = np.array(data['locs'], dtype=np.float32)
# taxa = [int(tt) for tt in data['taxa_presence'].keys()]
# a = 6
# # data['taxa_presence'] is a dict where keys are "taxa" and then the values are the indices of "obs_locs" where the species is present
# # obs locs is in lon, lat with -180 to 180 and -90 to 90
import numpy as np
import h3
import json
import os
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
def get_labels(species, data):
species = str(species)
lat = []
lon = []
gt = []
for hx in data:
cur_lat, cur_lon = h3.h3_to_geo(hx)
if species in data[hx]:
cur_label = int(len(data[hx][species]) > 0)
gt.append(cur_label)
lat.append(cur_lat)
lon.append(cur_lon)
lat = np.array(lat).astype(np.float32)
lon = np.array(lon).astype(np.float32)
obs_locs = np.vstack((lon, lat)).T
gt = np.array(gt).astype(np.float32)
return obs_locs, gt
def lonlat_to_pixel(lonlat, grid_width, grid_height):
# Convert normalized lon/lat (-1 to 1) to pixel coordinates
x_pixel = np.floor((lonlat[:, 0] + 1) / 2 * (grid_width - 1)).astype(int)
y_pixel = np.floor((1 - (lonlat[:, 1] + 1) / 2) * (grid_height - 1)).astype(int)
return x_pixel, y_pixel
# def plot_heatmap(data,save_loc):
# # Apply mask if provided
# ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True)
# # 1002, 2004 pixels
# # 0 in ocean (needs to be masked out)
#
# # Convert ocean_mask to boolean mask
# mask = ocean_mask.astype(bool)
# mask = mask[::2, ::2]
#
# if mask is not None:
# data = np.where(mask, data, 0)
#
# # Set NaN values to 0 for plotting
# data = np.nan_to_num(data, nan=0.0)
#
# fig, ax = plt.subplots(figsize=(20.04, 10.02), dpi=100)
# ax.set_xlim(-180, 180)
# ax.set_ylim(-90, 90)
# ax.axis('off')
#
# # Use 'magma' colormap with two discrete colors
# cmap = plt.get_cmap('magma', 2)
# cmap.set_bad(color='none')
# plt.rcParams['font.family'] = 'serif'
#
# cax_im = ax.imshow(data, extent=(-180, 180, -90, 90), origin='upper', cmap=cmap, vmin=0, vmax=1)
#
# plt.tight_layout()
# pdf_save_loc = save_loc + '.pdf'
# png_save_loc = save_loc + '.png'
# plt.savefig(pdf_save_loc, bbox_inches='tight', pad_inches=0)
# plt.savefig(png_save_loc, bbox_inches='tight', pad_inches=0)
# plt.close(fig)
def plot_heatmap(data, save_loc):
# Load the ocean mask
ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True)
# 1002, 2004 pixels
# 0 in ocean (needs to be masked out)
# Convert ocean_mask to boolean mask
mask = ocean_mask.astype(bool)
# If you need to downsample the mask, uncomment the following line
mask = mask[::2, ::2]
# Set ocean areas to np.nan
data = np.where(mask, data, np.nan)
# Create a masked array where NaNs are masked
data_masked = np.ma.array(data, mask=np.isnan(data))
fig, ax = plt.subplots(figsize=(20.04, 10.02), dpi=100)
ax.set_xlim(-180, 180)
ax.set_ylim(-90, 90)
ax.axis('off')
# Use 'magma' colormap with two discrete colors
cmap = plt.get_cmap('plasma', 2)
# Set color for masked (NaN) values
cmap.set_bad(color='none') # 'none' makes it transparent; use 'white' for white background
# Plot the data
cax_im = ax.imshow(
data_masked,
extent=(-180, 180, -90, 90),
origin='upper',
cmap=cmap,
vmin=0,
vmax=1,
interpolation='nearest'
)
plt.tight_layout()
pdf_save_loc = save_loc + '.pdf'
png_save_loc = save_loc + '.png'
plt.savefig(pdf_save_loc, bbox_inches='tight', pad_inches=0)
plt.savefig(png_save_loc, bbox_inches='tight', pad_inches=0)
plt.close(fig)
def plot_heatmap_2(data, save_loc):
# Load the ocean mask
ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True)
# 1002, 2004 pixels
# 0 in ocean (needs to be masked out)
# Convert ocean_mask to boolean mask
mask = ocean_mask.astype(bool)
# If you need to downsample the mask, uncomment the following line
# Set ocean areas to np.nan
data = np.where(mask, data, np.nan)
# Create a masked array where NaNs are masked
data_masked = np.ma.array(data, mask=np.isnan(data))
fig, ax = plt.subplots(figsize=(20.04, 10.02), dpi=100)
ax.set_xlim(-180, 180)
ax.set_ylim(-90, 90)
ax.axis('off')
# Use 'magma' colormap with two discrete colors
cmap = plt.get_cmap('plasma', 2)
# Set color for masked (NaN) values
cmap.set_bad(color='none') # 'none' makes it transparent; use 'white' for white background
# Plot the data
cax_im = ax.imshow(
data_masked,
extent=(-180, 180, -90, 90),
origin='upper',
cmap=cmap,
vmin=0,
vmax=1,
interpolation='nearest'
)
plt.tight_layout()
pdf_save_loc = save_loc + '.pdf'
png_save_loc = save_loc + '.png'
plt.savefig(pdf_save_loc, bbox_inches='tight', pad_inches=0)
plt.savefig(png_save_loc, bbox_inches='tight', pad_inches=0)
plt.show(block=False)
plt.close(fig)
def generate_ground_truth(taxa_id, snt=True, grid_height=501, grid_width=1002):
print(taxa_id)
if snt:
with open('paths.json', 'r') as f:
paths = json.load(f)
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
D = D.item()
loc_indices_per_species = D['loc_indices_per_species']
labels_per_species = D['labels_per_species']
taxa = D['taxa']
obs_locs = D['obs_locs']
obs_locs_idx = D['obs_locs_idx']
# class_index = np.where(taxa==taxa_id)
# class_index = class_index[0]
# class_index = class_index[0]
# species_loc_indices = loc_indices_per_species[class_index]
# species_locs = obs_locs[species_loc_indices]
# presence_indices = labels_per_species[class_index]
# species_locs = species_locs[presence_indices==1]
# Ensure class_index is correctly obtained as an integer index
class_indices = np.where(taxa == taxa_id)[0]
if len(class_indices) == 0:
raise ValueError(f"taxa_id {taxa_id} not found in taxa")
class_index = class_indices[0]
# Convert loc_indices_per_species[class_index] to a NumPy array
species_loc_indices = np.array(loc_indices_per_species[class_index])
# Retrieve the species locations using the indices
species_locs = obs_locs[species_loc_indices]
# Convert labels_per_species[class_index] to a NumPy array
presence_indices = np.array(labels_per_species[class_index])
# Filter species_locs where presence_indices == 1
species_locs = species_locs[presence_indices == 1]
else:
with open('paths.json', 'r') as f:
paths = json.load(f)
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
data = json.load(f)
obs_locs = np.array(data['locs'], dtype=np.float32)
taxa = [int(tt) for tt in data['taxa_presence'].keys()]
indices = data['taxa_presence'][str(taxa_id)]
species_locs = obs_locs[indices] # shape (N, 2)
# Normalize lonlat
species_locs_normalized = species_locs.copy()
species_locs_normalized[:, 0] = species_locs_normalized[:, 0] / 180 # lon / 180
species_locs_normalized[:, 1] = species_locs_normalized[:, 1] / 90 # lat / 90# Get grid dimensions from ocean_mas
# Get pixel coordinates
x_pixel, y_pixel = lonlat_to_pixel(species_locs_normalized, grid_width, grid_height)
# Ensure x_pixel and y_pixel are within bounds
x_pixel = np.clip(x_pixel, 0, grid_width - 1)
y_pixel = np.clip(y_pixel, 0, grid_height - 1)
# Create data array
data_array = np.zeros((grid_height, grid_width))
# Set pixels where species is present
data_array[y_pixel, x_pixel] = 1
# Now call plot_heatmap
title = f"Species presence for taxa {taxa_id}"
save_loc = f"./images/species_presence_{taxa_id}"
plot_heatmap(data_array, save_loc)
grid_height = 1002
grid_width = 2004
if snt:
with open('paths.json', 'r') as f:
paths = json.load(f)
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
D = D.item()
loc_indices_per_species = D['loc_indices_per_species']
labels_per_species = D['labels_per_species']
taxa = D['taxa']
obs_locs = D['obs_locs']
obs_locs_idx = D['obs_locs_idx']
# class_index = np.where(taxa==taxa_id)
# class_index = class_index[0]
# class_index = class_index[0]
# species_loc_indices = loc_indices_per_species[class_index]
# species_locs = obs_locs[species_loc_indices]
# presence_indices = labels_per_species[class_index]
# species_locs = species_locs[presence_indices==1]
# Ensure class_index is correctly obtained as an integer index
class_indices = np.where(taxa == taxa_id)[0]
if len(class_indices) == 0:
raise ValueError(f"taxa_id {taxa_id} not found in taxa")
class_index = class_indices[0]
# Convert loc_indices_per_species[class_index] to a NumPy array
species_loc_indices = np.array(loc_indices_per_species[class_index])
# Retrieve the species locations using the indices
species_locs = obs_locs[species_loc_indices]
# Convert labels_per_species[class_index] to a NumPy array
presence_indices = np.array(labels_per_species[class_index])
# Filter species_locs where presence_indices == 1
species_locs = species_locs[presence_indices == 1]
else:
with open('paths.json', 'r') as f:
paths = json.load(f)
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
data = json.load(f)
obs_locs = np.array(data['locs'], dtype=np.float32)
taxa = [int(tt) for tt in data['taxa_presence'].keys()]
indices = data['taxa_presence'][str(taxa_id)]
species_locs = obs_locs[indices] # shape (N, 2)
# Normalize lonlat
species_locs_normalized = species_locs.copy()
species_locs_normalized[:, 0] = species_locs_normalized[:, 0] / 180 # lon / 180
species_locs_normalized[:, 1] = species_locs_normalized[:, 1] / 90 # lat / 90# Get grid dimensions from ocean_mas
# Get pixel coordinates
x_pixel, y_pixel = lonlat_to_pixel(species_locs_normalized, grid_width, grid_height)
# Ensure x_pixel and y_pixel are within bounds
x_pixel = np.clip(x_pixel, 0, grid_width - 1)
y_pixel = np.clip(y_pixel, 0, grid_height - 1)
# Create data array
data_array = np.zeros((grid_height, grid_width))
# Set pixels where species is present
data_array[y_pixel, x_pixel] = 1
# Now call plot_heatmap
title = f"Species presence for taxa {taxa_id}"
save_loc = f"./images/species_presence_hr_{taxa_id}"
plot_heatmap_2(data_array, save_loc)
return True
if __name__ == '__main__':
snt = True
grid_height = 501
grid_width = 1002
taxa_id = 11901 # Or any taxa id you want to plot, as string
#TODO: why snt true? can't generate gt for (hyacinth macaw(18938), yellow baboon(67683), pika(43188), southernflyingsquirrel (46272))
generate_ground_truth(taxa_id=taxa_id, snt=snt, grid_height=grid_height, grid_width=grid_width)