-
Notifications
You must be signed in to change notification settings - Fork 5
/
model_library.py
231 lines (200 loc) · 9.21 KB
/
model_library.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
from keras import initializers, regularizers, activations, constraints
from keras.engine import Layer, InputSpec
from keras.layers import Reshape, Concatenate, Conv2D, MaxPooling2D, interfaces
from keras import backend as K
class Highway(Layer):
def __init__(self,
init='glorot_uniform',
activation=None,
weights=None,
W_regularizer=None,
b_regularizer=None,
activity_regularizer=None,
W_constraint=None,
b_constraint=None,
bias=True,
input_dim=None,
**kwargs):
self.init = initializers.get(init)
self.activation = activations.get(activation)
self.W_regularizer = regularizers.get(W_regularizer)
self.b_regularizer = regularizers.get(b_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.W_constraint = constraints.get(W_constraint)
self.b_constraint = constraints.get(b_constraint)
self.bias = bias
self.initial_weights = weights
self.input_spec = InputSpec(ndim=2)
self.input_dim = input_dim
if self.input_dim:
kwargs['input_shape'] = (self.input_dim,)
super(Highway, self).__init__(**kwargs)
def build(self, input_shape):
input_dim = input_shape[1]
self.input_spec = InputSpec(dtype=K.floatx(),
shape=(None, input_dim))
self.W = self.add_weight((input_dim, input_dim),
initializer=self.init,
name='W',
regularizer=self.W_regularizer,
constraint=self.W_constraint)
self.W_carry = self.add_weight((input_dim, input_dim),
initializer=self.init,
name='W_carry')
if self.bias:
self.b = self.add_weight((input_dim,),
initializer='zero',
name='b',
regularizer=self.b_regularizer,
constraint=self.b_constraint)
self.b_carry = self.add_weight((input_dim,),
initializer='one',
name='b_carry')
else:
self.b_carry = None
if self.initial_weights is not None:
self.set_weights(self.initial_weights)
del self.initial_weights
self.built = True
def call(self, x):
y = K.dot(x, self.W_carry)
if self.bias:
y += self.b_carry
transform_weight = activations.sigmoid(y)
y = K.dot(x, self.W)
if self.bias:
y += self.b
act = self.activation(y)
act *= transform_weight
output = act + (1 - transform_weight) * x
return output
def get_config(self):
config = {'init': initializers.serialize(self.init),
'activation': activations.serialize(self.activation),
'W_regularizer': regularizers.serialize(self.W_regularizer),
'b_regularizer': regularizers.serialize(self.b_regularizer),
'activity_regularizer': regularizers.serialize(self.activity_regularizer),
'W_constraint': constraints.serialize(self.W_constraint),
'b_constraint': constraints.serialize(self.b_constraint),
'bias': self.bias,
'input_dim': self.input_dim}
base_config = super(Highway, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def dot_product(x, kernel):
"""
Wrapper for dot product operation, in order to be compatible with both
Theano and Tensorflow
Args:
x (): input
kernel (): weights
Returns:
"""
if K.backend() == 'tensorflow':
return K.squeeze(K.dot(x, K.expand_dims(kernel)), axis=-1)
else:
return K.dot(x, kernel)
class AttentionWithContext(Layer):
"""
Attention operation, with a context/query vector, for temporal data.
Supports Masking.
Follows the work of Yang et al. [https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf]
"Hierarchical Attention Networks for Document Classification"
by using a context vector to assist the attention
# Input shape
3D tensor with shape: `(samples, steps, features)`.
# Output shape
2D tensor with shape: `(samples, features)`.
:param kwargs:
Just put it on top of an RNN Layer (GRU/LSTM/SimpleRNN) with return_sequences=True.
The dimensions are inferred based on the output shape of the RNN.
Example:
model.add(LSTM(64, return_sequences=True))
model.add(AttentionWithContext())
"""
def __init__(self,
W_regularizer=None, u_regularizer=None, b_regularizer=None,
W_constraint=None, u_constraint=None, b_constraint=None,
W_dropout=0., u_dropout=0., bias=True, **kwargs):
self.supports_masking = True
self.W_init = initializers.get('orthogonal')
self.u_init = initializers.get('glorot_uniform')
self.W_regularizer = regularizers.get(W_regularizer)
self.u_regularizer = regularizers.get(u_regularizer)
self.b_regularizer = regularizers.get(b_regularizer)
self.W_constraint = constraints.get(W_constraint)
self.u_constraint = constraints.get(u_constraint)
self.b_constraint = constraints.get(b_constraint)
self.W_dropout = min(1., max(0., W_dropout))
self.u_dropout = min(1., max(0., u_dropout))
self.bias = bias
super(AttentionWithContext, self).__init__(**kwargs)
def build(self, input_shape):
assert len(input_shape) == 3
self.W = self.add_weight((input_shape[-1], input_shape[-1]),
initializer=self.W_init,
name='{}_W'.format(self.name),
regularizer=self.W_regularizer,
constraint=self.W_constraint)
if self.bias:
self.b = self.add_weight((input_shape[-1],),
initializer='zero',
name='{}_b'.format(self.name),
regularizer=self.b_regularizer,
constraint=self.b_constraint)
self.u = self.add_weight((input_shape[-1],),
initializer=self.u_init,
name='{}_u'.format(self.name),
regularizer=self.u_regularizer,
constraint=self.u_constraint)
super(AttentionWithContext, self).build(input_shape)
def compute_mask(self, input, input_mask=None):
# do not pass the mask to the next layers
return None
def call(self, x, mask=None):
if 0. < self.W_dropout < 1.:
def dropped_inputs():
return K.dropout(x, self.W_dropout)
x_dp = K.in_train_phase(dropped_inputs, x)
else:
x_dp = x
uit = dot_product(x_dp, self.W)
if self.bias:
uit += self.b
uit = K.tanh(uit)
if 0. < self.u_dropout < 1.:
def dropped_inputs():
return K.dropout(uit, self.u_dropout)
u_dp = K.in_train_phase(dropped_inputs, uit)
else:
u_dp = uit
ait = dot_product(u_dp, self.u)
a = K.exp(ait)
a = K.expand_dims(a)
a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())
weighted_input = x * a
output = K.sum(weighted_input, axis=1)
return output
def compute_output_shape(self, input_shape):
return input_shape[0], input_shape[-1]
def get_config(self):
config = {'W_regularizer': regularizers.serialize(self.W_regularizer),
'u_regularizer': regularizers.serialize(self.u_regularizer),
'b_regularizer': regularizers.serialize(self.b_regularizer),
'W_constraint': constraints.serialize(self.W_constraint),
'u_constraint': constraints.serialize(self.u_constraint),
'b_constraint': constraints.serialize(self.b_constraint),
'W_dropout': self.W_dropout,
'u_dropout': self.u_dropout,
'bias': self.bias}
base_config = super(AttentionWithContext, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def CNN(seq_length, length, input_size, feature_maps, kernels, x):
concat_input = []
for feature_map, kernel in zip(feature_maps, kernels):
reduced_l = length - kernel + 1
conv = Conv2D(feature_map, (1, kernel), activation='tanh', data_format="channels_last")(x)
maxp = MaxPooling2D((1, reduced_l), data_format="channels_last")(conv)
concat_input.append(maxp)
con = Concatenate()(concat_input)
con = Reshape((seq_length, sum(feature_maps)))(con)
return con