117 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			117 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Python
		
	
	
	
| #!/usr/bin/python
 | |
| 
 | |
| import math
 | |
| 
 | |
| class Table(object):
 | |
| 
 | |
|   def __init__(self, table_entry=256, table_range=8):
 | |
|     self.table_entry = table_entry
 | |
|     self.table_range = table_range
 | |
|     pass
 | |
| 
 | |
|   def sigmoid(self, x):
 | |
|     return 1 / (1 + math.exp(-1*x))
 | |
|   
 | |
|   def tanh(self, x):
 | |
|     return (math.exp(2*x)-1) / (math.exp(2*x)+1)
 | |
|   
 | |
|   def fp2q7(self, x):
 | |
|     x_int = math.floor(x*(2**7)+0.5)
 | |
|     if x_int >= 128 :
 | |
|       x_int = 127
 | |
|     if x_int < -128 :
 | |
|       x_int = -128
 | |
|     if x_int >= 0 :
 | |
|       return x_int
 | |
|     else :
 | |
|       return 0x100 + x_int
 | |
|   
 | |
|   def fp2q15(self, x):
 | |
|     x_int = math.floor(x*(2**15)+0.5)
 | |
|     if x_int >= 2**15 :
 | |
|       x_int = 2**15-1
 | |
|     if x_int < -1*2**15 :
 | |
|       x_int = -1*2**15
 | |
|     if x_int >= 0 :
 | |
|       return x_int
 | |
|     else :
 | |
|       return 0x10000 + x_int
 | |
| 
 | |
|   def table_gen(self):
 | |
|     outfile = open("NNCommonTable.c", "wb")
 | |
| 
 | |
|     outfile.write("/*\n * Common tables for NN\n *\n *\n *\n *\n */\n\n#include \"arm_math.h\"\n#include \"NNCommonTable.h\"\n\n/*\n * Table for sigmoid\n */\n")
 | |
|   
 | |
|     for function_type in ["sigmoid", "tanh"]:
 | |
|       for data_type in [7, 15]:
 | |
|         out_type = "q"+str(data_type)+"_t"
 | |
|         act_func = getattr(self, function_type)
 | |
|         quan_func = getattr(self, 'fp2q'+str(data_type))
 | |
| 
 | |
|         # unified table
 | |
|         outfile.write('const %s %sTable_q%d[%d] = {\n' % (out_type, function_type, data_type, self.table_entry) )
 | |
|         for i in range(self.table_entry):
 | |
|           # convert into actual value
 | |
|           if i < self.table_entry/2:
 | |
|             value_q7 = self.table_range * (i)
 | |
|           else:
 | |
|             value_q7 = self.table_range * (i - self.table_entry)
 | |
| 
 | |
|           if data_type == 7:
 | |
|             #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
 | |
|             outfile.write('0x%02x, ' % (quan_func(act_func(float(value_q7)/self.table_entry))))
 | |
|           else:
 | |
|             #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
 | |
|             outfile.write('0x%04x, ' % (quan_func(act_func(float(value_q7)/self.table_entry))))
 | |
|           if i % 8 == 7:
 | |
|             outfile.write("\n")
 | |
|         outfile.write("};\n\n")
 | |
| 
 | |
|       for data_type in [15]:
 | |
|         out_type = "q"+str(data_type)+"_t"
 | |
|         act_func = getattr(self, function_type)
 | |
|         quan_func = getattr(self, 'fp2q'+str(data_type))
 | |
| 
 | |
|         # H-L tables
 | |
|         outfile.write('const %s %sLTable_q%d[%d] = {\n' % (out_type, function_type, data_type, self.table_entry/2))
 | |
|         for i in range(self.table_entry/2):
 | |
|           # convert into actual value, max value is 16*self.table_entry/4 / 4
 | |
|           # which is equivalent to self.table_entry / self.table_entry/2 = 2, i.e., 1/4 of 8
 | |
|           if i < self.table_entry/4:
 | |
|             value_q7 = self.table_range * i / 4
 | |
|           else:
 | |
|             value_q7 = self.table_range * (i - self.table_entry/2) / 4
 | |
|           if data_type == 7:
 | |
|             #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
 | |
|             outfile.write('0x%02x, ' % (quan_func(act_func(float(value_q7)/(self.table_entry/2)))))
 | |
|           else:
 | |
|             #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
 | |
|             outfile.write('0x%04x, ' % (quan_func(act_func(float(value_q7)/(self.table_entry/2)))))
 | |
|           if i % 8 == 7:
 | |
|             outfile.write("\n")
 | |
|         outfile.write("};\n\n")
 | |
| 
 | |
|         outfile.write('const %s %sHTable_q%d[%d] = {\n' % (out_type, function_type, data_type, 3*self.table_entry/4))
 | |
|         for i in range(3 * self.table_entry/4):
 | |
|           # convert into actual value, tageting range (2, 8)
 | |
|           if i < 3*self.table_entry/8 :
 | |
|             value_q7 = self.table_range * ( i + self.table_entry/8 )
 | |
|           else:
 | |
|             value_q7 = self.table_range * ( i + self.table_entry/8 - self.table_entry)
 | |
|           if data_type == 7:
 | |
|             #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
 | |
|             outfile.write('0x%02x, ' % (quan_func(act_func(float(value_q7)/self.table_entry))))
 | |
|           else:
 | |
|             #outfile.write('%f, ' % (act_func(float(value_q7)/256)))
 | |
|             outfile.write('0x%04x, ' % (quan_func(act_func(float(value_q7)/self.table_entry))))
 | |
|           if i % 8 == 7:
 | |
|             outfile.write("\n")
 | |
|         outfile.write("};\n\n")
 | |
|     
 | |
|     outfile.close()
 | |
|   
 | |
|   
 | |
| mytable = Table(table_entry=256, table_range=16)
 | |
| 
 | |
| mytable.table_gen()
 |