C++でテンプレートクラスをテンプレート関数に渡す

行列を扱うクラスを作った

C++で行列を管理する方法は2つあります。

1つはdata[col][row]のように2次元配列で管理する方法、もう一つは行×列サイズの1次元配列を作りdata[行番号×列幅+列番号]のようにアクセスする方法です。

私は後者で管理する事が多いのですが、pythonのNumpyを扱ってから、その便利さに感動し、C++で行列を扱うクラスを作ってみました。

内容は1次元配列をカプセル化し行・列でアクセスする至って簡単なテンプレートクラスです。


template <class Elem> class CMatrix
{
private:
	UINT m_nRow;
	UINT m_nCol;
public:
	Elem *m_pData;

	CMatrix() {
		m_pData = NULL;
		m_nRow = 0;
		m_nCol = 0;
	}

	CMatrix(UINT row, UINT col) {
		ReSize(row, col);
	}

	~CMatrix() {
		Clear();
	}

	void Clear() {
		if (m_pData) {
			delete[]m_pData;
			m_pData = NULL;
		}
		m_nRow = 0;
		m_nCol = 0;
	}

	void ReSize(UINT row, UINT col) {
		Clear();
		m_pData = new Elem[row * col];
		m_nRow = row;
		m_nCol = col;
	}

	BOOL Get(UINT row, UINT col, Elem &ret) {
		if (m_nRow <= row || m_nCol <= col) {
			return FALSE;
		}
		ret = m_pData[row * m_nCol + col];
		return TRUE;
	}

	BOOL Set(UINT row, UINT col, Elem ret) {
		if (m_nRow <= row || m_nCol <= col) {
			return FALSE;
		}
		Elem data;
		data = ret;
		m_pData[row * m_nCol + col] = data;
		return TRUE;
	}

	UINT GetRowCount() {
		return m_nRow;
	}

	UINT GetColumnCount() {
		return m_nCol;
	}

	BOOL Dupe(CMatrix<Elem> &dst) {
		dst.ReSize(m_nRow, m_nCol);
		int i, cnt = m_nRow * m_nCol;
		for (i = 0; i < cnt; i++) {
			dst.m_pData[i] = m_pData[i];
		}
		return TRUE;
	}

	// 指定行をstd::vector配列として取り出す
	BOOL GetRow(UINT row, vector<Elem>& ret) {
		int i, cnt;
		cnt = m_nCol;
		ret.resize(cnt);
		for (i = 0; i < cnt; i++) {
			ret[i] = m_pData[i + row * m_nCol];
		}
		return TRUE;
	}

	BOOL GetColumn(UINT col, vector<Elem>& ret) {
		int i, cnt;
		cnt = m_nRow;
		ret.resize(cnt);
		for (i = 0; i < cnt; i++) {
			ret[i] = m_pData[col + m_nCol * i];
		}
		return TRUE;
	}

	BOOL CrossVector(vector<int>& vecX, vector<int>& vecY, vector<Elem>& result) {
		int i, cnt;
		cnt = vecX.size();
		if (vecY.size() != cnt)
			return FALSE;
		result.resize(cnt);

		Elem e;
		UINT row, col;
		for (i = 0; i < cnt; i++) {
			row = vecY[i];
			col = vecX[i];

			if (row > m_nRow || col > m_nCol)
				return FALSE;

			Get(row, col, e);
			result[i] = e;
		}
		return TRUE;
	}

	BOOL Fill(Elem e) {
		int i, cnt = m_nRow * m_nCol;
		for (i = 0; i < cnt; i++) {
			m_pData[i] = e;
		}
		return TRUE;
	}

	// 指定行を取り出した行列を作成する(行は配列で指定する)
	BOOL SelectRows(vector<int>& vecRows, CMatrix<Elem>& result) {
		Elem ret;
		int col, index;
		index = 0;
		result.ReSize(vecRows.size(), m_nCol);
		for (auto a : vecRows) {
			for (col = 0; col < m_nCol; col++) {
				Get(a, col, ret);
				result.Set(index, col, ret);
			}
			index++;
		}
		return TRUE;
	}
};

CMatrix<string>とすると文字列の行列として扱えて、CMatrix<int>とすると数値の行列として扱えます。

数値変換

CSVのファイルを読み取って、CMatrix<string>で管理するようにしたものの

要素を数値として取り出そうとすると、実装先で数値変換のコードを書くハメになったので、もっと手軽にテンプレート間で型変換ができる方法が無いかと考えました。

CMatrix<string>を一括してCMatrix<int>に変換できると便利ですよね。

string→intだけでなく、もっと色んな型のデータを自在に変換できたら便利だと思い

テンプレートクラスをテンプレート関数に渡すという方法で実装してみました。

型AのCMatrixと型BのCMatrixを受け取り、変換関数B = C(A)のポインタを3つ目の引数に取るテンプレート関数です。

