-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexample_run_script.py
More file actions
36 lines (29 loc) · 932 Bytes
/
example_run_script.py
File metadata and controls
36 lines (29 loc) · 932 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from trainer import Trainer
BATCH_SIZE = 512
EPOCH_NUM = 500
if __name__ == '__main__':
from sklearn.exceptions import UndefinedMetricWarning
import warnings
warnings.filterwarnings(action='ignore', category=UndefinedMetricWarning)
args = dict()
args['input_path'] = 'input_dataset.pickle'
args['output_path'] = 'output_path'
args['lr'] = 'learning_rate'
args['weight_decay'] = 'weight_decay'
args['rnn_dim'] = 'weight_decay'
args['film_rnn_dim'] = 'film_rnndim'
args['global'] = True
args['local'] = True
args['bidirect'] = True
model_name = 'LocgloModel'
model = globals()[model_name]
trainer = Trainer(args)
trainer.load_dataset(args['input_path'])
args['fea_dim'] = trainer.X_train.shape[-1]
model = model(args)
trainer.gru_model = model
trainer.train(
epoch_num=EPOCH_NUM,
batch_size=BATCH_SIZE
)
trainer.test()