2016-04-27 08:00:58 -07:00
"""
2 D rendering framework
"""
from __future__ import division
import os , sys
if " Apple " in sys . version :
if ' DYLD_FALLBACK_LIBRARY_PATH ' in os . environ :
os . environ [ ' DYLD_FALLBACK_LIBRARY_PATH ' ] + = ' :/usr/lib '
# (JDS 2016/04/15): avoid bug on Anaconda 2.3.0 / Yosemite
2016-05-01 23:38:19 -04:00
from gym . utils import reraise
2016-04-27 08:00:58 -07:00
from gym import error
2016-04-27 11:46:00 -07:00
try :
import pyglet
except ImportError as e :
2016-05-01 23:38:19 -04:00
reraise ( suffix = " HINT: you can install pyglet directly via ' pip install pyglet ' . But if you really just want to install all Gym dependencies and not have to think about it, ' pip install -e .[all] ' or ' pip install gym[all] ' will do it. " )
2016-04-27 11:46:00 -07:00
2016-04-27 08:00:58 -07:00
try :
from pyglet . gl import *
except ImportError as e :
2016-05-01 23:38:19 -04:00
reraise ( prefix = " Error occured while running `from pyglet.gl import *` " , suffix = " HINT: make sure you have OpenGL install. On Ubuntu, you can run ' apt-get install python-opengl ' . If you ' re running on a server, you may need a virtual frame buffer; something like this should work: ' xvfb-run -s \" -screen 0 1400x900x24 \" python <your_script.py> ' " )
2016-04-27 08:00:58 -07:00
import math
import numpy as np
RAD2DEG = 57.29577951308232
class Viewer ( object ) :
def __init__ ( self , width , height ) :
self . width = width
self . height = height
self . window = pyglet . window . Window ( width = width , height = height )
self . geoms = [ ]
self . onetime_geoms = [ ]
self . transform = Transform ( )
glEnable ( GL_BLEND )
glBlendFunc ( GL_SRC_ALPHA , GL_ONE_MINUS_SRC_ALPHA )
def close ( self ) :
self . window . close ( )
def set_bounds ( self , left , right , bottom , top ) :
assert right > left and top > bottom
scalex = self . width / ( right - left )
scaley = self . height / ( top - bottom )
self . transform = Transform (
translation = ( - left * scalex , - bottom * scalex ) ,
scale = ( scalex , scaley ) )
def add_geom ( self , geom ) :
self . geoms . append ( geom )
def add_onetime ( self , geom ) :
self . onetime_geoms . append ( geom )
def render ( self ) :
glClearColor ( 1 , 1 , 1 , 1 )
self . window . clear ( )
self . window . switch_to ( )
self . window . dispatch_events ( )
self . transform . enable ( )
for geom in self . geoms :
geom . render ( )
for geom in self . onetime_geoms :
geom . render ( )
self . transform . disable ( )
self . window . flip ( )
self . onetime_geoms = [ ]
# Convenience
def draw_circle ( self , radius = 10 , res = 30 , filled = True , * * attrs ) :
geom = make_circle ( radius = radius , res = res , filled = filled )
_add_attrs ( geom , attrs )
self . add_onetime ( geom )
return geom
def draw_polygon ( self , v , filled = True , * * attrs ) :
geom = make_polygon ( v = v , filled = filled )
_add_attrs ( geom , attrs )
self . add_onetime ( geom )
return geom
def draw_polyline ( self , v , * * attrs ) :
geom = make_polyline ( v = v )
_add_attrs ( geom , attrs )
self . add_onetime ( geom )
return geom
def draw_line ( self , start , end , * * attrs ) :
geom = Line ( start , end )
_add_attrs ( geom , attrs )
self . add_onetime ( geom )
return geom
def get_array ( self ) :
self . window . flip ( )
image_data = pyglet . image . get_buffer_manager ( ) . get_color_buffer ( ) . get_image_data ( )
self . window . flip ( )
arr = np . fromstring ( image_data . data , dtype = np . uint8 , sep = ' ' )
arr = arr . reshape ( self . height , self . width , 4 )
return arr [ : : - 1 , : , 0 : 3 ]
def _add_attrs ( geom , attrs ) :
if " color " in attrs :
geom . set_color ( attrs [ " color " ] )
if " linewidth " in attrs :
geom . set_linewidth ( attrs [ " linewidth " ] )
class Geom ( object ) :
def __init__ ( self ) :
self . _color = Color ( ( 0 , 0 , 0 , 1.0 ) )
self . attrs = [ self . _color ]
def render ( self ) :
for attr in reversed ( self . attrs ) :
attr . enable ( )
self . render1 ( )
for attr in self . attrs :
attr . disable ( )
def render1 ( self ) :
raise NotImplementedError
def add_attr ( self , attr ) :
self . attrs . append ( attr )
def set_color ( self , r , g , b ) :
self . _color . vec4 = ( r , g , b , 1 )
class Attr ( object ) :
def enable ( self ) :
raise NotImplementedError
def disable ( self ) :
pass
class Transform ( Attr ) :
def __init__ ( self , translation = ( 0.0 , 0.0 ) , rotation = 0.0 , scale = ( 1 , 1 ) ) :
self . set_translation ( * translation )
self . set_rotation ( rotation )
self . set_scale ( * scale )
def enable ( self ) :
glPushMatrix ( )
glTranslatef ( self . translation [ 0 ] , self . translation [ 1 ] , 0 ) # translate to GL loc ppint
glRotatef ( RAD2DEG * self . rotation , 0 , 0 , 1.0 )
glScalef ( self . scale [ 0 ] , self . scale [ 1 ] , 1 )
def disable ( self ) :
glPopMatrix ( )
def set_translation ( self , newx , newy ) :
self . translation = ( float ( newx ) , float ( newy ) )
def set_rotation ( self , new ) :
self . rotation = float ( new )
def set_scale ( self , newx , newy ) :
self . scale = ( float ( newx ) , float ( newy ) )
class Color ( Attr ) :
def __init__ ( self , vec4 ) :
self . vec4 = vec4
def enable ( self ) :
glColor4f ( * self . vec4 )
class LineStyle ( Attr ) :
def __init__ ( self , style ) :
self . style = style
def enable ( self ) :
glEnable ( GL_LINE_STIPPLE )
glLineStipple ( 1 , self . style )
def disable ( self ) :
glDisable ( GL_LINE_STIPPLE )
class LineWidth ( Attr ) :
def __init__ ( self , stroke ) :
self . stroke = stroke
def enable ( self ) :
glLineWidth ( self . stroke )
class Point ( Geom ) :
def __init__ ( self ) :
Geom . __init__ ( self )
def render1 ( self ) :
glBegin ( GL_POINTS ) # draw point
glVertex3f ( 0.0 , 0.0 , 0.0 )
glEnd ( )
class FilledPolygon ( Geom ) :
def __init__ ( self , v ) :
Geom . __init__ ( self )
self . v = v
def render1 ( self ) :
if len ( self . v ) == 4 : glBegin ( GL_QUADS )
elif len ( self . v ) > 4 : glBegin ( GL_POLYGON )
else : glBegin ( GL_TRIANGLES )
for p in self . v :
glVertex3f ( p [ 0 ] , p [ 1 ] , 0 ) # draw each vertex
glEnd ( )
def make_circle ( radius = 10 , res = 30 , filled = True ) :
points = [ ]
2016-04-27 18:03:29 -07:00
for i in range ( res ) :
2016-04-27 08:00:58 -07:00
ang = 2 * math . pi * i / res
points . append ( ( math . cos ( ang ) * radius , math . sin ( ang ) * radius ) )
if filled :
return FilledPolygon ( points )
else :
return PolyLine ( points , True )
def make_polygon ( v , filled = True ) :
if filled : return FilledPolygon ( v )
else : return PolyLine ( v , True )
def make_polyline ( v ) :
return PolyLine ( v , False )
def make_capsule ( length , width ) :
l , r , t , b = 0 , length , width / 2 , - width / 2
box = make_polygon ( [ ( l , b ) , ( l , t ) , ( r , t ) , ( r , b ) ] )
circ0 = make_circle ( width / 2 )
circ1 = make_circle ( width / 2 )
circ1 . add_attr ( Transform ( translation = ( length , 0 ) ) )
geom = Compound ( [ box , circ0 , circ1 ] )
return geom
class Compound ( Geom ) :
def __init__ ( self , gs ) :
Geom . __init__ ( self )
self . gs = gs
for g in self . gs :
g . attrs = [ a for a in g . attrs if not isinstance ( a , Color ) ]
def render1 ( self ) :
for g in self . gs :
g . render ( )
class PolyLine ( Geom ) :
def __init__ ( self , v , close ) :
Geom . __init__ ( self )
self . v = v
self . close = close
self . linewidth = LineWidth ( 1 )
self . add_attr ( self . linewidth )
def render1 ( self ) :
glBegin ( GL_LINE_LOOP if self . close else GL_LINE_STRIP )
for p in self . v :
glVertex3f ( p [ 0 ] , p [ 1 ] , 0 ) # draw each vertex
glEnd ( )
def set_linewidth ( self , x ) :
self . linewidth . stroke = x
class Line ( Geom ) :
def __init__ ( self , start = ( 0.0 , 0.0 ) , end = ( 0.0 , 0.0 ) ) :
Geom . __init__ ( self )
self . start = start
self . end = end
self . linewidth = LineWidth ( 1 )
self . add_attr ( self . linewidth )
def render1 ( self ) :
glBegin ( GL_LINES )
glVertex2f ( * self . start )
glVertex2f ( * self . end )
glEnd ( )
class Image ( Geom ) :
def __init__ ( self , fname , width , height ) :
Geom . __init__ ( self )
self . width = width
self . height = height
img = pyglet . image . load ( fname )
self . img = img
self . flip = False
def render1 ( self ) :
self . img . blit ( - self . width / 2 , - self . height / 2 , width = self . width , height = self . height )
# ================================================================
class SimpleImageViewer ( object ) :
def __init__ ( self ) :
self . window = None
self . isopen = False
def imshow ( self , arr ) :
if self . window is None :
height , width , channels = arr . shape
self . window = pyglet . window . Window ( width = width , height = height )
self . width = width
self . height = height
self . isopen = True
assert arr . shape == ( self . height , self . width , 3 ) , " You passed in an image with the wrong number shape "
image = pyglet . image . ImageData ( self . width , self . height , ' RGB ' , arr . tobytes ( ) , pitch = self . width * - 3 )
self . window . clear ( )
self . window . switch_to ( )
self . window . dispatch_events ( )
image . blit ( 0 , 0 )
self . window . flip ( )
def close ( self ) :
if self . isopen :
self . window . close ( )
self . isopen = False
def __del__ ( self ) :
self . close ( )