Source code for paramz.tests.pickle_tests

'''
Created on 13 Mar 2014

@author: maxz
'''
import unittest, pickle, tempfile, os, paramz
import numpy as np
from ..core.index_operations import ParameterIndexOperations, ParameterIndexOperationsView
from ..core.observable_array import ObsAr
from paramz.transformations import Exponent, Logexp
from ..parameterized import Parameterized
from ..param import Param

[docs]class ListDictTestCase(unittest.TestCase):
[docs] def assertListDictEquals(self, d1, d2, msg=None): #py3 fix #for k,v in d1.iteritems(): for k,v in d1.items(): self.assertListEqual(list(v), list(d2[k]), msg)
[docs] def assertArrayListEquals(self, l1, l2): for a1, a2 in zip(l1,l2): np.testing.assert_array_equal(a1, a2)
[docs]class Test(ListDictTestCase):
[docs] def test_parameter_index_operations(self): pio = ParameterIndexOperations(dict(test1=np.array([4,3,1,6,4]), test2=np.r_[2:130])) piov = ParameterIndexOperationsView(pio, 20, 250) #py3 fix #self.assertListDictEquals(dict(piov.items()), dict(piov.copy().iteritems())) self.assertListDictEquals(dict(piov.items()), dict(piov.copy().items())) #py3 fix #self.assertListDictEquals(dict(pio.iteritems()), dict(pio.copy().items())) self.assertListDictEquals(dict(pio.items()), dict(pio.copy().items())) self.assertArrayListEquals(pio.copy().indices(), pio.indices()) self.assertArrayListEquals(piov.copy().indices(), piov.indices()) with tempfile.TemporaryFile('w+b') as f: pickle.dump(pio, f) f.seek(0) pio2 = pickle.load(f) self.assertListDictEquals(pio._properties, pio2._properties) with tempfile.TemporaryFile('w+b') as f: pickle.dump(piov, f) f.seek(0) pio2 = paramz.load(f) #py3 fix #self.assertListDictEquals(dict(piov.items()), dict(pio2.iteritems())) self.assertListDictEquals(dict(piov.items()), dict(pio2.items()))
[docs] def test_param(self): param = Param('test', np.arange(4*2).reshape(4,2)) param[0].constrain_positive() param[1].fix() pcopy = param.copy() self.assertListEqual(param.tolist(), pcopy.tolist()) self.assertListEqual(str(param).split('\n'), str(pcopy).split('\n')) self.assertIsNot(param, pcopy) with tempfile.TemporaryFile('w+b') as f: pickle.dump(param, f) f.seek(0) pcopy = paramz.load(f) self.assertListEqual(param.tolist(), pcopy.tolist()) self.assertSequenceEqual(str(param), str(pcopy))
[docs] def test_observable_array(self): obs = ObsAr(np.arange(4*2).reshape(4,2)) pcopy = obs.copy() self.assertListEqual(obs.tolist(), pcopy.tolist()) tmpfile = ''.join(map(str, np.random.randint(10, size=20))) try: obs.pickle(tmpfile) pcopy = paramz.load(tmpfile) except: raise finally: os.remove(tmpfile) self.assertListEqual(obs.tolist(), pcopy.tolist()) self.assertSequenceEqual(str(obs), str(pcopy))
[docs] def test_parameterized(self): par = Parameterized('parameterized') p2 = Parameterized('rbf') p2.p1 = Param('lengthscale', np.random.uniform(0.1,.5,3), Exponent()) p2.link_parameter(p2.p1) par.p1 = p2 par.p2 = Param('linear', np.random.uniform(0.1, .5, 2), Logexp()) par.link_parameters(par.p1, par.p2) par.gradient = 10 par.randomize() pcopy = par.copy() self.assertIsInstance(pcopy.constraints, ParameterIndexOperations) self.assertIsInstance(pcopy.rbf.constraints, ParameterIndexOperationsView) self.assertIs(pcopy.constraints, pcopy.rbf.constraints._param_index_ops) self.assertIs(pcopy.constraints, pcopy.rbf.lengthscale.constraints._param_index_ops) self.assertListEqual(par.param_array.tolist(), pcopy.param_array.tolist()) pcopy.gradient = 10 # gradient does not get copied anymore self.assertListEqual(par.gradient_full.tolist(), pcopy.gradient_full.tolist()) self.assertSequenceEqual(str(par), str(pcopy)) self.assertIsNot(par.param_array, pcopy.param_array) self.assertIsNot(par.gradient_full, pcopy.gradient_full) with tempfile.TemporaryFile('w+b') as f: par.pickle(f) f.seek(0) pcopy = paramz.load(f) self.assertListEqual(par.param_array.tolist(), pcopy.param_array.tolist()) pcopy.gradient = 10 np.testing.assert_allclose(par.linear.gradient_full, pcopy.linear.gradient_full) np.testing.assert_allclose(pcopy.linear.gradient_full, 10) self.assertSequenceEqual(str(par), str(pcopy))
def _callback(self, what, which): what.count += 1
if __name__ == "__main__": #import sys;sys.argv = ['', 'Test.test_parameter_index_operations'] unittest.main()