From a728a23adf37e6e37d8f388888b963c4a7d9d469 Mon Sep 17 00:00:00 2001 From: David Beazley Date: Sun, 17 Feb 2019 19:48:18 -0600 Subject: [PATCH] Added Wasm example --- example/wasm/expr.py | 245 ++++++++++++ example/wasm/test.e | 25 ++ example/wasm/wasm.py | 929 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 1199 insertions(+) create mode 100644 example/wasm/expr.py create mode 100644 example/wasm/test.e create mode 100644 example/wasm/wasm.py diff --git a/example/wasm/expr.py b/example/wasm/expr.py new file mode 100644 index 0000000..c235993 --- /dev/null +++ b/example/wasm/expr.py @@ -0,0 +1,245 @@ +# ----------------------------------------------------------------------------- +# expr.py +# +# Proof-of-concept encoding of functions/expressions into Wasm. +# +# This file implements a mini-language for writing Wasm functions as expressions. +# It only supports integers. +# +# Here's a few examples: +# +# # Some basic function definitions +# add(x, y) = x + y; +# mul(x, y) = x * y; +# dsquare(x, y) = mul(x, x) + mul(y, y); +# +# # A recursive function +# fact(n) = if n < 1 then 1 else n*fact(n-1); +# +# The full grammar: +# +# functions : functions function +# | function +# +# function : NAME ( parms ) = expr ; +# +# expr : expr + expr +# | expr - expr +# | expr * expr +# | expr / expr +# | expr < expr +# | expr <= expr +# | expr > expr +# | expr >= expr +# | expr == expr +# | expr != expr +# | ( expr ) +# | NAME (exprs) +# | if expr then expr else expr +# | NUMBER +# +# Note: This is implemented as one-pass compiler with no intermediate AST. +# Some of the grammar rules have to be written in a funny way to make this +# work. If doing this for real, I'd probably build an AST and construct +# Wasm code through AST walking. +# ----------------------------------------------------------------------------- + +import sys +sys.path.append('../..') + +from sly import Lexer, Parser +import wasm + +class ExprLexer(Lexer): + tokens = { NAME, NUMBER, PLUS, TIMES, MINUS, DIVIDE, LPAREN, RPAREN, COMMA, + LT, LE, GT, GE, EQ, NE, IF, THEN, ELSE, ASSIGN, SEMI } + ignore = ' \t' + + # Tokens + NAME = r'[a-zA-Z_][a-zA-Z0-9_]*' + NAME['if'] = IF + NAME['then'] = THEN + NAME['else'] = ELSE + + NUMBER = r'\d+' + + # Special symbols + PLUS = r'\+' + MINUS = r'-' + TIMES = r'\*' + DIVIDE = r'/' + LPAREN = r'\(' + RPAREN = r'\)' + COMMA = r',' + LE = r'<=' + LT = r'<' + GE = r'>=' + GT = r'>' + EQ = r'==' + NE = r'!=' + ASSIGN = r'=' + SEMI = ';' + + # Ignored pattern + ignore_newline = r'\n+' + ignore_comment = r'#.*\n' + + # Extra action for newlines + def ignore_newline(self, t): + self.lineno += t.value.count('\n') + + def error(self, t): + print("Illegal character '%s'" % t.value[0]) + self.index += 1 + +class ExprParser(Parser): + tokens = ExprLexer.tokens + + precedence = ( + ('left', IF, ELSE), + ('left', EQ, NE, LT, LE, GT, GE), + ('left', PLUS, MINUS), + ('left', TIMES, DIVIDE), + ('right', UMINUS) + ) + + def __init__(self): + self.functions = { } + self.module = wasm.Module() + + @_('functions function') + def functions(self, p): + pass + + @_('function') + def functions(self, p): + pass + + @_('function_decl ASSIGN expr SEMI') + def function(self, p): + self.function.block_end() + self.function = None + + @_('NAME LPAREN parms RPAREN') + def function_decl(self, p): + self.locals = { name:n for n, name in enumerate(p.parms) } + self.function = self.module.add_function(p.NAME, [wasm.i32]*len(p.parms), [wasm.i32]) + self.functions[p.NAME] = self.function + + @_('NAME LPAREN RPAREN') + def function_decl(self, p): + self.locals = { } + self.function = self.module.add_function(p.NAME, [], [wasm.i32]) + self.functions[p.NAME] = self.function + + @_('parms COMMA parm') + def parms(self, p): + return p.parms + [p.parm] + + @_('parm') + def parms(self, p): + return [ p.parm ] + + @_('NAME') + def parm(self, p): + return p.NAME + + @_('expr PLUS expr') + def expr(self, p): + self.function.i32.add() + + @_('expr MINUS expr') + def expr(self, p): + self.function.i32.sub() + + @_('expr TIMES expr') + def expr(self, p): + self.function.i32.mul() + + @_('expr DIVIDE expr') + def expr(self, p): + self.function.i32.div_s() + + @_('expr LT expr') + def expr(self, p): + self.function.i32.lt_s() + + @_('expr LE expr') + def expr(self, p): + self.function.i32.le_s() + + @_('expr GT expr') + def expr(self, p): + self.function.i32.gt_s() + + @_('expr GE expr') + def expr(self, p): + self.function.i32.ge_s() + + @_('expr EQ expr') + def expr(self, p): + self.function.i32.eq() + + @_('expr NE expr') + def expr(self, p): + self.function.i32.ne() + + @_('MINUS expr %prec UMINUS') + def expr(self, p): + pass + + @_('LPAREN expr RPAREN') + def expr(self, p): + pass + + @_('NUMBER') + def expr(self, p): + self.function.i32.const(int(p.NUMBER)) + + @_('NAME') + def expr(self, p): + self.function.local.get(self.locals[p.NAME]) + + @_('NAME LPAREN exprlist RPAREN') + def expr(self, p): + self.function.call(self.functions[p.NAME]) + + @_('NAME LPAREN RPAREN') + def expr(self, p): + self.function.call(self.functions[p.NAME]) + + @_('IF expr thenexpr ELSE expr') + def expr(self, p): + self.function.block_end() + + @_('exprlist COMMA expr') + def exprlist(self, p): + pass + + @_('expr') + def exprlist(self, p): + pass + + @_('startthen expr') + def thenexpr(self, p): + self.function.else_start() + + @_('THEN') + def startthen(self, p): + self.function.if_start(wasm.i32) + +if __name__ == '__main__': + import sys + if len(sys.argv) != 2: + raise SystemExit(f'Usage: {sys.argv[0]} module') + + lexer = ExprLexer() + parser = ExprParser() + parser.parse(lexer.tokenize(open(sys.argv[1]).read())) + + name = sys.argv[1].split('.')[0] + parser.module.write_wasm(name) + parser.module.write_html(name) + print(f'Wrote: {name}.wasm') + print(f'Wrote: {name}.html') + print('Use python3 -m http.server to test') diff --git a/example/wasm/test.e b/example/wasm/test.e new file mode 100644 index 0000000..69620cc --- /dev/null +++ b/example/wasm/test.e @@ -0,0 +1,25 @@ +# Experimental Wasm function examples. +# To run: +# +# 1. First run python3 expr.py test.e +# 2. Use python3 -m http.server +# +# Go to a browser and visit http://localhost:8000/test.html. +# From the browser, open the Javascript console. Try executing +# the functions from there. +# +# Some basic functions +add(x,y) = x+y; +sub(x,y) = x-y; +mul(x,y) = x*y; +div(x,y) = x/y; + +# A function calling other functions +dsquare(x,y) = mul(x,x) + mul(y,y); + +# A conditional +minval(a, b) = if a < b then a else b; + +# Some recursive functions +fact(n) = if n <= 1 then 1 else n*fact(n-1); +fib(n) = if n < 2 then 1 else fib(n-1) + fib(n-2); diff --git a/example/wasm/wasm.py b/example/wasm/wasm.py new file mode 100644 index 0000000..ddda6f8 --- /dev/null +++ b/example/wasm/wasm.py @@ -0,0 +1,929 @@ +# wasm.py +# +# Experimental builder for Wasm binary encoding. Use at your own peril. +# +# Author: David Beazley (@dabeaz) +# Copyright (C) 2019 +# http://www.dabeaz.com + +import struct +import enum +from collections import defaultdict +import json + +# ------------------------------------------------------------ +# Low level encoding of values + +def encode_unsigned(value): + ''' + Produce an LEB128 encoded unsigned integer. + ''' + bits = bin(value)[2:] + if len(bits) % 7: + bits = '0'*(7 - len(bits) % 7) + bits + + parts = [ bits[i:i+7] for i in range(0,len(bits), 7) ] + parts = [ parts[0], *['1'+p for p in parts[1:]] ] + parts = [ int(p, 2) for p in parts ] + return bytes(parts[::-1]) + +assert encode_unsigned(624485) == bytes([0xe5, 0x8e, 0x26]) + +def encode_signed(value): + ''' + Produce a LEB128 encoded signed integer. + ''' + if value > 0: + return encode_unsigned(value) + + bits = bin(~(~0 << value.bit_length()) & value)[2:] + bits = '1'+ '0'*(value.bit_length() - len(bits)) + bits + if len(bits) % 7: + bits = '1'*(7 - len(bits) % 7) + bits + return encode_unsigned(int(bits,2)) + +assert encode_signed(-624485) == bytes([0x9b, 0xf1, 0x59]) + +def encode_float64(value): + return struct.pack('' + +HexEnum.__class__ = HexEnumMeta + +class i32(HexEnum, encoding=0x7f): + eqz = 0x45 + eq = 0x46 + ne = 0x47 + lt_s = 0x48 + lt_u = 0x49 + gt_s = 0x4a + gt_u = 0x4b + le_s = 0x4c + le_u = 0x4d + ge_s = 0x4e + ge_u = 0x4f + clz = 0x67 + ctz = 0x68 + popcnt = 0x69 + add = 0x6a + sub = 0x6b + mul = 0x6c + div_s = 0x6d + div_u = 0x6e + rem_s = 0x6f + rem_u = 0x70 + and_ = 0x71 + or_ = 0x72 + xor = 0x73 + shl = 0x74 + shr_s = 0x75 + shr_u = 0x76 + rotl = 0x77 + rotr = 0x78 + wrap_i64 = 0xa7 + trunc_f32_s = 0xa8 + trunc_f32_u = 0xa9 + trunc_f64_s = 0xaa + trunc_f64_u = 0xab + reinterpret_f32 = 0xbc + load = 0x28 + load8_s = 0x2c + load8_u = 0x2d + load16_s = 0x2e + load16_u = 0x2f + store = 0x36 + store8 = 0x3a + store16 = 0x3b + const = 0x41 + +class i64(HexEnum, encoding=0x7e): + eqz = 0x50 + eq = 0x51 + ne = 0x52 + lt_s = 0x53 + lt_u = 0x54 + gt_s = 0x55 + gt_u = 0x56 + le_s = 0x57 + le_u = 0x58 + ge_s = 0x59 + ge_u = 0x5a + clz = 0x79 + ctz = 0x7a + popcnt = 0x7b + add = 0x7c + sub = 0x7d + mul = 0x7e + div_s = 0x7f + div_u = 0x80 + rem_s = 0x81 + rem_u = 0x82 + and_ = 0x83 + or_ = 0x84 + xor = 0x85 + shl = 0x86 + shr_s = 0x87 + shr_u = 0x88 + rotl = 0x89 + rotr = 0x8a + extend_i32_s = 0xac + extend_i32_u = 0xad + trunc_f32_s = 0xae + trunc_f32_u = 0xaf + trunc_f64_s = 0xb0 + trunc_f64_u = 0xb1 + reinterpret_f64 = 0xbd + load = 0x29 + load8_s = 0x30 + load8_u = 0x31 + load16_s = 0x32 + load16_u = 0x33 + load32_s = 0x34 + load32_u = 0x35 + store = 0x37 + store8 = 0x3c + store16 = 0x3d + store32 = 0x3e + const = 0x42 + +class f32(HexEnum, encoding=0x7d): + eq = 0x5b + ne = 0x5c + lt = 0x5d + gt = 0x5e + le = 0x5f + ge = 0x60 + abs = 0x8b + neg = 0x8c + ceil = 0x8d + floor = 0x8e + trunc = 0x8f + nearest = 0x90 + sqrt = 0x91 + add = 0x92 + sub = 0x93 + mul = 0x94 + div = 0x95 + min = 0x96 + max = 0x97 + copysign = 0x98 + convert_i32_s = 0xb2 + convert_i32_u = 0xb3 + convert_i64_s = 0xb4 + convert_i64_u = 0xb5 + demote_f64 = 0xb6 + reinterpret_i32 = 0xbe + load = 0x2a + store = 0x38 + const = 0x43 + +class f64(HexEnum, encoding=0x7c): + eq = 0x61 + ne = 0x62 + lt = 0x63 + gt = 0x64 + le = 0x65 + ge = 0x66 + abs = 0x99 + neg = 0x9a + ceil = 0x9b + floor = 0x9c + trunc = 0x9d + nearest = 0x9e + sqrt = 0x9f + add = 0xa0 + sub = 0xa1 + mul = 0xa2 + div = 0xa3 + min = 0xa4 + max = 0xa5 + copysign = 0xa6 + convert_i32_s = 0xb7 + convert_i32_u = 0xb8 + convert_i64_s = 0xb9 + convert_i64_u = 0xba + promote_f32 = 0xbb + reinterpret_i64 = 0xbf + load = 0x2b + store = 0x39 + const = 0x44 + +class local(HexEnum): + get = 0x20 + set = 0x21 + tee = 0x22 + +class global_(HexEnum): + get = 0x23 + set = 0x24 + +global_.__name__ = 'global' + +# Special void type for block returns +void = 0x40 + +# ------------------------------------------------------------ +def encode_function_type(parameters, results): + ''' + parameters is a vector of value types + results is a vector value types + ''' + enc_parms = bytes(parameters) + enc_results = bytes(results) + return b'\x60' + encode_vector(enc_parms) + encode_vector(enc_results) + + +def encode_limits(min, max=None): + if max is None: + return b'\x00' + encode_unsigned(min) + else: + return b'\x01' + encode_unsigned(min) + encode_unsigned(max) + +def encode_table_type(elemtype, min, max=None): + return b'\x70' + encode_limits(min, max) + +def encode_global_type(value_type, mut=True): + return bytes([value_type, mut]) + + +# ---------------------------------------------------------------------- +# Instruction builders +# +# Wasm instructions are grouped into different namespaces. For example: +# +# i32.add() +# local.get() +# memory.size() +# ... +# +# The classes that follow implement the namespace for different instruction +# categories. + +# Builder for the local.* namespace + +class SubBuilder: + def __init__(self, builder): + self._builder = builder + + def _append(self, instr): + self._builder._code.append(instr) + +class LocalBuilder(SubBuilder): + def get(self, localidx): + self._append([local.get, *encode_unsigned(localidx)]) + + def set(self, localidx): + self._append([local.set, *encode_unsigned(localidx)]) + + def tee(self, localidx): + self._append([local.tee, *encode_unsigned(localidx)]) + +class GlobalBuilder(SubBuilder): + def get(self, glob): + if isinstance(glob, int): + globidx = glob + else: + globidx = glob.idx + self._append([global_.get, *encode_unsigned(globidx)]) + + def set(self, glob): + if isinstance(glob, int): + globidx = glob + else: + globidx = glob.idx + self._append([global_.set, *encode_unsigned(globidx)]) + +class MemoryBuilder(SubBuilder): + def size(self): + self._append([0x3f, 0x00]) + + def grow(self): + self._append([0x40, 0x00]) + +class OpBuilder(SubBuilder): + _optable = None # To be supplied by subclasses + + # Memory ops + def load(self, align, offset): + self._append([self._optable.load, *encode_unsigned(align), *encode_unsigned(offset)]) + + def load8_s(self, align, offset): + self._append([self._optable.load8_s, *encode_unsigned(align), *encode_unsigned(offset)]) + + def load8_u(self, align, offset): + self._append([self._optable.load8_u, *encode_unsigned(align), *encode_unsigned(offset)]) + + def load16_s(self, align, offset): + self._append([self._optable.load16_s, *encode_unsigned(align), *encode_unsigned(offset)]) + + def load16_u(self, align, offset): + self._append([self._optable.load16_u, *encode_unsigned(align), *encode_unsigned(offset)]) + + def load32_s(self, align, offset): + self._append([self._optable.load32_s, *encode_unsigned(align), *encode_unsigned(offset)]) + + def load32_u(self, align, offset): + self._append([self._optable.load32_u, *encode_unsigned(align), *encode_unsigned(offset)]) + + def store(self, align, offset): + self._append([self._optable.store, *encode_unsigned(align), *encode_unsigned(offset)]) + + def store8(self, align, offset): + self._append([self._optable.store8, *encode_unsigned(align), *encode_unsigned(offset)]) + + def store16(self, align, offset): + self._append([self._optable.store16, *encode_unsigned(align), *encode_unsigned(offset)]) + + def store32(self, align, offset): + self._append([self._optable.store32, *encode_unsigned(align), *encode_unsigned(offset)]) + + def __getattr__(self, key): + def call(): + self._append([getattr(self._optable, key)]) + return call + +class I32OpBuilder(OpBuilder): + _optable = i32 + + def const(self, value): + self._append([self._optable.const, *encode_signed(value)]) + +class I64OpBuilder(OpBuilder): + _optable = i64 + + def const(self, value): + self._append([self._optable.const, *encode_signed(value)]) + +class F32OpBuilder(OpBuilder): + _optable = f32 + + def const(self, value): + self._append([self._optable.const, *encode_float32(value)]) + +class F64OpBuilder(OpBuilder): + _optable = f64 + + def const(self, value): + self._append([self._optable.const, *encode_float64(value)]) + +def _flatten(instr): + for x in instr: + if isinstance(x, list): + yield from _flatten(x) + else: + yield x + +# High-level class that allows instructions to be easily encoded. +class InstructionBuilder: + def __init__(self): + self._code = [ ] + self.local = LocalBuilder(self) + self.global_ = GlobalBuilder(self) + self.i32 = I32OpBuilder(self) + self.i64 = I64OpBuilder(self) + self.f32 = F32OpBuilder(self) + self.f64 = F64OpBuilder(self) + + # Control-flow stack. + self._control = [ None ] + + def __iter__(self): + return iter(self._code) + + # Resolve a human-readable label into control-stack index + def _resolve_label(self, label): + if isinstance(label, int): + return label + index = self._control.index(label) + return len(label) - 1 - index + + # Control flow instructions + def unreachable(self): + self._code.append([0x01]) + + def nop(self): + self._code.append([0x01]) + + def block_start(self, result_type, label=None): + self._code.append([0x02, result_type]) + self._control.append(label) + return len(self._control) + + def block_end(self): + self._code.append([0x0b]) + self._control.pop() + + def loop_start(self, result_type, label=None): + self._code.append([0x03, result_type]) + self._control.append(label) + return len(self._control) + + def if_start(self, result_type, label=None): + self._code.append([0x04, result_type]) + self._control.append(label) + + def else_start(self): + self._code.append([0x05]) + + def br(self, label): + labelidx = self._resolve_label(label) + self._code.append([0x0c, *encode_unsigned(labelidx)]) + + def br_if(self, label): + labelidx = self._resolve_label(label) + self._code.append([0x0d, *encode_unsigned(labelidx)]) + + def br_table(self, labels, label): + enc_labels = [encode_unsigned(self._resolve_label(idx)) for idx in labels] + self._code.append([0x0e, *encode_vector(enc_labels), *encode_unsigned(self._resolve_label(label))]) + + def return_(self): + self._code.append([0x0f]) + + def call(self, func): + if isinstance(func, (ImportFunction,Function)): + self._code.append([0x10, *encode_unsigned(func._idx)]) + else: + self._code.append([0x10, *encode_unsigned(func)]) + + def call_indirect(self, typesig): + if isinstance(typesig, Type): + typeidx = typesig.idx + else: + typeidx = typesig + self._code.append([0x11, *encode_unsigned(typeidx), 0x00]) + + def drop(self): + self._code.append([0x1a]) + + def select(self): + self._code.append([0x1b]) + + +class Type: + def __init__(self, parms, results, idx): + self.parms = parms + self.results = results + self.idx = idx + + def __repr__(self): + return f'{self.parms!r} -> {self.results!r}' + +class ImportFunction: + def __init__(self, name, typesig, idx): + self._name = name + self._typesig = typesig + self._idx = idx + + def __repr__(self): + return f'ImportFunction({self._name}, {self._typesig}, {self._idx})' + +class Function(InstructionBuilder): + def __init__(self, name, typesig, idx, export=True): + super().__init__() + self._name = name + self._typesig = typesig + self._locals = list(typesig.parms) + self._export = export + self._idx = idx + + def __repr__(self): + return f'Function({self._name}, {self._typesig}, {self._idx})' + + # Allocate a new local variable of a given type + def alloc(self, valuetype): + self._locals.append(valuetype) + return len(self.locals) - 1 + +class ImportGlobal: + def __init__(self, name, valtype, idx): + self.name = name + self.valtype = valtype + self.idx = idx + + def __repr__(self): + return f'ImportGlobal({self.name}, {self.valtype}, {self.idx})' + +class Global: + def __init__(self, name, valtype, initializer, idx): + self.name = name + self.valtype = valtype + self.initializer = initializer + self.idx = idx + + def __repr__(self): + return f'Global({self.name}, {self.valtype}, {self.initializer}, {self.idx})' + +class Module: + def __init__(self): + # Vector of function type signatures. Signatures are reused + # if more than one function has the same signature. + self.type_section = [] + + # Vector of imported entities. These can be functions, globals, + # tables, and memories + self.import_section = [] + + # There are 4 basic entities within a Wasm file. Functions, + # globals, memories, and tables. Each kind of entity is + # stored in a separate list and is indexed by an integer + # index starting at 0. Imported entities must always + # go before entities defined in the Wasm module itself. + self.funcidx = 0 + self.globalidx = 0 + self.memoryidx = 0 + self.tableidx = 0 + + self.function_section = [] # Vector of typeidx + self.global_section = [] # Vector of globals + self.table_section = [] # Vector of tables + self.memory_section = [] # Vector of memories + + # Exported entities. A module may export functions, globals, + # tables, and memories + + self.export_section = [] # Vector of exports + + # Optional start function. A function that executes upon loading + self.start_section = None # Optional start function + + # Initialization of table elements + self.element_section = [] + + # Code section for function bodies. + self.code_section = [] + + # Data section contains data segments + self.data_section = [] + + # List of function objects (to help with encoding) + self.functions = [] + + # Output for JS/Html + self.js_exports = ""; + self.html_exports = ""; + self.js_imports = defaultdict(dict) + + def add_type(self, parms, results): + enc = encode_function_type(parms, results) + if enc in self.type_section: + return Type(parms, results, self.type_section.index(enc)) + else: + self.type_section.append(enc) + return Type(parms, results, len(self.type_section) - 1) + + def import_function(self, module, name, parms, results): + if len(self.function_section) > 0: + raise RuntimeError('function imports must go before first function definition') + + typesig = self.add_type(parms, results) + code = encode_name(module) + encode_name(name) + b'\x00' + encode_unsigned(typesig.idx) + self.import_section.append(code) + self.js_imports[module][name] = f"function: {typesig}" + self.funcidx += 1 + return ImportFunction(f'{module}.{name}', typesig, self.funcidx - 1) + + def import_table(self, module, name, elemtype, min, max=None): + code = encode_name(module) + encode_name(name) + b'\x01' + encode_table_type(elemtype, min, max) + self.import_section.append(code) + self.js_imports[module][name] = "table:" + self.tableidx += 1 + return self.tableidx - 1 + + def import_memtype(self, module, name, min, max=None): + code = encode_name(module) + encode_name(name) + b'\x02' + encode_limits(min, max) + self.import_section.append(code) + self.js_imports[module][name] = "memory:" + self.memoryidx += 1 + return self.memoryidx - 1 + + def import_global(self, module, name, value_type): + if len(self.global_section) > 0: + raise RuntimeError('global imports must go before first global definition') + + code = encode_name(module) + encode_name(name) + b'\x03' + encode_global_type(value_type, False) + self.import_section.append(code) + self.js_imports[module][name] = f"global: {value_type}" + self.globalidx += 1 + return ImportGlobal(f'{module}.{name}', value_type, self.globalidx - 1) + + def add_function(self, name, parms, results, export=True): + typesig = self.add_type(parms, results) + func = Function(name, typesig, self.funcidx, export) + self.funcidx += 1 + self.functions.append(func) + self.function_section.append(encode_unsigned(typesig.idx)) + self.html_exports += f'

