from transformers import DistilBertTokenizer, DistilBertForSequenceClassification import torch import numpy as np

Define the model and tokenizer

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)

Define the key words and their corresponding labels

key_words = ['ascites', 'cirrhosis', 'liver disease'] labels = [0, 1]

Define a function to preprocess the input text

def preprocess_text(text): inputs = tokenizer.encode_plus( text, add_special_tokens=True, max_length=512, return_attention_mask=True, return_tensors='pt' ) return inputs

Define a function to make predictions

def make_prediction(text): inputs = preprocess_text(text) outputs = model(inputs['input_ids'], attention_mask=inputs['attention_mask']) logits = outputs.logits probabilities = torch.nn.functional.softmax(logits, dim=1) predicted_class = torch.argmax(probabilities) return predicted_class.item()

Define a function to get the clinic that the referral should be directed to

def get_clinic(text): predicted_class = make_prediction(text) if predicted_class == 1: return 'Liver Clinic' else: return 'Kidney Clinic'

Define the model's configuration

model_config = { 'model_type': 'distilbert', 'num_labels': 2, 'key_words': key_words, 'labels': labels }

Define the model's metadata

model_metadata = { 'name': 'Referral Clinic Classifier', 'description': 'A model that classifies referrals to either the Liver Clinic or Kidney Clinic based on the presence of certain key words.', 'author': 'Your Name', 'version': '1.0' }

Train the model

train_data = [ ('Patient has ascites and cirrhosis.', 1), ('Patient has liver disease.', 1), ('Patient has kidney disease.', 0), ('Patient has liver failure.', 1), ('Patient has kidney failure.', 0), ]

for text, label in train_data: inputs = preprocess_text(text) labels = torch.tensor(label) outputs = model(inputs['input_ids'], attention_mask=inputs['attention_mask'], labels=labels) loss = outputs.loss model.zero_grad() loss.backward() optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) optimizer.step()

Save the model to a file

torch.save(model.state_dict(),'referral_clinic_classifier.pth') with open('model_config.json', 'w') as f: json.dump(model_config, f) with open('model_metadata.json', 'w') as f: json.dump(model_metadata, f)

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.