ahatamiz commited on
Commit
46ed26c
·
verified ·
1 Parent(s): 1043cb3

Create hf_model.py

Browse files
Files changed (1) hide show
  1. hf_model.py +52 -0
hf_model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import namedtuple
16
+ from typing import Optional, List, Union
17
+
18
+ import torch
19
+ from transformers import PretrainedConfig, PreTrainedModel
20
+ from .mamba_vision import *
21
+ from timm.models import create_model, load_checkpoint
22
+
23
+
24
+ class MambaVisionConfig(PretrainedConfig):
25
+
26
+ def __init__(
27
+ self,
28
+ args: Optional[dict] = None,
29
+ **kwargs,
30
+ ):
31
+ self.args = args
32
+ super().__init__(**kwargs)
33
+
34
+
35
+ class MambaVisionModel(PreTrainedModel):
36
+ """Pretrained Hugging Face model for MambaVision.
37
+
38
+ This class inherits from PreTrainedModel, which provides
39
+ HuggingFace's functionality for loading and saving models.
40
+ """
41
+
42
+ config_class = MambaVisionConfig
43
+
44
+ def __init__(self, config):
45
+ super().__init__(config)
46
+ MambaVisionArgs = namedtuple("MambaVisionArgs", config.args.keys())
47
+ args = MambaVisionArgs(**config.args)
48
+ self.config = config
49
+ self.model = create_model(args.model)
50
+
51
+ def forward(self, x: torch.Tensor):
52
+ return self.model.forward(x)