Spaces:
Sleeping
Sleeping
# import numpy as np | |
# import h3 | |
# import json | |
# import os | |
# | |
# snt=False | |
# | |
# def get_labels(species, data): | |
# species = str(species) | |
# lat = [] | |
# lon = [] | |
# gt = [] | |
# for hx in data: | |
# cur_lat, cur_lon = h3.h3_to_geo(hx) | |
# if species in data[hx]: | |
# cur_label = int(len(data[hx][species]) > 0) | |
# gt.append(cur_label) | |
# lat.append(cur_lat) | |
# lon.append(cur_lon) | |
# lat = np.array(lat).astype(np.float32) | |
# lon = np.array(lon).astype(np.float32) | |
# obs_locs = np.vstack((lon, lat)).T | |
# gt = np.array(gt).astype(np.float32) | |
# return obs_locs, gt | |
# | |
# def lonlat_to_pixel(lonlat, grid_width, grid_height): | |
# # Convert normalized lon/lat (-1 to 1) to pixel coordinates | |
# x_pixel = np.floor((lonlat[:, 0] + 1) / 2 * (grid_width - 1)).astype(int) | |
# y_pixel = np.floor((1 - (lonlat[:, 1] + 1) / 2) * (grid_height - 1)).astype(int) | |
# return x_pixel, y_pixel | |
# | |
# ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True) | |
# # 1002, 2004 pixels | |
# # 0 in ocean (needs to be masked out) | |
# | |
# if snt: | |
# with open('paths.json', 'r') as f: | |
# paths = json.load(f) | |
# D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True) | |
# D = D.item() | |
# loc_indices_per_species = D['loc_indices_per_species'] | |
# labels_per_species = D['labels_per_species'] | |
# taxa = D['taxa'] | |
# obs_locs = D['obs_locs'] | |
# obs_locs_idx = D['obs_locs_idx'] | |
# else: | |
# with open('paths.json', 'r') as f: | |
# paths = json.load(f) | |
# with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
# data = json.load(f) | |
# obs_locs = np.array(data['locs'], dtype=np.float32) | |
# taxa = [int(tt) for tt in data['taxa_presence'].keys()] | |
# a = 6 | |
# # data['taxa_presence'] is a dict where keys are "taxa" and then the values are the indices of "obs_locs" where the species is present | |
# # obs locs is in lon, lat with -180 to 180 and -90 to 90 | |
import numpy as np | |
import h3 | |
import json | |
import os | |
import matplotlib.pyplot as plt | |
from mpl_toolkits.axes_grid1 import make_axes_locatable | |
def get_labels(species, data): | |
species = str(species) | |
lat = [] | |
lon = [] | |
gt = [] | |
for hx in data: | |
cur_lat, cur_lon = h3.h3_to_geo(hx) | |
if species in data[hx]: | |
cur_label = int(len(data[hx][species]) > 0) | |
gt.append(cur_label) | |
lat.append(cur_lat) | |
lon.append(cur_lon) | |
lat = np.array(lat).astype(np.float32) | |
lon = np.array(lon).astype(np.float32) | |
obs_locs = np.vstack((lon, lat)).T | |
gt = np.array(gt).astype(np.float32) | |
return obs_locs, gt | |
def lonlat_to_pixel(lonlat, grid_width, grid_height): | |
# Convert normalized lon/lat (-1 to 1) to pixel coordinates | |
x_pixel = np.floor((lonlat[:, 0] + 1) / 2 * (grid_width - 1)).astype(int) | |
y_pixel = np.floor((1 - (lonlat[:, 1] + 1) / 2) * (grid_height - 1)).astype(int) | |
return x_pixel, y_pixel | |
# def plot_heatmap(data,save_loc): | |
# # Apply mask if provided | |
# ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True) | |
# # 1002, 2004 pixels | |
# # 0 in ocean (needs to be masked out) | |
# | |
# # Convert ocean_mask to boolean mask | |
# mask = ocean_mask.astype(bool) | |
# mask = mask[::2, ::2] | |
# | |
# if mask is not None: | |
# data = np.where(mask, data, 0) | |
# | |
# # Set NaN values to 0 for plotting | |
# data = np.nan_to_num(data, nan=0.0) | |
# | |
# fig, ax = plt.subplots(figsize=(20.04, 10.02), dpi=100) | |
# ax.set_xlim(-180, 180) | |
# ax.set_ylim(-90, 90) | |
# ax.axis('off') | |
# | |
# # Use 'magma' colormap with two discrete colors | |
# cmap = plt.get_cmap('magma', 2) | |
# cmap.set_bad(color='none') | |
# plt.rcParams['font.family'] = 'serif' | |
# | |
# cax_im = ax.imshow(data, extent=(-180, 180, -90, 90), origin='upper', cmap=cmap, vmin=0, vmax=1) | |
# | |
# plt.tight_layout() | |
# pdf_save_loc = save_loc + '.pdf' | |
# png_save_loc = save_loc + '.png' | |
# plt.savefig(pdf_save_loc, bbox_inches='tight', pad_inches=0) | |
# plt.savefig(png_save_loc, bbox_inches='tight', pad_inches=0) | |
# plt.close(fig) | |
def plot_heatmap(data, save_loc): | |
# Load the ocean mask | |
ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True) | |
# 1002, 2004 pixels | |
# 0 in ocean (needs to be masked out) | |
# Convert ocean_mask to boolean mask | |
mask = ocean_mask.astype(bool) | |
# If you need to downsample the mask, uncomment the following line | |
mask = mask[::2, ::2] | |
# Set ocean areas to np.nan | |
data = np.where(mask, data, np.nan) | |
# Create a masked array where NaNs are masked | |
data_masked = np.ma.array(data, mask=np.isnan(data)) | |
fig, ax = plt.subplots(figsize=(20.04, 10.02), dpi=100) | |
ax.set_xlim(-180, 180) | |
ax.set_ylim(-90, 90) | |
ax.axis('off') | |
# Use 'magma' colormap with two discrete colors | |
cmap = plt.get_cmap('plasma', 2) | |
# Set color for masked (NaN) values | |
cmap.set_bad(color='none') # 'none' makes it transparent; use 'white' for white background | |
# Plot the data | |
cax_im = ax.imshow( | |
data_masked, | |
extent=(-180, 180, -90, 90), | |
origin='upper', | |
cmap=cmap, | |
vmin=0, | |
vmax=1, | |
interpolation='nearest' | |
) | |
plt.tight_layout() | |
pdf_save_loc = save_loc + '.pdf' | |
png_save_loc = save_loc + '.png' | |
plt.savefig(pdf_save_loc, bbox_inches='tight', pad_inches=0) | |
plt.savefig(png_save_loc, bbox_inches='tight', pad_inches=0) | |
plt.close(fig) | |
def plot_heatmap_2(data, save_loc): | |
# Load the ocean mask | |
ocean_mask = np.load("data/masks/ocean_mask.npy", allow_pickle=True) | |
# 1002, 2004 pixels | |
# 0 in ocean (needs to be masked out) | |
# Convert ocean_mask to boolean mask | |
mask = ocean_mask.astype(bool) | |
# If you need to downsample the mask, uncomment the following line | |
# Set ocean areas to np.nan | |
data = np.where(mask, data, np.nan) | |
# Create a masked array where NaNs are masked | |
data_masked = np.ma.array(data, mask=np.isnan(data)) | |
fig, ax = plt.subplots(figsize=(20.04, 10.02), dpi=100) | |
ax.set_xlim(-180, 180) | |
ax.set_ylim(-90, 90) | |
ax.axis('off') | |
# Use 'magma' colormap with two discrete colors | |
cmap = plt.get_cmap('plasma', 2) | |
# Set color for masked (NaN) values | |
cmap.set_bad(color='none') # 'none' makes it transparent; use 'white' for white background | |
# Plot the data | |
cax_im = ax.imshow( | |
data_masked, | |
extent=(-180, 180, -90, 90), | |
origin='upper', | |
cmap=cmap, | |
vmin=0, | |
vmax=1, | |
interpolation='nearest' | |
) | |
plt.tight_layout() | |
pdf_save_loc = save_loc + '.pdf' | |
png_save_loc = save_loc + '.png' | |
plt.savefig(pdf_save_loc, bbox_inches='tight', pad_inches=0) | |
plt.savefig(png_save_loc, bbox_inches='tight', pad_inches=0) | |
plt.show(block=False) | |
plt.close(fig) | |
def generate_ground_truth(taxa_id, snt=True, grid_height=501, grid_width=1002): | |
print(taxa_id) | |
if snt: | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True) | |
D = D.item() | |
loc_indices_per_species = D['loc_indices_per_species'] | |
labels_per_species = D['labels_per_species'] | |
taxa = D['taxa'] | |
obs_locs = D['obs_locs'] | |
obs_locs_idx = D['obs_locs_idx'] | |
# class_index = np.where(taxa==taxa_id) | |
# class_index = class_index[0] | |
# class_index = class_index[0] | |
# species_loc_indices = loc_indices_per_species[class_index] | |
# species_locs = obs_locs[species_loc_indices] | |
# presence_indices = labels_per_species[class_index] | |
# species_locs = species_locs[presence_indices==1] | |
# Ensure class_index is correctly obtained as an integer index | |
class_indices = np.where(taxa == taxa_id)[0] | |
if len(class_indices) == 0: | |
raise ValueError(f"taxa_id {taxa_id} not found in taxa") | |
class_index = class_indices[0] | |
# Convert loc_indices_per_species[class_index] to a NumPy array | |
species_loc_indices = np.array(loc_indices_per_species[class_index]) | |
# Retrieve the species locations using the indices | |
species_locs = obs_locs[species_loc_indices] | |
# Convert labels_per_species[class_index] to a NumPy array | |
presence_indices = np.array(labels_per_species[class_index]) | |
# Filter species_locs where presence_indices == 1 | |
species_locs = species_locs[presence_indices == 1] | |
else: | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
data = json.load(f) | |
obs_locs = np.array(data['locs'], dtype=np.float32) | |
taxa = [int(tt) for tt in data['taxa_presence'].keys()] | |
indices = data['taxa_presence'][str(taxa_id)] | |
species_locs = obs_locs[indices] # shape (N, 2) | |
# Normalize lonlat | |
species_locs_normalized = species_locs.copy() | |
species_locs_normalized[:, 0] = species_locs_normalized[:, 0] / 180 # lon / 180 | |
species_locs_normalized[:, 1] = species_locs_normalized[:, 1] / 90 # lat / 90# Get grid dimensions from ocean_mas | |
# Get pixel coordinates | |
x_pixel, y_pixel = lonlat_to_pixel(species_locs_normalized, grid_width, grid_height) | |
# Ensure x_pixel and y_pixel are within bounds | |
x_pixel = np.clip(x_pixel, 0, grid_width - 1) | |
y_pixel = np.clip(y_pixel, 0, grid_height - 1) | |
# Create data array | |
data_array = np.zeros((grid_height, grid_width)) | |
# Set pixels where species is present | |
data_array[y_pixel, x_pixel] = 1 | |
# Now call plot_heatmap | |
title = f"Species presence for taxa {taxa_id}" | |
save_loc = f"./images/species_presence_{taxa_id}" | |
plot_heatmap(data_array, save_loc) | |
grid_height = 1002 | |
grid_width = 2004 | |
if snt: | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True) | |
D = D.item() | |
loc_indices_per_species = D['loc_indices_per_species'] | |
labels_per_species = D['labels_per_species'] | |
taxa = D['taxa'] | |
obs_locs = D['obs_locs'] | |
obs_locs_idx = D['obs_locs_idx'] | |
# class_index = np.where(taxa==taxa_id) | |
# class_index = class_index[0] | |
# class_index = class_index[0] | |
# species_loc_indices = loc_indices_per_species[class_index] | |
# species_locs = obs_locs[species_loc_indices] | |
# presence_indices = labels_per_species[class_index] | |
# species_locs = species_locs[presence_indices==1] | |
# Ensure class_index is correctly obtained as an integer index | |
class_indices = np.where(taxa == taxa_id)[0] | |
if len(class_indices) == 0: | |
raise ValueError(f"taxa_id {taxa_id} not found in taxa") | |
class_index = class_indices[0] | |
# Convert loc_indices_per_species[class_index] to a NumPy array | |
species_loc_indices = np.array(loc_indices_per_species[class_index]) | |
# Retrieve the species locations using the indices | |
species_locs = obs_locs[species_loc_indices] | |
# Convert labels_per_species[class_index] to a NumPy array | |
presence_indices = np.array(labels_per_species[class_index]) | |
# Filter species_locs where presence_indices == 1 | |
species_locs = species_locs[presence_indices == 1] | |
else: | |
with open('paths.json', 'r') as f: | |
paths = json.load(f) | |
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f: | |
data = json.load(f) | |
obs_locs = np.array(data['locs'], dtype=np.float32) | |
taxa = [int(tt) for tt in data['taxa_presence'].keys()] | |
indices = data['taxa_presence'][str(taxa_id)] | |
species_locs = obs_locs[indices] # shape (N, 2) | |
# Normalize lonlat | |
species_locs_normalized = species_locs.copy() | |
species_locs_normalized[:, 0] = species_locs_normalized[:, 0] / 180 # lon / 180 | |
species_locs_normalized[:, 1] = species_locs_normalized[:, 1] / 90 # lat / 90# Get grid dimensions from ocean_mas | |
# Get pixel coordinates | |
x_pixel, y_pixel = lonlat_to_pixel(species_locs_normalized, grid_width, grid_height) | |
# Ensure x_pixel and y_pixel are within bounds | |
x_pixel = np.clip(x_pixel, 0, grid_width - 1) | |
y_pixel = np.clip(y_pixel, 0, grid_height - 1) | |
# Create data array | |
data_array = np.zeros((grid_height, grid_width)) | |
# Set pixels where species is present | |
data_array[y_pixel, x_pixel] = 1 | |
# Now call plot_heatmap | |
title = f"Species presence for taxa {taxa_id}" | |
save_loc = f"./images/species_presence_hr_{taxa_id}" | |
plot_heatmap_2(data_array, save_loc) | |
return True | |
if __name__ == '__main__': | |
snt = True | |
grid_height = 501 | |
grid_width = 1002 | |
taxa_id = 11901 # Or any taxa id you want to plot, as string | |
#TODO: why snt true? can't generate gt for (hyacinth macaw(18938), yellow baboon(67683), pika(43188), southernflyingsquirrel (46272)) | |
generate_ground_truth(taxa_id=taxa_id, snt=snt, grid_height=grid_height, grid_width=grid_width) |