|
import os |
|
|
|
import streamlit as st |
|
import hashlib |
|
import uuid |
|
import time |
|
import json |
|
import numpy as np |
|
from concrete.ml.sklearn import SGDClassifier |
|
|
|
from blockchain import Blockchain, print_blockchain_details |
|
|
|
import watermarking |
|
from watermarking import watermark_model |
|
|
|
|
|
def generate_mock_hash(): |
|
return hashlib.sha256(str(time.time()).encode()).hexdigest() |
|
|
|
|
|
from utils import ( |
|
CLIENT_DIR, |
|
CURRENT_DIR, |
|
DEPLOYMENT_DIR, |
|
KEYS_DIR, |
|
INPUT_BROWSER_LIMIT, |
|
clean_directory, |
|
SERVER_DIR, |
|
) |
|
|
|
from concrete.ml.deployment import FHEModelClient |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
st.sidebar.title("Contact") |
|
st.sidebar.info( |
|
""" |
|
- Reda Bellafqira |
|
- Mehdi Ben Ghali |
|
- Pierre-Elisée Flory |
|
- Mohammed Lansari |
|
- Thomas Winninger |
|
""" |
|
) |
|
|
|
st.title("Zamark: Secure Watermarking Service") |
|
|
|
st.image( |
|
"watermarking.png", |
|
) |
|
|
|
|
|
def todo(): |
|
st.warning("Not implemented yet", icon="⚠️") |
|
|
|
|
|
def key_gen_fn(client_id): |
|
""" |
|
Generate keys for a given user. The keys are saved in KEYS_DIR |
|
|
|
!!! needs a model in DEPLOYMENT_DIR as "client.zip" !!! |
|
Args: |
|
client_id (str): The client_id, retrieved from streamlit |
|
""" |
|
clean_directory() |
|
|
|
client = FHEModelClient(path_dir=DEPLOYMENT_DIR, key_dir=KEYS_DIR / f"{client_id}") |
|
client.load() |
|
|
|
|
|
client.generate_private_and_evaluation_keys() |
|
|
|
|
|
serialized_evaluation_keys = client.get_serialized_evaluation_keys() |
|
assert isinstance(serialized_evaluation_keys, bytes) |
|
|
|
|
|
evaluation_key_path = KEYS_DIR / f"{client_id}/evaluation_key" |
|
with evaluation_key_path.open("wb") as f: |
|
f.write(serialized_evaluation_keys) |
|
|
|
|
|
serialized_evaluation_keys_shorten_hex = serialized_evaluation_keys.hex()[ |
|
:INPUT_BROWSER_LIMIT |
|
] |
|
|
|
|
|
with st.expander("Generated keys"): |
|
st.write(f"{len(serialized_evaluation_keys) / (10**6):.2f} MB") |
|
st.code(serialized_evaluation_keys_shorten_hex) |
|
|
|
st.success("Keys have been generated!", icon="✅") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def decode_id(binary_rep): |
|
"""Decode a string of bits to an ascii string |
|
|
|
Args: |
|
binary_rep (_type_): the binary string |
|
|
|
Returns: |
|
_type_: an ascii string |
|
""" |
|
|
|
|
|
binary_int = int(binary_rep, 2) |
|
|
|
byte_number = binary_int.bit_length() + 7 // 8 |
|
|
|
binary_array = binary_int.to_bytes(byte_number, "big") |
|
|
|
ascii_text = binary_array.decode() |
|
|
|
return ascii_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.header("Client Configuration", divider=True) |
|
|
|
|
|
|
|
X_trigger, y_trigger = None, None |
|
if st.button("Generate the trigger set for the watermarking"): |
|
|
|
X_trigger, y_trigger = watermarking.gen_trigger_set() |
|
|
|
np.save("x_trigger", X_trigger) |
|
np.save("y_trigger", y_trigger) |
|
|
|
|
|
|
|
x_train, y_train, x_test, y_test = watermarking.gen_database() |
|
|
|
np.save("x_train", x_train) |
|
np.save("y_train", y_train) |
|
np.save("x_test", x_test) |
|
np.save("y_test", y_test) |
|
|
|
|
|
st.success("Trigger set generated and data saved successfully!") |
|
|
|
|
|
st.write(f"Trigger set shape: X={X_trigger.shape}, y={y_trigger.shape}") |
|
st.write(f"Training data shape: X={x_train.shape}, y={y_train.shape}") |
|
st.write(f"Test data shape: X={x_test.shape}, y={y_test.shape}") |
|
|
|
|
|
st.header("Model Training and Encryption", divider=True) |
|
|
|
model, x_train, y_train, x_test, y_test = None, None, None, None, None |
|
parameters_range = (-1.0, 1.0) |
|
if st.button("Model Training and Encryption"): |
|
|
|
x_train, y_train, x_test, y_test = watermarking.gen_database() |
|
|
|
|
|
|
|
model = SGDClassifier( |
|
random_state=42, |
|
max_iter=100, |
|
fit_encrypted=True, |
|
parameters_range=parameters_range, |
|
penalty=None, |
|
learning_rate="constant", |
|
verbose=1) |
|
|
|
model.coef_ = np.load("model_coef.npy") |
|
model.intercept_ = np.load("model_intercept.npy") |
|
|
|
|
|
st.success("Model training and encryption completed successfully!") |
|
|
|
|
|
st.write("Model Information:") |
|
st.write(f"- Type: {type(model).__name__}") |
|
st.write(f"- Number of features: {model.coef_.shape[1]}") |
|
st.write(f"- Parameters range: {parameters_range}") |
|
|
|
st.write("\nData Information:") |
|
st.write(f"- Training set shape: X={x_train.shape}, y={y_train.shape}") |
|
st.write(f"- Test set shape: X={x_test.shape}, y={y_test.shape}") |
|
|
|
|
|
st.write("\nModel Coefficients Preview:") |
|
st.write(model.coef_[:5]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.header("Model Watermarking", divider=True) |
|
|
|
|
|
|
|
|
|
wat_model = None |
|
parameters_range = (-1.0, 1.0) |
|
if st.button("Model Watermarking"): |
|
|
|
|
|
|
|
wat_model = SGDClassifier( |
|
random_state=42, |
|
max_iter=100, |
|
fit_encrypted=True, |
|
parameters_range=parameters_range, |
|
penalty=None, |
|
learning_rate="constant", |
|
verbose=1) |
|
|
|
wat_model.coef_ = np.load("wat_model_coef.npy") |
|
wat_model.intercept_ = np.load("wat_model_intercept.npy") |
|
|
|
|
|
st.success("Model watermarking completed successfully!") |
|
|
|
|
|
st.write("Watermarked Model Information:") |
|
st.write(f"- Type: {type(wat_model).__name__}") |
|
st.write(f"- Number of features: {wat_model.coef_.shape[1]}") |
|
st.write(f"- Parameters range: {parameters_range}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.header("Update Blockchain", divider=True) |
|
|
|
|
|
if 'block_data' not in st.session_state: |
|
st.session_state.block_data = None |
|
|
|
|
|
if st.button("Update Blockchain"): |
|
try: |
|
|
|
loaded_blockchain, data = Blockchain.load_from_file("blockchain.json") |
|
|
|
|
|
is_valid = loaded_blockchain.is_chain_valid() |
|
st.write(f"Loaded blockchain is valid: {is_valid}") |
|
|
|
if not is_valid: |
|
st.warning("The loaded blockchain is not valid. Please check data integrity.") |
|
else: |
|
parameters_range = (-1.0, 1.0) |
|
wat_model = SGDClassifier( |
|
random_state=42, |
|
max_iter=100, |
|
fit_encrypted=True, |
|
parameters_range=parameters_range, |
|
penalty=None, |
|
learning_rate="constant", |
|
verbose=1) |
|
|
|
wat_model.coef_ = np.load("wat_model_coef.npy") |
|
wat_model.intercept_ = np.load("wat_model_intercept.npy") |
|
|
|
X_trigger = np.load("x_trigger.npy") |
|
y_trigger = np.load("y_trigger.npy") |
|
|
|
watermarked_model_hash = watermarking.get_model_hash(wat_model) |
|
trigger_set_hf = watermarking.get_trigger_hash(X_trigger, y_trigger) |
|
trigger_set_client = watermarking.get_trigger_hash(X_trigger, y_trigger) |
|
|
|
|
|
new_block = loaded_blockchain.add_block(trigger_set_hf, trigger_set_client, watermarked_model_hash) |
|
|
|
|
|
loaded_blockchain.save_to_file("blockchain.json") |
|
|
|
|
|
st.session_state.block_data = new_block.to_dict() |
|
|
|
st.success("Blockchain updated successfully!") |
|
|
|
|
|
st.subheader("New Block Information") |
|
st.write(f"Block ID: {new_block.counter}") |
|
st.write(f"Timestamp: {new_block.timestamp}") |
|
st.write(f"Previous Hash: {new_block.previous_hash}") |
|
st.write(f"Current Hash: {new_block.hash}") |
|
|
|
|
|
st.subheader("Blockchain Statistics") |
|
st.write(f"Total Blocks: {len(loaded_blockchain.chain)}") |
|
st.write(f"Blockchain File Size: {os.path.getsize('blockchain.json') / 1024:.2f} KB") |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred while updating the blockchain: {str(e)}") |
|
|
|
|
|
if st.session_state.block_data: |
|
st.subheader("Latest Block Data (JSON)") |
|
|
|
|
|
block_json = json.dumps(st.session_state.block_data, indent=2) |
|
|
|
|
|
st.code(block_json, language='json') |
|
|
|
|
|
st.subheader("Download Blockchain") |
|
with open("blockchain.json", "rb") as file: |
|
btn = st.download_button( |
|
label="Download Blockchain JSON", |
|
data=file, |
|
file_name="blockchain.json", |
|
mime="application/json" |
|
) |