# wasm.py # # Experimental builder for Wasm binary encoding. Use at your own peril. # # Author: David Beazley (@dabeaz) # Copyright (C) 2019 # http://www.dabeaz.com import struct import enum from collections import defaultdict import json # ------------------------------------------------------------ # Low level encoding of values def encode_unsigned(value): ''' Produce an LEB128 encoded unsigned integer. ''' bits = bin(value)[2:] if len(bits) % 7: bits = '0'*(7 - len(bits) % 7) + bits parts = [ bits[i:i+7] for i in range(0,len(bits), 7) ] parts = [ parts[0], *['1'+p for p in parts[1:]] ] parts = [ int(p, 2) for p in parts ] return bytes(parts[::-1]) assert encode_unsigned(624485) == bytes([0xe5, 0x8e, 0x26]) def encode_signed(value): ''' Produce a LEB128 encoded signed integer. ''' if value > 0: return encode_unsigned(value) bits = bin(~(~0 << value.bit_length()) & value)[2:] bits = '1'+ '0'*(value.bit_length() - len(bits)) + bits if len(bits) % 7: bits = '1'*(7 - len(bits) % 7) + bits return encode_unsigned(int(bits,2)) assert encode_signed(-624485) == bytes([0x9b, 0xf1, 0x59]) def encode_float64(value): return struct.pack('' HexEnum.__class__ = HexEnumMeta class i32(HexEnum, encoding=0x7f): eqz = 0x45 eq = 0x46 ne = 0x47 lt_s = 0x48 lt_u = 0x49 gt_s = 0x4a gt_u = 0x4b le_s = 0x4c le_u = 0x4d ge_s = 0x4e ge_u = 0x4f clz = 0x67 ctz = 0x68 popcnt = 0x69 add = 0x6a sub = 0x6b mul = 0x6c div_s = 0x6d div_u = 0x6e rem_s = 0x6f rem_u = 0x70 and_ = 0x71 or_ = 0x72 xor = 0x73 shl = 0x74 shr_s = 0x75 shr_u = 0x76 rotl = 0x77 rotr = 0x78 wrap_i64 = 0xa7 trunc_f32_s = 0xa8 trunc_f32_u = 0xa9 trunc_f64_s = 0xaa trunc_f64_u = 0xab reinterpret_f32 = 0xbc load = 0x28 load8_s = 0x2c load8_u = 0x2d load16_s = 0x2e load16_u = 0x2f store = 0x36 store8 = 0x3a store16 = 0x3b const = 0x41 class i64(HexEnum, encoding=0x7e): eqz = 0x50 eq = 0x51 ne = 0x52 lt_s = 0x53 lt_u = 0x54 gt_s = 0x55 gt_u = 0x56 le_s = 0x57 le_u = 0x58 ge_s = 0x59 ge_u = 0x5a clz = 0x79 ctz = 0x7a popcnt = 0x7b add = 0x7c sub = 0x7d mul = 0x7e div_s = 0x7f div_u = 0x80 rem_s = 0x81 rem_u = 0x82 and_ = 0x83 or_ = 0x84 xor = 0x85 shl = 0x86 shr_s = 0x87 shr_u = 0x88 rotl = 0x89 rotr = 0x8a extend_i32_s = 0xac extend_i32_u = 0xad trunc_f32_s = 0xae trunc_f32_u = 0xaf trunc_f64_s = 0xb0 trunc_f64_u = 0xb1 reinterpret_f64 = 0xbd load = 0x29 load8_s = 0x30 load8_u = 0x31 load16_s = 0x32 load16_u = 0x33 load32_s = 0x34 load32_u = 0x35 store = 0x37 store8 = 0x3c store16 = 0x3d store32 = 0x3e const = 0x42 class f32(HexEnum, encoding=0x7d): eq = 0x5b ne = 0x5c lt = 0x5d gt = 0x5e le = 0x5f ge = 0x60 abs = 0x8b neg = 0x8c ceil = 0x8d floor = 0x8e trunc = 0x8f nearest = 0x90 sqrt = 0x91 add = 0x92 sub = 0x93 mul = 0x94 div = 0x95 min = 0x96 max = 0x97 copysign = 0x98 convert_i32_s = 0xb2 convert_i32_u = 0xb3 convert_i64_s = 0xb4 convert_i64_u = 0xb5 demote_f64 = 0xb6 reinterpret_i32 = 0xbe load = 0x2a store = 0x38 const = 0x43 class f64(HexEnum, encoding=0x7c): eq = 0x61 ne = 0x62 lt = 0x63 gt = 0x64 le = 0x65 ge = 0x66 abs = 0x99 neg = 0x9a ceil = 0x9b floor = 0x9c trunc = 0x9d nearest = 0x9e sqrt = 0x9f add = 0xa0 sub = 0xa1 mul = 0xa2 div = 0xa3 min = 0xa4 max = 0xa5 copysign = 0xa6 convert_i32_s = 0xb7 convert_i32_u = 0xb8 convert_i64_s = 0xb9 convert_i64_u = 0xba promote_f32 = 0xbb reinterpret_i64 = 0xbf load = 0x2b store = 0x39 const = 0x44 class local(HexEnum): get = 0x20 set = 0x21 tee = 0x22 class global_(HexEnum): get = 0x23 set = 0x24 global_.__name__ = 'global' # Special void type for block returns void = 0x40 # ------------------------------------------------------------ def encode_function_type(parameters, results): ''' parameters is a vector of value types results is a vector value types ''' enc_parms = bytes(parameters) enc_results = bytes(results) return b'\x60' + encode_vector(enc_parms) + encode_vector(enc_results) def encode_limits(min, max=None): if max is None: return b'\x00' + encode_unsigned(min) else: return b'\x01' + encode_unsigned(min) + encode_unsigned(max) def encode_table_type(elemtype, min, max=None): return b'\x70' + encode_limits(min, max) def encode_global_type(value_type, mut=True): return bytes([value_type, mut]) # ---------------------------------------------------------------------- # Instruction builders # # Wasm instructions are grouped into different namespaces. For example: # # i32.add() # local.get() # memory.size() # ... # # The classes that follow implement the namespace for different instruction # categories. # Builder for the local.* namespace class SubBuilder: def __init__(self, builder): self._builder = builder def _append(self, instr): self._builder._code.append(instr) class LocalBuilder(SubBuilder): def get(self, localidx): self._append([local.get, *encode_unsigned(localidx)]) def set(self, localidx): self._append([local.set, *encode_unsigned(localidx)]) def tee(self, localidx): self._append([local.tee, *encode_unsigned(localidx)]) class GlobalBuilder(SubBuilder): def get(self, glob): if isinstance(glob, int): globidx = glob else: globidx = glob.idx self._append([global_.get, *encode_unsigned(globidx)]) def set(self, glob): if isinstance(glob, int): globidx = glob else: globidx = glob.idx self._append([global_.set, *encode_unsigned(globidx)]) class MemoryBuilder(SubBuilder): def size(self): self._append([0x3f, 0x00]) def grow(self): self._append([0x40, 0x00]) class OpBuilder(SubBuilder): _optable = None # To be supplied by subclasses # Memory ops def load(self, align, offset): self._append([self._optable.load, *encode_unsigned(align), *encode_unsigned(offset)]) def load8_s(self, align, offset): self._append([self._optable.load8_s, *encode_unsigned(align), *encode_unsigned(offset)]) def load8_u(self, align, offset): self._append([self._optable.load8_u, *encode_unsigned(align), *encode_unsigned(offset)]) def load16_s(self, align, offset): self._append([self._optable.load16_s, *encode_unsigned(align), *encode_unsigned(offset)]) def load16_u(self, align, offset): self._append([self._optable.load16_u, *encode_unsigned(align), *encode_unsigned(offset)]) def load32_s(self, align, offset): self._append([self._optable.load32_s, *encode_unsigned(align), *encode_unsigned(offset)]) def load32_u(self, align, offset): self._append([self._optable.load32_u, *encode_unsigned(align), *encode_unsigned(offset)]) def store(self, align, offset): self._append([self._optable.store, *encode_unsigned(align), *encode_unsigned(offset)]) def store8(self, align, offset): self._append([self._optable.store8, *encode_unsigned(align), *encode_unsigned(offset)]) def store16(self, align, offset): self._append([self._optable.store16, *encode_unsigned(align), *encode_unsigned(offset)]) def store32(self, align, offset): self._append([self._optable.store32, *encode_unsigned(align), *encode_unsigned(offset)]) def __getattr__(self, key): def call(): self._append([getattr(self._optable, key)]) return call class I32OpBuilder(OpBuilder): _optable = i32 def const(self, value): self._append([self._optable.const, *encode_signed(value)]) class I64OpBuilder(OpBuilder): _optable = i64 def const(self, value): self._append([self._optable.const, *encode_signed(value)]) class F32OpBuilder(OpBuilder): _optable = f32 def const(self, value): self._append([self._optable.const, *encode_float32(value)]) class F64OpBuilder(OpBuilder): _optable = f64 def const(self, value): self._append([self._optable.const, *encode_float64(value)]) def _flatten(instr): for x in instr: if isinstance(x, list): yield from _flatten(x) else: yield x # High-level class that allows instructions to be easily encoded. class InstructionBuilder: def __init__(self): self._code = [ ] self.local = LocalBuilder(self) self.global_ = GlobalBuilder(self) self.i32 = I32OpBuilder(self) self.i64 = I64OpBuilder(self) self.f32 = F32OpBuilder(self) self.f64 = F64OpBuilder(self) # Control-flow stack. self._control = [ None ] def __iter__(self): return iter(self._code) # Resolve a human-readable label into control-stack index def _resolve_label(self, label): if isinstance(label, int): return label index = self._control.index(label) return len(label) - 1 - index # Control flow instructions def unreachable(self): self._code.append([0x01]) def nop(self): self._code.append([0x01]) def block_start(self, result_type, label=None): self._code.append([0x02, result_type]) self._control.append(label) return len(self._control) def block_end(self): self._code.append([0x0b]) self._control.pop() def loop_start(self, result_type, label=None): self._code.append([0x03, result_type]) self._control.append(label) return len(self._control) def if_start(self, result_type, label=None): self._code.append([0x04, result_type]) self._control.append(label) def else_start(self): self._code.append([0x05]) def br(self, label): labelidx = self._resolve_label(label) self._code.append([0x0c, *encode_unsigned(labelidx)]) def br_if(self, label): labelidx = self._resolve_label(label) self._code.append([0x0d, *encode_unsigned(labelidx)]) def br_table(self, labels, label): enc_labels = [encode_unsigned(self._resolve_label(idx)) for idx in labels] self._code.append([0x0e, *encode_vector(enc_labels), *encode_unsigned(self._resolve_label(label))]) def return_(self): self._code.append([0x0f]) def call(self, func): if isinstance(func, (ImportFunction,Function)): self._code.append([0x10, *encode_unsigned(func._idx)]) else: self._code.append([0x10, *encode_unsigned(func)]) def call_indirect(self, typesig): if isinstance(typesig, Type): typeidx = typesig.idx else: typeidx = typesig self._code.append([0x11, *encode_unsigned(typeidx), 0x00]) def drop(self): self._code.append([0x1a]) def select(self): self._code.append([0x1b]) class Type: def __init__(self, parms, results, idx): self.parms = parms self.results = results self.idx = idx def __repr__(self): return f'{self.parms!r} -> {self.results!r}' class ImportFunction: def __init__(self, name, typesig, idx): self._name = name self._typesig = typesig self._idx = idx def __repr__(self): return f'ImportFunction({self._name}, {self._typesig}, {self._idx})' class Function(InstructionBuilder): def __init__(self, name, typesig, idx, export=True): super().__init__() self._name = name self._typesig = typesig self._locals = list(typesig.parms) self._export = export self._idx = idx def __repr__(self): return f'Function({self._name}, {self._typesig}, {self._idx})' # Allocate a new local variable of a given type def alloc(self, valuetype): self._locals.append(valuetype) return len(self.locals) - 1 class ImportGlobal: def __init__(self, name, valtype, idx): self.name = name self.valtype = valtype self.idx = idx def __repr__(self): return f'ImportGlobal({self.name}, {self.valtype}, {self.idx})' class Global: def __init__(self, name, valtype, initializer, idx): self.name = name self.valtype = valtype self.initializer = initializer self.idx = idx def __repr__(self): return f'Global({self.name}, {self.valtype}, {self.initializer}, {self.idx})' class Module: def __init__(self): # Vector of function type signatures. Signatures are reused # if more than one function has the same signature. self.type_section = [] # Vector of imported entities. These can be functions, globals, # tables, and memories self.import_section = [] # There are 4 basic entities within a Wasm file. Functions, # globals, memories, and tables. Each kind of entity is # stored in a separate list and is indexed by an integer # index starting at 0. Imported entities must always # go before entities defined in the Wasm module itself. self.funcidx = 0 self.globalidx = 0 self.memoryidx = 0 self.tableidx = 0 self.function_section = [] # Vector of typeidx self.global_section = [] # Vector of globals self.table_section = [] # Vector of tables self.memory_section = [] # Vector of memories # Exported entities. A module may export functions, globals, # tables, and memories self.export_section = [] # Vector of exports # Optional start function. A function that executes upon loading self.start_section = None # Optional start function # Initialization of table elements self.element_section = [] # Code section for function bodies. self.code_section = [] # Data section contains data segments self.data_section = [] # List of function objects (to help with encoding) self.functions = [] # Output for JS/Html self.js_exports = ""; self.html_exports = ""; self.js_imports = defaultdict(dict) def add_type(self, parms, results): enc = encode_function_type(parms, results) if enc in self.type_section: return Type(parms, results, self.type_section.index(enc)) else: self.type_section.append(enc) return Type(parms, results, len(self.type_section) - 1) def import_function(self, module, name, parms, results): if len(self.function_section) > 0: raise RuntimeError('function imports must go before first function definition') typesig = self.add_type(parms, results) code = encode_name(module) + encode_name(name) + b'\x00' + encode_unsigned(typesig.idx) self.import_section.append(code) self.js_imports[module][name] = f"function: {typesig}" self.funcidx += 1 return ImportFunction(f'{module}.{name}', typesig, self.funcidx - 1) def import_table(self, module, name, elemtype, min, max=None): code = encode_name(module) + encode_name(name) + b'\x01' + encode_table_type(elemtype, min, max) self.import_section.append(code) self.js_imports[module][name] = "table:" self.tableidx += 1 return self.tableidx - 1 def import_memtype(self, module, name, min, max=None): code = encode_name(module) + encode_name(name) + b'\x02' + encode_limits(min, max) self.import_section.append(code) self.js_imports[module][name] = "memory:" self.memoryidx += 1 return self.memoryidx - 1 def import_global(self, module, name, value_type): if len(self.global_section) > 0: raise RuntimeError('global imports must go before first global definition') code = encode_name(module) + encode_name(name) + b'\x03' + encode_global_type(value_type, False) self.import_section.append(code) self.js_imports[module][name] = f"global: {value_type}" self.globalidx += 1 return ImportGlobal(f'{module}.{name}', value_type, self.globalidx - 1) def add_function(self, name, parms, results, export=True): typesig = self.add_type(parms, results) func = Function(name, typesig, self.funcidx, export) self.funcidx += 1 self.functions.append(func) self.function_section.append(encode_unsigned(typesig.idx)) self.html_exports += f'

