BackPropagationNeuralNetworks.py

Go to the documentation of this file.
00001 # Back-Propagation Neural Networks
00002 # 
00003 # Written in Python.  See http://www.python.org/
00004 # Placed in the public domain.
00005 # Neil Schemenauer <nas@arctrix.com>
00006 
00007 import math
00008 import random
00009 import string
00010 
00011 random.seed(0)
00012 
00013 ##
00014 # calculate a random number where:  a <= rand < b
00015 def rand(a, b):
00016     return (b-a)*random.random() + a
00017 
00018 ##
00019 # Make a matrix (we could use NumPy to speed this up)
00020 def makeMatrix(I, J, fill=0.0):
00021     m = []
00022     for i in range(I):
00023         m.append([fill]*J)
00024     return m
00025 
00026 ##
00027 # our sigmoid function, tanh is a little nicer than the standard 1/(1+e^-x)
00028 def sigmoid(x):
00029     return math.tanh(x)
00030 
00031 ##
00032 # derivative of our sigmoid function
00033 def dsigmoid(y):
00034     return 1.0-y*y
00035 
00036 class NN:
00037     def __init__(self, ni, nh, no):
00038         # number of input, hidden, and output nodes
00039         self.ni = ni + 1 # +1 for bias node
00040         self.nh = nh
00041         self.no = no
00042 
00043         # activations for nodes
00044         self.ai = [1.0]*self.ni
00045         self.ah = [1.0]*self.nh
00046         self.ao = [1.0]*self.no
00047         
00048         # create weights
00049         self.wi = makeMatrix(self.ni, self.nh)
00050         self.wo = makeMatrix(self.nh, self.no)
00051         # set them to random vaules
00052         for i in range(self.ni):
00053             for j in range(self.nh):
00054                 self.wi[i][j] = rand(-2.0, 2.0)
00055         for j in range(self.nh):
00056             for k in range(self.no):
00057                 self.wo[j][k] = rand(-2.0, 2.0)
00058 
00059         # last change in weights for momentum   
00060         self.ci = makeMatrix(self.ni, self.nh)
00061         self.co = makeMatrix(self.nh, self.no)
00062 
00063     def update(self, inputs):
00064         if len(inputs) != self.ni-1:
00065             raise ValueError, 'wrong number of inputs'
00066 
00067         # input activations
00068         for i in range(self.ni-1):
00069             #self.ai[i] = sigmoid(inputs[i])
00070             self.ai[i] = inputs[i]
00071 
00072         # hidden activations
00073         for j in range(self.nh):
00074             sum = 0.0
00075             for i in range(self.ni):
00076                 sum = sum + self.ai[i] * self.wi[i][j]
00077             self.ah[j] = sigmoid(sum)
00078 
00079         # output activations
00080         for k in range(self.no):
00081             sum = 0.0
00082             for j in range(self.nh):
00083                 sum = sum + self.ah[j] * self.wo[j][k]
00084             self.ao[k] = sigmoid(sum)
00085 
00086         return self.ao[:]
00087 
00088 
00089     def backPropagate(self, targets, N, M):
00090         if len(targets) != self.no:
00091             raise ValueError, 'wrong number of target values'
00092 
00093         # calculate error terms for output
00094         output_deltas = [0.0] * self.no
00095         for k in range(self.no):
00096             error = targets[k]-self.ao[k]
00097             output_deltas[k] = dsigmoid(self.ao[k]) * error
00098 
00099         # calculate error terms for hidden
00100         hidden_deltas = [0.0] * self.nh
00101         for j in range(self.nh):
00102             error = 0.0
00103             for k in range(self.no):
00104                 error = error + output_deltas[k]*self.wo[j][k]
00105             hidden_deltas[j] = dsigmoid(self.ah[j]) * error
00106 
00107         # update output weights
00108         for j in range(self.nh):
00109             for k in range(self.no):
00110                 change = output_deltas[k]*self.ah[j]
00111                 self.wo[j][k] = self.wo[j][k] + N*change + M*self.co[j][k]
00112                 self.co[j][k] = change
00113                 #print N*change, M*self.co[j][k]
00114 
00115         # update input weights
00116         for i in range(self.ni):
00117             for j in range(self.nh):
00118                 change = hidden_deltas[j]*self.ai[i]
00119                 self.wi[i][j] = self.wi[i][j] + N*change + M*self.ci[i][j]
00120                 self.ci[i][j] = change
00121 
00122         # calculate error
00123         error = 0.0
00124         for k in range(len(targets)):
00125             error = error + 0.5*(targets[k]-self.ao[k])**2
00126         return error
00127 
00128 
00129     def test(self, patterns):
00130         for p in patterns:
00131             print p[0], '->', self.update(p[0])
00132 
00133     def weights(self):
00134         print 'Input weights:'
00135         for i in range(self.ni):
00136             print self.wi[i]
00137         print
00138         print 'Output weights:'
00139         for j in range(self.nh):
00140             print self.wo[j]
00141 
00142     def train(self, patterns, iterations=1000, N=0.5, M=0.1):
00143         # N: learning rate
00144         # M: momentum factor
00145         for i in xrange(iterations):
00146             error = 0.0
00147             for p in patterns:
00148                 inputs = p[0]
00149                 targets = p[1]
00150                 self.update(inputs)
00151                 error = error + self.backPropagate(targets, N, M)
00152             if i % 100 == 0:
00153                 print 'error %-14f' % error
00154 
00155 
00156 ##
00157 # Teach network XOR function
00158 def demo():
00159     pat = [
00160         [[0,0], [0]],
00161         [[0,1], [1]],
00162         [[1,0], [1]],
00163         [[1,1], [0]]
00164     ]
00165 
00166     # create a network with two input, two hidden, and one output nodes
00167     n = NN(2, 2, 1)
00168     # train it with some patterns
00169     n.train(pat)
00170     # test it
00171     n.test(pat)
00172 
00173 
00174 
00175 if __name__ == '__main__':
00176     demo()
00177 
00178 
00179 

© Copyright 2008-2009 Vyper Logix Corp., All Right Reserved; If you reference this document or any part of this document you must use the citation verbatim (including the link) "© Copyright 2008-2009 Vyper Logix Corp., All Right Reserved."

Notice: This source code contained in this document is NOT open source and is NOT being distributed as open source.

122,241 lines of code and growing...