Contact
CoCalc Logo Icon
StoreFeaturesDocsShareSupport News AboutSign UpSign In
| Download
Path: gsc.py
Views: 108
1
import itertools
2
import numpy as np
3
import re
4
import numbers
5
import sys
6
from scipy import linalg # CHECK numpy vs. scipy
7
import pandas as pd
8
# import mpl_toolkits.axes_grid1
9
import grammar
10
from collections import OrderedDict
11
import matplotlib.pyplot as plt
12
13
ver = 0.3
14
15
class Fillers(object):
16
"""
17
"""
18
19
def __init__(self, cfg, add_null=True):
20
21
null_filler = '_'
22
23
self.g = grammar.Grammar(cfg)
24
self._get_filler_names()
25
if add_null:
26
self.null = null_filler
27
self.names.append(null_filler)
28
else:
29
self.null = None
30
self._construct_treelets()
31
self._construct_pairs()
32
33
def _get_filler_names(self):
34
35
hnf_rules = self.g.hnfRules
36
fillers = []
37
for key in hnf_rules:
38
fillers.append(key)
39
for item in hnf_rules[key]:
40
fillers.extend(item)
41
42
fillers = list(set(fillers))
43
fillers.sort()
44
self.names = fillers
45
46
def _construct_treelets(self):
47
# List of list of fillers (strings)
48
# [[mother, daughter1(, daughter2, ...)],
49
# [ ... ]]
50
51
hnf_rules = self.g.hnfRules
52
treelet_fillers = []
53
54
for lhs in hnf_rules.keys():
55
rhs = hnf_rules[lhs]
56
57
for sym in rhs:
58
if len(sym) == 1 and self.is_bracketed(sym[0]):
59
curr_treelet = [lhs, sym[0]]
60
rhs_new = hnf_rules[sym[0]]
61
for sym_new in rhs_new:
62
curr_treelet.extend(sym_new)
63
treelet_fillers.append(curr_treelet)
64
65
self.treelets = treelet_fillers
66
67
def _construct_pairs(self):
68
69
self.pairs = []
70
71
for treelet in self.treelets:
72
self.pairs.append([[treelet[0], treelet[1]], '1_of_1'])
73
n_daughters = len(treelet) - 2
74
75
for ii in range(n_daughters):
76
self.pairs.append([[treelet[1], treelet[ii + 2]],
77
'%d_of_%d' % (ii + 1, n_daughters)])
78
79
def find(self, filler):
80
return 0
81
82
def find_mothers(self, filler):
83
return 0
84
85
def find_daughters(self, filler):
86
return 0
87
88
def subset_bracketed(self):
89
"""Return a list of bracketed filler symbols"""
90
pattern = re.compile('.*\[[0-9]+\]$')
91
fillers_bracketed = []
92
for filler in self.names:
93
if pattern.match(filler) is not None:
94
fillers_bracketed.append(filler)
95
return fillers_bracketed
96
97
def subset_unbracketed(self):
98
"""Return a list of bracketed filler symbols"""
99
fillers_unbracketed = [filler for filler in self.names
100
if filler not in self.subset_bracketed()]
101
return fillers_unbracketed
102
103
def subset_root(self):
104
return self.g.getRootNode(self.g.hnfRules)
105
106
def subset_terminals(self):
107
return self.g.getTerminalNodes(self.g.hnfRules)
108
109
def is_terminal(self, filler):
110
return filler in self.subset_terminals()
111
112
def is_bracketed(self, filler):
113
return filler in self.subset_bracketed()
114
115
def is_unbracketed(self, filler):
116
return filler in self.subset_unbracketed()
117
118
def is_root(self, filler):
119
return filler in self.subset_root()
120
121
def read_treelets(self):
122
for ii, treelet in enumerate(self.treelets):
123
temp = (treelet[0] + ' ( ' +
124
treelet[1] + ' ( ' + ' '.join(treelet[2:]) + ' ) )')
125
print(temp)
126
127
# class RecursiveRoles(object):
128
# # HNF assumed
129
# # CHANGE the depth: in Nick's program, 'r' is depth 1, and '0r' is depth 2. '00r' is depth 3 and so on.
130
# # In this program, 'r' and '0r' is depth 0, '00r' and '000r' is depth 1, and so on.
131
# # : depth = Number of applicatoins of rewrite rules (in the CNF)
132
133
# def __init__(self, max_depth, num_branch=2):
134
135
# self.max_depth = max_depth
136
# self.num_branch = num_branch
137
# self.root_symbol = 'r'
138
# self.roles = None
139
# self.roles_enhanced = None
140
# self.depth = 0
141
# self.names = []
142
# self.names.append(self.root_symbol)
143
# self._generate(mother=self.root_symbol)
144
# self._sort()
145
# self._construct_treelets()
146
# self._construct_pairs()
147
148
# def _construct_pairs(self):
149
# self.pairs = []
150
# for treelet in self.treelets:
151
# n_daughters = len(treelet) - 1
152
# for ii in range(n_daughters):
153
# self.pairs.append([[treelet[0], treelet[ii + 1]],
154
# '%d_of_%d' % (ii + 1, n_daughters)])
155
156
# def _generate(self, mother):
157
158
# self.depth = self.check_depth(mother)
159
160
# bracketed = '0' + mother
161
# unbracketed = [str(ii) + bracketed for ii in range(self.num_branch)]
162
# self.names.append(bracketed)
163
# [self.names.append(role) for role in unbracketed]
164
165
# self.depth += 1
166
167
# if self.depth < self.max_depth:
168
# for mother in unbracketed:
169
# self._generate(mother)
170
171
# def _sort(self):
172
173
# temp = []
174
# for role_name in self.names:
175
# temp.append([role_name[::-1], len(role_name)])
176
# temp = sorted(temp, key = lambda x: (x[1], x[0]))
177
178
# self.names = [item[0][::-1] for item in temp]
179
180
# def _construct_treelets(self):
181
182
# self.treelets = []
183
# for role in self.names:
184
# depth = self.check_depth(role)
185
# if depth < self.max_depth:
186
# # Non-terminal
187
# daughters = [self.names[idx] for idx in self.find_daughters(role)]
188
# treelet = [role] # add a mother
189
# [treelet.append(daughter) for daughter in daughters]
190
# self.treelets.append(treelet)
191
192
# def find_mother(self, role):
193
# """Return the index of the mother of a given role."""
194
195
# if role == self.root_symbol:
196
# res = []
197
# else:
198
# res = self.find(role[1:])
199
200
# return res
201
202
# def find_daughters(self, role):
203
# """Return the index of the mother of a given role."""
204
205
# if self.is_bracketed(role):
206
# res = self.find([str(ii) + role for ii in range(self.num_branch)])
207
# else:
208
# res = self.find('0' + role)
209
# return res
210
211
# def find(self, roles):
212
# """Return a list of indices given a list of roles (tuple)."""
213
214
# if not isinstance(roles, list):
215
# roles = [roles]
216
217
# return [self.names.index(role) for role in roles]
218
219
# def is_bracketed(self, role):
220
221
# return not bool(len(role) % 2)
222
223
# def check_depth(self, role):
224
225
# return (len(role) - 1) // 2
226
227
# def update_max_depth(self):
228
229
# lengths = [len(role_name) for role_name in self.names]
230
# self.max_depth = self.check_depth(self.names[np.argmax(lengths)])
231
232
233
class SpanRoles(object):
234
235
# ToDo:
236
# Support the setting: use_terminal=True (not yet supported --> fillers)
237
238
def __init__(self, max_sent_len, max_num_branch=2):
239
'''Construct a SpanRole object.'''
240
# use_terminal: do you have special terminal symbols?
241
self.max_sent_len = max_sent_len
242
self.max_num_branch = max_num_branch
243
self.use_terminal = False
244
self._generate()
245
self._sort()
246
self._name()
247
self._graph()
248
self._check_overlap()
249
self._construct_treelets()
250
self._construct_pairs()
251
252
def _generate(self):
253
# Generate a set of span roles with n-branches.
254
roles = []
255
for num_branch in range(self.max_num_branch):
256
roles += list(itertools.combinations(
257
range(self.max_sent_len + 1), num_branch + 2))
258
259
if self.use_terminal:
260
roles_terminal = []
261
for role in roles:
262
if self.get_width(role) == 1:
263
roles_terminal.append((role[0], role[1], role[1]))
264
roles = roles + roles_terminal
265
266
self.roles = roles
267
self.num_roles = len(roles)
268
269
def sort_roles(self, roles_tuple):
270
# Sort span roles.
271
# (1) width - increasing order
272
# (2) num_branch - decreasing order
273
# (3) init_position - increasing order
274
roles_augmented = []
275
for role in roles_tuple:
276
roles_augmented.append(
277
(role, self.get_width(role),
278
self.get_num_branch(role), role[0]))
279
280
roles_augmented = sorted(
281
roles_augmented, key=lambda x: (x[1], -x[2], x[3]))
282
roles_tuple_sorted = [role_augmented[0] for role_augmented
283
in roles_augmented]
284
return roles_tuple_sorted
285
286
def _sort(self):
287
self.roles = self.sort_roles(self.roles)
288
289
def _name(self):
290
role_names = []
291
for role in self.roles:
292
role_names.append(self.tuple2str(role))
293
self.names = role_names
294
295
def _graph(self):
296
G = {}
297
G = OrderedDict(G)
298
299
for role_str in self.names:
300
G[role_str] = {}
301
G[role_str]['m'] = []
302
G[role_str]['d'] = []
303
304
for role_tuple in self.roles:
305
role_str = self.tuple2str(role_tuple)
306
307
if self.is_bracketed(role_tuple):
308
mother_tuple = self._find_mother(role_tuple)
309
mother_str = self.tuple2str(mother_tuple)
310
daughter_tuple = self._find_daughter(role_tuple)
311
daughter_str = [self.tuple2str(d) for d in daughter_tuple
312
if self.get_width(d) > 0]
313
G[role_str]['m'].append([mother_str])
314
G[mother_str]['d'].append([role_str])
315
G[role_str]['d'].append(daughter_str)
316
for ii, d in enumerate(daughter_str):
317
G[d]['m'].append([role_str])
318
319
if self.use_terminal and self.is_terminal(role_tuple):
320
mother_tuple = self._find_mother(role_tuple)
321
mother_str = self.tuple2str(mother_tuple)
322
G[role_str]['m'].append([mother_str])
323
G[mother_str]['d'].append([role_str])
324
325
self.G = G
326
327
def _check_overlap(self):
328
# It is assumed that span roles are sorted correctly. (see sort_roles)
329
# Find a set of bracketed roles that have the same mother
330
# (e.g., Roles (0,1,3) and (0,2,3) have the same mother.).
331
# We assume that every binding competes with every other binding
332
# in this set of bracketed roles that share the same mother node.
333
quant_list = [] # list of list of role names
334
for role_name in self.G:
335
if not self.is_bracketed(self.str2tuple(role_name)):
336
# An unbracketed role cannot have multiple daughters.
337
# Thus, daughters (below) must be
338
# list of list of a single binding.
339
daughters = self.G[role_name]['d']
340
if len(daughters) > 0:
341
ds = [d[0] for d in daughters
342
if not self.is_terminal(self.str2tuple(d[0]))]
343
if len(ds) > 0:
344
quant_list.append(ds)
345
quant_list.append([role_name])
346
347
# sorting
348
quant_list = sorted(
349
quant_list,
350
key=lambda x: (self.get_width(self.str2tuple(x[0])),
351
-self.get_num_branch(self.str2tuple(x[0])),
352
self.str2tuple(x[0])[0]))
353
354
self.quant_list = quant_list
355
356
def _construct_treelets(self):
357
# [(0,2), (0,1,2), (0,1), (1,2)] # S(S[1](A B))
358
# [(0,1), (0,1,1)] # A(a) (when special terminal fillers are used)
359
treelet_roles = []
360
for role_str in self.G:
361
role_tuple = self.str2tuple(role_str)
362
if self.is_bracketed(role_tuple):
363
curr_treelet = []
364
mlist = self.G[role_str]['m']
365
if len(mlist) > 1: # error
366
sys.exit('CHECK!!!')
367
curr_treelet.extend(mlist[0])
368
curr_treelet.append(role_str)
369
dlist = self.G[role_str]['d']
370
curr_treelet.extend(dlist[0])
371
treelet_roles.append(curr_treelet)
372
373
# if self.use_terminal:
374
# for role_tuple in self.subset_terminals():
375
# role_name = self.tuple2str(role_tuple)
376
# mlist = self.G[role_name]['m']
377
# for m in mlist:
378
# treelet_roles.append([m[0], role_name])
379
380
self.treelets = treelet_roles
381
382
def _construct_pairs(self):
383
self.pairs = []
384
for treelet in self.treelets:
385
self.pairs.append([[treelet[0], treelet[1]], '1_of_1'])
386
n_daughters = len(treelet) - 2
387
for ii in range(n_daughters):
388
self.pairs.append([[treelet[1], treelet[ii + 2]],
389
'%d_of_%d' % (ii + 1, n_daughters)])
390
391
def _find_mother(self, role_tuple_bracketed):
392
if self.is_bracketed(role_tuple_bracketed):
393
return (role_tuple_bracketed[0], role_tuple_bracketed[-1])
394
if self.use_terminal and self.is_terminal(role_tuple_bracketed):
395
return (role_tuple_bracketed[0], role_tuple_bracketed[-1])
396
397
def _find_daughter(self, role_tuple_bracketed):
398
if self.is_bracketed(role_tuple_bracketed):
399
return [(role_tuple_bracketed[ii], role_tuple_bracketed[ii + 1])
400
for ii in range(len(role_tuple_bracketed) - 1)]
401
402
def str2tuple(self, role_str):
403
return tuple([int(pos) for pos in role_str[1:-1].split(',')])
404
405
def tuple2str(self, role_tuple):
406
return str(role_tuple).replace(' ', '')
407
408
def get_width(self, role_tuple):
409
return role_tuple[-1] - role_tuple[0]
410
411
def get_num_branch(self, role_tuple):
412
return len(role_tuple) - 1
413
414
def is_bracketed(self, role_tuple):
415
# If we assume HNF, only bracked roles can have multiple children.
416
return (self.get_num_branch(role_tuple) > 1 and
417
self.get_width(role_tuple) > 1)
418
419
def is_terminal(self, role_tuple):
420
if self.use_terminal:
421
return (self.get_width(role_tuple) == 1 and
422
self.get_num_branch(role_tuple) == 2)
423
else:
424
return (self.get_width(role_tuple) == 1 and
425
self.get_num_branch(role_tuple) == 1)
426
427
def subset_terminals(self):
428
'''Return a list of terminal roles (tuple)'''
429
return [r for r in self.roles if self.is_terminal(r)]
430
431
def subset_bracketed(self):
432
'''Return a list of terminal roles (tuple)'''
433
return [r for r in self.roles if self.is_bracketed(r)]
434
435
def read_treelets(self):
436
for ii, treelet in enumerate(self.treelets):
437
temp = (treelet[0] + ' ( ' +
438
treelet[1] + ' ( ' + ' '.join(treelet[2:]) + ' ) )')
439
print(temp)
440
441
442
class HarmonicGrammar(object):
443
"""Construct an object implementing a Harmonic Grammar
444
445
[ [['S[1]/(0,1,2)', 'A/(0,1)'], 2.], # H(S[1]/(0,1,2), A/(0,1)) = 2
446
[['S/(0,2)', 'S[1]/(0,1,2)'], 2.],
447
[['A/(0,1)'], -1.], # H(A/(0,1)) = -1
448
...
449
... # For parsimony,
450
[['A'], -1.], # H(A/r) = -1 for all r in R, a set of roles
451
[['(0,1,2)'], -3.] # H(f/(0,1,2)) = -3 for all f in F, a set of fillers
452
...]
453
"""
454
455
def __init__(self, cfg, size, role_type="span_role",
456
add_null=True, match='pair', penalize='impossible',
457
null_bias=0., add_constraint=False, unary_base='filler'):
458
459
# cfg (sting): context-free grammar
460
# size (int): max_sent_len when span roles are used,
461
# max_depth when recursive roles are used.
462
# role_type (string): 'span_role' or 'recursive_role'
463
# add_null (bool): add a null treelet or not
464
# null_bias (numeric): bias values assigned to null bindings
465
466
self.role_type = role_type
467
self.size = size
468
self.rules = []
469
self.filler_names = None
470
self.role_names = None
471
self.binding_names = None
472
473
# CFG(CNF) -> CFG(HNF)
474
# For now, only binary branching is supported.
475
self.num_branch = 2
476
self.bias_constraint = -10.
477
self.unused_binding_harmony = -10
478
self.null_bias = null_bias
479
480
# Use Nick's program to get a grammar in harmonic normal form.
481
# Later integrate the program with this and
482
# make it more consistent with this program.
483
self.grammar = grammar.Grammar(cfg)
484
self.fillers = Fillers(cfg, add_null=add_null)
485
self.filler_names = self.fillers.names
486
487
if role_type == "span_role":
488
self.roles = SpanRoles(
489
max_sent_len=size, max_num_branch=self.num_branch)
490
self.role_names = self.roles.names
491
self._set_binding_names()
492
493
self._construct_treelets(add_constraint)
494
self._construct_pairs()
495
self._convert_tuple_to_str()
496
self._add_rules_binary(match=match)
497
self._add_rules_unary(which=unary_base)
498
499
if add_constraint:
500
self._add_constraints(penalize=penalize)
501
502
# # unary_base is not filler:
503
# if add_null and (unary_base == 'role'):
504
# self._add_rules_null(null_bias)
505
# self._add_rules_root()
506
507
# elif role_type == "recursive_role":
508
# self.roles = RecursiveRoles(max_depth=size, num_branch=2)
509
# self.treelets = []
510
511
# # there may be multiple start symbols
512
# fillers = self.fillers.subset_root()
513
# role = self.roles.root_symbol
514
# for filler in fillers:
515
# self._generate_treelet_with_recursive_roles(
516
# binding=(filler, role))
517
# self._sort()
518
# self._prune_roles()
519
# self._set_binding_names()
520
# self._convert_tuple_to_str()
521
# self.rules = []
522
# self._add_rules_binary(use='treelets')
523
# self._add_rules_unary(which=unary_base)
524
# if unary_base != 'binding':
525
# if add_null and (unary_base == 'role'):
526
# self._add_rules_null(null_bias)
527
# self._add_rules_root()
528
529
# else:
530
# sys.exit('The role_type argument must be set to either "span_role" or "recursive_role".')
531
532
def _add_constraints(self, penalize='unused'):
533
534
if penalize is 'unused':
535
# penalize unused bindings
536
bindings_used = []
537
for treelet in self.treelets:
538
[bindings_used.append(binding) for binding in treelet]
539
540
for binding in self.binding_names:
541
if binding not in bindings_used:
542
if binding.split('/')[0] != '_':
543
self.rules.append(
544
[[binding], self.unused_binding_harmony])
545
# else:
546
# self.rules.append([[binding], self.null_bias])
547
548
elif penalize is 'impossible':
549
# penalize impossible bindings
550
551
for b in self.binding_names:
552
f, r = b.split('/')
553
r_tuple = self.roles.str2tuple(r)
554
555
if f is not self.fillers.null:
556
557
# terminal fillers - terminal roles
558
if f in self.fillers.subset_terminals():
559
if r_tuple not in self.roles.subset_terminals():
560
self.rules.append([[b], self.bias_constraint])
561
562
# starting fillers - non-terminal, unbracketed roles
563
elif self.fillers.is_root(f):
564
if (r_tuple in self.roles.subset_terminals()) or \
565
(r_tuple in self.roles.subset_bracketed()):
566
self.rules.append([[b], self.bias_constraint])
567
568
# unbracketed fillers - non-terminal, unbracketed roles
569
elif f not in self.get_bracketed_fillers():
570
if (r_tuple in self.roles.subset_terminals()) or \
571
(r_tuple in self.roles.subset_bracketed()):
572
self.rules.append([[b], self.bias_constraint])
573
574
# bracketed fillers - bracketed roles
575
elif f in self.get_bracketed_fillers():
576
if r_tuple not in self.roles.subset_bracketed():
577
self.rules.append([[b], self.bias_constraint])
578
579
def _set_binding_names(self, sep='/'):
580
581
self.binding_names = [f + sep + r for r in self.roles.names
582
for f in self.fillers.names]
583
584
def _construct_treelets(self, add_constraint):
585
586
terminal_fillers = self.get_terminal_fillers()
587
terminal_roles = self.get_terminal_roles()
588
589
if self.fillers.null is not None:
590
null_filler = self.fillers.null
591
else:
592
null_filler = None
593
594
treelet_bindings = []
595
for treelet_f in self.fillers.treelets:
596
for treelet_r in self.roles.treelets:
597
if len(treelet_f) == len(treelet_r):
598
# test terminals and non-terminals
599
treelet_b = list(zip(treelet_f, treelet_r))
600
if add_constraint:
601
count = 0
602
for binding_index in range(len(treelet_b)):
603
is_terminal_f = (treelet_b[binding_index][0]
604
in terminal_fillers)
605
is_terminal_r = (treelet_b[binding_index][1]
606
in terminal_roles)
607
is_null_f = (treelet_b[binding_index][0] ==
608
null_filler)
609
if (is_terminal_f and is_terminal_r) or \
610
((not is_terminal_f) and (not is_terminal_r)) or (is_null_f):
611
count += 1
612
if count == len(treelet_b):
613
treelet_bindings.append(treelet_b)
614
else:
615
treelet_bindings.append(treelet_b)
616
617
self.treelets = treelet_bindings
618
619
def _construct_pairs(self):
620
621
pair_bindings = []
622
for pair_f in self.fillers.pairs:
623
for pair_r in self.roles.pairs:
624
if pair_f[1] == pair_r[1]: # same type
625
pair_bindings.append(list(zip(pair_f[0], pair_r[0])))
626
627
self.pairs = pair_bindings
628
629
def _convert_tuple_to_str(self):
630
631
for ii, treelet in enumerate(self.treelets):
632
for jj, binding in enumerate(treelet):
633
self.treelets[ii][jj] = binding[0] + '/' + binding[1]
634
635
for ii, pair in enumerate(self.pairs):
636
for jj, binding in enumerate(pair):
637
self.pairs[ii][jj] = binding[0] + '/' + binding[1]
638
639
def _add_rules_binary(self, match='pair'):
640
"""Add binary HG rules"""
641
642
if match == 'treelet':
643
for treelet in self.treelets:
644
self.rules.append([[treelet[0], treelet[1]], 2.0])
645
for ii in range(len(treelet) - 2):
646
self.rules.append([[treelet[1], treelet[ii + 2]], 2.0])
647
648
elif match == 'pair':
649
for pair in self.pairs:
650
self.rules.append([[pair[0], pair[1]], 2.0])
651
652
def _add_rules_unary(self, which="filler"):
653
"""Add unary HG rules"""
654
655
# Case 1
656
# role_type: span_role
657
# unary_base: filler
658
for ii, filler in enumerate(self.fillers.names):
659
if filler in self.fillers.subset_bracketed():
660
self.rules.append([[filler], -1.0 - self.num_branch])
661
elif filler in self.fillers.subset_terminals():
662
self.rules.append([[filler], -1.0])
663
elif filler in self.fillers.subset_root():
664
self.rules.append([[filler], -1.0])
665
elif filler == self.fillers.null:
666
self.rules.append([[filler], self.null_bias])
667
else:
668
self.rules.append([[filler], -2.0])
669
670
# if which == "role":
671
# if self.role_type == "span_role":
672
# # Role bias --- but in the Harmonic Grammar book,
673
# # the bias was set based on fillers, not on roles.
674
# for ii, role in enumerate(self.roles.roles):
675
# num_edges = float(len(role)) # 1 mother + [len(role)-1] daughters
676
# width = role[-1] - role[0]
677
# if width == 1:
678
# num_edges = num_edges - 1 # no daughter
679
# # No adjustment for the top span role (e.g., [0, max_sent_len] which does not have a mother
680
# # in the implemented model. But here the role has a mother so num_edges was set to 2
681
# # 1 for a single daughter and 1 for a possible mother in the infinite model
682
# # Consider adding +1 for that case.
683
# self.rules.append([[str(role).replace(' ', '')], -num_edges])
684
685
# elif self.role_type == "recursive_role":
686
# sys.exit('CHECK: We have not yet implemented this case.')
687
# temp = [len(role) for role in self.roles.names]
688
# for role in self.roles.names:
689
# if len(role) == 1:
690
# self.rules.append([[role], -1]) # root node
691
# elif len(role) == max(temp): # terminal nodes
692
# return 0
693
694
# elif which == "filler":
695
# for ii, filler in enumerate(self.fillers.names):
696
# if filler in self.fillers.subset_bracketed():
697
# self.rules.append([[filler], -3.0])
698
# elif filler in self.fillers.subset_terminals():
699
# self.rules.append([[filler], -1.0])
700
# elif filler in self.fillers.subset_root():
701
# self.rules.append([[filler], -1.0])
702
# elif filler == self.fillers.null:
703
# self.rules.append([[filler], 0.0])
704
# else:
705
# self.rules.append([[filler], -2.0])
706
707
# elif which == "binding":
708
# bindings_used = []
709
# for treelet in self.treelets:
710
# [bindings_used.append(binding) for binding in treelet]
711
712
# for binding in self.binding_names:
713
# if binding in bindings_used:
714
# filler, role = binding.split('/')
715
# if self.fillers.is_terminal(filler):
716
# self.rules.append([[binding], -1.0])
717
# elif self.fillers.is_root(filler):
718
# self.rules.append([[binding], -1.0])
719
# elif self.fillers.is_bracketed(filler):
720
# self.rules.append(
721
# [[binding], -1.0 - self.num_branch])
722
# elif self.fillers.is_unbracketed(filler):
723
# self.rules.append([[binding], -2.0])
724
# else:
725
# sys.exit('Error')
726
# else:
727
# if binding.split('/')[0] != '_':
728
# self.rules.append([[binding], unused_binding_harmony])
729
# else:
730
# self.rules.append([[binding], 0])
731
732
# elif which == "binding0":
733
# # Use this when add_constraint = False
734
# # terminal fillers cannot occur in non-terminal roles
735
# # non-terminal fillers cannot occur in terminal roles
736
# # bracketed fillers cannot occur in non-bracketed roles
737
# # non-bracketed fillers cannot occur in bracketed roles
738
# terminal_roles = self.get_terminal_roles()
739
# bracketed_roles = self.get_bracketed_roles()
740
741
# for binding in self.binding_names:
742
# filler, role = binding.split('/')
743
# if filler == '_':
744
# self.rules.append([[binding], 0])
745
# elif self.fillers.is_terminal(filler):
746
# if role in terminal_roles:
747
# self.rules.append([[binding], -1.0])
748
# else:
749
# self.rules.append([[binding], unused_binding_harmony])
750
# elif self.fillers.is_root(filler):
751
# if (role not in bracketed_roles) and (role not in terminal_roles):
752
# self.rules.append([[binding], -1.0])
753
# else:
754
# self.rules.append([[binding], unused_binding_harmony])
755
# elif self.fillers.is_bracketed(filler):
756
# if role in bracketed_roles:
757
# self.rules.append([[binding], -1.0 - self.num_branch])
758
# else:
759
# self.rules.append([[binding], unused_binding_harmony])
760
# elif self.fillers.is_unbracketed(filler):
761
# if (role not in bracketed_roles) and (role not in terminal_roles):
762
# self.rules.append([[binding], -2.0])
763
# else:
764
# self.rules.append([[binding], unused_binding_harmony])
765
# else:
766
# sys.exit('Error')
767
768
# def _prune_roles(self):
769
# """ Remove unused roles """
770
771
# roles_minimal = []
772
# for treelet in self.treelets:
773
# [roles_minimal.append(binding[1]) for binding in treelet]
774
# roles_minimal = list(set(roles_minimal))
775
# self.roles.names = [
776
# role for role in self.roles.names if role in roles_minimal]
777
778
# r_treelets = []
779
# for r_treelet in self.roles.treelets:
780
# if set(r_treelet).issubset(set(self.roles.names)):
781
# r_treelets.append(r_treelet)
782
783
# self.roles.treelets = r_treelets
784
# self.roles.update_max_depth()
785
786
# def _sort(self):
787
788
# # recursive roles
789
# treelets_enhanced = []
790
# for treelet in self.treelets:
791
# treelets_enhanced.append(
792
# [treelet, self.fillers.names.index(treelet[0][0]),
793
# self.roles.names.index(treelet[0][1])])
794
795
# treelets_enhanced = sorted(
796
# treelets_enhanced, key=lambda x: (x[1], x[0]))
797
# self.treelets = [treelet[0] for treelet in treelets_enhanced]
798
799
# def _generate_treelet_with_recursive_roles(self, binding):
800
801
# filler = binding[0]
802
# role = binding[1]
803
# f_treelets = [treelet for treelet in self.fillers.treelets
804
# if treelet[0] == filler]
805
# r_treelet = [treelet for treelet in self.roles.treelets
806
# if treelet[0] == role][0]
807
# new_bindings = []
808
809
# for f_treelet in f_treelets:
810
# if len(r_treelet) == len(f_treelet):
811
# bindings = list(zip(f_treelet, r_treelet))
812
# if bindings not in self.treelets:
813
# self.treelets.append(bindings)
814
# [new_bindings.append(bb) for ii, bb in enumerate(bindings)
815
# if ii > 0 and (bb not in new_bindings)]
816
817
# # Check level
818
# for binding in new_bindings:
819
# if self.roles.check_depth(binding[1]) < self.roles.max_depth:
820
# self._generate_treelet_with_recursive_roles(binding)
821
822
def read_rules(self):
823
824
print("Binary rules:\n")
825
for rule in self.rules:
826
if len(rule[0]) == 2:
827
print('H(' + rule[0][0] + ', ' + rule[0][1] + ') = %.1f' % rule[1])
828
829
print("\nUnary rules:\n")
830
for rule in self.rules:
831
if len(rule[0]) == 1:
832
if rule[0][0] in self.fillers.names:
833
print('H(' + rule[0][0] + '/*) = %.1f' % rule[1])
834
elif rule[0][0] in self.roles.names:
835
print('H(*/' + rule[0][0] + ') = %.1f' % rule[1])
836
elif rule[0][0] in self.binding_names:
837
print('H(' + rule[0][0] + ') = %.1f' % rule[1])
838
else:
839
return 0 # ERROR MESSAGE
840
841
#print("\nFor any conflicting unary rules, \nthe rule on the bottom overwrites the rule on the top.")
842
843
def _add_rules_root(self):
844
# empty bias
845
root_fillers = self.grammar.getRootNode(self.grammar.hnfRules)
846
if not isinstance(root_fillers, list):
847
root_fillers = [root_fillers]
848
for filler in root_fillers:
849
self.rules.append([[filler], -1.])
850
return 0
851
852
# def _add_rules_unary(self, which="role"):
853
# """Add unary HG rules"""
854
855
# unused_binding_harmony = -5.0
856
857
# if which == "role":
858
# if self.role_type == "span_role":
859
# # Role bias --- but in the Harmonic Grammar book, the bias was set based on fillers, not on roles.
860
# for ii, role in enumerate(self.roles.roles):
861
# num_edges = float(len(role)) # 1 mother + [len(role)-1] daughters
862
# width = role[-1] - role[0]
863
# if width == 1:
864
# num_edges = num_edges - 1 # no daughter
865
# # No adjustment for the top span role (e.g., [0, max_sent_len] which does not have a mother
866
# # in the implemented model. But here the role has a mother so num_edges was set to 2
867
# # 1 for a single daughter and 1 for a possible mother in the infinite model
868
# # Consider adding +1 for that case.
869
# self.rules.append([[str(role).replace(' ', '')], -num_edges])
870
871
# elif self.role_type == "recursive_role":
872
# sys.exit('CHECK: We have not yet implemented this case.')
873
# temp = [len(role) for role in self.roles.names]
874
# for role in self.roles.names:
875
876
# if len(role) == 1:
877
# self.rules.append([[role], -1]) # root node
878
# elif len(role) == max(temp): # terminal nodes
879
# return 0
880
881
# elif which == "filler":
882
# for ii, filler in enumerate(self.fillers.names):
883
# if filler in self.fillers.subset_bracketed():
884
# self.rules.append([[filler], -3.0])
885
# elif filler in self.fillers.subset_terminals():
886
# self.rules.append([[filler], -1.0])
887
# elif filler in self.fillers.subset_root():
888
# self.rules.append([[filler], -1.0])
889
# elif filler == self.fillers.null:
890
# self.rules.append([[filler], 0.0])
891
# else:
892
# self.rules.append([[filler], -2.0])
893
894
# elif which == "binding":
895
# bindings_used = []
896
# for treelet in self.treelets:
897
# [bindings_used.append(binding) for binding in treelet]
898
899
# for binding in self.binding_names:
900
# if binding in bindings_used:
901
# filler, role = binding.split('/')
902
# if self.fillers.is_terminal(filler):
903
# self.rules.append([[binding], -1.0])
904
# elif self.fillers.is_root(filler):
905
# self.rules.append([[binding], -1.0])
906
# elif self.fillers.is_bracketed(filler):
907
# self.rules.append(
908
# [[binding], -1.0 - self.num_branch])
909
# elif self.fillers.is_unbracketed(filler):
910
# self.rules.append([[binding], -2.0])
911
# else:
912
# sys.exit('Error')
913
# else:
914
# if binding.split('/')[0] != '_':
915
# self.rules.append([[binding], unused_binding_harmony])
916
# else:
917
# self.rules.append([[binding], 0])
918
919
# elif which == "binding0":
920
# # Use this when add_constraint = False
921
# # terminal fillers cannot occur in non-terminal roles
922
# # non-terminal fillers cannot occur in terminal roles
923
# # bracketed fillers cannot occur in non-bracketed roles
924
# # non-bracketed fillers cannot occur in bracketed roles
925
# terminal_roles = self.get_terminal_roles()
926
# bracketed_roles = self.get_bracketed_roles()
927
928
# for binding in self.binding_names:
929
# filler, role = binding.split('/')
930
# if filler == '_':
931
# self.rules.append([[binding], 0])
932
# elif self.fillers.is_terminal(filler):
933
# if role in terminal_roles:
934
# self.rules.append([[binding], -1.0])
935
# else:
936
# self.rules.append([[binding], unused_binding_harmony])
937
# elif self.fillers.is_root(filler):
938
# if (role not in bracketed_roles) and (role not in terminal_roles):
939
# self.rules.append([[binding], -1.0])
940
# else:
941
# self.rules.append([[binding], unused_binding_harmony])
942
# elif self.fillers.is_bracketed(filler):
943
# if role in bracketed_roles:
944
# self.rules.append([[binding], -1.0 - self.num_branch])
945
# else:
946
# self.rules.append([[binding], unused_binding_harmony])
947
# elif self.fillers.is_unbracketed(filler):
948
# if (role not in bracketed_roles) and (role not in terminal_roles):
949
# self.rules.append([[binding], -2.0])
950
# else:
951
# self.rules.append([[binding], unused_binding_harmony])
952
# else:
953
# sys.exit('Error')
954
955
# def _add_rules_null(self, null_bias):
956
957
# self.rules.append([[self.fillers.null], null_bias])
958
959
def reorder_fillers(self, fillers_ordered):
960
if self.fillers.null is not None:
961
if not (self.fillers.null in fillers_ordered):
962
fillers_ordered.append(self.fillers.null)
963
964
if set(self.fillers.names) == set(fillers_ordered):
965
self.fillers.names = fillers_ordered
966
else:
967
sys.exit("The set of filler names constructued ffrom grammar is not same as the set of filler names you provided.")
968
969
self._set_binding_names()
970
971
def get_binary(self):
972
return 0
973
974
def get_unary(self):
975
return 0
976
977
def sort(self):
978
return 0
979
980
def get_terminal_fillers(self):
981
982
terminal_fillers = self.grammar.getTerminalNodes(self.grammar.hnfRules)
983
terminal_fillers.sort()
984
return terminal_fillers
985
986
def get_bracketed_fillers(self):
987
"""Return a list of bracketed filler symbols"""
988
989
pattern = re.compile('.*\[[0-9]+\]$')
990
fillers_bracketed = []
991
for key in self.grammar.hnfRules.keys():
992
if pattern.match(key) is not None:
993
fillers_bracketed.append(key)
994
return fillers_bracketed
995
996
def get_terminal_roles(self):
997
#return self.roles.subset_terminals()
998
return [str(role).replace(' ', '') for role in self.roles.subset_terminals()]
999
1000
def get_bracketed_roles(self):
1001
#return self.roles.subset_terminals()
1002
return [str(role).replace(' ', '') for role in self.roles.subset_bracketed()]
1003
1004
1005
class GscNet(object):
1006
1007
def __init__(self, hg=None, encodings=None, opts=None, seed=None):
1008
1009
self._set_opts()
1010
self._update_opts(opts=opts)
1011
1012
self._set_encodings()
1013
self._update_encodings(encodings=encodings)
1014
1015
if seed is not None:
1016
self.set_seed(seed)
1017
self.seed = seed
1018
1019
self.hg = hg
1020
self._add_names()
1021
self._generate_encodings()
1022
self._compute_TPmat()
1023
1024
self.WC = np.zeros((self.num_bindings, self.num_bindings))
1025
self.bC = np.zeros(self.num_bindings)
1026
if hg is not None:
1027
self._build_model(hg)
1028
self._set_weights()
1029
self._set_biases()
1030
self._set_quant_list()
1031
1032
self.actC = np.zeros(self.num_bindings)
1033
self.actC_prev = np.zeros(self.num_bindings)
1034
self.act = self.C2N()
1035
self.act_prev = self.C2N(actC=self.actC_prev)
1036
1037
self.extC = np.zeros(self.num_bindings)
1038
self.extC_prev = np.zeros(self.num_bindings)
1039
self.ext = self.C2N(actC=self.extC)
1040
self.ext_prev = self.C2N(actC=self.extC_prev)
1041
1042
# Bowl parameters
1043
if isinstance(self.opts['bowl_center'], numbers.Number):
1044
self.bowl_center = (self.opts['bowl_center'] *
1045
np.ones(self.num_bindings))
1046
else:
1047
self.bowl_center = self.opts['bowl_center']
1048
if self.opts['bowl_strength'] is None:
1049
self.opts['bowl_strength'] = (
1050
self._compute_recommended_bowl_strength() +
1051
self.opts['beta_min_offset'])
1052
else:
1053
self.check_bowl_strength()
1054
self.bowl_strength = self.opts['bowl_strength']
1055
self.zeta = self.C2N(actC=self.bowl_center)
1056
1057
self.t = 0
1058
self.speed = None
1059
self.ema_speed = None
1060
self.clamped = False
1061
self.q = self.opts['q_init']
1062
self.T = self.opts['T_init']
1063
self.dt = self.opts['dt']
1064
1065
self.reset()
1066
1067
#self.quant_list = quant_list
1068
#self.grid_points = grid_points
1069
1070
## Generate a set of all grid points?
1071
## Do not set it to True if the model is big.
1072
## The number of grid points increases very quickly as the number of roles increases
1073
#self.getGPset = getGPset # True or False
1074
#if getGPset:
1075
# self.all_grid_points()
1076
1077
########################################################################
1078
#
1079
# Build a model
1080
#
1081
########################################################################
1082
1083
def _set_quant_list(self):
1084
quant_list = []
1085
for role_name in self.role_names:
1086
quant_list.append(self.find_roles(role_name))
1087
self.quant_list = quant_list
1088
1089
def _set_opts(self):
1090
"""Set option variables to default values."""
1091
1092
self.opts = {}
1093
self.opts['trace_varnames'] = [
1094
'act', 'H', 'H0', 'Q', 'q', 'T', 't', 'ema_speed', 'speed']
1095
self.opts['norm_ord'] = np.inf
1096
self.opts['coord'] = 'N'
1097
self.opts['ema_factor'] = 0.001
1098
self.opts['ema_tau'] = -1 / np.log(self.opts['ema_factor'])
1099
self.opts['T_init'] = 1e-3
1100
self.opts['T_min'] = 0.
1101
self.opts['T_decay_rate'] = 1e-3
1102
self.opts['q_init'] = 0.
1103
self.opts['q_max'] = 200.
1104
self.opts['q_rate'] = 10.
1105
self.opts['c'] = 0.5
1106
self.opts['bowl_center'] = 0.5
1107
self.opts['bowl_strength'] = None
1108
self.opts['beta_min_offset'] = 0.1
1109
self.opts['dt'] = 0.001
1110
self.opts['H0_on'] = True
1111
self.opts['H1_on'] = True
1112
self.opts['Hq_on'] = True
1113
self.opts['max_dt'] = 0.01
1114
self.opts['min_dt'] = 0.0005
1115
self.opts['q_policy'] = None
1116
1117
def _update_opts(self, opts):
1118
"""Update option variable values"""
1119
1120
if opts is not None:
1121
for key in opts:
1122
if key in self.opts:
1123
self.opts[key] = opts[key]
1124
if key == 'ema_factor':
1125
self.opts['ema_tau'] = -1 / np.log(self.opts[key])
1126
if key == 'ema_tau':
1127
self.opts['ema_factor'] = np.exp(-1 / self.opts[key])
1128
else:
1129
sys.exit('Check [opts]')
1130
1131
def _set_encodings(self):
1132
"""Set encoding variables to default values."""
1133
1134
self.encodings = {}
1135
self.encodings['dp_f'] = 0.
1136
self.encodings['dp_r'] = 0.
1137
self.encodings['coord_f'] = 'dist'
1138
self.encodings['coord_r'] = 'dist'
1139
self.encodings['dim_f'] = None
1140
self.encodings['dim_r'] = None
1141
self.encodings['filler_names'] = None
1142
self.encodings['role_names'] = None
1143
self.encodings['F'] = None
1144
self.encodings['R'] = None
1145
self.encodings['similarity'] = None
1146
1147
def _update_encodings(self, encodings):
1148
"""Update encoding variables"""
1149
1150
if encodings is not None:
1151
for key in encodings:
1152
if key in self.encodings:
1153
self.encodings[key] = encodings[key]
1154
1155
def _add_names(self):
1156
"""Add filler, role, and binding names to the GscNet object"""
1157
1158
if self.hg is None:
1159
if self.encodings['filler_names'] is None:
1160
sys.exit("Please provide a list of filler names.")
1161
if self.encodings['role_names'] is None:
1162
sys.exit("Please provide a list of role names.")
1163
self.filler_names = self.encodings['filler_names']
1164
self.role_names = self.encodings['role_names']
1165
self.binding_names = [f + '/' + r for r in self.role_names for f in self.filler_names]
1166
else:
1167
if isinstance(self.hg, HarmonicGrammar):
1168
self.filler_names = self.hg.fillers.names
1169
self.role_names = self.hg.roles.names
1170
self.binding_names = self.hg.binding_names
1171
else:
1172
sys.exit('[hg] is not an instance of HarmonicGrammar class.')
1173
1174
self.num_fillers = len(self.filler_names)
1175
self.num_roles = len(self.role_names)
1176
self.num_bindings = len(self.binding_names)
1177
1178
def _generate_encodings(self):
1179
"""Generate vector encodings of fillers, roles, and their bindings"""
1180
1181
if self.encodings['similarity'] is not None:
1182
# Update dp_f and dp_r
1183
dp_f = np.diag(np.ones(self.num_fillers))
1184
dp_r = np.diag(np.ones(self.num_roles))
1185
1186
for dp in self.encodings['similarity']:
1187
if all(sym in self.filler_names for sym in dp[0]):
1188
dp_f[self.filler_names.index(dp[0][0]), self.filler_names.index(dp[0][1])] = dp[1]
1189
dp_f[self.filler_names.index(dp[0][1]), self.filler_names.index(dp[0][0])] = dp[1]
1190
elif all(sym in self.role_names for sym in dp[0]):
1191
dp_r[self.role_names.index(dp[0][0]), self.role_names.index(dp[0][1])] = dp[1]
1192
dp_r[self.role_names.index(dp[0][1]), self.role_names.index(dp[0][0])] = dp[1]
1193
else:
1194
sys.exit('Cannot find some f/r bindings in your similarity list.')
1195
1196
self.encodings['dp_f'] = dp_f
1197
self.encodings['dp_r'] = dp_r
1198
1199
self.F = encode_symbols(
1200
self.num_fillers,
1201
coord=self.encodings['coord_f'],
1202
dp=self.encodings['dp_f'],
1203
dim=self.encodings['dim_f'])
1204
1205
self.R = encode_symbols(
1206
self.num_roles,
1207
coord=self.encodings['coord_r'],
1208
dp=self.encodings['dp_r'],
1209
dim=self.encodings['dim_r'])
1210
1211
# Overwrite if users provide F and R
1212
if self.encodings['F'] is not None:
1213
self.F = self.encodings['F']
1214
if self.encodings['R'] is not None:
1215
self.R = self.encodings['R']
1216
1217
self.dim_f = self.F.shape[0]
1218
self.dim_r = self.R.shape[0]
1219
self.num_units = self.dim_f * self.dim_r
1220
1221
ndigits = len(str(self.num_units))
1222
self.unit_names = ['U' + str(ii+1).zfill(ndigits) for ii in list(range(self.num_units))]
1223
1224
def _build_model(self, hg):
1225
"""Set the weight and bias values using a HarmonicGrammar object [hg]."""
1226
1227
if isinstance(hg, HarmonicGrammar):
1228
for rule in hg.rules:
1229
if len(rule[0]) == 2: # binary rules
1230
self.set_weight(rule[0][0], rule[0][1], rule[1])
1231
elif len(rule[0]) == 1: # unary rules
1232
if rule[0][0] in self.binding_names:
1233
self.set_bias(rule[0][0], rule[1])
1234
elif rule[0][0] in self.filler_names:
1235
self.set_filler_bias(rule[0][0], rule[1])
1236
elif rule[0][0] in self.role_names:
1237
self.set_role_bias(rule[0][0], rule[1])
1238
else:
1239
sys.exit('Check the rule in your Harmonic Grammar:' + rule)
1240
else:
1241
sys.exit('The given grammar as hg is not an instance of HarmonicGrammar class.')
1242
1243
if np.allclose(self.W, self.W.T) == False:
1244
sys.exit("The weight matrix (2D array) is not symmetric. Please check it.")
1245
1246
def _compute_TPmat(self):
1247
"""Compute the matrices of change of basis from conceptual to neural and from neural to conceptual coordinates.
1248
"""
1249
1250
# TP matrix that converts local to distributed representations (conceptual coordinate to neural coordinate).
1251
# See http://en.wikipedia.org/wiki/Vectorization_(mathematics) for justification of kronecker product.
1252
TP = np.kron(self.R, self.F) # Pay attention to the argument order.
1253
if TP.shape[0] == TP.shape[1]:
1254
TPinv = linalg.inv(TP)
1255
else:
1256
TPinv = linalg.pinv(TP) # TP may be a non-square matrix. So use pseudo-inverse.
1257
self.TP = TP
1258
self.TPinv = TPinv
1259
self.Gc = self.TPinv.T.dot(self.TPinv)
1260
1261
def _set_weights(self):
1262
"""Compute the weight values in the neural space (distributed representation)"""
1263
1264
self.W = self.TPinv.T.dot(self.WC).dot(self.TPinv)
1265
1266
def _set_biases(self):
1267
"""Compute the bias values in the neural space (distributed representation)"""
1268
1269
self.b = self.TPinv.T.dot(self.bC)
1270
1271
def _compute_recommended_bowl_strength(self):
1272
"""Compute the recommended value of bowl strength. Note that the value depends on external input."""
1273
1274
eigvals, eigvecs = np.linalg.eigh(self.WC) # WC should be a symmetric matrix. So eigh() was used instead of eig()
1275
eig_max = max(eigvals) # Condition 1: beta > eig_max to be stable
1276
if np.sum(abs(self.bowl_center)) > 0:
1277
if self.num_bindings == 1:
1278
beta1 = -(self.bC + self.extC) / self.bowl_center
1279
beta2 = (self.bC + self.extC + eig_max) / (1 - self.bowl_center)
1280
else:
1281
#beta1 = -min(self.bC+self.extC)/self.bowl_center # Condition 2: beta > beta1
1282
#beta2 = (max(self.bC+self.extC)+eig_max)/(1-self.bowl_center) # Condition 3: beta > beta2
1283
beta1 = -min((self.bC + self.extC) /self.bowl_center) # Condition 2: beta > beta1
1284
beta2 = max((self.bC + self.extC + eig_max) / (1 - self.bowl_center)) # Condition 3: beta > beta2 [CHECK]
1285
val = max(eig_max, beta1, beta2)
1286
else:
1287
val = eig_max
1288
1289
return val
1290
1291
def check_bowl_strength(self, disp=True):
1292
'''Compute and print the recommended beta value
1293
given the weights and biases in the C-space.'''
1294
1295
beta_min = self._compute_recommended_bowl_strength()
1296
if self.opts['bowl_strength'] <= beta_min:
1297
sys.exit("Bowl strength should be greater than %.4f." % beta_min)
1298
1299
if disp:
1300
print('(Current bowl strength: %.3f) must be greater than (minimum: %.3f)' % (self.opts['bowl_strength'], beta_min))
1301
1302
def update_bowl_center(self, bowl_center):
1303
"""Update the bowl center
1304
1305
Usage:
1306
1307
>>> net.update_bowl_center(0.3) # Set the bowl center to 0.3 * \vec{1}
1308
>>>
1309
>>> import numpy as np
1310
>>> bowl_center = np.random.random(net.num_bindings)
1311
>>> net.update_bowl_center(bowl_center)
1312
1313
: bowl_center: float or 1d NumPy array (size=number of bindings). the bowl center.
1314
"""
1315
1316
if not (isinstance(bowl_center, np.ndarray) or isinstance(bowl_center, numbers.Number)):
1317
sys.exit('You must provide a scalar or a NumPy array as bowl_center.')
1318
1319
if isinstance(bowl_center, numbers.Number):
1320
bowl_center = bowl_center * np.ones(self.num_bindings)
1321
1322
if bowl_center.shape[0] != self.num_bindings:
1323
sys.exit('When you provide a NumPy array as bowl_center, it must have the same number of elements as the number of f/r bindings.')
1324
1325
self.bowl_center = bowl_center
1326
self.zeta = self.C2N(actC=self.bowl_center)
1327
1328
def update_bowl_strength(self, bowl_strength=None):
1329
"""Replace the current bowl strength with
1330
the recommended bowl strength (+ offset)
1331
1332
Usage:
1333
1334
>>> net = gsc.GscNet(...)
1335
>>> net.set_weight('a/(0,1)', 'b/(1,2)', 2.0)
1336
>>> net.update_bowl_strength()
1337
1338
: bowl_strength : float or None (=default)
1339
"""
1340
1341
if bowl_strength is None:
1342
self.opts['bowl_strength'] = (
1343
self._compute_recommended_bowl_strength() +
1344
self.opts['beta_min_offset'])
1345
else:
1346
self.opts['bowl_strength'] = bowl_strength
1347
self.bowl_strength = self.opts['bowl_strength']
1348
1349
def set_weight(self, binding_name1, binding_name2, weight, symmetric=True):
1350
'''Set the weight of a connection between binding1 and binding2.
1351
When symmetric is set to True (default), the connection weight from
1352
binding2 to binding1 is set to the same value.'''
1353
1354
idx1 = self.find_bindings(binding_name1)
1355
idx2 = self.find_bindings(binding_name2)
1356
if symmetric:
1357
self.WC[idx1, idx2] = self.WC[idx2, idx1] = weight
1358
else:
1359
self.WC[idx2, idx1] = weight
1360
self._set_weights()
1361
1362
def set_bias(self, binding_name, bias):
1363
'''Set bias values of [binding_name] to [bias]'''
1364
1365
idx = self.find_bindings(binding_name)
1366
self.bC[idx] = bias
1367
self._set_biases()
1368
1369
def set_filler_bias(self, filler_name, bias):
1370
'''Set the bias of bindings of all roles
1371
with particular fillers to [bias].'''
1372
1373
filler_list = [bb.split('/')[0] for bb in self.binding_names]
1374
if not isinstance(filler_name, list):
1375
filler_name = [filler_name]
1376
for jj, filler in enumerate(filler_name):
1377
idx = [ii for ii, ff in enumerate(filler_list) if filler == ff]
1378
self.bC[idx] = bias
1379
1380
self._set_biases()
1381
1382
def set_role_bias(self, role_name, bias):
1383
'''Set the bias of bindings of all fillers
1384
with particular roles to [bias].'''
1385
1386
role_list = [bb.split('/')[1] for bb in self.binding_names]
1387
if not isinstance(role_name, list):
1388
role_name = [role_name]
1389
for jj, role in enumerate(role_name):
1390
idx = [ii for ii, rr in enumerate(role_list) if role == rr]
1391
self.bC[idx] = bias
1392
1393
self._set_biases()
1394
1395
######################################################################
1396
#
1397
# Util functions
1398
#
1399
######################################################################
1400
1401
def find_bindings(self, binding_names):
1402
'''Find the indices of the bindings from the list of binding names.'''
1403
1404
if not isinstance(binding_names, list):
1405
binding_names = [binding_names]
1406
return [self.binding_names.index(bb) for bb in binding_names]
1407
1408
def find_fillers(self, filler_name):
1409
1410
if not isinstance(filler_name, list):
1411
filler_name = [filler_name]
1412
1413
filler_list = [bb.split('/')[0] for bb in self.binding_names]
1414
filler_idx = []
1415
for jj, filler in enumerate(filler_name):
1416
idx = [ii for ii, ff in enumerate(filler_list) if filler == ff]
1417
filler_idx += idx
1418
1419
return filler_idx
1420
1421
def find_roles(self, role_name):
1422
1423
if not isinstance(role_name, list):
1424
role_name = [role_name]
1425
1426
role_list = [bb.split('/')[1] for bb in self.binding_names]
1427
role_idx = []
1428
for jj, role in enumerate(role_name):
1429
idx = [ii for ii, rr in enumerate(role_list) if role == rr]
1430
role_idx += idx
1431
1432
return role_idx
1433
1434
def vec2mat(self, actC=None):
1435
'''Convert an activation state vector to a matrix form
1436
in which each row corresponds to a filler
1437
and each column corresponds to a role.'''
1438
1439
if actC is None:
1440
actC = self.S2C()
1441
return actC.reshape(self.num_fillers, self.num_roles, order='F')
1442
1443
def C2N(self, actC=None):
1444
'''Change basis: from conceptual/pattern to neural space.'''
1445
1446
if actC is None:
1447
actC = self.actC
1448
return self.TP.dot(actC)
1449
1450
def N2C(self, act=None):
1451
'''Change basis: from neural to conceptual/pattern space.'''
1452
1453
if act is None:
1454
act = self.act
1455
return self.TPinv.dot(act)
1456
1457
def read_weight(self, which='WC'):
1458
'''Print the weight matrix in a readable format
1459
(in the pattern coordinate).'''
1460
1461
if which[-1] == 'C':
1462
print(pd.DataFrame(
1463
getattr(self, which), index=self.binding_names,
1464
columns=self.binding_names))
1465
else:
1466
print(pd.DataFrame(
1467
getattr(self, which), index=self.unit_names,
1468
columns=self.unit_names))
1469
1470
def read_bias(self, which='bC', print_vertical=True):
1471
'''Print the bias vector (in the pattern coordinate).'''
1472
1473
if which[-1] == 'C':
1474
if print_vertical:
1475
print(pd.DataFrame(
1476
getattr(self, which).reshape(self.num_bindings, 1),
1477
index=self.binding_names, columns=["bias"]))
1478
else:
1479
print(pd.DataFrame(
1480
getattr(self, which).reshape(1, self.num_bindings),
1481
index=["bias"], columns=self.binding_names))
1482
else:
1483
if print_vertical:
1484
print(pd.DataFrame(
1485
getattr(self, which).reshape(self.num_bindings, 1),
1486
index=self.unit_names, columns=["bias"]))
1487
else:
1488
print(pd.DataFrame(
1489
getattr(self, which).reshape(1, self.num_bindings),
1490
index=["bias"], columns=self.unit_names))
1491
1492
def read_state(self, act=None):
1493
'''Print the current state (C-SPACE) in a readable format.
1494
Pandas should be installed.'''
1495
1496
if act is None:
1497
act = self.act
1498
actC = self.vec2mat(self.N2C(act))
1499
print(pd.DataFrame(
1500
actC, index=self.filler_names, columns=self.role_names))
1501
1502
def read_grid_point(self, act=None, disp=True, skip=True):
1503
'''Print a grid point close to the current state. The grid point will be
1504
chosen by the snap-it method: a filler with the highest activation
1505
value in each role will be chosen.'''
1506
1507
act_min = 0.5
1508
1509
if act is None:
1510
act = self.act
1511
1512
actC = self.vec2mat(self.N2C(act))
1513
winner_idx = np.argmax(actC, axis=0)
1514
winners = [self.filler_names[ii] for ii in winner_idx]
1515
winners = ["%s/%s" % bb for bb in zip(winners, self.role_names)]
1516
1517
if skip:
1518
# if true, do not print null winners nor weak winners
1519
# (whose activation values are smaller than act_min)
1520
roles = [role_num for role_num, filler_num in enumerate(winner_idx)
1521
if (actC[filler_num, role_num] > act_min and
1522
self.filler_names[filler_num] is not self.hg.fillers.null)]
1523
winners = [winners[r] for r in roles]
1524
1525
if disp:
1526
print(winners)
1527
return winners
1528
1529
def get_grid_points(self, n=1e7):
1530
"""Get a list of the top [n] grid points with high harmony values
1531
and compute H_0 at every grid point. Regardless of the [n] value,
1532
the program will check every grid point. Note that the number of
1533
grid points is [num_fillers]^[num_roles] which increases explosively
1534
as [num_roles] increases. This method works only when the total
1535
number of grid points is reasonably small (currently set to 1e7)."""
1536
1537
if self.num_fillers ** self.num_roles > 1e7:
1538
sys.exit('There are too many grid points (= %d) to check all grid points.' % self.num_fillers ** self.num_roles)
1539
1540
if self.num_fillers ** self.num_roles < n:
1541
n = self.num_fillers ** self.num_roles
1542
1543
quant_list = [None] * self.num_roles
1544
for rind, role in enumerate(self.role_names):
1545
quant_list[rind] = [self.binding_names[ii] for ii in self.find_roles(role)]
1546
1547
gpset = []
1548
gpset_h = np.zeros(n)
1549
if self.num_fillers ** self.num_roles > 10000:
1550
print('Of %d grid points: ' % self.num_fillers ** self.num_roles)
1551
1552
for ii, gp in enumerate(itertools.product(*quant_list)):
1553
1554
if (ii + 1) % 1e4 == 0:
1555
print('[%06d]' % (ii + 1), end='')
1556
if (ii + 1) % 1e5 == 0:
1557
print('')
1558
1559
gp = list(gp)
1560
self.set_state(gp)
1561
hh = self.H0()
1562
if ii < n:
1563
gpset_h[ii] = hh
1564
gpset.append(gp)
1565
else:
1566
if hh > np.min(gpset_h):
1567
gpset[np.argmin(gpset_h)] = gp
1568
gpset_h[np.argmin(gpset_h)] = hh
1569
1570
# Sort the grid points in a decreasing order of Hg
1571
idx = np.argsort(gpset_h)[::-1]
1572
self.gpset = [gpset[ii] for ii in idx]
1573
self.gpset_h = gpset_h[idx]
1574
1575
def set_seed(self, num):
1576
'''Set a random number seed.'''
1577
1578
np.random.seed(num)
1579
1580
######################################################################
1581
#
1582
# Harmony
1583
#
1584
######################################################################
1585
1586
# def H(self, act=None):
1587
# """Evalutate total harmony"""
1588
1589
# return self.Hg(act) + self.opts['Hq_on'] * self.q * self.Q(act)
1590
1591
def H(self, act=None):
1592
"""Evalutate total harmony"""
1593
1594
return self.Hg(act) + float(self.opts['Hq_on']) * self.q * self.Qa(act)
1595
1596
def Hg(self, act=None):
1597
"""Evalutate H_G (= H0 + H1)"""
1598
1599
return (float(self.opts['H0_on']) * self.H0(act) +
1600
float(self.opts['H1_on']) * self.H1(act)) # + constant
1601
1602
def H0(self, act=None):
1603
"""Evaluate H0"""
1604
1605
if act is None:
1606
act = self.act
1607
return 0.5 * act.dot(self.W).dot(act) + (self.b + self.ext).dot(act)
1608
1609
def H1(self, act=None):
1610
"""Evalutate H1 (bowl harmony)"""
1611
1612
if act is None:
1613
act = self.act
1614
return (self.bowl_strength *
1615
(-0.5 * (act - self.zeta).T.dot(self.Gc).dot(act - self.zeta)))
1616
1617
def Q(self, act=None):
1618
"""Evaluate quantization harmony Q = c * Q0 + (1-c) * Q1"""
1619
1620
return (self.opts['c'] * self.Q0(act) +
1621
(1 - self.opts['c']) * self.Q1(act))
1622
1623
def Qa(self, act=None): # Experimental
1624
"""Evaluate quantization harmony Q = c * Q0 + (1-c) * Q1"""
1625
1626
return self.opts['c'] * self.Q0(act) + (1-self.opts['c']) * self.Q1a(act=act, quant_list=self.quant_list)
1627
1628
def Qb(self, act=None): # Experimental
1629
"""Evaluate quantization harmony Q = c * Q0 + (1-c) * Q1"""
1630
1631
return self.opts['c'] * self.Q0(act) + (1-self.opts['c']) * self.Q1b(act=act, quant_list=self.quant_list)
1632
1633
def Q0(self, act=None):
1634
"""Evaluate Q0"""
1635
1636
if act is None:
1637
act = self.act
1638
actC = self.N2C(act)
1639
return -np.sum(actC**2 * (1 - actC)**2)
1640
1641
def Q1(self, act=None):
1642
"""Evaluate Q1"""
1643
1644
if act is None:
1645
act = self.act
1646
return -np.sum((np.sum(self.vec2mat(self.N2C(act))**2, axis=0) - 1)**2)
1647
1648
def Q1a(self, act=None, quant_list=None):
1649
"""Evaluate Q1 (sum of squared = 1)"""
1650
1651
if act is None:
1652
act = self.act
1653
if quant_list is None:
1654
quant_list = self.quant_list
1655
1656
actC = self.N2C(act)
1657
q1 = 0
1658
for qlist in quant_list:
1659
# ssq = (actC[qlist]**2).sum()
1660
ssq = actC[qlist].dot(actC[qlist])
1661
q1 += (ssq - 1)**2
1662
1663
return -q1
1664
1665
def Q1b(self, act=None, quant_list=None):
1666
"""Evaluate Q1 (sum of squared = 0 or 1)"""
1667
1668
if act is None:
1669
act = self.act
1670
if quant_list is None:
1671
quant_list = self.quant_list
1672
1673
actC = self.N2C(act)
1674
q1 = 0
1675
for qlist in quant_list:
1676
# ssq = (actC[qlist]**2).sum()
1677
ssq = actC[qlist].dot(actC[qlist])
1678
q1 += (ssq - 1)**2 * ssq**2
1679
1680
return -q1
1681
1682
######################################################################
1683
#
1684
# Harmony Gradient
1685
#
1686
######################################################################
1687
1688
# def HGrad(self, act=None):
1689
# '''Compute the harmony gradient evaluated at the current state'''
1690
1691
# return self.HgGrad(act) + self.opts['Hq_on'] * self.q * self.QGrad(act)
1692
1693
def HGrad(self, act=None):
1694
'''Compute the harmony gradient evaluated at the current state'''
1695
1696
return self.HgGrad(act) + float(self.opts['Hq_on']) * self.q * self.QaGrad(act)
1697
1698
def HgGrad(self, act=None):
1699
"""Compute the gradient of grammar harmony H_G"""
1700
1701
return (float(self.opts['H0_on']) * self.H0Grad(act) +
1702
float(self.opts['H1_on']) * self.H1Grad(act))
1703
1704
def H0Grad(self, act=None):
1705
1706
if act is None:
1707
act = self.act
1708
return self.W.dot(act) + self.b + self.ext
1709
1710
def H1Grad(self, act=None):
1711
1712
if act is None:
1713
act = self.act
1714
return (self.bowl_strength *
1715
(-self.Gc.dot(act) + self.Gc.dot(self.zeta)))
1716
1717
def QGrad(self, act=None):
1718
1719
return (self.opts['c'] * self.Q0Grad(act) +
1720
(1 - self.opts['c']) * self.Q1Grad(act))
1721
1722
def QaGrad(self, act=None):
1723
1724
return self.opts['c'] * self.Q0Grad(act) + (1-self.opts['c']) * self.Q1aGrad(act=act, quant_list=self.quant_list)
1725
1726
def QbGrad(self, act=None):
1727
1728
return self.opts['c'] * self.Q0Grad(act) + (1-self.opts['c']) * self.Q1bGrad(act=act, quant_list=self.quant_list)
1729
1730
def Q0Grad(self, act=None):
1731
1732
if act is None:
1733
act = self.act
1734
actC = self.N2C(act)
1735
g = 2 * actC * (1 - actC) * (1 - 2 * actC) # g_{fr} vectorized
1736
return -np.einsum('ij,i', self.TPinv, g)
1737
1738
def Q1Grad(self, act=None):
1739
"""Compute the gradient of quantization harmony (Q1)"""
1740
1741
if act is None:
1742
act = self.act
1743
TPinv_reshaped = self.TPinv.reshape(
1744
(self.num_fillers, self.num_roles, self.num_units), order='F')
1745
actC = self.N2C(act)
1746
amat = self.vec2mat(actC)
1747
term1 = np.einsum('ij->j', amat**2) - 1
1748
term2 = np.einsum('ij,ijk->jk', amat, TPinv_reshaped)
1749
# == in term2 ==
1750
# i: filler index (f)
1751
# j: role index (r)
1752
# k: unit index (phi-rho pair)
1753
return -4 * np.einsum('j,jk', term1, term2)
1754
1755
def Q1aGrad(self, act=None, quant_list=None):
1756
"""Compute the gradient of quantization harmony (Q1)"""
1757
1758
if act is None:
1759
act = self.act
1760
if quant_list is None:
1761
quant_list = self.quant_list
1762
1763
#TPinv_reshaped = self.TPinv.reshape((self.num_fillers, self.num_roles, self.num_units), order='F')
1764
actC = self.N2C(act)
1765
# amat = self.vec2mat(actC)
1766
1767
q1grad = 0
1768
for qlist in quant_list:
1769
curr_actC = actC[qlist]
1770
curr_TPinv = self.TPinv[qlist, :]
1771
1772
curr_term1 = (curr_actC**2).sum() - 1
1773
curr_term2 = np.einsum('i,ij->j', curr_actC, curr_TPinv)
1774
q1grad += curr_term1 * curr_term2
1775
1776
return -4 * q1grad
1777
1778
def Q1bGrad(self, act=None, quant_list=None):
1779
"""Compute the gradient of quantization harmony (Q1)"""
1780
1781
if act is None:
1782
act = self.act
1783
if quant_list is None:
1784
quant_list = self.quant_list
1785
1786
#TPinv_reshaped = self.TPinv.reshape((self.num_fillers, self.num_roles, self.num_units), order='F')
1787
actC = self.N2C(act)
1788
# amat = self.vec2mat(actC)
1789
1790
q1grad = 0
1791
for qlist in quant_list:
1792
curr_actC = actC[qlist]
1793
curr_TPinv = self.TPinv[qlist, :]
1794
1795
ssq = curr_actC.dot(curr_actC)
1796
curr_term1 = ssq * (ssq - 1) * (2*ssq - 1)
1797
curr_term2 = np.einsum('i,ij->j', curr_actC, curr_TPinv) # CHECK
1798
q1grad += curr_term1 * curr_term2
1799
1800
return -4 * q1grad
1801
1802
######################################################################
1803
#
1804
# Log traces
1805
#
1806
######################################################################
1807
1808
def initialize_traces(self, trace_list):
1809
"""Create storage for traces."""
1810
1811
if trace_list == 'all':
1812
trace_list = self.opts['trace_varnames']
1813
else:
1814
if not isinstance(trace_list, list):
1815
sys.exit(('Check [trace_list] that should be a list object. \n'
1816
'If you want to log a single variable (e.g., H), \n'
1817
'you must provide ["H"], not "H", as the value of [trace_list].'))
1818
1819
var_not_in_varnames = [var for var in trace_list if var not in self.opts['trace_varnames']]
1820
if len(var_not_in_varnames) > 0:
1821
sys.exit(('Check [trace_list]. You provided variable name(s) that are not availalbe in the software.\n'
1822
'Currently, the following variables are available:\n' + self.opts['trace_varnames']))
1823
1824
if hasattr(self, 'traces'):
1825
for key in trace_list:
1826
self.traces[key] = list(self.traces[key])
1827
else:
1828
self.traces = {}
1829
for key in trace_list:
1830
self.traces[key] = []
1831
1832
self.update_traces()
1833
1834
def update_traces(self):
1835
"""Log traces"""
1836
1837
if 'act' in self.traces:
1838
self.traces['act'].append(list(self.act))
1839
if 'H' in self.traces:
1840
self.traces['H'].append(self.H())
1841
if 'H0' in self.traces:
1842
self.traces['H0'].append(self.H0())
1843
if 'Q' in self.traces:
1844
self.traces['Q'].append(self.Q())
1845
if 'q' in self.traces:
1846
self.traces['q'].append(self.q)
1847
if 't' in self.traces:
1848
self.traces['t'].append(self.t)
1849
if 'T' in self.traces:
1850
self.traces['T'].append(self.T)
1851
if 'ema_speed' in self.traces:
1852
self.traces['ema_speed'].append(self.ema_speed)
1853
if 'speed' in self.traces:
1854
self.traces['speed'].append(self.speed)
1855
1856
def finalize_traces(self):
1857
"""Convert list objects of traces to NumPy array objects."""
1858
1859
for key in self.traces:
1860
self.traces[key] = np.array(self.traces[key])
1861
1862
######################################################################
1863
#
1864
# Input (clamp vs. external input)
1865
#
1866
######################################################################
1867
1868
def _compute_projmat(self, A):
1869
"""Compute a projection matrix of a given matrix A. A is an n x m
1870
matrix of basis (column) vectors of the subspace. This function
1871
works only when the rank of A is equal to the nunmber of columns of A.
1872
"""
1873
1874
return A.dot(linalg.inv(A.T.dot(A))).dot(A.T)
1875
1876
def clamp(self, binding_names,
1877
clamp_vals=1.0, clamp_comp=False): # [CHECK]
1878
'''Clamp f/r bindings to [clamp_vals]'''
1879
1880
# A matrix of basic vectors each of which corresponds to a filler/role
1881
# binding whose activation state can change.
1882
if not isinstance(clamp_vals, list):
1883
clamp_vals = [clamp_vals]
1884
if not isinstance(binding_names, list):
1885
binding_names = [binding_names]
1886
if len(clamp_vals) > 1:
1887
if len(clamp_vals) != len(binding_names):
1888
sys.exit('The number of bindings clamped is not equal to the number of values provided.')
1889
1890
self.clamped = True
1891
self.binding_names_clamped = binding_names
1892
clampvecC = np.zeros(self.num_bindings)
1893
1894
if clamp_comp:
1895
role_names = [b.split('/')[1] for b in binding_names]
1896
idx1 = self.find_roles(role_names)
1897
clampvecC[idx1] = 0.0
1898
1899
idx = self.find_bindings(binding_names)
1900
clampvecC[idx] = clamp_vals
1901
self.clampvecC = clampvecC
1902
1903
if clamp_comp:
1904
idx += idx1 # CHECK
1905
idx.sort()
1906
1907
# Choose unclamped bindings. --- free to vary.
1908
idx0 = [bb for bb in np.arange(self.num_bindings) if bb not in idx]
1909
A = self.TP[:, idx0] # complement set of idx (basis vectors of the subspace)
1910
if len(idx0) > 0:
1911
self.projmat = self._compute_projmat(A)
1912
else:
1913
self.projmat = np.zeros((self.num_units, self.num_units))
1914
self.clampvec = self.C2N(clampvecC)
1915
self.act = self.act_clamped(self.act)
1916
self.actC = self.N2C()
1917
1918
def unclamp(self):
1919
"""Unclamp"""
1920
1921
if self.clamped is True:
1922
del self.clampvec
1923
del self.clampvecC
1924
del self.projmat
1925
del self.binding_names_clamped
1926
self.clamped = False
1927
1928
def act_clamped(self, act=None):
1929
"""Get a new activation vector after projecting an activation vector
1930
to a subspace."""
1931
1932
if act is None:
1933
act = self.act
1934
return self.projmat.dot(act) + self.clampvec
1935
1936
def set_input(self, binding_names, ext_vals, inhib_comp=False):
1937
'''Set external input.'''
1938
1939
if not isinstance(ext_vals, list):
1940
ext_vals = [ext_vals]
1941
if not isinstance(binding_names, list):
1942
binding_names = [binding_names]
1943
if len(ext_vals) > 1:
1944
if len(binding_names) != len(ext_vals):
1945
sys.exit("binding_names and ext_vals have different lengths.")
1946
1947
self.clear_input() # Consider removing this line.
1948
1949
#if inhib_comp:
1950
# role_names = [b.split('/')[1] for b in binding_names]
1951
# idx = self.find_roles(role_names)
1952
# self.extC[idx] = -np.asarray(ext_vals) # -ext_vals (list object)
1953
1954
idx = self.find_bindings(binding_names)
1955
self.extC[idx] = ext_vals
1956
self.ext = self.C2N(self.extC)
1957
1958
def clear_input(self):
1959
'''Remove external input'''
1960
1961
self.extC = np.zeros(self.num_bindings)
1962
self.ext = self.C2N(self.extC)
1963
1964
#######################################################################
1965
#
1966
# Set state
1967
#
1968
#######################################################################
1969
1970
def reset(self):
1971
'''Reset the model. q and T will be set to their initial values'''
1972
1973
self.q = self.opts['q_init']
1974
self.T = self.opts['T_init']
1975
self.t = 0
1976
self.randomize_state()
1977
self.actC = self.N2C()
1978
1979
self.extC_prev = np.zeros(self.num_bindings)
1980
self.ext_prev = self.TP.dot(self.extC_prev)
1981
self.unclamp()
1982
self.clear_input()
1983
if hasattr(self, 'traces'):
1984
del self.traces
1985
1986
def set_state(self, binding_names, vals=1.0):
1987
"""Set state to a particular vector at which the activation values
1988
of the given bindings are set to [vals] (default=1.0)
1989
and the activation values of the other bindings are set to 0."""
1990
1991
idx = self.find_bindings(binding_names)
1992
self.actC = np.zeros(self.num_bindings)
1993
self.actC[idx] = vals
1994
self.act = self.C2N()
1995
1996
def set_init_state(self, mu=0.5, sd=0.2):
1997
"""Set initial state"""
1998
1999
self.actC = np.random.normal(loc=mu, scale=sd, size=self.num_bindings)
2000
self.act = self.C2N()
2001
2002
def randomize_state(self, minact=0, maxact=1):
2003
'''Set the activation state to a random vector
2004
inside a hypercube of [minact, maxact]^num_bindings'''
2005
2006
self.actC = np.random.uniform(minact, maxact, self.num_bindings)
2007
self.act = self.C2N(self.actC)
2008
2009
#######################################################################
2010
#
2011
# Update
2012
#
2013
#######################################################################
2014
2015
def run(self, duration, update_T=True, update_q=True, log_trace=True,
2016
trace_list='all', plot=False, tol=None, testvar='ema_speed',
2017
grayscale=False, colorbar=True):
2018
'''Run simulations for a given amount of time [time].'''
2019
2020
# tspan is a list of two numbers (time span)
2021
# backward compatibility
2022
# if tspan is a scalar, then set t_direction to +1 and the initial to 0.
2023
# type(tspan)
2024
2025
# self.t = tspan[0]
2026
# t_end = tspan[1]
2027
2028
self.converged = False
2029
t_max = self.t + duration
2030
2031
# Log initial state
2032
self.step = 0
2033
if log_trace:
2034
self.initialize_traces(trace_list)
2035
# self.update_traces() # If the object has a trace attribute created from a previous run,
2036
# it will create a redundant copy. Move this to the initialize_traces().
2037
2038
# Interpolation of external input
2039
while self.t < t_max:
2040
# while self.t < t_end:
2041
self.update(update_T=update_T, update_q=update_q)
2042
if log_trace:
2043
self.update_traces()
2044
2045
if tol is not None:
2046
self.check_convergence(tol=tol, testvar=testvar)
2047
if self.converged:
2048
break
2049
2050
self.rt = self.t
2051
self.extC_prev[:] = self.extC # this is done in the sim method.
2052
self.ext_prev[:] = self.TP.dot(self.extC_prev)
2053
2054
if log_trace:
2055
self.finalize_traces()
2056
2057
if log_trace and plot:
2058
actC_trace = self.N2C(self.traces['act'].T).T
2059
times = self.traces['t']
2060
# generate a sequence of equally distributed times points
2061
times_new = np.linspace(times[0], times[-1], times.shape[0])
2062
actC_trace_new = []
2063
for b_ind in range(actC_trace.shape[1]):
2064
actC_trace_new.append(
2065
np.interp(times_new, times, actC_trace[:, b_ind]))
2066
actC_trace_new = np.array(actC_trace_new).T
2067
heatmap(
2068
actC_trace_new.T,
2069
xlabel="Time", xtick=False,
2070
ylabel="Bindings", yticklabels=self.binding_names,
2071
grayscale=grayscale, colorbar=colorbar, val_range=[0, 1])
2072
2073
def update(self, update_T=True, update_q=True):
2074
"""Update state, speed, ema_speed (and optionally T, q, dt)"""
2075
2076
self.act_prev[:] = self.act
2077
self.actC_prev[:] = self.actC
2078
self.update_state()
2079
self.update_speed()
2080
#if self.grid_points is not None:
2081
# self.update_dist(self.grid_points, space=space, norm_ord=norm_ord)
2082
2083
if update_T and (self.opts['T_decay_rate'] > 0):
2084
self.update_T()
2085
if update_q:
2086
self.update_q()
2087
2088
def update_state(self):
2089
'''Update state (with noise)'''
2090
2091
grad = self.HGrad()
2092
grad_mag = np.sqrt(grad.dot(grad))
2093
if grad_mag > 0: # adaptive step size
2094
self.dt = min(self.opts['max_dt'], self.opts['max_dt'] / grad_mag)
2095
self.dt = max(self.opts['min_dt'], self.dt)
2096
2097
self.t += self.dt # update time
2098
self.act += self.dt * grad
2099
self.add_noise()
2100
if self.clamped:
2101
self.act = self.act_clamped()
2102
self.actC = self.N2C() # update actC; will be used to compute HqGrad
2103
2104
def add_noise(self):
2105
'''Add noise to state in neural coordinates.'''
2106
2107
self.act += (np.sqrt(2 * self.T * self.dt) *
2108
np.random.randn(self.num_units))
2109
2110
def update_T(self):
2111
'''Update temperature'''
2112
2113
self.T = (np.exp(-self.opts['T_decay_rate'] * self.dt) *
2114
(self.T - self.opts['T_min']) + self.opts['T_min'])
2115
2116
def update_q(self):
2117
'''Update quantization strength'''
2118
2119
if self.opts['q_policy'] is not None:
2120
# Check the format of q_policy
2121
self.q = np.interp(
2122
self.t, self.opts['q_policy'][:, 0], self.opts['q_policy'][:, 1])
2123
else:
2124
self.q = max(min(self.q + self.opts['q_rate'] *
2125
self.dt, self.opts['q_max']), 0)
2126
2127
def update_speed(self):
2128
"""Update speed and ema_speed"""
2129
2130
if self.opts['coord'] == 'N':
2131
diff = self.act - self.act_prev
2132
elif self.opts['coord'] == 'C':
2133
diff = self.actC - self.actC_prev
2134
2135
self.speed = linalg.norm(
2136
diff, ord=self.opts['norm_ord']) / abs(self.dt)
2137
if self.ema_speed is None:
2138
self.ema_speed = self.speed
2139
else:
2140
ema_weight = self.opts['ema_factor'] ** abs(self.dt)
2141
self.ema_speed = (ema_weight * self.ema_speed +
2142
(1 - ema_weight) * self.speed)
2143
2144
# See EMA_{eq} in the following document: http://www.eckner.com/papers/ts_alg.pdf
2145
# : tau = -1 / log(ema_factor).
2146
# : ema_factor = exp(-1/ema_tau)
2147
2148
def check_convergence(self, tol, testvar='ema_speed'):
2149
'''Check if the convergence criterion (distance vs. ema_speed) has been satisfied.'''
2150
2151
#if testvar == 'dist':
2152
# if (self.dist < tol).any():
2153
# self.converged = True
2154
2155
if testvar == 'ema_speed':
2156
if self.ema_speed < tol:
2157
self.converged = True
2158
2159
if testvar == 'Q':
2160
if self.Q() > tol:
2161
self.converged = True
2162
2163
def plot_state(self, act=None, actC=None, coord='C',
2164
colorbar=True, disp=True, grayscale=True):
2165
"""Plot the activation state (conceptual coordinate) in a heatmap."""
2166
2167
if (act is None) and (actC is None):
2168
act = self.act
2169
actC = self.actC
2170
elif (act is None) and (actC is not None):
2171
act = self.C2N(actC)
2172
elif (act is not None) and (actC is None):
2173
actC = self.N2C(act)
2174
else:
2175
sys.exit('Error. You must pass either act or actC but not both to the function.')
2176
2177
if coord == 'C':
2178
heatmap(
2179
self.vec2mat(actC), xticklabels=self.role_names,
2180
yticklabels=self.filler_names, grayscale=grayscale,
2181
colorbar=colorbar, disp=disp, val_range=[0, 1])
2182
elif coord == 'N':
2183
act_mat = act.reshape((self.dim_f, self.dim_r), order='F')
2184
yticklabels = ['f' + str(ii) for ii in range(self.dim_f)]
2185
xticklabels = ['r' + str(ii) for ii in range(self.dim_r)]
2186
heatmap(
2187
act_mat, xticklabels=xticklabels, yticklabels=yticklabels,
2188
grayscale=grayscale, colorbar=colorbar, disp=disp)
2189
2190
def plot_trace(self, varname):
2191
"""Plot the trace of a given variable"""
2192
2193
x = self.traces['t']
2194
if varname is 'actC':
2195
y = self.N2C(self.traces[varname[:-1]].T).T
2196
else:
2197
y = self.traces[varname]
2198
2199
plt.plot(x, y)
2200
plt.xlabel('Time', fontsize=16)
2201
plt.ylabel(varname, fontsize=16)
2202
plt.grid(True)
2203
plt.show()
2204
2205
#def update_dist(self, grid_points):
2206
2207
# if not any(isinstance(tt, list) for tt in grid_points):
2208
# grid_points = [grid_points]
2209
2210
# dist = np.zeros(len(grid_points))
2211
# for ii, grid_point in enumerate(grid_points):
2212
# dist[ii] = self.compute_dist(ref_point=grid_point)
2213
2214
# self.dist = dist
2215
2216
#def compute_dist(self, ref_point):
2217
# """
2218
# Compute the distance of the current state from a grid point.
2219
# [grid point] is a set of bindings.
2220
# """
2221
2222
# idx = self.find_bindings(ref_point)
2223
# destC = np.zeros(self.num_bindings)
2224
# destC[idx] = 1.0
2225
2226
# if self.opts['coord'] == 'N':
2227
# state1 = self.act
2228
# state2 = self.N2S(destC)
2229
# elif self.opts['coord'] == 'C':
2230
# state1 = self.N2C(self.act)
2231
# state2 = destC
2232
2233
# return np.linalg.norm(state1-state2, ord=self.opts['norm_ord'])
2234
2235
2236
def encode_symbols(num_symbols, coord='dist', dp=0., dim=None):
2237
"""Generate the vector encodings of [num_symbols] symbols assuming a given similarity structure.
2238
Each column vector will represent a unique symbol.
2239
2240
Usage:
2241
2242
>>> gsc.encode_symbols(2)
2243
>>> gsc.encode_symbols(3, coord='dist', dp=0.3, dim=5)
2244
2245
: num_symbols : int, number of symbols to encode
2246
: coord : string, 'dist' (distributed representation, default) or 'local' (local representation)
2247
: dp : float (0 [default] <= dp <= 1) or 2D-numpy array of pairwise similarity (dot product)
2248
: dim : int, number of dimensions to encode a symbol. must not be smaller than [num_symbols]
2249
:
2250
: [dp] and [dim] values are ignored if coord is set to 'local'.
2251
"""
2252
2253
if coord == 'local':
2254
sym_mat = np.eye(num_symbols)
2255
else:
2256
if dim is None:
2257
dim = num_symbols
2258
else:
2259
if dim < num_symbols:
2260
sys.exit("The [dim] value must be same as or greater than the [num_symbols] value.")
2261
2262
if isinstance(dp, numbers.Number):
2263
dp = (dp * np.ones((num_symbols, num_symbols)) +
2264
(1 - dp) * np.eye(num_symbols, num_symbols))
2265
2266
sym_mat = dot_products(num_symbols, dim, dp)
2267
2268
return sym_mat
2269
2270
2271
def dot_products(num_symbols, dim, dp_mat, max_iter=100000):
2272
"""Generate a 2D numpy array of random numbers such that the pairwise dot
2273
products of column vectors are close to the numbers specified in [dp_mat].
2274
2275
Don Matthias wrote the original script in MATLAB for the LDNet program.
2276
He explains how this program works as follows:
2277
2278
Given square matrix dpMatrix of dimension N-by-N, find N
2279
dim-dimensional unit vectors whose pairwise dot products match
2280
dpMatrix. Results are returned in the columns of M. itns is the
2281
number of iterations of search required, and may be ignored.
2282
2283
Algorithm: Find a matrix M such that M'*M = dpMatrix. This is done
2284
via gradient descent on a cost function that is the square of the
2285
frobenius norm of (M'*M-dpMatrix).
2286
2287
NOTE: It has trouble finding more than about 16 vectors, possibly for
2288
dumb numerical reasons (like stepsize and tolerance), which might be
2289
fixable if necessary.
2290
"""
2291
2292
if not (dp_mat.T == dp_mat).all():
2293
sys.exit('dot_products: dp_mat must be symmetric')
2294
2295
if (np.diag(dp_mat) != 1).any():
2296
sys.exit('dot_products: dp_mat must have all ones on the main diagonal')
2297
2298
sym_mat = np.random.uniform(
2299
size=dim * num_symbols).reshape(dim, num_symbols, order='F')
2300
min_step = .1
2301
tol = 1e-6
2302
converged = False
2303
for iter_num in range(1, max_iter + 1):
2304
inc = sym_mat.dot(sym_mat.T.dot(sym_mat) - dp_mat)
2305
step = min(min_step, .01 / abs(inc).max())
2306
sym_mat = sym_mat - step * inc
2307
max_diff = abs(sym_mat.T.dot(sym_mat) - dp_mat).max()
2308
if max_diff <= tol:
2309
converged = True
2310
break
2311
2312
if not converged:
2313
print("Didn't converge after %d iterations" % max_iter)
2314
2315
return sym_mat
2316
2317
2318
def heatmap(data, xlabel=None, ylabel=None, xticklabels=None, yticklabels=None,
2319
grayscale=False, colorbar=True, rotate_xticklabels=False,
2320
xtick=True, ytick=True, disp=True, val_range=None):
2321
2322
# Plot the activation trace as heatmap
2323
if grayscale:
2324
cmap = plt.cm.get_cmap("gray_r")
2325
else:
2326
cmap = plt.cm.get_cmap("Reds")
2327
2328
if val_range is not None:
2329
plt.imshow(data, cmap=cmap, vmin=val_range[0], vmax=val_range[1],
2330
interpolation="nearest", aspect='auto')
2331
else:
2332
plt.imshow(data, cmap=cmap, interpolation="nearest", aspect='auto')
2333
2334
if xlabel is not None:
2335
plt.xlabel(xlabel, fontsize=16)
2336
if ylabel is not None:
2337
plt.ylabel(ylabel, fontsize=16)
2338
if xticklabels is not None:
2339
if rotate_xticklabels:
2340
plt.xticks(
2341
np.arange(len(xticklabels)), xticklabels,
2342
rotation='vertical')
2343
else:
2344
plt.xticks(np.arange(len(xticklabels)), xticklabels)
2345
2346
if yticklabels is not None:
2347
plt.yticks(np.arange(len(yticklabels)), yticklabels)
2348
2349
if xtick is False:
2350
plt.tick_params(
2351
axis='x', # changes apply to the x-axis
2352
which='both', # both major and minor ticks are affected
2353
bottom='off', # ticks along the bottom edge are off
2354
top='off', # ticks along the top edge are off
2355
labelbottom='off') # labels along the bottom edge are off
2356
if ytick is False:
2357
plt.tick_params(
2358
axis='y', # changes apply to the x-axis
2359
which='both', # both major and minor ticks are affected
2360
left='off', # ticks along the bottom edge are off
2361
right='off', # ticks along the top edge are off
2362
labelleft='off') # labels along the bottom edge are off
2363
2364
if colorbar:
2365
plt.colorbar()
2366
2367
if disp:
2368
plt.show()
2369
2370
2371
def plot_TP(vec1, vec2, figsize=None):
2372
'''Compute the outer product of two vectors and present it in a diagram.'''
2373
2374
nrow = vec1.shape[0]
2375
ncol = vec2.shape[0]
2376
radius = 0.4
2377
2378
arr = np.zeros((nrow + 1, ncol + 1))
2379
arr[1:, 1:] = np.outer(vec1, vec2)
2380
arr[0, 1:] = vec2
2381
arr[1:, 0] = vec1
2382
2383
if figsize is None:
2384
fig, ax = plt.subplots()
2385
else:
2386
fig, ax = plt.subplots(figsize=figsize)
2387
2388
for ii in range(nrow + 1):
2389
for jj in range(ncol + 1):
2390
if (ii == 0) and (jj == 0):
2391
continue
2392
if (ii == 0) or (jj == 0):
2393
alpha = 1 # 0.3
2394
else:
2395
alpha = 1
2396
2397
if arr[ii, jj] >= 0:
2398
curr_unit = plt.Circle(
2399
(jj, -ii), radius,
2400
color=plt.cm.gray(1 - abs(arr[ii, jj])),
2401
alpha=alpha)
2402
ax.add_artist(curr_unit)
2403
curr_unit = plt.Circle(
2404
(jj, -ii), radius,
2405
color='k', fill=False)
2406
ax.add_artist(curr_unit)
2407
else:
2408
curr_unit = plt.Circle(
2409
(jj, -ii), radius,
2410
color='k', fill=False)
2411
ax.add_artist(curr_unit)
2412
curr_unit = plt.Circle(
2413
(jj, -ii), radius - 0.1,
2414
color=plt.cm.gray(1 - abs(arr[ii, jj])),
2415
alpha=alpha)
2416
ax.add_artist(curr_unit)
2417
curr_unit = plt.Circle(
2418
(jj, -ii), radius - 0.1,
2419
color='k', fill=False)
2420
ax.add_artist(curr_unit)
2421
2422
ax.axis([
2423
0 - radius - 0.6, ncol + radius + 0.6,
2424
- nrow - radius - 0.6, 0 + radius + 0.6])
2425
ax.set_aspect('equal', adjustable='box')
2426
ax.axis('off')
2427
2428
2429
# ============================================================================
2430
# The functions below are not the part of the GSC simulator. They were
2431
# used to compuate the graidnet of quantization harmony in an earlier
2432
# version. I included the functions (1) to check if the newly implemented
2433
# functions return the same values, and (2) to compare the computation
2434
# speed in both versions.
2435
#
2436
# Q0GradE and Q1GradE (elementwise computation) must return the same values
2437
# as Q0GradV, and Q1GradV (partial vectorization).
2438
# ============================================================================
2439
2440
def b_ind(f, r, net):
2441
'''Get a binding index.'''
2442
return f + r * net.num_fillers
2443
2444
2445
def u_ind(phi, rho, net):
2446
'''Get a unit index.'''
2447
return phi + rho * net.dim_f
2448
2449
2450
def w(f, r, phi, rho, net):
2451
# A = net.TPinv
2452
return net.TPinv[b_ind(f, r, net), u_ind(phi, rho, net)]
2453
2454
2455
def get_a(n, net, f, r):
2456
'''Check the activation value of a f/r binding.'''
2457
act = 0
2458
for phi in range(net.dim_f):
2459
for rho in range(net.dim_r):
2460
act += w(f, r, phi, rho, net) * n[u_ind(phi, rho, net)]
2461
return act
2462
2463
2464
def n2a(n, net, f=None, r=None):
2465
# quant_list
2466
if (f is None) and (r is None):
2467
avec = np.zeros(net.num_bindings)
2468
for f in range(net.num_fillers):
2469
for r in range(net.num_roles):
2470
avec[b_ind(f, r, net)] = get_a(n, net, f, r)
2471
return avec
2472
elif (f is None) and (r is not None):
2473
avec = np.zeros(net.num_fillers)
2474
for f in range(net.num_fillers):
2475
avec[f] = get_a(n, net, f, r)
2476
return avec
2477
elif (f is not None) and (r is None):
2478
avec = np.zeros(net.num_roles)
2479
for r in range(net.num_roles):
2480
avec[r] = get_a(n, net, f, r)
2481
return avec
2482
else:
2483
return get_a(n, net, f, r)
2484
2485
2486
def Q0E(net, n):
2487
q0 = 0.0
2488
for f in range(net.num_fillers):
2489
for r in range(net.num_roles):
2490
q0 += n2a(n, net, f=f, r=r)**2 * (1 - n2a(n, net, f=f, r=r))**2
2491
return -q0
2492
2493
2494
def Q0GradE(net, n):
2495
# Elementwise computation. Very slow.
2496
# Based on the first derivation
2497
q0grad = np.zeros(net.num_units)
2498
for phi in range(net.dim_f):
2499
for rho in range(net.dim_r):
2500
q0grad[u_ind(phi, rho, net)] = 0.0
2501
for f in range(net.num_fillers):
2502
for r in range(net.num_roles):
2503
a_fr = n2a(n, net, f, r)
2504
g_fr = 2 * a_fr * (1 - a_fr) * (1 - 2 * a_fr)
2505
q0grad[u_ind(phi, rho, net)] += w(
2506
f, r, phi, rho, net) * g_fr
2507
return -q0grad
2508
2509
2510
def Q1E(net, n):
2511
q1 = 0.0
2512
for r in range(net.num_roles):
2513
q1 += (np.sum(n2a(n, net, r=r)**2) - 1)**2
2514
return -np.sum((np.sum(net.vec2mat(n2a(n, net))**2, axis=0) - 1)**2)
2515
2516
2517
def Q1GradE(net, n): # Elementwise computation
2518
q1grad = np.zeros(net.num_units)
2519
for phi in range(net.dim_f):
2520
for rho in range(net.dim_r):
2521
unit_grad = 0.0
2522
for r in range(net.num_roles):
2523
var1 = np.sum(n2a(n, net, r=r)**2) - 1
2524
var2 = 0.0
2525
for f in range(net.num_fillers):
2526
var2 += n2a(n, net, f, r) * w(f, r, phi, rho, net)
2527
unit_grad += 4 * var1 * var2
2528
q1grad[u_ind(phi, rho, net)] = unit_grad
2529
return -q1grad
2530
2531
2532
def Q0GradV(net, n):
2533
# Based on the first derivation
2534
a = net.N2C(n)
2535
g = 2 * a * (1 - a) * (1 - 2 * a) # a vectorized version of g_{fr}
2536
# A_fr_phirho = np.sum(net.TPinv[:, phirho] * g)
2537
gmat = np.tile(g, (net.num_units, 1)).T
2538
q0grad = np.sum(net.TPinv * gmat, axis=0)
2539
return -q0grad
2540
2541
2542
def Q1GradV(net, n):
2543
a = net.N2C(n)
2544
q1grad = 0.0
2545
for r_ind, rr in enumerate(net.role_names):
2546
curr_binding_ind = net.find_roles(rr)
2547
amat = np.tile(a[curr_binding_ind], (net.num_units, 1)).T
2548
term2 = np.sum(net.TPinv[curr_binding_ind, :] * amat, axis=0)
2549
term1 = np.sum(a[curr_binding_ind] ** 2) - 1
2550
q1grad += term1 * term2
2551
q1grad = 4 * q1grad
2552
return -q1grad
2553
2554