{name}({", ".join(str(p) for p in parms)}) -> {results[0]!s}

\n' return func def add_table(self, elemtype, min, max=None): self.table_section.append(encode_table_type(elemtype, min, max)) self.tableidx += 1 return self.tableidx - 1 def add_memory(self, min, max=None): self.memory_section.append(encode_limits(min, max)) self.memoryidx += 1 return self.memoryidx - 1 def add_global(self, name, value_type, initializer, mutable=True, export=True): code = encode_global_type(value_type, mutuable) expr = InstructionBuilder() getattr(expr, str(valtype)).const(initializer) expr.finalize() code += expr._code self.global_section.append(code) if export: self.export_global(name, self.globalidx) self.globalidx += 1 return Global(name, value_type, initializer, self.globalidx - 1) def export_function(self, name, funcidx): code = encode_name(name) + b'\x00' + encode_unsigned(funcidx) self.export_section.append(code) self.js_exports += f'window.{name} = results.instance.exports.{name};\n' def export_table(self, name, tableidx): code = encode_name(name) + b'\x01' + encode_unsigned(tableidx) self.export_section.append(code) def export_memory(self, name, memidx): code = encode_name(name) + b'\x02' + encode_unsigned(memidx) self.export_section.append(code) def export_global(self, name, globalidx): code = encode_name(name) + b'\x03' + encode_unsigned(globalidx) self.export_section.append(code) def start_function(self, funcidx): self.start = encode_unsigned(funcidx) def add_element(self, tableidx, expr, funcidxs): code = encode_unsigned(tableidx) + expr.code code += encode_vector([encode_unsigned(i) for i in funcidxs]) self.element_section.append(code) def add_function_code(self, locals, expr): # Locals is a list of valtypes [i32, i32, etc...] # expr is an expression representing the actual code (InstructionBuilder) locs = [ encode_unsigned(1) + bytes([loc]) for loc in locals ] locs_code = encode_vector(locs) func_code = locs_code + bytes(_flatten(expr)) code = encode_unsigned(len(func_code)) + func_code self.code_section.append(code) def add_data(self, memidx, expr, data): # data is bytes code = encode_unsigned(memidx) + expr.code + encode_vector([data[i:i+1] for i in range(len(data))]) self.data_section.append(code) def _encode_section_vector(self, sectionid, contents): if not contents: return b'' contents_code = encode_vector(contents) code = bytes([sectionid]) + encode_unsigned(len(contents_code)) + contents_code return code def encode(self): for func in self.functions: self.add_function_code(func._locals, func._code) if func._export: self.export_function(func._name, func._idx) # Encode the whole module code = b'\x00\x61\x73\x6d\x01\x00\x00\x00' code += self._encode_section_vector(1, self.type_section) code += self._encode_section_vector(2, self.import_section) code += self._encode_section_vector(3, self.function_section) code += self._encode_section_vector(4, self.table_section) code += self._encode_section_vector(5, self.memory_section) code += self._encode_section_vector(6, self.global_section) code += self._encode_section_vector(7, self.export_section) if self.start_section: code += encode_unsigned(8) + self.start_section code += self._encode_section_vector(9, self.element_section) code += self._encode_section_vector(10, self.code_section) code += self._encode_section_vector(11, self.data_section) return code def write_wasm(self, modname): with open(f'{modname}.wasm', 'wb') as f: f.write(self.encode()) def write_html(self, modname): with open(f'{modname}.html', 'wt') as f: f.write(js_template.format( module=modname, imports=json.dumps(self.js_imports, indent=4), exports=self.js_exports, exports_html=self.html_exports, ) ) js_template = '''

