Return to Blog

Building a command line tool to generate number sequences

By John Lekberg on June 13, 2020.


This week's post is about building a command line tool to generate number sequences. You will learn:

Script source code

#!/usr/bin/env python3

import ast
import functools
import itertools
import operator
import sys

# Exceptions


class ScriptError(Exception):
    pass


class InvalidPatternError(ScriptError):
    pass


class InvalidBinOpError(ScriptError):
    pass


class InvalidVariableError(ScriptError):
    pass


class InvalidNodeError(ScriptError):
    pass


#

valid_binop = {
    ast.Add: operator.add,
    ast.Sub: operator.sub,
    ast.Mult: operator.mul,
    ast.Div: operator.truediv,
    ast.Mod: operator.mod,
    ast.Pow: operator.pow,
    ast.FloorDiv: operator.floordiv,
}

# node_to_func


@functools.singledispatch
def node_to_func(node, var):
    """Convert an ast.AST node into a unary function
    of one variable.
    
    node -- ast.AST. The node to convert.
    var -- str. The name of the variable. E.g.
        "n", "x".
        
    """
    raise InvalidNodeError()


@node_to_func.register(ast.Expression)
def _(node, var):
    return node_to_func(node.body, var)


@node_to_func.register(ast.BinOp)
def _(node, var):
    op = valid_binop.get(type(node.op))
    if op is None:
        raise InvalidBinOpError()
    left = node_to_func(node.left, var)
    right = node_to_func(node.right, var)
    return lambda n: op(left(n), right(n))


@node_to_func.register(ast.Num)
def _(node, var):
    return lambda n: node.n


@node_to_func.register(ast.Name)
def _(node, var):
    if node.id == var:
        return lambda n: n
    else:
        raise InvalidVariableError()


#


def pattern_to_func(pattern, var):
    """Convert a text pattern into a unary
    function of on variable.
    
    pattern -- str. The text pattern to convert.
        E.g. "2 * n + 3"
    var -- str. The name of the variable. E.g.
        "n", "x".
        
    """
    try:
        node = ast.parse(pattern, mode="eval")
    except SyntaxError:
        raise InvalidPatternError()
    return node_to_func(node, var)


#


def pattern_values(pattern, *, up_to, var):
    """Generates values for a pattern.
    
    pattern -- str. A text pattern. E.g.
        "2 * n + 3"
    up_to -- int or None. If None, count forever.
        Otherwise, count up to just below this
        number.
    var -- str. The name of the variable. E.g.
        "n", "x".
        
    """
    if up_to is None:
        numbers = itertools.count()
    else:
        numbers = range(up_to)

    f = pattern_to_func(pattern, var)

    for n in numbers:
        yield f(n)


#

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "pattern",
        help='Pattern to generate sequence. E.g. "2 * n + 1".',
    )
    parser.add_argument(
        "--up-to",
        type=int,
        metavar="N",
        help="Count from 0 to N-1. (Default: count forever.)",
    )
    parser.add_argument(
        "--var",
        metavar="X",
        default="n",
        help='Name of pattern variable. (Default: "n".)',
    )
    args = parser.parse_args()

    try:
        values = pattern_values(
            args.pattern, up_to=args.up_to, var=args.var
        )
        for value in values:
            print(value)
    except InvalidPatternError:
        print(
            f"ERROR: {args.pattern!r} uses invalid Python syntax.",
            file=sys.stderr,
        )
    except InvalidVariableError:
        print(
            f"ERROR: Only {args.var!r} is allowed as a variable.",
            file=sys.stderr,
        )
    except InvalidBinOpError:
        print(
            f"ERROR: Only allowed operations are +, -, *, "
            "/, %, **, and //.",
            file=sys.stderr,
        )
    except InvalidNodeError:
        print(
            f"ERROR: Pattern must be simple expression with "
            "only allowed operations, numbers, and variables.",
            file=sys.stderr,
        )
$ ./seq-pattern --help
usage: seq-pattern [-h] [--up-to N] [--var X] pattern

positional arguments:
  pattern     Pattern to generate sequence. E.g. "2 * n + 1".

optional arguments:
  -h, --help  show this help message and exit
  --up-to N   Count from 0 to N-1. (Default: count forever.)
  --var X     Name of pattern variable. (Default: "n".)

Using the script to generate number sequences

I generate the numbers 1 to 10:

$ ./seq-pattern 'n+1' --up-to 10
1
2
3
4
5
6
7
8
9
10

I generate the first 10 powers of 2:

$ ./seq-pattern '2**n' --up-to 10
1
2
4
8
16
32
64
128
256
512

I generate an infinite streams of 0s and 1s:

$ ./seq-pattern '(n//3)%2'
0
0
0
1
1
1
0
0
0
1
[...]

How the script works

The script uses the ast module to parse the text into an abstract syntax tree (AST).

The function node_to_func takes this AST and turns it into a callable function. It uses functools.singledispatch to call different helper functions based on the type of AST node being processed.

If no upper bound is specified, I use itertools.count to generate an infinite sequence of numbers. Otherwise, I use range to generate a finite sequence of numbers up to the upper bound.

To keep the error handling simple, I use several custom Exceptions and catch those at the top level.

In conclusion...

In this week's post, you learned how to use the ast module to parse Python text into an abstract syntax tree (AST), and then safely evaluate that AST to turn it into a callable function.

My challenge to you:

Modify node_to_func to support several functions from the math module, like math.sin and math.cos.

If you enjoyed this week's post, share it with your friends and stay tuned for next week's post. See you then!


(If you spot any errors or typos on this post, contact me via my contact page.)