{name}({", ".join(str(p) for p in parms)}) -> {results[0]!s}

\n' + return func + + def add_table(self, elemtype, min, max=None): + self.table_section.append(encode_table_type(elemtype, min, max)) + self.tableidx += 1 + return self.tableidx - 1 + + def add_memory(self, min, max=None): + self.memory_section.append(encode_limits(min, max)) + self.memoryidx += 1 + return self.memoryidx - 1 + + def add_global(self, name, value_type, initializer, mutable=True, export=True): + code = encode_global_type(value_type, mutuable) + expr = InstructionBuilder() + getattr(expr, str(valtype)).const(initializer) + expr.finalize() + code += expr._code + self.global_section.append(code) + if export: + self.export_global(name, self.globalidx) + self.globalidx += 1 + return Global(name, value_type, initializer, self.globalidx - 1) + + def export_function(self, name, funcidx): + code = encode_name(name) + b'\x00' + encode_unsigned(funcidx) + self.export_section.append(code) + self.js_exports += f'window.{name} = results.instance.exports.{name};\n' + + + def export_table(self, name, tableidx): + code = encode_name(name) + b'\x01' + encode_unsigned(tableidx) + self.export_section.append(code) + + def export_memory(self, name, memidx): + code = encode_name(name) + b'\x02' + encode_unsigned(memidx) + self.export_section.append(code) + + def export_global(self, name, globalidx): + code = encode_name(name) + b'\x03' + encode_unsigned(globalidx) + self.export_section.append(code) + + def start_function(self, funcidx): + self.start = encode_unsigned(funcidx) + + def add_element(self, tableidx, expr, funcidxs): + code = encode_unsigned(tableidx) + expr.code + code += encode_vector([encode_unsigned(i) for i in funcidxs]) + self.element_section.append(code) + + def add_function_code(self, locals, expr): + # Locals is a list of valtypes [i32, i32, etc...] + # expr is an expression representing the actual code (InstructionBuilder) + + locs = [ encode_unsigned(1) + bytes([loc]) for loc in locals ] + locs_code = encode_vector(locs) + func_code = locs_code + bytes(_flatten(expr)) + code = encode_unsigned(len(func_code)) + func_code + self.code_section.append(code) + + def add_data(self, memidx, expr, data): + # data is bytes + code = encode_unsigned(memidx) + expr.code + encode_vector([data[i:i+1] for i in range(len(data))]) + self.data_section.append(code) + + def _encode_section_vector(self, sectionid, contents): + if not contents: + return b'' + contents_code = encode_vector(contents) + code = bytes([sectionid]) + encode_unsigned(len(contents_code)) + contents_code + return code + + def encode(self): + for func in self.functions: + self.add_function_code(func._locals, func._code) + if func._export: + self.export_function(func._name, func._idx) + + # Encode the whole module + code = b'\x00\x61\x73\x6d\x01\x00\x00\x00' + code += self._encode_section_vector(1, self.type_section) + code += self._encode_section_vector(2, self.import_section) + code += self._encode_section_vector(3, self.function_section) + code += self._encode_section_vector(4, self.table_section) + code += self._encode_section_vector(5, self.memory_section) + code += self._encode_section_vector(6, self.global_section) + code += self._encode_section_vector(7, self.export_section) + if self.start_section: + code += encode_unsigned(8) + self.start_section + code += self._encode_section_vector(9, self.element_section) + code += self._encode_section_vector(10, self.code_section) + code += self._encode_section_vector(11, self.data_section) + return code + + def write_wasm(self, modname): + with open(f'{modname}.wasm', 'wb') as f: + f.write(self.encode()) + + def write_html(self, modname): + with open(f'{modname}.html', 'wt') as f: + f.write(js_template.format( + module=modname, + imports=json.dumps(self.js_imports, indent=4), + exports=self.js_exports, + exports_html=self.html_exports, + ) + ) + +js_template = ''' + + + + +