module {module}

The following exports are made. Access from the JS console.

{exports_html} ''' def test1(): mod = Module() # An external function import. Note: All imports MUST go first. # Indexing affects function indexing for functions defined in the module. # Import some functions from JS # math_sin = mod.import_function('util', 'sin', [f64], [f64]) # math_cos = mod.import_function('util', 'cos', [f64], [f64]) # Import a function from another module entirely # fact = mod.import_function('recurse', 'fact', [i32], [i32]) # Import a global variable (from JS?) # FOO = mod.import_global('util', 'FOO', f64) # A more complicated function dsquared_func = mod.add_function('dsquared', [f64, f64], [f64]) dsquared_func.local.get(0) dsquared_func.local.get(0) dsquared_func.f64.mul() dsquared_func.local.get(1) dsquared_func.local.get(1) dsquared_func.f64.mul() dsquared_func.f64.add() dsquared_func.block_end() # A function calling another function distance = mod.add_function('distance', [f64, f64], [f64]) distance.local.get(0) distance.local.get(1) distance.call(dsquared_func) distance.f64.sqrt() distance.block_end() # A function calling out to JS # ext = mod.add_function('ext', [f64, f64], [f64]) # ext.local.get(0) # ext.call(math_sin) # ext.local.get(1) # ext.call(math_cos) # ext.f64.add() # ext.block_end() # A function calling across modules # tenf = mod.add_function('tenfact', [i32], [i32]) # tenf.local.get(0) # tenf.call(fact) # tenf.i32.const(10) # tenf.i32.mul() # tenf.block_end() # A function accessing an imported global variable # gf = mod.add_function('gf', [f64], [f64]) # gf.global_.get(FOO) # gf.local.get(0) # gf.f64.mul() # gf.block_end() # Memory mod.add_memory(1) mod.export_memory('memory', 0) # Function that returns a byte value getval = mod.add_function('getval', [i32], [i32]) getval.local.get(0) getval.i32.load8_u(0, 0) getval.block_end() # Function that sets a byte value setval = mod.add_function('setval', [i32,i32], [i32]) setval.local.get(0) # Memory address setval.local.get(1) # value setval.i32.store8(0,0) setval.i32.const(1) setval.block_end() return mod def test2(): mod = Module() fact = mod.add_function('fact', [i32], [i32]) fact.local.get(0) fact.i32.const(1) fact.i32.lt_s() fact.if_start(i32) fact.i32.const(1) fact.else_start() fact.local.get(0) fact.local.get(0) fact.i32.const(1) fact.i32.sub() fact.call(fact) fact.i32.mul() fact.block_end() fact.block_end() return mod if __name__ == '__main__': mod = test1() mod.write_wasm('test') mod.write_html('test')