template <typename TIn, typename TOut>
BOOL ConvertMatrix(CMatrix<TIn> &mxIn, CMatrix<TOut> &mxOut, TOut(*calc)(TIn)) {
	TIn ret;
	int row, col, cntRow, cntCol;
	cntRow = mxIn.GetRowCount();
	cntCol = mxIn.GetColumnCount();
	mxOut.ReSize(cntRow, cntCol);

	for (row = 0; row < cntRow; row++) {
		for (col = 0; col < cntCol; col++) {
			mxIn.Get(row, col, ret);
			mxOut.Set(row, col, calc(ret));
		}
	}
	return TRUE;
}

これをCMatrixと同じヘッダーファイルに記述しておきます。

次にstringからintに変換する関数to_intを作成します。

int to_int(string s) {
	return atoi(s.data());
}

これを変換関数として第3引数に渡します

データを変換するときはこのように呼び出すだけです。

CMatrix<string> mx_str;
CMatrix<int> mx_int;

ConvertMatrix(mx_str, mx_int, to_int);

コンパイラが型を自動で解釈してくれるので、引数の型を明示する必要もありません。

この関数は実は凄くて、変換関数次第であらゆる一括加工を行う事ができます。(小文字英数を大文字英数に変換したりなど)

テンプレートクラスにテンプレート関数、鬼に金棒ですね。

実際、CSVデータを数値に変換しようと考えた場合、ヘッダーと行頭を除いた数値部分だけをくり抜いて変換する必要があります。

template <typename TIn, typename TOut>
BOOL ConvertMatrixArea(CMatrix<TIn> &mxIn, CMatrix<TOut> &mxOut, TOut(*calc)(TIn), int startRow, int startCol, int cntRow=-1, int cntCol=-1) {
	TIn ret;
	int row, col, dstRow, dstCol;
	
	if (cntRow < 0 || (startRow + cntRow) > mxIn.GetRowCount()) {
		cntRow = mxIn.GetRowCount() - startRow;
	}
	if (cntCol < 0 || (startCol + cntCol) > mxIn.GetColumnCount()) {
		cntCol = mxIn.GetColumnCount() - startCol;
	}

	mxOut.ReSize(cntRow, cntCol);

	dstRow = dstCol = 0;
	for (row = startRow; row < cntRow; row++) {
		dstCol = 0;
		for (col = startCol; col < cntCol; col++) {
			mxIn.Get(row, col, ret);
			mxOut.Set(dstRow, dstCol, calc(ret));
			dstCol++;
		}
		dstRow++;
	}
	return TRUE;
}

先程のテンプレート関数に開始行、開始列、行数、列数を指定したものです。

列数・行数を省略すると末尾までコピーとNumpyを意識してみました。(-1は逆転だったような・・)

ラムダ式で処理

全ての要素に一括で関数処理を掛けたいという場合、クラス側が関数ポインタを受け取って全ての要素にその関数を当てるようにすれば、iだのjだの使って2重ループを書く必要がなくなります。

C++11からラムダ式という即席の無名関数が使えるようになったので、それを使った処理を書いてみたいと思います。

CMatrixにメソッドを追加しました

BOOL FillFunc(void(*func)(Elem&)) {
		int row, col;
		for (row = 0; row < m_nRow; row++) {
			for (col = 0; col < m_nCol; col++) {
				func(m_pData[row * m_nCol + col]);
			}
		}
		return TRUE;
	}

全ての要素を引数の関数に渡すだけの処理です。

CMatrix<string>ならstring&を受け取って処理をする関数のポインタを渡して使います。

これがラムダ式を使った呼び出し法です。

mx_str.FillFunc(
		[](string& s) {s += "1"; }
);

関数の引数の中に関数を書くという、少し気味悪い方法ですがJavaScriptなどではお馴染みですね。

これは全ての要素の末尾に”1″を追加する処理になっています。

2つのラムダ式を引数に取ることで、特定の条件を満たしたものだけを関数に通すようなメソッドも作れます。

BOOL FillFuncIf(void(*func)(Elem&), bool(*func_if)(Elem&)) {
		int row, col;
		for (row = 0; row < m_nRow; row++) {
			for (col = 0; col < m_nCol; col++) {
				if(func_if(m_pData[row * m_nCol + col]))
					func(m_pData[row * m_nCol + col]);
			}
		}
		return TRUE;
	}

第一引数に変換関数、第二引数に判定関数を渡して使います。

実際の使用例です。

mx_str.FillFuncIf(
		[](string& s) {s += "1"; },
		[](string& s) {return s.length() > 5; }
	);

この例は5文字より大きな文字列に対し、末尾に”1″を追加するという処理になります。

複雑なループ処理をクラス内で実装しておいて、データ処理を外部から提供するという形にできるので、

コンテナを扱うテンプレートクラスとラムダ式の相性は抜群ですね。

おすすめ

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です