module {module}

+ +

+The following exports are made. Access from the JS console. +

+ +{exports_html} + + +''' + +def test1(): + mod = Module() + + # An external function import. Note: All imports MUST go first. + # Indexing affects function indexing for functions defined in the module. + + # Import some functions from JS + # math_sin = mod.import_function('util', 'sin', [f64], [f64]) + # math_cos = mod.import_function('util', 'cos', [f64], [f64]) + + # Import a function from another module entirely + # fact = mod.import_function('recurse', 'fact', [i32], [i32]) + + # Import a global variable (from JS?) + # FOO = mod.import_global('util', 'FOO', f64) + + # A more complicated function + dsquared_func = mod.add_function('dsquared', [f64, f64], [f64]) + dsquared_func.local.get(0) + dsquared_func.local.get(0) + dsquared_func.f64.mul() + dsquared_func.local.get(1) + dsquared_func.local.get(1) + dsquared_func.f64.mul() + dsquared_func.f64.add() + dsquared_func.block_end() + + # A function calling another function + distance = mod.add_function('distance', [f64, f64], [f64]) + distance.local.get(0) + distance.local.get(1) + distance.call(dsquared_func) + distance.f64.sqrt() + distance.block_end() + + # A function calling out to JS + # ext = mod.add_function('ext', [f64, f64], [f64]) + # ext.local.get(0) + # ext.call(math_sin) + # ext.local.get(1) + # ext.call(math_cos) + # ext.f64.add() + # ext.block_end() + + # A function calling across modules + # tenf = mod.add_function('tenfact', [i32], [i32]) + # tenf.local.get(0) + # tenf.call(fact) + # tenf.i32.const(10) + # tenf.i32.mul() + # tenf.block_end() + + # A function accessing an imported global variable + # gf = mod.add_function('gf', [f64], [f64]) + # gf.global_.get(FOO) + # gf.local.get(0) + # gf.f64.mul() + # gf.block_end() + + # Memory + mod.add_memory(1) + mod.export_memory('memory', 0) + + + # Function that returns a byte value + getval = mod.add_function('getval', [i32], [i32]) + getval.local.get(0) + getval.i32.load8_u(0, 0) + getval.block_end() + + # Function that sets a byte value + setval = mod.add_function('setval', [i32,i32], [i32]) + setval.local.get(0) # Memory address + setval.local.get(1) # value + setval.i32.store8(0,0) + setval.i32.const(1) + setval.block_end() + return mod + + +def test2(): + mod = Module() + + fact = mod.add_function('fact', [i32], [i32]) + fact.local.get(0) + fact.i32.const(1) + fact.i32.lt_s() + fact.if_start(i32) + fact.i32.const(1) + fact.else_start() + fact.local.get(0) + fact.local.get(0) + fact.i32.const(1) + fact.i32.sub() + fact.call(fact) + fact.i32.mul() + fact.block_end() + fact.block_end() + + return mod + +if __name__ == '__main__': + mod = test1() + + mod.write_wasm('test') + mod.write_html('test') + + + + + + + + + + + + + + + + + + + + + +