Skip to content

Commit

Permalink
adding mpl jointplot #111
Browse files Browse the repository at this point in the history
  • Loading branch information
timkpaine committed Oct 26, 2017
1 parent 504da64 commit dc1b330
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lantern/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def plot(data, type=None, raw=False, colors=None, **kwargs):
return getattr(_pm[BACKEND], typ.value)(data, type=typ, colors=colors, **kwargs)

# require more than 1 column
if typ in [lookup('pie'), lookup('bubble'), lookup('scatter'), lookup('bar'), lookup('stackedbar'), lookup('horizontalbar'), lookup('horizontalstackedbar'), lookup('box'), lookup('lmplot')]:
if typ in [lookup('pie'), lookup('bubble'), lookup('scatter'), lookup('bar'), lookup('stackedbar'), lookup('horizontalbar'), lookup('horizontalstackedbar'), lookup('box'), lookup('lmplot'), lookup('jointplot')]:
select = [col]
skip.add(col)

Expand All @@ -164,7 +164,7 @@ def plot(data, type=None, raw=False, colors=None, **kwargs):

# bubble specific options
# scatter specific options
if typ in [lookup('bubble'), lookup('scatter'), lookup('bubble3d'), lookup('scatter3d'), lookup('lmplot')]:
if typ in [lookup('bubble'), lookup('scatter'), lookup('bubble3d'), lookup('scatter3d'), lookup('lmplot'), lookup('jointplot')]:
scatter = _parseScatter(kwargs.pop('scatter', {}), col)
x = scatter.get('x', col)
y = scatter.get('y', col)
Expand Down
4 changes: 4 additions & 0 deletions lantern/plotting/plot_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def groupedscatter(data, **kwargs):
def heatmap(data, **kwargs):
raise NotImplementedError()

@staticmethod
def jointplot(data, **kwargs):
raise NotImplementedError()

@staticmethod
def lmplot(data, **kwargs):
raise NotImplementedError()
Expand Down
4 changes: 4 additions & 0 deletions lantern/plotting/plot_cufflinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,10 @@ def candlestick(data, **kwargs):
def density(data, **kwargs):
raise NotImplementedError()

@staticmethod
def jointplot(data, **kwargs):
raise NotImplementedError()

@staticmethod
def lmplot(data, **kwargs):
raise NotImplementedError()
Expand Down
14 changes: 14 additions & 0 deletions lantern/plotting/plot_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,20 @@ def horizontalstackedbar(data, **kwargs):
stacked=True,
**kwargs)

@staticmethod
def jointplot(data, **kwargs):
kwargs = MatplotlibPlotMap._wrapper(**kwargs)
scatter = kwargs.pop('scatter', {})
x = scatter.pop('x', data.columns[0])
y = scatter.pop('y', data.columns[0])

color = kwargs.pop('color')
scatter_kws = {'color': kwargs.pop('scatter_color', color)}
line_kws = {'color': kwargs.pop('line_color', color)}
marginal_kws = {'color': kwargs.pop('bar_color', color)}
kind = kwargs.pop('kind', 'reg')
return sns.jointplot(x=x, y=y, data=data, kind=kind, scatter_kws=scatter_kws, line_kws=line_kws, marginal_kws=marginal_kws, **kwargs)

@staticmethod
def line(data, **kwargs):
ax = MatplotlibPlotMap._newAx(x=False, y=kwargs.get('y', 'left') == 'right', y_side=kwargs.pop('y', 'left'), color=kwargs.get('colors'))
Expand Down
5 changes: 5 additions & 0 deletions lantern/plotting/plottypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class BasePlotType(Enum):
HISTOGRAM = 'histogram'
HORIZONTALBAR = 'horizontalbar'
HORIZONTALSTACKEDBAR = 'horizontalstackedbar'
JOINTPLOT = 'jointplot'
LINE = 'line'
LMPLOT = 'lmplot'
MULTISCATTER = 'multiscatter'
Expand Down Expand Up @@ -130,6 +131,10 @@ def horizontalbar():
def horizontalstackedbar():
'''plot type'''

@abstractmethod
def jointplot():
'''plot type'''

@abstractstatic
def line():
'''plot type'''
Expand Down
9 changes: 9 additions & 0 deletions tests/plot/test_plot_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ def test_lmplot(self):
df = line.sample()
plot(df[[df.columns[0], df.columns[1]]], type='lmplot', scatter={df.columns[0]: {'x': df.columns[0], 'y': df.columns[1]}})

def test_jointplot(self):
with patch('lantern.plotting.plot_matplotlib.in_ipynb', create=True) as mock1:
from lantern.plotting import setBackend
from lantern import line, plot
mock1.return_value = True
setBackend('matplotlib')
df = line.sample()
plot(df[[df.columns[0], df.columns[1]]], type='jointplot', scatter={df.columns[0]: {'x': df.columns[0], 'y': df.columns[1]}})

def test_probplot(self):
with patch('lantern.plotting.plot_matplotlib.in_ipynb', create=True) as mock1:
from lantern.plotting import setBackend
Expand Down

0 comments on commit dc1b330

Please sign in to comment.