Initial commit
This commit is contained in:
360
main.py
Normal file
360
main.py
Normal file
@@ -0,0 +1,360 @@
|
||||
import random
|
||||
import math
|
||||
import colorama
|
||||
|
||||
UP = (0, 1)
|
||||
LEFT = (-1, 0)
|
||||
DOWN = (0, -1)
|
||||
RIGHT = (1, 0)
|
||||
DIRS = [UP, DOWN, LEFT, RIGHT]
|
||||
|
||||
|
||||
class CompatibilityOracle(object):
|
||||
|
||||
"""The CompatibilityOracle class is responsible for telling us
|
||||
which combinations of tiles and directions are compatible. It's
|
||||
so simple that it perhaps doesn't need to be a class, but I think
|
||||
it helps keep things clear.
|
||||
"""
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def check(self, tile1, tile2, direction):
|
||||
return (tile1, tile2, direction) in self.data
|
||||
|
||||
|
||||
class Wavefunction(object):
|
||||
|
||||
"""The Wavefunction class is responsible for storing which tiles
|
||||
are permitted and forbidden in each location of an output image.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def mk(size, weights):
|
||||
"""Initialize a new Wavefunction for a grid of `size`,
|
||||
where the different tiles have overall weights `weights`.
|
||||
|
||||
Arguments:
|
||||
size -- a 2-tuple of (width, height)
|
||||
weights -- a dict of tile -> weight of tile
|
||||
"""
|
||||
coefficients = Wavefunction.init_coefficients(size, weights.keys())
|
||||
return Wavefunction(coefficients, weights)
|
||||
|
||||
@staticmethod
|
||||
def init_coefficients(size, tiles):
|
||||
"""Initializes a 2-D wavefunction matrix of coefficients.
|
||||
The matrix has size `size`, and each element of the matrix
|
||||
starts with all tiles as possible. No tile is forbidden yet.
|
||||
|
||||
NOTE: coefficients is a slight misnomer, since they are a
|
||||
set of possible tiles instead of a tile -> number/bool dict. This
|
||||
makes the code a little simpler. We keep the name `coefficients`
|
||||
for consistency with other descriptions of Wavefunction Collapse.
|
||||
|
||||
Arguments:
|
||||
size -- a 2-tuple of (width, height)
|
||||
tiles -- a set of all the possible tiles
|
||||
|
||||
Returns:
|
||||
A 2-D matrix in which each element is a set
|
||||
"""
|
||||
coefficients = []
|
||||
|
||||
for x in range(size[0]):
|
||||
row = []
|
||||
for y in range(size[1]):
|
||||
row.append(set(tiles))
|
||||
coefficients.append(row)
|
||||
|
||||
return coefficients
|
||||
|
||||
def __init__(self, coefficients, weights):
|
||||
self.coefficients = coefficients
|
||||
self.weights = weights
|
||||
|
||||
def get(self, co_ords):
|
||||
"""Returns the set of possible tiles at `co_ords`"""
|
||||
x, y = co_ords
|
||||
return self.coefficients[x][y]
|
||||
|
||||
def get_collapsed(self, co_ords):
|
||||
"""Returns the only remaining possible tile at `co_ords`.
|
||||
If there is not exactly 1 remaining possible tile then
|
||||
this method raises an exception.
|
||||
"""
|
||||
opts = self.get(co_ords)
|
||||
assert(len(opts) == 1)
|
||||
return next(iter(opts))
|
||||
|
||||
def get_all_collapsed(self):
|
||||
"""Returns a 2-D matrix of the only remaining possible
|
||||
tiles at each location in the wavefunction. If any location
|
||||
does not have exactly 1 remaining possible tile then
|
||||
this method raises an exception.
|
||||
"""
|
||||
width = len(self.coefficients)
|
||||
height = len(self.coefficients[0])
|
||||
|
||||
collapsed = []
|
||||
for x in range(width):
|
||||
row = []
|
||||
for y in range(height):
|
||||
row.append(self.get_collapsed((x,y)))
|
||||
collapsed.append(row)
|
||||
|
||||
return collapsed
|
||||
|
||||
def shannon_entropy(self, co_ords):
|
||||
"""Calculates the Shannon Entropy of the wavefunction at
|
||||
`co_ords`.
|
||||
"""
|
||||
x, y = co_ords
|
||||
|
||||
sum_of_weights = 0
|
||||
sum_of_weight_log_weights = 0
|
||||
for opt in self.coefficients[x][y]:
|
||||
weight = self.weights[opt]
|
||||
sum_of_weights += weight
|
||||
sum_of_weight_log_weights += weight * math.log(weight)
|
||||
|
||||
return math.log(sum_of_weights) - (sum_of_weight_log_weights / sum_of_weights)
|
||||
|
||||
|
||||
def is_fully_collapsed(self):
|
||||
"""Returns true if every element in Wavefunction is fully
|
||||
collapsed, and false otherwise.
|
||||
"""
|
||||
for x, row in enumerate(self.coefficients):
|
||||
for y, sq in enumerate(row):
|
||||
if len(sq) > 1:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def collapse(self, co_ords):
|
||||
"""Collapses the wavefunction at `co_ords` to a single, definite
|
||||
tile. The tile is chosen randomly from the remaining possible tiles
|
||||
at `co_ords`, weighted according to the Wavefunction's global
|
||||
`weights`.
|
||||
|
||||
This method mutates the Wavefunction, and does not return anything.
|
||||
"""
|
||||
x, y = co_ords
|
||||
opts = self.coefficients[x][y]
|
||||
valid_weights = {tile: weight for tile, weight in self.weights.iteritems() if tile in opts}
|
||||
|
||||
total_weights = sum(valid_weights.values())
|
||||
rnd = random.random() * total_weights
|
||||
|
||||
chosen = None
|
||||
for tile, weight in valid_weights.iteritems():
|
||||
rnd -= weight
|
||||
if rnd < 0:
|
||||
chosen = tile
|
||||
break
|
||||
|
||||
self.coefficients[x][y] = set(chosen)
|
||||
|
||||
def constrain(self, co_ords, forbidden_tile):
|
||||
"""Removes `forbidden_tile` from the list of possible tiles
|
||||
at `co_ords`.
|
||||
|
||||
This method mutates the Wavefunction, and does not return anything.
|
||||
"""
|
||||
x, y = co_ords
|
||||
self.coefficients[x][y].remove(forbidden_tile)
|
||||
|
||||
|
||||
class Model(object):
|
||||
|
||||
"""The Model class is responsible for orchestrating the
|
||||
Wavefunction Collapse algorithm.
|
||||
"""
|
||||
|
||||
def __init__(self, output_size, weights, compatibility_oracle):
|
||||
self.output_size = output_size
|
||||
self.compatibility_oracle = compatibility_oracle
|
||||
|
||||
self.wavefunction = Wavefunction.mk(output_size, weights)
|
||||
|
||||
def run(self):
|
||||
"""Collapses the Wavefunction until it is fully collapsed,
|
||||
then returns a 2-D matrix of the final, collapsed state.
|
||||
"""
|
||||
while not self.wavefunction.is_fully_collapsed():
|
||||
self.iterate()
|
||||
|
||||
return self.wavefunction.get_all_collapsed()
|
||||
|
||||
def iterate(self):
|
||||
"""Performs a single iteration of the Wavefunction Collapse
|
||||
Algorithm.
|
||||
"""
|
||||
# 1. Find the co-ordinates of minimum entropy
|
||||
co_ords = self.min_entropy_co_ords()
|
||||
# 2. Collapse the wavefunction at these co-ordinates
|
||||
self.wavefunction.collapse(co_ords)
|
||||
# 3. Propagate the consequences of this collapse
|
||||
self.propagate(co_ords)
|
||||
|
||||
def propagate(self, co_ords):
|
||||
"""Propagates the consequences of the wavefunction at `co_ords`
|
||||
collapsing. If the wavefunction at (x,y) collapses to a fixed tile,
|
||||
then some tiles may not longer be theoretically possible at
|
||||
surrounding locations.
|
||||
|
||||
This method keeps propagating the consequences of the consequences,
|
||||
and so on until no consequences remain.
|
||||
"""
|
||||
stack = [co_ords]
|
||||
|
||||
while len(stack) > 0:
|
||||
cur_coords = stack.pop()
|
||||
# Get the set of all possible tiles at the current location
|
||||
cur_possible_tiles = self.wavefunction.get(cur_coords)
|
||||
|
||||
# Iterate through each location immediately adjacent to the
|
||||
# current location.
|
||||
for d in valid_dirs(cur_coords, self.output_size):
|
||||
other_coords = (cur_coords[0] + d[0], cur_coords[1] + d[1])
|
||||
|
||||
# Iterate through each possible tile in the adjacent location's
|
||||
# wavefunction.
|
||||
for other_tile in set(self.wavefunction.get(other_coords)):
|
||||
# Check whether the tile is compatible with any tile in
|
||||
# the current location's wavefunction.
|
||||
other_tile_is_possible = any([
|
||||
self.compatibility_oracle.check(cur_tile, other_tile, d) for cur_tile in cur_possible_tiles
|
||||
])
|
||||
# If the tile is not compatible with any of the tiles in
|
||||
# the current location's wavefunction then it is impossible
|
||||
# for it to ever get chosen. We therefore remove it from
|
||||
# the other location's wavefunction.
|
||||
if not other_tile_is_possible:
|
||||
self.wavefunction.constrain(other_coords, other_tile)
|
||||
stack.append(other_coords)
|
||||
|
||||
def min_entropy_co_ords(self):
|
||||
"""Returns the co-ords of the location whose wavefunction has
|
||||
the lowest entropy.
|
||||
"""
|
||||
min_entropy = None
|
||||
min_entropy_coords = None
|
||||
|
||||
width, height = self.output_size
|
||||
for x in range(width):
|
||||
for y in range(height):
|
||||
if len(self.wavefunction.get((x,y))) == 1:
|
||||
continue
|
||||
|
||||
entropy = self.wavefunction.shannon_entropy((x, y))
|
||||
# Add some noise to mix things up a little
|
||||
entropy_plus_noise = entropy - (random.random() / 1000)
|
||||
if min_entropy is None or entropy_plus_noise < min_entropy:
|
||||
min_entropy = entropy_plus_noise
|
||||
min_entropy_coords = (x, y)
|
||||
|
||||
return min_entropy_coords
|
||||
|
||||
|
||||
def render_colors(matrix, colors):
|
||||
"""Render the fully collapsed `matrix` using the given `colors.
|
||||
|
||||
Arguments:
|
||||
matrix -- 2-D matrix of tiles
|
||||
colors -- dict of tile -> `colorama` color
|
||||
"""
|
||||
for row in matrix:
|
||||
output_row = []
|
||||
for val in row:
|
||||
color = colors[val]
|
||||
output_row.append(color + val + colorama.Style.RESET_ALL)
|
||||
|
||||
print "".join(output_row)
|
||||
|
||||
|
||||
def valid_dirs(cur_co_ord, matrix_size):
|
||||
"""Returns the valid directions from `cur_co_ord` in a matrix
|
||||
of `matrix_size`. Ensures that we don't try to take step to the
|
||||
left when we are already on the left edge of the matrix.
|
||||
"""
|
||||
x, y = cur_co_ord
|
||||
width, height = matrix_size
|
||||
dirs = []
|
||||
|
||||
if x > 0: dirs.append(LEFT)
|
||||
if x < width-1: dirs.append(RIGHT)
|
||||
if y > 0: dirs.append(DOWN)
|
||||
if y < height-1: dirs.append(UP)
|
||||
|
||||
return dirs
|
||||
|
||||
|
||||
def parse_example_matrix(matrix):
|
||||
"""Parses an example `matrix`. Extracts:
|
||||
|
||||
1. Tile compatibilities - which pairs of tiles can be placed next
|
||||
to each other and in which directions
|
||||
2. Tile weights - how common different tiles are
|
||||
|
||||
Arguments:
|
||||
matrix -- a 2-D matrix of tiles
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
* A set of compatibile tile combinations, where each combination is of
|
||||
the form (tile1, tile2, direction)
|
||||
* A dict of weights of the form tile -> weight
|
||||
"""
|
||||
compatibilities = set()
|
||||
matrix_width = len(matrix)
|
||||
matrix_height = len(matrix[0])
|
||||
|
||||
weights = {}
|
||||
|
||||
for x, row in enumerate(matrix):
|
||||
for y, cur_tile in enumerate(row):
|
||||
if cur_tile not in weights:
|
||||
weights[cur_tile] = 0
|
||||
weights[cur_tile] += 1
|
||||
|
||||
for d in valid_dirs((x,y), (matrix_width, matrix_height)):
|
||||
other_tile = matrix[x+d[0]][y+d[1]]
|
||||
compatibilities.add((cur_tile, other_tile, d))
|
||||
|
||||
return compatibilities, weights
|
||||
|
||||
|
||||
input_matrix = [
|
||||
['L','L','L','L'],
|
||||
['L','L','L','L'],
|
||||
['L','L','L','L'],
|
||||
['L','C','C','L'],
|
||||
['C','S','S','C'],
|
||||
['S','S','S','S'],
|
||||
['S','S','S','S'],
|
||||
]
|
||||
input_matrix2 = [
|
||||
['A','A','A','A'],
|
||||
['A','A','A','A'],
|
||||
['A','A','A','A'],
|
||||
['A','C','C','A'],
|
||||
['C','B','B','C'],
|
||||
['C','B','B','C'],
|
||||
['A','C','C','A'],
|
||||
]
|
||||
|
||||
compatibilities, weights = parse_example_matrix(input_matrix)
|
||||
compatibility_oracle = CompatibilityOracle(compatibilities)
|
||||
model = Model((10, 50), weights, compatibility_oracle)
|
||||
output = model.run()
|
||||
|
||||
colors = {
|
||||
'L': colorama.Fore.GREEN,
|
||||
'S': colorama.Fore.BLUE,
|
||||
'C': colorama.Fore.YELLOW,
|
||||
}
|
||||
|
||||
render_colors(output, colors)
|
||||
Reference in New Issue
Block a user