Package pylearn :: Package datasets :: Module tzanetakis
[hide private]

Source Code for Module pylearn.datasets.tzanetakis

  1  """ 
  2  Load Tzanetakis' genre-classification dataset. 
  3   
  4  """ 
  5  from __future__ import absolute_import 
  6   
  7  import os 
  8  import numpy 
  9   
 10  from ..io.amat import AMat 
 11  from .config import data_root 
 12  from .dataset import dataset_factory, Dataset 
13 14 -def centre_data(x, inplace=False):
15 rval = x if inplace else x.copy() 16 #zero-mean 17 rval -= numpy.mean(rval, axis=0) 18 #unit-variance 19 rval *= 1.0 / (1.0e-6 + numpy.std(rval, axis=0)) 20 return rval
21
22 -def mfcc16(segments_per_song = 1, include_covariance = True, random_split = 0, 23 ntrain = 700, nvalid = 100, ntest = 200, 24 normalize=True):
25 if segments_per_song != 1: 26 raise NotImplementedError() 27 28 path = os.path.join(data_root(), 'tzanetakis','feat_mfcc16_540_1.stat.amat') 29 dat = AMat(path=path) 30 all_input = dat.input 31 assert all_input.shape == (1000 * segments_per_song, 152) 32 all_targ = numpy.tile(numpy.arange(10).reshape(10,1), 100 * segments_per_song)\ 33 .reshape(1000 * segments_per_song) 34 35 if not include_covariance: 36 all_input = all_input[:,0:16] 37 38 #shuffle the data according to the random split 39 assert all_input.shape[0] == all_targ.shape[0] 40 seed = random_split + 1 41 numpy.random.RandomState(seed).shuffle(all_input) 42 numpy.random.RandomState(seed).shuffle(all_targ) 43 44 #construct a dataset to return 45 rval = Dataset() 46 47 def prepx(x): 48 return centre_data(x, inplace=True) if normalize else x
49 50 rval.train = Dataset.Obj(x=prepx(all_input[0:ntrain]), 51 y=all_targ[0:ntrain]) 52 rval.valid = Dataset.Obj(x=prepx(all_input[ntrain:ntrain+nvalid]), 53 y=all_targ[ntrain:ntrain+nvalid]) 54 rval.test = Dataset.Obj(x=prepx(all_input[ntrain+nvalid:ntrain+nvalid+ntest]), 55 y=all_targ[ntrain+nvalid:ntrain+nvalid+ntest]) 56 57 rval.n_classes = 10 58 59 return rval 60 61 import theano
62 63 -class TzanetakisExample(theano.Op):
64 """Return the i'th file, label pair from the Tzanetakis dataset.""" 65 @staticmethod
66 - def read_tracklist(alt_path_root=None):
67 """Read the tzanetakis dataset file 68 :rtype: (list, list) 69 :returns: paths, labels 70 """ 71 tracklist = open(data_root() + '/tzanetakis/tracklist.txt') 72 path = [] 73 label = [] 74 for line in tracklist: 75 toks = line.split() 76 try: 77 if alt_path_root is None: 78 path.append(toks[0]) 79 else: 80 line_path = toks[0] 81 file_name = line_path.split('/')[-1] 82 path.append(alt_path_root + '/' + file_name) 83 label.append(toks[1]) 84 except: 85 print 'BAD LINE IN TZANETAKIS TRACKLIST' 86 print line, toks 87 raise 88 assert len(path) == 1000 89 return path, label
90 91 class_idx_dict = dict(blues=numpy.asarray(0), 92 classical=1, 93 country=2, 94 disco=3, 95 hiphop=4, 96 jazz=5, 97 metal=6, 98 pop=7, 99 reggae=8, 100 rock=9) 101
102 - def __init__(self, alt_path_root=None):
103 self.path, self.label = self.read_tracklist(alt_path_root) 104 self.class_idx_dict = {} 105 classes = ('blues classical country disco hiphop jazz metal pop reggae rock').split() 106 for i, c in enumerate(classes): 107 self.class_idx_dict[c] = numpy.asarray(i, dtype='int64')
108 109 n_examples = property(lambda self: len(self.path)) 110 nclasses = property(lambda self: 10) 111 112
113 - def make_node(self, idx):
114 idx_ = theano.tensor.as_tensor_variable(idx) 115 if idx_.type not in theano.tensor.int_types: 116 raise TypeError(idx) 117 return theano.Apply(self, 118 [idx_], 119 [theano.generic('tzanetakis_path'), 120 theano.tensor.lscalar('tzanetakis_label')])
121
122 - def perform(self, node, (idx,), (path, label)):
123 path[0] = self.path[idx] 124 label[0] = self.class_idx_dict[self.label[idx]]
125
126 - def grad(self, inputs, g_output):
127 return [None for i in inputs]
128 129 #tzanetakis_example = TzanetakisExample() #requires reading a data file 130